231 lines
7.0 KiB
Rust
231 lines
7.0 KiB
Rust
use ai::error::{AiError, AiResult};
|
|
use ai::tool::tools::FunctionCall;
|
|
use async_trait::async_trait;
|
|
use git::rpc::proto as p;
|
|
use git::rpc::proto::blob_service_client::BlobServiceClient;
|
|
use git::rpc::proto::commit_service_client::CommitServiceClient;
|
|
use git::rpc::proto::tree_service_client::TreeServiceClient;
|
|
use serde_json::{Value, json};
|
|
|
|
use super::helpers::{
|
|
arg_opt_str, arg_str, git_ctx, require_repo_member, rpc_err,
|
|
};
|
|
use crate::agent::run::AppAgentContext;
|
|
|
|
pub struct GitTreeEntriesTool;
|
|
|
|
impl GitTreeEntriesTool {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for GitTreeEntriesTool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl FunctionCall for GitTreeEntriesTool {
|
|
type Context = AppAgentContext;
|
|
|
|
fn name(&self) -> &'static str {
|
|
"git_tree_entries"
|
|
}
|
|
|
|
fn description(&self) -> &'static str {
|
|
"List files and subdirectories at a given path in a commit's tree. Use this to explore repo structure."
|
|
}
|
|
|
|
fn schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"workspace": { "type": "string", "description": "Workspace name" },
|
|
"repo": { "type": "string", "description": "Repository name" },
|
|
"commit_oid": { "type": "string", "description": "Commit OID to read the tree from" },
|
|
"path": { "type": "string", "description": "Directory path (empty string for root)" }
|
|
},
|
|
"required": ["workspace", "repo", "commit_oid"]
|
|
})
|
|
}
|
|
|
|
async fn call(
|
|
&self,
|
|
ctx: &mut AppAgentContext,
|
|
args: Value,
|
|
) -> AiResult<Value> {
|
|
let git = git_ctx(ctx)?;
|
|
let workspace = arg_str(&args, "workspace")?;
|
|
let repo_name = arg_str(&args, "repo")?;
|
|
let commit_oid = arg_str(&args, "commit_oid")?;
|
|
let path = arg_opt_str(&args, "path").unwrap_or("");
|
|
|
|
let repo =
|
|
require_repo_member(git, ctx.user_id, workspace, repo_name).await?;
|
|
|
|
let mut commit_client = CommitServiceClient::new(git.channel.clone());
|
|
let commit_resp = commit_client
|
|
.commit_info(p::CommitInfoRequest {
|
|
repo_id: repo.id.to_string(),
|
|
oid: Some(p::ObjectId {
|
|
value: commit_oid.to_string(),
|
|
}),
|
|
})
|
|
.await
|
|
.map_err(rpc_err)?
|
|
.into_inner();
|
|
|
|
let tree_oid =
|
|
commit_resp.commit.and_then(|c| c.tree_id).ok_or_else(|| {
|
|
AiError::Response("commit has no tree".to_string())
|
|
})?;
|
|
|
|
let mut client = TreeServiceClient::new(git.channel.clone());
|
|
let resp = client
|
|
.tree_entries(p::TreeEntriesRequest {
|
|
repo_id: repo.id.to_string(),
|
|
oid: Some(tree_oid),
|
|
base_path: path.to_string(),
|
|
last: false,
|
|
})
|
|
.await
|
|
.map_err(rpc_err)?
|
|
.into_inner();
|
|
|
|
let entries: Vec<Value> = resp
|
|
.entries
|
|
.iter()
|
|
.map(|e| {
|
|
let kind = match p::TreeKind::try_from(e.kind) {
|
|
Ok(p::TreeKind::Blob) => "file",
|
|
Ok(p::TreeKind::Tree) => "dir",
|
|
Ok(p::TreeKind::LfsPointer) => "lfs",
|
|
_ => "unknown",
|
|
};
|
|
json!({
|
|
"name": e.name,
|
|
"oid": e.oid.as_ref().map(|o| &o.value),
|
|
"kind": kind,
|
|
"is_binary": e.is_binary,
|
|
"is_lfs": e.is_lfs,
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
Ok(json!({ "entries": entries, "count": entries.len() }))
|
|
}
|
|
}
|
|
|
|
pub struct GitFileContentTool;
|
|
|
|
impl GitFileContentTool {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for GitFileContentTool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl FunctionCall for GitFileContentTool {
|
|
type Context = AppAgentContext;
|
|
|
|
fn name(&self) -> &'static str {
|
|
"git_file_content"
|
|
}
|
|
|
|
fn description(&self) -> &'static str {
|
|
"Read the content of a file at a given path from a specific commit. Returns the file content as text."
|
|
}
|
|
|
|
fn schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"workspace": { "type": "string", "description": "Workspace name" },
|
|
"repo": { "type": "string", "description": "Repository name" },
|
|
"commit_oid": { "type": "string", "description": "Commit OID" },
|
|
"path": { "type": "string", "description": "File path in the repo" }
|
|
},
|
|
"required": ["workspace", "repo", "commit_oid", "path"]
|
|
})
|
|
}
|
|
|
|
async fn call(
|
|
&self,
|
|
ctx: &mut AppAgentContext,
|
|
args: Value,
|
|
) -> AiResult<Value> {
|
|
let git = git_ctx(ctx)?;
|
|
let workspace = arg_str(&args, "workspace")?;
|
|
let repo_name = arg_str(&args, "repo")?;
|
|
let commit_oid = arg_str(&args, "commit_oid")?;
|
|
let path = arg_str(&args, "path")?;
|
|
|
|
let repo =
|
|
require_repo_member(git, ctx.user_id, workspace, repo_name).await?;
|
|
|
|
let mut tree_client = TreeServiceClient::new(git.channel.clone());
|
|
let entry_resp = tree_client
|
|
.tree_entry_by_path_from_commit(
|
|
p::TreeEntryByPathFromCommitRequest {
|
|
repo_id: repo.id.to_string(),
|
|
commit_oid: Some(p::ObjectId {
|
|
value: commit_oid.to_string(),
|
|
}),
|
|
path: path.to_string(),
|
|
},
|
|
)
|
|
.await
|
|
.map_err(rpc_err)?
|
|
.into_inner();
|
|
|
|
let entry = entry_resp.entry.ok_or_else(|| {
|
|
AiError::Config(format!("file not found: {path}"))
|
|
})?;
|
|
|
|
if entry.kind == p::TreeKind::Tree as i32 {
|
|
return Err(AiError::Config(format!(
|
|
"'{path}' is a directory, not a file"
|
|
)));
|
|
}
|
|
|
|
let blob_oid = entry
|
|
.oid
|
|
.ok_or_else(|| AiError::Response("entry has no oid".to_string()))?;
|
|
|
|
let mut blob_client = BlobServiceClient::new(git.channel.clone());
|
|
let blob_resp = blob_client
|
|
.blob_load(p::BlobLoadRequest {
|
|
repo_id: repo.id.to_string(),
|
|
id: Some(blob_oid),
|
|
path: path.to_string(),
|
|
})
|
|
.await
|
|
.map_err(rpc_err)?
|
|
.into_inner();
|
|
|
|
let content = String::from_utf8_lossy(&blob_resp.blob).to_string();
|
|
|
|
let truncated = content.len() > 64_000;
|
|
let content = if truncated {
|
|
format!("{}...(truncated)", &content[..64_000])
|
|
} else {
|
|
content
|
|
};
|
|
|
|
Ok(json!({
|
|
"path": path,
|
|
"content": content,
|
|
"size": blob_resp.blob.len(),
|
|
"truncated": truncated,
|
|
}))
|
|
}
|
|
}
|