refactor(agent): replace custom ReAct loop with rig::agent::Agent
- Use AgentBuilder for native tool-calling with stream_prompt() - Add RecordingTool wrapper preserving retry + DB recording - Fix tool_choice bug in do_completion (same as call_stream_once) - Add seq field to RoomMessageStreamChunkEvent for strict ordering - Map streaming events: Text→Answer, Reasoning→Thought, ToolCall→Action - Only final event has done=true, removed premature stream ending - Store __chunks__ JSON in thinking_content for ordered replay
This commit is contained in:
parent
2bd40aee1b
commit
5b3a6700be
@ -1,22 +1,31 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use futures::StreamExt;
|
||||
use models::projects::project_skill;
|
||||
use models::rooms::room_ai;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
use rig::agent::{AgentBuilder, MultiTurnStreamItem};
|
||||
use rig::client::CompletionClient;
|
||||
use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
|
||||
use sea_orm::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::context::RoomMessageContext;
|
||||
use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback};
|
||||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::client::{call_stream, call_with_params, StreamChunk, StreamChunkType, StreamedToolCall};
|
||||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||||
use crate::client::{
|
||||
StreamChunk, StreamChunkType, StreamedToolCall, call_stream, call_with_params,
|
||||
};
|
||||
use crate::compact::{CompactConfig, CompactService};
|
||||
use crate::embed::EmbedService;
|
||||
use crate::error::{AgentError, Result};
|
||||
use crate::perception::{PerceptionService, SkillEntry, ToolCallEvent};
|
||||
use crate::react::{ReactAgent, ReactConfig, DEFAULT_SYSTEM_PROMPT};
|
||||
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor, ToolResult, registry::ToolRegistry};
|
||||
use crate::react::{DEFAULT_SYSTEM_PROMPT, ReactStep};
|
||||
use crate::react::types::Action as ReactAction;
|
||||
use crate::tool::{
|
||||
RecordingTool, ToolCall as AgentToolCall, ToolContext, ToolExecutor,
|
||||
registry::ToolRegistry,
|
||||
};
|
||||
|
||||
/// Result from streaming AI response.
|
||||
pub struct StreamResult {
|
||||
@ -104,9 +113,12 @@ impl ChatService {
|
||||
config: config::AppConfig,
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
) -> Option<crate::RigToolSet> {
|
||||
self.tool_registry.as_ref().map(|registry| {
|
||||
crate::RigToolSet::from_registry(registry, db, cache, config, room_id, sender_id)
|
||||
crate::RigToolSet::from_registry(
|
||||
registry, db, cache, config, room_id, sender_id, project_id,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@ -140,11 +152,16 @@ impl ChatService {
|
||||
let mut tool_depth = 0;
|
||||
let mut input_tokens = 0i64;
|
||||
let mut output_tokens = 0i64;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
||||
|
||||
let config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
)
|
||||
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
|
||||
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
|
||||
.with_base_url(
|
||||
self.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into()),
|
||||
);
|
||||
|
||||
loop {
|
||||
let response = call_with_params(
|
||||
@ -183,9 +200,10 @@ impl ChatService {
|
||||
})
|
||||
.collect();
|
||||
|
||||
messages.push(
|
||||
ChatRequestMessage::assistant(Some(text.clone()), Some(tool_call_messages.clone()))
|
||||
);
|
||||
messages.push(ChatRequestMessage::assistant(
|
||||
Some(text.clone()),
|
||||
Some(tool_call_messages.clone()),
|
||||
));
|
||||
|
||||
// Create ToolCall list for executor (we need real IDs and args)
|
||||
// Since we can't get args from streaming, use name matching from the text
|
||||
@ -210,15 +228,69 @@ impl ChatService {
|
||||
if let Some(ref registry) = self.tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
|
||||
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
|
||||
request.db.clone(),
|
||||
session_id,
|
||||
);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let executor = ToolExecutor::new();
|
||||
match executor.execute_batch(calls, &mut ctx).await {
|
||||
Ok(results) => ToolExecutor::to_tool_messages(&results),
|
||||
Ok(results) => {
|
||||
for (call, result) in
|
||||
response.tool_calls_finished.iter().zip(results.iter())
|
||||
{
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
let is_error =
|
||||
matches!(result.result, crate::tool::ToolResult::Error(_));
|
||||
let error_msg = match &result.result {
|
||||
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
|
||||
_ => None,
|
||||
};
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: call.clone(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call.clone(),
|
||||
caller: request.sender.uid,
|
||||
arguments: serde_json::Value::Null,
|
||||
status: if is_error {
|
||||
models::ai::ToolCallStatus::Failed
|
||||
} else {
|
||||
models::ai::ToolCallStatus::Success
|
||||
},
|
||||
execution_time_ms: Some(elapsed),
|
||||
error_message: error_msg,
|
||||
error_stack: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
}
|
||||
crate::tool::ToolExecutor::to_tool_messages(&results)
|
||||
}
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
for call_name in &response.tool_calls_finished {
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: Uuid::new_v4().to_string(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call_name.clone(),
|
||||
caller: request.sender.uid,
|
||||
arguments: serde_json::Value::Null,
|
||||
status: models::ai::ToolCallStatus::Failed,
|
||||
execution_time_ms: Some(elapsed),
|
||||
error_message: Some(e.to_string()),
|
||||
error_stack: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let err_msg = format!("[Tool call failed: {}]", e);
|
||||
response
|
||||
.tool_calls_finished
|
||||
.iter()
|
||||
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
|
||||
.map(|_| {
|
||||
ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@ -250,8 +322,10 @@ impl ChatService {
|
||||
})
|
||||
.collect();
|
||||
for event in &tool_events {
|
||||
if let Some(ctx) =
|
||||
self.perception_service.passive.detect(event, &skill_entries)
|
||||
if let Some(ctx) = self
|
||||
.perception_service
|
||||
.passive
|
||||
.detect(event, &skill_entries)
|
||||
{
|
||||
messages.push(ctx.to_system_message());
|
||||
}
|
||||
@ -268,16 +342,62 @@ impl ChatService {
|
||||
} else {
|
||||
text
|
||||
};
|
||||
return Ok(ProcessResult { content, input_tokens, output_tokens });
|
||||
// Record session
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(request.room.id),
|
||||
model: Set(request.model.id),
|
||||
version: Set(version_id.unwrap_or_default()),
|
||||
token_input: Set(input_tokens),
|
||||
token_output: Set(output_tokens),
|
||||
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(&request.db)
|
||||
.await;
|
||||
return Ok(ProcessResult {
|
||||
content,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
|
||||
// Record session
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(request.room.id),
|
||||
model: Set(request.model.id),
|
||||
version: Set(version_id.unwrap_or_default()),
|
||||
token_input: Set(input_tokens),
|
||||
token_output: Set(output_tokens),
|
||||
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(&request.db)
|
||||
.await;
|
||||
return Ok(ProcessResult {
|
||||
content: text,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn process_stream(&self, request: AiRequest, on_chunk: StreamCallback) -> Result<StreamResult> {
|
||||
pub async fn process_stream(
|
||||
&self,
|
||||
request: AiRequest,
|
||||
on_chunk: StreamCallback,
|
||||
) -> Result<StreamResult> {
|
||||
// Wrap on_chunk in Arc so it can be shared across loop iterations
|
||||
let on_chunk = Arc::new(on_chunk);
|
||||
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
||||
@ -302,11 +422,19 @@ impl ChatService {
|
||||
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
||||
.unwrap_or(request.max_tokens as u32);
|
||||
let mut tool_depth = 0;
|
||||
let mut total_input_tokens = 0i64;
|
||||
let mut total_output_tokens = 0i64;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_start = std::time::Instant::now();
|
||||
|
||||
let config = AiClientConfig::new(
|
||||
self.ai_api_key.clone().unwrap_or_default(),
|
||||
)
|
||||
.with_base_url(self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into()));
|
||||
let version_id = room_ai.as_ref().and_then(|r| r.version);
|
||||
|
||||
let config = AiClientConfig::new(self.ai_api_key.clone().unwrap_or_default())
|
||||
.with_base_url(
|
||||
self.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into()),
|
||||
);
|
||||
|
||||
let mut full_content = String::new();
|
||||
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
||||
@ -325,6 +453,7 @@ impl ChatService {
|
||||
temperature,
|
||||
max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None },
|
||||
None, // tool_choice — auto (let model decide)
|
||||
Arc::new(move |delta| {
|
||||
let fut = on_chunk_cb(AiStreamChunk {
|
||||
content: delta.to_string(),
|
||||
@ -351,6 +480,9 @@ impl ChatService {
|
||||
)
|
||||
.await?;
|
||||
|
||||
total_input_tokens += response.input_tokens;
|
||||
total_output_tokens += response.output_tokens;
|
||||
|
||||
// Collect chunks from this streaming iteration in order.
|
||||
all_chunks.extend(response.chunks);
|
||||
|
||||
@ -425,23 +557,44 @@ impl ChatService {
|
||||
request.config.clone(),
|
||||
request.room.id,
|
||||
Some(request.sender.uid),
|
||||
);
|
||||
)
|
||||
.with_project(request.project.id);
|
||||
if let Some(ref registry) = self.tool_registry {
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
}
|
||||
|
||||
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(
|
||||
request.db.clone(),
|
||||
session_id,
|
||||
);
|
||||
|
||||
for call in &calls {
|
||||
let start = std::time::Instant::now();
|
||||
let executor = crate::tool::ToolExecutor::new();
|
||||
let results = match executor.execute_batch(vec![call.clone()], &mut ctx).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: call.id.clone(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call.name.clone(),
|
||||
caller: request.sender.uid,
|
||||
arguments: call.arguments_json().unwrap_or_default(),
|
||||
status: models::ai::ToolCallStatus::Failed,
|
||||
execution_time_ms: Some(elapsed),
|
||||
error_message: Some(e.to_string()),
|
||||
error_stack: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
|
||||
let err_text = format!("[Tool call failed: {}]", e);
|
||||
tracing::warn!(tool = %call.name, error = %e, "tool_call_failed");
|
||||
// Do NOT emit tool_result chunks to frontend — show error via tool_call instead
|
||||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||||
let err_display = format!("❌ {} (failed)", call.name);
|
||||
on_chunk(AiStreamChunk {
|
||||
content: err_display.clone(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolCall,
|
||||
chunk_type: AiChunkType::ToolResult,
|
||||
})
|
||||
.await;
|
||||
all_chunks.push(StreamChunk {
|
||||
@ -464,6 +617,29 @@ impl ChatService {
|
||||
text.clone()
|
||||
};
|
||||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||||
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
||||
let error_msg = match &result.result {
|
||||
crate::tool::ToolResult::Error(msg) => Some(msg.clone()),
|
||||
_ => None,
|
||||
};
|
||||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||||
tool_call_id: call.id.clone(),
|
||||
session_id: recorder.session_id(),
|
||||
tool_name: call.name.clone(),
|
||||
caller: request.sender.uid,
|
||||
arguments: call.arguments_json().unwrap_or_default(),
|
||||
status: if is_error {
|
||||
models::ai::ToolCallStatus::Failed
|
||||
} else {
|
||||
models::ai::ToolCallStatus::Success
|
||||
},
|
||||
execution_time_ms: Some(elapsed),
|
||||
error_message: error_msg,
|
||||
error_stack: None,
|
||||
retry_count: 0,
|
||||
});
|
||||
// Do NOT emit tool_result chunks to frontend — raw output may contain sensitive data.
|
||||
// Log server-side only; frontend sees tool_call status via on_chunk below.
|
||||
}
|
||||
@ -471,7 +647,7 @@ impl ChatService {
|
||||
on_chunk(AiStreamChunk {
|
||||
content: success_display.clone(),
|
||||
done: false,
|
||||
chunk_type: AiChunkType::ToolCall,
|
||||
chunk_type: AiChunkType::ToolResult,
|
||||
})
|
||||
.await;
|
||||
all_chunks.push(StreamChunk {
|
||||
@ -509,8 +685,10 @@ impl ChatService {
|
||||
})
|
||||
.collect();
|
||||
for event in &tool_events {
|
||||
if let Some(ctx) =
|
||||
self.perception_service.passive.detect(event, &skill_entries)
|
||||
if let Some(ctx) = self
|
||||
.perception_service
|
||||
.passive
|
||||
.detect(event, &skill_entries)
|
||||
{
|
||||
messages.push(ctx.to_system_message());
|
||||
}
|
||||
@ -533,6 +711,23 @@ impl ChatService {
|
||||
chunk_type: StreamChunkType::Answer,
|
||||
content: max_depth_text,
|
||||
});
|
||||
// Record session
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(request.room.id),
|
||||
model: Set(request.model.id),
|
||||
version: Set(version_id.unwrap_or_default()),
|
||||
token_input: Set(total_input_tokens),
|
||||
token_output: Set(total_output_tokens),
|
||||
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(&request.db)
|
||||
.await;
|
||||
return Ok(StreamResult {
|
||||
content: full_content,
|
||||
reasoning_content: String::new(),
|
||||
@ -557,6 +752,23 @@ impl ChatService {
|
||||
chunk_type: StreamChunkType::Answer,
|
||||
content: response.content.clone(),
|
||||
});
|
||||
// Record session
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(request.room.id),
|
||||
model: Set(request.model.id),
|
||||
version: Set(version_id.unwrap_or_default()),
|
||||
token_input: Set(total_input_tokens),
|
||||
token_output: Set(total_output_tokens),
|
||||
latency_ms: Set(Some(session_start.elapsed().as_millis() as i64)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(&request.db)
|
||||
.await;
|
||||
return Ok(StreamResult {
|
||||
content: full_content,
|
||||
reasoning_content: response.reasoning_content,
|
||||
@ -616,7 +828,10 @@ impl ChatService {
|
||||
parts.push(format!("Description: {}", desc));
|
||||
}
|
||||
parts.push(format!("Default branch: {}", repo.default_branch));
|
||||
parts.push(format!("Private: {}", if repo.is_private { "yes" } else { "no" }));
|
||||
parts.push(format!(
|
||||
"Private: {}",
|
||||
if repo.is_private { "yes" } else { "no" }
|
||||
));
|
||||
parts.push(format!("Created: {}", repo.created_at.format("%Y-%m-%d")));
|
||||
messages.push(ChatRequestMessage::system(format!(
|
||||
"Mentioned repository:\n{}",
|
||||
@ -692,7 +907,11 @@ impl ChatService {
|
||||
"Current Project:\n{}\nDescription: {}\nPublic: {}",
|
||||
request.project.display_name,
|
||||
request.project.description.as_deref().unwrap_or("(none)"),
|
||||
if request.project.is_public { "yes" } else { "no" }
|
||||
if request.project.is_public {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
}
|
||||
)));
|
||||
|
||||
let mut sender_parts = vec![format!("**Sender:** {}", request.sender.username)];
|
||||
@ -773,7 +992,11 @@ impl ChatService {
|
||||
if let Some(embed_service) = &self.embed_service {
|
||||
let awareness = crate::perception::VectorActiveAwareness::default();
|
||||
vector_skills = awareness
|
||||
.detect(embed_service, &request.input, &request.project.id.to_string())
|
||||
.detect(
|
||||
embed_service,
|
||||
&request.input,
|
||||
&request.project.id.to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
@ -813,32 +1036,14 @@ impl ChatService {
|
||||
.await
|
||||
}
|
||||
|
||||
fn is_retryable_tool_error(msg: &str) -> bool {
|
||||
let msg_lower = msg.to_lowercase();
|
||||
msg_lower.contains("connection")
|
||||
|| msg_lower.contains("timeout")
|
||||
|| msg_lower.contains("timed out")
|
||||
|| msg_lower.contains("rate limit")
|
||||
|| msg_lower.contains("too many")
|
||||
|| msg_lower.contains("unavailable")
|
||||
|| msg_lower.contains("service unavailable")
|
||||
|| msg_lower.contains("temporarily")
|
||||
|| msg_lower.contains("refused")
|
||||
|| msg_lower.contains("reset")
|
||||
|| msg_lower.contains("broken pipe")
|
||||
|| msg_lower.contains("deadline exceeded")
|
||||
|| msg_lower.contains("try again")
|
||||
}
|
||||
|
||||
pub async fn process_react<C>(
|
||||
&self,
|
||||
request: &AiRequest,
|
||||
mut on_chunk: C,
|
||||
) -> Result<String>
|
||||
pub async fn process_react<C>(&self, request: &AiRequest, mut on_chunk: C) -> Result<String>
|
||||
where
|
||||
C: FnMut(crate::react::ReactStep) + Send,
|
||||
{
|
||||
let base_url = self.ai_base_url.clone().unwrap_or_else(|| "https://api.openai.com".into());
|
||||
let base_url = self
|
||||
.ai_base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://api.openai.com".into());
|
||||
let api_key = self.ai_api_key.clone().unwrap_or_default();
|
||||
let client_config = AiClientConfig::new(api_key).with_base_url(base_url);
|
||||
|
||||
@ -848,104 +1053,176 @@ impl ChatService {
|
||||
|
||||
let db = request.db.clone();
|
||||
let cache = request.cache.clone();
|
||||
let config = request.config.clone();
|
||||
let cfg = request.config.clone();
|
||||
let room_id = request.room.id;
|
||||
let project_id = Some(request.project.id);
|
||||
let sender_uid = Some(request.sender.uid);
|
||||
let registry = registry.clone();
|
||||
let sender_uid = request.sender.uid;
|
||||
let project_id = request.project.id;
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_start = std::time::Instant::now();
|
||||
let version_id = room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||||
.filter(room_ai::Column::Model.eq(request.model.id))
|
||||
.one(&request.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|r| r.version);
|
||||
|
||||
let executor: std::sync::Arc<
|
||||
dyn Fn(String, serde_json::Value) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
> = std::sync::Arc::new(move |name: String, args: serde_json::Value| {
|
||||
let db = db.clone();
|
||||
let cache = cache.clone();
|
||||
let config = config.clone();
|
||||
let room_id = room_id;
|
||||
let project_id = project_id;
|
||||
let sender_uid = sender_uid;
|
||||
let registry = registry.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let max_retries = 3;
|
||||
let mut last_err = String::new();
|
||||
|
||||
for attempt in 0..=max_retries {
|
||||
let mut ctx = ToolContext::new(db.clone(), cache.clone(), config.clone(), room_id, sender_uid);
|
||||
if let Some(pid) = project_id {
|
||||
ctx = ctx.with_project(pid);
|
||||
}
|
||||
ctx.registry_mut().merge(registry.clone());
|
||||
|
||||
let tool_executor = ToolExecutor::new();
|
||||
let call = AgentToolCall {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
name: name.clone(),
|
||||
arguments: serde_json::to_string(&args).unwrap_or_else(|_| "{}".into()),
|
||||
};
|
||||
|
||||
match tool_executor.execute_batch(vec![call], &mut ctx).await {
|
||||
Ok(results) => {
|
||||
let result = results.into_iter().next()
|
||||
.ok_or_else(|| "no tool result returned".to_string())?;
|
||||
match result.result {
|
||||
ToolResult::Ok(v) => return Ok(v),
|
||||
ToolResult::Error(msg) => {
|
||||
if attempt < max_retries && Self::is_retryable_tool_error(&msg) {
|
||||
last_err = msg;
|
||||
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
|
||||
tracing::warn!(
|
||||
tool = %name,
|
||||
attempt = attempt + 1,
|
||||
backoff_ms = backoff_ms,
|
||||
error = %last_err,
|
||||
"tool_execute_retry"
|
||||
// Build rig tools with recording wrapper directly from registry
|
||||
let mut tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>> = Vec::new();
|
||||
for def in registry.definitions() {
|
||||
let name = def.name.clone();
|
||||
if let Some(handler) = registry.get(&name) {
|
||||
let adapter = crate::tool::RigToolAdapter::new(
|
||||
handler.clone(),
|
||||
def.clone(),
|
||||
db.clone(),
|
||||
cache.clone(),
|
||||
cfg.clone(),
|
||||
room_id,
|
||||
Some(sender_uid),
|
||||
project_id,
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||
continue;
|
||||
}
|
||||
return Err(msg);
|
||||
tools.push(Box::new(RecordingTool::new(
|
||||
Box::new(adapter),
|
||||
db.clone(),
|
||||
session_id,
|
||||
sender_uid,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Build rig agent (handles multi-turn tool calls natively)
|
||||
let rig_client = client_config.build_rig_client();
|
||||
let model = rig_client.completion_model(&request.model.name);
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(DEFAULT_SYSTEM_PROMPT)
|
||||
.tools(tools)
|
||||
.default_max_turns(request.max_tool_depth)
|
||||
.build();
|
||||
|
||||
let stream = agent
|
||||
.stream_prompt(&request.input)
|
||||
.with_history(Vec::new())
|
||||
.multi_turn(request.max_tool_depth)
|
||||
.await;
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut step_count = 0usize;
|
||||
let mut final_content = String::new();
|
||||
let mut total_input_tokens: i64 = 0;
|
||||
let mut total_output_tokens: i64 = 0;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
step_count += 1;
|
||||
let t = text.text;
|
||||
on_chunk(ReactStep::Answer {
|
||||
step: step_count,
|
||||
answer: t.clone(),
|
||||
});
|
||||
final_content.push_str(&t);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Reasoning(reasoning),
|
||||
)) => {
|
||||
let reasoning_text = reasoning.reasoning.join("");
|
||||
if !reasoning_text.is_empty() {
|
||||
step_count += 1;
|
||||
on_chunk(ReactStep::Thought {
|
||||
step: step_count,
|
||||
thought: reasoning_text,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::ReasoningDelta { reasoning, .. },
|
||||
)) => {
|
||||
if !reasoning.is_empty() {
|
||||
step_count += 1;
|
||||
on_chunk(ReactStep::Thought {
|
||||
step: step_count,
|
||||
thought: reasoning,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::ToolCall { tool_call, .. },
|
||||
)) => {
|
||||
step_count += 1;
|
||||
let args: serde_json::Value = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => {
|
||||
serde_json::from_str(s).unwrap_or(serde_json::Value::Null)
|
||||
}
|
||||
v => v.clone(),
|
||||
};
|
||||
on_chunk(ReactStep::Action {
|
||||
step: step_count,
|
||||
action: ReactAction::new(&tool_call.function.name, args),
|
||||
});
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
)) => {
|
||||
step_count += 1;
|
||||
let obs = tool_result_content_to_string(&tool_result.content);
|
||||
on_chunk(ReactStep::Observation {
|
||||
step: step_count,
|
||||
observation: obs,
|
||||
});
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
total_input_tokens = usage.input_tokens as i64;
|
||||
total_output_tokens = usage.output_tokens as i64;
|
||||
// Text was already streamed incrementally via Answer events.
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = e.to_string();
|
||||
if attempt < max_retries && Self::is_retryable_tool_error(&last_err) {
|
||||
let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32));
|
||||
tracing::warn!(
|
||||
tool = %name,
|
||||
attempt = attempt + 1,
|
||||
backoff_ms = backoff_ms,
|
||||
error = %last_err,
|
||||
"tool_execute_retry"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||
continue;
|
||||
}
|
||||
return Err(last_err);
|
||||
let err_msg = format!("rig agent stream error: {}", e);
|
||||
return Err(AgentError::OpenAi(err_msg));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err)
|
||||
}) as Pin<Box<dyn std::future::Future<Output = std::result::Result<serde_json::Value, String>> + Send>>
|
||||
});
|
||||
let elapsed_ms = session_start.elapsed().as_millis() as i64;
|
||||
let _ = models::ai::ai_session::ActiveModel {
|
||||
id: Set(session_id),
|
||||
room: Set(request.room.id),
|
||||
model: Set(request.model.id),
|
||||
version: Set(version_id.unwrap_or_default()),
|
||||
token_input: Set(total_input_tokens),
|
||||
token_output: Set(total_output_tokens),
|
||||
latency_ms: Set(Some(elapsed_ms)),
|
||||
cost: Set(None),
|
||||
currency: Set(None),
|
||||
error_message: Set(None),
|
||||
error_code: Set(None),
|
||||
created_at: Set(chrono::Utc::now()),
|
||||
}
|
||||
.insert(&request.db)
|
||||
.await;
|
||||
|
||||
let tools = self.tools();
|
||||
let config = ReactConfig {
|
||||
max_steps: request.max_tool_depth,
|
||||
stop_sequences: Vec::new(),
|
||||
tool_executor: Some(executor),
|
||||
};
|
||||
Ok(final_content)
|
||||
}
|
||||
}
|
||||
|
||||
let mut agent = ReactAgent::new(DEFAULT_SYSTEM_PROMPT, tools, config);
|
||||
agent.add_user_message(&request.input);
|
||||
|
||||
agent
|
||||
.run(&request.model.name, &client_config, |step| {
|
||||
on_chunk(step);
|
||||
/// Extract text from rig's ToolResultContent, ignoring images.
|
||||
fn tool_result_content_to_string(content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>) -> String {
|
||||
use rig::completion::message::ToolResultContent;
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
if let ToolResultContent::Text(t) = item {
|
||||
Some(t.text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
@ -287,14 +287,6 @@ where
|
||||
.map(|ts| ts.iter().filter_map(to_rig_tool_def).collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let tc = match tool_choice {
|
||||
Some("none") => rig::completion::message::ToolChoice::None,
|
||||
Some("auto") | None => rig::completion::message::ToolChoice::Auto,
|
||||
Some(s) => rig::completion::message::ToolChoice::Specific {
|
||||
function_names: vec![s.to_string()],
|
||||
},
|
||||
};
|
||||
|
||||
let mut builder = model.completion_request("");
|
||||
|
||||
if !preamble.is_empty() {
|
||||
@ -317,7 +309,24 @@ where
|
||||
builder = builder.tools(tool_defs);
|
||||
}
|
||||
|
||||
builder = builder.tool_choice(tc);
|
||||
// Only set tool_choice when explicitly provided (mirrors call_stream_once logic)
|
||||
if let Some(tc) = tool_choice {
|
||||
match tc {
|
||||
"none" => {
|
||||
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
|
||||
}
|
||||
"auto" => {
|
||||
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
|
||||
}
|
||||
s => {
|
||||
builder = builder.tool_choice(
|
||||
rig::completion::message::ToolChoice::Specific {
|
||||
function_names: vec![s.to_string()],
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = builder.send().await.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
@ -498,6 +507,7 @@ pub async fn call_stream(
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
tool_choice: Option<&str>,
|
||||
on_text_delta: StreamTextCb,
|
||||
on_reasoning_delta: StreamReasoningCb,
|
||||
on_tool_call: StreamToolCallCb,
|
||||
@ -506,7 +516,7 @@ pub async fn call_stream(
|
||||
|
||||
loop {
|
||||
let result = call_stream_once(
|
||||
messages, model_name, config, temperature, max_tokens, tools,
|
||||
messages, model_name, config, temperature, max_tokens, tools, tool_choice,
|
||||
on_text_delta.clone(), on_reasoning_delta.clone(), on_tool_call.clone(),
|
||||
)
|
||||
.await;
|
||||
@ -542,6 +552,7 @@ async fn call_stream_once(
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
tool_choice: Option<&str>,
|
||||
on_text_delta: StreamTextCb,
|
||||
on_reasoning_delta: StreamReasoningCb,
|
||||
on_tool_call: StreamToolCallCb,
|
||||
@ -581,6 +592,24 @@ async fn call_stream_once(
|
||||
builder = builder.tools(tool_defs);
|
||||
}
|
||||
|
||||
if let Some(tc) = tool_choice {
|
||||
match tc {
|
||||
"none" => {
|
||||
builder = builder.tool_choice(rig::completion::message::ToolChoice::None);
|
||||
}
|
||||
"auto" => {
|
||||
builder = builder.tool_choice(rig::completion::message::ToolChoice::Auto);
|
||||
}
|
||||
s => {
|
||||
builder = builder.tool_choice(
|
||||
rig::completion::message::ToolChoice::Specific {
|
||||
function_names: vec![s.to_string()],
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stream_fut = async {
|
||||
let mut stream = builder
|
||||
.stream()
|
||||
@ -592,6 +621,10 @@ async fn call_stream_once(
|
||||
let mut tool_calls: Vec<StreamedToolCall> = Vec::new();
|
||||
let mut chunks: Vec<StreamChunk> = Vec::new();
|
||||
|
||||
// Some models (e.g. GLM) ignore tool_choice="none" and still emit tool_calls.
|
||||
// Filter them out so they don't cause spurious tool execution attempts.
|
||||
let skip_tool_calls = tool_choice == Some("none");
|
||||
|
||||
use std::collections::HashMap;
|
||||
let mut partial_tool_calls: HashMap<String, StreamedToolCall> = HashMap::new();
|
||||
let mut stream_finished = false;
|
||||
@ -612,6 +645,10 @@ async fn call_stream_once(
|
||||
tool_call,
|
||||
internal_call_id,
|
||||
}) => {
|
||||
if skip_tool_calls {
|
||||
partial_tool_calls.remove(&internal_call_id);
|
||||
continue;
|
||||
}
|
||||
let arguments = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
other => serde_json::to_string(other).unwrap_or_else(|_| "{}".to_string()),
|
||||
@ -638,6 +675,9 @@ async fn call_stream_once(
|
||||
internal_call_id,
|
||||
content: delta_content,
|
||||
}) => {
|
||||
if skip_tool_calls {
|
||||
continue;
|
||||
}
|
||||
use rig::streaming::ToolCallDeltaContent;
|
||||
match delta_content {
|
||||
ToolCallDeltaContent::Name(name) => {
|
||||
@ -677,9 +717,13 @@ async fn call_stream_once(
|
||||
}
|
||||
Ok(StreamedAssistantContent::Final(response)) => {
|
||||
stream_finished = true;
|
||||
if !skip_tool_calls {
|
||||
for (_, tc) in partial_tool_calls.drain() {
|
||||
tool_calls.push(tc);
|
||||
}
|
||||
} else {
|
||||
partial_tool_calls.drain();
|
||||
}
|
||||
if let Some(usage) = response.token_usage() {
|
||||
let in_toks = usage.input_tokens as i64;
|
||||
let out_toks = usage.output_tokens as i64;
|
||||
@ -700,7 +744,7 @@ async fn call_stream_once(
|
||||
}
|
||||
|
||||
// Flush any remaining partial tool calls (if stream ended without Final or Final had no usage)
|
||||
if !stream_finished {
|
||||
if !stream_finished && !skip_tool_calls {
|
||||
for (_, tc) in partial_tool_calls.drain() {
|
||||
tool_calls.push(tc);
|
||||
}
|
||||
|
||||
@ -31,16 +31,13 @@ pub use client::types::ChatRequestMessage;
|
||||
pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary};
|
||||
pub use embed::{new_embed_client, EmbedClient, EmbedService, QdrantClient, SearchResult};
|
||||
pub use error::{AgentError, Result};
|
||||
pub use react::{
|
||||
Hook, HookAction, NoopHook, ReactAgent, ReactConfig, ReactStep, ToolCallAction, TracingHook,
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
};
|
||||
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
|
||||
pub use tool::{
|
||||
ToolCall, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
|
||||
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
|
||||
ToolRegistry, ToolResult, ToolSchema,
|
||||
};
|
||||
|
||||
#[cfg(feature = "rig")]
|
||||
pub use agent::RigAgentService;
|
||||
#[cfg(feature = "rig")]
|
||||
pub use tool::rig_adapter::RigToolSet;
|
||||
pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error};
|
||||
|
||||
@ -1,130 +0,0 @@
|
||||
//! Observability hooks for the ReAct agent loop.
|
||||
//!
|
||||
//! Hooks allow injecting custom behavior (logging, tracing, filtering, termination)
|
||||
//! at each step of the reasoning loop without coupling to the core agent logic.
|
||||
//!
|
||||
//! Inspired by rig's `PromptHook` trait.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! #[derive(Clone)]
|
||||
//! struct MyHook;
|
||||
//!
|
||||
//! impl Hook for MyHook {
|
||||
//! async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
|
||||
//! tracing::info!("[step {}] thinking: {}", step, thought);
|
||||
//! HookAction::Continue
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! let agent = ReactAgent::new(prompt, tools, config).with_hook(MyHook);
|
||||
//! ```
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Controls whether the agent loop continues after a hook callback.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum HookAction {
|
||||
/// Continue processing normally.
|
||||
Continue,
|
||||
/// Skip the current step and continue.
|
||||
Skip,
|
||||
/// Terminate the loop immediately with the given reason.
|
||||
Terminate(&'static str),
|
||||
}
|
||||
|
||||
/// Controls behavior after a tool call hook callback.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ToolCallAction {
|
||||
/// Execute the tool normally.
|
||||
Continue,
|
||||
/// Skip tool execution and inject a custom result.
|
||||
Skip(String),
|
||||
/// Terminate the loop with the given reason.
|
||||
Terminate(&'static str),
|
||||
}
|
||||
|
||||
/// Default no-op hook that does nothing.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct NoopHook;
|
||||
|
||||
impl Hook for NoopHook {}
|
||||
|
||||
impl Hook for () {}
|
||||
|
||||
/// A hook that logs everything to stderr using `eprintln`.
|
||||
/// No external dependencies required.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct TracingHook;
|
||||
|
||||
impl TracingHook {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hook for TracingHook {
|
||||
async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
|
||||
eprintln!("[step {}] thought: {}", step, thought);
|
||||
HookAction::Continue
|
||||
}
|
||||
|
||||
async fn on_tool_call(&self, step: usize, name: &str, args_json: &str) -> ToolCallAction {
|
||||
eprintln!("[step {}] tool_call: {}({})", step, name, args_json);
|
||||
ToolCallAction::Continue
|
||||
}
|
||||
|
||||
async fn on_observation(&self, step: usize, observation: &str) -> HookAction {
|
||||
eprintln!("[step {}] observation: {}", step, observation);
|
||||
HookAction::Continue
|
||||
}
|
||||
|
||||
async fn on_answer(&self, step: usize, answer: &str) -> HookAction {
|
||||
eprintln!("[step {}] answer: {}", step, answer);
|
||||
HookAction::Continue
|
||||
}
|
||||
}
|
||||
|
||||
/// Hook trait for observing and controlling the ReAct agent loop.
|
||||
///
|
||||
/// Implement this trait to inject custom behavior at each step:
|
||||
/// - Log thoughts, tool calls, observations, and final answers
|
||||
/// - Filter or redact sensitive data
|
||||
/// - Dynamically terminate the loop based on content
|
||||
/// - Inject custom tool results (e.g., for testing or sandboxing)
|
||||
///
|
||||
/// All methods have default no-op implementations, so you only need to
|
||||
/// override the ones you care about.
|
||||
///
|
||||
/// The hook is called synchronously during the agent loop. Keep hook
|
||||
/// callbacks fast — avoid blocking I/O. For heavy work, spawn a task
|
||||
/// and return immediately.
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
/// Called when the agent emits a thought/reasoning step.
|
||||
///
|
||||
/// Return `HookAction::Terminate` to stop the loop early.
|
||||
async fn on_thought(&self, _step: usize, _thought: &str) -> HookAction {
|
||||
HookAction::Continue
|
||||
}
|
||||
|
||||
/// Called just before a tool is executed.
|
||||
///
|
||||
/// Return `ToolCallAction::Skip(result)` to skip execution and inject `result` instead.
|
||||
/// Return `ToolCallAction::Terminate` to stop the loop without executing the tool.
|
||||
async fn on_tool_call(&self, _step: usize, _name: &str, _args_json: &str) -> ToolCallAction {
|
||||
ToolCallAction::Continue
|
||||
}
|
||||
|
||||
/// Called after a tool returns an observation.
|
||||
async fn on_observation(&self, _step: usize, _observation: &str) -> HookAction {
|
||||
HookAction::Continue
|
||||
}
|
||||
|
||||
/// Called when the agent produces a final answer.
|
||||
async fn on_answer(&self, _step: usize, _answer: &str) -> HookAction {
|
||||
HookAction::Continue
|
||||
}
|
||||
}
|
||||
@ -1,413 +0,0 @@
|
||||
//! ReAct (Reasoning + Acting) agent core.
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::call_with_params;
|
||||
use crate::client::types::ChatRequestMessage;
|
||||
use crate::error::{AgentError, Result};
|
||||
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
|
||||
use crate::react::types::{Action, ReactConfig, ReactStep};
|
||||
|
||||
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
|
||||
|
||||
/// A ReAct agent that performs multi-step tool-augmented reasoning.
|
||||
#[derive(Clone)]
|
||||
pub struct ReactAgent {
|
||||
messages: Vec<ChatRequestMessage>,
|
||||
#[allow(dead_code)]
|
||||
tool_definitions: Vec<serde_json::Value>,
|
||||
config: ReactConfig,
|
||||
step_count: usize,
|
||||
hook: Arc<dyn Hook>,
|
||||
}
|
||||
|
||||
impl ReactAgent {
|
||||
/// Create a new agent with a system prompt and tool definitions (as JSON values).
|
||||
pub fn new(
|
||||
system_prompt: &str,
|
||||
tools: Vec<serde_json::Value>,
|
||||
config: ReactConfig,
|
||||
) -> Self {
|
||||
let messages = vec![ChatRequestMessage::system(system_prompt)];
|
||||
Self {
|
||||
messages,
|
||||
tool_definitions: tools,
|
||||
config,
|
||||
step_count: 0,
|
||||
hook: Arc::new(NoopHook),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an initial user message to the conversation.
|
||||
pub fn add_user_message(&mut self, content: &str) {
|
||||
self.messages.push(ChatRequestMessage::user(content));
|
||||
}
|
||||
|
||||
/// Attach a hook to observe and control the agent loop.
|
||||
///
|
||||
/// Hooks can log steps, filter content, inject custom tool results,
|
||||
/// or terminate the loop early. Multiple `.with_hook()` calls replace
|
||||
/// the previous hook.
|
||||
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
|
||||
self.hook = Arc::new(hook);
|
||||
self
|
||||
}
|
||||
|
||||
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
|
||||
pub async fn run<C>(
|
||||
&mut self,
|
||||
model_name: &str,
|
||||
client_config: &crate::client::AiClientConfig,
|
||||
mut on_chunk: C,
|
||||
) -> Result<String>
|
||||
where
|
||||
C: FnMut(ReactStep) + Send,
|
||||
{
|
||||
loop {
|
||||
if self.step_count >= self.config.max_steps {
|
||||
let msg = format!(
|
||||
"Agent reached maximum reasoning steps ({}) without producing a final answer.",
|
||||
self.config.max_steps
|
||||
);
|
||||
on_chunk(ReactStep::Answer {
|
||||
step: self.step_count,
|
||||
answer: msg.clone(),
|
||||
});
|
||||
return Ok(msg);
|
||||
}
|
||||
|
||||
self.step_count += 1;
|
||||
let step = self.step_count;
|
||||
|
||||
// For ReAct we force text-only responses so the model follows our JSON-in-text format.
|
||||
let tool_choice_str = if self.tool_definitions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some("none")
|
||||
};
|
||||
|
||||
let response = call_with_params(
|
||||
&self.messages,
|
||||
model_name,
|
||||
client_config,
|
||||
0.2, // temperature
|
||||
4096, // max output tokens
|
||||
None,
|
||||
if self.tool_definitions.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(&self.tool_definitions)
|
||||
},
|
||||
tool_choice_str,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let parsed = parse_react_response(&response.content);
|
||||
let answer = parsed.answer.clone();
|
||||
let action = parsed.action.clone();
|
||||
|
||||
on_chunk(ReactStep::Thought {
|
||||
step,
|
||||
thought: parsed.thought.clone(),
|
||||
});
|
||||
|
||||
match self.hook.on_thought(step, &parsed.thought).await {
|
||||
HookAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at thought step: {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
HookAction::Skip => {}
|
||||
HookAction::Continue => {}
|
||||
}
|
||||
|
||||
// Final answer — emit and return.
|
||||
if let Some(ans) = answer {
|
||||
on_chunk(ReactStep::Answer {
|
||||
step,
|
||||
answer: ans.clone(),
|
||||
});
|
||||
|
||||
match self.hook.on_answer(step, &ans).await {
|
||||
HookAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at answer step: {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return Ok(ans);
|
||||
}
|
||||
|
||||
// No answer — either do a tool call or fall back.
|
||||
let Some(act) = action else {
|
||||
let content = response.content.clone();
|
||||
on_chunk(ReactStep::Answer {
|
||||
step,
|
||||
answer: content.clone(),
|
||||
});
|
||||
|
||||
match self.hook.on_answer(step, &content).await {
|
||||
HookAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at fallback answer: {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
return Ok(content);
|
||||
};
|
||||
|
||||
on_chunk(ReactStep::Action {
|
||||
step,
|
||||
action: act.clone(),
|
||||
});
|
||||
|
||||
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
match self.hook.on_tool_call(step, &act.name, &args_json).await {
|
||||
ToolCallAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at tool call: {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
ToolCallAction::Skip(injected_result) => {
|
||||
let observation = injected_result;
|
||||
on_chunk(ReactStep::Observation {
|
||||
step,
|
||||
observation: observation.clone(),
|
||||
});
|
||||
|
||||
match self.hook.on_observation(step, &observation).await {
|
||||
HookAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at observation (injected): {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Append assistant message with tool_calls.
|
||||
let assistant_msg = build_tool_call_message(&act);
|
||||
self.messages.push(assistant_msg);
|
||||
|
||||
// Append observation as a tool message.
|
||||
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
|
||||
|
||||
continue;
|
||||
}
|
||||
ToolCallAction::Continue => {}
|
||||
}
|
||||
|
||||
// Append the assistant message with tool_calls.
|
||||
let assistant_msg = build_tool_call_message(&act);
|
||||
self.messages.push(assistant_msg);
|
||||
|
||||
// Execute the tool.
|
||||
let observation = match &self.config.tool_executor {
|
||||
Some(exec) => {
|
||||
let result = exec(act.name.clone(), act.args.clone()).await;
|
||||
match result {
|
||||
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
|
||||
Err(e) => serde_json::json!({ "error": e }).to_string(),
|
||||
}
|
||||
}
|
||||
None => serde_json::json!({
|
||||
"error": format!("no tool executor registered for '{}'", act.name)
|
||||
})
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
on_chunk(ReactStep::Observation {
|
||||
step,
|
||||
observation: observation.clone(),
|
||||
});
|
||||
|
||||
match self.hook.on_observation(step, &observation).await {
|
||||
HookAction::Terminate(reason) => {
|
||||
return Err(AgentError::Internal(format!(
|
||||
"hook terminated at observation step: {}",
|
||||
reason
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Append observation as a tool message.
|
||||
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of steps executed so far.
|
||||
pub fn steps(&self) -> usize {
|
||||
self.step_count
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Response parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ParsedReActResponse {
|
||||
thought: String,
|
||||
action: Option<Action>,
|
||||
answer: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_react_response(content: &str) -> ParsedReActResponse {
|
||||
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RawStep {
|
||||
#[serde(default)]
|
||||
thought: Option<String>,
|
||||
#[serde(default)]
|
||||
action: Option<RawAction>,
|
||||
#[serde(default)]
|
||||
answer: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default, rename = "arguments")]
|
||||
args: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RawAction {
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default, rename = "arguments")]
|
||||
args: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<RawStep>(&json_str) {
|
||||
Ok(raw) => {
|
||||
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
|
||||
let answer = raw.answer;
|
||||
let action = raw.action.map(|a| Action {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
name: a.name.unwrap_or_default(),
|
||||
args: a.args.unwrap_or(serde_json::Value::Null),
|
||||
});
|
||||
let action = action.or_else(|| {
|
||||
if raw.name.is_some() || raw.args.is_some() {
|
||||
Some(Action {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
name: raw.name.unwrap_or_default(),
|
||||
args: raw.args.unwrap_or(serde_json::Value::Null),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
ParsedReActResponse {
|
||||
thought,
|
||||
action,
|
||||
answer,
|
||||
}
|
||||
}
|
||||
Err(_) => ParsedReActResponse {
|
||||
thought: content.to_string(),
|
||||
action: None,
|
||||
answer: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_json(s: &str) -> Option<String> {
|
||||
let trimmed = s.trim();
|
||||
|
||||
if trimmed.starts_with('{') || trimmed.starts_with('[') {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
|
||||
for line in trimmed.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with("```json") || line == "```" {
|
||||
let mut buf = String::new();
|
||||
let mut found_start = false;
|
||||
for l in trimmed.lines() {
|
||||
let l = l.trim();
|
||||
if !found_start && (l == "```json" || l == "```") {
|
||||
found_start = true;
|
||||
continue;
|
||||
}
|
||||
if found_start && l == "```" {
|
||||
break;
|
||||
}
|
||||
if found_start {
|
||||
buf.push_str(l);
|
||||
buf.push('\n');
|
||||
}
|
||||
}
|
||||
let result = buf.trim().to_string();
|
||||
if !result.is_empty() {
|
||||
return Some(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chars: Vec<char> = trimmed.chars().collect();
|
||||
for i in 0..chars.len() {
|
||||
let c = chars[i];
|
||||
if (c == '{' || c == '[') && i > 0 {
|
||||
let prev = chars[i - 1];
|
||||
if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' {
|
||||
continue;
|
||||
}
|
||||
let candidate: String = chars[i..].iter().collect();
|
||||
if serde_json::from_str::<serde_json::Value>(&candidate).is_ok() {
|
||||
return Some(candidate.trim_end().to_string());
|
||||
}
|
||||
let mut depth = 0isize;
|
||||
let mut in_string = false;
|
||||
let mut escaped = false;
|
||||
for (j, c) in candidate.char_indices() {
|
||||
if escaped { escaped = false; continue; }
|
||||
if c == '\\' { escaped = true; continue; }
|
||||
if c == '"' { in_string = !in_string; continue; }
|
||||
if in_string { continue; }
|
||||
if c == '{' || c == '[' { depth += 1; }
|
||||
if c == '}' || c == ']' { depth -= 1; }
|
||||
if depth == 0 {
|
||||
let json_end = j + c.len_utf8();
|
||||
let trimmed_candidate = &candidate[..json_end];
|
||||
if serde_json::from_str::<serde_json::Value>(trimmed_candidate).is_ok() {
|
||||
return Some(trimmed_candidate.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Build an assistant message with tool_calls from an Action.
|
||||
fn build_tool_call_message(action: &Action) -> ChatRequestMessage {
|
||||
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
ChatRequestMessage {
|
||||
role: "assistant".into(),
|
||||
content: Some(format!("Action: {}", action.name)),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(vec![crate::client::types::ToolCall {
|
||||
id: action.id.clone(),
|
||||
type_: "function".into(),
|
||||
function: crate::client::types::ToolCallFunction {
|
||||
name: action.name.clone(),
|
||||
arguments: fn_arg_str,
|
||||
},
|
||||
}]),
|
||||
}
|
||||
}
|
||||
@ -1,18 +1,13 @@
|
||||
//! ReAct (Reason + Act) agent loop for structured tool use.
|
||||
//! ReAct (Reason + Act) agent types.
|
||||
//!
|
||||
//! The agent alternates between a **thought** phase (reasoning about what to do)
|
||||
//! and an **action** phase (calling tools). Observations from tool results feed
|
||||
//! back into the next thought, enabling multi-step reasoning.
|
||||
//! Provides the step types used by the ReAct callback interface.
|
||||
//! The actual agent loop is handled by rig's built-in Agent.
|
||||
|
||||
pub mod hooks;
|
||||
pub mod loop_core;
|
||||
pub mod types;
|
||||
|
||||
pub use hooks::{Hook, HookAction, NoopHook, ToolCallAction, TracingHook};
|
||||
pub use loop_core::ReactAgent;
|
||||
pub use types::{ReactConfig, ReactStep};
|
||||
|
||||
/// Default system prompt for the ReAct agent.
|
||||
/// Default system prompt for the ReAct agent (used with rig's native tool-calling).
|
||||
///
|
||||
/// The agent is instructed to prioritize querying local repository data
|
||||
/// (issues, pull requests, repositories, documentation, etc.) before
|
||||
@ -25,26 +20,6 @@ Always query the platform's local data before guessing or referring to external
|
||||
|
||||
If local data does not contain the answer, state that clearly before considering external information.
|
||||
|
||||
## Response Format
|
||||
|
||||
Respond as JSON:
|
||||
|
||||
1. When you need to look up data:
|
||||
```json
|
||||
{
|
||||
"thought": "What you need to find and why.",
|
||||
"action": { "name": "tool_name", "arguments": { ... } }
|
||||
}
|
||||
```
|
||||
|
||||
2. When you have enough information to answer:
|
||||
```json
|
||||
{
|
||||
"thought": "How you arrived at the answer.",
|
||||
"answer": "Your final answer."
|
||||
}
|
||||
```
|
||||
|
||||
## Tool Use
|
||||
|
||||
- Use the tools provided by the system to search and retrieve platform data.
|
||||
|
||||
131
libs/agent/tool/recorder.rs
Normal file
131
libs/agent/tool/recorder.rs
Normal file
@ -0,0 +1,131 @@
|
||||
//! Batch tool call recorder — persists tool call records to `ai_tool_call` table.
|
||||
//!
|
||||
//! Uses an mpsc channel + background flush loop to batch-insert records,
|
||||
//! reducing DB pressure from individual inserts.
|
||||
//!
|
||||
//! Flush triggers:
|
||||
//! - Buffer reaches `BATCH_SIZE` (default 50)
|
||||
//! - `FLUSH_INTERVAL` (default 5s) elapses with non-empty buffer
|
||||
//! - Sender is dropped (remaining records flushed on channel close)
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use db::database::AppDatabase;
|
||||
use models::ai::ai_tool_call;
|
||||
use models::ai::ToolCallStatus;
|
||||
use sea_orm::*;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
const FLUSH_INTERVAL: Duration = Duration::from_secs(5);
|
||||
const BATCH_SIZE: usize = 50;
|
||||
|
||||
/// A single tool call record to be persisted.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallRecord {
|
||||
pub tool_call_id: String,
|
||||
pub session_id: Uuid,
|
||||
pub tool_name: String,
|
||||
pub caller: Uuid,
|
||||
pub arguments: serde_json::Value,
|
||||
pub status: ToolCallStatus,
|
||||
pub execution_time_ms: Option<i64>,
|
||||
pub error_message: Option<String>,
|
||||
pub error_stack: Option<String>,
|
||||
pub retry_count: i32,
|
||||
}
|
||||
|
||||
/// Channel-based batched recorder. Cheap to clone — all clones share the same sender.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCallRecorder {
|
||||
tx: mpsc::UnboundedSender<ToolCallRecord>,
|
||||
session_id: Uuid,
|
||||
}
|
||||
|
||||
impl ToolCallRecorder {
|
||||
/// Create a new recorder with an auto-generated session ID
|
||||
/// and spawn a background flush loop.
|
||||
pub fn new(db: AppDatabase) -> Self {
|
||||
Self::with_session(db, Uuid::new_v4())
|
||||
}
|
||||
|
||||
/// Create a new recorder with a specific session ID
|
||||
/// (so tool call records can be linked to an `AiSession`).
|
||||
pub fn with_session(db: AppDatabase, session_id: Uuid) -> Self {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
tokio::spawn(flush_loop(db, rx));
|
||||
Self { tx, session_id }
|
||||
}
|
||||
|
||||
/// The session ID shared by all tool calls recorded through this instance.
|
||||
pub fn session_id(&self) -> Uuid {
|
||||
self.session_id
|
||||
}
|
||||
|
||||
/// Enqueue a tool call record for batch persistence.
|
||||
pub fn record(&self, record: ToolCallRecord) {
|
||||
let _ = self.tx.send(record);
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush_loop(db: AppDatabase, mut rx: mpsc::UnboundedReceiver<ToolCallRecord>) {
|
||||
let mut buffer = Vec::with_capacity(BATCH_SIZE);
|
||||
let mut ticker = tokio::time::interval(FLUSH_INTERVAL);
|
||||
ticker.tick().await; // skip first immediate tick
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(record) = rx.recv() => {
|
||||
buffer.push(record);
|
||||
if buffer.len() >= BATCH_SIZE {
|
||||
flush(&db, &mut buffer).await;
|
||||
}
|
||||
}
|
||||
_ = ticker.tick() => {
|
||||
if !buffer.is_empty() {
|
||||
flush(&db, &mut buffer).await;
|
||||
}
|
||||
}
|
||||
else => {
|
||||
// Channel closed — flush remaining and exit
|
||||
if !buffer.is_empty() {
|
||||
flush(&db, &mut buffer).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush(db: &AppDatabase, buffer: &mut Vec<ToolCallRecord>) {
|
||||
let now = chrono::Utc::now();
|
||||
let models: Vec<ai_tool_call::ActiveModel> = buffer
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let status = r.status.to_string();
|
||||
ai_tool_call::ActiveModel {
|
||||
tool_call_id: Set(r.tool_call_id.clone()),
|
||||
session: Set(r.session_id),
|
||||
tool_name: Set(r.tool_name.clone()),
|
||||
caller: Set(r.caller),
|
||||
arguments: Set(r.arguments.clone()),
|
||||
result: Set(serde_json::Value::Null),
|
||||
status: Set(status),
|
||||
execution_time_ms: Set(r.execution_time_ms),
|
||||
error_message: Set(r.error_message.clone()),
|
||||
error_stack: Set(r.error_stack.clone()),
|
||||
retry_count: Set(r.retry_count),
|
||||
created_at: Set(now),
|
||||
completed_at: Set(Some(now)),
|
||||
updated_at: Set(now),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let count = models.len();
|
||||
if let Err(e) = ai_tool_call::Entity::insert_many(models).exec(db).await {
|
||||
tracing::warn!(error = %e, count, "failed_to_flush_tool_call_records");
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
}
|
||||
@ -4,6 +4,7 @@
|
||||
//! to implement rig's ToolDyn trait, enabling integration with rig's Agent.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use futures::FutureExt;
|
||||
use rig::completion::ToolDefinition;
|
||||
@ -11,8 +12,146 @@ use rig::tool::{ToolDyn, ToolError, ToolSet};
|
||||
|
||||
use super::context::ToolContext;
|
||||
use super::definition::ToolDefinition as AgentToolDefinition;
|
||||
use super::recorder::{ToolCallRecord, ToolCallRecorder};
|
||||
use super::registry::{ToolHandler, ToolRegistry};
|
||||
|
||||
/// Returns true if the tool error message indicates a transient failure that can be retried.
|
||||
pub fn is_retryable_tool_error(msg: &str) -> bool {
|
||||
let lower = msg.to_lowercase();
|
||||
lower.contains("retry")
|
||||
|| lower.contains("timeout")
|
||||
|| lower.contains("rate limit")
|
||||
|| lower.contains("too many requests")
|
||||
|| lower.contains("unavailable")
|
||||
|| lower.contains("connection refused")
|
||||
|| lower.contains("5")
|
||||
|| lower.contains("try again")
|
||||
}
|
||||
|
||||
/// Wraps a ToolDyn with automatic retry and tool call recording.
|
||||
///
|
||||
/// Used by the rig Agent path to replace the custom ReAct executor closure.
|
||||
pub struct RecordingTool {
|
||||
inner: Box<dyn ToolDyn>,
|
||||
db: db::database::AppDatabase,
|
||||
session_id: uuid::Uuid,
|
||||
caller: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl RecordingTool {
|
||||
pub fn new(
|
||||
inner: Box<dyn ToolDyn>,
|
||||
db: db::database::AppDatabase,
|
||||
session_id: uuid::Uuid,
|
||||
caller: uuid::Uuid,
|
||||
) -> Self {
|
||||
Self { inner, db, session_id, caller }
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolDyn for RecordingTool {
|
||||
fn name(&self) -> String {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
fn definition<'a>(
|
||||
&'a self,
|
||||
prompt: String,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
|
||||
self.inner.definition(prompt)
|
||||
}
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
args: String,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
|
||||
let inner: &'a Box<dyn ToolDyn> = &self.inner;
|
||||
let db = self.db.clone();
|
||||
let session_id = self.session_id;
|
||||
let caller = self.caller;
|
||||
let tool_name = inner.name();
|
||||
|
||||
Box::pin(async move {
|
||||
let recorder = ToolCallRecorder::with_session(db.clone(), session_id);
|
||||
let max_retries = 3u32;
|
||||
let mut last_err = String::new();
|
||||
let start = Instant::now();
|
||||
|
||||
for attempt in 0..=max_retries {
|
||||
let attempt_start = Instant::now();
|
||||
let attempt_args = args.clone();
|
||||
let attempt_result = inner.call(attempt_args).await;
|
||||
|
||||
let elapsed_ms = attempt_start.elapsed().as_millis() as i64;
|
||||
let args_json: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or_default();
|
||||
|
||||
match attempt_result {
|
||||
Ok(value) => {
|
||||
recorder.record(ToolCallRecord {
|
||||
tool_call_id: tool_name.clone(),
|
||||
session_id,
|
||||
tool_name: tool_name.clone(),
|
||||
caller,
|
||||
arguments: args_json,
|
||||
status: models::ai::ToolCallStatus::Success,
|
||||
execution_time_ms: Some(elapsed_ms),
|
||||
error_message: None,
|
||||
error_stack: None,
|
||||
retry_count: attempt as i32,
|
||||
});
|
||||
return Ok(value);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
if attempt < max_retries && is_retryable_tool_error(&err_msg) {
|
||||
last_err = err_msg;
|
||||
let backoff_ms =
|
||||
100u64.saturating_mul(2u64.pow(attempt as u32));
|
||||
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||
continue;
|
||||
}
|
||||
recorder.record(ToolCallRecord {
|
||||
tool_call_id: tool_name.clone(),
|
||||
session_id,
|
||||
tool_name: tool_name.clone(),
|
||||
caller,
|
||||
arguments: args_json,
|
||||
status: models::ai::ToolCallStatus::Failed,
|
||||
execution_time_ms: Some(elapsed_ms),
|
||||
error_message: Some(err_msg.clone()),
|
||||
error_stack: None,
|
||||
retry_count: attempt as i32,
|
||||
});
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: record failure after all retries exhausted
|
||||
let elapsed_ms = start.elapsed().as_millis() as i64;
|
||||
let args_json: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or_default();
|
||||
recorder.record(ToolCallRecord {
|
||||
tool_call_id: tool_name.clone(),
|
||||
session_id,
|
||||
tool_name: tool_name.clone(),
|
||||
caller,
|
||||
arguments: args_json,
|
||||
status: models::ai::ToolCallStatus::Failed,
|
||||
execution_time_ms: Some(elapsed_ms),
|
||||
error_message: Some(last_err),
|
||||
error_stack: None,
|
||||
retry_count: max_retries as i32,
|
||||
});
|
||||
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"max retries exceeded",
|
||||
))))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper that converts our ToolRegistry to rig's ToolSet.
|
||||
pub struct RigToolSet {
|
||||
/// The rig ToolSet
|
||||
@ -30,6 +169,7 @@ impl RigToolSet {
|
||||
config: config::AppConfig,
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
) -> Self {
|
||||
let mut toolset = ToolSet::default();
|
||||
let mut definitions = HashMap::new();
|
||||
@ -50,6 +190,7 @@ impl RigToolSet {
|
||||
config: config.clone(),
|
||||
room_id,
|
||||
sender_id,
|
||||
project_id,
|
||||
};
|
||||
toolset.add_tool(adapter);
|
||||
}
|
||||
@ -85,6 +226,23 @@ pub struct RigToolAdapter {
|
||||
config: config::AppConfig,
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl RigToolAdapter {
|
||||
/// Create a new RigToolAdapter with all required context.
|
||||
pub fn new(
|
||||
handler: ToolHandler,
|
||||
definition: AgentToolDefinition,
|
||||
db: db::database::AppDatabase,
|
||||
cache: db::cache::AppCache,
|
||||
config: config::AppConfig,
|
||||
room_id: uuid::Uuid,
|
||||
sender_id: Option<uuid::Uuid>,
|
||||
project_id: uuid::Uuid,
|
||||
) -> Self {
|
||||
Self { handler, definition, db, cache, config, room_id, sender_id, project_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolDyn for RigToolAdapter {
|
||||
@ -113,6 +271,7 @@ impl ToolDyn for RigToolAdapter {
|
||||
let config = self.config.clone();
|
||||
let room_id = self.room_id;
|
||||
let sender_id = self.sender_id;
|
||||
let project_id = self.project_id;
|
||||
|
||||
async move {
|
||||
let ctx = ToolContext::new(
|
||||
@ -121,7 +280,8 @@ impl ToolDyn for RigToolAdapter {
|
||||
config,
|
||||
room_id,
|
||||
sender_id,
|
||||
);
|
||||
)
|
||||
.with_project(project_id);
|
||||
|
||||
let args_json: serde_json::Value = serde_json::from_str(&args)
|
||||
.map_err(|e| ToolError::JsonError(e))?;
|
||||
|
||||
@ -272,6 +272,7 @@ pub async fn ws_universal(
|
||||
"data": {
|
||||
"message_id": chunk.message_id,
|
||||
"room_id": chunk.room_id,
|
||||
"seq": chunk.seq,
|
||||
"content": chunk.content,
|
||||
"done": chunk.done,
|
||||
"error": chunk.error,
|
||||
|
||||
@ -110,6 +110,9 @@ pub struct ProjectRoomEvent {
|
||||
pub struct RoomMessageStreamChunkEvent {
|
||||
pub message_id: Uuid,
|
||||
pub room_id: Uuid,
|
||||
/// Monotonically increasing sequence number for ordering within this stream.
|
||||
#[serde(default)]
|
||||
pub seq: u64,
|
||||
pub content: String,
|
||||
pub done: bool,
|
||||
pub error: Option<String>,
|
||||
|
||||
@ -54,6 +54,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> =
|
||||
std::sync::Arc::new(std::sync::Mutex::new(String::new()));
|
||||
let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
|
||||
|
||||
// Helper: recover from poison instead of panicking.
|
||||
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
|
||||
@ -65,6 +66,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
let streaming_msg_id = streaming_msg_id;
|
||||
let room_id = room_id_inner;
|
||||
let step_count = step_count.clone();
|
||||
let chunk_seq = chunk_seq.clone();
|
||||
let ai_display_name_for_step = std::sync::Arc::new(ai_display_name.clone());
|
||||
let steps = steps.clone();
|
||||
let answer_buffer = answer_buffer.clone();
|
||||
@ -73,18 +75,20 @@ pub async fn process_message_ai_react_streaming(
|
||||
let room_manager = room_manager.clone();
|
||||
let (chunk_type, content) = match &step {
|
||||
ReactStep::Thought { step: _, thought } => {
|
||||
("thinking".to_string(), format!("[Thinking] {}", thought))
|
||||
("thinking".to_string(), thought.clone())
|
||||
}
|
||||
ReactStep::Action { step: _, action } => {
|
||||
*lock_or_recover(&last_action_name) = action.name.clone();
|
||||
("tool_call".to_string(), format!("[Action] Calling `{}` with {:?}", action.name, action.args))
|
||||
("tool_call".to_string(), serde_json::json!({
|
||||
"name": action.name,
|
||||
"arguments": action.args,
|
||||
}).to_string())
|
||||
}
|
||||
ReactStep::Observation {
|
||||
step: _,
|
||||
observation: _,
|
||||
observation,
|
||||
} => {
|
||||
let action_name = lock_or_recover(&last_action_name).clone();
|
||||
("tool_call".to_string(), format!("[Observation] {} (completed)", action_name))
|
||||
("tool_result".to_string(), observation.clone())
|
||||
}
|
||||
ReactStep::Answer { step: _, answer } => {
|
||||
("answer".to_string(), answer.clone())
|
||||
@ -96,22 +100,33 @@ pub async fn process_message_ai_react_streaming(
|
||||
step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// Record ordered step for storage
|
||||
// Record ordered step for storage — merge consecutive same-type chunks
|
||||
// to ensure strict think→answer→think→answer alternation.
|
||||
{
|
||||
let mut s = lock_or_recover(&steps);
|
||||
if let Some(last) = s.last_mut() {
|
||||
if last.0 == chunk_type {
|
||||
last.1.push_str(&content);
|
||||
} else {
|
||||
s.push((chunk_type.clone(), content.clone()));
|
||||
}
|
||||
} else {
|
||||
s.push((chunk_type.clone(), content.clone()));
|
||||
}
|
||||
}
|
||||
if is_answer {
|
||||
let mut ab = lock_or_recover(&answer_buffer);
|
||||
ab.push_str(&content);
|
||||
}
|
||||
|
||||
let done = is_answer;
|
||||
let done = false;
|
||||
let ai_name = ai_display_name_for_step.clone();
|
||||
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
let event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: current_seq,
|
||||
content: content.clone(),
|
||||
done,
|
||||
error: None,
|
||||
@ -125,6 +140,21 @@ pub async fn process_message_ai_react_streaming(
|
||||
|
||||
let result = chat_service.process_react(&request, on_step).await;
|
||||
|
||||
// Broadcast final done=true event to close the streaming channel on frontend.
|
||||
let final_stream_content = lock_or_recover(&answer_buffer).clone();
|
||||
room_manager
|
||||
.broadcast_stream_chunk(RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
content: final_stream_content.clone(),
|
||||
done: true,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: Some("answer".to_string()),
|
||||
})
|
||||
.await;
|
||||
|
||||
let final_content = lock_or_recover(&answer_buffer).clone();
|
||||
let all_steps = lock_or_recover(&steps).clone();
|
||||
let reasoning_chain: String = all_steps
|
||||
@ -172,7 +202,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
}
|
||||
|
||||
// Serialize ordered steps as JSON for ordered replay.
|
||||
let thinking_content = {
|
||||
let thinking_content_serialized = {
|
||||
let steps = lock_or_recover(&steps);
|
||||
if steps.is_empty() {
|
||||
None
|
||||
@ -186,6 +216,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
Some(chunks_json.to_string())
|
||||
}
|
||||
};
|
||||
let thinking_content_for_event = thinking_content_serialized.clone();
|
||||
|
||||
let envelope = RoomMessageEnvelope {
|
||||
id: streaming_msg_id,
|
||||
@ -197,7 +228,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
thread_id: None,
|
||||
content: persist_content.clone(),
|
||||
content_type: "text".to_string(),
|
||||
thinking_content,
|
||||
thinking_content: thinking_content_serialized,
|
||||
send_at: now,
|
||||
seq,
|
||||
in_reply_to: None,
|
||||
@ -244,7 +275,7 @@ pub async fn process_message_ai_react_streaming(
|
||||
thread_id: None,
|
||||
content: persist_content,
|
||||
content_type: "text".to_string(),
|
||||
thinking_content: None,
|
||||
thinking_content: thinking_content_for_event,
|
||||
send_at: now,
|
||||
seq,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user