147 lines
4.1 KiB
Rust
147 lines
4.1 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct Toolset {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub tools: Vec<String>,
|
|
pub requires_env: Vec<String>,
|
|
}
|
|
|
|
impl Toolset {
|
|
pub fn new(
|
|
name: impl Into<String>,
|
|
description: impl Into<String>,
|
|
) -> Self {
|
|
Self {
|
|
name: name.into(),
|
|
description: description.into(),
|
|
tools: Vec::new(),
|
|
requires_env: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
|
|
self.tools.push(tool_name.into());
|
|
self
|
|
}
|
|
|
|
pub fn with_tools(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self {
|
|
self.tools.extend(tool_names.into_iter().map(Into::into));
|
|
self
|
|
}
|
|
|
|
pub fn with_required_env(
|
|
mut self,
|
|
env_vars: impl IntoIterator<Item = impl Into<String>>,
|
|
) -> Self {
|
|
self.requires_env.extend(env_vars.into_iter().map(Into::into));
|
|
self
|
|
}
|
|
pub fn is_available(&self) -> bool {
|
|
for env_var in &self.requires_env {
|
|
if std::env::var(env_var).is_err() {
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
}
|
|
|
|
pub fn contains(&self, tool_name: &str) -> bool {
|
|
self.tools.iter().any(|t| t == tool_name)
|
|
}
|
|
}
|
|
pub mod toolset_names {
|
|
pub const CORE: &str = "core";
|
|
pub const TERMINAL: &str = "terminal";
|
|
pub const WEB: &str = "web";
|
|
pub const FILE: &str = "file";
|
|
pub const MEMORY: &str = "memory";
|
|
pub const VISION: &str = "vision";
|
|
pub const SEARCH: &str = "search";
|
|
pub const BROWSER: &str = "browser";
|
|
pub const CODE_EXECUTION: &str = "code_execution";
|
|
pub const DELEGATION: &str = "delegation";
|
|
}
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct ToolsetRegistry {
|
|
toolsets: HashMap<String, Toolset>,
|
|
tool_index: HashMap<String, String>,
|
|
}
|
|
|
|
impl ToolsetRegistry {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
pub fn register(&mut self, toolset: Toolset) {
|
|
let name = toolset.name.clone();
|
|
for tool in &toolset.tools {
|
|
self.tool_index.insert(tool.clone(), name.clone());
|
|
}
|
|
self.toolsets.insert(name, toolset);
|
|
}
|
|
|
|
pub fn get(&self, name: &str) -> Option<&Toolset> {
|
|
self.toolsets.get(name)
|
|
}
|
|
pub fn toolset_for(&self, tool_name: &str) -> Option<&str> {
|
|
self.tool_index.get(tool_name).map(String::as_str)
|
|
}
|
|
pub fn resolve_tool_names(
|
|
&self,
|
|
enabled: &[String],
|
|
disabled: &[String],
|
|
default_all: bool,
|
|
) -> Vec<String> {
|
|
let mut names = HashSet::new();
|
|
let mut denied = HashSet::new();
|
|
|
|
for ts_name in disabled {
|
|
if let Some(ts) = self.toolsets.get(ts_name) {
|
|
for tool in &ts.tools {
|
|
denied.insert(tool.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
if enabled.is_empty() && default_all {
|
|
for ts in self.toolsets.values() {
|
|
if !disabled.contains(&ts.name) && ts.is_available() {
|
|
for tool in &ts.tools {
|
|
if !denied.contains(tool) {
|
|
names.insert(tool.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for ts_name in enabled {
|
|
if let Some(ts) = self.toolsets.get(ts_name) {
|
|
if ts.is_available() {
|
|
for tool in &ts.tools {
|
|
if !denied.contains(tool) {
|
|
names.insert(tool.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut sorted: Vec<String> = names.into_iter().collect();
|
|
sorted.sort();
|
|
sorted
|
|
}
|
|
|
|
pub fn iter(&self) -> impl Iterator<Item = &Toolset> {
|
|
self.toolsets.values()
|
|
}
|
|
|
|
pub fn all_tool_names(&self) -> Vec<String> {
|
|
let mut names: Vec<String> = self.tool_index.keys().cloned().collect();
|
|
names.sort();
|
|
names
|
|
}
|
|
}
|