//! Adapter to bridge our ToolRegistry with rig's Tool system. //! //! This module provides adapters that wrap our custom ToolHandler/Registry //! to implement rig's ToolDyn trait, enabling integration with rig's Agent. use std::collections::HashMap; use futures::FutureExt; use rig::completion::ToolDefinition; use rig::tool::{ToolDyn, ToolError, ToolSet}; use super::context::ToolContext; use super::definition::ToolDefinition as AgentToolDefinition; use super::registry::{ToolHandler, ToolRegistry}; /// A wrapper that converts our ToolRegistry to rig's ToolSet. pub struct RigToolSet { /// The rig ToolSet inner: ToolSet, /// Tool definitions for converting back definitions: HashMap, } impl RigToolSet { /// Create a new RigToolSet from our ToolRegistry. pub fn from_registry( registry: &ToolRegistry, db: db::database::AppDatabase, cache: db::cache::AppCache, config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, ) -> Self { let mut toolset = ToolSet::default(); let mut definitions = HashMap::new(); for name in registry.definitions().map(|d| d.name.clone()).collect::>() { let def = registry.definitions().find(|d| d.name == name).cloned().unwrap_or_else(|| { AgentToolDefinition::new(&name) }); definitions.insert(name.clone(), def.clone()); let handler = registry.get(&name).cloned(); if let Some(handler) = handler { let adapter = RigToolAdapter { handler, definition: def, db: db.clone(), cache: cache.clone(), config: config.clone(), room_id, sender_id, }; toolset.add_tool(adapter); } } Self { inner: toolset, definitions } } /// Get the inner rig ToolSet pub fn inner(&self) -> &ToolSet { &self.inner } /// Get the tool definitions pub fn definitions(&self) -> &HashMap { &self.definitions } /// Convert to JSON tool definitions for non-rig paths pub fn to_openai_tools(&self) -> Vec { self.definitions.values() .map(|d| d.to_openai_tool()) .collect() } } /// Adapter that wraps our ToolHandler to implement rig's ToolDyn. pub struct RigToolAdapter { handler: ToolHandler, definition: AgentToolDefinition, db: db::database::AppDatabase, cache: db::cache::AppCache, config: config::AppConfig, room_id: uuid::Uuid, sender_id: Option, } impl ToolDyn for RigToolAdapter { fn name(&self) -> String { self.definition.name.clone() } fn definition<'a>(&'a self, _prompt: String) -> std::pin::Pin + Send + 'a>> { let def = self.definition.clone(); Box::pin(async move { ToolDefinition { name: def.name.clone(), description: def.description.unwrap_or_default(), parameters: def.parameters .as_ref() .map(|p| serde_json::to_value(p).unwrap_or(serde_json::json!({}))) .unwrap_or(serde_json::json!({})), } }) } fn call<'a>(&'a self, args: String) -> std::pin::Pin> + Send + 'a>> { let handler = self.handler.clone(); let db = self.db.clone(); let cache = self.cache.clone(); let config = self.config.clone(); let room_id = self.room_id; let sender_id = self.sender_id; async move { let ctx = ToolContext::new( db, cache, config, room_id, sender_id, ); let args_json: serde_json::Value = serde_json::from_str(&args) .map_err(|e| ToolError::JsonError(e))?; let result = handler.execute(ctx, args_json).await; match result { Ok(value) => { serde_json::to_string(&value) .map_err(|e| ToolError::JsonError(e)) } Err(e) => { let error_msg = match e { super::call::ToolError::NotFound(n) => n, super::call::ToolError::ParseError(p) => p, super::call::ToolError::ExecutionError(e) => e, super::call::ToolError::RecursionLimitExceeded { max_depth } => { format!("recursion limit exceeded (max depth: {})", max_depth) } super::call::ToolError::MaxToolCallsExceeded(n) => { format!("max tool calls exceeded: {}", n) } super::call::ToolError::Internal(i) => i, }; Err(ToolError::ToolCallError(Box::new(std::io::Error::new( std::io::ErrorKind::Other, error_msg, )))) } } }.boxed() } }