gitdataai/lib/service/ai/model.rs

161 lines
5.5 KiB
Rust

use crate::AppService;
use crate::ai::types::{AiModelFilter, AiModelListItem, AiModelResponse};
use crate::error::AppError;
use crate::{Pagination, session_user};
use db::sqlx;
use db::sqlx::AssertSqlSafe;
use model::ai::AiModelModel;
use model::ai::AiProviderModel;
use session::Session;
impl AppService {
pub async fn ai_model_list(
&self,
ctx: &Session,
filter: AiModelFilter,
pagination: Pagination,
) -> Result<Vec<AiModelListItem>, AppError> {
let _user_uid = session_user(ctx)?;
let mut conditions = vec![
"m.public = true".to_string(),
"m.deleted_at IS NULL".to_string(),
];
let mut param_idx = 1;
if filter.enabled.is_some() {
conditions.push(format!("m.enabled = ${param_idx}"));
param_idx += 1;
}
if filter.provider.is_some() {
conditions.push(format!("m.provider = ${param_idx}"));
param_idx += 1;
}
if filter.modality.is_some() {
conditions.push(format!("m.modality = ${param_idx}"));
param_idx += 1;
}
if filter.name.is_some() {
conditions.push(format!("m.name ILIKE ${param_idx}"));
param_idx += 1;
}
let where_clause = conditions.join(" AND ");
let limit_idx = param_idx;
let offset_idx = param_idx + 1;
let query = format!(
"SELECT m.id, m.provider, m.name, m.display_name, m.description, m.modality, \
m.context_window, m.input_token_limit, m.output_token_limit, \
m.enabled, m.public, m.created_at, m.updated_at, m.deleted_at \
FROM ai_model m WHERE {where_clause} \
ORDER BY m.display_name ASC LIMIT ${limit_idx} OFFSET ${offset_idx}"
);
let mut q = sqlx::query_as::<_, AiModelModel>(AssertSqlSafe(query));
if let Some(enabled) = &filter.enabled {
q = q.bind(enabled);
}
if let Some(provider) = &filter.provider {
q = q.bind(provider);
}
if let Some(modality) = &filter.modality {
q = q.bind(modality);
}
if let Some(name) = &filter.name {
q = q.bind(format!("%{}%", name));
}
q = q
.bind(pagination.limit() as i64)
.bind(pagination.offset() as i64);
let models = q
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let mut results = Vec::new();
for m in models {
let provider = self.ai_provider_by_id(m.provider).await?;
results.push(AiModelListItem {
id: m.id,
name: m.name,
display_name: m.display_name,
description: crate::non_empty(
m.description.unwrap_or_default(),
),
modality: m.modality,
provider_name: provider.name,
provider_logo_url: provider.logo_url,
context_window: m.context_window,
input_token_limit: m.input_token_limit,
output_token_limit: m.output_token_limit,
enabled: m.enabled,
created_at: m.created_at,
updated_at: m.updated_at,
});
}
Ok(results)
}
pub async fn ai_model_get(
&self,
ctx: &Session,
id: uuid::Uuid,
) -> Result<AiModelResponse, AppError> {
let _user_uid = session_user(ctx)?;
let m = sqlx::query_as::<_, AiModelModel>(
"SELECT id, provider, name, display_name, description, modality, \
context_window, input_token_limit, output_token_limit, \
enabled, public, created_at, updated_at, deleted_at \
FROM ai_model WHERE id = $1 AND (public = true OR deleted_at IS NULL)",
)
.bind(id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or(AppError::NotFound("model not found".to_string()))?;
let provider = self.ai_provider_by_id(m.provider).await?;
let card = self.ai_card_get_inner(m.id).await?;
let versions = self.ai_version_list_inner(m.id).await?;
let tags = self.ai_tag_list_inner(m.id).await?;
let like_count = self.ai_like_count_inner(m.id).await?;
Ok(AiModelResponse {
id: m.id,
name: m.name,
display_name: m.display_name,
description: crate::non_empty(m.description.unwrap_or_default()),
modality: m.modality,
context_window: m.context_window,
input_token_limit: m.input_token_limit,
output_token_limit: m.output_token_limit,
enabled: m.enabled,
public: m.public,
provider_name: provider.name,
provider_logo_url: provider.logo_url,
card,
versions,
tags,
like_count,
created_at: m.created_at,
updated_at: m.updated_at,
})
}
pub async fn ai_provider_by_id(
&self,
id: uuid::Uuid,
) -> Result<AiProviderModel, AppError> {
sqlx::query_as::<_, AiProviderModel>(
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
FROM ai_provider WHERE id = $1",
)
.bind(id)
.fetch_one(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))
}
}