333 lines
10 KiB
Rust
333 lines
10 KiB
Rust
//! AI model management — CRUD.
|
|
//!
|
|
//! All functions take `&DatabaseConnection` instead of `&AppService`.
|
|
|
|
use chrono::Utc;
|
|
use db::database::AppDatabase;
|
|
use models::agents::model;
|
|
use models::agents::model_pricing;
|
|
use models::agents::model_version;
|
|
use models::agents::{
|
|
ModelCapability, ModelModality, ModelStatus,
|
|
model::{Column as MColumn, Entity as MEntity},
|
|
model_provider::Entity as ProviderEntity,
|
|
};
|
|
use sea_orm::*;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::AgentError;
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
|
|
pub struct CreateModelRequest {
|
|
pub provider_id: Uuid,
|
|
pub name: String,
|
|
pub modality: String,
|
|
pub capability: String,
|
|
pub context_length: i64,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
#[serde(default)]
|
|
pub is_open_source: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
|
|
pub struct UpdateModelRequest {
|
|
pub display_name: Option<String>,
|
|
pub modality: Option<String>,
|
|
pub capability: Option<String>,
|
|
pub context_length: Option<i64>,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
pub is_open_source: Option<bool>,
|
|
pub status: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
|
|
pub struct ModelResponse {
|
|
pub id: Uuid,
|
|
pub provider_id: Uuid,
|
|
pub name: String,
|
|
pub modality: String,
|
|
pub capability: String,
|
|
pub context_length: i64,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
pub is_open_source: bool,
|
|
pub status: String,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
pub updated_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
impl From<model::Model> for ModelResponse {
|
|
fn from(m: model::Model) -> Self {
|
|
Self {
|
|
id: m.id,
|
|
provider_id: m.provider_id,
|
|
name: m.name,
|
|
modality: m.modality,
|
|
capability: m.capability,
|
|
context_length: m.context_length,
|
|
max_output_tokens: m.max_output_tokens,
|
|
training_cutoff: m.training_cutoff,
|
|
is_open_source: m.is_open_source,
|
|
status: m.status,
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// List models, optionally filtered by provider.
|
|
pub async fn list_models(
|
|
db: &AppDatabase,
|
|
provider_id: Option<Uuid>,
|
|
) -> Result<Vec<ModelResponse>, AgentError> {
|
|
let mut query = MEntity::find().order_by_asc(MColumn::Name);
|
|
if let Some(pid) = provider_id {
|
|
query = query.filter(MColumn::ProviderId.eq(pid));
|
|
}
|
|
let models = query.all(db).await?;
|
|
Ok(models.into_iter().map(ModelResponse::from).collect())
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
|
|
pub struct ModelWithPricingResponse {
|
|
pub id: Uuid,
|
|
pub provider_id: Uuid,
|
|
pub name: String,
|
|
pub modality: String,
|
|
pub capability: String,
|
|
pub context_length: i64,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
pub is_open_source: bool,
|
|
pub status: String,
|
|
pub input_price: Option<String>,
|
|
pub output_price: Option<String>,
|
|
pub currency: Option<String>,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
pub updated_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
|
|
pub struct ModelListResponse {
|
|
pub data: Vec<ModelWithPricingResponse>,
|
|
pub total: u64,
|
|
pub page: u64,
|
|
pub per_page: u64,
|
|
}
|
|
|
|
/// List models with pricing, pagination, search, and deprecation filter.
|
|
pub async fn list_models_with_pricing(
|
|
db: &AppDatabase,
|
|
provider_id: Option<Uuid>,
|
|
search: Option<&str>,
|
|
page: u64,
|
|
per_page: u64,
|
|
) -> Result<ModelListResponse, AgentError> {
|
|
let mut query = MEntity::find()
|
|
.filter(MColumn::Status.ne("deprecated"))
|
|
.order_by_asc(MColumn::Name);
|
|
|
|
if let Some(pid) = provider_id {
|
|
query = query.filter(MColumn::ProviderId.eq(pid));
|
|
}
|
|
if let Some(q) = search {
|
|
if !q.is_empty() {
|
|
query = query.filter(MColumn::Name.contains(q));
|
|
}
|
|
}
|
|
|
|
let total = query.clone().count(db).await? as u64;
|
|
let offset = (page.saturating_sub(1)) * per_page;
|
|
let models = query.offset(offset).limit(per_page).all(db).await?;
|
|
|
|
// Batch-fetch default versions for these models
|
|
let model_ids: Vec<Uuid> = models.iter().map(|m| m.id).collect();
|
|
let versions = if model_ids.is_empty() {
|
|
vec![]
|
|
} else {
|
|
model_version::Entity::find()
|
|
.filter(model_version::Column::ModelId.is_in(model_ids))
|
|
.filter(model_version::Column::IsDefault.eq(true))
|
|
.all(db)
|
|
.await
|
|
.unwrap_or_default()
|
|
};
|
|
|
|
let version_ids: Vec<Uuid> = versions.iter().map(|v| v.id).collect();
|
|
let pricings = if version_ids.is_empty() {
|
|
vec![]
|
|
} else {
|
|
model_pricing::Entity::find()
|
|
.filter(model_pricing::Column::ModelVersionId.is_in(version_ids))
|
|
.all(db)
|
|
.await
|
|
.unwrap_or_default()
|
|
};
|
|
|
|
// Build lookup: model_id → latest pricing (by effective_from DESC)
|
|
let mut pricing_map: std::collections::HashMap<Uuid, &model_pricing::Model> =
|
|
std::collections::HashMap::new();
|
|
let version_to_model: std::collections::HashMap<Uuid, Uuid> =
|
|
versions.iter().map(|v| (v.id, v.model_id)).collect();
|
|
|
|
for p in &pricings {
|
|
if let Some(model_id) = version_to_model.get(&p.model_version_id) {
|
|
match pricing_map.get(model_id) {
|
|
Some(existing) => {
|
|
if p.effective_from > existing.effective_from {
|
|
pricing_map.insert(*model_id, p);
|
|
}
|
|
}
|
|
None => {
|
|
pricing_map.insert(*model_id, p);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let data = models
|
|
.into_iter()
|
|
.map(|m| {
|
|
let pricing = pricing_map.get(&m.id);
|
|
ModelWithPricingResponse {
|
|
id: m.id,
|
|
provider_id: m.provider_id,
|
|
name: m.name,
|
|
modality: m.modality,
|
|
capability: m.capability,
|
|
context_length: m.context_length,
|
|
max_output_tokens: m.max_output_tokens,
|
|
training_cutoff: m.training_cutoff,
|
|
is_open_source: m.is_open_source,
|
|
status: m.status,
|
|
input_price: pricing.map(|p| p.input_price_per_1k_tokens.clone()),
|
|
output_price: pricing.map(|p| p.output_price_per_1k_tokens.clone()),
|
|
currency: pricing.map(|p| p.currency.clone()),
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
Ok(ModelListResponse {
|
|
data,
|
|
total,
|
|
page,
|
|
per_page,
|
|
})
|
|
}
|
|
|
|
/// Get a single model by ID.
|
|
pub async fn get_model(db: &AppDatabase, id: Uuid) -> Result<ModelResponse, AgentError> {
|
|
let model = MEntity::find_by_id(id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
/// Create a new model.
|
|
pub async fn create_model(
|
|
db: &AppDatabase,
|
|
request: CreateModelRequest,
|
|
) -> Result<ModelResponse, AgentError> {
|
|
ProviderEntity::find_by_id(request.provider_id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or_else(|| AgentError::NotFound("Provider not found".to_string()))?;
|
|
|
|
let _ = request
|
|
.modality
|
|
.parse::<ModelModality>()
|
|
.map_err(|_| AgentError::InvalidInput {
|
|
field: "modality".into(),
|
|
reason: "Invalid modality".into(),
|
|
})?;
|
|
let _ =
|
|
request
|
|
.capability
|
|
.parse::<ModelCapability>()
|
|
.map_err(|_| AgentError::InvalidInput {
|
|
field: "capability".into(),
|
|
reason: "Invalid capability".into(),
|
|
})?;
|
|
|
|
let now = Utc::now();
|
|
let active = model::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
provider_id: Set(request.provider_id),
|
|
name: Set(request.name),
|
|
modality: Set(request.modality),
|
|
capability: Set(request.capability),
|
|
context_length: Set(request.context_length),
|
|
max_output_tokens: Set(request.max_output_tokens),
|
|
training_cutoff: Set(request.training_cutoff),
|
|
is_open_source: Set(request.is_open_source),
|
|
status: Set(ModelStatus::Active.to_string()),
|
|
created_at: Set(now),
|
|
updated_at: Set(now),
|
|
..Default::default()
|
|
};
|
|
let model = active.insert(db).await?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
/// Update an existing model.
|
|
pub async fn update_model(
|
|
db: &AppDatabase,
|
|
id: Uuid,
|
|
request: UpdateModelRequest,
|
|
) -> Result<ModelResponse, AgentError> {
|
|
let model = MEntity::find_by_id(id)
|
|
.one(db)
|
|
.await?
|
|
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
|
|
|
|
let mut active: model::ActiveModel = model.into();
|
|
if let Some(modality) = request.modality {
|
|
let _ = modality
|
|
.parse::<ModelModality>()
|
|
.map_err(|_| AgentError::InvalidInput {
|
|
field: "modality".into(),
|
|
reason: "Invalid modality".into(),
|
|
})?;
|
|
active.modality = Set(modality);
|
|
}
|
|
if let Some(capability) = request.capability {
|
|
let _ = capability
|
|
.parse::<ModelCapability>()
|
|
.map_err(|_| AgentError::InvalidInput {
|
|
field: "capability".into(),
|
|
reason: "Invalid capability".into(),
|
|
})?;
|
|
active.capability = Set(capability);
|
|
}
|
|
if let Some(context_length) = request.context_length {
|
|
active.context_length = Set(context_length);
|
|
}
|
|
if let Some(max_output_tokens) = request.max_output_tokens {
|
|
active.max_output_tokens = Set(Some(max_output_tokens));
|
|
}
|
|
if let Some(training_cutoff) = request.training_cutoff {
|
|
active.training_cutoff = Set(Some(training_cutoff));
|
|
}
|
|
if let Some(is_open_source) = request.is_open_source {
|
|
active.is_open_source = Set(is_open_source);
|
|
}
|
|
if let Some(status) = request.status {
|
|
active.status = Set(status);
|
|
}
|
|
active.updated_at = Set(Utc::now());
|
|
|
|
let model = active.update(db).await?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
/// Delete a model by ID.
|
|
pub async fn delete_model(db: &AppDatabase, id: Uuid) -> Result<(), AgentError> {
|
|
MEntity::delete_by_id(id).exec(db).await?;
|
|
Ok(())
|
|
}
|