gitdataai/lib/service/ai/sync.rs

604 lines
18 KiB
Rust

use std::time::Duration;
use ai::client::EndpointConfig;
use ai::sync::{UpstreamModel, UpstreamPricing};
use chrono::Utc;
use db::sqlx::{self, types::Decimal};
use model::ai::{AiModelModel, AiModelVersionModel, AiProviderModel};
use serde::Serialize;
use tokio::time::interval;
use utoipa::ToSchema;
use uuid::Uuid;
use crate::{AppService, error::AppError};
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct SyncModelsResponse {
pub models_created: i64,
pub models_updated: i64,
pub models_offline: i64,
pub models_deactivated: i64,
pub versions_created: i64,
pub pricing_created: i64,
pub pricing_updated: i64,
}
#[derive(Debug, Clone)]
pub struct SyncResult {
pub provider_id: Uuid,
pub provider_name: String,
pub total: u32,
pub synced: u32,
pub skipped: u32,
}
fn extract_provider_name(model: &UpstreamModel) -> String {
if let Some(owned) = &model.owned_by {
if !owned.is_empty() {
return normalize_provider_name(owned);
}
}
normalize_provider_name(model.id.split('/').next().unwrap_or("unknown"))
}
fn normalize_provider_name(slug: &str) -> String {
match slug {
"openai" => "openai",
"anthropic" => "anthropic",
"google" | "google-ai" => "google",
"mistralai" => "mistral",
"meta-llama" | "meta" => "meta",
"deepseek" => "deepseek",
"azure" | "azure-openai" => "azure",
"x-ai" | "xai" => "xai",
"moonshot" => "moonshot",
"alibaba" | "qwen" => "qwen",
s => s,
}
.to_string()
}
fn provider_display_name(name: &str) -> String {
match name {
"openai" => "OpenAI",
"anthropic" => "Anthropic",
"google" => "Google DeepMind",
"mistral" => "Mistral AI",
"meta" => "Meta",
"deepseek" => "DeepSeek",
"azure" => "Microsoft Azure",
"xai" => "xAI",
"moonshot" => "Moonshot AI",
"qwen" => "Alibaba Qwen",
s => s,
}
.to_string()
}
fn infer_modality(model: &UpstreamModel) -> &'static str {
if let Some(caps) = &model.capabilities {
if caps.vision == Some(true) {
return "multimodal";
}
}
let lower = model.id.to_lowercase();
if lower.contains("vision")
|| lower.contains("dall-e")
|| lower.contains("gpt-image")
|| lower.contains("gpt-4o")
{
"multimodal"
} else if lower.contains("embedding") {
"text"
} else if lower.contains("whisper") || lower.contains("audio") {
"audio"
} else {
"text"
}
}
async fn upsert_provider(
db: &db::AppDatabase,
slug: &str,
) -> Result<AiProviderModel, AppError> {
let _display = provider_display_name(slug);
let now = Utc::now();
if let Some(existing) = sqlx::query_as::<_, AiProviderModel>(
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
FROM ai_provider WHERE name = $1",
)
.bind(slug)
.fetch_optional(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
{
sqlx::query(
"UPDATE ai_provider SET updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(existing.id)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(existing)
} else {
let id = Uuid::now_v7();
sqlx::query(
"INSERT INTO ai_provider (id, name, enabled, created_at, updated_at) \
VALUES ($1, $2, true, $3, $3)",
)
.bind(id)
.bind(slug)
.bind(now)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(AiProviderModel {
id,
name: slug.to_string(),
base_url: None,
website_url: None,
logo_url: None,
enabled: true,
created_at: now,
updated_at: now,
})
}
}
async fn upsert_model(
db: &db::AppDatabase,
provider_id: Uuid,
model: &UpstreamModel,
) -> Result<(AiModelModel, bool), AppError> {
let now = Utc::now();
let name = &model.id;
let modality = infer_modality(model);
let ctx = model.context_length.map(|c| c as i32);
let max_out = model.max_output_tokens.map(|v| v as i32);
if let Some(existing) = sqlx::query_as::<_, AiModelModel>(
"SELECT id, provider, name, display_name, description, modality, \
context_window, input_token_limit, output_token_limit, \
enabled, public, created_at, updated_at, deleted_at \
FROM ai_model WHERE name = $1 AND deleted_at IS NULL",
)
.bind(name)
.fetch_optional(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
{
sqlx::query(
"UPDATE ai_model SET provider = $1, context_window = $2, \
output_token_limit = $3, enabled = true, updated_at = $4 \
WHERE id = $5",
)
.bind(provider_id)
.bind(ctx)
.bind(max_out)
.bind(now)
.bind(existing.id)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok((existing, false))
} else {
let id = Uuid::now_v7();
sqlx::query(
"INSERT INTO ai_model (id, provider, name, display_name, modality, \
context_window, input_token_limit, output_token_limit, enabled, \
public, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, true, $9, $9)",
)
.bind(id)
.bind(provider_id)
.bind(name)
.bind(name)
.bind(modality)
.bind(ctx)
.bind(ctx)
.bind(max_out)
.bind(now)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let inserted = sqlx::query_as::<_, AiModelModel>(
"SELECT id, provider, name, display_name, description, modality, \
context_window, input_token_limit, output_token_limit, \
enabled, public, created_at, updated_at, deleted_at \
FROM ai_model WHERE id = $1",
)
.bind(id)
.fetch_one(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok((inserted, true))
}
}
async fn upsert_version(
db: &db::AppDatabase,
model_id: Uuid,
provider_model_name: &str,
) -> Result<(AiModelVersionModel, bool), AppError> {
let now = Utc::now();
if let Some(existing) = sqlx::query_as::<_, AiModelVersionModel>(
"SELECT id, model, version, provider_model_name, \
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
FROM ai_model_version WHERE model = $1 AND provider_model_name = $2",
)
.bind(model_id)
.bind(provider_model_name)
.fetch_optional(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
{
Ok((existing, false))
} else {
let id = Uuid::now_v7();
sqlx::query(
"INSERT INTO ai_model_version (id, model, version, provider_model_name, \
enabled, created_at, updated_at) \
VALUES ($1, $2, 'latest', $3, true, $4, $4)",
)
.bind(id)
.bind(model_id)
.bind(provider_model_name)
.bind(now)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let inserted = sqlx::query_as::<_, AiModelVersionModel>(
"SELECT id, model, version, provider_model_name, \
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
FROM ai_model_version WHERE id = $1",
)
.bind(id)
.fetch_one(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok((inserted, true))
}
}
async fn upsert_pricing(
db: &db::AppDatabase,
version_id: Uuid,
pricing: Option<&UpstreamPricing>,
) -> Result<PricingResult, AppError> {
let Some(p) = pricing else {
return Ok(PricingResult::Skipped);
};
let input_million: Option<Decimal> = p
.prompt
.as_deref()
.and_then(parse_token_price_decimal)
.map(|per_token| per_token * Decimal::from(1_000_000u64))
.or_else(|| {
p.input
.filter(|v| *v > 0.0)
.map(|v| Decimal::try_from(v).unwrap_or_default())
});
let output_million: Option<Decimal> = p
.completion
.as_deref()
.and_then(parse_token_price_decimal)
.map(|per_token| per_token * Decimal::from(1_000_000u64))
.or_else(|| {
p.output
.filter(|v| *v > 0.0)
.map(|v| Decimal::try_from(v).unwrap_or_default())
});
let cache_input: Option<Decimal> = p
.cache_read
.filter(|v| *v > 0.0)
.map(|v| Decimal::try_from(v).unwrap_or_default());
if input_million.is_none() && output_million.is_none() {
return Ok(PricingResult::Skipped);
}
let existing = sqlx::query_scalar::<_, Uuid>(
"SELECT id FROM ai_model_version WHERE id = $1",
)
.bind(version_id)
.fetch_optional(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
if existing.is_none() {
return Ok(PricingResult::Skipped);
}
let count = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM ai_model_version \
WHERE id = $1 AND input_price_per_million IS NOT NULL",
)
.bind(version_id)
.fetch_one(db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
sqlx::query(
"UPDATE ai_model_version SET \
input_price_per_million = COALESCE($1, input_price_per_million), \
output_price_per_million = COALESCE($2, output_price_per_million), \
cached_input_price_per_million = COALESCE($3, cached_input_price_per_million), \
updated_at = $4 \
WHERE id = $5",
)
.bind(&input_million)
.bind(&output_million)
.bind(&cache_input)
.bind(Utc::now())
.bind(version_id)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
if count > 0 {
Ok(PricingResult::Updated)
} else {
Ok(PricingResult::Created)
}
}
enum PricingResult {
Created,
Updated,
Skipped,
}
fn parse_token_price_decimal(s: &str) -> Option<Decimal> {
use std::str::FromStr;
Decimal::from_str(s).ok()
}
async fn disable_all_models(db: &db::AppDatabase) -> Result<i64, AppError> {
let result = sqlx::query(
"UPDATE ai_model SET enabled = false, updated_at = $1 WHERE enabled = true",
)
.bind(Utc::now())
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(result.rows_affected() as i64)
}
async fn deactivate_orphaned_models(
db: &db::AppDatabase,
) -> Result<i64, AppError> {
let now = Utc::now();
sqlx::query(
"UPDATE ai_model_version SET enabled = false, updated_at = $1 \
WHERE model IN (SELECT id FROM ai_model WHERE enabled = false AND deleted_at IS NULL)",
)
.bind(now)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let result = sqlx::query(
"UPDATE ai_model SET deleted_at = $1, updated_at = $1 \
WHERE enabled = false AND deleted_at IS NULL",
)
.bind(now)
.execute(db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(result.rows_affected() as i64)
}
async fn sync_models_from_upstream(
db: &db::AppDatabase,
upstream_models: Vec<UpstreamModel>,
) -> SyncModelsResponse {
let models_offline = disable_all_models(db).await.unwrap_or(0);
tracing::info!(
upstream_total = upstream_models.len(),
"syncing models from upstream"
);
let mut models_created = 0i64;
let mut models_updated = 0i64;
let mut versions_created = 0i64;
let mut pricing_created = 0i64;
let mut pricing_updated = 0i64;
for model in &upstream_models {
let provider_slug = extract_provider_name(model);
let provider = match upsert_provider(db, &provider_slug).await {
Ok(p) => p,
Err(e) => {
tracing::warn!(
provider = %provider_slug,
error = %e,
"sync: upsert_provider error"
);
continue;
}
};
let (model_record, _is_new) =
match upsert_model(db, provider.id, model).await {
Ok((m, created)) => {
if created {
models_created += 1;
} else {
models_updated += 1;
}
(m, created)
}
Err(e) => {
tracing::warn!(
model = %model.id,
error = %e,
"sync: upsert_model error"
);
continue;
}
};
let (version_record, version_is_new) =
match upsert_version(db, model_record.id, &model.id).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(
model = %model.id,
error = %e,
"sync: upsert_version error"
);
continue;
}
};
if version_is_new {
versions_created += 1;
}
match upsert_pricing(db, version_record.id, model.pricing.as_ref())
.await
{
Ok(PricingResult::Created) => pricing_created += 1,
Ok(PricingResult::Updated) => pricing_updated += 1,
Ok(PricingResult::Skipped) => {}
Err(e) => {
tracing::warn!(
model = %model.id,
error = %e,
"sync: upsert_pricing error"
);
}
}
}
let deactivated = deactivate_orphaned_models(db).await.unwrap_or(0);
SyncModelsResponse {
models_created,
models_updated,
models_offline,
models_deactivated: deactivated,
versions_created,
pricing_created,
pricing_updated,
}
}
impl AppService {
pub async fn sync_upstream_models(
&self,
) -> Result<SyncModelsResponse, AppError> {
let api_key = self.config.ai_api_key().map_err(|e| {
AppError::InternalServerError(format!(
"AI API key not configured: {}",
e
))
})?;
let base_url = self.config.ai_basic_url().unwrap_or_default();
let config = EndpointConfig::new(&base_url, &api_key)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let upstream_models = ai::sync::list_models(&config)
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
tracing::info!(
model_count = upstream_models.len(),
"sync_upstream_models: {} models from upstream",
upstream_models.len()
);
let result = sync_models_from_upstream(&self.db, upstream_models).await;
tracing::info!(
models_created = result.models_created,
models_updated = result.models_updated,
versions_created = result.versions_created,
pricing_created = result.pricing_created,
pricing_updated = result.pricing_updated,
"sync_upstream_models: complete"
);
Ok(result)
}
}
pub fn spawn_model_sync_loop(
service: AppService,
) -> tokio::task::JoinHandle<()> {
let db = service.db.clone();
let config = service.config.clone();
tokio::spawn(async move {
sync_once(&db, &config).await;
let mut tick = interval(Duration::from_secs(60 * 10));
loop {
tick.tick().await;
sync_once(&db, &config).await;
}
})
}
async fn sync_once(db: &db::AppDatabase, config: &config::AppConfig) {
let api_key = match config.ai_api_key() {
Ok(k) => k,
Err(e) => {
tracing::warn!(error = %e, "Model sync: AI API key not configured");
return;
}
};
let base_url = config.ai_basic_url().unwrap_or_default();
let endpoint_config = match EndpointConfig::new(&base_url, &api_key) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "Model sync: invalid endpoint config");
return;
}
};
let upstream_models = match ai::sync::list_models(&endpoint_config).await {
Ok(models) => models,
Err(e) => {
tracing::warn!(error = %e, "Model sync: failed to list upstream models");
return;
}
};
tracing::info!(
model_count = upstream_models.len(),
"Model sync: {} models from upstream",
upstream_models.len()
);
let result = sync_models_from_upstream(db, upstream_models).await;
tracing::info!(
models_created = result.models_created,
models_updated = result.models_updated,
versions_created = result.versions_created,
pricing_created = result.pricing_created,
pricing_updated = result.pricing_updated,
"Model sync complete"
);
}