use std::time::Duration; use ai::sync::{UpstreamModel, UpstreamPricing}; use ai::client::EndpointConfig; use chrono::Utc; use db::sqlx::{self, types::Decimal}; use model::ai::{AiModelModel, AiModelVersionModel, AiProviderModel}; use serde::Serialize; use tokio::time::interval; use utoipa::ToSchema; use uuid::Uuid; use crate::{AppService, error::AppError}; #[derive(Debug, Clone, Serialize, ToSchema)] pub struct SyncModelsResponse { pub models_created: i64, pub models_updated: i64, pub models_offline: i64, pub models_deactivated: i64, pub versions_created: i64, pub pricing_created: i64, pub pricing_updated: i64, } #[derive(Debug, Clone)] pub struct SyncResult { pub provider_id: Uuid, pub provider_name: String, pub total: u32, pub synced: u32, pub skipped: u32, } fn extract_provider_name(model: &UpstreamModel) -> String { if let Some(owned) = &model.owned_by { if !owned.is_empty() { return normalize_provider_name(owned); } } normalize_provider_name(model.id.split('/').next().unwrap_or("unknown")) } fn normalize_provider_name(slug: &str) -> String { match slug { "openai" => "openai", "anthropic" => "anthropic", "google" | "google-ai" => "google", "mistralai" => "mistral", "meta-llama" | "meta" => "meta", "deepseek" => "deepseek", "azure" | "azure-openai" => "azure", "x-ai" | "xai" => "xai", "moonshot" => "moonshot", "alibaba" | "qwen" => "qwen", s => s, } .to_string() } fn provider_display_name(name: &str) -> String { match name { "openai" => "OpenAI", "anthropic" => "Anthropic", "google" => "Google DeepMind", "mistral" => "Mistral AI", "meta" => "Meta", "deepseek" => "DeepSeek", "azure" => "Microsoft Azure", "xai" => "xAI", "moonshot" => "Moonshot AI", "qwen" => "Alibaba Qwen", s => s, } .to_string() } fn infer_modality(model: &UpstreamModel) -> &'static str { if let Some(caps) = &model.capabilities { if caps.vision == Some(true) { return "multimodal"; } } let lower = model.id.to_lowercase(); if lower.contains("vision") || lower.contains("dall-e") || lower.contains("gpt-image") || lower.contains("gpt-4o") { "multimodal" } else if lower.contains("embedding") { "text" } else if lower.contains("whisper") || lower.contains("audio") { "audio" } else { "text" } } async fn upsert_provider( db: &db::AppDatabase, slug: &str, ) -> Result { let _display = provider_display_name(slug); let now = Utc::now(); if let Some(existing) = sqlx::query_as::<_, AiProviderModel>( "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ FROM ai_provider WHERE name = $1", ) .bind(slug) .fetch_optional(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? { sqlx::query( "UPDATE ai_provider SET updated_at = $1 WHERE id = $2", ) .bind(now) .bind(existing.id) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(existing) } else { let id = Uuid::now_v7(); sqlx::query( "INSERT INTO ai_provider (id, name, enabled, created_at, updated_at) \ VALUES ($1, $2, true, $3, $3)", ) .bind(id) .bind(slug) .bind(now) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(AiProviderModel { id, name: slug.to_string(), base_url: None, website_url: None, logo_url: None, enabled: true, created_at: now, updated_at: now, }) } } async fn upsert_model( db: &db::AppDatabase, provider_id: Uuid, model: &UpstreamModel, ) -> Result<(AiModelModel, bool), AppError> { let now = Utc::now(); let name = &model.id; let modality = infer_modality(model); let ctx = model.context_length.map(|c| c as i32); let max_out = model.max_output_tokens.map(|v| v as i32); if let Some(existing) = sqlx::query_as::<_, AiModelModel>( "SELECT id, provider, name, display_name, description, modality, \ context_window, input_token_limit, output_token_limit, \ enabled, public, created_at, updated_at, deleted_at \ FROM ai_model WHERE name = $1 AND deleted_at IS NULL", ) .bind(name) .fetch_optional(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? { sqlx::query( "UPDATE ai_model SET provider = $1, context_window = $2, \ output_token_limit = $3, enabled = true, updated_at = $4 \ WHERE id = $5", ) .bind(provider_id) .bind(ctx) .bind(max_out) .bind(now) .bind(existing.id) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok((existing, false)) } else { let id = Uuid::now_v7(); sqlx::query( "INSERT INTO ai_model (id, provider, name, display_name, modality, \ context_window, input_token_limit, output_token_limit, enabled, \ public, created_at, updated_at) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, true, $9, $9)", ) .bind(id) .bind(provider_id) .bind(name) .bind(name) .bind(modality) .bind(ctx) .bind(ctx) .bind(max_out) .bind(now) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; let inserted = sqlx::query_as::<_, AiModelModel>( "SELECT id, provider, name, display_name, description, modality, \ context_window, input_token_limit, output_token_limit, \ enabled, public, created_at, updated_at, deleted_at \ FROM ai_model WHERE id = $1", ) .bind(id) .fetch_one(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok((inserted, true)) } } async fn upsert_version( db: &db::AppDatabase, model_id: Uuid, provider_model_name: &str, ) -> Result<(AiModelVersionModel, bool), AppError> { let now = Utc::now(); if let Some(existing) = sqlx::query_as::<_, AiModelVersionModel>( "SELECT id, model, version, provider_model_name, \ input_price_per_million, output_price_per_million, cached_input_price_per_million, \ training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ FROM ai_model_version WHERE model = $1 AND provider_model_name = $2", ) .bind(model_id) .bind(provider_model_name) .fetch_optional(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? { Ok((existing, false)) } else { let id = Uuid::now_v7(); sqlx::query( "INSERT INTO ai_model_version (id, model, version, provider_model_name, \ enabled, created_at, updated_at) \ VALUES ($1, $2, 'latest', $3, true, $4, $4)", ) .bind(id) .bind(model_id) .bind(provider_model_name) .bind(now) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; let inserted = sqlx::query_as::<_, AiModelVersionModel>( "SELECT id, model, version, provider_model_name, \ input_price_per_million, output_price_per_million, cached_input_price_per_million, \ training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ FROM ai_model_version WHERE id = $1", ) .bind(id) .fetch_one(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok((inserted, true)) } } async fn upsert_pricing( db: &db::AppDatabase, version_id: Uuid, pricing: Option<&UpstreamPricing>, ) -> Result { let Some(p) = pricing else { return Ok(PricingResult::Skipped); }; let input_million: Option = p.prompt.as_deref().and_then(parse_token_price_decimal) .map(|per_token| per_token * Decimal::from(1_000_000u64)) .or_else(|| { p.input .filter(|v| *v > 0.0) .map(|v| Decimal::try_from(v).unwrap_or_default()) }); let output_million: Option = p.completion.as_deref().and_then(parse_token_price_decimal) .map(|per_token| per_token * Decimal::from(1_000_000u64)) .or_else(|| { p.output .filter(|v| *v > 0.0) .map(|v| Decimal::try_from(v).unwrap_or_default()) }); let cache_input: Option = p.cache_read .filter(|v| *v > 0.0) .map(|v| Decimal::try_from(v).unwrap_or_default()); if input_million.is_none() && output_million.is_none() { return Ok(PricingResult::Skipped); } let existing = sqlx::query_scalar::<_, Uuid>( "SELECT id FROM ai_model_version WHERE id = $1", ) .bind(version_id) .fetch_optional(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; if existing.is_none() { return Ok(PricingResult::Skipped); } let count = sqlx::query_scalar::<_, i64>( "SELECT COUNT(*) FROM ai_model_version \ WHERE id = $1 AND input_price_per_million IS NOT NULL", ) .bind(version_id) .fetch_one(db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; sqlx::query( "UPDATE ai_model_version SET \ input_price_per_million = COALESCE($1, input_price_per_million), \ output_price_per_million = COALESCE($2, output_price_per_million), \ cached_input_price_per_million = COALESCE($3, cached_input_price_per_million), \ updated_at = $4 \ WHERE id = $5", ) .bind(&input_million) .bind(&output_million) .bind(&cache_input) .bind(Utc::now()) .bind(version_id) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; if count > 0 { Ok(PricingResult::Updated) } else { Ok(PricingResult::Created) } } enum PricingResult { Created, Updated, Skipped, } fn parse_token_price_decimal(s: &str) -> Option { use std::str::FromStr; Decimal::from_str(s).ok() } async fn disable_all_models(db: &db::AppDatabase) -> Result { let result = sqlx::query( "UPDATE ai_model SET enabled = false, updated_at = $1 WHERE enabled = true", ) .bind(Utc::now()) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(result.rows_affected() as i64) } async fn deactivate_orphaned_models(db: &db::AppDatabase) -> Result { let now = Utc::now(); sqlx::query( "UPDATE ai_model_version SET enabled = false, updated_at = $1 \ WHERE model IN (SELECT id FROM ai_model WHERE enabled = false AND deleted_at IS NULL)", ) .bind(now) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; let result = sqlx::query( "UPDATE ai_model SET deleted_at = $1, updated_at = $1 \ WHERE enabled = false AND deleted_at IS NULL", ) .bind(now) .execute(db.writer()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))?; Ok(result.rows_affected() as i64) } async fn sync_models_from_upstream( db: &db::AppDatabase, upstream_models: Vec, ) -> SyncModelsResponse { let models_offline = disable_all_models(db).await.unwrap_or(0); tracing::info!( upstream_total = upstream_models.len(), "syncing models from upstream" ); let mut models_created = 0i64; let mut models_updated = 0i64; let mut versions_created = 0i64; let mut pricing_created = 0i64; let mut pricing_updated = 0i64; for model in &upstream_models { let provider_slug = extract_provider_name(model); let provider = match upsert_provider(db, &provider_slug).await { Ok(p) => p, Err(e) => { tracing::warn!( provider = %provider_slug, error = %e, "sync: upsert_provider error" ); continue; } }; let (model_record, _is_new) = match upsert_model(db, provider.id, model).await { Ok((m, created)) => { if created { models_created += 1; } else { models_updated += 1; } (m, created) } Err(e) => { tracing::warn!( model = %model.id, error = %e, "sync: upsert_model error" ); continue; } }; let (version_record, version_is_new) = match upsert_version(db, model_record.id, &model.id).await { Ok(v) => v, Err(e) => { tracing::warn!( model = %model.id, error = %e, "sync: upsert_version error" ); continue; } }; if version_is_new { versions_created += 1; } match upsert_pricing(db, version_record.id, model.pricing.as_ref()).await { Ok(PricingResult::Created) => pricing_created += 1, Ok(PricingResult::Updated) => pricing_updated += 1, Ok(PricingResult::Skipped) => {} Err(e) => { tracing::warn!( model = %model.id, error = %e, "sync: upsert_pricing error" ); } } } let deactivated = deactivate_orphaned_models(db).await.unwrap_or(0); SyncModelsResponse { models_created, models_updated, models_offline, models_deactivated: deactivated, versions_created, pricing_created, pricing_updated, } } impl AppService { pub async fn sync_upstream_models(&self) -> Result { let api_key = self .config .ai_api_key() .map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?; let base_url = self.config.ai_basic_url().unwrap_or_default(); let config = EndpointConfig::new(&base_url, &api_key) .map_err(|e| AppError::InternalServerError(e.to_string()))?; let upstream_models = ai::sync::list_models(&config) .await .map_err(|e| AppError::InternalServerError(e.to_string()))?; tracing::info!( model_count = upstream_models.len(), "sync_upstream_models: {} models from upstream", upstream_models.len() ); let result = sync_models_from_upstream(&self.db, upstream_models).await; tracing::info!( models_created = result.models_created, models_updated = result.models_updated, versions_created = result.versions_created, pricing_created = result.pricing_created, pricing_updated = result.pricing_updated, "sync_upstream_models: complete" ); Ok(result) } } pub fn spawn_model_sync_loop(service: AppService) -> tokio::task::JoinHandle<()> { let db = service.db.clone(); let config = service.config.clone(); tokio::spawn(async move { sync_once(&db, &config).await; let mut tick = interval(Duration::from_secs(60 * 10)); loop { tick.tick().await; sync_once(&db, &config).await; } }) } async fn sync_once(db: &db::AppDatabase, config: &config::AppConfig) { let api_key = match config.ai_api_key() { Ok(k) => k, Err(e) => { tracing::warn!(error = %e, "Model sync: AI API key not configured"); return; } }; let base_url = config.ai_basic_url().unwrap_or_default(); let endpoint_config = match EndpointConfig::new(&base_url, &api_key) { Ok(c) => c, Err(e) => { tracing::warn!(error = %e, "Model sync: invalid endpoint config"); return; } }; let upstream_models = match ai::sync::list_models(&endpoint_config).await { Ok(models) => models, Err(e) => { tracing::warn!(error = %e, "Model sync: failed to list upstream models"); return; } }; tracing::info!( model_count = upstream_models.len(), "Model sync: {} models from upstream", upstream_models.len() ); let result = sync_models_from_upstream(db, upstream_models).await; tracing::info!( models_created = result.models_created, models_updated = result.models_updated, versions_created = result.versions_created, pricing_created = result.pricing_created, pricing_updated = result.pricing_updated, "Model sync complete" ); }