gitdataai/lib/service/agent/git_tools/helpers.rs

116 lines
4.0 KiB
Rust

use ai::error::{AiError, AiResult};
#[derive(Debug, Clone)]
pub struct FileChange {
pub path: String,
pub content: Vec<u8>,
}
use ai::tool::register::ToolRegister;
use db::sqlx;
use model::repos::RepoModel;
use serde_json::Value;
use uuid::Uuid;
use crate::agent::run::{AppAgentContext, GitAgentContext};
pub fn register_git_tools(tools: &mut ToolRegister<AppAgentContext>) {
tools.register(super::commit::GitCommitHistoryTool::new());
tools.register(super::commit::GitCommitInfoTool::new());
tools.register(super::commit::GitCommitExistsTool::new());
tools.register(super::commit::GitCherryPickTool::new());
tools.register(super::commit::GitCommitCreateTool::new());
tools.register(super::branch::GitBranchListTool::new());
tools.register(super::branch::GitBranchInfoTool::new());
tools.register(super::branch::GitBranchAheadBehindTool::new());
tools.register(super::branch::GitCreateBranchTool::new());
tools.register(super::branch::GitBranchDeleteTool::new());
tools.register(super::tree::GitTreeEntriesTool::new());
tools.register(super::tree::GitFileContentTool::new());
tools.register(super::diff::GitDiffStatsTool::new());
tools.register(super::diff::GitDiffPatchTool::new());
tools.register(super::blame::GitBlameTool::new());
tools.register(super::tag::GitTagListTool::new());
tools.register(super::tag::GitTagInfoTool::new());
tools.register(super::tag::GitCreateTagTool::new());
tools.register(super::tag::GitDeleteTagTool::new());
tools.register(super::merge::GitMergeBaseTool::new());
tools.register(super::merge::GitMergeAnalysisTool::new());
tools.register(super::merge::GitMergeIsConflictedTool::new());
}
pub(super) async fn require_repo_member(
git: &GitAgentContext,
user_id: Uuid,
workspace_name: &str,
repo_name: &str,
) -> AiResult<RepoModel> {
let wk_id: Uuid =
sqlx::query_scalar("SELECT id FROM workspace WHERE name = $1")
.bind(workspace_name)
.fetch_optional(git.db.reader())
.await
.map_err(AiError::Database)?
.ok_or_else(|| {
AiError::Config(format!(
"workspace '{workspace_name}' not found"
))
})?;
let is_member: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM wk_member \
WHERE wk = $1 AND \"user\" = $2 AND leave_at IS NULL",
)
.bind(wk_id)
.bind(user_id)
.fetch_one(git.db.reader())
.await
.map_err(AiError::Database)?;
if is_member == 0 {
return Err(AiError::Config(format!(
"user is not a member of workspace '{workspace_name}'"
)));
}
let repo: RepoModel = sqlx::query_as(
"SELECT id, wk, name, description, default_branch, visibility, \
size_bytes, is_archived, is_template, is_mirror, created_by, \
storage_path, created_at, updated_at, deleted_at \
FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL",
)
.bind(wk_id)
.bind(repo_name)
.fetch_optional(git.db.reader())
.await
.map_err(AiError::Database)?
.ok_or_else(|| AiError::Config(format!("repo '{repo_name}' not found")))?;
Ok(repo)
}
pub(super) fn git_ctx(ctx: &AppAgentContext) -> AiResult<&GitAgentContext> {
ctx.git.as_ref().ok_or_else(|| {
AiError::Config(
"git tools are not available in this session".to_string(),
)
})
}
pub(super) fn rpc_err(status: tonic::Status) -> AiError {
AiError::Api(format!("git rpc error: {}", status.message()))
}
pub(super) fn arg_str<'a>(args: &'a Value, key: &str) -> AiResult<&'a str> {
args.get(key).and_then(|v| v.as_str()).ok_or_else(|| {
AiError::Config(format!("'{key}' parameter is required"))
})
}
pub(super) fn arg_opt_str<'a>(args: &'a Value, key: &str) -> Option<&'a str> {
args.get(key).and_then(|v| v.as_str())
}
pub(super) fn arg_u64(args: &Value, key: &str, default: u64) -> u64 {
args.get(key).and_then(|v| v.as_u64()).unwrap_or(default)
}