555 lines
20 KiB
Rust
555 lines
20 KiB
Rust
use std::sync::Arc;
|
|
|
|
use ai::agent::{RigAgent, RigStreamChunk, RigToolSet};
|
|
use ai::tool::register::ToolRegister;
|
|
use serde_json::{Value, json};
|
|
use tokio::sync::mpsc;
|
|
use tracing::{error, info, warn};
|
|
use uuid::Uuid;
|
|
|
|
use super::run::AppAgentContext;
|
|
use super::types::AgentRunRequest;
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
impl AppService {
|
|
pub async fn agent_run_streaming(
|
|
&self,
|
|
user_id: Uuid,
|
|
req: AgentRunRequest,
|
|
) -> Result<mpsc::UnboundedReceiver<String>, AppError> {
|
|
let ctx = self.agent_session_context(req.session_id, user_id).await?;
|
|
let conversation_id = req.conversation_id.ok_or_else(|| {
|
|
AppError::BadRequest("conversation_id is required".to_string())
|
|
})?;
|
|
let conversation = self
|
|
.agent_require_conversation_access(user_id, conversation_id)
|
|
.await?;
|
|
if conversation.session != ctx.session_id {
|
|
return Err(AppError::BadRequest(
|
|
"conversation does not belong to session".to_string(),
|
|
));
|
|
}
|
|
|
|
let ai_client =
|
|
self.agent_build_ai_client(ctx.model_version_id).await?;
|
|
let agent_config = self.agent_build_config(&ctx, req.max_steps);
|
|
|
|
self.agent_maybe_compact(
|
|
&ai_client,
|
|
&ctx.provider_model_name,
|
|
conversation_id,
|
|
)
|
|
.await
|
|
.unwrap_or_else(|e| {
|
|
warn!(error = %e, "compaction check failed, continuing");
|
|
});
|
|
|
|
let mut tools: ToolRegister<AppAgentContext> = ToolRegister::new();
|
|
if conversation.title == "New Chat"
|
|
|| conversation.title.trim().is_empty()
|
|
{
|
|
tools.register(super::tools::SetTitleTool::new());
|
|
}
|
|
tools.register(super::memory::SaveMemoryTool::new());
|
|
let (memories_text, _memory_rows) =
|
|
self.agent_load_memories(ctx.session_id).await?;
|
|
if !memories_text.is_empty() {
|
|
tools.register(super::memory::RecallMemoryTool::new(memories_text));
|
|
}
|
|
|
|
// Git RPC tools
|
|
super::git_tools::register_git_tools(&mut tools);
|
|
super::workspace_tools::register_workspace_tools(&mut tools);
|
|
super::issue_tools::register_issue_tools(&mut tools);
|
|
|
|
let (tx, rx) = mpsc::unbounded_channel::<String>();
|
|
|
|
let agent = RigAgent::new(ai_client.clone(), agent_config)
|
|
.map_err(|e| AppError::AiError(e))?;
|
|
|
|
let timeout_secs = req.timeout_secs.unwrap_or(300);
|
|
|
|
let agent_request = self
|
|
.agent_build_request(
|
|
&ai_client,
|
|
&ctx,
|
|
req.conversation_id,
|
|
req.input.clone(),
|
|
Some(timeout_secs),
|
|
)
|
|
.await?;
|
|
|
|
let invocation_id = Uuid::now_v7();
|
|
let ctx_clone = ctx.clone();
|
|
let model_name = ctx.provider_model_name.clone();
|
|
let self_clone = self.clone();
|
|
|
|
info!(
|
|
invocation_id = %invocation_id,
|
|
session_id = %ctx.session_id,
|
|
user_id = %user_id,
|
|
"agent sse stream starting"
|
|
);
|
|
|
|
if let Err(e) = self
|
|
.cache
|
|
.set::<Uuid>(
|
|
&format!("agent:stream:active:{}", conversation_id),
|
|
&invocation_id,
|
|
)
|
|
.await
|
|
{
|
|
warn!(error = %e, "agent sse: failed to mark stream active");
|
|
}
|
|
|
|
let user_message_id = match self
|
|
.persist_user_message(conversation_id, user_id, &req.input)
|
|
.await
|
|
{
|
|
Ok(id) => Some(id),
|
|
Err(e) => {
|
|
let _ = tx.send(super::persistence::stream_error(
|
|
"failed to persist user message",
|
|
));
|
|
let _ = self
|
|
.cache
|
|
.remove(&format!("agent:stream:active:{}", conversation_id))
|
|
.await;
|
|
return Err(e);
|
|
}
|
|
};
|
|
|
|
let first_input = req.input.clone();
|
|
|
|
let shared_ctx = Arc::new(tokio::sync::Mutex::new(AppAgentContext {
|
|
user_id,
|
|
session_id: ctx.session_id,
|
|
conversation_id,
|
|
pending_title: None,
|
|
pending_memories: Vec::new(),
|
|
git: Some(super::run::GitAgentContext {
|
|
channel: self.git.clone(),
|
|
db: self.db.clone(),
|
|
cache: self.cache.clone(),
|
|
}),
|
|
}));
|
|
|
|
let mut tool_set = RigToolSet::from_register(&tools, shared_ctx);
|
|
let rig_tools = tool_set.take_tools();
|
|
let (mut chunk_rx, handle) = agent.run(agent_request, rig_tools);
|
|
|
|
let trace_svc = self.clone();
|
|
|
|
tokio::spawn(async move {
|
|
let mut tracer = super::trace::TraceAccumulator::new(
|
|
trace_svc,
|
|
invocation_id,
|
|
conversation_id,
|
|
);
|
|
let mut phase: &str = "think";
|
|
while let Some(chunk) = chunk_rx.recv().await {
|
|
let (new_phase, sse_event) =
|
|
process_chunk_with_phase(&chunk, phase, &mut tracer).await;
|
|
if new_phase != phase {
|
|
phase = new_phase;
|
|
let _ = tx.send(phase_sse(phase));
|
|
}
|
|
if let Some(sse) = sse_event {
|
|
let _ = tx.send(sse);
|
|
}
|
|
}
|
|
let agent_result = match tokio::time::timeout(
|
|
std::time::Duration::from_secs(timeout_secs),
|
|
handle,
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(inner)) => inner,
|
|
Ok(Err(e)) => Err(ai::error::AiError::Response(e.to_string())),
|
|
Err(_) => Err(ai::error::AiError::Timeout {
|
|
seconds: timeout_secs,
|
|
}),
|
|
};
|
|
|
|
let _ = self_clone
|
|
.cache
|
|
.remove(&format!("agent:stream:active:{}", conversation_id))
|
|
.await;
|
|
|
|
let agent_ctx = tool_set.into_context();
|
|
|
|
match agent_result {
|
|
Ok(result) => {
|
|
self_clone.metrics.record_ai_run(&model_name, "completed");
|
|
self_clone.metrics.record_ai_token_usage(
|
|
&model_name,
|
|
result.input_tokens,
|
|
result.output_tokens,
|
|
);
|
|
for step in &result.steps {
|
|
for tc in &step.tool_calls {
|
|
let status = if tc.error.is_some() {
|
|
"error"
|
|
} else {
|
|
"success"
|
|
};
|
|
self_clone
|
|
.metrics
|
|
.record_ai_tool_call(&tc.name, status);
|
|
}
|
|
}
|
|
|
|
let reasoning_content: Option<String> = {
|
|
let collected: Vec<String> = result
|
|
.steps
|
|
.iter()
|
|
.filter_map(|step| step.reasoning_content.clone())
|
|
.collect();
|
|
if collected.is_empty() {
|
|
None
|
|
} else {
|
|
Some(collected.join("\n\n"))
|
|
}
|
|
};
|
|
|
|
match self_clone
|
|
.persist_assistant_message(
|
|
conversation_id,
|
|
ctx_clone.session_id,
|
|
&result.output,
|
|
reasoning_content.as_deref(),
|
|
invocation_id,
|
|
)
|
|
.await
|
|
{
|
|
Ok(msg_id) => {
|
|
for step in &result.steps {
|
|
for tc in &step.tool_calls {
|
|
let _ = self_clone
|
|
.agent_record_tool_call(
|
|
invocation_id,
|
|
ctx_clone.session_id,
|
|
Some(conversation_id),
|
|
Some(msg_id),
|
|
&tc.id,
|
|
&tc.name,
|
|
Some(&tc.arguments.to_string()),
|
|
tc.output
|
|
.as_ref()
|
|
.map(|v| v.to_string())
|
|
.as_deref(),
|
|
tc.error.as_deref(),
|
|
if tc.error.is_some() {
|
|
"error"
|
|
} else {
|
|
"success"
|
|
},
|
|
tc.elapsed_ms,
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
|
|
let _ = self_clone
|
|
.persist_billing_and_deduct(
|
|
&ctx_clone,
|
|
invocation_id,
|
|
result.input_tokens,
|
|
result.output_tokens,
|
|
)
|
|
.await;
|
|
|
|
let _ = self_clone
|
|
.agent_record_invocation(
|
|
invocation_id,
|
|
ctx_clone.session_id,
|
|
Some(conversation_id),
|
|
Some(msg_id),
|
|
ctx_clone.model_version_id,
|
|
"completed",
|
|
None,
|
|
)
|
|
.await;
|
|
|
|
let _ = self_clone
|
|
.update_conversation_timestamp(conversation_id)
|
|
.await;
|
|
|
|
let title = agent_ctx
|
|
.pending_title
|
|
.filter(|t| !t.trim().is_empty())
|
|
.or_else(|| {
|
|
// Only auto-set title from input when still default.
|
|
if conversation.title == "New Chat"
|
|
|| conversation.title.trim().is_empty()
|
|
{
|
|
let first_line = first_input
|
|
.lines()
|
|
.next()
|
|
.unwrap_or(&first_input);
|
|
let truncated: String = first_line
|
|
.chars()
|
|
.take(50)
|
|
.collect();
|
|
if truncated.trim().is_empty() {
|
|
None
|
|
} else {
|
|
Some(if first_line.len() > 50 {
|
|
format!(
|
|
"{}…",
|
|
truncated.trim_end()
|
|
)
|
|
} else {
|
|
truncated.trim().to_string()
|
|
})
|
|
}
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
if let Some(new_title) = &title {
|
|
if self_clone
|
|
.update_conversation_title(
|
|
conversation_id,
|
|
new_title,
|
|
)
|
|
.await
|
|
.is_ok()
|
|
{
|
|
let title_event = serde_json::json!({
|
|
"type": "title_updated",
|
|
"conversation_id": conversation_id.to_string(),
|
|
"title": new_title,
|
|
});
|
|
let _ = tx.send(format!(
|
|
"data: {}\n\n",
|
|
title_event
|
|
));
|
|
}
|
|
}
|
|
|
|
if !agent_ctx.pending_memories.is_empty() {
|
|
let _ = self_clone
|
|
.agent_persist_memories(
|
|
ctx_clone.session_id,
|
|
&agent_ctx.pending_memories,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let _ = tx.send(done_sse_with_phase(
|
|
msg_id,
|
|
&result.output,
|
|
"summarize",
|
|
));
|
|
info!(invocation_id = %invocation_id, message_id = %msg_id, "agent sse stream completed");
|
|
}
|
|
Err(e) => {
|
|
error!(error = %e, "sse: failed to persist assistant message");
|
|
let _ = tx.send(super::persistence::stream_error(
|
|
"persistence failed",
|
|
));
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
self_clone.metrics.record_ai_run(&model_name, "failed");
|
|
warn!(invocation_id = %invocation_id, error = %e, "agent sse stream failed");
|
|
let _ = tx
|
|
.send(super::persistence::stream_error(&e.to_string()));
|
|
let error_content = format!(
|
|
"I encountered an error while processing your request: {}",
|
|
e
|
|
);
|
|
let _ = self_clone
|
|
.persist_assistant_message(
|
|
conversation_id,
|
|
ctx_clone.session_id,
|
|
&error_content,
|
|
None,
|
|
invocation_id,
|
|
)
|
|
.await;
|
|
let _ = self_clone
|
|
.agent_record_invocation(
|
|
invocation_id,
|
|
ctx_clone.session_id,
|
|
Some(conversation_id),
|
|
user_message_id,
|
|
ctx_clone.model_version_id,
|
|
"failed",
|
|
Some(&e.to_string()),
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(rx)
|
|
}
|
|
}
|
|
async fn process_chunk_with_phase(
|
|
chunk: &RigStreamChunk,
|
|
_current_phase: &str,
|
|
tracer: &mut super::trace::TraceAccumulator,
|
|
) -> (&'static str, Option<String>) {
|
|
match chunk {
|
|
RigStreamChunk::Thinking { content, .. } => {
|
|
tracer.feed_thinking(content).await;
|
|
("think", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::TextDelta { content, .. } => {
|
|
tracer.feed_text(content).await;
|
|
("answer", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::ToolCallStarted {
|
|
tool_call_id,
|
|
tool_name,
|
|
arguments,
|
|
} => {
|
|
let args_val: Value =
|
|
serde_json::from_str(arguments).unwrap_or(Value::Null);
|
|
tracer
|
|
.feed_tool_call(tool_call_id, tool_name, &args_val)
|
|
.await;
|
|
("act", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::ToolCallFinished {
|
|
tool_call_id,
|
|
tool_name,
|
|
output,
|
|
error,
|
|
} => {
|
|
let out_val: Option<Value> = match output {
|
|
o if o.is_empty() => None,
|
|
o => serde_json::from_str(o).ok(),
|
|
};
|
|
tracer
|
|
.feed_tool_result(
|
|
tool_call_id,
|
|
tool_name,
|
|
out_val.as_ref(),
|
|
error.as_deref(),
|
|
0,
|
|
)
|
|
.await;
|
|
("act", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::Final {
|
|
content,
|
|
input_tokens,
|
|
output_tokens,
|
|
} => {
|
|
tracer
|
|
.finish(content, *input_tokens as i64, *output_tokens as i64)
|
|
.await;
|
|
("summarize", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::Failed { .. } => ("summarize", None),
|
|
RigStreamChunk::SubagentStarted { .. } => {
|
|
("act", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::SubagentCompleted { .. } => {
|
|
("act", Some(format_chunk_sse(chunk)))
|
|
}
|
|
RigStreamChunk::SubagentFailed { .. } => {
|
|
("summarize", Some(format_chunk_sse(chunk)))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_chunk_sse(chunk: &RigStreamChunk) -> String {
|
|
let payload = match chunk {
|
|
RigStreamChunk::TextDelta { index, content } => json!({
|
|
"type": "delta", "index": index, "content": content,
|
|
}),
|
|
RigStreamChunk::Thinking { index, content } => json!({
|
|
"type": "thinking", "index": index, "content": content,
|
|
}),
|
|
RigStreamChunk::ToolCallStarted {
|
|
tool_call_id,
|
|
tool_name,
|
|
arguments,
|
|
} => json!({
|
|
"type": "tool_call_started",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_name": tool_name,
|
|
"arguments": arguments,
|
|
}),
|
|
RigStreamChunk::ToolCallFinished {
|
|
tool_call_id,
|
|
tool_name,
|
|
output,
|
|
error,
|
|
} => json!({
|
|
"type": "tool_call_finished",
|
|
"tool_call_id": tool_call_id,
|
|
"tool_name": tool_name,
|
|
"output": output,
|
|
"error": error,
|
|
}),
|
|
RigStreamChunk::SubagentStarted {
|
|
subagent_id,
|
|
role,
|
|
task,
|
|
} => json!({
|
|
"type": "subagent_started",
|
|
"subagent_id": subagent_id,
|
|
"role": role,
|
|
"task": task,
|
|
}),
|
|
RigStreamChunk::SubagentCompleted {
|
|
subagent_id,
|
|
role,
|
|
task,
|
|
output,
|
|
} => json!({
|
|
"type": "subagent_completed",
|
|
"subagent_id": subagent_id,
|
|
"role": role,
|
|
"task": task,
|
|
"output": output,
|
|
}),
|
|
RigStreamChunk::SubagentFailed { error } => json!({
|
|
"type": "subagent_failed",
|
|
"error": error,
|
|
}),
|
|
RigStreamChunk::Final { .. } | RigStreamChunk::Failed { .. } => {
|
|
return String::new();
|
|
}
|
|
};
|
|
format!("data: {}\n\n", payload)
|
|
}
|
|
|
|
fn phase_sse(phase: &str) -> String {
|
|
let payload = json!({
|
|
"type": "phase_change",
|
|
"phase": phase,
|
|
"label": phase_label(phase),
|
|
});
|
|
format!("data: {}\n\n", payload)
|
|
}
|
|
|
|
fn phase_label(phase: &str) -> &str {
|
|
match phase {
|
|
"think" => "Thinking",
|
|
"answer" => "Answering",
|
|
"act" => "Acting",
|
|
"summarize" => "Summarizing",
|
|
_ => phase,
|
|
}
|
|
}
|
|
|
|
fn done_sse_with_phase(message_id: Uuid, output: &str, phase: &str) -> String {
|
|
let payload = json!({
|
|
"type": "done",
|
|
"message_id": message_id.to_string(),
|
|
"status": "completed",
|
|
"phase": phase,
|
|
"label": phase_label(phase),
|
|
"output": output,
|
|
});
|
|
format!("data: {}\n\n", payload)
|
|
}
|