use std::pin::Pin; use std::sync::Arc; use rig::completion::ToolDefinition as RigToolDefinition; use rig::tool::ToolDyn; use serde_json::Value; use tokio::sync::Mutex; use crate::tool::tools::FunctionCall; pub struct RigTool where C: Clone + Send + Sync + 'static, { context: Arc>, tool: Arc>, name: String, description: String, schema: Value, } impl RigTool where C: Clone + Send + Sync + 'static, { pub fn new(tool: Arc>, context: Arc>) -> Self { let name = tool.name().to_string(); let description = tool.description().to_string(); let schema = tool.schema(); Self { context, tool, name, description, schema, } } } impl ToolDyn for RigTool where C: Clone + Send + Sync + 'static, { fn name(&self) -> String { self.name.clone() } fn definition<'a>( &'a self, _prompt: String, ) -> Pin + Send + 'a>> { let name = self.name.clone(); let description = self.description.clone(); let params = self.schema.clone(); Box::pin(async move { RigToolDefinition { name, description, parameters: params, } }) } fn call<'a>( &'a self, args: String, ) -> Pin< Box> + Send + 'a>, > { let tool = self.tool.clone(); let context = self.context.clone(); Box::pin(async move { let args_value: Value = serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?; let mut ctx = context.lock().await; match tool.call(&mut *ctx, args_value).await { Ok(value) => serde_json::to_string(&value) .map_err(rig::tool::ToolError::JsonError), Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new( std::io::Error::other(ai_err.to_string()), ))), } }) } } pub struct RigToolSet where C: Clone + Send + Sync + 'static, { tools: Vec>, context: Option>>, } impl RigToolSet where C: Clone + Send + Sync + 'static, { pub fn new() -> Self { Self { tools: Vec::new(), context: None, } } pub fn from_register( register: &crate::tool::register::ToolRegister, context: Arc>, ) -> Self { let mut tools: Vec> = Vec::with_capacity(register.len()); for tool_arc in ®ister.tools { tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone()))); } Self { tools, context: Some(context), } } pub fn is_empty(&self) -> bool { self.tools.is_empty() } pub fn len(&self) -> usize { self.tools.len() } pub fn context(&self) -> Option<&Arc>> { self.context.as_ref() } pub fn take_tools(&mut self) -> Vec> { std::mem::take(&mut self.tools) } pub fn into_context(mut self) -> C { self.context .take() .and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner())) .unwrap_or_else(|| unreachable!("context must be available")) } } impl Default for RigToolSet where C: Clone + Send + Sync + 'static, { fn default() -> Self { Self::new() } }