refactor(agent): replace custom ReAct loop with rig::agent::Agent

- Use AgentBuilder for native tool-calling with stream_prompt()
- Add RecordingTool wrapper preserving retry + DB recording
- Fix tool_choice bug in do_completion (same as call_stream_once)
- Add seq field to RoomMessageStreamChunkEvent for strict ordering
- Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action
- Only final event has done=true, removed premature stream ending
- Store __chunks__ JSON in thinking_content for ordered replay
This commit is contained in:
ZhenYi 2026-04-28 09:42:36 +08:00
parent 2bd40aee1b
commit 5b3a6700be
11 changed files with 828 additions and 752 deletions

View File

@ -1,22 +1,31 @@
use std::pin::Pin; use futures::StreamExt;
use std::sync::Arc;
use std::time::Duration;
use models::projects::project_skill; use models::projects::project_skill;
use models::rooms::room_ai; use models::rooms::room_ai;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
use super::context::RoomMessageContext; use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::AiClientConfig; use crate::client::AiClientConfig;
use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall}; use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
};
use crate::compact::{CompactConfig, CompactService}; use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService; use crate::embed::EmbedService;
use crate::error::{AgentError, Result}; use crate::error::{AgentError, Result};
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent}; use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT}; use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry}; use crate::react::types::Action as ReactAction;
use crate::tool::{
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
registry::ToolRegistry,
};
/// Result from streaming AI response. /// Result from streaming AI response.
pub struct StreamResult { pub struct StreamResult {
@ -104,9 +113,12 @@ impl ChatService {
config: config::AppConfig, config: config::AppConfig,
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Option<crate::RigToolSet> { ) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| { self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry(registry, db, cache, config, room_id, sender_id) crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id,
)
}) })
} }
@ -140,11 +152,16 @@ impl ChatService {
let mut tool_depth = 0; let mut tool_depth = 0;
let mut input_tokens = 0i64; let mut input_tokens = 0i64;
let mut output_tokens = 0i64; let mut output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new( let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
self.ai_api_key.clone().unwrap_or_default(), .with_base_url(
) self.ai_base_url
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); .clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
loop { loop {
let response = call_with_params( let response = call_with_params(
@ -183,9 +200,10 @@ impl ChatService {
}) })
.collect(); .collect();
messages.push( messages.push(ChatRequestMessage::assistant(
ChatRequestMessage::assistant(Some(text.clone()), Some(tool_call_messages.clone())) Some(text.clone()),
); Some(tool_call_messages.clone()),
));
// Create ToolCall list for executor (we need real IDs and args) // Create ToolCall list for executor (we need real IDs and args)
// Since we can't get args from streaming, use name matching from the text // Since we can't get args from streaming, use name matching from the text
@ -210,15 +228,69 @@ impl ChatService {
if let Some(ref registry) = self.tool_registry { if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone()); ctx.registry_mut().merge(registry.clone());
} }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
let start = std::time::Instant::now();
let executor = ToolExecutor::new(); let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await { match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => ToolExecutor::to_tool_messages(&results), Ok(results) => {
for (call, result) in
response.tool_calls_finished.iter().zip(results.iter())
{
let elapsed = start.elapsed().as_millis() as i64;
let is_error =
matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.clone(),
session_id: recorder.session_id(),
tool_name: call.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}
Err(e) => { Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in &response.tool_calls_finished {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
}
let err_msg = format!("[Tool call failed: {}]", e); let err_msg = format!("[Tool call failed: {}]", e);
response response
.tool_calls_finished .tool_calls_finished
.iter() .iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)) .map(|_| {
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
})
.collect() .collect()
} }
} }
@ -250,8 +322,10 @@ impl ChatService {
}) })
.collect(); .collect();
for event in &tool_events { for event in &tool_events {
if let Some(ctx) = if let Some(ctx) = self
self.perception_service.passive.detect(event, &skill_entries) .perception_service
.passive
.detect(event, &skill_entries)
{ {
messages.push(ctx.to_system_message()); messages.push(ctx.to_system_message());
} }
@ -268,16 +342,62 @@ impl ChatService {
} else { } else {
text text
}; };
return Ok(ProcessResult { content, input_tokens, output_tokens }); // Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(ProcessResult {
content,
input_tokens,
output_tokens,
});
} }
continue; continue;
} }
return Ok(ProcessResult { content: text, input_tokens, output_tokens }); // Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(ProcessResult {
content: text,
input_tokens,
output_tokens,
});
} }
} }
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> { pub async fn process_stream(
&self,
request: AiRequest,
on_chunk: StreamCallback,
) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations // Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk); let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default(); let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
@ -302,11 +422,19 @@ impl ChatService {
.and_then(|r| r.max_tokens.map(|v| v as u32)) .and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32); .unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0; let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let config = AiClientConfig::new( let version_id = room_ai.as_ref().and_then(|r| r.version);
self.ai_api_key.clone().unwrap_or_default(),
) let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into())); .with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let mut full_content = String::new(); let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new(); let mut all_chunks: Vec<StreamChunk> = Vec::new();
@ -325,6 +453,7 @@ impl ChatService {
temperature, temperature,
max_tokens, max_tokens,
if tools_enabled { Some(&tools) } else { None }, if tools_enabled { Some(&tools) } else { None },
None, // tool_choice — auto (let model decide)
Arc::new(move |delta| { Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk { let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(), content: delta.to_string(),
@ -351,6 +480,9 @@ impl ChatService {
) )
.await?; .await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
// Collect chunks from this streaming iteration in order. // Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks); all_chunks.extend(response.chunks);
@ -425,23 +557,44 @@ impl ChatService {
request.config.clone(), request.config.clone(),
request.room.id, request.room.id,
Some(request.sender.uid), Some(request.sender.uid),
); )
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry { if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone()); ctx.registry_mut().merge(registry.clone());
} }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
for call in &calls { for call in &calls {
let start = std::time::Instant::now();
let executor = crate::tool::ToolExecutor::new(); let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await { let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
let err_text = format!("[Tool call failed: {}]", e); let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed"); tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
let err_display = format!("{} (failed)", call.name); let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: err_display.clone(), content: err_display.clone(),
done: false, done: false,
chunk_type: AiChunkType::ToolCall, chunk_type: AiChunkType::ToolResult,
}) })
.await; .await;
all_chunks.push(StreamChunk { all_chunks.push(StreamChunk {
@ -464,6 +617,29 @@ impl ChatService {
text.clone() text.clone()
}; };
tracing::debug!("tool_result: {} — {}", call.name, preview); tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data. // Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below. // Log server-side only; frontend sees tool_call status via on_chunk below.
} }
@ -471,7 +647,7 @@ impl ChatService {
on_chunk(AiStreamChunk { on_chunk(AiStreamChunk {
content: success_display.clone(), content: success_display.clone(),
done: false, done: false,
chunk_type: AiChunkType::ToolCall, chunk_type: AiChunkType::ToolResult,
}) })
.await; .await;
all_chunks.push(StreamChunk { all_chunks.push(StreamChunk {
@ -509,8 +685,10 @@ impl ChatService {
}) })
.collect(); .collect();
for event in &tool_events { for event in &tool_events {
if let Some(ctx) = if let Some(ctx) = self
self.perception_service.passive.detect(event, &skill_entries) .perception_service
.passive
.detect(event, &skill_entries)
{ {
messages.push(ctx.to_system_message()); messages.push(ctx.to_system_message());
} }
@ -533,6 +711,23 @@ impl ChatService {
chunk_type: StreamChunkType::Answer, chunk_type: StreamChunkType::Answer,
content: max_depth_text, content: max_depth_text,
}); });
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(StreamResult { return Ok(StreamResult {
content: full_content, content: full_content,
reasoning_content: String::new(), reasoning_content: String::new(),
@ -557,6 +752,23 @@ impl ChatService {
chunk_type: StreamChunkType::Answer, chunk_type: StreamChunkType::Answer,
content: response.content.clone(), content: response.content.clone(),
}); });
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(StreamResult { return Ok(StreamResult {
content: full_content, content: full_content,
reasoning_content: response.reasoning_content, reasoning_content: response.reasoning_content,
@ -616,7 +828,10 @@ impl ChatService {
parts.push(format!("Description: {}", desc)); parts.push(format!("Description: {}", desc));
} }
parts.push(format!("Default branch: {}", repo.default_branch)); parts.push(format!("Default branch: {}", repo.default_branch));
parts.push(format!("Private: {}", if repo.is_private { "yes" } else { "no" })); parts.push(format!(
"Private: {}",
if repo.is_private { "yes" } else { "no" }
));
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d"))); parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
messages.push(ChatRequestMessage::system(format!( messages.push(ChatRequestMessage::system(format!(
"Mentioned repository:\n{}", "Mentioned repository:\n{}",
@ -692,7 +907,11 @@ impl ChatService {
"Current Project:\n{}\nDescription: {}\nPublic: {}", "Current Project:\n{}\nDescription: {}\nPublic: {}",
request.project.display_name, request.project.display_name,
request.project.description.as_deref().unwrap_or("(none)"), request.project.description.as_deref().unwrap_or("(none)"),
if request.project.is_public { "yes" } else { "no" } if request.project.is_public {
"yes"
} else {
"no"
}
))); )));
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)]; let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
@ -773,7 +992,11 @@ impl ChatService {
if let Some(embed_service) = &self.embed_service { if let Some(embed_service) = &self.embed_service {
let awareness = crate::perception::VectorActiveAwareness::default(); let awareness = crate::perception::VectorActiveAwareness::default();
vector_skills = awareness vector_skills = awareness
.detect(embed_service, &request.input, &request.project.id.to_string()) .detect(
embed_service,
&request.input,
&request.project.id.to_string(),
)
.await; .await;
} }
@ -813,32 +1036,14 @@ impl ChatService {
.await .await
} }
fn is_retryable_tool_error(msg: &str) -> bool { pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<String>
let msg_lower = msg.to_lowercase();
msg_lower.contains("connection")
|| msg_lower.contains("timeout")
|| msg_lower.contains("timed out")
|| msg_lower.contains("rate limit")
|| msg_lower.contains("too many")
|| msg_lower.contains("unavailable")
|| msg_lower.contains("service unavailable")
|| msg_lower.contains("temporarily")
|| msg_lower.contains("refused")
|| msg_lower.contains("reset")
|| msg_lower.contains("broken pipe")
|| msg_lower.contains("deadline exceeded")
|| msg_lower.contains("try again")
}
pub async fn process_react<C>(
&self,
request: &AiRequest,
mut on_chunk: C,
) -> Result<String>
where where
C: FnMut(crate::react::ReactStep) + Send, C: FnMut(crate::react::ReactStep) + Send,
{ {
let base_url = self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()); let base_url = self
.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = self.ai_api_key.clone().unwrap_or_default(); let api_key = self.ai_api_key.clone().unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url); let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
@ -848,104 +1053,176 @@ impl ChatService {
let db = request.db.clone(); let db = request.db.clone();
let cache = request.cache.clone(); let cache = request.cache.clone();
let config = request.config.clone(); let cfg = request.config.clone();
let room_id = request.room.id; let room_id = request.room.id;
let project_id = Some(request.project.id); let sender_uid = request.sender.uid;
let sender_uid = Some(request.sender.uid); let project_id = request.project.id;
let registry = registry.clone(); let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let executor: std::sync::Arc< // Build rig tools with recording wrapper directly from registry
dyn Fn(String, serde_json::Value) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>> let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
+ Send for def in registry.definitions() {
+ Sync, let name = def.name.clone();
> = std::sync::Arc::new(move |name: String, args: serde_json::Value| { if let Some(handler) = registry.get(&name) {
let db = db.clone(); let adapter = crate::tool::RigToolAdapter::new(
let cache = cache.clone(); handler.clone(),
let config = config.clone(); def.clone(),
let room_id = room_id; db.clone(),
let project_id = project_id; cache.clone(),
let sender_uid = sender_uid; cfg.clone(),
let registry = registry.clone(); room_id,
Some(sender_uid),
Box::pin(async move { project_id,
let max_retries = 3;
let mut last_err = String::new();
for attempt in 0..=max_retries {
let mut ctx = ToolContext::new(db.clone(), cache.clone(), config.clone(), room_id, sender_uid);
if let Some(pid) = project_id {
ctx = ctx.with_project(pid);
}
ctx.registry_mut().merge(registry.clone());
let tool_executor = ToolExecutor::new();
let call = AgentToolCall {
id: Uuid::new_v4().to_string(),
name: name.clone(),
arguments: serde_json::to_string(&args).unwrap_or_else(|_| "{}".into()),
};
match tool_executor.execute_batch(vec![call], &mut ctx).await {
Ok(results) => {
let result = results.into_iter().next()
.ok_or_else(|| "no tool result returned".to_string())?;
match result.result {
ToolResult::Ok(v) => return Ok(v),
ToolResult::Error(msg) => {
if attempt < max_retries && Self::is_retryable_tool_error(&msg) {
last_err = msg;
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
tracing::warn!(
tool = %name,
attempt = attempt + 1,
backoff_ms = backoff_ms,
error = %last_err,
"tool_execute_retry"
); );
tokio::time::sleep(Duration::from_millis(backoff_ms)).await; tools.push(Box::new(RecordingTool::new(
continue; Box::new(adapter),
} db.clone(),
return Err(msg); session_id,
sender_uid,
)));
} }
} }
// Build rig agent (handles multi-turn tool calls natively)
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
tokio::pin!(stream);
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
});
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let reasoning_text = reasoning.reasoning.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
});
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
});
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
// Text was already streamed incrementally via Answer events.
} }
Err(e) => { Err(e) => {
last_err = e.to_string(); let err_msg = format!("rig agent stream error: {}", e);
if attempt < max_retries && Self::is_retryable_tool_error(&last_err) { return Err(AgentError::OpenAi(err_msg));
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
tracing::warn!(
tool = %name,
attempt = attempt + 1,
backoff_ms = backoff_ms,
error = %last_err,
"tool_execute_retry"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
return Err(last_err);
} }
_ => {}
} }
} }
Err(last_err) let elapsed_ms = session_start.elapsed().as_millis() as i64;
}) as Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>> let _ = models::ai::ai_session::ActiveModel {
}); id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(elapsed_ms)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
let tools = self.tools(); Ok(final_content)
let config = ReactConfig { }
max_steps: request.max_tool_depth, }
stop_sequences: Vec::new(),
tool_executor: Some(executor),
};
let mut agent = ReactAgent::new(DEFAULT_SYSTEM_PROMPT, tools, config); /// Extract text from rig's ToolResultContent, ignoring images.
agent.add_user_message(&request.input); fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
use rig::completion::message::ToolResultContent;
agent content
.run(&request.model.name, &client_config, |step| { .iter()
on_chunk(step); .filter_map(|item| {
if let ToolResultContent::Text(t) = item {
Some(t.text.clone())
} else {
None
}
}) })
.await .collect::<Vec<_>>()
} .join("\n")
} }

View File

@ -287,14 +287,6 @@ where
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect()) .map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default(); .unwrap_or_default();
let tc = match tool_choice {
Some("none") => rig::completion::message::ToolChoice::None,
Some("auto") | None => rig::completion::message::ToolChoice::Auto,
Some(s) => rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
};
let mut builder = model.completion_request(""); let mut builder = model.completion_request("");
if !preamble.is_empty() { if !preamble.is_empty() {
@ -317,7 +309,24 @@ where
builder = builder.tools(tool_defs); builder = builder.tools(tool_defs);
} }
builder = builder.tool_choice(tc); // Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?; let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
@ -498,6 +507,7 @@ pub async fn call_stream(
temperature: f32, temperature: f32,
max_tokens: u32, max_tokens: u32,
tools: Option<&[serde_json::Value]>, tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb, on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb, on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb, on_tool_call: StreamToolCallCb,
@ -506,7 +516,7 @@ pub async fn call_stream(
loop { loop {
let result = call_stream_once( let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools, messages, model_name, config, temperature, max_tokens, tools, tool_choice,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(), on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
) )
.await; .await;
@ -542,6 +552,7 @@ async fn call_stream_once(
temperature: f32, temperature: f32,
max_tokens: u32, max_tokens: u32,
tools: Option<&[serde_json::Value]>, tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb, on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb, on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb, on_tool_call: StreamToolCallCb,
@ -581,6 +592,24 @@ async fn call_stream_once(
builder = builder.tools(tool_defs); builder = builder.tools(tool_defs);
} }
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let stream_fut = async { let stream_fut = async {
let mut stream = builder let mut stream = builder
.stream() .stream()
@ -592,6 +621,10 @@ async fn call_stream_once(
let mut tool_calls: Vec<StreamedToolCall> = Vec::new(); let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new(); let mut chunks: Vec<StreamChunk> = Vec::new();
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
// Filter them out so they don't cause spurious tool execution attempts.
let skip_tool_calls = tool_choice == Some("none");
use std::collections::HashMap; use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new(); let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false; let mut stream_finished = false;
@ -612,6 +645,10 @@ async fn call_stream_once(
tool_call, tool_call,
internal_call_id, internal_call_id,
}) => { }) => {
if skip_tool_calls {
partial_tool_calls.remove(&internal_call_id);
continue;
}
let arguments = match &tool_call.function.arguments { let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(), serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()), other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
@ -638,6 +675,9 @@ async fn call_stream_once(
internal_call_id, internal_call_id,
content: delta_content, content: delta_content,
}) => { }) => {
if skip_tool_calls {
continue;
}
use rig::streaming::ToolCallDeltaContent; use rig::streaming::ToolCallDeltaContent;
match delta_content { match delta_content {
ToolCallDeltaContent::Name(name) => { ToolCallDeltaContent::Name(name) => {
@ -677,9 +717,13 @@ async fn call_stream_once(
} }
Ok(StreamedAssistantContent::Final(response)) => { Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true; stream_finished = true;
if !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() { for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc); tool_calls.push(tc);
} }
} else {
partial_tool_calls.drain();
}
if let Some(usage) = response.token_usage() { if let Some(usage) = response.token_usage() {
let in_toks = usage.input_tokens as i64; let in_toks = usage.input_tokens as i64;
let out_toks = usage.output_tokens as i64; let out_toks = usage.output_tokens as i64;
@ -700,7 +744,7 @@ async fn call_stream_once(
} }
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage) // Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished { if !stream_finished && !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() { for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc); tool_calls.push(tc);
} }

View File

@ -31,16 +31,13 @@ pub use client::types::ChatRequestMessage;
pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary}; pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary};
pub use embed::{new_embed_client, EmbedClient, EmbedService, QdrantClient, SearchResult}; pub use embed::{new_embed_client, EmbedClient, EmbedService, QdrantClient, SearchResult};
pub use error::{AgentError, Result}; pub use error::{AgentError, Result};
pub use react::{ pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
Hook, HookAction, NoopHook, ReactAgent, ReactConfig, ReactStep, ToolCallAction, TracingHook,
DEFAULT_SYSTEM_PROMPT,
};
pub use tool::{ pub use tool::{
ToolCall, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam, ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolRegistry, ToolResult, ToolSchema, ToolRegistry, ToolResult, ToolSchema,
}; };
#[cfg(feature = "rig")] #[cfg(feature = "rig")]
pub use agent::RigAgentService; pub use agent::RigAgentService;
#[cfg(feature = "rig")] #[cfg(feature = "rig")]
pub use tool::rig_adapter::RigToolSet; pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error};

View File

@ -1,130 +0,0 @@
//! Observability hooks for the ReAct agent loop.
//!
//! Hooks allow injecting custom behavior (logging, tracing, filtering, termination)
//! at each step of the reasoning loop without coupling to the core agent logic.
//!
//! Inspired by rig's `PromptHook` trait.
//!
//! # Example
//!
//! ```ignore
//! #[derive(Clone)]
//! struct MyHook;
//!
//! impl Hook for MyHook {
//! async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
//! tracing::info!("[step {}] thinking: {}", step, thought);
//! HookAction::Continue
//! }
//! }
//!
//! let agent = ReactAgent::new(prompt, tools, config).with_hook(MyHook);
//! ```
use async_trait::async_trait;
/// Controls whether the agent loop continues after a hook callback.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookAction {
/// Continue processing normally.
Continue,
/// Skip the current step and continue.
Skip,
/// Terminate the loop immediately with the given reason.
Terminate(&'static str),
}
/// Controls behavior after a tool call hook callback.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolCallAction {
/// Execute the tool normally.
Continue,
/// Skip tool execution and inject a custom result.
Skip(String),
/// Terminate the loop with the given reason.
Terminate(&'static str),
}
/// Default no-op hook that does nothing.
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopHook;
impl Hook for NoopHook {}
impl Hook for () {}
/// A hook that logs everything to stderr using `eprintln`.
/// No external dependencies required.
#[derive(Debug, Clone, Copy, Default)]
pub struct TracingHook;
impl TracingHook {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Hook for TracingHook {
async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
eprintln!("[step {}] thought: {}", step, thought);
HookAction::Continue
}
async fn on_tool_call(&self, step: usize, name: &str, args_json: &str) -> ToolCallAction {
eprintln!("[step {}] tool_call: {}({})", step, name, args_json);
ToolCallAction::Continue
}
async fn on_observation(&self, step: usize, observation: &str) -> HookAction {
eprintln!("[step {}] observation: {}", step, observation);
HookAction::Continue
}
async fn on_answer(&self, step: usize, answer: &str) -> HookAction {
eprintln!("[step {}] answer: {}", step, answer);
HookAction::Continue
}
}
/// Hook trait for observing and controlling the ReAct agent loop.
///
/// Implement this trait to inject custom behavior at each step:
/// - Log thoughts, tool calls, observations, and final answers
/// - Filter or redact sensitive data
/// - Dynamically terminate the loop based on content
/// - Inject custom tool results (e.g., for testing or sandboxing)
///
/// All methods have default no-op implementations, so you only need to
/// override the ones you care about.
///
/// The hook is called synchronously during the agent loop. Keep hook
/// callbacks fast — avoid blocking I/O. For heavy work, spawn a task
/// and return immediately.
#[async_trait]
pub trait Hook: Send + Sync {
/// Called when the agent emits a thought/reasoning step.
///
/// Return `HookAction::Terminate` to stop the loop early.
async fn on_thought(&self, _step: usize, _thought: &str) -> HookAction {
HookAction::Continue
}
/// Called just before a tool is executed.
///
/// Return `ToolCallAction::Skip(result)` to skip execution and inject `result` instead.
/// Return `ToolCallAction::Terminate` to stop the loop without executing the tool.
async fn on_tool_call(&self, _step: usize, _name: &str, _args_json: &str) -> ToolCallAction {
ToolCallAction::Continue
}
/// Called after a tool returns an observation.
async fn on_observation(&self, _step: usize, _observation: &str) -> HookAction {
HookAction::Continue
}
/// Called when the agent produces a final answer.
async fn on_answer(&self, _step: usize, _answer: &str) -> HookAction {
HookAction::Continue
}
}

View File

@ -1,413 +0,0 @@
//! ReAct (Reasoning + Acting) agent core.
use uuid::Uuid;
use std::sync::Arc;
use crate::call_with_params;
use crate::client::types::ChatRequestMessage;
use crate::error::{AgentError, Result};
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
use crate::react::types::{Action, ReactConfig, ReactStep};
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
/// A ReAct agent that performs multi-step tool-augmented reasoning.
#[derive(Clone)]
pub struct ReactAgent {
messages: Vec<ChatRequestMessage>,
#[allow(dead_code)]
tool_definitions: Vec<serde_json::Value>,
config: ReactConfig,
step_count: usize,
hook: Arc<dyn Hook>,
}
impl ReactAgent {
/// Create a new agent with a system prompt and tool definitions (as JSON values).
pub fn new(
system_prompt: &str,
tools: Vec<serde_json::Value>,
config: ReactConfig,
) -> Self {
let messages = vec![ChatRequestMessage::system(system_prompt)];
Self {
messages,
tool_definitions: tools,
config,
step_count: 0,
hook: Arc::new(NoopHook),
}
}
/// Add an initial user message to the conversation.
pub fn add_user_message(&mut self, content: &str) {
self.messages.push(ChatRequestMessage::user(content));
}
/// Attach a hook to observe and control the agent loop.
///
/// Hooks can log steps, filter content, inject custom tool results,
/// or terminate the loop early. Multiple `.with_hook()` calls replace
/// the previous hook.
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
self.hook = Arc::new(hook);
self
}
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
pub async fn run<C>(
&mut self,
model_name: &str,
client_config: &crate::client::AiClientConfig,
mut on_chunk: C,
) -> Result<String>
where
C: FnMut(ReactStep) + Send,
{
loop {
if self.step_count >= self.config.max_steps {
let msg = format!(
"Agent reached maximum reasoning steps ({}) without producing a final answer.",
self.config.max_steps
);
on_chunk(ReactStep::Answer {
step: self.step_count,
answer: msg.clone(),
});
return Ok(msg);
}
self.step_count += 1;
let step = self.step_count;
// For ReAct we force text-only responses so the model follows our JSON-in-text format.
let tool_choice_str = if self.tool_definitions.is_empty() {
None
} else {
Some("none")
};
let response = call_with_params(
&self.messages,
model_name,
client_config,
0.2, // temperature
4096, // max output tokens
None,
if self.tool_definitions.is_empty() {
None
} else {
Some(&self.tool_definitions)
},
tool_choice_str,
)
.await?;
let parsed = parse_react_response(&response.content);
let answer = parsed.answer.clone();
let action = parsed.action.clone();
on_chunk(ReactStep::Thought {
step,
thought: parsed.thought.clone(),
});
match self.hook.on_thought(step, &parsed.thought).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at thought step: {}",
reason
)));
}
HookAction::Skip => {}
HookAction::Continue => {}
}
// Final answer — emit and return.
if let Some(ans) = answer {
on_chunk(ReactStep::Answer {
step,
answer: ans.clone(),
});
match self.hook.on_answer(step, &ans).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at answer step: {}",
reason
)));
}
_ => {}
}
return Ok(ans);
}
// No answer — either do a tool call or fall back.
let Some(act) = action else {
let content = response.content.clone();
on_chunk(ReactStep::Answer {
step,
answer: content.clone(),
});
match self.hook.on_answer(step, &content).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at fallback answer: {}",
reason
)));
}
_ => {}
}
return Ok(content);
};
on_chunk(ReactStep::Action {
step,
action: act.clone(),
});
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
match self.hook.on_tool_call(step, &act.name, &args_json).await {
ToolCallAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at tool call: {}",
reason
)));
}
ToolCallAction::Skip(injected_result) => {
let observation = injected_result;
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation (injected): {}",
reason
)));
}
_ => {}
}
// Append assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
continue;
}
ToolCallAction::Continue => {}
}
// Append the assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Execute the tool.
let observation = match &self.config.tool_executor {
Some(exec) => {
let result = exec(act.name.clone(), act.args.clone()).await;
match result {
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
Err(e) => serde_json::json!({ "error": e }).to_string(),
}
}
None => serde_json::json!({
"error": format!("no tool executor registered for '{}'", act.name)
})
.to_string(),
};
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation step: {}",
reason
)));
}
_ => {}
}
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
}
}
/// Returns the number of steps executed so far.
pub fn steps(&self) -> usize {
self.step_count
}
}
// ---------------------------------------------------------------------------
// Response parsing
// ---------------------------------------------------------------------------
struct ParsedReActResponse {
thought: String,
action: Option<Action>,
answer: Option<String>,
}
fn parse_react_response(content: &str) -> ParsedReActResponse {
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
#[derive(serde::Deserialize)]
struct RawStep {
#[serde(default)]
thought: Option<String>,
#[serde(default)]
action: Option<RawAction>,
#[serde(default)]
answer: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
#[derive(serde::Deserialize)]
struct RawAction {
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
match serde_json::from_str::<RawStep>(&json_str) {
Ok(raw) => {
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
let answer = raw.answer;
let action = raw.action.map(|a| Action {
id: Uuid::new_v4().to_string(),
name: a.name.unwrap_or_default(),
args: a.args.unwrap_or(serde_json::Value::Null),
});
let action = action.or_else(|| {
if raw.name.is_some() || raw.args.is_some() {
Some(Action {
id: Uuid::new_v4().to_string(),
name: raw.name.unwrap_or_default(),
args: raw.args.unwrap_or(serde_json::Value::Null),
})
} else {
None
}
});
ParsedReActResponse {
thought,
action,
answer,
}
}
Err(_) => ParsedReActResponse {
thought: content.to_string(),
action: None,
answer: None,
},
}
}
fn extract_json(s: &str) -> Option<String> {
let trimmed = s.trim();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Some(trimmed.to_string());
}
for line in trimmed.lines() {
let line = line.trim();
if line.starts_with("```json") || line == "```" {
let mut buf = String::new();
let mut found_start = false;
for l in trimmed.lines() {
let l = l.trim();
if !found_start && (l == "```json" || l == "```") {
found_start = true;
continue;
}
if found_start && l == "```" {
break;
}
if found_start {
buf.push_str(l);
buf.push('\n');
}
}
let result = buf.trim().to_string();
if !result.is_empty() {
return Some(result);
}
}
}
let chars: Vec<char> = trimmed.chars().collect();
for i in 0..chars.len() {
let c = chars[i];
if (c == '{' || c == '[') && i > 0 {
let prev = chars[i - 1];
if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' {
continue;
}
let candidate: String = chars[i..].iter().collect();
if serde_json::from_str::<serde_json::Value>(&candidate).is_ok() {
return Some(candidate.trim_end().to_string());
}
let mut depth = 0isize;
let mut in_string = false;
let mut escaped = false;
for (j, c) in candidate.char_indices() {
if escaped { escaped = false; continue; }
if c == '\\' { escaped = true; continue; }
if c == '"' { in_string = !in_string; continue; }
if in_string { continue; }
if c == '{' || c == '[' { depth += 1; }
if c == '}' || c == ']' { depth -= 1; }
if depth == 0 {
let json_end = j + c.len_utf8();
let trimmed_candidate = &candidate[..json_end];
if serde_json::from_str::<serde_json::Value>(trimmed_candidate).is_ok() {
return Some(trimmed_candidate.to_string());
}
}
}
}
}
None
}
/// Build an assistant message with tool_calls from an Action.
fn build_tool_call_message(action: &Action) -> ChatRequestMessage {
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
ChatRequestMessage {
role: "assistant".into(),
content: Some(format!("Action: {}", action.name)),
name: None,
tool_call_id: None,
tool_calls: Some(vec![crate::client::types::ToolCall {
id: action.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: action.name.clone(),
arguments: fn_arg_str,
},
}]),
}
}

View File

@ -1,18 +1,13 @@
//! ReAct (Reason + Act) agent loop for structured tool use. //! ReAct (Reason + Act) agent types.
//! //!
//! The agent alternates between a **thought** phase (reasoning about what to do) //! Provides the step types used by the ReAct callback interface.
//! and an **action** phase (calling tools). Observations from tool results feed //! The actual agent loop is handled by rig's built-in Agent.
//! back into the next thought, enabling multi-step reasoning.
pub mod hooks;
pub mod loop_core;
pub mod types; pub mod types;
pub use hooks::{Hook, HookAction, NoopHook, ToolCallAction, TracingHook};
pub use loop_core::ReactAgent;
pub use types::{ReactConfig, ReactStep}; pub use types::{ReactConfig, ReactStep};
/// Default system prompt for the ReAct agent. /// Default system prompt for the ReAct agent (used with rig's native tool-calling).
/// ///
/// The agent is instructed to prioritize querying local repository data /// The agent is instructed to prioritize querying local repository data
/// (issues, pull requests, repositories, documentation, etc.) before /// (issues, pull requests, repositories, documentation, etc.) before
@ -25,26 +20,6 @@ Always query the platform's local data before guessing or referring to external
If local data does not contain the answer, state that clearly before considering external information. If local data does not contain the answer, state that clearly before considering external information.
## Response Format
Respond as JSON:
1. When you need to look up data:
```json
{
"thought": "What you need to find and why.",
"action": { "name": "tool_name", "arguments": { ... } }
}
```
2. When you have enough information to answer:
```json
{
"thought": "How you arrived at the answer.",
"answer": "Your final answer."
}
```
## Tool Use ## Tool Use
- Use the tools provided by the system to search and retrieve platform data. - Use the tools provided by the system to search and retrieve platform data.

131
libs/agent/tool/recorder.rs Normal file
View File

@ -0,0 +1,131 @@
//! Batch tool call recorder — persists tool call records to `ai_tool_call` table.
//!
//! Uses an mpsc channel + background flush loop to batch-insert records,
//! reducing DB pressure from individual inserts.
//!
//! Flush triggers:
//! - Buffer reaches `BATCH_SIZE` (default 50)
//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer
//! - Sender is dropped (remaining records flushed on channel close)
use std::time::Duration;
use db::database::AppDatabase;
use models::ai::ai_tool_call;
use models::ai::ToolCallStatus;
use sea_orm::*;
use tokio::sync::mpsc;
use uuid::Uuid;
const FLUSH_INTERVAL: Duration = Duration::from_secs(5);
const BATCH_SIZE: usize = 50;
/// A single tool call record to be persisted.
#[derive(Debug, Clone)]
pub struct ToolCallRecord {
pub tool_call_id: String,
pub session_id: Uuid,
pub tool_name: String,
pub caller: Uuid,
pub arguments: serde_json::Value,
pub status: ToolCallStatus,
pub execution_time_ms: Option<i64>,
pub error_message: Option<String>,
pub error_stack: Option<String>,
pub retry_count: i32,
}
/// Channel-based batched recorder. Cheap to clone — all clones share the same sender.
#[derive(Clone)]
pub struct ToolCallRecorder {
tx: mpsc::UnboundedSender<ToolCallRecord>,
session_id: Uuid,
}
impl ToolCallRecorder {
/// Create a new recorder with an auto-generated session ID
/// and spawn a background flush loop.
pub fn new(db: AppDatabase) -> Self {
Self::with_session(db, Uuid::new_v4())
}
/// Create a new recorder with a specific session ID
/// (so tool call records can be linked to an `AiSession`).
pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(flush_loop(db, rx));
Self { tx, session_id }
}
/// The session ID shared by all tool calls recorded through this instance.
pub fn session_id(&self) -> Uuid {
self.session_id
}
/// Enqueue a tool call record for batch persistence.
pub fn record(&self, record: ToolCallRecord) {
let _ = self.tx.send(record);
}
}
async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver<ToolCallRecord>) {
let mut buffer = Vec::with_capacity(BATCH_SIZE);
let mut ticker = tokio::time::interval(FLUSH_INTERVAL);
ticker.tick().await; // skip first immediate tick
loop {
tokio::select! {
Some(record) = rx.recv() => {
buffer.push(record);
if buffer.len() >= BATCH_SIZE {
flush(&db, &mut buffer).await;
}
}
_ = ticker.tick() => {
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
}
else => {
// Channel closed — flush remaining and exit
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
break;
}
}
}
}
async fn flush(db: &AppDatabase, buffer: &mut Vec<ToolCallRecord>) {
let now = chrono::Utc::now();
let models: Vec<ai_tool_call::ActiveModel> = buffer
.iter()
.map(|r| {
let status = r.status.to_string();
ai_tool_call::ActiveModel {
tool_call_id: Set(r.tool_call_id.clone()),
session: Set(r.session_id),
tool_name: Set(r.tool_name.clone()),
caller: Set(r.caller),
arguments: Set(r.arguments.clone()),
result: Set(serde_json::Value::Null),
status: Set(status),
execution_time_ms: Set(r.execution_time_ms),
error_message: Set(r.error_message.clone()),
error_stack: Set(r.error_stack.clone()),
retry_count: Set(r.retry_count),
created_at: Set(now),
completed_at: Set(Some(now)),
updated_at: Set(now),
}
})
.collect();
let count = models.len();
if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await {
tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records");
}
buffer.clear();
}

View File

@ -4,6 +4,7 @@
//! to implement rig's ToolDyn trait, enabling integration with rig's Agent. //! to implement rig's ToolDyn trait, enabling integration with rig's Agent.
use std::collections::HashMap; use std::collections::HashMap;
use std::time::{Duration, Instant};
use futures::FutureExt; use futures::FutureExt;
use rig::completion::ToolDefinition; use rig::completion::ToolDefinition;
@ -11,8 +12,146 @@ use rig::tool::{ToolDyn, ToolError, ToolSet};
use super::context::ToolContext; use super::context::ToolContext;
use super::definition::ToolDefinition as AgentToolDefinition; use super::definition::ToolDefinition as AgentToolDefinition;
use super::recorder::{ToolCallRecord, ToolCallRecorder};
use super::registry::{ToolHandler, ToolRegistry}; use super::registry::{ToolHandler, ToolRegistry};
/// Returns true if the tool error message indicates a transient failure that can be retried.
pub fn is_retryable_tool_error(msg: &str) -> bool {
let lower = msg.to_lowercase();
lower.contains("retry")
|| lower.contains("timeout")
|| lower.contains("rate limit")
|| lower.contains("too many requests")
|| lower.contains("unavailable")
|| lower.contains("connection refused")
|| lower.contains("5")
|| lower.contains("try again")
}
/// Wraps a ToolDyn with automatic retry and tool call recording.
///
/// Used by the rig Agent path to replace the custom ReAct executor closure.
pub struct RecordingTool {
inner: Box<dyn ToolDyn>,
db: db::database::AppDatabase,
session_id: uuid::Uuid,
caller: uuid::Uuid,
}
impl RecordingTool {
pub fn new(
inner: Box<dyn ToolDyn>,
db: db::database::AppDatabase,
session_id: uuid::Uuid,
caller: uuid::Uuid,
) -> Self {
Self { inner, db, session_id, caller }
}
}
impl ToolDyn for RecordingTool {
fn name(&self) -> String {
self.inner.name()
}
fn definition<'a>(
&'a self,
prompt: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
self.inner.definition(prompt)
}
fn call<'a>(
&'a self,
args: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
let inner: &'a Box<dyn ToolDyn> = &self.inner;
let db = self.db.clone();
let session_id = self.session_id;
let caller = self.caller;
let tool_name = inner.name();
Box::pin(async move {
let recorder = ToolCallRecorder::with_session(db.clone(), session_id);
let max_retries = 3u32;
let mut last_err = String::new();
let start = Instant::now();
for attempt in 0..=max_retries {
let attempt_start = Instant::now();
let attempt_args = args.clone();
let attempt_result = inner.call(attempt_args).await;
let elapsed_ms = attempt_start.elapsed().as_millis() as i64;
let args_json: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
match attempt_result {
Ok(value) => {
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Success,
execution_time_ms: Some(elapsed_ms),
error_message: None,
error_stack: None,
retry_count: attempt as i32,
});
return Ok(value);
}
Err(e) => {
let err_msg = e.to_string();
if attempt < max_retries && is_retryable_tool_error(&err_msg) {
last_err = err_msg;
let backoff_ms =
100u64.saturating_mul(2u64.pow(attempt as u32));
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed_ms),
error_message: Some(err_msg.clone()),
error_stack: None,
retry_count: attempt as i32,
});
return Err(e);
}
}
}
// Fallback: record failure after all retries exhausted
let elapsed_ms = start.elapsed().as_millis() as i64;
let args_json: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed_ms),
error_message: Some(last_err),
error_stack: None,
retry_count: max_retries as i32,
});
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
"max retries exceeded",
))))
})
}
}
/// A wrapper that converts our ToolRegistry to rig's ToolSet. /// A wrapper that converts our ToolRegistry to rig's ToolSet.
pub struct RigToolSet { pub struct RigToolSet {
/// The rig ToolSet /// The rig ToolSet
@ -30,6 +169,7 @@ impl RigToolSet {
config: config::AppConfig, config: config::AppConfig,
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Self { ) -> Self {
let mut toolset = ToolSet::default(); let mut toolset = ToolSet::default();
let mut definitions = HashMap::new(); let mut definitions = HashMap::new();
@ -50,6 +190,7 @@ impl RigToolSet {
config: config.clone(), config: config.clone(),
room_id, room_id,
sender_id, sender_id,
project_id,
}; };
toolset.add_tool(adapter); toolset.add_tool(adapter);
} }
@ -85,6 +226,23 @@ pub struct RigToolAdapter {
config: config::AppConfig, config: config::AppConfig,
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
}
impl RigToolAdapter {
/// Create a new RigToolAdapter with all required context.
pub fn new(
handler: ToolHandler,
definition: AgentToolDefinition,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Self {
Self { handler, definition, db, cache, config, room_id, sender_id, project_id }
}
} }
impl ToolDyn for RigToolAdapter { impl ToolDyn for RigToolAdapter {
@ -113,6 +271,7 @@ impl ToolDyn for RigToolAdapter {
let config = self.config.clone(); let config = self.config.clone();
let room_id = self.room_id; let room_id = self.room_id;
let sender_id = self.sender_id; let sender_id = self.sender_id;
let project_id = self.project_id;
async move { async move {
let ctx = ToolContext::new( let ctx = ToolContext::new(
@ -121,7 +280,8 @@ impl ToolDyn for RigToolAdapter {
config, config,
room_id, room_id,
sender_id, sender_id,
); )
.with_project(project_id);
let args_json: serde_json::Value = serde_json::from_str(&args) let args_json: serde_json::Value = serde_json::from_str(&args)
.map_err(|e| ToolError::JsonError(e))?; .map_err(|e| ToolError::JsonError(e))?;

View File

@ -272,6 +272,7 @@ pub async fn ws_universal(
"data": { "data": {
"message_id": chunk.message_id, "message_id": chunk.message_id,
"room_id": chunk.room_id, "room_id": chunk.room_id,
"seq": chunk.seq,
"content": chunk.content, "content": chunk.content,
"done": chunk.done, "done": chunk.done,
"error": chunk.error, "error": chunk.error,

View File

@ -110,6 +110,9 @@ pub struct ProjectRoomEvent {
pub struct RoomMessageStreamChunkEvent { pub struct RoomMessageStreamChunkEvent {
pub message_id: Uuid, pub message_id: Uuid,
pub room_id: Uuid, pub room_id: Uuid,
/// Monotonically increasing sequence number for ordering within this stream.
#[serde(default)]
pub seq: u64,
pub content: String, pub content: String,
pub done: bool, pub done: bool,
pub error: Option<String>, pub error: Option<String>,

View File

@ -54,6 +54,7 @@ pub async fn process_message_ai_react_streaming(
let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> = let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> =
std::sync::Arc::new(std::sync::Mutex::new(String::new())); std::sync::Arc::new(std::sync::Mutex::new(String::new()));
let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
// Helper: recover from poison instead of panicking. // Helper: recover from poison instead of panicking.
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> { fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
@ -65,6 +66,7 @@ pub async fn process_message_ai_react_streaming(
let streaming_msg_id = streaming_msg_id; let streaming_msg_id = streaming_msg_id;
let room_id = room_id_inner; let room_id = room_id_inner;
let step_count = step_count.clone(); let step_count = step_count.clone();
let chunk_seq = chunk_seq.clone();
let ai_display_name_for_step = std::sync::Arc::new(ai_display_name.clone()); let ai_display_name_for_step = std::sync::Arc::new(ai_display_name.clone());
let steps = steps.clone(); let steps = steps.clone();
let answer_buffer = answer_buffer.clone(); let answer_buffer = answer_buffer.clone();
@ -73,18 +75,20 @@ pub async fn process_message_ai_react_streaming(
let room_manager = room_manager.clone(); let room_manager = room_manager.clone();
let (chunk_type, content) = match &step { let (chunk_type, content) = match &step {
ReactStep::Thought { step: _, thought } => { ReactStep::Thought { step: _, thought } => {
("thinking".to_string(), format!("[Thinking] {}", thought)) ("thinking".to_string(), thought.clone())
} }
ReactStep::Action { step: _, action } => { ReactStep::Action { step: _, action } => {
*lock_or_recover(&last_action_name) = action.name.clone(); *lock_or_recover(&last_action_name) = action.name.clone();
("tool_call".to_string(), format!("[Action] Calling `{}` with {:?}", action.name, action.args)) ("tool_call".to_string(), serde_json::json!({
"name": action.name,
"arguments": action.args,
}).to_string())
} }
ReactStep::Observation { ReactStep::Observation {
step: _, step: _,
observation: _, observation,
} => { } => {
let action_name = lock_or_recover(&last_action_name).clone(); ("tool_result".to_string(), observation.clone())
("tool_call".to_string(), format!("[Observation] {} (completed)", action_name))
} }
ReactStep::Answer { step: _, answer } => { ReactStep::Answer { step: _, answer } => {
("answer".to_string(), answer.clone()) ("answer".to_string(), answer.clone())
@ -96,22 +100,33 @@ pub async fn process_message_ai_react_streaming(
step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
} }
// Record ordered step for storage // Record ordered step for storage — merge consecutive same-type chunks
// to ensure strict think→answer→think→answer alternation.
{ {
let mut s = lock_or_recover(&steps); let mut s = lock_or_recover(&steps);
if let Some(last) = s.last_mut() {
if last.0 == chunk_type {
last.1.push_str(&content);
} else {
s.push((chunk_type.clone(), content.clone())); s.push((chunk_type.clone(), content.clone()));
} }
} else {
s.push((chunk_type.clone(), content.clone()));
}
}
if is_answer { if is_answer {
let mut ab = lock_or_recover(&answer_buffer); let mut ab = lock_or_recover(&answer_buffer);
ab.push_str(&content); ab.push_str(&content);
} }
let done = is_answer; let done = false;
let ai_name = ai_display_name_for_step.clone(); let ai_name = ai_display_name_for_step.clone();
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::spawn(async move { tokio::spawn(async move {
let event = RoomMessageStreamChunkEvent { let event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id, message_id: streaming_msg_id,
room_id, room_id,
seq: current_seq,
content: content.clone(), content: content.clone(),
done, done,
error: None, error: None,
@ -125,6 +140,21 @@ pub async fn process_message_ai_react_streaming(
let result = chat_service.process_react(&request, on_step).await; let result = chat_service.process_react(&request, on_step).await;
// Broadcast final done=true event to close the streaming channel on frontend.
let final_stream_content = lock_or_recover(&answer_buffer).clone();
room_manager
.broadcast_stream_chunk(RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id: room_id_inner,
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
content: final_stream_content.clone(),
done: true,
error: None,
display_name: Some(ai_display_name.clone()),
chunk_type: Some("answer".to_string()),
})
.await;
let final_content = lock_or_recover(&answer_buffer).clone(); let final_content = lock_or_recover(&answer_buffer).clone();
let all_steps = lock_or_recover(&steps).clone(); let all_steps = lock_or_recover(&steps).clone();
let reasoning_chain: String = all_steps let reasoning_chain: String = all_steps
@ -172,7 +202,7 @@ pub async fn process_message_ai_react_streaming(
} }
// Serialize ordered steps as JSON for ordered replay. // Serialize ordered steps as JSON for ordered replay.
let thinking_content = { let thinking_content_serialized = {
let steps = lock_or_recover(&steps); let steps = lock_or_recover(&steps);
if steps.is_empty() { if steps.is_empty() {
None None
@ -186,6 +216,7 @@ pub async fn process_message_ai_react_streaming(
Some(chunks_json.to_string()) Some(chunks_json.to_string())
} }
}; };
let thinking_content_for_event = thinking_content_serialized.clone();
let envelope = RoomMessageEnvelope { let envelope = RoomMessageEnvelope {
id: streaming_msg_id, id: streaming_msg_id,
@ -197,7 +228,7 @@ pub async fn process_message_ai_react_streaming(
thread_id: None, thread_id: None,
content: persist_content.clone(), content: persist_content.clone(),
content_type: "text".to_string(), content_type: "text".to_string(),
thinking_content, thinking_content: thinking_content_serialized,
send_at: now, send_at: now,
seq, seq,
in_reply_to: None, in_reply_to: None,
@ -244,7 +275,7 @@ pub async fn process_message_ai_react_streaming(
thread_id: None, thread_id: None,
content: persist_content, content: persist_content,
content_type: "text".to_string(), content_type: "text".to_string(),
thinking_content: None, thinking_content: thinking_content_for_event,
send_at: now, send_at: now,
seq, seq,
display_name: Some(ai_display_name.clone()), display_name: Some(ai_display_name.clone()),