refactor(fctool): add descriptions to tools and simplify model sync
- Add description field to all fctool file and git tools - Simplify extract_model_name in sync.rs (use upstream id directly)
This commit is contained in:
parent
da2853d0ec
commit
7b43f55f41
@ -16,6 +16,7 @@ pub mod call;
|
||||
pub mod context;
|
||||
pub mod definition;
|
||||
pub mod executor;
|
||||
pub mod recorder;
|
||||
pub mod registry;
|
||||
|
||||
#[cfg(feature = "rig")]
|
||||
@ -28,4 +29,8 @@ pub use call::{ToolCall, ToolCallResult, ToolError, ToolResult};
|
||||
pub use context::ToolContext;
|
||||
pub use definition::{ToolDefinition, ToolParam, ToolSchema};
|
||||
pub use executor::ToolExecutor;
|
||||
pub use recorder::{ToolCallRecord, ToolCallRecorder};
|
||||
pub use registry::{ToolHandler, ToolRegistry};
|
||||
|
||||
#[cfg(feature = "rig")]
|
||||
pub use rig_adapter::{is_retryable_tool_error, RecordingTool, RigToolAdapter, RigToolSet};
|
||||
|
||||
@ -53,6 +53,8 @@ async fn read_csv_exec(
|
||||
|
||||
let commit_oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
git::commit::types::CommitOid::new(&rev)
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
oid
|
||||
} else {
|
||||
domain
|
||||
.commit_get_prefix(&rev)
|
||||
|
||||
@ -69,6 +69,8 @@ async fn git_grep_exec(
|
||||
// Resolve revision to commit oid
|
||||
let commit_oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
git::commit::types::CommitOid::new(&rev)
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
oid
|
||||
} else {
|
||||
domain
|
||||
.commit_get_prefix(&rev)
|
||||
|
||||
@ -132,6 +132,8 @@ async fn read_json_exec(
|
||||
|
||||
let commit_oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
git::commit::types::CommitOid::new(&rev)
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
oid
|
||||
} else {
|
||||
domain
|
||||
.commit_get_prefix(&rev)
|
||||
|
||||
@ -43,6 +43,8 @@ async fn read_markdown_exec(
|
||||
|
||||
let commit_oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
git::commit::types::CommitOid::new(&rev)
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
oid
|
||||
} else {
|
||||
domain
|
||||
.commit_get_prefix(&rev)
|
||||
|
||||
@ -35,6 +35,8 @@ async fn read_sql_exec(
|
||||
|
||||
let commit_oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
git::commit::types::CommitOid::new(&rev)
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
oid
|
||||
} else {
|
||||
domain
|
||||
.commit_get_prefix(&rev)
|
||||
|
||||
@ -115,6 +115,8 @@ fn resolve_oid(
|
||||
) -> Result<git::commit::types::CommitOid, String> {
|
||||
if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
Ok(git::commit::types::CommitOid::new(rev))
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(rev) {
|
||||
Ok(oid)
|
||||
} else {
|
||||
domain.commit_get_prefix(rev).map_err(|e| e.to_string()).map(|m| m.oid)
|
||||
}
|
||||
|
||||
@ -48,10 +48,12 @@ async fn git_log_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result<serde_
|
||||
}
|
||||
|
||||
/// Resolve a rev string to commit metadata. Tries full OID first (exactly 40 hex chars),
|
||||
/// falls back to prefix lookup (branch, tag, short hash).
|
||||
/// then reference name resolution (branch, tag, HEAD), then hex prefix lookup.
|
||||
fn resolve_commit(domain: &git::GitDomain, rev: &str) -> Result<git::commit::types::CommitMeta, String> {
|
||||
if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
domain.commit_get(&git::commit::types::CommitOid::new(rev)).map_err(|e| e.to_string())
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(rev) {
|
||||
domain.commit_get(&oid).map_err(|e| e.to_string())
|
||||
} else {
|
||||
domain.commit_get_prefix(rev).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
@ -19,6 +19,8 @@ async fn git_diff_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result<serde
|
||||
let resolve = |rev: &str| -> Result<git::commit::types::CommitOid, String> {
|
||||
if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
Ok(git::commit::types::CommitOid::new(rev))
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(rev) {
|
||||
Ok(oid)
|
||||
} else {
|
||||
domain.commit_get_prefix(rev).map_err(|e| e.to_string()).map(|m| m.oid)
|
||||
}
|
||||
@ -45,12 +47,13 @@ async fn git_diff_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result<serde
|
||||
if domain.repo().head().is_err() {
|
||||
return Err("No commits found in repository".into());
|
||||
}
|
||||
let head_meta = domain.commit_get_prefix("HEAD").map_err(|e| e.to_string())?;
|
||||
let head_oid = domain.ref_target("HEAD").map_err(|e| e.to_string())?
|
||||
.ok_or_else(|| "HEAD reference not found".to_string())?;
|
||||
// Bare repos have no working tree — use tree-to-tree diff instead
|
||||
if domain.repo().is_bare() {
|
||||
domain.diff_tree_to_tree(None, Some(&head_meta.oid), opts).map_err(|e| e.to_string())?
|
||||
domain.diff_tree_to_tree(None, Some(&head_oid), opts).map_err(|e| e.to_string())?
|
||||
} else {
|
||||
domain.diff_commit_to_workdir(&head_meta.oid, opts).map_err(|e| e.to_string())?
|
||||
domain.diff_commit_to_workdir(&head_oid, opts).map_err(|e| e.to_string())?
|
||||
}
|
||||
}
|
||||
(Some(base), None) => {
|
||||
@ -96,6 +99,8 @@ async fn git_diff_stats_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result
|
||||
let resolve = |rev: &str| -> Result<git::commit::types::CommitOid, String> {
|
||||
if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
Ok(git::commit::types::CommitOid::new(rev))
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(rev) {
|
||||
Ok(oid)
|
||||
} else {
|
||||
domain.commit_get_prefix(rev).map_err(|e| e.to_string()).map(|m| m.oid)
|
||||
}
|
||||
@ -123,6 +128,8 @@ async fn git_blame_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result<serd
|
||||
let domain = ctx.open_repo(project_name, repo_name).await?;
|
||||
let oid = if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
Ok(git::commit::types::CommitOid::new(&rev))
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(&rev) {
|
||||
Ok(oid)
|
||||
} else {
|
||||
domain.commit_get_prefix(&rev).map_err(|e| e.to_string()).map(|m| m.oid)
|
||||
}?;
|
||||
|
||||
@ -6,10 +6,12 @@ use base64::Engine;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Resolve a rev string to a commit OID. Tries full OID first (exactly 40 hex chars),
|
||||
/// falls back to prefix lookup (branch, tag, short hash).
|
||||
/// then reference name resolution (branch, tag, HEAD), then hex prefix lookup.
|
||||
fn resolve_commit_oid(domain: &git::GitDomain, rev: &str) -> Result<git::commit::types::CommitOid, String> {
|
||||
if rev.len() == 40 && rev.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
Ok(git::commit::types::CommitOid::new(rev))
|
||||
} else if let Ok(Some(oid)) = domain.ref_target(rev) {
|
||||
Ok(oid)
|
||||
} else {
|
||||
domain.commit_get_prefix(rev).map_err(|e| e.to_string()).map(|m| m.oid)
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ use models::agents::model_provider::Model as ProviderModel;
|
||||
use models::agents::model_version::Entity as VersionEntity;
|
||||
use models::agents::{CapabilityType, ModelCapability, ModelModality, ModelStatus};
|
||||
use sea_orm::prelude::*;
|
||||
use sea_orm::{QueryOrder, Set};
|
||||
use sea_orm::Set;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use session::Session;
|
||||
@ -240,10 +240,9 @@ async fn upsert_provider(db: &AppDatabase, slug: &str) -> Result<ProviderModel,
|
||||
}
|
||||
}
|
||||
|
||||
/// Upserts a model by name only (deduplication key), ignoring provider.
|
||||
/// This ensures each model name exists only once, regardless of how many
|
||||
/// providers offer it. The first provider encountered is kept.
|
||||
async fn upsert_model_by_name(
|
||||
/// Upserts a model by upstream ID as the deduplication key.
|
||||
/// Each upstream model ID maps to exactly one row in the local `ai_model` table.
|
||||
async fn upsert_model_by_id(
|
||||
db: &AppDatabase,
|
||||
provider_id: Uuid,
|
||||
model: &UpstreamModel,
|
||||
@ -253,11 +252,11 @@ async fn upsert_model_by_name(
|
||||
let capability = infer_capability(model);
|
||||
let ctx = context_length(model);
|
||||
let max_out = max_output_tokens(model);
|
||||
let model_name = extract_model_name(model);
|
||||
let model_id_str = extract_model_name(model);
|
||||
|
||||
use models::agents::model::Column as MCol;
|
||||
if let Some(existing) = ModelEntity::find()
|
||||
.filter(MCol::Name.eq(&model_name))
|
||||
.filter(MCol::Name.eq(&model_id_str))
|
||||
.one(db)
|
||||
.await?
|
||||
{
|
||||
@ -276,7 +275,7 @@ async fn upsert_model_by_name(
|
||||
let active = models::agents::model::ActiveModel {
|
||||
id: Set(Uuid::now_v7()),
|
||||
provider_id: Set(provider_id),
|
||||
name: Set(model_name),
|
||||
name: Set(model_id_str),
|
||||
modality: Set(modality.to_string()),
|
||||
capability: Set(capability.to_string()),
|
||||
context_length: Set(ctx),
|
||||
@ -431,61 +430,16 @@ async fn upsert_parameter_profile(
|
||||
|
||||
// Core sync logic ------------------------------------------------------------
|
||||
|
||||
/// Extracts the base model name from an upstream model ID.
|
||||
/// e.g., "openai/gpt-4o-mini" -> "gpt-4o-mini", "anthropic/claude-3.5-sonnet" -> "claude-3.5-sonnet"
|
||||
/// Extracts the API model identifier from an upstream model.
|
||||
/// Uses the upstream `id` field directly (e.g. "kimi-k2.6") as the model name
|
||||
/// stored in the database, since this is what AI API calls use as the `model` parameter.
|
||||
fn extract_model_name(model: &UpstreamModel) -> String {
|
||||
// Use the name field if available, otherwise extract from id
|
||||
if let Some(name) = &model.name {
|
||||
if !name.is_empty() {
|
||||
return name.clone();
|
||||
}
|
||||
}
|
||||
// Extract from id: "provider/model-name" -> "model-name"
|
||||
model.id.split('/').last().unwrap_or(&model.id).to_string()
|
||||
model.id.clone()
|
||||
}
|
||||
|
||||
/// Deduplicates existing models in the database by name.
|
||||
/// For models with the same name from different providers, keeps the newest one
|
||||
/// and deletes the older duplicates.
|
||||
async fn deduplicate_existing_models(db: &AppDatabase) -> Result<i64, AppError> {
|
||||
use models::agents::model::Entity as MEntity;
|
||||
use models::agents::model::Column as MCol;
|
||||
|
||||
// Find all models grouped by name, ordered by creation time
|
||||
let all_models = MEntity::find()
|
||||
.order_by_asc(MCol::CreatedAt)
|
||||
.all(db.reader())
|
||||
.await
|
||||
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
||||
|
||||
// Group by name
|
||||
let mut name_to_ids: std::collections::HashMap<String, Vec<uuid::Uuid>> =
|
||||
std::collections::HashMap::new();
|
||||
for model in &all_models {
|
||||
name_to_ids
|
||||
.entry(model.name.clone())
|
||||
.or_default()
|
||||
.push(model.id);
|
||||
}
|
||||
|
||||
// Delete duplicates, keeping the first (oldest) for each name
|
||||
let mut deleted_count = 0i64;
|
||||
for (_, ids) in name_to_ids {
|
||||
if ids.len() > 1 {
|
||||
// Keep the first (oldest), delete the rest
|
||||
for id_to_delete in ids.into_iter().skip(1) {
|
||||
MEntity::delete_by_id(id_to_delete)
|
||||
.exec(db.writer())
|
||||
.await
|
||||
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
||||
deleted_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(deleted_count)
|
||||
}
|
||||
|
||||
async fn mark_all_models_offline(db: &AppDatabase) -> Result<i64, AppError> {
|
||||
use models::agents::model::Entity as MEntity;
|
||||
use models::agents::model::Column as MCol;
|
||||
@ -509,32 +463,12 @@ async fn sync_models_from_upstream(
|
||||
db: &AppDatabase,
|
||||
upstream_models: Vec<UpstreamModel>,
|
||||
) -> SyncModelsResponse {
|
||||
// Step 0: Deduplicate existing models in the database by name
|
||||
let existing_deduped = deduplicate_existing_models(db).await.unwrap_or(0);
|
||||
if existing_deduped > 0 {
|
||||
tracing::info!(
|
||||
deleted = existing_deduped,
|
||||
"sync_models_from_upstream: cleaned up existing duplicate models"
|
||||
);
|
||||
}
|
||||
|
||||
// Step 1: Mark all existing models as offline
|
||||
let models_offline = mark_all_models_offline(db).await.unwrap_or(0);
|
||||
|
||||
// Step 2: Deduplicate upstream models by name, keeping the first occurrence
|
||||
let mut seen_names: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
let deduplicated_models: Vec<&UpstreamModel> = upstream_models
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
let name = extract_model_name(m);
|
||||
seen_names.insert(name)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!(
|
||||
upstream_total = upstream_models.len(),
|
||||
deduplicated_count = deduplicated_models.len(),
|
||||
"sync_models_from_upstream: deduplicated upstream models"
|
||||
"sync_models_from_upstream: syncing models"
|
||||
);
|
||||
|
||||
let mut models_created = 0i64;
|
||||
@ -545,7 +479,7 @@ async fn sync_models_from_upstream(
|
||||
let mut capabilities_created = 0i64;
|
||||
let mut profiles_created = 0i64;
|
||||
|
||||
for model in deduplicated_models {
|
||||
for model in &upstream_models {
|
||||
let provider_slug = extract_provider_name(model);
|
||||
let provider = match upsert_provider(db, &provider_slug).await {
|
||||
Ok(p) => p,
|
||||
@ -559,7 +493,7 @@ async fn sync_models_from_upstream(
|
||||
}
|
||||
};
|
||||
|
||||
let (model_record, _is_new) = match upsert_model_by_name(db, provider.id, model).await {
|
||||
let (model_record, _is_new) = match upsert_model_by_id(db, provider.id, model).await {
|
||||
Ok((m, created)) => {
|
||||
if created {
|
||||
models_created += 1;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user