gitdataai/lib/ai/agent/tool.rs
2026-05-30 01:38:40 +08:00

159 lines
3.7 KiB
Rust

use std::pin::Pin;
use std::sync::Arc;
use rig::completion::ToolDefinition as RigToolDefinition;
use rig::tool::ToolDyn;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::tool::tools::FunctionCall;
pub struct RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
context: Arc<Mutex<C>>,
tool: Arc<dyn FunctionCall<Context = C>>,
name: String,
description: String,
schema: Value,
}
impl<C> RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
pub fn new(tool: Arc<dyn FunctionCall<Context = C>>, context: Arc<Mutex<C>>) -> Self {
let name = tool.name().to_string();
let description = tool.description().to_string();
let schema = tool.schema();
Self {
context,
tool,
name,
description,
schema,
}
}
}
impl<C> ToolDyn for RigTool<C>
where
C: Clone + Send + Sync + 'static,
{
fn name(&self) -> String {
self.name.clone()
}
fn definition<'a>(
&'a self,
_prompt: String,
) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>> {
let name = self.name.clone();
let description = self.description.clone();
let params = self.schema.clone();
Box::pin(async move {
RigToolDefinition {
name,
description,
parameters: params,
}
})
}
fn call<'a>(
&'a self,
args: String,
) -> Pin<
Box<dyn std::future::Future<Output = Result<String, rig::tool::ToolError>> + Send + 'a>,
> {
let tool = self.tool.clone();
let context = self.context.clone();
Box::pin(async move {
let args_value: Value =
serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?;
let mut ctx = context.lock().await;
match tool.call(&mut *ctx, args_value).await {
Ok(value) => serde_json::to_string(&value)
.map_err(rig::tool::ToolError::JsonError),
Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new(
std::io::Error::other(ai_err.to_string()),
))),
}
})
}
}
pub struct RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
tools: Vec<Box<dyn ToolDyn + 'static>>,
context: Option<Arc<Mutex<C>>>,
}
impl<C> RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
tools: Vec::new(),
context: None,
}
}
pub fn from_register(
register: &crate::tool::register::ToolRegister<C>,
context: Arc<Mutex<C>>,
) -> Self {
let mut tools: Vec<Box<dyn ToolDyn + 'static>> = Vec::with_capacity(register.len());
for tool_arc in &register.tools {
tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone())));
}
Self {
tools,
context: Some(context),
}
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn context(&self) -> Option<&Arc<Mutex<C>>> {
self.context.as_ref()
}
pub fn take_tools(&mut self) -> Vec<Box<dyn ToolDyn + 'static>> {
std::mem::take(&mut self.tools)
}
pub fn into_context(mut self) -> C {
self.context
.take()
.and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner()))
.unwrap_or_else(|| unreachable!("context must be available"))
}
}
impl<C> Default for RigToolSet<C>
where
C: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}