gitdataai/lib/ai/sync.rs
2026-05-30 01:38:40 +08:00

127 lines
3.8 KiB
Rust

use std::error::Error;
use std::sync::LazyLock;
use tracing::{debug, warn};
use crate::{
client::EndpointConfig,
error::{AiError, AiResult},
};
#[derive(Debug, serde::Deserialize)]
struct ModelsListResponse {
data: Vec<UpstreamModel>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamModel {
pub id: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub owned_by: Option<String>,
#[serde(default)]
pub context_length: Option<i32>,
#[serde(default)]
pub max_output_tokens: Option<i32>,
#[serde(default)]
pub capabilities: Option<UpstreamCapabilities>,
#[serde(default)]
pub pricing: Option<UpstreamPricing>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamCapabilities {
#[serde(default)]
pub vision: Option<bool>,
#[serde(default)]
pub tool_call: Option<bool>,
#[serde(default)]
pub reasoning: Option<bool>,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct UpstreamPricing {
#[serde(default)]
pub prompt: Option<String>,
#[serde(default)]
pub completion: Option<String>,
#[serde(default)]
pub input: Option<f64>,
#[serde(default)]
pub output: Option<f64>,
#[serde(default)]
pub cache_read: Option<f64>,
#[serde(default)]
pub unit: Option<String>,
#[serde(default)]
pub currency: Option<String>,
}
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
let mut builder = reqwest::Client::builder();
let proxy_url = std::env::var("HTTPS_PROXY")
.or_else(|_| std::env::var("https_proxy"))
.or_else(|_| std::env::var("HTTP_PROXY"))
.or_else(|_| std::env::var("http_proxy"))
.ok();
if let Some(raw) = &proxy_url {
let url = raw.trim().trim_matches('"').trim_matches('\'');
match reqwest::Proxy::all(url) {
Ok(proxy) => {
debug!(proxy_url = %url, "sync: using proxy");
builder = builder.proxy(proxy);
}
Err(e) => {
warn!(proxy_url = %url, error = %e, "sync: invalid proxy URL, skipping");
}
}
}
#[allow(clippy::expect_used)]
builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration")
});
pub async fn list_models(
config: &EndpointConfig,
) -> AiResult<Vec<UpstreamModel>> {
let base = config.base_url.trim_end_matches('/');
let url = if base.ends_with("/v1") {
format!("{}/models", base)
} else {
format!("{}/v1/models", base)
};
debug!(url = %url, "listing models from upstream");
let resp = HTTP_CLIENT
.get(&url)
.header("Authorization", format!("Bearer {}", config.api_key.trim()))
.send()
.await
.map_err(|e| {
tracing::error!(
error = %e,
source = ?e.source(),
"list_models: request failed with full cause chain"
);
AiError::Response(format!("failed to list models: {}", e))
})?;
let body = resp
.text()
.await
.map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?;
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
debug!(count = parsed.data.len(), "parsed models in standard format");
return Ok(parsed.data);
}
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {
debug!(count = parsed.len(), "parsed models in array format");
return Ok(parsed);
}
warn!(
body = %body.chars().take(500).collect::<String>(),
"list_models: unknown response format"
);
Err(AiError::Response(format!(
"unexpected /v1/models response format (first 200 chars): {}",
body.chars().take(200).collect::<String>()
)))
}