gitdataai/libs/agent/model/model_entry.rs

333 lines
10 KiB
Rust

//! AI model management — CRUD.
//!
//! All functions take `&DatabaseConnection` instead of `&AppService`.
use chrono::Utc;
use db::database::AppDatabase;
use models::agents::model;
use models::agents::model_pricing;
use models::agents::model_version;
use models::agents::{
ModelCapability, ModelModality, ModelStatus,
model::{Column as MColumn, Entity as MEntity},
model_provider::Entity as ProviderEntity,
};
use sea_orm::*;
use uuid::Uuid;
use crate::error::AgentError;
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateModelRequest {
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
#[serde(default)]
pub is_open_source: bool,
}
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct UpdateModelRequest {
pub display_name: Option<String>,
pub modality: Option<String>,
pub capability: Option<String>,
pub context_length: Option<i64>,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: Option<bool>,
pub status: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
pub struct ModelResponse {
pub id: Uuid,
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: bool,
pub status: String,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
impl From<model::Model> for ModelResponse {
fn from(m: model::Model) -> Self {
Self {
id: m.id,
provider_id: m.provider_id,
name: m.name,
modality: m.modality,
capability: m.capability,
context_length: m.context_length,
max_output_tokens: m.max_output_tokens,
training_cutoff: m.training_cutoff,
is_open_source: m.is_open_source,
status: m.status,
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}
/// List models, optionally filtered by provider.
pub async fn list_models(
db: &AppDatabase,
provider_id: Option<Uuid>,
) -> Result<Vec<ModelResponse>, AgentError> {
let mut query = MEntity::find().order_by_asc(MColumn::Name);
if let Some(pid) = provider_id {
query = query.filter(MColumn::ProviderId.eq(pid));
}
let models = query.all(db).await?;
Ok(models.into_iter().map(ModelResponse::from).collect())
}
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
pub struct ModelWithPricingResponse {
pub id: Uuid,
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: bool,
pub status: String,
pub input_price: Option<String>,
pub output_price: Option<String>,
pub currency: Option<String>,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
pub struct ModelListResponse {
pub data: Vec<ModelWithPricingResponse>,
pub total: u64,
pub page: u64,
pub per_page: u64,
}
/// List models with pricing, pagination, search, and deprecation filter.
pub async fn list_models_with_pricing(
db: &AppDatabase,
provider_id: Option<Uuid>,
search: Option<&str>,
page: u64,
per_page: u64,
) -> Result<ModelListResponse, AgentError> {
let mut query = MEntity::find()
.filter(MColumn::Status.ne("deprecated"))
.order_by_asc(MColumn::Name);
if let Some(pid) = provider_id {
query = query.filter(MColumn::ProviderId.eq(pid));
}
if let Some(q) = search {
if !q.is_empty() {
query = query.filter(MColumn::Name.contains(q));
}
}
let total = query.clone().count(db).await? as u64;
let offset = (page.saturating_sub(1)) * per_page;
let models = query.offset(offset).limit(per_page).all(db).await?;
// Batch-fetch default versions for these models
let model_ids: Vec<Uuid> = models.iter().map(|m| m.id).collect();
let versions = if model_ids.is_empty() {
vec![]
} else {
model_version::Entity::find()
.filter(model_version::Column::ModelId.is_in(model_ids))
.filter(model_version::Column::IsDefault.eq(true))
.all(db)
.await
.unwrap_or_default()
};
let version_ids: Vec<Uuid> = versions.iter().map(|v| v.id).collect();
let pricings = if version_ids.is_empty() {
vec![]
} else {
model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.is_in(version_ids))
.all(db)
.await
.unwrap_or_default()
};
// Build lookup: model_id → latest pricing (by effective_from DESC)
let mut pricing_map: std::collections::HashMap<Uuid, &model_pricing::Model> =
std::collections::HashMap::new();
let version_to_model: std::collections::HashMap<Uuid, Uuid> =
versions.iter().map(|v| (v.id, v.model_id)).collect();
for p in &pricings {
if let Some(model_id) = version_to_model.get(&p.model_version_id) {
match pricing_map.get(model_id) {
Some(existing) => {
if p.effective_from > existing.effective_from {
pricing_map.insert(*model_id, p);
}
}
None => {
pricing_map.insert(*model_id, p);
}
}
}
}
let data = models
.into_iter()
.map(|m| {
let pricing = pricing_map.get(&m.id);
ModelWithPricingResponse {
id: m.id,
provider_id: m.provider_id,
name: m.name,
modality: m.modality,
capability: m.capability,
context_length: m.context_length,
max_output_tokens: m.max_output_tokens,
training_cutoff: m.training_cutoff,
is_open_source: m.is_open_source,
status: m.status,
input_price: pricing.map(|p| p.input_price_per_1k_tokens.clone()),
output_price: pricing.map(|p| p.output_price_per_1k_tokens.clone()),
currency: pricing.map(|p| p.currency.clone()),
created_at: m.created_at,
updated_at: m.updated_at,
}
})
.collect();
Ok(ModelListResponse {
data,
total,
page,
per_page,
})
}
/// Get a single model by ID.
pub async fn get_model(db: &AppDatabase, id: Uuid) -> Result<ModelResponse, AgentError> {
let model = MEntity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
Ok(ModelResponse::from(model))
}
/// Create a new model.
pub async fn create_model(
db: &AppDatabase,
request: CreateModelRequest,
) -> Result<ModelResponse, AgentError> {
ProviderEntity::find_by_id(request.provider_id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound("Provider not found".to_string()))?;
let _ = request
.modality
.parse::<ModelModality>()
.map_err(|_| AgentError::InvalidInput {
field: "modality".into(),
reason: "Invalid modality".into(),
})?;
let _ =
request
.capability
.parse::<ModelCapability>()
.map_err(|_| AgentError::InvalidInput {
field: "capability".into(),
reason: "Invalid capability".into(),
})?;
let now = Utc::now();
let active = model::ActiveModel {
id: Set(Uuid::now_v7()),
provider_id: Set(request.provider_id),
name: Set(request.name),
modality: Set(request.modality),
capability: Set(request.capability),
context_length: Set(request.context_length),
max_output_tokens: Set(request.max_output_tokens),
training_cutoff: Set(request.training_cutoff),
is_open_source: Set(request.is_open_source),
status: Set(ModelStatus::Active.to_string()),
created_at: Set(now),
updated_at: Set(now),
..Default::default()
};
let model = active.insert(db).await?;
Ok(ModelResponse::from(model))
}
/// Update an existing model.
pub async fn update_model(
db: &AppDatabase,
id: Uuid,
request: UpdateModelRequest,
) -> Result<ModelResponse, AgentError> {
let model = MEntity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
let mut active: model::ActiveModel = model.into();
if let Some(modality) = request.modality {
let _ = modality
.parse::<ModelModality>()
.map_err(|_| AgentError::InvalidInput {
field: "modality".into(),
reason: "Invalid modality".into(),
})?;
active.modality = Set(modality);
}
if let Some(capability) = request.capability {
let _ = capability
.parse::<ModelCapability>()
.map_err(|_| AgentError::InvalidInput {
field: "capability".into(),
reason: "Invalid capability".into(),
})?;
active.capability = Set(capability);
}
if let Some(context_length) = request.context_length {
active.context_length = Set(context_length);
}
if let Some(max_output_tokens) = request.max_output_tokens {
active.max_output_tokens = Set(Some(max_output_tokens));
}
if let Some(training_cutoff) = request.training_cutoff {
active.training_cutoff = Set(Some(training_cutoff));
}
if let Some(is_open_source) = request.is_open_source {
active.is_open_source = Set(is_open_source);
}
if let Some(status) = request.status {
active.status = Set(status);
}
active.updated_at = Set(Utc::now());
let model = active.update(db).await?;
Ok(ModelResponse::from(model))
}
/// Delete a model by ID.
pub async fn delete_model(db: &AppDatabase, id: Uuid) -> Result<(), AgentError> {
MEntity::delete_by_id(id).exec(db).await?;
Ok(())
}