gitdataai/libs/agent/model/parameter_profile.rs

143 lines
4.6 KiB
Rust

//! Model parameter profile management — CRUD.
use db::database::AppDatabase;
use models::agents::model_parameter_profile;
use sea_orm::*;
use uuid::Uuid;
use crate::error::AgentError;
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateModelParameterProfileRequest {
pub model_version_id: Uuid,
pub temperature_min: f64,
pub temperature_max: f64,
pub top_p_min: f64,
pub top_p_max: f64,
#[serde(default)]
pub frequency_penalty_supported: bool,
#[serde(default)]
pub presence_penalty_supported: bool,
}
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct UpdateModelParameterProfileRequest {
pub temperature_min: Option<f64>,
pub temperature_max: Option<f64>,
pub top_p_min: Option<f64>,
pub top_p_max: Option<f64>,
pub frequency_penalty_supported: Option<bool>,
pub presence_penalty_supported: Option<bool>,
}
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
pub struct ModelParameterProfileResponse {
pub id: i64,
pub model_version_id: Uuid,
pub temperature_min: f64,
pub temperature_max: f64,
pub top_p_min: f64,
pub top_p_max: f64,
pub frequency_penalty_supported: bool,
pub presence_penalty_supported: bool,
}
impl From<model_parameter_profile::Model> for ModelParameterProfileResponse {
fn from(p: model_parameter_profile::Model) -> Self {
Self {
id: p.id,
model_version_id: p.model_version_id,
temperature_min: p.temperature_min,
temperature_max: p.temperature_max,
top_p_min: p.top_p_min,
top_p_max: p.top_p_max,
frequency_penalty_supported: p.frequency_penalty_supported,
presence_penalty_supported: p.presence_penalty_supported,
}
}
}
pub async fn list_parameter_profiles(
db: &AppDatabase,
model_version_id: Uuid,
) -> Result<Vec<ModelParameterProfileResponse>, AgentError> {
let profiles = model_parameter_profile::Entity::find()
.filter(model_parameter_profile::Column::ModelVersionId.eq(model_version_id))
.all(db)
.await?;
Ok(profiles
.into_iter()
.map(ModelParameterProfileResponse::from)
.collect())
}
pub async fn get_parameter_profile(
db: &AppDatabase,
id: i64,
) -> Result<ModelParameterProfileResponse, AgentError> {
let profile = model_parameter_profile::Entity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Parameter profile not found: {}", id)))?;
Ok(ModelParameterProfileResponse::from(profile))
}
pub async fn create_parameter_profile(
db: &AppDatabase,
request: CreateModelParameterProfileRequest,
) -> Result<ModelParameterProfileResponse, AgentError> {
let active = model_parameter_profile::ActiveModel {
model_version_id: Set(request.model_version_id),
temperature_min: Set(request.temperature_min),
temperature_max: Set(request.temperature_max),
top_p_min: Set(request.top_p_min),
top_p_max: Set(request.top_p_max),
frequency_penalty_supported: Set(request.frequency_penalty_supported),
presence_penalty_supported: Set(request.presence_penalty_supported),
..Default::default()
};
let profile = active.insert(db).await?;
Ok(ModelParameterProfileResponse::from(profile))
}
pub async fn update_parameter_profile(
db: &AppDatabase,
id: i64,
request: UpdateModelParameterProfileRequest,
) -> Result<ModelParameterProfileResponse, AgentError> {
let profile = model_parameter_profile::Entity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Parameter profile not found: {}", id)))?;
let mut active: model_parameter_profile::ActiveModel = profile.into();
if let Some(v) = request.temperature_min {
active.temperature_min = Set(v);
}
if let Some(v) = request.temperature_max {
active.temperature_max = Set(v);
}
if let Some(v) = request.top_p_min {
active.top_p_min = Set(v);
}
if let Some(v) = request.top_p_max {
active.top_p_max = Set(v);
}
if let Some(v) = request.frequency_penalty_supported {
active.frequency_penalty_supported = Set(v);
}
if let Some(v) = request.presence_penalty_supported {
active.presence_penalty_supported = Set(v);
}
let profile = active.update(db).await?;
Ok(ModelParameterProfileResponse::from(profile))
}
pub async fn delete_parameter_profile(db: &AppDatabase, id: i64) -> Result<(), AgentError> {
model_parameter_profile::Entity::delete_by_id(id)
.exec(db)
.await?;
Ok(())
}