172 lines
3.9 KiB
Rust
172 lines
3.9 KiB
Rust
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
|
|
use rig::completion::ToolDefinition as RigToolDefinition;
|
|
use rig::tool::ToolDyn;
|
|
use serde_json::Value;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::tool::tools::FunctionCall;
|
|
|
|
pub struct RigTool<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
context: Arc<Mutex<C>>,
|
|
tool: Arc<dyn FunctionCall<Context = C>>,
|
|
name: String,
|
|
description: String,
|
|
schema: Value,
|
|
}
|
|
|
|
impl<C> RigTool<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
pub fn new(
|
|
tool: Arc<dyn FunctionCall<Context = C>>,
|
|
context: Arc<Mutex<C>>,
|
|
) -> Self {
|
|
let name = tool.name().to_string();
|
|
let description = tool.description().to_string();
|
|
let schema = tool.schema();
|
|
|
|
Self {
|
|
context,
|
|
tool,
|
|
name,
|
|
description,
|
|
schema,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<C> ToolDyn for RigTool<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
fn name(&self) -> String {
|
|
self.name.clone()
|
|
}
|
|
|
|
fn definition<'a>(
|
|
&'a self,
|
|
_prompt: String,
|
|
) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>>
|
|
{
|
|
let name = self.name.clone();
|
|
let description = self.description.clone();
|
|
let params = self.schema.clone();
|
|
|
|
Box::pin(async move {
|
|
RigToolDefinition {
|
|
name,
|
|
description,
|
|
parameters: params,
|
|
}
|
|
})
|
|
}
|
|
|
|
fn call<'a>(
|
|
&'a self,
|
|
args: String,
|
|
) -> Pin<
|
|
Box<
|
|
dyn std::future::Future<
|
|
Output = Result<String, rig::tool::ToolError>,
|
|
> + Send
|
|
+ 'a,
|
|
>,
|
|
> {
|
|
let tool = self.tool.clone();
|
|
let context = self.context.clone();
|
|
|
|
Box::pin(async move {
|
|
let args_value: Value = serde_json::from_str(&args)
|
|
.map_err(rig::tool::ToolError::JsonError)?;
|
|
|
|
let mut ctx = context.lock().await;
|
|
|
|
match tool.call(&mut *ctx, args_value).await {
|
|
Ok(value) => serde_json::to_string(&value)
|
|
.map_err(rig::tool::ToolError::JsonError),
|
|
Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(
|
|
Box::new(std::io::Error::other(ai_err.to_string())),
|
|
)),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct RigToolSet<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
tools: Vec<Box<dyn ToolDyn + 'static>>,
|
|
context: Option<Arc<Mutex<C>>>,
|
|
}
|
|
|
|
impl<C> RigToolSet<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
pub fn new() -> Self {
|
|
Self {
|
|
tools: Vec::new(),
|
|
context: None,
|
|
}
|
|
}
|
|
|
|
pub fn from_register(
|
|
register: &crate::tool::register::ToolRegister<C>,
|
|
context: Arc<Mutex<C>>,
|
|
) -> Self {
|
|
let mut tools: Vec<Box<dyn ToolDyn + 'static>> =
|
|
Vec::with_capacity(register.len());
|
|
|
|
for tool_arc in ®ister.tools {
|
|
tools.push(Box::new(RigTool::new(
|
|
tool_arc.clone(),
|
|
context.clone(),
|
|
)));
|
|
}
|
|
|
|
Self {
|
|
tools,
|
|
context: Some(context),
|
|
}
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.tools.is_empty()
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.tools.len()
|
|
}
|
|
|
|
pub fn context(&self) -> Option<&Arc<Mutex<C>>> {
|
|
self.context.as_ref()
|
|
}
|
|
|
|
pub fn take_tools(&mut self) -> Vec<Box<dyn ToolDyn + 'static>> {
|
|
std::mem::take(&mut self.tools)
|
|
}
|
|
|
|
pub fn into_context(mut self) -> C {
|
|
self.context
|
|
.take()
|
|
.and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner()))
|
|
.unwrap_or_else(|| unreachable!("context must be available"))
|
|
}
|
|
}
|
|
|
|
impl<C> Default for RigToolSet<C>
|
|
where
|
|
C: Clone + Send + Sync + 'static,
|
|
{
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|