127 lines
3.8 KiB
Rust
127 lines
3.8 KiB
Rust
use std::error::Error;
|
|
use std::sync::LazyLock;
|
|
|
|
use tracing::{debug, warn};
|
|
|
|
use crate::{
|
|
client::EndpointConfig,
|
|
error::{AiError, AiResult},
|
|
};
|
|
#[derive(Debug, serde::Deserialize)]
|
|
struct ModelsListResponse {
|
|
data: Vec<UpstreamModel>,
|
|
}
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
pub struct UpstreamModel {
|
|
pub id: String,
|
|
#[serde(default)]
|
|
pub name: Option<String>,
|
|
#[serde(default)]
|
|
pub owned_by: Option<String>,
|
|
#[serde(default)]
|
|
pub context_length: Option<i32>,
|
|
#[serde(default)]
|
|
pub max_output_tokens: Option<i32>,
|
|
#[serde(default)]
|
|
pub capabilities: Option<UpstreamCapabilities>,
|
|
#[serde(default)]
|
|
pub pricing: Option<UpstreamPricing>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
pub struct UpstreamCapabilities {
|
|
#[serde(default)]
|
|
pub vision: Option<bool>,
|
|
#[serde(default)]
|
|
pub tool_call: Option<bool>,
|
|
#[serde(default)]
|
|
pub reasoning: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
pub struct UpstreamPricing {
|
|
#[serde(default)]
|
|
pub prompt: Option<String>,
|
|
#[serde(default)]
|
|
pub completion: Option<String>,
|
|
#[serde(default)]
|
|
pub input: Option<f64>,
|
|
#[serde(default)]
|
|
pub output: Option<f64>,
|
|
#[serde(default)]
|
|
pub cache_read: Option<f64>,
|
|
#[serde(default)]
|
|
pub unit: Option<String>,
|
|
#[serde(default)]
|
|
pub currency: Option<String>,
|
|
}
|
|
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
|
let mut builder = reqwest::Client::builder();
|
|
let proxy_url = std::env::var("HTTPS_PROXY")
|
|
.or_else(|_| std::env::var("https_proxy"))
|
|
.or_else(|_| std::env::var("HTTP_PROXY"))
|
|
.or_else(|_| std::env::var("http_proxy"))
|
|
.ok();
|
|
if let Some(raw) = &proxy_url {
|
|
let url = raw.trim().trim_matches('"').trim_matches('\'');
|
|
match reqwest::Proxy::all(url) {
|
|
Ok(proxy) => {
|
|
debug!(proxy_url = %url, "sync: using proxy");
|
|
builder = builder.proxy(proxy);
|
|
}
|
|
Err(e) => {
|
|
warn!(proxy_url = %url, error = %e, "sync: invalid proxy URL, skipping");
|
|
}
|
|
}
|
|
}
|
|
#[allow(clippy::expect_used)]
|
|
builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration")
|
|
});
|
|
pub async fn list_models(
|
|
config: &EndpointConfig,
|
|
) -> AiResult<Vec<UpstreamModel>> {
|
|
let base = config.base_url.trim_end_matches('/');
|
|
let url = if base.ends_with("/v1") {
|
|
format!("{}/models", base)
|
|
} else {
|
|
format!("{}/v1/models", base)
|
|
};
|
|
|
|
debug!(url = %url, "listing models from upstream");
|
|
let resp = HTTP_CLIENT
|
|
.get(&url)
|
|
.header("Authorization", format!("Bearer {}", config.api_key.trim()))
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
tracing::error!(
|
|
error = %e,
|
|
source = ?e.source(),
|
|
"list_models: request failed with full cause chain"
|
|
);
|
|
AiError::Response(format!("failed to list models: {}", e))
|
|
})?;
|
|
|
|
let body = resp
|
|
.text()
|
|
.await
|
|
.map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?;
|
|
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
|
|
debug!(count = parsed.data.len(), "parsed models in standard format");
|
|
return Ok(parsed.data);
|
|
}
|
|
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {
|
|
debug!(count = parsed.len(), "parsed models in array format");
|
|
return Ok(parsed);
|
|
}
|
|
|
|
warn!(
|
|
body = %body.chars().take(500).collect::<String>(),
|
|
"list_models: unknown response format"
|
|
);
|
|
Err(AiError::Response(format!(
|
|
"unexpected /v1/models response format (first 200 chars): {}",
|
|
body.chars().take(200).collect::<String>()
|
|
)))
|
|
}
|