164 lines
5.3 KiB
Rust
164 lines
5.3 KiB
Rust
use crate::AppService;
|
|
use crate::error::AppError;
|
|
use models::agents::model_parameter_profile;
|
|
use sea_orm::*;
|
|
use serde::{Deserialize, Serialize};
|
|
use session::Session;
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
use super::provider::require_system_caller;
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct CreateModelParameterProfileRequest {
|
|
pub model_version_id: Uuid,
|
|
pub temperature_min: f64,
|
|
pub temperature_max: f64,
|
|
pub top_p_min: f64,
|
|
pub top_p_max: f64,
|
|
#[serde(default)]
|
|
pub frequency_penalty_supported: bool,
|
|
#[serde(default)]
|
|
pub presence_penalty_supported: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct UpdateModelParameterProfileRequest {
|
|
pub temperature_min: Option<f64>,
|
|
pub temperature_max: Option<f64>,
|
|
pub top_p_min: Option<f64>,
|
|
pub top_p_max: Option<f64>,
|
|
pub frequency_penalty_supported: Option<bool>,
|
|
pub presence_penalty_supported: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct ModelParameterProfileResponse {
|
|
pub id: i64,
|
|
pub model_version_id: Uuid,
|
|
pub temperature_min: f64,
|
|
pub temperature_max: f64,
|
|
pub top_p_min: f64,
|
|
pub top_p_max: f64,
|
|
pub frequency_penalty_supported: bool,
|
|
pub presence_penalty_supported: bool,
|
|
}
|
|
|
|
impl From<model_parameter_profile::Model> for ModelParameterProfileResponse {
|
|
fn from(p: model_parameter_profile::Model) -> Self {
|
|
Self {
|
|
id: p.id,
|
|
model_version_id: p.model_version_id,
|
|
temperature_min: p.temperature_min,
|
|
temperature_max: p.temperature_max,
|
|
top_p_min: p.top_p_min,
|
|
top_p_max: p.top_p_max,
|
|
frequency_penalty_supported: p.frequency_penalty_supported,
|
|
presence_penalty_supported: p.presence_penalty_supported,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn agent_model_parameter_profile_list(
|
|
&self,
|
|
model_version_id: Uuid,
|
|
_ctx: &Session,
|
|
) -> Result<Vec<ModelParameterProfileResponse>, AppError> {
|
|
let profiles = model_parameter_profile::Entity::find()
|
|
.filter(model_parameter_profile::Column::ModelVersionId.eq(model_version_id))
|
|
.all(&self.db)
|
|
.await?;
|
|
Ok(profiles
|
|
.into_iter()
|
|
.map(ModelParameterProfileResponse::from)
|
|
.collect())
|
|
}
|
|
|
|
pub async fn agent_model_parameter_profile_get(
|
|
&self,
|
|
id: i64,
|
|
_ctx: &Session,
|
|
) -> Result<ModelParameterProfileResponse, AppError> {
|
|
let profile = model_parameter_profile::Entity::find_by_id(id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound(
|
|
"Parameter profile not found".to_string(),
|
|
))?;
|
|
Ok(ModelParameterProfileResponse::from(profile))
|
|
}
|
|
|
|
pub async fn agent_model_parameter_profile_create(
|
|
&self,
|
|
request: CreateModelParameterProfileRequest,
|
|
ctx: &Session,
|
|
) -> Result<ModelParameterProfileResponse, AppError> {
|
|
require_system_caller(ctx)?;
|
|
|
|
let active = model_parameter_profile::ActiveModel {
|
|
model_version_id: Set(request.model_version_id),
|
|
temperature_min: Set(request.temperature_min),
|
|
temperature_max: Set(request.temperature_max),
|
|
top_p_min: Set(request.top_p_min),
|
|
top_p_max: Set(request.top_p_max),
|
|
frequency_penalty_supported: Set(request.frequency_penalty_supported),
|
|
presence_penalty_supported: Set(request.presence_penalty_supported),
|
|
..Default::default()
|
|
};
|
|
let profile = active.insert(&self.db).await?;
|
|
Ok(ModelParameterProfileResponse::from(profile))
|
|
}
|
|
|
|
pub async fn agent_model_parameter_profile_update(
|
|
&self,
|
|
id: i64,
|
|
request: UpdateModelParameterProfileRequest,
|
|
ctx: &Session,
|
|
) -> Result<ModelParameterProfileResponse, AppError> {
|
|
require_system_caller(ctx)?;
|
|
|
|
let profile = model_parameter_profile::Entity::find_by_id(id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or(AppError::NotFound(
|
|
"Parameter profile not found".to_string(),
|
|
))?;
|
|
|
|
let mut active: model_parameter_profile::ActiveModel = profile.into();
|
|
if let Some(v) = request.temperature_min {
|
|
active.temperature_min = Set(v);
|
|
}
|
|
if let Some(v) = request.temperature_max {
|
|
active.temperature_max = Set(v);
|
|
}
|
|
if let Some(v) = request.top_p_min {
|
|
active.top_p_min = Set(v);
|
|
}
|
|
if let Some(v) = request.top_p_max {
|
|
active.top_p_max = Set(v);
|
|
}
|
|
if let Some(v) = request.frequency_penalty_supported {
|
|
active.frequency_penalty_supported = Set(v);
|
|
}
|
|
if let Some(v) = request.presence_penalty_supported {
|
|
active.presence_penalty_supported = Set(v);
|
|
}
|
|
|
|
let profile = active.update(&self.db).await?;
|
|
Ok(ModelParameterProfileResponse::from(profile))
|
|
}
|
|
|
|
pub async fn agent_model_parameter_profile_delete(
|
|
&self,
|
|
id: i64,
|
|
ctx: &Session,
|
|
) -> Result<(), AppError> {
|
|
require_system_caller(ctx)?;
|
|
model_parameter_profile::Entity::delete_by_id(id)
|
|
.exec(&self.db)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
}
|