gitdataai/lib/service/agent/sse.rs

384 lines
16 KiB
Rust

use std::sync::Arc;
use ai::agent::{RigAgent, RigStreamChunk, RigToolSet};
use ai::tool::register::ToolRegister;
use serde_json::{json, Value};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use uuid::Uuid;
use super::run::AppAgentContext;
use super::types::AgentRunRequest;
use crate::error::AppError;
use crate::AppService;
impl AppService {
pub async fn agent_run_streaming(
&self,
user_id: Uuid,
req: AgentRunRequest,
) -> Result<mpsc::UnboundedReceiver<String>, AppError> {
let ctx = self.agent_session_context(req.session_id, user_id).await?;
let conversation_id = req
.conversation_id
.ok_or_else(|| AppError::BadRequest("conversation_id is required".to_string()))?;
let conversation = self
.agent_require_conversation_access(user_id, conversation_id)
.await?;
if conversation.session != ctx.session_id {
return Err(AppError::BadRequest(
"conversation does not belong to session".to_string(),
));
}
let ai_client = self.agent_build_ai_client(ctx.model_version_id).await?;
let agent_config = self.agent_build_config(&ctx, req.max_steps);
self.agent_maybe_compact(&ai_client, &ctx.provider_model_name, conversation_id)
.await
.unwrap_or_else(|e| {
warn!(error = %e, "compaction check failed, continuing");
});
let mut tools: ToolRegister<AppAgentContext> = ToolRegister::new();
if conversation.title == "New Chat" || conversation.title.trim().is_empty() {
tools.register(super::tools::SetTitleTool::new());
}
tools.register(super::memory::SaveMemoryTool::new());
let (memories_text, _memory_rows) =
self.agent_load_memories(ctx.session_id).await?;
if !memories_text.is_empty() {
tools.register(super::memory::RecallMemoryTool::new(memories_text));
}
// Git RPC tools
super::git_tools::register_git_tools(&mut tools);
super::workspace_tools::register_workspace_tools(&mut tools);
super::issue_tools::register_issue_tools(&mut tools);
let (tx, rx) = mpsc::unbounded_channel::<String>();
let agent = RigAgent::new(ai_client.clone(), agent_config)
.map_err(|e| AppError::AiError(e))?;
let timeout_secs = req.timeout_secs.unwrap_or(300);
let agent_request = self
.agent_build_request(
&ai_client,
&ctx,
req.conversation_id,
req.input.clone(),
Some(timeout_secs),
)
.await?;
let invocation_id = Uuid::now_v7();
let ctx_clone = ctx.clone();
let self_clone = self.clone();
info!(
invocation_id = %invocation_id,
session_id = %ctx.session_id,
user_id = %user_id,
"agent sse stream starting"
);
if let Err(e) = self.cache.set::<Uuid>(
&format!("agent:stream:active:{}", conversation_id),
&invocation_id,
).await {
warn!(error = %e, "agent sse: failed to mark stream active");
}
let user_message_id = match self
.persist_user_message(conversation_id, user_id, &req.input)
.await
{
Ok(id) => Some(id),
Err(e) => {
let _ = tx.send(super::persistence::stream_error("failed to persist user message"));
let _ = self.cache
.remove(&format!("agent:stream:active:{}", conversation_id))
.await;
return Err(e);
}
};
let first_input = req.input.clone();
let shared_ctx = Arc::new(tokio::sync::Mutex::new(AppAgentContext {
user_id,
session_id: ctx.session_id,
conversation_id,
pending_title: None,
pending_memories: Vec::new(),
git: Some(super::run::GitAgentContext {
channel: self.git.clone(),
db: self.db.clone(),
cache: self.cache.clone(),
}),
}));
let mut tool_set = RigToolSet::from_register(&tools, shared_ctx);
let rig_tools = tool_set.take_tools();
let (mut chunk_rx, handle) = agent.run(agent_request, rig_tools);
let trace_svc = self.clone();
tokio::spawn(async move {
let mut tracer = super::trace::TraceAccumulator::new(
trace_svc, invocation_id, conversation_id,
);
let mut phase: &str = "think";
while let Some(chunk) = chunk_rx.recv().await {
let (new_phase, sse_event) = process_chunk_with_phase(&chunk, phase, &mut tracer).await;
if new_phase != phase {
phase = new_phase;
let _ = tx.send(phase_sse(phase));
}
if let Some(sse) = sse_event {
let _ = tx.send(sse);
}
}
let agent_result = match tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
handle,
)
.await
{
Ok(Ok(inner)) => inner,
Ok(Err(e)) => Err(ai::error::AiError::Response(e.to_string())),
Err(_) => Err(ai::error::AiError::Timeout { seconds: timeout_secs }),
};
let _ = self_clone.cache
.remove(&format!("agent:stream:active:{}", conversation_id))
.await;
let agent_ctx = tool_set.into_context();
match agent_result {
Ok(result) => {
let reasoning_content: Option<String> = {
let collected: Vec<String> = result
.steps
.iter()
.filter_map(|step| step.reasoning_content.clone())
.collect();
if collected.is_empty() { None } else { Some(collected.join("\n\n")) }
};
match self_clone
.persist_assistant_message(
conversation_id,
ctx_clone.session_id,
&result.output,
reasoning_content.as_deref(),
invocation_id,
)
.await
{
Ok(msg_id) => {
for step in &result.steps {
for tc in &step.tool_calls {
let _ = self_clone
.agent_record_tool_call(
invocation_id, ctx_clone.session_id,
Some(conversation_id), Some(msg_id),
&tc.id, &tc.name,
Some(&tc.arguments.to_string()),
tc.output.as_ref().map(|v| v.to_string()).as_deref(),
tc.error.as_deref(),
if tc.error.is_some() { "error" } else { "success" },
tc.elapsed_ms,
)
.await;
}
}
let _ = self_clone.persist_billing_and_deduct(
&ctx_clone, invocation_id,
result.input_tokens, result.output_tokens,
).await;
let _ = self_clone.agent_record_invocation(
invocation_id, ctx_clone.session_id,
Some(conversation_id), Some(msg_id),
ctx_clone.model_version_id, "completed", None,
).await;
let _ = self_clone.update_conversation_timestamp(conversation_id).await;
let title = agent_ctx.pending_title
.filter(|t| !t.trim().is_empty())
.or_else(|| {
// Only auto-set title from input when still default.
if conversation.title == "New Chat" || conversation.title.trim().is_empty() {
let first_line = first_input.lines().next().unwrap_or(&first_input);
let truncated: String = first_line.chars().take(50).collect();
if truncated.trim().is_empty() { None }
else { Some(if first_line.len() > 50 { format!("{}", truncated.trim_end()) } else { truncated.trim().to_string() }) }
} else {
None
}
});
if let Some(new_title) = &title {
if self_clone.update_conversation_title(conversation_id, new_title).await.is_ok() {
let title_event = serde_json::json!({
"type": "title_updated",
"conversation_id": conversation_id.to_string(),
"title": new_title,
});
let _ = tx.send(format!("data: {}\n\n", title_event));
}
}
if !agent_ctx.pending_memories.is_empty() {
let _ = self_clone.agent_persist_memories(
ctx_clone.session_id, &agent_ctx.pending_memories,
).await;
}
let _ = tx.send(done_sse_with_phase(msg_id, &result.output, "summarize"));
info!(invocation_id = %invocation_id, message_id = %msg_id, "agent sse stream completed");
}
Err(e) => {
error!(error = %e, "sse: failed to persist assistant message");
let _ = tx.send(super::persistence::stream_error("persistence failed"));
}
}
}
Err(e) => {
warn!(invocation_id = %invocation_id, error = %e, "agent sse stream failed");
let _ = tx.send(super::persistence::stream_error(&e.to_string()));
let error_content = format!(
"I encountered an error while processing your request: {}", e
);
let _ = self_clone.persist_assistant_message(
conversation_id, ctx_clone.session_id, &error_content, None, invocation_id,
).await;
let _ = self_clone.agent_record_invocation(
invocation_id, ctx_clone.session_id,
Some(conversation_id), user_message_id,
ctx_clone.model_version_id, "failed", Some(&e.to_string()),
).await;
}
}
});
Ok(rx)
}
}
async fn process_chunk_with_phase(
chunk: &RigStreamChunk,
_current_phase: &str,
tracer: &mut super::trace::TraceAccumulator,
) -> (&'static str, Option<String>) {
match chunk {
RigStreamChunk::Thinking { content, .. } => {
tracer.feed_thinking(content).await;
("think", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::TextDelta { content, .. } => {
tracer.feed_text(content).await;
("answer", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => {
let args_val: Value = serde_json::from_str(arguments).unwrap_or(Value::Null);
tracer.feed_tool_call(tool_call_id, tool_name, &args_val).await;
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => {
let out_val: Option<Value> = match output {
o if o.is_empty() => None,
o => serde_json::from_str(o).ok(),
};
tracer.feed_tool_result(tool_call_id, tool_name, out_val.as_ref(), error.as_deref(), 0).await;
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::Final { content, input_tokens, output_tokens } => {
tracer.finish(content, *input_tokens as i64, *output_tokens as i64).await;
("summarize", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::Failed { .. } => ("summarize", None),
RigStreamChunk::SubagentStarted { .. } => ("act", Some(format_chunk_sse(chunk))),
RigStreamChunk::SubagentCompleted { .. } => ("act", Some(format_chunk_sse(chunk))),
RigStreamChunk::SubagentFailed { .. } => ("summarize", Some(format_chunk_sse(chunk))),
}
}
fn format_chunk_sse(chunk: &RigStreamChunk) -> String {
let payload = match chunk {
RigStreamChunk::TextDelta { index, content } => json!({
"type": "delta", "index": index, "content": content,
}),
RigStreamChunk::Thinking { index, content } => json!({
"type": "thinking", "index": index, "content": content,
}),
RigStreamChunk::ToolCallStarted { tool_call_id, tool_name, arguments } => json!({
"type": "tool_call_started",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"arguments": arguments,
}),
RigStreamChunk::ToolCallFinished { tool_call_id, tool_name, output, error } => json!({
"type": "tool_call_finished",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"output": output,
"error": error,
}),
RigStreamChunk::SubagentStarted { subagent_id, role, task } => json!({
"type": "subagent_started",
"subagent_id": subagent_id,
"role": role,
"task": task,
}),
RigStreamChunk::SubagentCompleted { subagent_id, role, task, output } => json!({
"type": "subagent_completed",
"subagent_id": subagent_id,
"role": role,
"task": task,
"output": output,
}),
RigStreamChunk::SubagentFailed { error } => json!({
"type": "subagent_failed",
"error": error,
}),
RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => return String::new(),
};
format!("data: {}\n\n", payload)
}
fn phase_sse(phase: &str) -> String {
let payload = json!({
"type": "phase_change",
"phase": phase,
"label": phase_label(phase),
});
format!("data: {}\n\n", payload)
}
fn phase_label(phase: &str) -> &str {
match phase {
"think" => "Thinking",
"answer" => "Answering",
"act" => "Acting",
"summarize" => "Summarizing",
_ => phase,
}
}
fn done_sse_with_phase(message_id: Uuid, output: &str, phase: &str) -> String {
let payload = json!({
"type": "done",
"message_id": message_id.to_string(),
"status": "completed",
"phase": phase,
"label": phase_label(phase),
"output": output,
});
format!("data: {}\n\n", payload)
}