gitdataai/libs/agent/model/model_entry.rs
ZhenYi 10c0cc007b refactor(agent): split into submodules and add Qdrant embedding
- Split agent crate into client/, model/, agent/ subdirs
- Add billing.rs for token usage recording
- Add sync.rs for upstream model sync
- EmbedService: Qdrant-backed vector memory for semantic search
- ChatService: wire EmbedService for memory lookup, passive skill awareness
- ReAct loop: streamline with tokio::select! and proper error handling
2026-04-25 20:09:33 +08:00

201 lines
6.2 KiB
Rust

//! AI model management — CRUD.
//!
//! All functions take `&DatabaseConnection` instead of `&AppService`.
use chrono::Utc;
use db::database::AppDatabase;
use models::agents::model;
use models::agents::{
ModelCapability, ModelModality, ModelStatus,
model::{Column as MColumn, Entity as MEntity},
model_provider::Entity as ProviderEntity,
};
use sea_orm::*;
use uuid::Uuid;
use crate::error::AgentError;
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct CreateModelRequest {
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
#[serde(default)]
pub is_open_source: bool,
}
#[derive(Debug, Clone, serde::Deserialize, utoipa::ToSchema)]
pub struct UpdateModelRequest {
pub display_name: Option<String>,
pub modality: Option<String>,
pub capability: Option<String>,
pub context_length: Option<i64>,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: Option<bool>,
pub status: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
pub struct ModelResponse {
pub id: Uuid,
pub provider_id: Uuid,
pub name: String,
pub modality: String,
pub capability: String,
pub context_length: i64,
pub max_output_tokens: Option<i64>,
pub training_cutoff: Option<chrono::DateTime<Utc>>,
pub is_open_source: bool,
pub status: String,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
impl From<model::Model> for ModelResponse {
fn from(m: model::Model) -> Self {
Self {
id: m.id,
provider_id: m.provider_id,
name: m.name,
modality: m.modality,
capability: m.capability,
context_length: m.context_length,
max_output_tokens: m.max_output_tokens,
training_cutoff: m.training_cutoff,
is_open_source: m.is_open_source,
status: m.status,
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}
/// List models, optionally filtered by provider.
pub async fn list_models(
db: &AppDatabase,
provider_id: Option<Uuid>,
) -> Result<Vec<ModelResponse>, AgentError> {
let mut query = MEntity::find().order_by_asc(MColumn::Name);
if let Some(pid) = provider_id {
query = query.filter(MColumn::ProviderId.eq(pid));
}
let models = query.all(db).await?;
Ok(models.into_iter().map(ModelResponse::from).collect())
}
/// Get a single model by ID.
pub async fn get_model(db: &AppDatabase, id: Uuid) -> Result<ModelResponse, AgentError> {
let model = MEntity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
Ok(ModelResponse::from(model))
}
/// Create a new model.
pub async fn create_model(
db: &AppDatabase,
request: CreateModelRequest,
) -> Result<ModelResponse, AgentError> {
ProviderEntity::find_by_id(request.provider_id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound("Provider not found".to_string()))?;
let _ = request
.modality
.parse::<ModelModality>()
.map_err(|_| AgentError::InvalidInput {
field: "modality".into(),
reason: "Invalid modality".into(),
})?;
let _ = request
.capability
.parse::<ModelCapability>()
.map_err(|_| AgentError::InvalidInput {
field: "capability".into(),
reason: "Invalid capability".into(),
})?;
let now = Utc::now();
let active = model::ActiveModel {
id: Set(Uuid::now_v7()),
provider_id: Set(request.provider_id),
name: Set(request.name),
modality: Set(request.modality),
capability: Set(request.capability),
context_length: Set(request.context_length),
max_output_tokens: Set(request.max_output_tokens),
training_cutoff: Set(request.training_cutoff),
is_open_source: Set(request.is_open_source),
status: Set(ModelStatus::Active.to_string()),
created_at: Set(now),
updated_at: Set(now),
..Default::default()
};
let model = active.insert(db).await?;
Ok(ModelResponse::from(model))
}
/// Update an existing model.
pub async fn update_model(
db: &AppDatabase,
id: Uuid,
request: UpdateModelRequest,
) -> Result<ModelResponse, AgentError> {
let model = MEntity::find_by_id(id)
.one(db)
.await?
.ok_or_else(|| AgentError::NotFound(format!("Model not found: {}", id)))?;
let mut active: model::ActiveModel = model.into();
if let Some(modality) = request.modality {
let _ = modality
.parse::<ModelModality>()
.map_err(|_| AgentError::InvalidInput {
field: "modality".into(),
reason: "Invalid modality".into(),
})?;
active.modality = Set(modality);
}
if let Some(capability) = request.capability {
let _ = capability
.parse::<ModelCapability>()
.map_err(|_| AgentError::InvalidInput {
field: "capability".into(),
reason: "Invalid capability".into(),
})?;
active.capability = Set(capability);
}
if let Some(context_length) = request.context_length {
active.context_length = Set(context_length);
}
if let Some(max_output_tokens) = request.max_output_tokens {
active.max_output_tokens = Set(Some(max_output_tokens));
}
if let Some(training_cutoff) = request.training_cutoff {
active.training_cutoff = Set(Some(training_cutoff));
}
if let Some(is_open_source) = request.is_open_source {
active.is_open_source = Set(is_open_source);
}
if let Some(status) = request.status {
active.status = Set(status);
}
active.updated_at = Set(Utc::now());
let model = active.update(db).await?;
Ok(ModelResponse::from(model))
}
/// Delete a model by ID.
pub async fn delete_model(db: &AppDatabase, id: Uuid) -> Result<(), AgentError> {
MEntity::delete_by_id(id).exec(db).await?;
Ok(())
}