From 5b3a6700be89360d2f95c31e13bfc59053b5fb31 Mon Sep 17 00:00:00 2001 From: ZhenYi <434836402@qq.com> Date: Tue, 28 Apr 2026 09:42:36 +0800 Subject: [PATCH] refactor(agent): replace custom ReAct loop with rig::agent::Agent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use AgentBuilder for native tool-calling with stream_prompt() - Add RecordingTool wrapper preserving retry + DB recording - Fix tool_choice bug in do_completion (same as call_stream_once) - Add seq field to RoomMessageStreamChunkEvent for strict ordering - Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action - Only final event has done=true, removed premature stream ending - Store __chunks__ JSON in thinking_content for ordered replay --- libs/agent/chat/service.rs | 575 +++++++++++++++----- libs/agent/client/mod.rs | 70 ++- libs/agent/lib.rs | 9 +- libs/agent/react/hooks.rs | 130 ----- libs/agent/react/loop_core.rs | 413 -------------- libs/agent/react/mod.rs | 33 +- libs/agent/tool/recorder.rs | 131 +++++ libs/agent/tool/rig_adapter.rs | 162 +++++- libs/api/room/ws_universal.rs | 1 + libs/queue/types.rs | 3 + libs/room/src/service/ai_react_streaming.rs | 53 +- 11 files changed, 828 insertions(+), 752 deletions(-) delete mode 100644 libs/agent/react/hooks.rs delete mode 100644 libs/agent/react/loop_core.rs create mode 100644 libs/agent/tool/recorder.rs diff --git a/libs/agent/chat/service.rs b/libs/agent/chat/service.rs index 255df1f..b0e41e1 100644 --- a/libs/agent/chat/service.rs +++ b/libs/agent/chat/service.rs @@ -1,22 +1,31 @@ -use std::pin::Pin; -use std::sync::Arc; -use std::time::Duration; +use futures::StreamExt; use models::projects::project_skill; use models::rooms::room_ai; -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; +use rig::agent::{AgentBuilder, MultiTurnStreamItem}; +use rig::client::CompletionClient; +use rig::streaming::{StreamedAssistantContent, StreamingPrompt}; +use sea_orm::*; +use std::pin::Pin; +use std::sync::Arc; use uuid::Uuid; use super::context::RoomMessageContext; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; -use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::AiClientConfig; -use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall}; +use crate::client::types::{ChatRequestMessage, ToolCall}; +use crate::client::{ + StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params, +}; use crate::compact::{CompactConfig, CompactService}; use crate::embed::EmbedService; use crate::error::{AgentError, Result}; use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent}; -use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT}; -use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry}; +use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep}; +use crate::react::types::Action as ReactAction; +use crate::tool::{ + RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor, + registry::ToolRegistry, +}; /// Result from streaming AI response. pub struct StreamResult { @@ -104,9 +113,12 @@ impl ChatService { config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, + project_id: uuid::Uuid, ) -> Option { self.tool_registry.as_ref().map(|registry| { - crate::RigToolSet::from_registry(registry, db, cache, config, room_id, sender_id) + crate::RigToolSet::from_registry( + registry, db, cache, config, room_id, sender_id, project_id, + ) }) } @@ -140,11 +152,16 @@ impl ChatService { let mut tool_depth = 0; let mut input_tokens = 0i64; let mut output_tokens = 0i64; + let session_id = Uuid::new_v4(); + let session_start = std::time::Instant::now(); + let version_id = room_ai.as_ref().and_then(|r| r.version); - let config = AiClientConfig::new( - self.ai_api_key.clone().unwrap_or_default(), - ) - .with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); + let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default()) + .with_base_url( + self.ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()), + ); loop { let response = call_with_params( @@ -183,9 +200,10 @@ impl ChatService { }) .collect(); - messages.push( - ChatRequestMessage::assistant(Some(text.clone()), Some(tool_call_messages.clone())) - ); + messages.push(ChatRequestMessage::assistant( + Some(text.clone()), + Some(tool_call_messages.clone()), + )); // Create ToolCall list for executor (we need real IDs and args) // Since we can't get args from streaming, use name matching from the text @@ -210,15 +228,69 @@ impl ChatService { if let Some(ref registry) = self.tool_registry { ctx.registry_mut().merge(registry.clone()); } + + let recorder = crate::tool::recorder::ToolCallRecorder::with_session( + request.db.clone(), + session_id, + ); + let start = std::time::Instant::now(); + let executor = ToolExecutor::new(); match executor.execute_batch(calls, &mut ctx).await { - Ok(results) => ToolExecutor::to_tool_messages(&results), + Ok(results) => { + for (call, result) in + response.tool_calls_finished.iter().zip(results.iter()) + { + let elapsed = start.elapsed().as_millis() as i64; + let is_error = + matches!(result.result, crate::tool::ToolResult::Error(_)); + let error_msg = match &result.result { + crate::tool::ToolResult::Error(msg) => Some(msg.clone()), + _ => None, + }; + recorder.record(crate::tool::recorder::ToolCallRecord { + tool_call_id: call.clone(), + session_id: recorder.session_id(), + tool_name: call.clone(), + caller: request.sender.uid, + arguments: serde_json::Value::Null, + status: if is_error { + models::ai::ToolCallStatus::Failed + } else { + models::ai::ToolCallStatus::Success + }, + execution_time_ms: Some(elapsed), + error_message: error_msg, + error_stack: None, + retry_count: 0, + }); + } + crate::tool::ToolExecutor::to_tool_messages(&results) + } Err(e) => { + let elapsed = start.elapsed().as_millis() as i64; + for call_name in &response.tool_calls_finished { + recorder.record(crate::tool::recorder::ToolCallRecord { + tool_call_id: Uuid::new_v4().to_string(), + session_id: recorder.session_id(), + tool_name: call_name.clone(), + caller: request.sender.uid, + arguments: serde_json::Value::Null, + status: models::ai::ToolCallStatus::Failed, + execution_time_ms: Some(elapsed), + error_message: Some(e.to_string()), + error_stack: None, + retry_count: 0, + }); + } + let err_msg = format!("[Tool call failed: {}]", e); response .tool_calls_finished .iter() - .map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)) + .map(|_| { + ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg) + }) .collect() } } @@ -250,8 +322,10 @@ impl ChatService { }) .collect(); for event in &tool_events { - if let Some(ctx) = - self.perception_service.passive.detect(event, &skill_entries) + if let Some(ctx) = self + .perception_service + .passive + .detect(event, &skill_entries) { messages.push(ctx.to_system_message()); } @@ -268,16 +342,62 @@ impl ChatService { } else { text }; - return Ok(ProcessResult { content, input_tokens, output_tokens }); + // Record session + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(request.room.id), + model: Set(request.model.id), + version: Set(version_id.unwrap_or_default()), + token_input: Set(input_tokens), + token_output: Set(output_tokens), + latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(&request.db) + .await; + return Ok(ProcessResult { + content, + input_tokens, + output_tokens, + }); } continue; } - return Ok(ProcessResult { content: text, input_tokens, output_tokens }); + // Record session + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(request.room.id), + model: Set(request.model.id), + version: Set(version_id.unwrap_or_default()), + token_input: Set(input_tokens), + token_output: Set(output_tokens), + latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(&request.db) + .await; + return Ok(ProcessResult { + content: text, + input_tokens, + output_tokens, + }); } } - pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result { + pub async fn process_stream( + &self, + request: AiRequest, + on_chunk: StreamCallback, + ) -> Result { // Wrap on_chunk in Arc so it can be shared across loop iterations let on_chunk = Arc::new(on_chunk); let tools: Vec = request.tools.clone().unwrap_or_default(); @@ -302,11 +422,19 @@ impl ChatService { .and_then(|r| r.max_tokens.map(|v| v as u32)) .unwrap_or(request.max_tokens as u32); let mut tool_depth = 0; + let mut total_input_tokens = 0i64; + let mut total_output_tokens = 0i64; + let session_id = Uuid::new_v4(); + let session_start = std::time::Instant::now(); - let config = AiClientConfig::new( - self.ai_api_key.clone().unwrap_or_default(), - ) - .with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); + let version_id = room_ai.as_ref().and_then(|r| r.version); + + let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default()) + .with_base_url( + self.ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()), + ); let mut full_content = String::new(); let mut all_chunks: Vec = Vec::new(); @@ -325,6 +453,7 @@ impl ChatService { temperature, max_tokens, if tools_enabled { Some(&tools) } else { None }, + None, // tool_choice — auto (let model decide) Arc::new(move |delta| { let fut = on_chunk_cb(AiStreamChunk { content: delta.to_string(), @@ -351,6 +480,9 @@ impl ChatService { ) .await?; + total_input_tokens += response.input_tokens; + total_output_tokens += response.output_tokens; + // Collect chunks from this streaming iteration in order. all_chunks.extend(response.chunks); @@ -425,23 +557,44 @@ impl ChatService { request.config.clone(), request.room.id, Some(request.sender.uid), - ); + ) + .with_project(request.project.id); if let Some(ref registry) = self.tool_registry { ctx.registry_mut().merge(registry.clone()); } + + let recorder = crate::tool::recorder::ToolCallRecorder::with_session( + request.db.clone(), + session_id, + ); + for call in &calls { + let start = std::time::Instant::now(); let executor = crate::tool::ToolExecutor::new(); let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await { Ok(r) => r, Err(e) => { + let elapsed = start.elapsed().as_millis() as i64; + recorder.record(crate::tool::recorder::ToolCallRecord { + tool_call_id: call.id.clone(), + session_id: recorder.session_id(), + tool_name: call.name.clone(), + caller: request.sender.uid, + arguments: call.arguments_json().unwrap_or_default(), + status: models::ai::ToolCallStatus::Failed, + execution_time_ms: Some(elapsed), + error_message: Some(e.to_string()), + error_stack: None, + retry_count: 0, + }); + let err_text = format!("[Tool call failed: {}]", e); - tracing::warn!(tool = %call.name, error = %e, "tool_call_failed"); - // Do NOT emit tool_result chunks to frontend — show error via tool_call instead + tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed"); let err_display = format!("❌ {} (failed)", call.name); on_chunk(AiStreamChunk { content: err_display.clone(), done: false, - chunk_type: AiChunkType::ToolCall, + chunk_type: AiChunkType::ToolResult, }) .await; all_chunks.push(StreamChunk { @@ -464,6 +617,29 @@ impl ChatService { text.clone() }; tracing::debug!("tool_result: {} — {}", call.name, preview); + + let elapsed = start.elapsed().as_millis() as i64; + let is_error = matches!(result.result, crate::tool::ToolResult::Error(_)); + let error_msg = match &result.result { + crate::tool::ToolResult::Error(msg) => Some(msg.clone()), + _ => None, + }; + recorder.record(crate::tool::recorder::ToolCallRecord { + tool_call_id: call.id.clone(), + session_id: recorder.session_id(), + tool_name: call.name.clone(), + caller: request.sender.uid, + arguments: call.arguments_json().unwrap_or_default(), + status: if is_error { + models::ai::ToolCallStatus::Failed + } else { + models::ai::ToolCallStatus::Success + }, + execution_time_ms: Some(elapsed), + error_message: error_msg, + error_stack: None, + retry_count: 0, + }); // Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data. // Log server-side only; frontend sees tool_call status via on_chunk below. } @@ -471,7 +647,7 @@ impl ChatService { on_chunk(AiStreamChunk { content: success_display.clone(), done: false, - chunk_type: AiChunkType::ToolCall, + chunk_type: AiChunkType::ToolResult, }) .await; all_chunks.push(StreamChunk { @@ -509,8 +685,10 @@ impl ChatService { }) .collect(); for event in &tool_events { - if let Some(ctx) = - self.perception_service.passive.detect(event, &skill_entries) + if let Some(ctx) = self + .perception_service + .passive + .detect(event, &skill_entries) { messages.push(ctx.to_system_message()); } @@ -533,6 +711,23 @@ impl ChatService { chunk_type: StreamChunkType::Answer, content: max_depth_text, }); + // Record session + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(request.room.id), + model: Set(request.model.id), + version: Set(version_id.unwrap_or_default()), + token_input: Set(total_input_tokens), + token_output: Set(total_output_tokens), + latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(&request.db) + .await; return Ok(StreamResult { content: full_content, reasoning_content: String::new(), @@ -557,6 +752,23 @@ impl ChatService { chunk_type: StreamChunkType::Answer, content: response.content.clone(), }); + // Record session + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(request.room.id), + model: Set(request.model.id), + version: Set(version_id.unwrap_or_default()), + token_input: Set(total_input_tokens), + token_output: Set(total_output_tokens), + latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(&request.db) + .await; return Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, @@ -616,7 +828,10 @@ impl ChatService { parts.push(format!("Description: {}", desc)); } parts.push(format!("Default branch: {}", repo.default_branch)); - parts.push(format!("Private: {}", if repo.is_private { "yes" } else { "no" })); + parts.push(format!( + "Private: {}", + if repo.is_private { "yes" } else { "no" } + )); parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d"))); messages.push(ChatRequestMessage::system(format!( "Mentioned repository:\n{}", @@ -692,7 +907,11 @@ impl ChatService { "Current Project:\n{}\nDescription: {}\nPublic: {}", request.project.display_name, request.project.description.as_deref().unwrap_or("(none)"), - if request.project.is_public { "yes" } else { "no" } + if request.project.is_public { + "yes" + } else { + "no" + } ))); let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)]; @@ -773,7 +992,11 @@ impl ChatService { if let Some(embed_service) = &self.embed_service { let awareness = crate::perception::VectorActiveAwareness::default(); vector_skills = awareness - .detect(embed_service, &request.input, &request.project.id.to_string()) + .detect( + embed_service, + &request.input, + &request.project.id.to_string(), + ) .await; } @@ -813,32 +1036,14 @@ impl ChatService { .await } - fn is_retryable_tool_error(msg: &str) -> bool { - let msg_lower = msg.to_lowercase(); - msg_lower.contains("connection") - || msg_lower.contains("timeout") - || msg_lower.contains("timed out") - || msg_lower.contains("rate limit") - || msg_lower.contains("too many") - || msg_lower.contains("unavailable") - || msg_lower.contains("service unavailable") - || msg_lower.contains("temporarily") - || msg_lower.contains("refused") - || msg_lower.contains("reset") - || msg_lower.contains("broken pipe") - || msg_lower.contains("deadline exceeded") - || msg_lower.contains("try again") - } - - pub async fn process_react( - &self, - request: &AiRequest, - mut on_chunk: C, - ) -> Result + pub async fn process_react(&self, request: &AiRequest, mut on_chunk: C) -> Result where C: FnMut(crate::react::ReactStep) + Send, { - let base_url = self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()); + let base_url = self + .ai_base_url + .clone() + .unwrap_or_else(|| "https://api.openai.com".into()); let api_key = self.ai_api_key.clone().unwrap_or_default(); let client_config = AiClientConfig::new(api_key).with_base_url(base_url); @@ -848,104 +1053,176 @@ impl ChatService { let db = request.db.clone(); let cache = request.cache.clone(); - let config = request.config.clone(); + let cfg = request.config.clone(); let room_id = request.room.id; - let project_id = Some(request.project.id); - let sender_uid = Some(request.sender.uid); - let registry = registry.clone(); + let sender_uid = request.sender.uid; + let project_id = request.project.id; + let session_id = Uuid::new_v4(); + let session_start = std::time::Instant::now(); + let version_id = room_ai::Entity::find() + .filter(room_ai::Column::Room.eq(request.room.id)) + .filter(room_ai::Column::Model.eq(request.model.id)) + .one(&request.db) + .await + .ok() + .flatten() + .and_then(|r| r.version); - let executor: std::sync::Arc< - dyn Fn(String, serde_json::Value) -> Pin> + Send>> - + Send - + Sync, - > = std::sync::Arc::new(move |name: String, args: serde_json::Value| { - let db = db.clone(); - let cache = cache.clone(); - let config = config.clone(); - let room_id = room_id; - let project_id = project_id; - let sender_uid = sender_uid; - let registry = registry.clone(); + // Build rig tools with recording wrapper directly from registry + let mut tools: Vec> = Vec::new(); + for def in registry.definitions() { + let name = def.name.clone(); + if let Some(handler) = registry.get(&name) { + let adapter = crate::tool::RigToolAdapter::new( + handler.clone(), + def.clone(), + db.clone(), + cache.clone(), + cfg.clone(), + room_id, + Some(sender_uid), + project_id, + ); + tools.push(Box::new(RecordingTool::new( + Box::new(adapter), + db.clone(), + session_id, + sender_uid, + ))); + } + } - Box::pin(async move { - let max_retries = 3; - let mut last_err = String::new(); + // Build rig agent (handles multi-turn tool calls natively) + let rig_client = client_config.build_rig_client(); + let model = rig_client.completion_model(&request.model.name); + let agent = AgentBuilder::new(model) + .preamble(DEFAULT_SYSTEM_PROMPT) + .tools(tools) + .default_max_turns(request.max_tool_depth) + .build(); - for attempt in 0..=max_retries { - let mut ctx = ToolContext::new(db.clone(), cache.clone(), config.clone(), room_id, sender_uid); - if let Some(pid) = project_id { - ctx = ctx.with_project(pid); - } - ctx.registry_mut().merge(registry.clone()); + let stream = agent + .stream_prompt(&request.input) + .with_history(Vec::new()) + .multi_turn(request.max_tool_depth) + .await; - let tool_executor = ToolExecutor::new(); - let call = AgentToolCall { - id: Uuid::new_v4().to_string(), - name: name.clone(), - arguments: serde_json::to_string(&args).unwrap_or_else(|_| "{}".into()), - }; + tokio::pin!(stream); - match tool_executor.execute_batch(vec![call], &mut ctx).await { - Ok(results) => { - let result = results.into_iter().next() - .ok_or_else(|| "no tool result returned".to_string())?; - match result.result { - ToolResult::Ok(v) => return Ok(v), - ToolResult::Error(msg) => { - if attempt < max_retries && Self::is_retryable_tool_error(&msg) { - last_err = msg; - let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32)); - tracing::warn!( - tool = %name, - attempt = attempt + 1, - backoff_ms = backoff_ms, - error = %last_err, - "tool_execute_retry" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - continue; - } - return Err(msg); - } - } - } - Err(e) => { - last_err = e.to_string(); - if attempt < max_retries && Self::is_retryable_tool_error(&last_err) { - let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32)); - tracing::warn!( - tool = %name, - attempt = attempt + 1, - backoff_ms = backoff_ms, - error = %last_err, - "tool_execute_retry" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - continue; - } - return Err(last_err); - } + let mut step_count = 0usize; + let mut final_content = String::new(); + let mut total_input_tokens: i64 = 0; + let mut total_output_tokens: i64 = 0; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::Text(text), + )) => { + step_count += 1; + let t = text.text; + on_chunk(ReactStep::Answer { + step: step_count, + answer: t.clone(), + }); + final_content.push_str(&t); + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::Reasoning(reasoning), + )) => { + let reasoning_text = reasoning.reasoning.join(""); + if !reasoning_text.is_empty() { + step_count += 1; + on_chunk(ReactStep::Thought { + step: step_count, + thought: reasoning_text, + }); } } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ReasoningDelta { reasoning, .. }, + )) => { + if !reasoning.is_empty() { + step_count += 1; + on_chunk(ReactStep::Thought { + step: step_count, + thought: reasoning, + }); + } + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { tool_call, .. }, + )) => { + step_count += 1; + let args: serde_json::Value = match &tool_call.function.arguments { + serde_json::Value::String(s) => { + serde_json::from_str(s).unwrap_or(serde_json::Value::Null) + } + v => v.clone(), + }; + on_chunk(ReactStep::Action { + step: step_count, + action: ReactAction::new(&tool_call.function.name, args), + }); + } + Ok(MultiTurnStreamItem::StreamUserItem( + rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, + )) => { + step_count += 1; + let obs = tool_result_content_to_string(&tool_result.content); + on_chunk(ReactStep::Observation { + step: step_count, + observation: obs, + }); + } + Ok(MultiTurnStreamItem::FinalResponse(resp)) => { + let usage = resp.usage(); + total_input_tokens = usage.input_tokens as i64; + total_output_tokens = usage.output_tokens as i64; + // Text was already streamed incrementally via Answer events. + } + Err(e) => { + let err_msg = format!("rig agent stream error: {}", e); + return Err(AgentError::OpenAi(err_msg)); + } + _ => {} + } + } - Err(last_err) - }) as Pin> + Send>> - }); + let elapsed_ms = session_start.elapsed().as_millis() as i64; + let _ = models::ai::ai_session::ActiveModel { + id: Set(session_id), + room: Set(request.room.id), + model: Set(request.model.id), + version: Set(version_id.unwrap_or_default()), + token_input: Set(total_input_tokens), + token_output: Set(total_output_tokens), + latency_ms: Set(Some(elapsed_ms)), + cost: Set(None), + currency: Set(None), + error_message: Set(None), + error_code: Set(None), + created_at: Set(chrono::Utc::now()), + } + .insert(&request.db) + .await; - let tools = self.tools(); - let config = ReactConfig { - max_steps: request.max_tool_depth, - stop_sequences: Vec::new(), - tool_executor: Some(executor), - }; - - let mut agent = ReactAgent::new(DEFAULT_SYSTEM_PROMPT, tools, config); - agent.add_user_message(&request.input); - - agent - .run(&request.model.name, &client_config, |step| { - on_chunk(step); - }) - .await + Ok(final_content) } } + +/// Extract text from rig's ToolResultContent, ignoring images. +fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany) -> String { + use rig::completion::message::ToolResultContent; + content + .iter() + .filter_map(|item| { + if let ToolResultContent::Text(t) = item { + Some(t.text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n") +} diff --git a/libs/agent/client/mod.rs b/libs/agent/client/mod.rs index d7b9c2c..bd69d11 100644 --- a/libs/agent/client/mod.rs +++ b/libs/agent/client/mod.rs @@ -287,14 +287,6 @@ where .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .unwrap_or_default(); - let tc = match tool_choice { - Some("none") => rig::completion::message::ToolChoice::None, - Some("auto") | None => rig::completion::message::ToolChoice::Auto, - Some(s) => rig::completion::message::ToolChoice::Specific { - function_names: vec![s.to_string()], - }, - }; - let mut builder = model.completion_request(""); if !preamble.is_empty() { @@ -317,7 +309,24 @@ where builder = builder.tools(tool_defs); } - builder = builder.tool_choice(tc); + // Only set tool_choice when explicitly provided (mirrors call_stream_once logic) + if let Some(tc) = tool_choice { + match tc { + "none" => { + builder = builder.tool_choice(rig::completion::message::ToolChoice::None); + } + "auto" => { + builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto); + } + s => { + builder = builder.tool_choice( + rig::completion::message::ToolChoice::Specific { + function_names: vec![s.to_string()], + }, + ); + } + } + } let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?; @@ -498,6 +507,7 @@ pub async fn call_stream( temperature: f32, max_tokens: u32, tools: Option<&[serde_json::Value]>, + tool_choice: Option<&str>, on_text_delta: StreamTextCb, on_reasoning_delta: StreamReasoningCb, on_tool_call: StreamToolCallCb, @@ -506,7 +516,7 @@ pub async fn call_stream( loop { let result = call_stream_once( - messages, model_name, config, temperature, max_tokens, tools, + messages, model_name, config, temperature, max_tokens, tools, tool_choice, on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(), ) .await; @@ -542,6 +552,7 @@ async fn call_stream_once( temperature: f32, max_tokens: u32, tools: Option<&[serde_json::Value]>, + tool_choice: Option<&str>, on_text_delta: StreamTextCb, on_reasoning_delta: StreamReasoningCb, on_tool_call: StreamToolCallCb, @@ -581,6 +592,24 @@ async fn call_stream_once( builder = builder.tools(tool_defs); } + if let Some(tc) = tool_choice { + match tc { + "none" => { + builder = builder.tool_choice(rig::completion::message::ToolChoice::None); + } + "auto" => { + builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto); + } + s => { + builder = builder.tool_choice( + rig::completion::message::ToolChoice::Specific { + function_names: vec![s.to_string()], + }, + ); + } + } + } + let stream_fut = async { let mut stream = builder .stream() @@ -592,6 +621,10 @@ async fn call_stream_once( let mut tool_calls: Vec = Vec::new(); let mut chunks: Vec = Vec::new(); + // Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls. + // Filter them out so they don't cause spurious tool execution attempts. + let skip_tool_calls = tool_choice == Some("none"); + use std::collections::HashMap; let mut partial_tool_calls: HashMap = HashMap::new(); let mut stream_finished = false; @@ -612,6 +645,10 @@ async fn call_stream_once( tool_call, internal_call_id, }) => { + if skip_tool_calls { + partial_tool_calls.remove(&internal_call_id); + continue; + } let arguments = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), @@ -638,6 +675,9 @@ async fn call_stream_once( internal_call_id, content: delta_content, }) => { + if skip_tool_calls { + continue; + } use rig::streaming::ToolCallDeltaContent; match delta_content { ToolCallDeltaContent::Name(name) => { @@ -677,8 +717,12 @@ async fn call_stream_once( } Ok(StreamedAssistantContent::Final(response)) => { stream_finished = true; - for (_, tc) in partial_tool_calls.drain() { - tool_calls.push(tc); + if !skip_tool_calls { + for (_, tc) in partial_tool_calls.drain() { + tool_calls.push(tc); + } + } else { + partial_tool_calls.drain(); } if let Some(usage) = response.token_usage() { let in_toks = usage.input_tokens as i64; @@ -700,7 +744,7 @@ async fn call_stream_once( } // Flush any remaining partial tool calls (if stream ended without Final or Final had no usage) - if !stream_finished { + if !stream_finished && !skip_tool_calls { for (_, tc) in partial_tool_calls.drain() { tool_calls.push(tc); } diff --git a/libs/agent/lib.rs b/libs/agent/lib.rs index 40c4eed..34c4c94 100644 --- a/libs/agent/lib.rs +++ b/libs/agent/lib.rs @@ -31,16 +31,13 @@ pub use client::types::ChatRequestMessage; pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary}; pub use embed::{new_embed_client, EmbedClient, EmbedService, QdrantClient, SearchResult}; pub use error::{AgentError, Result}; -pub use react::{ - Hook, HookAction, NoopHook, ReactAgent, ReactConfig, ReactStep, ToolCallAction, TracingHook, - DEFAULT_SYSTEM_PROMPT, -}; +pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT}; pub use tool::{ - ToolCall, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam, + ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam, ToolRegistry, ToolResult, ToolSchema, }; #[cfg(feature = "rig")] pub use agent::RigAgentService; #[cfg(feature = "rig")] -pub use tool::rig_adapter::RigToolSet; +pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error}; diff --git a/libs/agent/react/hooks.rs b/libs/agent/react/hooks.rs deleted file mode 100644 index b61d405..0000000 --- a/libs/agent/react/hooks.rs +++ /dev/null @@ -1,130 +0,0 @@ -//! Observability hooks for the ReAct agent loop. -//! -//! Hooks allow injecting custom behavior (logging, tracing, filtering, termination) -//! at each step of the reasoning loop without coupling to the core agent logic. -//! -//! Inspired by rig's `PromptHook` trait. -//! -//! # Example -//! -//! ```ignore -//! #[derive(Clone)] -//! struct MyHook; -//! -//! impl Hook for MyHook { -//! async fn on_thought(&self, step: usize, thought: &str) -> HookAction { -//! tracing::info!("[step {}] thinking: {}", step, thought); -//! HookAction::Continue -//! } -//! } -//! -//! let agent = ReactAgent::new(prompt, tools, config).with_hook(MyHook); -//! ``` - -use async_trait::async_trait; - -/// Controls whether the agent loop continues after a hook callback. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum HookAction { - /// Continue processing normally. - Continue, - /// Skip the current step and continue. - Skip, - /// Terminate the loop immediately with the given reason. - Terminate(&'static str), -} - -/// Controls behavior after a tool call hook callback. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ToolCallAction { - /// Execute the tool normally. - Continue, - /// Skip tool execution and inject a custom result. - Skip(String), - /// Terminate the loop with the given reason. - Terminate(&'static str), -} - -/// Default no-op hook that does nothing. -#[derive(Debug, Clone, Copy, Default)] -pub struct NoopHook; - -impl Hook for NoopHook {} - -impl Hook for () {} - -/// A hook that logs everything to stderr using `eprintln`. -/// No external dependencies required. -#[derive(Debug, Clone, Copy, Default)] -pub struct TracingHook; - -impl TracingHook { - pub fn new() -> Self { - Self - } -} - -#[async_trait] -impl Hook for TracingHook { - async fn on_thought(&self, step: usize, thought: &str) -> HookAction { - eprintln!("[step {}] thought: {}", step, thought); - HookAction::Continue - } - - async fn on_tool_call(&self, step: usize, name: &str, args_json: &str) -> ToolCallAction { - eprintln!("[step {}] tool_call: {}({})", step, name, args_json); - ToolCallAction::Continue - } - - async fn on_observation(&self, step: usize, observation: &str) -> HookAction { - eprintln!("[step {}] observation: {}", step, observation); - HookAction::Continue - } - - async fn on_answer(&self, step: usize, answer: &str) -> HookAction { - eprintln!("[step {}] answer: {}", step, answer); - HookAction::Continue - } -} - -/// Hook trait for observing and controlling the ReAct agent loop. -/// -/// Implement this trait to inject custom behavior at each step: -/// - Log thoughts, tool calls, observations, and final answers -/// - Filter or redact sensitive data -/// - Dynamically terminate the loop based on content -/// - Inject custom tool results (e.g., for testing or sandboxing) -/// -/// All methods have default no-op implementations, so you only need to -/// override the ones you care about. -/// -/// The hook is called synchronously during the agent loop. Keep hook -/// callbacks fast — avoid blocking I/O. For heavy work, spawn a task -/// and return immediately. -#[async_trait] -pub trait Hook: Send + Sync { - /// Called when the agent emits a thought/reasoning step. - /// - /// Return `HookAction::Terminate` to stop the loop early. - async fn on_thought(&self, _step: usize, _thought: &str) -> HookAction { - HookAction::Continue - } - - /// Called just before a tool is executed. - /// - /// Return `ToolCallAction::Skip(result)` to skip execution and inject `result` instead. - /// Return `ToolCallAction::Terminate` to stop the loop without executing the tool. - async fn on_tool_call(&self, _step: usize, _name: &str, _args_json: &str) -> ToolCallAction { - ToolCallAction::Continue - } - - /// Called after a tool returns an observation. - async fn on_observation(&self, _step: usize, _observation: &str) -> HookAction { - HookAction::Continue - } - - /// Called when the agent produces a final answer. - async fn on_answer(&self, _step: usize, _answer: &str) -> HookAction { - HookAction::Continue - } -} diff --git a/libs/agent/react/loop_core.rs b/libs/agent/react/loop_core.rs deleted file mode 100644 index 9edd6cb..0000000 --- a/libs/agent/react/loop_core.rs +++ /dev/null @@ -1,413 +0,0 @@ -//! ReAct (Reasoning + Acting) agent core. - -use uuid::Uuid; - -use std::sync::Arc; - -use crate::call_with_params; -use crate::client::types::ChatRequestMessage; -use crate::error::{AgentError, Result}; -use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction}; -use crate::react::types::{Action, ReactConfig, ReactStep}; - -pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep}; - -/// A ReAct agent that performs multi-step tool-augmented reasoning. -#[derive(Clone)] -pub struct ReactAgent { - messages: Vec, - #[allow(dead_code)] - tool_definitions: Vec, - config: ReactConfig, - step_count: usize, - hook: Arc, -} - -impl ReactAgent { - /// Create a new agent with a system prompt and tool definitions (as JSON values). - pub fn new( - system_prompt: &str, - tools: Vec, - config: ReactConfig, - ) -> Self { - let messages = vec![ChatRequestMessage::system(system_prompt)]; - Self { - messages, - tool_definitions: tools, - config, - step_count: 0, - hook: Arc::new(NoopHook), - } - } - - /// Add an initial user message to the conversation. - pub fn add_user_message(&mut self, content: &str) { - self.messages.push(ChatRequestMessage::user(content)); - } - - /// Attach a hook to observe and control the agent loop. - /// - /// Hooks can log steps, filter content, inject custom tool results, - /// or terminate the loop early. Multiple `.with_hook()` calls replace - /// the previous hook. - pub fn with_hook(mut self, hook: H) -> Self { - self.hook = Arc::new(hook); - self - } - - /// Run the ReAct loop until a final answer is produced or `max_steps` is reached. - pub async fn run( - &mut self, - model_name: &str, - client_config: &crate::client::AiClientConfig, - mut on_chunk: C, - ) -> Result - where - C: FnMut(ReactStep) + Send, - { - loop { - if self.step_count >= self.config.max_steps { - let msg = format!( - "Agent reached maximum reasoning steps ({}) without producing a final answer.", - self.config.max_steps - ); - on_chunk(ReactStep::Answer { - step: self.step_count, - answer: msg.clone(), - }); - return Ok(msg); - } - - self.step_count += 1; - let step = self.step_count; - - // For ReAct we force text-only responses so the model follows our JSON-in-text format. - let tool_choice_str = if self.tool_definitions.is_empty() { - None - } else { - Some("none") - }; - - let response = call_with_params( - &self.messages, - model_name, - client_config, - 0.2, // temperature - 4096, // max output tokens - None, - if self.tool_definitions.is_empty() { - None - } else { - Some(&self.tool_definitions) - }, - tool_choice_str, - ) - .await?; - - let parsed = parse_react_response(&response.content); - let answer = parsed.answer.clone(); - let action = parsed.action.clone(); - - on_chunk(ReactStep::Thought { - step, - thought: parsed.thought.clone(), - }); - - match self.hook.on_thought(step, &parsed.thought).await { - HookAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at thought step: {}", - reason - ))); - } - HookAction::Skip => {} - HookAction::Continue => {} - } - - // Final answer — emit and return. - if let Some(ans) = answer { - on_chunk(ReactStep::Answer { - step, - answer: ans.clone(), - }); - - match self.hook.on_answer(step, &ans).await { - HookAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at answer step: {}", - reason - ))); - } - _ => {} - } - - return Ok(ans); - } - - // No answer — either do a tool call or fall back. - let Some(act) = action else { - let content = response.content.clone(); - on_chunk(ReactStep::Answer { - step, - answer: content.clone(), - }); - - match self.hook.on_answer(step, &content).await { - HookAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at fallback answer: {}", - reason - ))); - } - _ => {} - } - - return Ok(content); - }; - - on_chunk(ReactStep::Action { - step, - action: act.clone(), - }); - - let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string()); - - match self.hook.on_tool_call(step, &act.name, &args_json).await { - ToolCallAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at tool call: {}", - reason - ))); - } - ToolCallAction::Skip(injected_result) => { - let observation = injected_result; - on_chunk(ReactStep::Observation { - step, - observation: observation.clone(), - }); - - match self.hook.on_observation(step, &observation).await { - HookAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at observation (injected): {}", - reason - ))); - } - _ => {} - } - - // Append assistant message with tool_calls. - let assistant_msg = build_tool_call_message(&act); - self.messages.push(assistant_msg); - - // Append observation as a tool message. - self.messages.push(ChatRequestMessage::tool(&act.id, observation)); - - continue; - } - ToolCallAction::Continue => {} - } - - // Append the assistant message with tool_calls. - let assistant_msg = build_tool_call_message(&act); - self.messages.push(assistant_msg); - - // Execute the tool. - let observation = match &self.config.tool_executor { - Some(exec) => { - let result = exec(act.name.clone(), act.args.clone()).await; - match result { - Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()), - Err(e) => serde_json::json!({ "error": e }).to_string(), - } - } - None => serde_json::json!({ - "error": format!("no tool executor registered for '{}'", act.name) - }) - .to_string(), - }; - - on_chunk(ReactStep::Observation { - step, - observation: observation.clone(), - }); - - match self.hook.on_observation(step, &observation).await { - HookAction::Terminate(reason) => { - return Err(AgentError::Internal(format!( - "hook terminated at observation step: {}", - reason - ))); - } - _ => {} - } - - // Append observation as a tool message. - self.messages.push(ChatRequestMessage::tool(&act.id, observation)); - } - } - - /// Returns the number of steps executed so far. - pub fn steps(&self) -> usize { - self.step_count - } -} - -// --------------------------------------------------------------------------- -// Response parsing -// --------------------------------------------------------------------------- - -struct ParsedReActResponse { - thought: String, - action: Option, - answer: Option, -} - -fn parse_react_response(content: &str) -> ParsedReActResponse { - let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string()); - - #[derive(serde::Deserialize)] - struct RawStep { - #[serde(default)] - thought: Option, - #[serde(default)] - action: Option, - #[serde(default)] - answer: Option, - #[serde(default)] - name: Option, - #[serde(default, rename = "arguments")] - args: Option, - } - - #[derive(serde::Deserialize)] - struct RawAction { - #[serde(default)] - name: Option, - #[serde(default, rename = "arguments")] - args: Option, - } - - match serde_json::from_str::(&json_str) { - Ok(raw) => { - let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string()); - let answer = raw.answer; - let action = raw.action.map(|a| Action { - id: Uuid::new_v4().to_string(), - name: a.name.unwrap_or_default(), - args: a.args.unwrap_or(serde_json::Value::Null), - }); - let action = action.or_else(|| { - if raw.name.is_some() || raw.args.is_some() { - Some(Action { - id: Uuid::new_v4().to_string(), - name: raw.name.unwrap_or_default(), - args: raw.args.unwrap_or(serde_json::Value::Null), - }) - } else { - None - } - }); - - ParsedReActResponse { - thought, - action, - answer, - } - } - Err(_) => ParsedReActResponse { - thought: content.to_string(), - action: None, - answer: None, - }, - } -} - -fn extract_json(s: &str) -> Option { - let trimmed = s.trim(); - - if trimmed.starts_with('{') || trimmed.starts_with('[') { - return Some(trimmed.to_string()); - } - - for line in trimmed.lines() { - let line = line.trim(); - if line.starts_with("```json") || line == "```" { - let mut buf = String::new(); - let mut found_start = false; - for l in trimmed.lines() { - let l = l.trim(); - if !found_start && (l == "```json" || l == "```") { - found_start = true; - continue; - } - if found_start && l == "```" { - break; - } - if found_start { - buf.push_str(l); - buf.push('\n'); - } - } - let result = buf.trim().to_string(); - if !result.is_empty() { - return Some(result); - } - } - } - - let chars: Vec = trimmed.chars().collect(); - for i in 0..chars.len() { - let c = chars[i]; - if (c == '{' || c == '[') && i > 0 { - let prev = chars[i - 1]; - if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' { - continue; - } - let candidate: String = chars[i..].iter().collect(); - if serde_json::from_str::(&candidate).is_ok() { - return Some(candidate.trim_end().to_string()); - } - let mut depth = 0isize; - let mut in_string = false; - let mut escaped = false; - for (j, c) in candidate.char_indices() { - if escaped { escaped = false; continue; } - if c == '\\' { escaped = true; continue; } - if c == '"' { in_string = !in_string; continue; } - if in_string { continue; } - if c == '{' || c == '[' { depth += 1; } - if c == '}' || c == ']' { depth -= 1; } - if depth == 0 { - let json_end = j + c.len_utf8(); - let trimmed_candidate = &candidate[..json_end]; - if serde_json::from_str::(trimmed_candidate).is_ok() { - return Some(trimmed_candidate.to_string()); - } - } - } - } - } - - None -} - -/// Build an assistant message with tool_calls from an Action. -fn build_tool_call_message(action: &Action) -> ChatRequestMessage { - let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string()); - - ChatRequestMessage { - role: "assistant".into(), - content: Some(format!("Action: {}", action.name)), - name: None, - tool_call_id: None, - tool_calls: Some(vec![crate::client::types::ToolCall { - id: action.id.clone(), - type_: "function".into(), - function: crate::client::types::ToolCallFunction { - name: action.name.clone(), - arguments: fn_arg_str, - }, - }]), - } -} diff --git a/libs/agent/react/mod.rs b/libs/agent/react/mod.rs index 140b5b3..935ce1e 100644 --- a/libs/agent/react/mod.rs +++ b/libs/agent/react/mod.rs @@ -1,18 +1,13 @@ -//! ReAct (Reason + Act) agent loop for structured tool use. +//! ReAct (Reason + Act) agent types. //! -//! The agent alternates between a **thought** phase (reasoning about what to do) -//! and an **action** phase (calling tools). Observations from tool results feed -//! back into the next thought, enabling multi-step reasoning. +//! Provides the step types used by the ReAct callback interface. +//! The actual agent loop is handled by rig's built-in Agent. -pub mod hooks; -pub mod loop_core; pub mod types; -pub use hooks::{Hook, HookAction, NoopHook, ToolCallAction, TracingHook}; -pub use loop_core::ReactAgent; pub use types::{ReactConfig, ReactStep}; -/// Default system prompt for the ReAct agent. +/// Default system prompt for the ReAct agent (used with rig's native tool-calling). /// /// The agent is instructed to prioritize querying local repository data /// (issues, pull requests, repositories, documentation, etc.) before @@ -25,26 +20,6 @@ Always query the platform's local data before guessing or referring to external If local data does not contain the answer, state that clearly before considering external information. -## Response Format - -Respond as JSON: - -1. When you need to look up data: -```json -{ - "thought": "What you need to find and why.", - "action": { "name": "tool_name", "arguments": { ... } } -} -``` - -2. When you have enough information to answer: -```json -{ - "thought": "How you arrived at the answer.", - "answer": "Your final answer." -} -``` - ## Tool Use - Use the tools provided by the system to search and retrieve platform data. diff --git a/libs/agent/tool/recorder.rs b/libs/agent/tool/recorder.rs new file mode 100644 index 0000000..f02df35 --- /dev/null +++ b/libs/agent/tool/recorder.rs @@ -0,0 +1,131 @@ +//! Batch tool call recorder — persists tool call records to `ai_tool_call` table. +//! +//! Uses an mpsc channel + background flush loop to batch-insert records, +//! reducing DB pressure from individual inserts. +//! +//! Flush triggers: +//! - Buffer reaches `BATCH_SIZE` (default 50) +//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer +//! - Sender is dropped (remaining records flushed on channel close) + +use std::time::Duration; + +use db::database::AppDatabase; +use models::ai::ai_tool_call; +use models::ai::ToolCallStatus; +use sea_orm::*; +use tokio::sync::mpsc; +use uuid::Uuid; + +const FLUSH_INTERVAL: Duration = Duration::from_secs(5); +const BATCH_SIZE: usize = 50; + +/// A single tool call record to be persisted. +#[derive(Debug, Clone)] +pub struct ToolCallRecord { + pub tool_call_id: String, + pub session_id: Uuid, + pub tool_name: String, + pub caller: Uuid, + pub arguments: serde_json::Value, + pub status: ToolCallStatus, + pub execution_time_ms: Option, + pub error_message: Option, + pub error_stack: Option, + pub retry_count: i32, +} + +/// Channel-based batched recorder. Cheap to clone — all clones share the same sender. +#[derive(Clone)] +pub struct ToolCallRecorder { + tx: mpsc::UnboundedSender, + session_id: Uuid, +} + +impl ToolCallRecorder { + /// Create a new recorder with an auto-generated session ID + /// and spawn a background flush loop. + pub fn new(db: AppDatabase) -> Self { + Self::with_session(db, Uuid::new_v4()) + } + + /// Create a new recorder with a specific session ID + /// (so tool call records can be linked to an `AiSession`). + pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + tokio::spawn(flush_loop(db, rx)); + Self { tx, session_id } + } + + /// The session ID shared by all tool calls recorded through this instance. + pub fn session_id(&self) -> Uuid { + self.session_id + } + + /// Enqueue a tool call record for batch persistence. + pub fn record(&self, record: ToolCallRecord) { + let _ = self.tx.send(record); + } +} + +async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver) { + let mut buffer = Vec::with_capacity(BATCH_SIZE); + let mut ticker = tokio::time::interval(FLUSH_INTERVAL); + ticker.tick().await; // skip first immediate tick + + loop { + tokio::select! { + Some(record) = rx.recv() => { + buffer.push(record); + if buffer.len() >= BATCH_SIZE { + flush(&db, &mut buffer).await; + } + } + _ = ticker.tick() => { + if !buffer.is_empty() { + flush(&db, &mut buffer).await; + } + } + else => { + // Channel closed — flush remaining and exit + if !buffer.is_empty() { + flush(&db, &mut buffer).await; + } + break; + } + } + } +} + +async fn flush(db: &AppDatabase, buffer: &mut Vec) { + let now = chrono::Utc::now(); + let models: Vec = buffer + .iter() + .map(|r| { + let status = r.status.to_string(); + ai_tool_call::ActiveModel { + tool_call_id: Set(r.tool_call_id.clone()), + session: Set(r.session_id), + tool_name: Set(r.tool_name.clone()), + caller: Set(r.caller), + arguments: Set(r.arguments.clone()), + result: Set(serde_json::Value::Null), + status: Set(status), + execution_time_ms: Set(r.execution_time_ms), + error_message: Set(r.error_message.clone()), + error_stack: Set(r.error_stack.clone()), + retry_count: Set(r.retry_count), + created_at: Set(now), + completed_at: Set(Some(now)), + updated_at: Set(now), + } + }) + .collect(); + + let count = models.len(); + if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await { + tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records"); + } + + buffer.clear(); +} diff --git a/libs/agent/tool/rig_adapter.rs b/libs/agent/tool/rig_adapter.rs index c356c9b..9c47d3b 100644 --- a/libs/agent/tool/rig_adapter.rs +++ b/libs/agent/tool/rig_adapter.rs @@ -4,6 +4,7 @@ //! to implement rig's ToolDyn trait, enabling integration with rig's Agent. use std::collections::HashMap; +use std::time::{Duration, Instant}; use futures::FutureExt; use rig::completion::ToolDefinition; @@ -11,8 +12,146 @@ use rig::tool::{ToolDyn, ToolError, ToolSet}; use super::context::ToolContext; use super::definition::ToolDefinition as AgentToolDefinition; +use super::recorder::{ToolCallRecord, ToolCallRecorder}; use super::registry::{ToolHandler, ToolRegistry}; +/// Returns true if the tool error message indicates a transient failure that can be retried. +pub fn is_retryable_tool_error(msg: &str) -> bool { + let lower = msg.to_lowercase(); + lower.contains("retry") + || lower.contains("timeout") + || lower.contains("rate limit") + || lower.contains("too many requests") + || lower.contains("unavailable") + || lower.contains("connection refused") + || lower.contains("5") + || lower.contains("try again") +} + +/// Wraps a ToolDyn with automatic retry and tool call recording. +/// +/// Used by the rig Agent path to replace the custom ReAct executor closure. +pub struct RecordingTool { + inner: Box, + db: db::database::AppDatabase, + session_id: uuid::Uuid, + caller: uuid::Uuid, +} + +impl RecordingTool { + pub fn new( + inner: Box, + db: db::database::AppDatabase, + session_id: uuid::Uuid, + caller: uuid::Uuid, + ) -> Self { + Self { inner, db, session_id, caller } + } +} + +impl ToolDyn for RecordingTool { + fn name(&self) -> String { + self.inner.name() + } + + fn definition<'a>( + &'a self, + prompt: String, + ) -> std::pin::Pin + Send + 'a>> { + self.inner.definition(prompt) + } + + fn call<'a>( + &'a self, + args: String, + ) -> std::pin::Pin> + Send + 'a>> { + let inner: &'a Box = &self.inner; + let db = self.db.clone(); + let session_id = self.session_id; + let caller = self.caller; + let tool_name = inner.name(); + + Box::pin(async move { + let recorder = ToolCallRecorder::with_session(db.clone(), session_id); + let max_retries = 3u32; + let mut last_err = String::new(); + let start = Instant::now(); + + for attempt in 0..=max_retries { + let attempt_start = Instant::now(); + let attempt_args = args.clone(); + let attempt_result = inner.call(attempt_args).await; + + let elapsed_ms = attempt_start.elapsed().as_millis() as i64; + let args_json: serde_json::Value = + serde_json::from_str(&args).unwrap_or_default(); + + match attempt_result { + Ok(value) => { + recorder.record(ToolCallRecord { + tool_call_id: tool_name.clone(), + session_id, + tool_name: tool_name.clone(), + caller, + arguments: args_json, + status: models::ai::ToolCallStatus::Success, + execution_time_ms: Some(elapsed_ms), + error_message: None, + error_stack: None, + retry_count: attempt as i32, + }); + return Ok(value); + } + Err(e) => { + let err_msg = e.to_string(); + if attempt < max_retries && is_retryable_tool_error(&err_msg) { + last_err = err_msg; + let backoff_ms = + 100u64.saturating_mul(2u64.pow(attempt as u32)); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + continue; + } + recorder.record(ToolCallRecord { + tool_call_id: tool_name.clone(), + session_id, + tool_name: tool_name.clone(), + caller, + arguments: args_json, + status: models::ai::ToolCallStatus::Failed, + execution_time_ms: Some(elapsed_ms), + error_message: Some(err_msg.clone()), + error_stack: None, + retry_count: attempt as i32, + }); + return Err(e); + } + } + } + + // Fallback: record failure after all retries exhausted + let elapsed_ms = start.elapsed().as_millis() as i64; + let args_json: serde_json::Value = + serde_json::from_str(&args).unwrap_or_default(); + recorder.record(ToolCallRecord { + tool_call_id: tool_name.clone(), + session_id, + tool_name: tool_name.clone(), + caller, + arguments: args_json, + status: models::ai::ToolCallStatus::Failed, + execution_time_ms: Some(elapsed_ms), + error_message: Some(last_err), + error_stack: None, + retry_count: max_retries as i32, + }); + Err(ToolError::ToolCallError(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "max retries exceeded", + )))) + }) + } +} + /// A wrapper that converts our ToolRegistry to rig's ToolSet. pub struct RigToolSet { /// The rig ToolSet @@ -30,6 +169,7 @@ impl RigToolSet { config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, + project_id: uuid::Uuid, ) -> Self { let mut toolset = ToolSet::default(); let mut definitions = HashMap::new(); @@ -50,6 +190,7 @@ impl RigToolSet { config: config.clone(), room_id, sender_id, + project_id, }; toolset.add_tool(adapter); } @@ -85,6 +226,23 @@ pub struct RigToolAdapter { config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, + project_id: uuid::Uuid, +} + +impl RigToolAdapter { + /// Create a new RigToolAdapter with all required context. + pub fn new( + handler: ToolHandler, + definition: AgentToolDefinition, + db: db::database::AppDatabase, + cache: db::cache::AppCache, + config: config::AppConfig, + room_id: uuid::Uuid, + sender_id: Option, + project_id: uuid::Uuid, + ) -> Self { + Self { handler, definition, db, cache, config, room_id, sender_id, project_id } + } } impl ToolDyn for RigToolAdapter { @@ -113,6 +271,7 @@ impl ToolDyn for RigToolAdapter { let config = self.config.clone(); let room_id = self.room_id; let sender_id = self.sender_id; + let project_id = self.project_id; async move { let ctx = ToolContext::new( @@ -121,7 +280,8 @@ impl ToolDyn for RigToolAdapter { config, room_id, sender_id, - ); + ) + .with_project(project_id); let args_json: serde_json::Value = serde_json::from_str(&args) .map_err(|e| ToolError::JsonError(e))?; diff --git a/libs/api/room/ws_universal.rs b/libs/api/room/ws_universal.rs index 9cc8813..a2711c3 100644 --- a/libs/api/room/ws_universal.rs +++ b/libs/api/room/ws_universal.rs @@ -272,6 +272,7 @@ pub async fn ws_universal( "data": { "message_id": chunk.message_id, "room_id": chunk.room_id, + "seq": chunk.seq, "content": chunk.content, "done": chunk.done, "error": chunk.error, diff --git a/libs/queue/types.rs b/libs/queue/types.rs index ee3f6e4..d49f002 100644 --- a/libs/queue/types.rs +++ b/libs/queue/types.rs @@ -110,6 +110,9 @@ pub struct ProjectRoomEvent { pub struct RoomMessageStreamChunkEvent { pub message_id: Uuid, pub room_id: Uuid, + /// Monotonically increasing sequence number for ordering within this stream. + #[serde(default)] + pub seq: u64, pub content: String, pub done: bool, pub error: Option, diff --git a/libs/room/src/service/ai_react_streaming.rs b/libs/room/src/service/ai_react_streaming.rs index fd2d783..db4e153 100644 --- a/libs/room/src/service/ai_react_streaming.rs +++ b/libs/room/src/service/ai_react_streaming.rs @@ -54,6 +54,7 @@ pub async fn process_message_ai_react_streaming( let answer_buffer: std::sync::Arc> = std::sync::Arc::new(std::sync::Mutex::new(String::new())); let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let chunk_seq: std::sync::Arc = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); // Helper: recover from poison instead of panicking. fn lock_or_recover(mutex: &std::sync::Mutex) -> std::sync::MutexGuard<'_, T> { @@ -65,6 +66,7 @@ pub async fn process_message_ai_react_streaming( let streaming_msg_id = streaming_msg_id; let room_id = room_id_inner; let step_count = step_count.clone(); + let chunk_seq = chunk_seq.clone(); let ai_display_name_for_step = std::sync::Arc::new(ai_display_name.clone()); let steps = steps.clone(); let answer_buffer = answer_buffer.clone(); @@ -73,18 +75,20 @@ pub async fn process_message_ai_react_streaming( let room_manager = room_manager.clone(); let (chunk_type, content) = match &step { ReactStep::Thought { step: _, thought } => { - ("thinking".to_string(), format!("[Thinking] {}", thought)) + ("thinking".to_string(), thought.clone()) } ReactStep::Action { step: _, action } => { *lock_or_recover(&last_action_name) = action.name.clone(); - ("tool_call".to_string(), format!("[Action] Calling `{}` with {:?}", action.name, action.args)) + ("tool_call".to_string(), serde_json::json!({ + "name": action.name, + "arguments": action.args, + }).to_string()) } ReactStep::Observation { step: _, - observation: _, + observation, } => { - let action_name = lock_or_recover(&last_action_name).clone(); - ("tool_call".to_string(), format!("[Observation] {} (completed)", action_name)) + ("tool_result".to_string(), observation.clone()) } ReactStep::Answer { step: _, answer } => { ("answer".to_string(), answer.clone()) @@ -96,22 +100,33 @@ pub async fn process_message_ai_react_streaming( step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); } - // Record ordered step for storage + // Record ordered step for storage — merge consecutive same-type chunks + // to ensure strict think→answer→think→answer alternation. { let mut s = lock_or_recover(&steps); - s.push((chunk_type.clone(), content.clone())); + if let Some(last) = s.last_mut() { + if last.0 == chunk_type { + last.1.push_str(&content); + } else { + s.push((chunk_type.clone(), content.clone())); + } + } else { + s.push((chunk_type.clone(), content.clone())); + } } if is_answer { let mut ab = lock_or_recover(&answer_buffer); ab.push_str(&content); } - let done = is_answer; + let done = false; let ai_name = ai_display_name_for_step.clone(); + let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed); tokio::spawn(async move { let event = RoomMessageStreamChunkEvent { message_id: streaming_msg_id, room_id, + seq: current_seq, content: content.clone(), done, error: None, @@ -125,6 +140,21 @@ pub async fn process_message_ai_react_streaming( let result = chat_service.process_react(&request, on_step).await; + // Broadcast final done=true event to close the streaming channel on frontend. + let final_stream_content = lock_or_recover(&answer_buffer).clone(); + room_manager + .broadcast_stream_chunk(RoomMessageStreamChunkEvent { + message_id: streaming_msg_id, + room_id: room_id_inner, + seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + content: final_stream_content.clone(), + done: true, + error: None, + display_name: Some(ai_display_name.clone()), + chunk_type: Some("answer".to_string()), + }) + .await; + let final_content = lock_or_recover(&answer_buffer).clone(); let all_steps = lock_or_recover(&steps).clone(); let reasoning_chain: String = all_steps @@ -172,7 +202,7 @@ pub async fn process_message_ai_react_streaming( } // Serialize ordered steps as JSON for ordered replay. - let thinking_content = { + let thinking_content_serialized = { let steps = lock_or_recover(&steps); if steps.is_empty() { None @@ -186,6 +216,7 @@ pub async fn process_message_ai_react_streaming( Some(chunks_json.to_string()) } }; + let thinking_content_for_event = thinking_content_serialized.clone(); let envelope = RoomMessageEnvelope { id: streaming_msg_id, @@ -197,7 +228,7 @@ pub async fn process_message_ai_react_streaming( thread_id: None, content: persist_content.clone(), content_type: "text".to_string(), - thinking_content, + thinking_content: thinking_content_serialized, send_at: now, seq, in_reply_to: None, @@ -244,7 +275,7 @@ pub async fn process_message_ai_react_streaming( thread_id: None, content: persist_content, content_type: "text".to_string(), - thinking_content: None, + thinking_content: thinking_content_for_event, send_at: now, seq, display_name: Some(ai_display_name.clone()),