288 lines
9.4 KiB
Rust
288 lines
9.4 KiB
Rust
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use ai::{agent::RigAgent, tool::register::ToolRegister};
|
|
use cache::AppCache;
|
|
use db::AppDatabase;
|
|
use tonic::transport::Channel;
|
|
use tracing::{info, warn};
|
|
use uuid::Uuid;
|
|
|
|
use super::types::{AgentRunRequest, AgentRunResponse, AgentUsageInfo};
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
#[derive(Clone)]
|
|
pub struct GitAgentContext {
|
|
pub channel: Channel,
|
|
pub db: AppDatabase,
|
|
pub cache: AppCache,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppAgentContext {
|
|
pub user_id: Uuid,
|
|
pub session_id: Uuid,
|
|
pub conversation_id: Uuid,
|
|
pub pending_title: Option<String>,
|
|
pub pending_memories: Vec<super::memory::PendingMemory>,
|
|
pub git: Option<GitAgentContext>,
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn agent_run(
|
|
&self,
|
|
user_id: Uuid,
|
|
req: AgentRunRequest,
|
|
) -> Result<AgentRunResponse, AppError> {
|
|
let ctx = self.agent_session_context(req.session_id, user_id).await?;
|
|
|
|
let conversation_id = req.conversation_id.ok_or_else(|| {
|
|
AppError::BadRequest("conversation_id is required".to_string())
|
|
})?;
|
|
let conversation = self
|
|
.agent_require_conversation_access(user_id, conversation_id)
|
|
.await?;
|
|
if conversation.session != ctx.session_id {
|
|
return Err(AppError::BadRequest(
|
|
"conversation does not belong to session".to_string(),
|
|
));
|
|
}
|
|
|
|
let ai_client =
|
|
self.agent_build_ai_client(ctx.model_version_id).await?;
|
|
|
|
let agent_config = self.agent_build_config(&ctx, req.max_steps);
|
|
|
|
self.agent_maybe_compact(
|
|
&ai_client,
|
|
&ctx.provider_model_name,
|
|
conversation_id,
|
|
)
|
|
.await
|
|
.unwrap_or_else(|e| {
|
|
warn!(error = %e, "compaction check failed, continuing");
|
|
});
|
|
|
|
let mut tools: ToolRegister<AppAgentContext> = ToolRegister::new();
|
|
if conversation.title == "New Chat"
|
|
|| conversation.title.trim().is_empty()
|
|
{
|
|
tools.register(super::tools::SetTitleTool::new());
|
|
}
|
|
tools.register(super::memory::SaveMemoryTool::new());
|
|
|
|
let (memories_text, _memory_rows) =
|
|
self.agent_load_memories(ctx.session_id).await?;
|
|
if !memories_text.is_empty() {
|
|
tools.register(super::memory::RecallMemoryTool::new(memories_text));
|
|
}
|
|
|
|
super::git_tools::register_git_tools(&mut tools);
|
|
super::workspace_tools::register_workspace_tools(&mut tools);
|
|
super::issue_tools::register_issue_tools(&mut tools);
|
|
|
|
let agent_ctx = AppAgentContext {
|
|
user_id,
|
|
session_id: ctx.session_id,
|
|
conversation_id,
|
|
pending_title: None,
|
|
pending_memories: Vec::new(),
|
|
git: Some(GitAgentContext {
|
|
channel: self.git.clone(),
|
|
db: self.db.clone(),
|
|
cache: self.cache.clone(),
|
|
}),
|
|
};
|
|
|
|
let shared_ctx = Arc::new(tokio::sync::Mutex::new(agent_ctx));
|
|
let mut tool_set =
|
|
ai::agent::RigToolSet::from_register(&tools, shared_ctx);
|
|
let rig_tools = tool_set.take_tools();
|
|
|
|
let agent = RigAgent::new(ai_client.clone(), agent_config)
|
|
.map_err(|e| AppError::AiError(e))?;
|
|
|
|
let timeout_secs = req.timeout_secs.unwrap_or(300);
|
|
|
|
let agent_request = self
|
|
.agent_build_request(
|
|
&ai_client,
|
|
&ctx,
|
|
req.conversation_id,
|
|
req.input.clone(),
|
|
Some(timeout_secs),
|
|
)
|
|
.await?;
|
|
|
|
let invocation_id = Uuid::now_v7();
|
|
let model_name = ctx.provider_model_name.clone();
|
|
info!(
|
|
invocation_id = %invocation_id,
|
|
session_id = %ctx.session_id,
|
|
user_id = %user_id,
|
|
billing_target = ?ctx.billing_target,
|
|
"agent run starting"
|
|
);
|
|
|
|
let user_message_id = self
|
|
.persist_user_message(conversation_id, user_id, &req.input)
|
|
.await?;
|
|
let result = match tokio::time::timeout(
|
|
Duration::from_secs(timeout_secs),
|
|
agent.chat(agent_request, rig_tools),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(output)) => Ok(output),
|
|
Ok(Err(e)) => Err(e),
|
|
Err(_) => Err(ai::error::AiError::Timeout {
|
|
seconds: timeout_secs,
|
|
}),
|
|
};
|
|
|
|
let agent_ctx = tool_set.into_context();
|
|
|
|
match result {
|
|
Ok(output) => {
|
|
self.metrics.record_ai_run(&model_name, "completed");
|
|
let message_id = self
|
|
.persist_assistant_message(
|
|
conversation_id,
|
|
ctx.session_id,
|
|
&output,
|
|
None,
|
|
invocation_id,
|
|
)
|
|
.await?;
|
|
|
|
let cost_info = self
|
|
.persist_billing_and_deduct(
|
|
&ctx,
|
|
invocation_id,
|
|
0, // input_tokens not tracked in chat() mode
|
|
0, // output_tokens not tracked in chat() mode
|
|
)
|
|
.await?;
|
|
|
|
self.agent_record_invocation(
|
|
invocation_id,
|
|
ctx.session_id,
|
|
Some(conversation_id),
|
|
Some(message_id),
|
|
ctx.model_version_id,
|
|
"completed",
|
|
None,
|
|
)
|
|
.await?;
|
|
|
|
self.update_conversation_timestamp(conversation_id).await?;
|
|
|
|
let title = agent_ctx
|
|
.pending_title
|
|
.filter(|t| !t.trim().is_empty())
|
|
.or_else(|| {
|
|
// Only auto-set title from input when still default.
|
|
if conversation.title == "New Chat"
|
|
|| conversation.title.trim().is_empty()
|
|
{
|
|
let first_line =
|
|
req.input.lines().next().unwrap_or(&req.input);
|
|
let truncated: String =
|
|
first_line.chars().take(50).collect();
|
|
if truncated.trim().is_empty() {
|
|
None
|
|
} else if first_line.len() > 50 {
|
|
Some(format!("{}…", truncated.trim_end()))
|
|
} else {
|
|
Some(truncated.trim().to_string())
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
if let Some(new_title) = &title {
|
|
if let Err(e) = self
|
|
.update_conversation_title(conversation_id, new_title)
|
|
.await
|
|
{
|
|
warn!(
|
|
conversation_id = %conversation_id,
|
|
error = %e,
|
|
"failed to update conversation title"
|
|
);
|
|
}
|
|
}
|
|
|
|
if let Err(e) = self
|
|
.agent_persist_memories(
|
|
ctx.session_id,
|
|
&agent_ctx.pending_memories,
|
|
)
|
|
.await
|
|
{
|
|
warn!(
|
|
invocation_id = %invocation_id,
|
|
error = %e,
|
|
"failed to persist agent memories"
|
|
);
|
|
}
|
|
|
|
info!(
|
|
invocation_id = %invocation_id,
|
|
message_id = %message_id,
|
|
"agent run completed successfully"
|
|
);
|
|
|
|
Ok(AgentRunResponse {
|
|
message_id,
|
|
conversation_id,
|
|
output,
|
|
steps: Vec::new(),
|
|
usage: AgentUsageInfo {
|
|
input_tokens: 0,
|
|
output_tokens: 0,
|
|
total_tokens: 0,
|
|
},
|
|
cost: cost_info,
|
|
})
|
|
}
|
|
Err(e) => {
|
|
self.metrics.record_ai_run(&model_name, "failed");
|
|
warn!(
|
|
invocation_id = %invocation_id,
|
|
error = %e,
|
|
"agent run failed"
|
|
);
|
|
|
|
let error_content = format!(
|
|
"I encountered an error while processing your request: {e}"
|
|
);
|
|
let _ = self
|
|
.persist_assistant_message(
|
|
conversation_id,
|
|
ctx.session_id,
|
|
&error_content,
|
|
None,
|
|
invocation_id,
|
|
)
|
|
.await;
|
|
|
|
self.agent_record_invocation(
|
|
invocation_id,
|
|
ctx.session_id,
|
|
Some(conversation_id),
|
|
Some(user_message_id),
|
|
ctx.model_version_id,
|
|
"failed",
|
|
Some(&e.to_string()),
|
|
)
|
|
.await?;
|
|
|
|
Err(AppError::AiError(e))
|
|
}
|
|
}
|
|
}
|
|
}
|