346 lines
13 KiB
Rust
346 lines
13 KiB
Rust
//! Adapter to bridge our ToolRegistry with rig's Tool system.
|
|
//!
|
|
//! This module provides adapters that wrap our custom ToolHandler/Registry
|
|
//! to implement rig's ToolDyn trait, enabling integration with rig's Agent.
|
|
|
|
use std::collections::HashMap;
|
|
use std::time::{Duration, Instant};
|
|
|
|
use futures::FutureExt;
|
|
use rig::completion::ToolDefinition;
|
|
use rig::tool::{ToolDyn, ToolError, ToolSet};
|
|
|
|
use super::context::ToolContext;
|
|
use super::definition::ToolDefinition as AgentToolDefinition;
|
|
use super::recorder::{ToolCallRecord, ToolCallRecorder};
|
|
use super::registry::{ToolHandler, ToolRegistry};
|
|
use queue::MessageProducer;
|
|
|
|
/// Returns true if the tool error message indicates a transient failure that can be retried.
|
|
pub fn is_retryable_tool_error(msg: &str) -> bool {
|
|
let lower = msg.to_lowercase();
|
|
lower.contains("retry")
|
|
|| lower.contains("timeout")
|
|
|| lower.contains("rate limit")
|
|
|| lower.contains("too many requests")
|
|
|| lower.contains("unavailable")
|
|
|| lower.contains("connection refused")
|
|
|| lower.contains("5")
|
|
|| lower.contains("try again")
|
|
}
|
|
|
|
/// Wraps a ToolDyn with automatic retry and tool call recording.
|
|
///
|
|
/// Used by the rig Agent path to replace the custom ReAct executor closure.
|
|
pub struct RecordingTool {
|
|
inner: Box<dyn ToolDyn>,
|
|
db: db::database::AppDatabase,
|
|
session_id: uuid::Uuid,
|
|
caller: uuid::Uuid,
|
|
}
|
|
|
|
impl RecordingTool {
|
|
pub fn new(
|
|
inner: Box<dyn ToolDyn>,
|
|
db: db::database::AppDatabase,
|
|
session_id: uuid::Uuid,
|
|
caller: uuid::Uuid,
|
|
) -> Self {
|
|
Self { inner, db, session_id, caller }
|
|
}
|
|
}
|
|
|
|
impl ToolDyn for RecordingTool {
|
|
fn name(&self) -> String {
|
|
self.inner.name()
|
|
}
|
|
|
|
fn definition<'a>(
|
|
&'a self,
|
|
prompt: String,
|
|
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
|
|
self.inner.definition(prompt)
|
|
}
|
|
|
|
fn call<'a>(
|
|
&'a self,
|
|
args: String,
|
|
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
|
|
let inner: &'a Box<dyn ToolDyn> = &self.inner;
|
|
let db = self.db.clone();
|
|
let session_id = self.session_id;
|
|
let caller = self.caller;
|
|
let tool_name = inner.name();
|
|
|
|
Box::pin(async move {
|
|
let recorder = ToolCallRecorder::with_session(db.clone(), session_id);
|
|
let max_retries = 3u32;
|
|
let mut last_err = String::new();
|
|
let start = Instant::now();
|
|
|
|
for attempt in 0..=max_retries {
|
|
let attempt_start = Instant::now();
|
|
let attempt_args = args.clone();
|
|
let attempt_result = inner.call(attempt_args).await;
|
|
|
|
let elapsed_ms = attempt_start.elapsed().as_millis() as i64;
|
|
let args_json: serde_json::Value =
|
|
serde_json::from_str(&args).unwrap_or_default();
|
|
|
|
match attempt_result {
|
|
Ok(value) => {
|
|
recorder.record(ToolCallRecord {
|
|
tool_call_id: tool_name.clone(),
|
|
session_id,
|
|
tool_name: tool_name.clone(),
|
|
caller,
|
|
arguments: args_json,
|
|
status: models::ai::ToolCallStatus::Success,
|
|
execution_time_ms: Some(elapsed_ms),
|
|
error_message: None,
|
|
error_stack: None,
|
|
retry_count: attempt as i32,
|
|
});
|
|
return Ok(value);
|
|
}
|
|
Err(e) => {
|
|
let err_msg = e.to_string();
|
|
if attempt < max_retries && is_retryable_tool_error(&err_msg) {
|
|
last_err = err_msg;
|
|
let backoff_ms =
|
|
100u64.saturating_mul(2u64.pow(attempt as u32));
|
|
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
|
continue;
|
|
}
|
|
recorder.record(ToolCallRecord {
|
|
tool_call_id: tool_name.clone(),
|
|
session_id,
|
|
tool_name: tool_name.clone(),
|
|
caller,
|
|
arguments: args_json,
|
|
status: models::ai::ToolCallStatus::Failed,
|
|
execution_time_ms: Some(elapsed_ms),
|
|
error_message: Some(err_msg.clone()),
|
|
error_stack: None,
|
|
retry_count: attempt as i32,
|
|
});
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback: record failure after all retries exhausted
|
|
let elapsed_ms = start.elapsed().as_millis() as i64;
|
|
let args_json: serde_json::Value =
|
|
serde_json::from_str(&args).unwrap_or_default();
|
|
recorder.record(ToolCallRecord {
|
|
tool_call_id: tool_name.clone(),
|
|
session_id,
|
|
tool_name: tool_name.clone(),
|
|
caller,
|
|
arguments: args_json,
|
|
status: models::ai::ToolCallStatus::Failed,
|
|
execution_time_ms: Some(elapsed_ms),
|
|
error_message: Some(last_err),
|
|
error_stack: None,
|
|
retry_count: max_retries as i32,
|
|
});
|
|
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
|
|
std::io::ErrorKind::Other,
|
|
"max retries exceeded",
|
|
))))
|
|
})
|
|
}
|
|
}
|
|
|
|
/// A wrapper that converts our ToolRegistry to rig's ToolSet.
|
|
pub struct RigToolSet {
|
|
/// The rig ToolSet
|
|
inner: ToolSet,
|
|
/// Tool definitions for converting back
|
|
definitions: HashMap<String, AgentToolDefinition>,
|
|
}
|
|
|
|
impl RigToolSet {
|
|
/// Create a new RigToolSet from our ToolRegistry.
|
|
pub fn from_registry(
|
|
registry: &ToolRegistry,
|
|
db: db::database::AppDatabase,
|
|
cache: db::cache::AppCache,
|
|
config: config::AppConfig,
|
|
room_id: uuid::Uuid,
|
|
sender_id: Option<uuid::Uuid>,
|
|
project_id: uuid::Uuid,
|
|
message_producer: Option<MessageProducer>,
|
|
ai_model_id: Option<uuid::Uuid>,
|
|
ai_model_name: Option<String>,
|
|
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
|
) -> Self {
|
|
let mut toolset = ToolSet::default();
|
|
let mut definitions = HashMap::new();
|
|
|
|
for name in registry.definitions().map(|d| d.name.clone()).collect::<Vec<_>>() {
|
|
let def = registry.definitions().find(|d| d.name == name).cloned().unwrap_or_else(|| {
|
|
AgentToolDefinition::new(&name)
|
|
});
|
|
definitions.insert(name.clone(), def.clone());
|
|
|
|
let handler = registry.get(&name).cloned();
|
|
if let Some(handler) = handler {
|
|
let adapter = RigToolAdapter {
|
|
handler,
|
|
definition: def,
|
|
db: db.clone(),
|
|
cache: cache.clone(),
|
|
config: config.clone(),
|
|
room_id,
|
|
sender_id,
|
|
project_id,
|
|
message_producer: message_producer.clone(),
|
|
ai_model_id,
|
|
ai_model_name: ai_model_name.clone(),
|
|
sent_in_turn: sent_in_turn.clone(),
|
|
};
|
|
toolset.add_tool(adapter);
|
|
}
|
|
}
|
|
|
|
Self { inner: toolset, definitions }
|
|
}
|
|
|
|
/// Get the inner rig ToolSet
|
|
pub fn inner(&self) -> &ToolSet {
|
|
&self.inner
|
|
}
|
|
|
|
/// Get the tool definitions
|
|
pub fn definitions(&self) -> &HashMap<String, AgentToolDefinition> {
|
|
&self.definitions
|
|
}
|
|
|
|
/// Convert to JSON tool definitions for non-rig paths
|
|
pub fn to_openai_tools(&self) -> Vec<serde_json::Value> {
|
|
self.definitions.values()
|
|
.map(|d| d.to_openai_tool())
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
/// Adapter that wraps our ToolHandler to implement rig's ToolDyn.
|
|
pub struct RigToolAdapter {
|
|
handler: ToolHandler,
|
|
definition: AgentToolDefinition,
|
|
db: db::database::AppDatabase,
|
|
cache: db::cache::AppCache,
|
|
config: config::AppConfig,
|
|
room_id: uuid::Uuid,
|
|
sender_id: Option<uuid::Uuid>,
|
|
project_id: uuid::Uuid,
|
|
message_producer: Option<MessageProducer>,
|
|
ai_model_id: Option<uuid::Uuid>,
|
|
ai_model_name: Option<String>,
|
|
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
|
}
|
|
|
|
impl RigToolAdapter {
|
|
/// Create a new RigToolAdapter with all required context.
|
|
pub fn new(
|
|
handler: ToolHandler,
|
|
definition: AgentToolDefinition,
|
|
db: db::database::AppDatabase,
|
|
cache: db::cache::AppCache,
|
|
config: config::AppConfig,
|
|
room_id: uuid::Uuid,
|
|
sender_id: Option<uuid::Uuid>,
|
|
project_id: uuid::Uuid,
|
|
message_producer: Option<MessageProducer>,
|
|
ai_model_id: Option<uuid::Uuid>,
|
|
ai_model_name: Option<String>,
|
|
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
|
|
) -> Self {
|
|
Self { handler, definition, db, cache, config, room_id, sender_id, project_id, message_producer, ai_model_id, ai_model_name, sent_in_turn }
|
|
}
|
|
}
|
|
|
|
impl ToolDyn for RigToolAdapter {
|
|
fn name(&self) -> String {
|
|
self.definition.name.clone()
|
|
}
|
|
|
|
fn definition<'a>(&'a self, _prompt: String) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
|
|
let def = self.definition.clone();
|
|
Box::pin(async move {
|
|
ToolDefinition {
|
|
name: def.name.clone(),
|
|
description: def.description.unwrap_or_default(),
|
|
parameters: def.parameters
|
|
.as_ref()
|
|
.map(|p| serde_json::to_value(p).unwrap_or(serde_json::json!({})))
|
|
.unwrap_or(serde_json::json!({})),
|
|
}
|
|
})
|
|
}
|
|
|
|
fn call<'a>(&'a self, args: String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
|
|
let handler = self.handler.clone();
|
|
let db = self.db.clone();
|
|
let cache = self.cache.clone();
|
|
let config = self.config.clone();
|
|
let room_id = self.room_id;
|
|
let sender_id = self.sender_id;
|
|
let project_id = self.project_id;
|
|
let message_producer = self.message_producer.clone();
|
|
let ai_model_id = self.ai_model_id;
|
|
let ai_model_name = self.ai_model_name.clone();
|
|
let sent_in_turn = self.sent_in_turn.clone();
|
|
|
|
async move {
|
|
let mut ctx = ToolContext::new(
|
|
db,
|
|
cache,
|
|
config,
|
|
room_id,
|
|
sender_id,
|
|
)
|
|
.with_project(project_id)
|
|
.with_sent_in_turn(sent_in_turn);
|
|
if let Some(mp) = message_producer {
|
|
ctx = ctx.with_message_producer(mp);
|
|
}
|
|
if let Some(mid) = ai_model_id {
|
|
ctx = ctx.with_ai_model(mid, ai_model_name.unwrap_or_default());
|
|
}
|
|
|
|
let args_json: serde_json::Value = serde_json::from_str(&args)
|
|
.map_err(|e| ToolError::JsonError(e))?;
|
|
|
|
let result = handler.execute(ctx, args_json).await;
|
|
|
|
match result {
|
|
Ok(value) => {
|
|
serde_json::to_string(&value)
|
|
.map_err(|e| ToolError::JsonError(e))
|
|
}
|
|
Err(e) => {
|
|
let error_msg = match e {
|
|
super::call::ToolError::NotFound(n) => n,
|
|
super::call::ToolError::ParseError(p) => p,
|
|
super::call::ToolError::ExecutionError(e) => e,
|
|
super::call::ToolError::RecursionLimitExceeded { max_depth } => {
|
|
format!("recursion limit exceeded (max depth: {})", max_depth)
|
|
}
|
|
super::call::ToolError::MaxToolCallsExceeded(n) => {
|
|
format!("max tool calls exceeded: {}", n)
|
|
}
|
|
super::call::ToolError::Internal(i) => i,
|
|
};
|
|
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
|
|
std::io::ErrorKind::Other,
|
|
error_msg,
|
|
))))
|
|
}
|
|
}
|
|
}.boxed()
|
|
}
|
|
}
|