604 lines
18 KiB
Rust
604 lines
18 KiB
Rust
use std::time::Duration;
|
|
|
|
use ai::client::EndpointConfig;
|
|
use ai::sync::{UpstreamModel, UpstreamPricing};
|
|
use chrono::Utc;
|
|
use db::sqlx::{self, types::Decimal};
|
|
use model::ai::{AiModelModel, AiModelVersionModel, AiProviderModel};
|
|
use serde::Serialize;
|
|
use tokio::time::interval;
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{AppService, error::AppError};
|
|
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct SyncModelsResponse {
|
|
pub models_created: i64,
|
|
pub models_updated: i64,
|
|
pub models_offline: i64,
|
|
pub models_deactivated: i64,
|
|
pub versions_created: i64,
|
|
pub pricing_created: i64,
|
|
pub pricing_updated: i64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct SyncResult {
|
|
pub provider_id: Uuid,
|
|
pub provider_name: String,
|
|
pub total: u32,
|
|
pub synced: u32,
|
|
pub skipped: u32,
|
|
}
|
|
|
|
fn extract_provider_name(model: &UpstreamModel) -> String {
|
|
if let Some(owned) = &model.owned_by {
|
|
if !owned.is_empty() {
|
|
return normalize_provider_name(owned);
|
|
}
|
|
}
|
|
normalize_provider_name(model.id.split('/').next().unwrap_or("unknown"))
|
|
}
|
|
|
|
fn normalize_provider_name(slug: &str) -> String {
|
|
match slug {
|
|
"openai" => "openai",
|
|
"anthropic" => "anthropic",
|
|
"google" | "google-ai" => "google",
|
|
"mistralai" => "mistral",
|
|
"meta-llama" | "meta" => "meta",
|
|
"deepseek" => "deepseek",
|
|
"azure" | "azure-openai" => "azure",
|
|
"x-ai" | "xai" => "xai",
|
|
"moonshot" => "moonshot",
|
|
"alibaba" | "qwen" => "qwen",
|
|
s => s,
|
|
}
|
|
.to_string()
|
|
}
|
|
|
|
fn provider_display_name(name: &str) -> String {
|
|
match name {
|
|
"openai" => "OpenAI",
|
|
"anthropic" => "Anthropic",
|
|
"google" => "Google DeepMind",
|
|
"mistral" => "Mistral AI",
|
|
"meta" => "Meta",
|
|
"deepseek" => "DeepSeek",
|
|
"azure" => "Microsoft Azure",
|
|
"xai" => "xAI",
|
|
"moonshot" => "Moonshot AI",
|
|
"qwen" => "Alibaba Qwen",
|
|
s => s,
|
|
}
|
|
.to_string()
|
|
}
|
|
|
|
fn infer_modality(model: &UpstreamModel) -> &'static str {
|
|
if let Some(caps) = &model.capabilities {
|
|
if caps.vision == Some(true) {
|
|
return "multimodal";
|
|
}
|
|
}
|
|
let lower = model.id.to_lowercase();
|
|
if lower.contains("vision")
|
|
|| lower.contains("dall-e")
|
|
|| lower.contains("gpt-image")
|
|
|| lower.contains("gpt-4o")
|
|
{
|
|
"multimodal"
|
|
} else if lower.contains("embedding") {
|
|
"text"
|
|
} else if lower.contains("whisper") || lower.contains("audio") {
|
|
"audio"
|
|
} else {
|
|
"text"
|
|
}
|
|
}
|
|
|
|
async fn upsert_provider(
|
|
db: &db::AppDatabase,
|
|
slug: &str,
|
|
) -> Result<AiProviderModel, AppError> {
|
|
let _display = provider_display_name(slug);
|
|
let now = Utc::now();
|
|
|
|
if let Some(existing) = sqlx::query_as::<_, AiProviderModel>(
|
|
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
|
|
FROM ai_provider WHERE name = $1",
|
|
)
|
|
.bind(slug)
|
|
.fetch_optional(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
{
|
|
sqlx::query(
|
|
"UPDATE ai_provider SET updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(existing.id)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
Ok(existing)
|
|
} else {
|
|
let id = Uuid::now_v7();
|
|
sqlx::query(
|
|
"INSERT INTO ai_provider (id, name, enabled, created_at, updated_at) \
|
|
VALUES ($1, $2, true, $3, $3)",
|
|
)
|
|
.bind(id)
|
|
.bind(slug)
|
|
.bind(now)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(AiProviderModel {
|
|
id,
|
|
name: slug.to_string(),
|
|
base_url: None,
|
|
website_url: None,
|
|
logo_url: None,
|
|
enabled: true,
|
|
created_at: now,
|
|
updated_at: now,
|
|
})
|
|
}
|
|
}
|
|
|
|
async fn upsert_model(
|
|
db: &db::AppDatabase,
|
|
provider_id: Uuid,
|
|
model: &UpstreamModel,
|
|
) -> Result<(AiModelModel, bool), AppError> {
|
|
let now = Utc::now();
|
|
let name = &model.id;
|
|
let modality = infer_modality(model);
|
|
let ctx = model.context_length.map(|c| c as i32);
|
|
let max_out = model.max_output_tokens.map(|v| v as i32);
|
|
|
|
if let Some(existing) = sqlx::query_as::<_, AiModelModel>(
|
|
"SELECT id, provider, name, display_name, description, modality, \
|
|
context_window, input_token_limit, output_token_limit, \
|
|
enabled, public, created_at, updated_at, deleted_at \
|
|
FROM ai_model WHERE name = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(name)
|
|
.fetch_optional(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
{
|
|
sqlx::query(
|
|
"UPDATE ai_model SET provider = $1, context_window = $2, \
|
|
output_token_limit = $3, enabled = true, updated_at = $4 \
|
|
WHERE id = $5",
|
|
)
|
|
.bind(provider_id)
|
|
.bind(ctx)
|
|
.bind(max_out)
|
|
.bind(now)
|
|
.bind(existing.id)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok((existing, false))
|
|
} else {
|
|
let id = Uuid::now_v7();
|
|
sqlx::query(
|
|
"INSERT INTO ai_model (id, provider, name, display_name, modality, \
|
|
context_window, input_token_limit, output_token_limit, enabled, \
|
|
public, created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, true, $9, $9)",
|
|
)
|
|
.bind(id)
|
|
.bind(provider_id)
|
|
.bind(name)
|
|
.bind(name)
|
|
.bind(modality)
|
|
.bind(ctx)
|
|
.bind(ctx)
|
|
.bind(max_out)
|
|
.bind(now)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let inserted = sqlx::query_as::<_, AiModelModel>(
|
|
"SELECT id, provider, name, display_name, description, modality, \
|
|
context_window, input_token_limit, output_token_limit, \
|
|
enabled, public, created_at, updated_at, deleted_at \
|
|
FROM ai_model WHERE id = $1",
|
|
)
|
|
.bind(id)
|
|
.fetch_one(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok((inserted, true))
|
|
}
|
|
}
|
|
|
|
async fn upsert_version(
|
|
db: &db::AppDatabase,
|
|
model_id: Uuid,
|
|
provider_model_name: &str,
|
|
) -> Result<(AiModelVersionModel, bool), AppError> {
|
|
let now = Utc::now();
|
|
|
|
if let Some(existing) = sqlx::query_as::<_, AiModelVersionModel>(
|
|
"SELECT id, model, version, provider_model_name, \
|
|
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
|
|
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
|
|
FROM ai_model_version WHERE model = $1 AND provider_model_name = $2",
|
|
)
|
|
.bind(model_id)
|
|
.bind(provider_model_name)
|
|
.fetch_optional(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
{
|
|
Ok((existing, false))
|
|
} else {
|
|
let id = Uuid::now_v7();
|
|
sqlx::query(
|
|
"INSERT INTO ai_model_version (id, model, version, provider_model_name, \
|
|
enabled, created_at, updated_at) \
|
|
VALUES ($1, $2, 'latest', $3, true, $4, $4)",
|
|
)
|
|
.bind(id)
|
|
.bind(model_id)
|
|
.bind(provider_model_name)
|
|
.bind(now)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let inserted = sqlx::query_as::<_, AiModelVersionModel>(
|
|
"SELECT id, model, version, provider_model_name, \
|
|
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
|
|
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
|
|
FROM ai_model_version WHERE id = $1",
|
|
)
|
|
.bind(id)
|
|
.fetch_one(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok((inserted, true))
|
|
}
|
|
}
|
|
|
|
async fn upsert_pricing(
|
|
db: &db::AppDatabase,
|
|
version_id: Uuid,
|
|
pricing: Option<&UpstreamPricing>,
|
|
) -> Result<PricingResult, AppError> {
|
|
let Some(p) = pricing else {
|
|
return Ok(PricingResult::Skipped);
|
|
};
|
|
let input_million: Option<Decimal> = p
|
|
.prompt
|
|
.as_deref()
|
|
.and_then(parse_token_price_decimal)
|
|
.map(|per_token| per_token * Decimal::from(1_000_000u64))
|
|
.or_else(|| {
|
|
p.input
|
|
.filter(|v| *v > 0.0)
|
|
.map(|v| Decimal::try_from(v).unwrap_or_default())
|
|
});
|
|
|
|
let output_million: Option<Decimal> = p
|
|
.completion
|
|
.as_deref()
|
|
.and_then(parse_token_price_decimal)
|
|
.map(|per_token| per_token * Decimal::from(1_000_000u64))
|
|
.or_else(|| {
|
|
p.output
|
|
.filter(|v| *v > 0.0)
|
|
.map(|v| Decimal::try_from(v).unwrap_or_default())
|
|
});
|
|
|
|
let cache_input: Option<Decimal> = p
|
|
.cache_read
|
|
.filter(|v| *v > 0.0)
|
|
.map(|v| Decimal::try_from(v).unwrap_or_default());
|
|
|
|
if input_million.is_none() && output_million.is_none() {
|
|
return Ok(PricingResult::Skipped);
|
|
}
|
|
|
|
let existing = sqlx::query_scalar::<_, Uuid>(
|
|
"SELECT id FROM ai_model_version WHERE id = $1",
|
|
)
|
|
.bind(version_id)
|
|
.fetch_optional(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
if existing.is_none() {
|
|
return Ok(PricingResult::Skipped);
|
|
}
|
|
|
|
let count = sqlx::query_scalar::<_, i64>(
|
|
"SELECT COUNT(*) FROM ai_model_version \
|
|
WHERE id = $1 AND input_price_per_million IS NOT NULL",
|
|
)
|
|
.bind(version_id)
|
|
.fetch_one(db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
sqlx::query(
|
|
"UPDATE ai_model_version SET \
|
|
input_price_per_million = COALESCE($1, input_price_per_million), \
|
|
output_price_per_million = COALESCE($2, output_price_per_million), \
|
|
cached_input_price_per_million = COALESCE($3, cached_input_price_per_million), \
|
|
updated_at = $4 \
|
|
WHERE id = $5",
|
|
)
|
|
.bind(&input_million)
|
|
.bind(&output_million)
|
|
.bind(&cache_input)
|
|
.bind(Utc::now())
|
|
.bind(version_id)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
if count > 0 {
|
|
Ok(PricingResult::Updated)
|
|
} else {
|
|
Ok(PricingResult::Created)
|
|
}
|
|
}
|
|
|
|
enum PricingResult {
|
|
Created,
|
|
Updated,
|
|
Skipped,
|
|
}
|
|
|
|
fn parse_token_price_decimal(s: &str) -> Option<Decimal> {
|
|
use std::str::FromStr;
|
|
Decimal::from_str(s).ok()
|
|
}
|
|
|
|
async fn disable_all_models(db: &db::AppDatabase) -> Result<i64, AppError> {
|
|
let result = sqlx::query(
|
|
"UPDATE ai_model SET enabled = false, updated_at = $1 WHERE enabled = true",
|
|
)
|
|
.bind(Utc::now())
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(result.rows_affected() as i64)
|
|
}
|
|
|
|
async fn deactivate_orphaned_models(
|
|
db: &db::AppDatabase,
|
|
) -> Result<i64, AppError> {
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE ai_model_version SET enabled = false, updated_at = $1 \
|
|
WHERE model IN (SELECT id FROM ai_model WHERE enabled = false AND deleted_at IS NULL)",
|
|
)
|
|
.bind(now)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let result = sqlx::query(
|
|
"UPDATE ai_model SET deleted_at = $1, updated_at = $1 \
|
|
WHERE enabled = false AND deleted_at IS NULL",
|
|
)
|
|
.bind(now)
|
|
.execute(db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(result.rows_affected() as i64)
|
|
}
|
|
|
|
async fn sync_models_from_upstream(
|
|
db: &db::AppDatabase,
|
|
upstream_models: Vec<UpstreamModel>,
|
|
) -> SyncModelsResponse {
|
|
let models_offline = disable_all_models(db).await.unwrap_or(0);
|
|
|
|
tracing::info!(
|
|
upstream_total = upstream_models.len(),
|
|
"syncing models from upstream"
|
|
);
|
|
|
|
let mut models_created = 0i64;
|
|
let mut models_updated = 0i64;
|
|
let mut versions_created = 0i64;
|
|
let mut pricing_created = 0i64;
|
|
let mut pricing_updated = 0i64;
|
|
|
|
for model in &upstream_models {
|
|
let provider_slug = extract_provider_name(model);
|
|
let provider = match upsert_provider(db, &provider_slug).await {
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
provider = %provider_slug,
|
|
error = %e,
|
|
"sync: upsert_provider error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let (model_record, _is_new) =
|
|
match upsert_model(db, provider.id, model).await {
|
|
Ok((m, created)) => {
|
|
if created {
|
|
models_created += 1;
|
|
} else {
|
|
models_updated += 1;
|
|
}
|
|
(m, created)
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model.id,
|
|
error = %e,
|
|
"sync: upsert_model error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let (version_record, version_is_new) =
|
|
match upsert_version(db, model_record.id, &model.id).await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model.id,
|
|
error = %e,
|
|
"sync: upsert_version error"
|
|
);
|
|
continue;
|
|
}
|
|
};
|
|
if version_is_new {
|
|
versions_created += 1;
|
|
}
|
|
|
|
match upsert_pricing(db, version_record.id, model.pricing.as_ref())
|
|
.await
|
|
{
|
|
Ok(PricingResult::Created) => pricing_created += 1,
|
|
Ok(PricingResult::Updated) => pricing_updated += 1,
|
|
Ok(PricingResult::Skipped) => {}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
model = %model.id,
|
|
error = %e,
|
|
"sync: upsert_pricing error"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
let deactivated = deactivate_orphaned_models(db).await.unwrap_or(0);
|
|
|
|
SyncModelsResponse {
|
|
models_created,
|
|
models_updated,
|
|
models_offline,
|
|
models_deactivated: deactivated,
|
|
versions_created,
|
|
pricing_created,
|
|
pricing_updated,
|
|
}
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn sync_upstream_models(
|
|
&self,
|
|
) -> Result<SyncModelsResponse, AppError> {
|
|
let api_key = self.config.ai_api_key().map_err(|e| {
|
|
AppError::InternalServerError(format!(
|
|
"AI API key not configured: {}",
|
|
e
|
|
))
|
|
})?;
|
|
|
|
let base_url = self.config.ai_basic_url().unwrap_or_default();
|
|
|
|
let config = EndpointConfig::new(&base_url, &api_key)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let upstream_models = ai::sync::list_models(&config)
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
tracing::info!(
|
|
model_count = upstream_models.len(),
|
|
"sync_upstream_models: {} models from upstream",
|
|
upstream_models.len()
|
|
);
|
|
|
|
let result = sync_models_from_upstream(&self.db, upstream_models).await;
|
|
|
|
tracing::info!(
|
|
models_created = result.models_created,
|
|
models_updated = result.models_updated,
|
|
versions_created = result.versions_created,
|
|
pricing_created = result.pricing_created,
|
|
pricing_updated = result.pricing_updated,
|
|
"sync_upstream_models: complete"
|
|
);
|
|
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
pub fn spawn_model_sync_loop(
|
|
service: AppService,
|
|
) -> tokio::task::JoinHandle<()> {
|
|
let db = service.db.clone();
|
|
let config = service.config.clone();
|
|
|
|
tokio::spawn(async move {
|
|
sync_once(&db, &config).await;
|
|
|
|
let mut tick = interval(Duration::from_secs(60 * 10));
|
|
loop {
|
|
tick.tick().await;
|
|
sync_once(&db, &config).await;
|
|
}
|
|
})
|
|
}
|
|
|
|
async fn sync_once(db: &db::AppDatabase, config: &config::AppConfig) {
|
|
let api_key = match config.ai_api_key() {
|
|
Ok(k) => k,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Model sync: AI API key not configured");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let base_url = config.ai_basic_url().unwrap_or_default();
|
|
|
|
let endpoint_config = match EndpointConfig::new(&base_url, &api_key) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Model sync: invalid endpoint config");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let upstream_models = match ai::sync::list_models(&endpoint_config).await {
|
|
Ok(models) => models,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "Model sync: failed to list upstream models");
|
|
return;
|
|
}
|
|
};
|
|
|
|
tracing::info!(
|
|
model_count = upstream_models.len(),
|
|
"Model sync: {} models from upstream",
|
|
upstream_models.len()
|
|
);
|
|
|
|
let result = sync_models_from_upstream(db, upstream_models).await;
|
|
|
|
tracing::info!(
|
|
models_created = result.models_created,
|
|
models_updated = result.models_updated,
|
|
versions_created = result.versions_created,
|
|
pricing_created = result.pricing_created,
|
|
pricing_updated = result.pricing_updated,
|
|
"Model sync complete"
|
|
);
|
|
}
|