//! Adapter to bridge our ToolRegistry with rig's Tool system. //! //! This module provides adapters that wrap our custom ToolHandler/Registry //! to implement rig's ToolDyn trait, enabling integration with rig's Agent. use std::collections::HashMap; use std::time::{Duration, Instant}; use futures::FutureExt; use rig::completion::ToolDefinition; use rig::tool::{ToolDyn, ToolError, ToolSet}; use super::context::ToolContext; use super::definition::ToolDefinition as AgentToolDefinition; use super::recorder::{ToolCallRecord, ToolCallRecorder}; use super::registry::{ToolHandler, ToolRegistry}; use queue::MessageProducer; /// Returns true if the tool error message indicates a transient failure that can be retried. pub fn is_retryable_tool_error(msg: &str) -> bool { let lower = msg.to_lowercase(); lower.contains("retry") || lower.contains("timeout") || lower.contains("rate limit") || lower.contains("too many requests") || lower.contains("unavailable") || lower.contains("connection refused") || lower.contains("5") || lower.contains("try again") } /// Wraps a ToolDyn with automatic retry and tool call recording. /// /// Used by the rig Agent path to replace the custom ReAct executor closure. pub struct RecordingTool { inner: Box, db: db::database::AppDatabase, session_id: uuid::Uuid, caller: uuid::Uuid, } impl RecordingTool { pub fn new( inner: Box, db: db::database::AppDatabase, session_id: uuid::Uuid, caller: uuid::Uuid, ) -> Self { Self { inner, db, session_id, caller } } } impl ToolDyn for RecordingTool { fn name(&self) -> String { self.inner.name() } fn definition<'a>( &'a self, prompt: String, ) -> std::pin::Pin + Send + 'a>> { self.inner.definition(prompt) } fn call<'a>( &'a self, args: String, ) -> std::pin::Pin> + Send + 'a>> { let inner: &'a Box = &self.inner; let db = self.db.clone(); let session_id = self.session_id; let caller = self.caller; let tool_name = inner.name(); Box::pin(async move { let recorder = ToolCallRecorder::with_session(db.clone(), session_id); let max_retries = 3u32; let mut last_err = String::new(); let start = Instant::now(); for attempt in 0..=max_retries { let attempt_start = Instant::now(); let attempt_args = args.clone(); let attempt_result = inner.call(attempt_args).await; let elapsed_ms = attempt_start.elapsed().as_millis() as i64; let args_json: serde_json::Value = serde_json::from_str(&args).unwrap_or_default(); match attempt_result { Ok(value) => { recorder.record(ToolCallRecord { tool_call_id: tool_name.clone(), session_id, tool_name: tool_name.clone(), caller, arguments: args_json, status: models::ai::ToolCallStatus::Success, execution_time_ms: Some(elapsed_ms), error_message: None, error_stack: None, retry_count: attempt as i32, }); return Ok(value); } Err(e) => { let err_msg = e.to_string(); if attempt < max_retries && is_retryable_tool_error(&err_msg) { last_err = err_msg; let backoff_ms = 100u64.saturating_mul(2u64.pow(attempt as u32)); tokio::time::sleep(Duration::from_millis(backoff_ms)).await; continue; } recorder.record(ToolCallRecord { tool_call_id: tool_name.clone(), session_id, tool_name: tool_name.clone(), caller, arguments: args_json, status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed_ms), error_message: Some(err_msg.clone()), error_stack: None, retry_count: attempt as i32, }); return Err(e); } } } // Fallback: record failure after all retries exhausted let elapsed_ms = start.elapsed().as_millis() as i64; let args_json: serde_json::Value = serde_json::from_str(&args).unwrap_or_default(); recorder.record(ToolCallRecord { tool_call_id: tool_name.clone(), session_id, tool_name: tool_name.clone(), caller, arguments: args_json, status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed_ms), error_message: Some(last_err), error_stack: None, retry_count: max_retries as i32, }); Err(ToolError::ToolCallError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, "max retries exceeded", )))) }) } } /// A wrapper that converts our ToolRegistry to rig's ToolSet. pub struct RigToolSet { /// The rig ToolSet inner: ToolSet, /// Tool definitions for converting back definitions: HashMap, } impl RigToolSet { /// Create a new RigToolSet from our ToolRegistry. pub fn from_registry( registry: &ToolRegistry, db: db::database::AppDatabase, cache: db::cache::AppCache, config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, project_id: uuid::Uuid, message_producer: Option, ai_model_id: Option, ai_model_name: Option, sent_in_turn: std::sync::Arc>>, ) -> Self { let mut toolset = ToolSet::default(); let mut definitions = HashMap::new(); for name in registry.definitions().map(|d| d.name.clone()).collect::>() { let def = registry.definitions().find(|d| d.name == name).cloned().unwrap_or_else(|| { AgentToolDefinition::new(&name) }); definitions.insert(name.clone(), def.clone()); let handler = registry.get(&name).cloned(); if let Some(handler) = handler { let adapter = RigToolAdapter { handler, definition: def, db: db.clone(), cache: cache.clone(), config: config.clone(), room_id, sender_id, project_id, message_producer: message_producer.clone(), ai_model_id, ai_model_name: ai_model_name.clone(), sent_in_turn: sent_in_turn.clone(), }; toolset.add_tool(adapter); } } Self { inner: toolset, definitions } } /// Get the inner rig ToolSet pub fn inner(&self) -> &ToolSet { &self.inner } /// Get the tool definitions pub fn definitions(&self) -> &HashMap { &self.definitions } /// Convert to JSON tool definitions for non-rig paths pub fn to_openai_tools(&self) -> Vec { self.definitions.values() .map(|d| d.to_openai_tool()) .collect() } } /// Adapter that wraps our ToolHandler to implement rig's ToolDyn. pub struct RigToolAdapter { handler: ToolHandler, definition: AgentToolDefinition, db: db::database::AppDatabase, cache: db::cache::AppCache, config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, project_id: uuid::Uuid, message_producer: Option, ai_model_id: Option, ai_model_name: Option, sent_in_turn: std::sync::Arc>>, } impl RigToolAdapter { /// Create a new RigToolAdapter with all required context. pub fn new( handler: ToolHandler, definition: AgentToolDefinition, db: db::database::AppDatabase, cache: db::cache::AppCache, config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, project_id: uuid::Uuid, message_producer: Option, ai_model_id: Option, ai_model_name: Option, sent_in_turn: std::sync::Arc>>, ) -> Self { Self { handler, definition, db, cache, config, room_id, sender_id, project_id, message_producer, ai_model_id, ai_model_name, sent_in_turn } } } impl ToolDyn for RigToolAdapter { fn name(&self) -> String { self.definition.name.clone() } fn definition<'a>(&'a self, _prompt: String) -> std::pin::Pin + Send + 'a>> { let def = self.definition.clone(); Box::pin(async move { ToolDefinition { name: def.name.clone(), description: def.description.unwrap_or_default(), parameters: def.parameters .as_ref() .map(|p| serde_json::to_value(p).unwrap_or(serde_json::json!({}))) .unwrap_or(serde_json::json!({})), } }) } fn call<'a>(&'a self, args: String) -> std::pin::Pin> + Send + 'a>> { let handler = self.handler.clone(); let db = self.db.clone(); let cache = self.cache.clone(); let config = self.config.clone(); let room_id = self.room_id; let sender_id = self.sender_id; let project_id = self.project_id; let message_producer = self.message_producer.clone(); let ai_model_id = self.ai_model_id; let ai_model_name = self.ai_model_name.clone(); let sent_in_turn = self.sent_in_turn.clone(); async move { let mut ctx = ToolContext::new( db, cache, config, room_id, sender_id, ) .with_project(project_id) .with_sent_in_turn(sent_in_turn); if let Some(mp) = message_producer { ctx = ctx.with_message_producer(mp); } if let Some(mid) = ai_model_id { ctx = ctx.with_ai_model(mid, ai_model_name.unwrap_or_default()); } let args_json: serde_json::Value = serde_json::from_str(&args) .map_err(|e| ToolError::JsonError(e))?; let result = handler.execute(ctx, args_json).await; match result { Ok(value) => { serde_json::to_string(&value) .map_err(|e| ToolError::JsonError(e)) } Err(e) => { let error_msg = match e { super::call::ToolError::NotFound(n) => n, super::call::ToolError::ParseError(p) => p, super::call::ToolError::ExecutionError(e) => e, super::call::ToolError::RecursionLimitExceeded { max_depth } => { format!("recursion limit exceeded (max depth: {})", max_depth) } super::call::ToolError::MaxToolCallsExceeded(n) => { format!("max tool calls exceeded: {}", n) } super::call::ToolError::Internal(i) => i, }; Err(ToolError::ToolCallError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, error_msg, )))) } } }.boxed() } }