gitdataai/lib/ai/tool/toolset.rs

151 lines
4.2 KiB
Rust

use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Toolset {
pub name: String,
pub description: String,
pub tools: Vec<String>,
pub requires_env: Vec<String>,
}
impl Toolset {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
tools: Vec::new(),
requires_env: Vec::new(),
}
}
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
self.tools.push(tool_name.into());
self
}
pub fn with_tools(
mut self,
tool_names: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.tools.extend(tool_names.into_iter().map(Into::into));
self
}
pub fn with_required_env(
mut self,
env_vars: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.requires_env
.extend(env_vars.into_iter().map(Into::into));
self
}
pub fn is_available(&self) -> bool {
for env_var in &self.requires_env {
if std::env::var(env_var).is_err() {
return false;
}
}
true
}
pub fn contains(&self, tool_name: &str) -> bool {
self.tools.iter().any(|t| t == tool_name)
}
}
pub mod toolset_names {
pub const CORE: &str = "core";
pub const TERMINAL: &str = "terminal";
pub const WEB: &str = "web";
pub const FILE: &str = "file";
pub const MEMORY: &str = "memory";
pub const VISION: &str = "vision";
pub const SEARCH: &str = "search";
pub const BROWSER: &str = "browser";
pub const CODE_EXECUTION: &str = "code_execution";
pub const DELEGATION: &str = "delegation";
}
#[derive(Clone, Debug, Default)]
pub struct ToolsetRegistry {
toolsets: HashMap<String, Toolset>,
tool_index: HashMap<String, String>,
}
impl ToolsetRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, toolset: Toolset) {
let name = toolset.name.clone();
for tool in &toolset.tools {
self.tool_index.insert(tool.clone(), name.clone());
}
self.toolsets.insert(name, toolset);
}
pub fn get(&self, name: &str) -> Option<&Toolset> {
self.toolsets.get(name)
}
pub fn toolset_for(&self, tool_name: &str) -> Option<&str> {
self.tool_index.get(tool_name).map(String::as_str)
}
pub fn resolve_tool_names(
&self,
enabled: &[String],
disabled: &[String],
default_all: bool,
) -> Vec<String> {
let mut names = HashSet::new();
let mut denied = HashSet::new();
for ts_name in disabled {
if let Some(ts) = self.toolsets.get(ts_name) {
for tool in &ts.tools {
denied.insert(tool.clone());
}
}
}
if enabled.is_empty() && default_all {
for ts in self.toolsets.values() {
if !disabled.contains(&ts.name) && ts.is_available() {
for tool in &ts.tools {
if !denied.contains(tool) {
names.insert(tool.clone());
}
}
}
}
} else {
for ts_name in enabled {
if let Some(ts) = self.toolsets.get(ts_name) {
if ts.is_available() {
for tool in &ts.tools {
if !denied.contains(tool) {
names.insert(tool.clone());
}
}
}
}
}
}
let mut sorted: Vec<String> = names.into_iter().collect();
sorted.sort();
sorted
}
pub fn iter(&self) -> impl Iterator<Item = &Toolset> {
self.toolsets.values()
}
pub fn all_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.tool_index.keys().cloned().collect();
names.sort();
names
}
}