198 lines
6.4 KiB
Rust
198 lines
6.4 KiB
Rust
use crate::AppService;
|
|
use crate::error::AppError;
|
|
use chrono::Utc;
|
|
use models::agents::model;
|
|
use models::agents::{
|
|
ModelCapability, ModelModality, ModelStatus,
|
|
model::{Column as MColumn, Entity as MEntity},
|
|
model_provider::Entity as ProviderEntity,
|
|
};
|
|
use sea_orm::*;
|
|
use serde::{Deserialize, Serialize};
|
|
use session::Session;
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
use super::provider::require_system_caller;
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct CreateModelRequest {
|
|
pub provider_id: Uuid,
|
|
pub name: String,
|
|
pub modality: String,
|
|
pub capability: String,
|
|
pub context_length: i64,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
#[serde(default)]
|
|
pub is_open_source: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct UpdateModelRequest {
|
|
pub display_name: Option<String>,
|
|
pub modality: Option<String>,
|
|
pub capability: Option<String>,
|
|
pub context_length: Option<i64>,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
pub is_open_source: Option<bool>,
|
|
pub status: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct ModelResponse {
|
|
pub id: Uuid,
|
|
pub provider_id: Uuid,
|
|
pub name: String,
|
|
pub modality: String,
|
|
pub capability: String,
|
|
pub context_length: i64,
|
|
pub max_output_tokens: Option<i64>,
|
|
pub training_cutoff: Option<chrono::DateTime<Utc>>,
|
|
pub is_open_source: bool,
|
|
pub status: String,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
pub updated_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
impl From<model::Model> for ModelResponse {
|
|
fn from(m: model::Model) -> Self {
|
|
Self {
|
|
id: m.id,
|
|
provider_id: m.provider_id,
|
|
name: m.name,
|
|
modality: m.modality,
|
|
capability: m.capability,
|
|
context_length: m.context_length,
|
|
max_output_tokens: m.max_output_tokens,
|
|
training_cutoff: m.training_cutoff,
|
|
is_open_source: m.is_open_source,
|
|
status: m.status,
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn agent_model_list(
|
|
&self,
|
|
provider_id: Option<Uuid>,
|
|
_ctx: &Session,
|
|
) -> Result<Vec<ModelResponse>, AppError> {
|
|
let mut query = MEntity::find().order_by_asc(MColumn::Name);
|
|
if let Some(pid) = provider_id {
|
|
query = query.filter(MColumn::ProviderId.eq(pid));
|
|
}
|
|
let models = query.all(&self.db).await?;
|
|
Ok(models.into_iter().map(ModelResponse::from).collect())
|
|
}
|
|
|
|
pub async fn agent_model_get(
|
|
&self,
|
|
id: Uuid,
|
|
_ctx: &Session,
|
|
) -> Result<ModelResponse, AppError> {
|
|
let model = MEntity::find_by_id(id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound("Model not found".to_string()))?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
pub async fn agent_model_create(
|
|
&self,
|
|
request: CreateModelRequest,
|
|
ctx: &Session,
|
|
) -> Result<ModelResponse, AppError> {
|
|
require_system_caller(ctx)?;
|
|
|
|
ProviderEntity::find_by_id(request.provider_id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound("Provider not found".to_string()))?;
|
|
|
|
let _ = request
|
|
.modality
|
|
.parse::<ModelModality>()
|
|
.map_err(|_| AppError::BadRequest("Invalid modality".to_string()))?;
|
|
let _ = request
|
|
.capability
|
|
.parse::<ModelCapability>()
|
|
.map_err(|_| AppError::BadRequest("Invalid capability".to_string()))?;
|
|
|
|
let now = Utc::now();
|
|
let active = model::ActiveModel {
|
|
id: Set(Uuid::now_v7()),
|
|
provider_id: Set(request.provider_id),
|
|
name: Set(request.name),
|
|
modality: Set(request.modality),
|
|
capability: Set(request.capability),
|
|
context_length: Set(request.context_length),
|
|
max_output_tokens: Set(request.max_output_tokens),
|
|
training_cutoff: Set(request.training_cutoff),
|
|
is_open_source: Set(request.is_open_source),
|
|
status: Set(ModelStatus::Active.to_string()),
|
|
created_at: Set(now),
|
|
updated_at: Set(now),
|
|
..Default::default()
|
|
};
|
|
let model = active.insert(&self.db).await?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
pub async fn agent_model_update(
|
|
&self,
|
|
id: Uuid,
|
|
request: UpdateModelRequest,
|
|
ctx: &Session,
|
|
) -> Result<ModelResponse, AppError> {
|
|
require_system_caller(ctx)?;
|
|
|
|
let model = MEntity::find_by_id(id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound("Model not found".to_string()))?;
|
|
|
|
let mut active: model::ActiveModel = model.into();
|
|
if let Some(modality) = request.modality {
|
|
let _ = modality
|
|
.parse::<ModelModality>()
|
|
.map_err(|_| AppError::BadRequest("Invalid modality".to_string()))?;
|
|
active.modality = Set(modality);
|
|
}
|
|
if let Some(capability) = request.capability {
|
|
let _ = capability
|
|
.parse::<ModelCapability>()
|
|
.map_err(|_| AppError::BadRequest("Invalid capability".to_string()))?;
|
|
active.capability = Set(capability);
|
|
}
|
|
if let Some(context_length) = request.context_length {
|
|
active.context_length = Set(context_length);
|
|
}
|
|
if let Some(max_output_tokens) = request.max_output_tokens {
|
|
active.max_output_tokens = Set(Some(max_output_tokens));
|
|
}
|
|
if let Some(training_cutoff) = request.training_cutoff {
|
|
active.training_cutoff = Set(Some(training_cutoff));
|
|
}
|
|
if let Some(is_open_source) = request.is_open_source {
|
|
active.is_open_source = Set(is_open_source);
|
|
}
|
|
if let Some(status) = request.status {
|
|
active.status = Set(status);
|
|
}
|
|
active.updated_at = Set(Utc::now());
|
|
|
|
let model = active.update(&self.db).await?;
|
|
Ok(ModelResponse::from(model))
|
|
}
|
|
|
|
pub async fn agent_model_delete(&self, id: Uuid, ctx: &Session) -> Result<(), AppError> {
|
|
require_system_caller(ctx)?;
|
|
MEntity::delete_by_id(id).exec(&self.db).await?;
|
|
Ok(())
|
|
}
|
|
}
|