754 lines
25 KiB
Rust
754 lines
25 KiB
Rust
#![allow(dead_code)]
|
|
//! Synchronizes AI model metadata from the upstream AI endpoint
|
|
//! (`GET /v1/models`) into the local database.
|
|
//!
|
|
//! Flow:
|
|
//! 1. Call `GET /v1/models` with the configured AI API key.
|
|
//! 2. Parse the rich response (name, context_length, max_output_tokens,
|
|
//! capabilities, pricing, owned_by) — no external metadata source needed.
|
|
//! 3. Upsert provider / model / version / pricing / capability / profile
|
|
//! records for all accessible models.
|
|
//!
|
|
//! Usage: call `start_sync_task()` to launch a background task that syncs
|
|
//! immediately and then every 10 minutes. On app startup, run it once
|
|
//! eagerly before accepting traffic.
|
|
|
|
use std::time::Duration;
|
|
use tokio::task::JoinHandle;
|
|
use tokio::time::interval;
|
|
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
use chrono::Utc;
|
|
use db::database::AppDatabase;
|
|
use models::agents::model::Entity as ModelEntity;
|
|
use models::agents::model_capability::Entity as CapabilityEntity;
|
|
use models::agents::model_parameter_profile::Entity as ProfileEntity;
|
|
use models::agents::model_provider::Entity as ProviderEntity;
|
|
use models::agents::model_provider::Model as ProviderModel;
|
|
use models::agents::model_version::Entity as VersionEntity;
|
|
use models::agents::{CapabilityType, ModelCapability, ModelModality, ModelStatus};
|
|
use sea_orm::Set;
|
|
use sea_orm::prelude::*;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use session::Session;
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
// Upstream /v1/models response types -----------------------------------------
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct ModelsListResponse {
|
|
data: Vec<UpstreamModel>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct UpstreamModel {
|
|
id: String,
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
owned_by: Option<String>,
|
|
#[serde(default)]
|
|
context_length: Option<u64>,
|
|
#[serde(default)]
|
|
max_output_tokens: Option<u64>,
|
|
#[serde(default)]
|
|
capabilities: Option<UpstreamCapabilities>,
|
|
#[serde(default)]
|
|
pricing: Option<UpstreamPricing>,
|
|
#[serde(default)]
|
|
r#type: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct UpstreamCapabilities {
|
|
#[serde(default)]
|
|
vision: Option<bool>,
|
|
#[serde(default)]
|
|
tool_call: Option<bool>,
|
|
#[serde(default)]
|
|
reasoning: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
struct UpstreamPricing {
|
|
#[serde(default)]
|
|
input: Option<f64>,
|
|
#[serde(default)]
|
|
output: Option<f64>,
|
|
#[serde(default)]
|
|
cache_read: Option<f64>,
|
|
#[serde(default)]
|
|
cache_write: Option<f64>,
|
|
#[serde(default)]
|
|
unit: Option<String>,
|
|
#[serde(default)]
|
|
currency: Option<String>,
|
|
}
|
|
|
|
// Response type --------------------------------------------------------------
|
|
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct SyncModelsResponse {
|
|
pub models_created: i64,
|
|
pub models_updated: i64,
|
|
pub models_offline: i64,
|
|
pub models_deactivated: i64,
|
|
pub versions_created: i64,
|
|
pub pricing_created: i64,
|
|
pub capabilities_created: i64,
|
|
pub profiles_created: i64,
|
|
}
|
|
|
|
// Mapping helpers ------------------------------------------------------------
|
|
|
|
fn infer_modality(model: &UpstreamModel) -> ModelModality {
|
|
if let Some(caps) = &model.capabilities {
|
|
if caps.vision == Some(true) {
|
|
return ModelModality::Multimodal;
|
|
}
|
|
}
|
|
let lower = model.id.to_lowercase();
|
|
if lower.contains("vision")
|
|
|| lower.contains("dall-e")
|
|
|| lower.contains("gpt-image")
|
|
|| lower.contains("gpt-4o")
|
|
{
|
|
ModelModality::Multimodal
|
|
} else if lower.contains("embedding") {
|
|
ModelModality::Text
|
|
} else if lower.contains("whisper") || lower.contains("audio") {
|
|
ModelModality::Audio
|
|
} else {
|
|
ModelModality::Text
|
|
}
|
|
}
|
|
|
|
fn infer_capability(model: &UpstreamModel) -> ModelCapability {
|
|
let lower = model.id.to_lowercase();
|
|
if lower.contains("embedding") {
|
|
ModelCapability::Embedding
|
|
} else {
|
|
ModelCapability::Chat
|
|
}
|
|
}
|
|
|
|
fn context_length(model: &UpstreamModel) -> i64 {
|
|
model.context_length.map(|c| c as i64).unwrap_or(8_192)
|
|
}
|
|
|
|
fn max_output_tokens(model: &UpstreamModel) -> Option<i64> {
|
|
model.max_output_tokens.map(|v| v as i64)
|
|
}
|
|
|
|
fn capability_list(model: &UpstreamModel) -> Vec<(CapabilityType, bool)> {
|
|
let mut caps = Vec::new();
|
|
|
|
// Function call / tool use
|
|
if let Some(u) = &model.capabilities {
|
|
if u.tool_call == Some(true) {
|
|
caps.push((CapabilityType::ToolUse, true));
|
|
}
|
|
if u.vision == Some(true) {
|
|
caps.push((CapabilityType::Vision, true));
|
|
}
|
|
}
|
|
|
|
// Always mark function call as supported by default for chat models
|
|
if caps.is_empty() {
|
|
caps.push((CapabilityType::FunctionCall, true));
|
|
}
|
|
|
|
caps
|
|
}
|
|
|
|
// Provider helpers -----------------------------------------------------------
|
|
|
|
fn extract_provider_name(model: &UpstreamModel) -> String {
|
|
if let Some(owned) = &model.owned_by {
|
|
if !owned.is_empty() {
|
|
return normalize_provider_name(owned);
|
|
}
|
|
}
|
|
normalize_provider_name(model.id.split('/').next().unwrap_or("unknown"))
|
|
}
|
|
|
|
fn normalize_provider_name(slug: &str) -> String {
|
|
match slug {
|
|
"openai" => "openai".to_string(),
|
|
"anthropic" => "anthropic".to_string(),
|
|
"google" | "google-ai" => "google".to_string(),
|
|
"mistralai" => "mistral".to_string(),
|
|
"meta-llama" | "meta" => "meta".to_string(),
|
|
"deepseek" => "deepseek".to_string(),
|
|
"azure" | "azure-openai" => "azure".to_string(),
|
|
"x-ai" | "xai" => "xai".to_string(),
|
|
"moonshot" => "moonshot".to_string(),
|
|
"zai" => "zai".to_string(),
|
|
"minimax" => "minimax".to_string(),
|
|
"alibaba" | "qwen" => "qwen".to_string(),
|
|
s => s.to_string(),
|
|
}
|
|
}
|
|
|
|
fn provider_display_name(name: &str) -> String {
|
|
match name {
|
|
"openai" => "OpenAI".to_string(),
|
|
"anthropic" => "Anthropic".to_string(),
|
|
"google" => "Google DeepMind".to_string(),
|
|
"mistral" => "Mistral AI".to_string(),
|
|
"meta" => "Meta".to_string(),
|
|
"deepseek" => "DeepSeek".to_string(),
|
|
"azure" => "Microsoft Azure".to_string(),
|
|
"xai" => "xAI".to_string(),
|
|
"moonshot" => "Moonshot AI".to_string(),
|
|
"zai" => "Zhipu AI".to_string(),
|
|
"minimax" => "MiniMax".to_string(),
|
|
"qwen" => "Alibaba Qwen".to_string(),
|
|
s => s.to_string(),
|
|
}
|
|
}
|
|
|
|
// Upsert helpers -------------------------------------------------------------
|
|
|
|
async fn upsert_provider(db: &AppDatabase, slug: &str) -> Result<ProviderModel, AppError> {
|
|
let display = provider_display_name(slug);
|
|
let now = Utc::now();
|
|
|
|
use models::agents::model_provider::Column as PCol;
|
|
if let Some(existing) = ProviderEntity::find()
|
|
.filter(PCol::Name.eq(slug))
|
|
.one(db)
|
|
.await?
|
|
{
|
|
let mut active: models::agents::model_provider::ActiveModel = existing.into();
|
|
active.updated_at = Set(now);
|
|
active.update(db).await.map_err(AppError::from)
|
|
} else {
|
|
let active = models::agents::model_provider::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
name: Set(slug.to_string()),
|
|
display_name: Set(display.to_string()),
|
|
website: Set(None),
|
|
status: Set(ModelStatus::Active.to_string()),
|
|
created_at: Set(now),
|
|
updated_at: Set(now),
|
|
};
|
|
active.insert(db).await.map_err(AppError::from)
|
|
}
|
|
}
|
|
|
|
/// Upserts a model by upstream ID as the deduplication key.
|
|
/// Each upstream model ID maps to exactly one row in the local `ai_model` table.
|
|
async fn upsert_model_by_id(
|
|
db: &AppDatabase,
|
|
provider_id: Uuid,
|
|
model: &UpstreamModel,
|
|
) -> Result<(models::agents::model::Model, bool), AppError> {
|
|
let now = Utc::now();
|
|
let modality = infer_modality(model);
|
|
let capability = infer_capability(model);
|
|
let ctx = context_length(model);
|
|
let max_out = max_output_tokens(model);
|
|
let model_id_str = extract_model_name(model);
|
|
|
|
use models::agents::model::Column as MCol;
|
|
if let Some(existing) = ModelEntity::find()
|
|
.filter(MCol::Name.eq(&model_id_str))
|
|
.one(db)
|
|
.await?
|
|
{
|
|
// Update existing model (deduplicated by name)
|
|
let mut active: models::agents::model::ActiveModel = existing.clone().into();
|
|
// Update provider if changed (first provider wins)
|
|
active.provider_id = Set(provider_id);
|
|
active.context_length = Set(ctx);
|
|
active.max_output_tokens = Set(max_out);
|
|
active.status = Set(ModelStatus::Active.to_string());
|
|
active.updated_at = Set(now);
|
|
let updated = active.update(db).await?;
|
|
Ok((updated, false))
|
|
} else {
|
|
// Create new model
|
|
let active = models::agents::model::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
provider_id: Set(provider_id),
|
|
name: Set(model_id_str),
|
|
modality: Set(modality.to_string()),
|
|
capability: Set(capability.to_string()),
|
|
context_length: Set(ctx),
|
|
max_output_tokens: Set(max_out),
|
|
training_cutoff: Set(None),
|
|
is_open_source: Set(false),
|
|
status: Set(ModelStatus::Active.to_string()),
|
|
created_at: Set(now),
|
|
updated_at: Set(now),
|
|
..Default::default()
|
|
};
|
|
let inserted = active.insert(db).await.map_err(AppError::from)?;
|
|
Ok((inserted, true))
|
|
}
|
|
}
|
|
|
|
async fn upsert_version(
|
|
db: &AppDatabase,
|
|
model_uuid: Uuid,
|
|
) -> Result<(models::agents::model_version::Model, bool), AppError> {
|
|
use models::agents::model_version::Column as VCol;
|
|
let now = Utc::now();
|
|
if let Some(existing) = VersionEntity::find()
|
|
.filter(VCol::ModelId.eq(model_uuid))
|
|
.filter(VCol::IsDefault.eq(true))
|
|
.one(db)
|
|
.await?
|
|
{
|
|
Ok((existing, false))
|
|
} else {
|
|
let active = models::agents::model_version::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
model_id: Set(model_uuid),
|
|
version: Set("1".to_string()),
|
|
release_date: Set(None),
|
|
change_log: Set(None),
|
|
is_default: Set(true),
|
|
status: Set(ModelStatus::Active.to_string()),
|
|
created_at: Set(now),
|
|
};
|
|
let inserted = active.insert(db).await.map_err(AppError::from)?;
|
|
Ok((inserted, true))
|
|
}
|
|
}
|
|
|
|
async fn upsert_pricing(
|
|
db: &AppDatabase,
|
|
version_uuid: Uuid,
|
|
pricing: Option<&UpstreamPricing>,
|
|
) -> Result<bool, AppError> {
|
|
use models::agents::model_pricing::Column as PCol;
|
|
use models::agents::model_pricing::Entity as PricingEntity;
|
|
let existing = PricingEntity::find()
|
|
.filter(PCol::ModelVersionId.eq(version_uuid))
|
|
.one(db)
|
|
.await?;
|
|
if existing.is_some() {
|
|
return Ok(false);
|
|
}
|
|
|
|
let (input_price, output_price) = if let Some(p) = pricing {
|
|
(
|
|
format!("{:.2}", p.input.unwrap_or(0.0)),
|
|
format!("{:.2}", p.output.unwrap_or(0.0)),
|
|
)
|
|
} else {
|
|
("0.00".to_string(), "0.00".to_string())
|
|
};
|
|
|
|
let currency = pricing
|
|
.and_then(|p| p.currency.clone())
|
|
.unwrap_or_else(|| "USD".to_string());
|
|
|
|
let active = models::agents::model_pricing::ActiveModel {
|
|
id: Set(Utc::now().timestamp_millis()),
|
|
model_version_id: Set(version_uuid),
|
|
input_price_per_1k_tokens: Set(input_price),
|
|
output_price_per_1k_tokens: Set(output_price),
|
|
currency: Set(currency),
|
|
effective_from: Set(Utc::now()),
|
|
};
|
|
active.insert(db).await.map_err(AppError::from)?;
|
|
Ok(true)
|
|
}
|
|
|
|
async fn upsert_capabilities(
|
|
db: &AppDatabase,
|
|
version_uuid: Uuid,
|
|
model: &UpstreamModel,
|
|
) -> Result<i64, AppError> {
|
|
use models::agents::model_capability::Column as CCol;
|
|
let caps = capability_list(model);
|
|
let now = Utc::now();
|
|
let mut created = 0i64;
|
|
|
|
for (cap_type, supported) in caps {
|
|
let exists = CapabilityEntity::find()
|
|
.filter(CCol::ModelVersionId.eq(version_uuid))
|
|
.filter(CCol::Capability.eq(cap_type.to_string()))
|
|
.one(db)
|
|
.await?;
|
|
if exists.is_some() {
|
|
continue;
|
|
}
|
|
let active = models::agents::model_capability::ActiveModel {
|
|
// FIXME: i64 primary key loses entropy from UUID. Use UUID type in schema.
|
|
id: Set(Uuid::now_v7().as_u128() as i64),
|
|
// FIXME: version_uuid truncated from 128-bit UUID to i64.
|
|
// Schema must be migrated: ALTER COLUMN model_version_id TO UUID type.
|
|
model_version_id: Set(version_uuid.as_u128() as i64),
|
|
capability: Set(cap_type.to_string()),
|
|
is_supported: Set(supported),
|
|
created_at: Set(now),
|
|
};
|
|
active.insert(db).await.map_err(AppError::from)?;
|
|
created += 1;
|
|
}
|
|
Ok(created)
|
|
}
|
|
|
|
async fn upsert_parameter_profile(
|
|
db: &AppDatabase,
|
|
version_uuid: Uuid,
|
|
model: &UpstreamModel,
|
|
) -> Result<bool, AppError> {
|
|
use models::agents::model_parameter_profile::Column as PCol;
|
|
let existing = ProfileEntity::find()
|
|
.filter(PCol::ModelVersionId.eq(version_uuid))
|
|
.one(db)
|
|
.await?;
|
|
if existing.is_some() {
|
|
return Ok(false);
|
|
}
|
|
|
|
let lower = model.id.to_lowercase();
|
|
let (t_min, t_max) = if lower.contains("o1") || lower.contains("o3") {
|
|
(1.0, 1.0)
|
|
} else {
|
|
(0.0, 2.0)
|
|
};
|
|
|
|
let active = models::agents::model_parameter_profile::ActiveModel {
|
|
id: Set(Utc::now().timestamp_millis()),
|
|
model_version_id: Set(version_uuid),
|
|
temperature_min: Set(t_min),
|
|
temperature_max: Set(t_max),
|
|
top_p_min: Set(0.0),
|
|
top_p_max: Set(1.0),
|
|
frequency_penalty_supported: Set(true),
|
|
presence_penalty_supported: Set(true),
|
|
};
|
|
active.insert(db).await.map_err(AppError::from)?;
|
|
Ok(true)
|
|
}
|
|
|
|
// Core sync logic ------------------------------------------------------------
|
|
|
|
/// Extracts the API model identifier from an upstream model.
|
|
/// Uses the upstream `id` field directly (e.g. "kimi-k2.6") as the model name
|
|
/// stored in the database, since this is what AI API calls use as the `model` parameter.
|
|
fn extract_model_name(model: &UpstreamModel) -> String {
|
|
model.id.clone()
|
|
}
|
|
|
|
/// Deduplicates existing models in the database by name.
|
|
/// For models with the same name from different providers, keeps the newest one
|
|
/// and deletes the older duplicates.
|
|
async fn mark_all_models_offline(db: &AppDatabase) -> Result<i64, AppError> {
|
|
use models::agents::model::Column as MCol;
|
|
use models::agents::model::Entity as MEntity;
|
|
|
|
let now = Utc::now();
|
|
let updated = MEntity::update_many()
|
|
.set(models::agents::model::ActiveModel {
|
|
status: Set(ModelStatus::Offline.to_string()),
|
|
updated_at: Set(now),
|
|
..Default::default()
|
|
})
|
|
.filter(MCol::Status.eq(ModelStatus::Active.to_string()))
|
|
.exec(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(updated.rows_affected as i64)
|
|
}
|
|
|
|
async fn sync_models_from_upstream(
|
|
db: &AppDatabase,
|
|
upstream_models: Vec<UpstreamModel>,
|
|
) -> SyncModelsResponse {
|
|
// Step 1: Mark all existing models as offline
|
|
let models_offline = mark_all_models_offline(db).await.unwrap_or(0);
|
|
|
|
tracing::info!(
|
|
upstream_total = upstream_models.len(),
|
|
"sync_models_from_upstream: syncing models"
|
|
);
|
|
|
|
let mut models_created = 0i64;
|
|
let mut models_updated = 0i64;
|
|
let models_deactivated: i64;
|
|
let mut versions_created = 0i64;
|
|
let mut pricing_created = 0i64;
|
|
let mut capabilities_created = 0i64;
|
|
let mut profiles_created = 0i64;
|
|
|
|
for model in &upstream_models {
|
|
let provider_slug = extract_provider_name(model);
|
|
let provider = match upsert_provider(db, &provider_slug).await {
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
provider = %provider_slug,
|
|
error = ?e,
|
|
"sync_models_from_upstream: upsert_provider error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let (model_record, _is_new) = match upsert_model_by_id(db, provider.id, model).await {
|
|
Ok((m, created)) => {
|
|
if created {
|
|
models_created += 1;
|
|
} else {
|
|
models_updated += 1;
|
|
}
|
|
(m, created)
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model.id,
|
|
error = ?e,
|
|
"sync_models_from_upstream: upsert_model error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let (version_record, version_is_new) = match upsert_version(db, model_record.id).await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model.id,
|
|
error = ?e,
|
|
"sync_models_from_upstream: upsert_version error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
if version_is_new {
|
|
versions_created += 1;
|
|
}
|
|
|
|
if upsert_pricing(db, version_record.id, model.pricing.as_ref())
|
|
.await
|
|
.unwrap_or(false)
|
|
{
|
|
pricing_created += 1;
|
|
}
|
|
|
|
capabilities_created += upsert_capabilities(db, version_record.id, model)
|
|
.await
|
|
.unwrap_or(0);
|
|
|
|
if upsert_parameter_profile(db, version_record.id, model)
|
|
.await
|
|
.unwrap_or(false)
|
|
{
|
|
profiles_created += 1;
|
|
}
|
|
}
|
|
|
|
// Step 3: Deactivate models that were offline before but exist locally
|
|
// (These are models that were added manually and are no longer in sync)
|
|
let deactivated = deactivate_orphaned_models(db).await.unwrap_or(0);
|
|
models_deactivated = deactivated;
|
|
|
|
SyncModelsResponse {
|
|
models_created,
|
|
models_updated,
|
|
models_offline,
|
|
models_deactivated,
|
|
versions_created,
|
|
pricing_created,
|
|
capabilities_created,
|
|
profiles_created,
|
|
}
|
|
}
|
|
|
|
/// Deactivates models that were previously marked offline and are not in any active sync.
|
|
/// These are manually added models that are no longer needed.
|
|
async fn deactivate_orphaned_models(db: &AppDatabase) -> Result<i64, AppError> {
|
|
use models::agents::model::Column as MCol;
|
|
use models::agents::model::Entity as MEntity;
|
|
|
|
let now = Utc::now();
|
|
let updated = MEntity::update_many()
|
|
.set(models::agents::model::ActiveModel {
|
|
status: Set(ModelStatus::Deprecated.to_string()),
|
|
updated_at: Set(now),
|
|
..Default::default()
|
|
})
|
|
.filter(MCol::Status.eq(ModelStatus::Offline.to_string()))
|
|
.exec(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(updated.rows_affected as i64)
|
|
}
|
|
|
|
// HTTP helpers ---------------------------------------------------------------
|
|
|
|
/// List models from the upstream AI endpoint (`GET /v1/models`).
|
|
async fn list_upstream_models(
|
|
client: &reqwest::Client,
|
|
base_url: &str,
|
|
api_key: &str,
|
|
) -> Result<Vec<UpstreamModel>, AppError> {
|
|
let base = base_url.trim_end_matches('/');
|
|
let url = if base.ends_with("/v1") {
|
|
format!("{}/models", base)
|
|
} else {
|
|
format!("{}/v1/models", base)
|
|
};
|
|
let resp = client
|
|
.get(&url)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("failed to list models: {}", e)))?;
|
|
|
|
let body = resp
|
|
.text()
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(format!("failed to read models body: {}", e)))?;
|
|
|
|
// Try standard OpenAI-compatible format: { "data": [{...}, ...] }
|
|
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
|
|
return Ok(parsed.data);
|
|
}
|
|
|
|
// Try raw array: [{...}, ...]
|
|
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {
|
|
return Ok(parsed);
|
|
}
|
|
|
|
tracing::warn!(
|
|
body = %body.chars().take(500).collect::<String>(),
|
|
"list_upstream_models: unknown response format"
|
|
);
|
|
Err(AppError::InternalServerError(format!(
|
|
"unexpected /v1/models response format (first 200 chars): {}",
|
|
body.chars().take(200).collect::<String>()
|
|
)))
|
|
}
|
|
|
|
fn build_ai_client(
|
|
config: &config::AppConfig,
|
|
) -> Result<(reqwest::Client, String, String), AppError> {
|
|
let api_key = config
|
|
.ai_api_key()
|
|
.map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?;
|
|
|
|
let base_url = config
|
|
.ai_basic_url()
|
|
.unwrap_or_else(|_| "https://api.openai.com".into());
|
|
|
|
Ok((reqwest::Client::new(), base_url, api_key))
|
|
}
|
|
|
|
fn build_ai_client_from_parts(
|
|
api_key: Option<String>,
|
|
base_url: Option<String>,
|
|
) -> Result<(reqwest::Client, String, String), String> {
|
|
let api_key = api_key.ok_or_else(|| "AI API key not configured".to_string())?;
|
|
let base_url = base_url.unwrap_or_else(|| "https://api.openai.com".into());
|
|
Ok((reqwest::Client::new(), base_url, api_key))
|
|
}
|
|
|
|
// Public API -----------------------------------------------------------------
|
|
|
|
impl AppService {
|
|
/// Sync model metadata from the upstream AI endpoint (`GET /v1/models`).
|
|
///
|
|
/// Parses the full response (name, context_length, max_output_tokens,
|
|
/// capabilities, pricing, owned_by) and upserts all related records.
|
|
pub async fn sync_upstream_models(
|
|
&self,
|
|
_ctx: &Session,
|
|
) -> Result<SyncModelsResponse, AppError> {
|
|
let (http_client, base_url, api_key) = build_ai_client(&self.config)?;
|
|
let upstream_models = list_upstream_models(&http_client, &base_url, &api_key).await?;
|
|
|
|
tracing::info!(
|
|
model_count = upstream_models.len(),
|
|
"sync_upstream_models: {} models from upstream endpoint",
|
|
upstream_models.len()
|
|
);
|
|
|
|
let result = sync_models_from_upstream(&self.db, upstream_models).await;
|
|
|
|
tracing::info!(
|
|
models_created = result.models_created,
|
|
models_updated = result.models_updated,
|
|
versions_created = result.versions_created,
|
|
pricing_created = result.pricing_created,
|
|
capabilities_created = result.capabilities_created,
|
|
profiles_created = result.profiles_created,
|
|
"sync_upstream_models: complete"
|
|
);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Spawn a background task that syncs model metadata immediately
|
|
/// and then every 10 minutes. Returns the `JoinHandle`.
|
|
///
|
|
/// Failures are logged but do not stop the task — it keeps retrying.
|
|
pub fn start_sync_task(self) -> JoinHandle<()> {
|
|
let db = self.db.clone();
|
|
let ai_api_key = self.config.ai_api_key().ok();
|
|
let ai_base_url = self.config.ai_basic_url().ok();
|
|
|
|
tokio::spawn(async move {
|
|
// Run once immediately on startup before taking traffic.
|
|
Self::sync_once(&db, ai_api_key.clone(), ai_base_url.clone()).await;
|
|
|
|
let mut tick = interval(Duration::from_secs(60 * 10));
|
|
loop {
|
|
tick.tick().await;
|
|
Self::sync_once(&db, ai_api_key.clone(), ai_base_url.clone()).await;
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Perform a single sync pass. Errors are logged and silently swallowed
|
|
/// so the periodic task never stops.
|
|
async fn sync_once(db: &AppDatabase, ai_api_key: Option<String>, ai_base_url: Option<String>) {
|
|
let (http_client, base_url, api_key) =
|
|
match build_ai_client_from_parts(ai_api_key, ai_base_url) {
|
|
Ok(c) => c,
|
|
Err(msg) => {
|
|
tracing::warn!(error = %msg, "Model sync: AI client config error");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let upstream_models = match list_upstream_models(&http_client, &base_url, &api_key).await {
|
|
Ok(models) => models,
|
|
Err(e) => {
|
|
tracing::warn!(error = ?e, "Model sync: failed to list upstream models");
|
|
return;
|
|
}
|
|
};
|
|
|
|
tracing::info!(
|
|
model_count = upstream_models.len(),
|
|
"Model sync: {} models from upstream",
|
|
upstream_models.len()
|
|
);
|
|
|
|
let result = sync_models_from_upstream(db, upstream_models).await;
|
|
|
|
tracing::info!(
|
|
models_created = result.models_created,
|
|
models_updated = result.models_updated,
|
|
versions_created = result.versions_created,
|
|
pricing_created = result.pricing_created,
|
|
capabilities_created = result.capabilities_created,
|
|
profiles_created = result.profiles_created,
|
|
"Model sync complete"
|
|
);
|
|
}
|
|
}
|