161 lines
5.5 KiB
Rust
161 lines
5.5 KiB
Rust
use crate::AppService;
|
|
use crate::ai::types::{AiModelFilter, AiModelListItem, AiModelResponse};
|
|
use crate::error::AppError;
|
|
use crate::{Pagination, session_user};
|
|
use db::sqlx;
|
|
use db::sqlx::AssertSqlSafe;
|
|
use model::ai::AiModelModel;
|
|
use model::ai::AiProviderModel;
|
|
use session::Session;
|
|
|
|
impl AppService {
|
|
pub async fn ai_model_list(
|
|
&self,
|
|
ctx: &Session,
|
|
filter: AiModelFilter,
|
|
pagination: Pagination,
|
|
) -> Result<Vec<AiModelListItem>, AppError> {
|
|
let _user_uid = session_user(ctx)?;
|
|
|
|
let mut conditions = vec![
|
|
"m.public = true".to_string(),
|
|
"m.deleted_at IS NULL".to_string(),
|
|
];
|
|
let mut param_idx = 1;
|
|
|
|
if filter.enabled.is_some() {
|
|
conditions.push(format!("m.enabled = ${param_idx}"));
|
|
param_idx += 1;
|
|
}
|
|
if filter.provider.is_some() {
|
|
conditions.push(format!("m.provider = ${param_idx}"));
|
|
param_idx += 1;
|
|
}
|
|
if filter.modality.is_some() {
|
|
conditions.push(format!("m.modality = ${param_idx}"));
|
|
param_idx += 1;
|
|
}
|
|
if filter.name.is_some() {
|
|
conditions.push(format!("m.name ILIKE ${param_idx}"));
|
|
param_idx += 1;
|
|
}
|
|
|
|
let where_clause = conditions.join(" AND ");
|
|
let limit_idx = param_idx;
|
|
let offset_idx = param_idx + 1;
|
|
|
|
let query = format!(
|
|
"SELECT m.id, m.provider, m.name, m.display_name, m.description, m.modality, \
|
|
m.context_window, m.input_token_limit, m.output_token_limit, \
|
|
m.enabled, m.public, m.created_at, m.updated_at, m.deleted_at \
|
|
FROM ai_model m WHERE {where_clause} \
|
|
ORDER BY m.display_name ASC LIMIT ${limit_idx} OFFSET ${offset_idx}"
|
|
);
|
|
|
|
let mut q = sqlx::query_as::<_, AiModelModel>(AssertSqlSafe(query));
|
|
if let Some(enabled) = &filter.enabled {
|
|
q = q.bind(enabled);
|
|
}
|
|
if let Some(provider) = &filter.provider {
|
|
q = q.bind(provider);
|
|
}
|
|
if let Some(modality) = &filter.modality {
|
|
q = q.bind(modality);
|
|
}
|
|
if let Some(name) = &filter.name {
|
|
q = q.bind(format!("%{}%", name));
|
|
}
|
|
q = q
|
|
.bind(pagination.limit() as i64)
|
|
.bind(pagination.offset() as i64);
|
|
|
|
let models = q
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let mut results = Vec::new();
|
|
for m in models {
|
|
let provider = self.ai_provider_by_id(m.provider).await?;
|
|
results.push(AiModelListItem {
|
|
id: m.id,
|
|
name: m.name,
|
|
display_name: m.display_name,
|
|
description: crate::non_empty(
|
|
m.description.unwrap_or_default(),
|
|
),
|
|
modality: m.modality,
|
|
provider_name: provider.name,
|
|
provider_logo_url: provider.logo_url,
|
|
context_window: m.context_window,
|
|
input_token_limit: m.input_token_limit,
|
|
output_token_limit: m.output_token_limit,
|
|
enabled: m.enabled,
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
});
|
|
}
|
|
Ok(results)
|
|
}
|
|
pub async fn ai_model_get(
|
|
&self,
|
|
ctx: &Session,
|
|
id: uuid::Uuid,
|
|
) -> Result<AiModelResponse, AppError> {
|
|
let _user_uid = session_user(ctx)?;
|
|
|
|
let m = sqlx::query_as::<_, AiModelModel>(
|
|
"SELECT id, provider, name, display_name, description, modality, \
|
|
context_window, input_token_limit, output_token_limit, \
|
|
enabled, public, created_at, updated_at, deleted_at \
|
|
FROM ai_model WHERE id = $1 AND (public = true OR deleted_at IS NULL)",
|
|
)
|
|
.bind(id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or(AppError::NotFound("model not found".to_string()))?;
|
|
|
|
let provider = self.ai_provider_by_id(m.provider).await?;
|
|
let card = self.ai_card_get_inner(m.id).await?;
|
|
let versions = self.ai_version_list_inner(m.id).await?;
|
|
let tags = self.ai_tag_list_inner(m.id).await?;
|
|
let like_count = self.ai_like_count_inner(m.id).await?;
|
|
|
|
Ok(AiModelResponse {
|
|
id: m.id,
|
|
name: m.name,
|
|
display_name: m.display_name,
|
|
description: crate::non_empty(m.description.unwrap_or_default()),
|
|
modality: m.modality,
|
|
context_window: m.context_window,
|
|
input_token_limit: m.input_token_limit,
|
|
output_token_limit: m.output_token_limit,
|
|
enabled: m.enabled,
|
|
public: m.public,
|
|
provider_name: provider.name,
|
|
provider_logo_url: provider.logo_url,
|
|
card,
|
|
versions,
|
|
tags,
|
|
like_count,
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
})
|
|
}
|
|
|
|
pub async fn ai_provider_by_id(
|
|
&self,
|
|
id: uuid::Uuid,
|
|
) -> Result<AiProviderModel, AppError> {
|
|
sqlx::query_as::<_, AiProviderModel>(
|
|
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
|
|
FROM ai_provider WHERE id = $1",
|
|
)
|
|
.bind(id)
|
|
.fetch_one(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))
|
|
}
|
|
}
|