gitdataai/lib/service/agent/run.rs
2026-06-01 22:04:38 +08:00

288 lines
9.4 KiB
Rust

use std::sync::Arc;
use std::time::Duration;
use ai::{agent::RigAgent, tool::register::ToolRegister};
use cache::AppCache;
use db::AppDatabase;
use tonic::transport::Channel;
use tracing::{info, warn};
use uuid::Uuid;
use super::types::{AgentRunRequest, AgentRunResponse, AgentUsageInfo};
use crate::AppService;
use crate::error::AppError;
#[derive(Clone)]
pub struct GitAgentContext {
pub channel: Channel,
pub db: AppDatabase,
pub cache: AppCache,
}
#[derive(Clone)]
pub struct AppAgentContext {
pub user_id: Uuid,
pub session_id: Uuid,
pub conversation_id: Uuid,
pub pending_title: Option<String>,
pub pending_memories: Vec<super::memory::PendingMemory>,
pub git: Option<GitAgentContext>,
}
impl AppService {
pub async fn agent_run(
&self,
user_id: Uuid,
req: AgentRunRequest,
) -> Result<AgentRunResponse, AppError> {
let ctx = self.agent_session_context(req.session_id, user_id).await?;
let conversation_id = req.conversation_id.ok_or_else(|| {
AppError::BadRequest("conversation_id is required".to_string())
})?;
let conversation = self
.agent_require_conversation_access(user_id, conversation_id)
.await?;
if conversation.session != ctx.session_id {
return Err(AppError::BadRequest(
"conversation does not belong to session".to_string(),
));
}
let ai_client =
self.agent_build_ai_client(ctx.model_version_id).await?;
let agent_config = self.agent_build_config(&ctx, req.max_steps);
self.agent_maybe_compact(
&ai_client,
&ctx.provider_model_name,
conversation_id,
)
.await
.unwrap_or_else(|e| {
warn!(error = %e, "compaction check failed, continuing");
});
let mut tools: ToolRegister<AppAgentContext> = ToolRegister::new();
if conversation.title == "New Chat"
|| conversation.title.trim().is_empty()
{
tools.register(super::tools::SetTitleTool::new());
}
tools.register(super::memory::SaveMemoryTool::new());
let (memories_text, _memory_rows) =
self.agent_load_memories(ctx.session_id).await?;
if !memories_text.is_empty() {
tools.register(super::memory::RecallMemoryTool::new(memories_text));
}
super::git_tools::register_git_tools(&mut tools);
super::workspace_tools::register_workspace_tools(&mut tools);
super::issue_tools::register_issue_tools(&mut tools);
let agent_ctx = AppAgentContext {
user_id,
session_id: ctx.session_id,
conversation_id,
pending_title: None,
pending_memories: Vec::new(),
git: Some(GitAgentContext {
channel: self.git.clone(),
db: self.db.clone(),
cache: self.cache.clone(),
}),
};
let shared_ctx = Arc::new(tokio::sync::Mutex::new(agent_ctx));
let mut tool_set =
ai::agent::RigToolSet::from_register(&tools, shared_ctx);
let rig_tools = tool_set.take_tools();
let agent = RigAgent::new(ai_client.clone(), agent_config)
.map_err(|e| AppError::AiError(e))?;
let timeout_secs = req.timeout_secs.unwrap_or(300);
let agent_request = self
.agent_build_request(
&ai_client,
&ctx,
req.conversation_id,
req.input.clone(),
Some(timeout_secs),
)
.await?;
let invocation_id = Uuid::now_v7();
let model_name = ctx.provider_model_name.clone();
info!(
invocation_id = %invocation_id,
session_id = %ctx.session_id,
user_id = %user_id,
billing_target = ?ctx.billing_target,
"agent run starting"
);
let user_message_id = self
.persist_user_message(conversation_id, user_id, &req.input)
.await?;
let result = match tokio::time::timeout(
Duration::from_secs(timeout_secs),
agent.chat(agent_request, rig_tools),
)
.await
{
Ok(Ok(output)) => Ok(output),
Ok(Err(e)) => Err(e),
Err(_) => Err(ai::error::AiError::Timeout {
seconds: timeout_secs,
}),
};
let agent_ctx = tool_set.into_context();
match result {
Ok(output) => {
self.metrics.record_ai_run(&model_name, "completed");
let message_id = self
.persist_assistant_message(
conversation_id,
ctx.session_id,
&output,
None,
invocation_id,
)
.await?;
let cost_info = self
.persist_billing_and_deduct(
&ctx,
invocation_id,
0, // input_tokens not tracked in chat() mode
0, // output_tokens not tracked in chat() mode
)
.await?;
self.agent_record_invocation(
invocation_id,
ctx.session_id,
Some(conversation_id),
Some(message_id),
ctx.model_version_id,
"completed",
None,
)
.await?;
self.update_conversation_timestamp(conversation_id).await?;
let title = agent_ctx
.pending_title
.filter(|t| !t.trim().is_empty())
.or_else(|| {
// Only auto-set title from input when still default.
if conversation.title == "New Chat"
|| conversation.title.trim().is_empty()
{
let first_line =
req.input.lines().next().unwrap_or(&req.input);
let truncated: String =
first_line.chars().take(50).collect();
if truncated.trim().is_empty() {
None
} else if first_line.len() > 50 {
Some(format!("{}", truncated.trim_end()))
} else {
Some(truncated.trim().to_string())
}
} else {
None
}
});
if let Some(new_title) = &title {
if let Err(e) = self
.update_conversation_title(conversation_id, new_title)
.await
{
warn!(
conversation_id = %conversation_id,
error = %e,
"failed to update conversation title"
);
}
}
if let Err(e) = self
.agent_persist_memories(
ctx.session_id,
&agent_ctx.pending_memories,
)
.await
{
warn!(
invocation_id = %invocation_id,
error = %e,
"failed to persist agent memories"
);
}
info!(
invocation_id = %invocation_id,
message_id = %message_id,
"agent run completed successfully"
);
Ok(AgentRunResponse {
message_id,
conversation_id,
output,
steps: Vec::new(),
usage: AgentUsageInfo {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
},
cost: cost_info,
})
}
Err(e) => {
self.metrics.record_ai_run(&model_name, "failed");
warn!(
invocation_id = %invocation_id,
error = %e,
"agent run failed"
);
let error_content = format!(
"I encountered an error while processing your request: {e}"
);
let _ = self
.persist_assistant_message(
conversation_id,
ctx.session_id,
&error_content,
None,
invocation_id,
)
.await;
self.agent_record_invocation(
invocation_id,
ctx.session_id,
Some(conversation_id),
Some(user_message_id),
ctx.model_version_id,
"failed",
Some(&e.to_string()),
)
.await?;
Err(AppError::AiError(e))
}
}
}
}