gitdataai/libs/service/agent/model.rs
2026-04-15 09:08:09 +08:00

198 lines
6.4 KiB
Rust

use crate::AppService;
use crate::error::AppError;
use chrono::Utc;
use models::agents::model;
use models::agents::{
ModelCapability, ModelModality, ModelStatus,
model::{Column as MColumn, Entity as MEntity},
model_provider::Entity as ProviderEntity,
};
use sea_orm::*;
use serde::{Deserialize, Serialize};
use session::Session;
use utoipa::ToSchema;
use uuid::Uuid;
use super::provider::require_system_caller;
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct CreateModelRequest {
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
#[serde(default)]
pub is_open_source: bool,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct UpdateModelRequest {
pub display_name: Option<String>,
pub modality: Option<String>,
pub capability: Option<String>,
pub context_length: Option<i64>,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: Option<bool>,
pub status: Option<String>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ModelResponse {
pub id: Uuid,
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: bool,
pub status: String,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
impl From<model::Model> for ModelResponse {
fn from(m: model::Model) -> Self {
Self {
id: m.id,
provider_id: m.provider_id,
name: m.name,
modality: m.modality,
capability: m.capability,
context_length: m.context_length,
max_output_tokens: m.max_output_tokens,
training_cutoff: m.training_cutoff,
is_open_source: m.is_open_source,
status: m.status,
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}
impl AppService {
pub async fn agent_model_list(
&self,
provider_id: Option<Uuid>,
_ctx: &Session,
) -> Result<Vec<ModelResponse>, AppError> {
let mut query = MEntity::find().order_by_asc(MColumn::Name);
if let Some(pid) = provider_id {
query = query.filter(MColumn::ProviderId.eq(pid));
}
let models = query.all(&self.db).await?;
Ok(models.into_iter().map(ModelResponse::from).collect())
}
pub async fn agent_model_get(
&self,
id: Uuid,
_ctx: &Session,
) -> Result<ModelResponse, AppError> {
let model = MEntity::find_by_id(id)
.one(&self.db)
.await?
.ok_or(AppError::NotFound("Model not found".to_string()))?;
Ok(ModelResponse::from(model))
}
pub async fn agent_model_create(
&self,
request: CreateModelRequest,
ctx: &Session,
) -> Result<ModelResponse, AppError> {
require_system_caller(ctx)?;
ProviderEntity::find_by_id(request.provider_id)
.one(&self.db)
.await?
.ok_or(AppError::NotFound("Provider not found".to_string()))?;
let _ = request
.modality
.parse::<ModelModality>()
.map_err(|_| AppError::BadRequest("Invalid modality".to_string()))?;
let _ = request
.capability
.parse::<ModelCapability>()
.map_err(|_| AppError::BadRequest("Invalid capability".to_string()))?;
let now = Utc::now();
let active = model::ActiveModel {
id: Set(Uuid::now_v7()),
provider_id: Set(request.provider_id),
name: Set(request.name),
modality: Set(request.modality),
capability: Set(request.capability),
context_length: Set(request.context_length),
max_output_tokens: Set(request.max_output_tokens),
training_cutoff: Set(request.training_cutoff),
is_open_source: Set(request.is_open_source),
status: Set(ModelStatus::Active.to_string()),
created_at: Set(now),
updated_at: Set(now),
..Default::default()
};
let model = active.insert(&self.db).await?;
Ok(ModelResponse::from(model))
}
pub async fn agent_model_update(
&self,
id: Uuid,
request: UpdateModelRequest,
ctx: &Session,
) -> Result<ModelResponse, AppError> {
require_system_caller(ctx)?;
let model = MEntity::find_by_id(id)
.one(&self.db)
.await?
.ok_or(AppError::NotFound("Model not found".to_string()))?;
let mut active: model::ActiveModel = model.into();
if let Some(modality) = request.modality {
let _ = modality
.parse::<ModelModality>()
.map_err(|_| AppError::BadRequest("Invalid modality".to_string()))?;
active.modality = Set(modality);
}
if let Some(capability) = request.capability {
let _ = capability
.parse::<ModelCapability>()
.map_err(|_| AppError::BadRequest("Invalid capability".to_string()))?;
active.capability = Set(capability);
}
if let Some(context_length) = request.context_length {
active.context_length = Set(context_length);
}
if let Some(max_output_tokens) = request.max_output_tokens {
active.max_output_tokens = Set(Some(max_output_tokens));
}
if let Some(training_cutoff) = request.training_cutoff {
active.training_cutoff = Set(Some(training_cutoff));
}
if let Some(is_open_source) = request.is_open_source {
active.is_open_source = Set(is_open_source);
}
if let Some(status) = request.status {
active.status = Set(status);
}
active.updated_at = Set(Utc::now());
let model = active.update(&self.db).await?;
Ok(ModelResponse::from(model))
}
pub async fn agent_model_delete(&self, id: Uuid, ctx: &Session) -> Result<(), AppError> {
require_system_caller(ctx)?;
MEntity::delete_by_id(id).exec(&self.db).await?;
Ok(())
}
}