gitdataai/lib/service/agent/sse.rs
2026-06-01 22:04:38 +08:00

555 lines
20 KiB
Rust

use std::sync::Arc;
use ai::agent::{RigAgent, RigStreamChunk, RigToolSet};
use ai::tool::register::ToolRegister;
use serde_json::{Value, json};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use uuid::Uuid;
use super::run::AppAgentContext;
use super::types::AgentRunRequest;
use crate::AppService;
use crate::error::AppError;
impl AppService {
pub async fn agent_run_streaming(
&self,
user_id: Uuid,
req: AgentRunRequest,
) -> Result<mpsc::UnboundedReceiver<String>, AppError> {
let ctx = self.agent_session_context(req.session_id, user_id).await?;
let conversation_id = req.conversation_id.ok_or_else(|| {
AppError::BadRequest("conversation_id is required".to_string())
})?;
let conversation = self
.agent_require_conversation_access(user_id, conversation_id)
.await?;
if conversation.session != ctx.session_id {
return Err(AppError::BadRequest(
"conversation does not belong to session".to_string(),
));
}
let ai_client =
self.agent_build_ai_client(ctx.model_version_id).await?;
let agent_config = self.agent_build_config(&ctx, req.max_steps);
self.agent_maybe_compact(
&ai_client,
&ctx.provider_model_name,
conversation_id,
)
.await
.unwrap_or_else(|e| {
warn!(error = %e, "compaction check failed, continuing");
});
let mut tools: ToolRegister<AppAgentContext> = ToolRegister::new();
if conversation.title == "New Chat"
|| conversation.title.trim().is_empty()
{
tools.register(super::tools::SetTitleTool::new());
}
tools.register(super::memory::SaveMemoryTool::new());
let (memories_text, _memory_rows) =
self.agent_load_memories(ctx.session_id).await?;
if !memories_text.is_empty() {
tools.register(super::memory::RecallMemoryTool::new(memories_text));
}
// Git RPC tools
super::git_tools::register_git_tools(&mut tools);
super::workspace_tools::register_workspace_tools(&mut tools);
super::issue_tools::register_issue_tools(&mut tools);
let (tx, rx) = mpsc::unbounded_channel::<String>();
let agent = RigAgent::new(ai_client.clone(), agent_config)
.map_err(|e| AppError::AiError(e))?;
let timeout_secs = req.timeout_secs.unwrap_or(300);
let agent_request = self
.agent_build_request(
&ai_client,
&ctx,
req.conversation_id,
req.input.clone(),
Some(timeout_secs),
)
.await?;
let invocation_id = Uuid::now_v7();
let ctx_clone = ctx.clone();
let model_name = ctx.provider_model_name.clone();
let self_clone = self.clone();
info!(
invocation_id = %invocation_id,
session_id = %ctx.session_id,
user_id = %user_id,
"agent sse stream starting"
);
if let Err(e) = self
.cache
.set::<Uuid>(
&format!("agent:stream:active:{}", conversation_id),
&invocation_id,
)
.await
{
warn!(error = %e, "agent sse: failed to mark stream active");
}
let user_message_id = match self
.persist_user_message(conversation_id, user_id, &req.input)
.await
{
Ok(id) => Some(id),
Err(e) => {
let _ = tx.send(super::persistence::stream_error(
"failed to persist user message",
));
let _ = self
.cache
.remove(&format!("agent:stream:active:{}", conversation_id))
.await;
return Err(e);
}
};
let first_input = req.input.clone();
let shared_ctx = Arc::new(tokio::sync::Mutex::new(AppAgentContext {
user_id,
session_id: ctx.session_id,
conversation_id,
pending_title: None,
pending_memories: Vec::new(),
git: Some(super::run::GitAgentContext {
channel: self.git.clone(),
db: self.db.clone(),
cache: self.cache.clone(),
}),
}));
let mut tool_set = RigToolSet::from_register(&tools, shared_ctx);
let rig_tools = tool_set.take_tools();
let (mut chunk_rx, handle) = agent.run(agent_request, rig_tools);
let trace_svc = self.clone();
tokio::spawn(async move {
let mut tracer = super::trace::TraceAccumulator::new(
trace_svc,
invocation_id,
conversation_id,
);
let mut phase: &str = "think";
while let Some(chunk) = chunk_rx.recv().await {
let (new_phase, sse_event) =
process_chunk_with_phase(&chunk, phase, &mut tracer).await;
if new_phase != phase {
phase = new_phase;
let _ = tx.send(phase_sse(phase));
}
if let Some(sse) = sse_event {
let _ = tx.send(sse);
}
}
let agent_result = match tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
handle,
)
.await
{
Ok(Ok(inner)) => inner,
Ok(Err(e)) => Err(ai::error::AiError::Response(e.to_string())),
Err(_) => Err(ai::error::AiError::Timeout {
seconds: timeout_secs,
}),
};
let _ = self_clone
.cache
.remove(&format!("agent:stream:active:{}", conversation_id))
.await;
let agent_ctx = tool_set.into_context();
match agent_result {
Ok(result) => {
self_clone.metrics.record_ai_run(&model_name, "completed");
self_clone.metrics.record_ai_token_usage(
&model_name,
result.input_tokens,
result.output_tokens,
);
for step in &result.steps {
for tc in &step.tool_calls {
let status = if tc.error.is_some() {
"error"
} else {
"success"
};
self_clone
.metrics
.record_ai_tool_call(&tc.name, status);
}
}
let reasoning_content: Option<String> = {
let collected: Vec<String> = result
.steps
.iter()
.filter_map(|step| step.reasoning_content.clone())
.collect();
if collected.is_empty() {
None
} else {
Some(collected.join("\n\n"))
}
};
match self_clone
.persist_assistant_message(
conversation_id,
ctx_clone.session_id,
&result.output,
reasoning_content.as_deref(),
invocation_id,
)
.await
{
Ok(msg_id) => {
for step in &result.steps {
for tc in &step.tool_calls {
let _ = self_clone
.agent_record_tool_call(
invocation_id,
ctx_clone.session_id,
Some(conversation_id),
Some(msg_id),
&tc.id,
&tc.name,
Some(&tc.arguments.to_string()),
tc.output
.as_ref()
.map(|v| v.to_string())
.as_deref(),
tc.error.as_deref(),
if tc.error.is_some() {
"error"
} else {
"success"
},
tc.elapsed_ms,
)
.await;
}
}
let _ = self_clone
.persist_billing_and_deduct(
&ctx_clone,
invocation_id,
result.input_tokens,
result.output_tokens,
)
.await;
let _ = self_clone
.agent_record_invocation(
invocation_id,
ctx_clone.session_id,
Some(conversation_id),
Some(msg_id),
ctx_clone.model_version_id,
"completed",
None,
)
.await;
let _ = self_clone
.update_conversation_timestamp(conversation_id)
.await;
let title = agent_ctx
.pending_title
.filter(|t| !t.trim().is_empty())
.or_else(|| {
// Only auto-set title from input when still default.
if conversation.title == "New Chat"
|| conversation.title.trim().is_empty()
{
let first_line = first_input
.lines()
.next()
.unwrap_or(&first_input);
let truncated: String = first_line
.chars()
.take(50)
.collect();
if truncated.trim().is_empty() {
None
} else {
Some(if first_line.len() > 50 {
format!(
"{}",
truncated.trim_end()
)
} else {
truncated.trim().to_string()
})
}
} else {
None
}
});
if let Some(new_title) = &title {
if self_clone
.update_conversation_title(
conversation_id,
new_title,
)
.await
.is_ok()
{
let title_event = serde_json::json!({
"type": "title_updated",
"conversation_id": conversation_id.to_string(),
"title": new_title,
});
let _ = tx.send(format!(
"data: {}\n\n",
title_event
));
}
}
if !agent_ctx.pending_memories.is_empty() {
let _ = self_clone
.agent_persist_memories(
ctx_clone.session_id,
&agent_ctx.pending_memories,
)
.await;
}
let _ = tx.send(done_sse_with_phase(
msg_id,
&result.output,
"summarize",
));
info!(invocation_id = %invocation_id, message_id = %msg_id, "agent sse stream completed");
}
Err(e) => {
error!(error = %e, "sse: failed to persist assistant message");
let _ = tx.send(super::persistence::stream_error(
"persistence failed",
));
}
}
}
Err(e) => {
self_clone.metrics.record_ai_run(&model_name, "failed");
warn!(invocation_id = %invocation_id, error = %e, "agent sse stream failed");
let _ = tx
.send(super::persistence::stream_error(&e.to_string()));
let error_content = format!(
"I encountered an error while processing your request: {}",
e
);
let _ = self_clone
.persist_assistant_message(
conversation_id,
ctx_clone.session_id,
&error_content,
None,
invocation_id,
)
.await;
let _ = self_clone
.agent_record_invocation(
invocation_id,
ctx_clone.session_id,
Some(conversation_id),
user_message_id,
ctx_clone.model_version_id,
"failed",
Some(&e.to_string()),
)
.await;
}
}
});
Ok(rx)
}
}
async fn process_chunk_with_phase(
chunk: &RigStreamChunk,
_current_phase: &str,
tracer: &mut super::trace::TraceAccumulator,
) -> (&'static str, Option<String>) {
match chunk {
RigStreamChunk::Thinking { content, .. } => {
tracer.feed_thinking(content).await;
("think", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::TextDelta { content, .. } => {
tracer.feed_text(content).await;
("answer", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::ToolCallStarted {
tool_call_id,
tool_name,
arguments,
} => {
let args_val: Value =
serde_json::from_str(arguments).unwrap_or(Value::Null);
tracer
.feed_tool_call(tool_call_id, tool_name, &args_val)
.await;
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::ToolCallFinished {
tool_call_id,
tool_name,
output,
error,
} => {
let out_val: Option<Value> = match output {
o if o.is_empty() => None,
o => serde_json::from_str(o).ok(),
};
tracer
.feed_tool_result(
tool_call_id,
tool_name,
out_val.as_ref(),
error.as_deref(),
0,
)
.await;
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::Final {
content,
input_tokens,
output_tokens,
} => {
tracer
.finish(content, *input_tokens as i64, *output_tokens as i64)
.await;
("summarize", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::Failed { .. } => ("summarize", None),
RigStreamChunk::SubagentStarted { .. } => {
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::SubagentCompleted { .. } => {
("act", Some(format_chunk_sse(chunk)))
}
RigStreamChunk::SubagentFailed { .. } => {
("summarize", Some(format_chunk_sse(chunk)))
}
}
}
fn format_chunk_sse(chunk: &RigStreamChunk) -> String {
let payload = match chunk {
RigStreamChunk::TextDelta { index, content } => json!({
"type": "delta", "index": index, "content": content,
}),
RigStreamChunk::Thinking { index, content } => json!({
"type": "thinking", "index": index, "content": content,
}),
RigStreamChunk::ToolCallStarted {
tool_call_id,
tool_name,
arguments,
} => json!({
"type": "tool_call_started",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"arguments": arguments,
}),
RigStreamChunk::ToolCallFinished {
tool_call_id,
tool_name,
output,
error,
} => json!({
"type": "tool_call_finished",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"output": output,
"error": error,
}),
RigStreamChunk::SubagentStarted {
subagent_id,
role,
task,
} => json!({
"type": "subagent_started",
"subagent_id": subagent_id,
"role": role,
"task": task,
}),
RigStreamChunk::SubagentCompleted {
subagent_id,
role,
task,
output,
} => json!({
"type": "subagent_completed",
"subagent_id": subagent_id,
"role": role,
"task": task,
"output": output,
}),
RigStreamChunk::SubagentFailed { error } => json!({
"type": "subagent_failed",
"error": error,
}),
RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => {
return String::new();
}
};
format!("data: {}\n\n", payload)
}
fn phase_sse(phase: &str) -> String {
let payload = json!({
"type": "phase_change",
"phase": phase,
"label": phase_label(phase),
});
format!("data: {}\n\n", payload)
}
fn phase_label(phase: &str) -> &str {
match phase {
"think" => "Thinking",
"answer" => "Answering",
"act" => "Acting",
"summarize" => "Summarizing",
_ => phase,
}
}
fn done_sse_with_phase(message_id: Uuid, output: &str, phase: &str) -> String {
let payload = json!({
"type": "done",
"message_id": message_id.to_string(),
"status": "completed",
"phase": phase,
"label": phase_label(phase),
"output": output,
});
format!("data: {}\n\n", payload)
}