refactor(agent): replace custom ReAct loop with rig::agent::Agent

- Use AgentBuilder for native tool-calling with stream_prompt()
- Add RecordingTool wrapper preserving retry + DB recording
- Fix tool_choice bug in do_completion (same as call_stream_once)
- Add seq field to RoomMessageStreamChunkEvent for strict ordering
- Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action
- Only final event has done=true, removed premature stream ending
- Store __chunks__ JSON in thinking_content for ordered replay
This commit is contained in:
ZhenYi 2026-04-28 09:42:36 +08:00
parent 2bd40aee1b
commit 5b3a6700be
11 changed files with 828 additions and 752 deletions

View File

@ -1,22 +1,31 @@
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use models::projects::project_skill;
use models::rooms::room_ai;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
use rig::client::CompletionClient;
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
use sea_orm::*;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::context::RoomMessageContext;
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::AiClientConfig;
use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
};
use crate::compact::{CompactConfig, CompactService};
use crate::embed::EmbedService;
use crate::error::{AgentError, Result};
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
use crate::react::types::Action as ReactAction;
use crate::tool::{
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
registry::ToolRegistry,
};
/// Result from streaming AI response.
pub struct StreamResult {
@ -104,9 +113,12 @@ impl ChatService {
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry(registry, db, cache, config, room_id, sender_id)
crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id,
)
})
}
@ -140,11 +152,16 @@ impl ChatService {
let mut tool_depth = 0;
let mut input_tokens = 0i64;
let mut output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
loop {
let response = call_with_params(
@ -183,9 +200,10 @@ impl ChatService {
})
.collect();
messages.push(
ChatRequestMessage::assistant(Some(text.clone()), Some(tool_call_messages.clone()))
);
messages.push(ChatRequestMessage::assistant(
Some(text.clone()),
Some(tool_call_messages.clone()),
));
// Create ToolCall list for executor (we need real IDs and args)
// Since we can't get args from streaming, use name matching from the text
@ -210,15 +228,69 @@ impl ChatService {
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
match executor.execute_batch(calls, &mut ctx).await {
Ok(results) => ToolExecutor::to_tool_messages(&results),
Ok(results) => {
for (call, result) in
response.tool_calls_finished.iter().zip(results.iter())
{
let elapsed = start.elapsed().as_millis() as i64;
let is_error =
matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.clone(),
session_id: recorder.session_id(),
tool_name: call.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in &response.tool_calls_finished {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
}
let err_msg = format!("[Tool call failed: {}]", e);
response
.tool_calls_finished
.iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.map(|_| {
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
})
.collect()
}
}
@ -250,8 +322,10 @@ impl ChatService {
})
.collect();
for event in &tool_events {
if let Some(ctx) =
self.perception_service.passive.detect(event, &skill_entries)
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
@ -268,16 +342,62 @@ impl ChatService {
} else {
text
};
return Ok(ProcessResult { content, input_tokens, output_tokens });
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(ProcessResult {
content,
input_tokens,
output_tokens,
});
}
continue;
}
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(input_tokens),
token_output: Set(output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(ProcessResult {
content: text,
input_tokens,
output_tokens,
});
}
}
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
pub async fn process_stream(
&self,
request: AiRequest,
on_chunk: StreamCallback,
) -> Result<StreamResult> {
// Wrap on_chunk in Arc so it can be shared across loop iterations
let on_chunk = Arc::new(on_chunk);
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
@ -302,11 +422,19 @@ impl ChatService {
.and_then(|r| r.max_tokens.map(|v| v as u32))
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let config = AiClientConfig::new(
self.ai_api_key.clone().unwrap_or_default(),
)
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
let version_id = room_ai.as_ref().and_then(|r| r.version);
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
.with_base_url(
self.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into()),
);
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
@ -325,6 +453,7 @@ impl ChatService {
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
None, // tool_choice — auto (let model decide)
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk {
content: delta.to_string(),
@ -351,6 +480,9 @@ impl ChatService {
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
// Collect chunks from this streaming iteration in order.
all_chunks.extend(response.chunks);
@ -425,23 +557,44 @@ impl ChatService {
request.config.clone(),
request.room.id,
Some(request.sender.uid),
);
)
.with_project(request.project.id);
if let Some(ref registry) = self.tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
request.db.clone(),
session_id,
);
for call in &calls {
let start = std::time::Instant::now();
let executor = crate::tool::ToolExecutor::new();
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
Ok(r) => r,
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0,
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk {
content: err_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
@ -464,6 +617,29 @@ impl ChatService {
text.clone()
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
_ => None,
};
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
models::ai::ToolCallStatus::Success
},
execution_time_ms: Some(elapsed),
error_message: error_msg,
error_stack: None,
retry_count: 0,
});
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
// Log server-side only; frontend sees tool_call status via on_chunk below.
}
@ -471,7 +647,7 @@ impl ChatService {
on_chunk(AiStreamChunk {
content: success_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
chunk_type: AiChunkType::ToolResult,
})
.await;
all_chunks.push(StreamChunk {
@ -509,8 +685,10 @@ impl ChatService {
})
.collect();
for event in &tool_events {
if let Some(ctx) =
self.perception_service.passive.detect(event, &skill_entries)
if let Some(ctx) = self
.perception_service
.passive
.detect(event, &skill_entries)
{
messages.push(ctx.to_system_message());
}
@ -533,6 +711,23 @@ impl ChatService {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
@ -557,6 +752,23 @@ impl ChatService {
chunk_type: StreamChunkType::Answer,
content: response.content.clone(),
});
// Record session
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
return Ok(StreamResult {
content: full_content,
reasoning_content: response.reasoning_content,
@ -616,7 +828,10 @@ impl ChatService {
parts.push(format!("Description: {}", desc));
}
parts.push(format!("Default branch: {}", repo.default_branch));
parts.push(format!("Private: {}", if repo.is_private { "yes" } else { "no" }));
parts.push(format!(
"Private: {}",
if repo.is_private { "yes" } else { "no" }
));
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
messages.push(ChatRequestMessage::system(format!(
"Mentioned repository:\n{}",
@ -692,7 +907,11 @@ impl ChatService {
"Current Project:\n{}\nDescription: {}\nPublic: {}",
request.project.display_name,
request.project.description.as_deref().unwrap_or("(none)"),
if request.project.is_public { "yes" } else { "no" }
if request.project.is_public {
"yes"
} else {
"no"
}
)));
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
@ -773,7 +992,11 @@ impl ChatService {
if let Some(embed_service) = &self.embed_service {
let awareness = crate::perception::VectorActiveAwareness::default();
vector_skills = awareness
.detect(embed_service, &request.input, &request.project.id.to_string())
.detect(
embed_service,
&request.input,
&request.project.id.to_string(),
)
.await;
}
@ -813,32 +1036,14 @@ impl ChatService {
.await
}
fn is_retryable_tool_error(msg: &str) -> bool {
let msg_lower = msg.to_lowercase();
msg_lower.contains("connection")
|| msg_lower.contains("timeout")
|| msg_lower.contains("timed out")
|| msg_lower.contains("rate limit")
|| msg_lower.contains("too many")
|| msg_lower.contains("unavailable")
|| msg_lower.contains("service unavailable")
|| msg_lower.contains("temporarily")
|| msg_lower.contains("refused")
|| msg_lower.contains("reset")
|| msg_lower.contains("broken pipe")
|| msg_lower.contains("deadline exceeded")
|| msg_lower.contains("try again")
}
pub async fn process_react<C>(
&self,
request: &AiRequest,
mut on_chunk: C,
) -> Result<String>
pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<String>
where
C: FnMut(crate::react::ReactStep) + Send,
{
let base_url = self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into());
let base_url = self
.ai_base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com".into());
let api_key = self.ai_api_key.clone().unwrap_or_default();
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
@ -848,104 +1053,176 @@ impl ChatService {
let db = request.db.clone();
let cache = request.cache.clone();
let config = request.config.clone();
let cfg = request.config.clone();
let room_id = request.room.id;
let project_id = Some(request.project.id);
let sender_uid = Some(request.sender.uid);
let registry = registry.clone();
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let session_id = Uuid::new_v4();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
.filter(room_ai::Column::Room.eq(request.room.id))
.filter(room_ai::Column::Model.eq(request.model.id))
.one(&request.db)
.await
.ok()
.flatten()
.and_then(|r| r.version);
let executor: std::sync::Arc<
dyn Fn(String, serde_json::Value) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>>
+ Send
+ Sync,
> = std::sync::Arc::new(move |name: String, args: serde_json::Value| {
let db = db.clone();
let cache = cache.clone();
let config = config.clone();
let room_id = room_id;
let project_id = project_id;
let sender_uid = sender_uid;
let registry = registry.clone();
// Build rig tools with recording wrapper directly from registry
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
for def in registry.definitions() {
let name = def.name.clone();
if let Some(handler) = registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(),
def.clone(),
db.clone(),
cache.clone(),
cfg.clone(),
room_id,
Some(sender_uid),
project_id,
);
tools.push(Box::new(RecordingTool::new(
Box::new(adapter),
db.clone(),
session_id,
sender_uid,
)));
}
}
Box::pin(async move {
let max_retries = 3;
let mut last_err = String::new();
// Build rig agent (handles multi-turn tool calls natively)
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
for attempt in 0..=max_retries {
let mut ctx = ToolContext::new(db.clone(), cache.clone(), config.clone(), room_id, sender_uid);
if let Some(pid) = project_id {
ctx = ctx.with_project(pid);
}
ctx.registry_mut().merge(registry.clone());
let stream = agent
.stream_prompt(&request.input)
.with_history(Vec::new())
.multi_turn(request.max_tool_depth)
.await;
let tool_executor = ToolExecutor::new();
let call = AgentToolCall {
id: Uuid::new_v4().to_string(),
name: name.clone(),
arguments: serde_json::to_string(&args).unwrap_or_else(|_| "{}".into()),
};
tokio::pin!(stream);
match tool_executor.execute_batch(vec![call], &mut ctx).await {
Ok(results) => {
let result = results.into_iter().next()
.ok_or_else(|| "no tool result returned".to_string())?;
match result.result {
ToolResult::Ok(v) => return Ok(v),
ToolResult::Error(msg) => {
if attempt < max_retries && Self::is_retryable_tool_error(&msg) {
last_err = msg;
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
tracing::warn!(
tool = %name,
attempt = attempt + 1,
backoff_ms = backoff_ms,
error = %last_err,
"tool_execute_retry"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
return Err(msg);
}
}
}
Err(e) => {
last_err = e.to_string();
if attempt < max_retries && Self::is_retryable_tool_error(&last_err) {
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
tracing::warn!(
tool = %name,
attempt = attempt + 1,
backoff_ms = backoff_ms,
error = %last_err,
"tool_execute_retry"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
return Err(last_err);
}
let mut step_count = 0usize;
let mut final_content = String::new();
let mut total_input_tokens: i64 = 0;
let mut total_output_tokens: i64 = 0;
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer {
step: step_count,
answer: t.clone(),
});
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Reasoning(reasoning),
)) => {
let reasoning_text = reasoning.reasoning.join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning_text,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
if !reasoning.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought {
step: step_count,
thought: reasoning,
});
}
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
step_count += 1;
let args: serde_json::Value = match &tool_call.function.arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
}
v => v.clone(),
};
on_chunk(ReactStep::Action {
step: step_count,
action: ReactAction::new(&tool_call.function.name, args),
});
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
step_count += 1;
let obs = tool_result_content_to_string(&tool_result.content);
on_chunk(ReactStep::Observation {
step: step_count,
observation: obs,
});
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
total_input_tokens = usage.input_tokens as i64;
total_output_tokens = usage.output_tokens as i64;
// Text was already streamed incrementally via Answer events.
}
Err(e) => {
let err_msg = format!("rig agent stream error: {}", e);
return Err(AgentError::OpenAi(err_msg));
}
_ => {}
}
}
Err(last_err)
}) as Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>>
});
let elapsed_ms = session_start.elapsed().as_millis() as i64;
let _ = models::ai::ai_session::ActiveModel {
id: Set(session_id),
room: Set(request.room.id),
model: Set(request.model.id),
version: Set(version_id.unwrap_or_default()),
token_input: Set(total_input_tokens),
token_output: Set(total_output_tokens),
latency_ms: Set(Some(elapsed_ms)),
cost: Set(None),
currency: Set(None),
error_message: Set(None),
error_code: Set(None),
created_at: Set(chrono::Utc::now()),
}
.insert(&request.db)
.await;
let tools = self.tools();
let config = ReactConfig {
max_steps: request.max_tool_depth,
stop_sequences: Vec::new(),
tool_executor: Some(executor),
};
let mut agent = ReactAgent::new(DEFAULT_SYSTEM_PROMPT, tools, config);
agent.add_user_message(&request.input);
agent
.run(&request.model.name, &client_config, |step| {
on_chunk(step);
})
.await
Ok(final_content)
}
}
/// Extract text from rig's ToolResultContent, ignoring images.
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
use rig::completion::message::ToolResultContent;
content
.iter()
.filter_map(|item| {
if let ToolResultContent::Text(t) = item {
Some(t.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}

View File

@ -287,14 +287,6 @@ where
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
.unwrap_or_default();
let tc = match tool_choice {
Some("none") => rig::completion::message::ToolChoice::None,
Some("auto") | None => rig::completion::message::ToolChoice::Auto,
Some(s) => rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
};
let mut builder = model.completion_request("");
if !preamble.is_empty() {
@ -317,7 +309,24 @@ where
builder = builder.tools(tool_defs);
}
builder = builder.tool_choice(tc);
// Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
@ -498,6 +507,7 @@ pub async fn call_stream(
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
@ -506,7 +516,7 @@ pub async fn call_stream(
loop {
let result = call_stream_once(
messages, model_name, config, temperature, max_tokens, tools,
messages, model_name, config, temperature, max_tokens, tools, tool_choice,
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
)
.await;
@ -542,6 +552,7 @@ async fn call_stream_once(
temperature: f32,
max_tokens: u32,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
on_text_delta: StreamTextCb,
on_reasoning_delta: StreamReasoningCb,
on_tool_call: StreamToolCallCb,
@ -581,6 +592,24 @@ async fn call_stream_once(
builder = builder.tools(tool_defs);
}
if let Some(tc) = tool_choice {
match tc {
"none" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
}
"auto" => {
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
}
s => {
builder = builder.tool_choice(
rig::completion::message::ToolChoice::Specific {
function_names: vec![s.to_string()],
},
);
}
}
}
let stream_fut = async {
let mut stream = builder
.stream()
@ -592,6 +621,10 @@ async fn call_stream_once(
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
let mut chunks: Vec<StreamChunk> = Vec::new();
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
// Filter them out so they don't cause spurious tool execution attempts.
let skip_tool_calls = tool_choice == Some("none");
use std::collections::HashMap;
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
let mut stream_finished = false;
@ -612,6 +645,10 @@ async fn call_stream_once(
tool_call,
internal_call_id,
}) => {
if skip_tool_calls {
partial_tool_calls.remove(&internal_call_id);
continue;
}
let arguments = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
@ -638,6 +675,9 @@ async fn call_stream_once(
internal_call_id,
content: delta_content,
}) => {
if skip_tool_calls {
continue;
}
use rig::streaming::ToolCallDeltaContent;
match delta_content {
ToolCallDeltaContent::Name(name) => {
@ -677,8 +717,12 @@ async fn call_stream_once(
}
Ok(StreamedAssistantContent::Final(response)) => {
stream_finished = true;
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
if !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}
} else {
partial_tool_calls.drain();
}
if let Some(usage) = response.token_usage() {
let in_toks = usage.input_tokens as i64;
@ -700,7 +744,7 @@ async fn call_stream_once(
}
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
if !stream_finished {
if !stream_finished && !skip_tool_calls {
for (_, tc) in partial_tool_calls.drain() {
tool_calls.push(tc);
}

View File

@ -31,16 +31,13 @@ pub use client::types::ChatRequestMessage;
pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary};
pub use embed::{new_embed_client, EmbedClient, EmbedService, QdrantClient, SearchResult};
pub use error::{AgentError, Result};
pub use react::{
Hook, HookAction, NoopHook, ReactAgent, ReactConfig, ReactStep, ToolCallAction, TracingHook,
DEFAULT_SYSTEM_PROMPT,
};
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
pub use tool::{
ToolCall, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolRegistry, ToolResult, ToolSchema,
};
#[cfg(feature = "rig")]
pub use agent::RigAgentService;
#[cfg(feature = "rig")]
pub use tool::rig_adapter::RigToolSet;
pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error};

View File

@ -1,130 +0,0 @@
//! Observability hooks for the ReAct agent loop.
//!
//! Hooks allow injecting custom behavior (logging, tracing, filtering, termination)
//! at each step of the reasoning loop without coupling to the core agent logic.
//!
//! Inspired by rig's `PromptHook` trait.
//!
//! # Example
//!
//! ```ignore
//! #[derive(Clone)]
//! struct MyHook;
//!
//! impl Hook for MyHook {
//! async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
//! tracing::info!("[step {}] thinking: {}", step, thought);
//! HookAction::Continue
//! }
//! }
//!
//! let agent = ReactAgent::new(prompt, tools, config).with_hook(MyHook);
//! ```
use async_trait::async_trait;
/// Controls whether the agent loop continues after a hook callback.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookAction {
/// Continue processing normally.
Continue,
/// Skip the current step and continue.
Skip,
/// Terminate the loop immediately with the given reason.
Terminate(&'static str),
}
/// Controls behavior after a tool call hook callback.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolCallAction {
/// Execute the tool normally.
Continue,
/// Skip tool execution and inject a custom result.
Skip(String),
/// Terminate the loop with the given reason.
Terminate(&'static str),
}
/// Default no-op hook that does nothing.
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopHook;
impl Hook for NoopHook {}
impl Hook for () {}
/// A hook that logs everything to stderr using `eprintln`.
/// No external dependencies required.
#[derive(Debug, Clone, Copy, Default)]
pub struct TracingHook;
impl TracingHook {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Hook for TracingHook {
async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
eprintln!("[step {}] thought: {}", step, thought);
HookAction::Continue
}
async fn on_tool_call(&self, step: usize, name: &str, args_json: &str) -> ToolCallAction {
eprintln!("[step {}] tool_call: {}({})", step, name, args_json);
ToolCallAction::Continue
}
async fn on_observation(&self, step: usize, observation: &str) -> HookAction {
eprintln!("[step {}] observation: {}", step, observation);
HookAction::Continue
}
async fn on_answer(&self, step: usize, answer: &str) -> HookAction {
eprintln!("[step {}] answer: {}", step, answer);
HookAction::Continue
}
}
/// Hook trait for observing and controlling the ReAct agent loop.
///
/// Implement this trait to inject custom behavior at each step:
/// - Log thoughts, tool calls, observations, and final answers
/// - Filter or redact sensitive data
/// - Dynamically terminate the loop based on content
/// - Inject custom tool results (e.g., for testing or sandboxing)
///
/// All methods have default no-op implementations, so you only need to
/// override the ones you care about.
///
/// The hook is called synchronously during the agent loop. Keep hook
/// callbacks fast — avoid blocking I/O. For heavy work, spawn a task
/// and return immediately.
#[async_trait]
pub trait Hook: Send + Sync {
/// Called when the agent emits a thought/reasoning step.
///
/// Return `HookAction::Terminate` to stop the loop early.
async fn on_thought(&self, _step: usize, _thought: &str) -> HookAction {
HookAction::Continue
}
/// Called just before a tool is executed.
///
/// Return `ToolCallAction::Skip(result)` to skip execution and inject `result` instead.
/// Return `ToolCallAction::Terminate` to stop the loop without executing the tool.
async fn on_tool_call(&self, _step: usize, _name: &str, _args_json: &str) -> ToolCallAction {
ToolCallAction::Continue
}
/// Called after a tool returns an observation.
async fn on_observation(&self, _step: usize, _observation: &str) -> HookAction {
HookAction::Continue
}
/// Called when the agent produces a final answer.
async fn on_answer(&self, _step: usize, _answer: &str) -> HookAction {
HookAction::Continue
}
}

View File

@ -1,413 +0,0 @@
//! ReAct (Reasoning + Acting) agent core.
use uuid::Uuid;
use std::sync::Arc;
use crate::call_with_params;
use crate::client::types::ChatRequestMessage;
use crate::error::{AgentError, Result};
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
use crate::react::types::{Action, ReactConfig, ReactStep};
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
/// A ReAct agent that performs multi-step tool-augmented reasoning.
#[derive(Clone)]
pub struct ReactAgent {
messages: Vec<ChatRequestMessage>,
#[allow(dead_code)]
tool_definitions: Vec<serde_json::Value>,
config: ReactConfig,
step_count: usize,
hook: Arc<dyn Hook>,
}
impl ReactAgent {
/// Create a new agent with a system prompt and tool definitions (as JSON values).
pub fn new(
system_prompt: &str,
tools: Vec<serde_json::Value>,
config: ReactConfig,
) -> Self {
let messages = vec![ChatRequestMessage::system(system_prompt)];
Self {
messages,
tool_definitions: tools,
config,
step_count: 0,
hook: Arc::new(NoopHook),
}
}
/// Add an initial user message to the conversation.
pub fn add_user_message(&mut self, content: &str) {
self.messages.push(ChatRequestMessage::user(content));
}
/// Attach a hook to observe and control the agent loop.
///
/// Hooks can log steps, filter content, inject custom tool results,
/// or terminate the loop early. Multiple `.with_hook()` calls replace
/// the previous hook.
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
self.hook = Arc::new(hook);
self
}
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
pub async fn run<C>(
&mut self,
model_name: &str,
client_config: &crate::client::AiClientConfig,
mut on_chunk: C,
) -> Result<String>
where
C: FnMut(ReactStep) + Send,
{
loop {
if self.step_count >= self.config.max_steps {
let msg = format!(
"Agent reached maximum reasoning steps ({}) without producing a final answer.",
self.config.max_steps
);
on_chunk(ReactStep::Answer {
step: self.step_count,
answer: msg.clone(),
});
return Ok(msg);
}
self.step_count += 1;
let step = self.step_count;
// For ReAct we force text-only responses so the model follows our JSON-in-text format.
let tool_choice_str = if self.tool_definitions.is_empty() {
None
} else {
Some("none")
};
let response = call_with_params(
&self.messages,
model_name,
client_config,
0.2, // temperature
4096, // max output tokens
None,
if self.tool_definitions.is_empty() {
None
} else {
Some(&self.tool_definitions)
},
tool_choice_str,
)
.await?;
let parsed = parse_react_response(&response.content);
let answer = parsed.answer.clone();
let action = parsed.action.clone();
on_chunk(ReactStep::Thought {
step,
thought: parsed.thought.clone(),
});
match self.hook.on_thought(step, &parsed.thought).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at thought step: {}",
reason
)));
}
HookAction::Skip => {}
HookAction::Continue => {}
}
// Final answer — emit and return.
if let Some(ans) = answer {
on_chunk(ReactStep::Answer {
step,
answer: ans.clone(),
});
match self.hook.on_answer(step, &ans).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at answer step: {}",
reason
)));
}
_ => {}
}
return Ok(ans);
}
// No answer — either do a tool call or fall back.
let Some(act) = action else {
let content = response.content.clone();
on_chunk(ReactStep::Answer {
step,
answer: content.clone(),
});
match self.hook.on_answer(step, &content).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at fallback answer: {}",
reason
)));
}
_ => {}
}
return Ok(content);
};
on_chunk(ReactStep::Action {
step,
action: act.clone(),
});
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
match self.hook.on_tool_call(step, &act.name, &args_json).await {
ToolCallAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at tool call: {}",
reason
)));
}
ToolCallAction::Skip(injected_result) => {
let observation = injected_result;
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation (injected): {}",
reason
)));
}
_ => {}
}
// Append assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
continue;
}
ToolCallAction::Continue => {}
}
// Append the assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Execute the tool.
let observation = match &self.config.tool_executor {
Some(exec) => {
let result = exec(act.name.clone(), act.args.clone()).await;
match result {
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
Err(e) => serde_json::json!({ "error": e }).to_string(),
}
}
None => serde_json::json!({
"error": format!("no tool executor registered for '{}'", act.name)
})
.to_string(),
};
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation step: {}",
reason
)));
}
_ => {}
}
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
}
}
/// Returns the number of steps executed so far.
pub fn steps(&self) -> usize {
self.step_count
}
}
// ---------------------------------------------------------------------------
// Response parsing
// ---------------------------------------------------------------------------
struct ParsedReActResponse {
thought: String,
action: Option<Action>,
answer: Option<String>,
}
fn parse_react_response(content: &str) -> ParsedReActResponse {
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
#[derive(serde::Deserialize)]
struct RawStep {
#[serde(default)]
thought: Option<String>,
#[serde(default)]
action: Option<RawAction>,
#[serde(default)]
answer: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
#[derive(serde::Deserialize)]
struct RawAction {
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
match serde_json::from_str::<RawStep>(&json_str) {
Ok(raw) => {
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
let answer = raw.answer;
let action = raw.action.map(|a| Action {
id: Uuid::new_v4().to_string(),
name: a.name.unwrap_or_default(),
args: a.args.unwrap_or(serde_json::Value::Null),
});
let action = action.or_else(|| {
if raw.name.is_some() || raw.args.is_some() {
Some(Action {
id: Uuid::new_v4().to_string(),
name: raw.name.unwrap_or_default(),
args: raw.args.unwrap_or(serde_json::Value::Null),
})
} else {
None
}
});
ParsedReActResponse {
thought,
action,
answer,
}
}
Err(_) => ParsedReActResponse {
thought: content.to_string(),
action: None,
answer: None,
},
}
}
fn extract_json(s: &str) -> Option<String> {
let trimmed = s.trim();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Some(trimmed.to_string());
}
for line in trimmed.lines() {
let line = line.trim();
if line.starts_with("```json") || line == "```" {
let mut buf = String::new();
let mut found_start = false;
for l in trimmed.lines() {
let l = l.trim();
if !found_start && (l == "```json" || l == "```") {
found_start = true;
continue;
}
if found_start && l == "```" {
break;
}
if found_start {
buf.push_str(l);
buf.push('\n');
}
}
let result = buf.trim().to_string();
if !result.is_empty() {
return Some(result);
}
}
}
let chars: Vec<char> = trimmed.chars().collect();
for i in 0..chars.len() {
let c = chars[i];
if (c == '{' || c == '[') && i > 0 {
let prev = chars[i - 1];
if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' {
continue;
}
let candidate: String = chars[i..].iter().collect();
if serde_json::from_str::<serde_json::Value>(&candidate).is_ok() {
return Some(candidate.trim_end().to_string());
}
let mut depth = 0isize;
let mut in_string = false;
let mut escaped = false;
for (j, c) in candidate.char_indices() {
if escaped { escaped = false; continue; }
if c == '\\' { escaped = true; continue; }
if c == '"' { in_string = !in_string; continue; }
if in_string { continue; }
if c == '{' || c == '[' { depth += 1; }
if c == '}' || c == ']' { depth -= 1; }
if depth == 0 {
let json_end = j + c.len_utf8();
let trimmed_candidate = &candidate[..json_end];
if serde_json::from_str::<serde_json::Value>(trimmed_candidate).is_ok() {
return Some(trimmed_candidate.to_string());
}
}
}
}
}
None
}
/// Build an assistant message with tool_calls from an Action.
fn build_tool_call_message(action: &Action) -> ChatRequestMessage {
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
ChatRequestMessage {
role: "assistant".into(),
content: Some(format!("Action: {}", action.name)),
name: None,
tool_call_id: None,
tool_calls: Some(vec![crate::client::types::ToolCall {
id: action.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: action.name.clone(),
arguments: fn_arg_str,
},
}]),
}
}

View File

@ -1,18 +1,13 @@
//! ReAct (Reason + Act) agent loop for structured tool use.
//! ReAct (Reason + Act) agent types.
//!
//! The agent alternates between a **thought** phase (reasoning about what to do)
//! and an **action** phase (calling tools). Observations from tool results feed
//! back into the next thought, enabling multi-step reasoning.
//! Provides the step types used by the ReAct callback interface.
//! The actual agent loop is handled by rig's built-in Agent.
pub mod hooks;
pub mod loop_core;
pub mod types;
pub use hooks::{Hook, HookAction, NoopHook, ToolCallAction, TracingHook};
pub use loop_core::ReactAgent;
pub use types::{ReactConfig, ReactStep};
/// Default system prompt for the ReAct agent.
/// Default system prompt for the ReAct agent (used with rig's native tool-calling).
///
/// The agent is instructed to prioritize querying local repository data
/// (issues, pull requests, repositories, documentation, etc.) before
@ -25,26 +20,6 @@ Always query the platform's local data before guessing or referring to external
If local data does not contain the answer, state that clearly before considering external information.
## Response Format
Respond as JSON:
1. When you need to look up data:
```json
{
"thought": "What you need to find and why.",
"action": { "name": "tool_name", "arguments": { ... } }
}
```
2. When you have enough information to answer:
```json
{
"thought": "How you arrived at the answer.",
"answer": "Your final answer."
}
```
## Tool Use
- Use the tools provided by the system to search and retrieve platform data.

131
libs/agent/tool/recorder.rs Normal file
View File

@ -0,0 +1,131 @@
//! Batch tool call recorder — persists tool call records to `ai_tool_call` table.
//!
//! Uses an mpsc channel + background flush loop to batch-insert records,
//! reducing DB pressure from individual inserts.
//!
//! Flush triggers:
//! - Buffer reaches `BATCH_SIZE` (default 50)
//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer
//! - Sender is dropped (remaining records flushed on channel close)
use std::time::Duration;
use db::database::AppDatabase;
use models::ai::ai_tool_call;
use models::ai::ToolCallStatus;
use sea_orm::*;
use tokio::sync::mpsc;
use uuid::Uuid;
const FLUSH_INTERVAL: Duration = Duration::from_secs(5);
const BATCH_SIZE: usize = 50;
/// A single tool call record to be persisted.
#[derive(Debug, Clone)]
pub struct ToolCallRecord {
pub tool_call_id: String,
pub session_id: Uuid,
pub tool_name: String,
pub caller: Uuid,
pub arguments: serde_json::Value,
pub status: ToolCallStatus,
pub execution_time_ms: Option<i64>,
pub error_message: Option<String>,
pub error_stack: Option<String>,
pub retry_count: i32,
}
/// Channel-based batched recorder. Cheap to clone — all clones share the same sender.
#[derive(Clone)]
pub struct ToolCallRecorder {
tx: mpsc::UnboundedSender<ToolCallRecord>,
session_id: Uuid,
}
impl ToolCallRecorder {
/// Create a new recorder with an auto-generated session ID
/// and spawn a background flush loop.
pub fn new(db: AppDatabase) -> Self {
Self::with_session(db, Uuid::new_v4())
}
/// Create a new recorder with a specific session ID
/// (so tool call records can be linked to an `AiSession`).
pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(flush_loop(db, rx));
Self { tx, session_id }
}
/// The session ID shared by all tool calls recorded through this instance.
pub fn session_id(&self) -> Uuid {
self.session_id
}
/// Enqueue a tool call record for batch persistence.
pub fn record(&self, record: ToolCallRecord) {
let _ = self.tx.send(record);
}
}
async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver<ToolCallRecord>) {
let mut buffer = Vec::with_capacity(BATCH_SIZE);
let mut ticker = tokio::time::interval(FLUSH_INTERVAL);
ticker.tick().await; // skip first immediate tick
loop {
tokio::select! {
Some(record) = rx.recv() => {
buffer.push(record);
if buffer.len() >= BATCH_SIZE {
flush(&db, &mut buffer).await;
}
}
_ = ticker.tick() => {
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
}
else => {
// Channel closed — flush remaining and exit
if !buffer.is_empty() {
flush(&db, &mut buffer).await;
}
break;
}
}
}
}
async fn flush(db: &AppDatabase, buffer: &mut Vec<ToolCallRecord>) {
let now = chrono::Utc::now();
let models: Vec<ai_tool_call::ActiveModel> = buffer
.iter()
.map(|r| {
let status = r.status.to_string();
ai_tool_call::ActiveModel {
tool_call_id: Set(r.tool_call_id.clone()),
session: Set(r.session_id),
tool_name: Set(r.tool_name.clone()),
caller: Set(r.caller),
arguments: Set(r.arguments.clone()),
result: Set(serde_json::Value::Null),
status: Set(status),
execution_time_ms: Set(r.execution_time_ms),
error_message: Set(r.error_message.clone()),
error_stack: Set(r.error_stack.clone()),
retry_count: Set(r.retry_count),
created_at: Set(now),
completed_at: Set(Some(now)),
updated_at: Set(now),
}
})
.collect();
let count = models.len();
if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await {
tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records");
}
buffer.clear();
}

View File

@ -4,6 +4,7 @@
//! to implement rig's ToolDyn trait, enabling integration with rig's Agent.
use std::collections::HashMap;
use std::time::{Duration, Instant};
use futures::FutureExt;
use rig::completion::ToolDefinition;
@ -11,8 +12,146 @@ use rig::tool::{ToolDyn, ToolError, ToolSet};
use super::context::ToolContext;
use super::definition::ToolDefinition as AgentToolDefinition;
use super::recorder::{ToolCallRecord, ToolCallRecorder};
use super::registry::{ToolHandler, ToolRegistry};
/// Returns true if the tool error message indicates a transient failure that can be retried.
pub fn is_retryable_tool_error(msg: &str) -> bool {
let lower = msg.to_lowercase();
lower.contains("retry")
|| lower.contains("timeout")
|| lower.contains("rate limit")
|| lower.contains("too many requests")
|| lower.contains("unavailable")
|| lower.contains("connection refused")
|| lower.contains("5")
|| lower.contains("try again")
}
/// Wraps a ToolDyn with automatic retry and tool call recording.
///
/// Used by the rig Agent path to replace the custom ReAct executor closure.
pub struct RecordingTool {
inner: Box<dyn ToolDyn>,
db: db::database::AppDatabase,
session_id: uuid::Uuid,
caller: uuid::Uuid,
}
impl RecordingTool {
pub fn new(
inner: Box<dyn ToolDyn>,
db: db::database::AppDatabase,
session_id: uuid::Uuid,
caller: uuid::Uuid,
) -> Self {
Self { inner, db, session_id, caller }
}
}
impl ToolDyn for RecordingTool {
fn name(&self) -> String {
self.inner.name()
}
fn definition<'a>(
&'a self,
prompt: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
self.inner.definition(prompt)
}
fn call<'a>(
&'a self,
args: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
let inner: &'a Box<dyn ToolDyn> = &self.inner;
let db = self.db.clone();
let session_id = self.session_id;
let caller = self.caller;
let tool_name = inner.name();
Box::pin(async move {
let recorder = ToolCallRecorder::with_session(db.clone(), session_id);
let max_retries = 3u32;
let mut last_err = String::new();
let start = Instant::now();
for attempt in 0..=max_retries {
let attempt_start = Instant::now();
let attempt_args = args.clone();
let attempt_result = inner.call(attempt_args).await;
let elapsed_ms = attempt_start.elapsed().as_millis() as i64;
let args_json: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
match attempt_result {
Ok(value) => {
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Success,
execution_time_ms: Some(elapsed_ms),
error_message: None,
error_stack: None,
retry_count: attempt as i32,
});
return Ok(value);
}
Err(e) => {
let err_msg = e.to_string();
if attempt < max_retries && is_retryable_tool_error(&err_msg) {
last_err = err_msg;
let backoff_ms =
100u64.saturating_mul(2u64.pow(attempt as u32));
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
}
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed_ms),
error_message: Some(err_msg.clone()),
error_stack: None,
retry_count: attempt as i32,
});
return Err(e);
}
}
}
// Fallback: record failure after all retries exhausted
let elapsed_ms = start.elapsed().as_millis() as i64;
let args_json: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
recorder.record(ToolCallRecord {
tool_call_id: tool_name.clone(),
session_id,
tool_name: tool_name.clone(),
caller,
arguments: args_json,
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed_ms),
error_message: Some(last_err),
error_stack: None,
retry_count: max_retries as i32,
});
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
"max retries exceeded",
))))
})
}
}
/// A wrapper that converts our ToolRegistry to rig's ToolSet.
pub struct RigToolSet {
/// The rig ToolSet
@ -30,6 +169,7 @@ impl RigToolSet {
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Self {
let mut toolset = ToolSet::default();
let mut definitions = HashMap::new();
@ -50,6 +190,7 @@ impl RigToolSet {
config: config.clone(),
room_id,
sender_id,
project_id,
};
toolset.add_tool(adapter);
}
@ -85,6 +226,23 @@ pub struct RigToolAdapter {
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
}
impl RigToolAdapter {
/// Create a new RigToolAdapter with all required context.
pub fn new(
handler: ToolHandler,
definition: AgentToolDefinition,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
) -> Self {
Self { handler, definition, db, cache, config, room_id, sender_id, project_id }
}
}
impl ToolDyn for RigToolAdapter {
@ -113,6 +271,7 @@ impl ToolDyn for RigToolAdapter {
let config = self.config.clone();
let room_id = self.room_id;
let sender_id = self.sender_id;
let project_id = self.project_id;
async move {
let ctx = ToolContext::new(
@ -121,7 +280,8 @@ impl ToolDyn for RigToolAdapter {
config,
room_id,
sender_id,
);
)
.with_project(project_id);
let args_json: serde_json::Value = serde_json::from_str(&args)
.map_err(|e| ToolError::JsonError(e))?;

View File

@ -272,6 +272,7 @@ pub async fn ws_universal(
"data": {
"message_id": chunk.message_id,
"room_id": chunk.room_id,
"seq": chunk.seq,
"content": chunk.content,
"done": chunk.done,
"error": chunk.error,

View File

@ -110,6 +110,9 @@ pub struct ProjectRoomEvent {
pub struct RoomMessageStreamChunkEvent {
pub message_id: Uuid,
pub room_id: Uuid,
/// Monotonically increasing sequence number for ordering within this stream.
#[serde(default)]
pub seq: u64,
pub content: String,
pub done: bool,
pub error: Option<String>,

View File

@ -54,6 +54,7 @@ pub async fn process_message_ai_react_streaming(
let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> =
std::sync::Arc::new(std::sync::Mutex::new(String::new()));
let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
// Helper: recover from poison instead of panicking.
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
@ -65,6 +66,7 @@ pub async fn process_message_ai_react_streaming(
let streaming_msg_id = streaming_msg_id;
let room_id = room_id_inner;
let step_count = step_count.clone();
let chunk_seq = chunk_seq.clone();
let ai_display_name_for_step = std::sync::Arc::new(ai_display_name.clone());
let steps = steps.clone();
let answer_buffer = answer_buffer.clone();
@ -73,18 +75,20 @@ pub async fn process_message_ai_react_streaming(
let room_manager = room_manager.clone();
let (chunk_type, content) = match &step {
ReactStep::Thought { step: _, thought } => {
("thinking".to_string(), format!("[Thinking] {}", thought))
("thinking".to_string(), thought.clone())
}
ReactStep::Action { step: _, action } => {
*lock_or_recover(&last_action_name) = action.name.clone();
("tool_call".to_string(), format!("[Action] Calling `{}` with {:?}", action.name, action.args))
("tool_call".to_string(), serde_json::json!({
"name": action.name,
"arguments": action.args,
}).to_string())
}
ReactStep::Observation {
step: _,
observation: _,
observation,
} => {
let action_name = lock_or_recover(&last_action_name).clone();
("tool_call".to_string(), format!("[Observation] {} (completed)", action_name))
("tool_result".to_string(), observation.clone())
}
ReactStep::Answer { step: _, answer } => {
("answer".to_string(), answer.clone())
@ -96,22 +100,33 @@ pub async fn process_message_ai_react_streaming(
step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
// Record ordered step for storage
// Record ordered step for storage — merge consecutive same-type chunks
// to ensure strict think→answer→think→answer alternation.
{
let mut s = lock_or_recover(&steps);
s.push((chunk_type.clone(), content.clone()));
if let Some(last) = s.last_mut() {
if last.0 == chunk_type {
last.1.push_str(&content);
} else {
s.push((chunk_type.clone(), content.clone()));
}
} else {
s.push((chunk_type.clone(), content.clone()));
}
}
if is_answer {
let mut ab = lock_or_recover(&answer_buffer);
ab.push_str(&content);
}
let done = is_answer;
let done = false;
let ai_name = ai_display_name_for_step.clone();
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::spawn(async move {
let event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id,
seq: current_seq,
content: content.clone(),
done,
error: None,
@ -125,6 +140,21 @@ pub async fn process_message_ai_react_streaming(
let result = chat_service.process_react(&request, on_step).await;
// Broadcast final done=true event to close the streaming channel on frontend.
let final_stream_content = lock_or_recover(&answer_buffer).clone();
room_manager
.broadcast_stream_chunk(RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id: room_id_inner,
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
content: final_stream_content.clone(),
done: true,
error: None,
display_name: Some(ai_display_name.clone()),
chunk_type: Some("answer".to_string()),
})
.await;
let final_content = lock_or_recover(&answer_buffer).clone();
let all_steps = lock_or_recover(&steps).clone();
let reasoning_chain: String = all_steps
@ -172,7 +202,7 @@ pub async fn process_message_ai_react_streaming(
}
// Serialize ordered steps as JSON for ordered replay.
let thinking_content = {
let thinking_content_serialized = {
let steps = lock_or_recover(&steps);
if steps.is_empty() {
None
@ -186,6 +216,7 @@ pub async fn process_message_ai_react_streaming(
Some(chunks_json.to_string())
}
};
let thinking_content_for_event = thinking_content_serialized.clone();
let envelope = RoomMessageEnvelope {
id: streaming_msg_id,
@ -197,7 +228,7 @@ pub async fn process_message_ai_react_streaming(
thread_id: None,
content: persist_content.clone(),
content_type: "text".to_string(),
thinking_content,
thinking_content: thinking_content_serialized,
send_at: now,
seq,
in_reply_to: None,
@ -244,7 +275,7 @@ pub async fn process_message_ai_react_streaming(
thread_id: None,
content: persist_content,
content_type: "text".to_string(),
thinking_content: None,
thinking_content: thinking_content_for_event,
send_at: now,
seq,
display_name: Some(ai_display_name.clone()),