gitdataai/libs/agent/tool/rig_adapter.rs
ZhenYi 10c0cc007b refactor(agent): split into submodules and add Qdrant embedding
- Split agent crate into client/, model/, agent/ subdirs
- Add billing.rs for token usage recording
- Add sync.rs for upstream model sync
- EmbedService: Qdrant-backed vector memory for semantic search
- ChatService: wire EmbedService for memory lookup, passive skill awareness
- ReAct loop: streamline with tokio::select! and proper error handling
2026-04-25 20:09:33 +08:00

158 lines
5.3 KiB
Rust

//! Adapter to bridge our ToolRegistry with rig's Tool system.
//!
//! This module provides adapters that wrap our custom ToolHandler/Registry
//! to implement rig's ToolDyn trait, enabling integration with rig's Agent.
use std::collections::HashMap;
use futures::FutureExt;
use rig::completion::ToolDefinition;
use rig::tool::{ToolDyn, ToolError, ToolSet};
use super::context::ToolContext;
use super::definition::ToolDefinition as AgentToolDefinition;
use super::registry::{ToolHandler, ToolRegistry};
/// A wrapper that converts our ToolRegistry to rig's ToolSet.
pub struct RigToolSet {
/// The rig ToolSet
inner: ToolSet,
/// Tool definitions for converting back
definitions: HashMap<String, AgentToolDefinition>,
}
impl RigToolSet {
/// Create a new RigToolSet from our ToolRegistry.
pub fn from_registry(
registry: &ToolRegistry,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
) -> Self {
let mut toolset = ToolSet::default();
let mut definitions = HashMap::new();
for name in registry.definitions().map(|d| d.name.clone()).collect::<Vec<_>>() {
let def = registry.definitions().find(|d| d.name == name).cloned().unwrap_or_else(|| {
AgentToolDefinition::new(&name)
});
definitions.insert(name.clone(), def.clone());
let handler = registry.get(&name).cloned();
if let Some(handler) = handler {
let adapter = RigToolAdapter {
handler,
definition: def,
db: db.clone(),
cache: cache.clone(),
config: config.clone(),
room_id,
sender_id,
};
toolset.add_tool(adapter);
}
}
Self { inner: toolset, definitions }
}
/// Get the inner rig ToolSet
pub fn inner(&self) -> &ToolSet {
&self.inner
}
/// Get the tool definitions
pub fn definitions(&self) -> &HashMap<String, AgentToolDefinition> {
&self.definitions
}
/// Convert to JSON tool definitions for non-rig paths
pub fn to_openai_tools(&self) -> Vec<serde_json::Value> {
self.definitions.values()
.map(|d| d.to_openai_tool())
.collect()
}
}
/// Adapter that wraps our ToolHandler to implement rig's ToolDyn.
pub struct RigToolAdapter {
handler: ToolHandler,
definition: AgentToolDefinition,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
config: config::AppConfig,
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
}
impl ToolDyn for RigToolAdapter {
fn name(&self) -> String {
self.definition.name.clone()
}
fn definition<'a>(&'a self, _prompt: String) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolDefinition> + Send + 'a>> {
let def = self.definition.clone();
Box::pin(async move {
ToolDefinition {
name: def.name.clone(),
description: def.description.unwrap_or_default(),
parameters: def.parameters
.as_ref()
.map(|p| serde_json::to_value(p).unwrap_or(serde_json::json!({})))
.unwrap_or(serde_json::json!({})),
}
})
}
fn call<'a>(&'a self, args: String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String, ToolError>> + Send + 'a>> {
let handler = self.handler.clone();
let db = self.db.clone();
let cache = self.cache.clone();
let config = self.config.clone();
let room_id = self.room_id;
let sender_id = self.sender_id;
async move {
let ctx = ToolContext::new(
db,
cache,
config,
room_id,
sender_id,
);
let args_json: serde_json::Value = serde_json::from_str(&args)
.map_err(|e| ToolError::JsonError(e))?;
let result = handler.execute(ctx, args_json).await;
match result {
Ok(value) => {
serde_json::to_string(&value)
.map_err(|e| ToolError::JsonError(e))
}
Err(e) => {
let error_msg = match e {
super::call::ToolError::NotFound(n) => n,
super::call::ToolError::ParseError(p) => p,
super::call::ToolError::ExecutionError(e) => e,
super::call::ToolError::RecursionLimitExceeded { max_depth } => {
format!("recursion limit exceeded (max depth: {})", max_depth)
}
super::call::ToolError::MaxToolCallsExceeded(n) => {
format!("max tool calls exceeded: {}", n)
}
super::call::ToolError::Internal(i) => i,
};
Err(ToolError::ToolCallError(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
error_msg,
))))
}
}
}.boxed()
}
}