//! AI model management — CRUD. //! //! All functions take `&DatabaseConnection` instead of `&AppService`. use chrono::Utc; use db::database::AppDatabase; use models::agents::model; use models::agents::model_pricing; use models::agents::model_version; use models::agents::{ ModelCapability, ModelModality, ModelStatus, model::{Column as MColumn, Entity as MEntity}, model_provider::Entity as ProviderEntity, }; use sea_orm::*; use uuid::Uuid; use crate::error::AgentError; #[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] pub struct CreateModelRequest { pub provider_id: Uuid, pub name: String, pub modality: String, pub capability: String, pub context_length: i64, pub max_output_tokens: Option, pub training_cutoff: Option>, #[serde(default)] pub is_open_source: bool, } #[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)] pub struct UpdateModelRequest { pub display_name: Option, pub modality: Option, pub capability: Option, pub context_length: Option, pub max_output_tokens: Option, pub training_cutoff: Option>, pub is_open_source: Option, pub status: Option, } #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct ModelResponse { pub id: Uuid, pub provider_id: Uuid, pub name: String, pub modality: String, pub capability: String, pub context_length: i64, pub max_output_tokens: Option, pub training_cutoff: Option>, pub is_open_source: bool, pub status: String, pub created_at: chrono::DateTime, pub updated_at: chrono::DateTime, } impl From for ModelResponse { fn from(m: model::Model) -> Self { Self { id: m.id, provider_id: m.provider_id, name: m.name, modality: m.modality, capability: m.capability, context_length: m.context_length, max_output_tokens: m.max_output_tokens, training_cutoff: m.training_cutoff, is_open_source: m.is_open_source, status: m.status, created_at: m.created_at, updated_at: m.updated_at, } } } /// List models, optionally filtered by provider. pub async fn list_models( db: &AppDatabase, provider_id: Option, ) -> Result, AgentError> { let mut query = MEntity::find().order_by_asc(MColumn::Name); if let Some(pid) = provider_id { query = query.filter(MColumn::ProviderId.eq(pid)); } let models = query.all(db).await?; Ok(models.into_iter().map(ModelResponse::from).collect()) } #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct ModelWithPricingResponse { pub id: Uuid, pub provider_id: Uuid, pub name: String, pub modality: String, pub capability: String, pub context_length: i64, pub max_output_tokens: Option, pub training_cutoff: Option>, pub is_open_source: bool, pub status: String, pub input_price: Option, pub output_price: Option, pub currency: Option, pub created_at: chrono::DateTime, pub updated_at: chrono::DateTime, } #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] pub struct ModelListResponse { pub data: Vec, pub total: u64, pub page: u64, pub per_page: u64, } /// List models with pricing, pagination, search, and deprecation filter. pub async fn list_models_with_pricing( db: &AppDatabase, provider_id: Option, search: Option<&str>, page: u64, per_page: u64, ) -> Result { let mut query = MEntity::find() .filter(MColumn::Status.ne("deprecated")) .order_by_asc(MColumn::Name); if let Some(pid) = provider_id { query = query.filter(MColumn::ProviderId.eq(pid)); } if let Some(q) = search { if !q.is_empty() { query = query.filter(MColumn::Name.contains(q)); } } let total = query.clone().count(db).await? as u64; let offset = (page.saturating_sub(1)) * per_page; let models = query .offset(offset) .limit(per_page) .all(db) .await?; // Batch-fetch default versions for these models let model_ids: Vec = models.iter().map(|m| m.id).collect(); let versions = if model_ids.is_empty() { vec![] } else { model_version::Entity::find() .filter(model_version::Column::ModelId.is_in(model_ids)) .filter(model_version::Column::IsDefault.eq(true)) .all(db) .await .unwrap_or_default() }; let version_ids: Vec = versions.iter().map(|v| v.id).collect(); let pricings = if version_ids.is_empty() { vec![] } else { model_pricing::Entity::find() .filter(model_pricing::Column::ModelVersionId.is_in(version_ids)) .all(db) .await .unwrap_or_default() }; // Build lookup: model_id → latest pricing (by effective_from DESC) let mut pricing_map: std::collections::HashMap = std::collections::HashMap::new(); let version_to_model: std::collections::HashMap = versions .iter() .map(|v| (v.id, v.model_id)) .collect(); for p in &pricings { if let Some(model_id) = version_to_model.get(&p.model_version_id) { match pricing_map.get(model_id) { Some(existing) => { if p.effective_from > existing.effective_from { pricing_map.insert(*model_id, p); } } None => { pricing_map.insert(*model_id, p); } } } } let data = models .into_iter() .map(|m| { let pricing = pricing_map.get(&m.id); ModelWithPricingResponse { id: m.id, provider_id: m.provider_id, name: m.name, modality: m.modality, capability: m.capability, context_length: m.context_length, max_output_tokens: m.max_output_tokens, training_cutoff: m.training_cutoff, is_open_source: m.is_open_source, status: m.status, input_price: pricing.map(|p| p.input_price_per_1k_tokens.clone()), output_price: pricing.map(|p| p.output_price_per_1k_tokens.clone()), currency: pricing.map(|p| p.currency.clone()), created_at: m.created_at, updated_at: m.updated_at, } }) .collect(); Ok(ModelListResponse { data, total, page, per_page, }) } /// Get a single model by ID. pub async fn get_model(db: &AppDatabase, id: Uuid) -> Result { let model = MEntity::find_by_id(id) .one(db) .await? .ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?; Ok(ModelResponse::from(model)) } /// Create a new model. pub async fn create_model( db: &AppDatabase, request: CreateModelRequest, ) -> Result { ProviderEntity::find_by_id(request.provider_id) .one(db) .await? .ok_or_else(|| AgentError::NotFound("Provider not found".to_string()))?; let _ = request .modality .parse::() .map_err(|_| AgentError::InvalidInput { field: "modality".into(), reason: "Invalid modality".into(), })?; let _ = request .capability .parse::() .map_err(|_| AgentError::InvalidInput { field: "capability".into(), reason: "Invalid capability".into(), })?; let now = Utc::now(); let active = model::ActiveModel { id: Set(Uuid::now_v7()), provider_id: Set(request.provider_id), name: Set(request.name), modality: Set(request.modality), capability: Set(request.capability), context_length: Set(request.context_length), max_output_tokens: Set(request.max_output_tokens), training_cutoff: Set(request.training_cutoff), is_open_source: Set(request.is_open_source), status: Set(ModelStatus::Active.to_string()), created_at: Set(now), updated_at: Set(now), ..Default::default() }; let model = active.insert(db).await?; Ok(ModelResponse::from(model)) } /// Update an existing model. pub async fn update_model( db: &AppDatabase, id: Uuid, request: UpdateModelRequest, ) -> Result { let model = MEntity::find_by_id(id) .one(db) .await? .ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?; let mut active: model::ActiveModel = model.into(); if let Some(modality) = request.modality { let _ = modality .parse::() .map_err(|_| AgentError::InvalidInput { field: "modality".into(), reason: "Invalid modality".into(), })?; active.modality = Set(modality); } if let Some(capability) = request.capability { let _ = capability .parse::() .map_err(|_| AgentError::InvalidInput { field: "capability".into(), reason: "Invalid capability".into(), })?; active.capability = Set(capability); } if let Some(context_length) = request.context_length { active.context_length = Set(context_length); } if let Some(max_output_tokens) = request.max_output_tokens { active.max_output_tokens = Set(Some(max_output_tokens)); } if let Some(training_cutoff) = request.training_cutoff { active.training_cutoff = Set(Some(training_cutoff)); } if let Some(is_open_source) = request.is_open_source { active.is_open_source = Set(is_open_source); } if let Some(status) = request.status { active.status = Set(status); } active.updated_at = Set(Utc::now()); let model = active.update(db).await?; Ok(ModelResponse::from(model)) } /// Delete a model by ID. pub async fn delete_model(db: &AppDatabase, id: Uuid) -> Result<(), AgentError> { MEntity::delete_by_id(id).exec(db).await?; Ok(()) }