use std::fmt; use std::sync::Arc; use config::AppConfig; use rig::providers::openai; use crate::error::{AiError, AiResult}; fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> { if value.trim().is_empty() { return Err(AiError::Config(format!("{scope} {field} is required"))); } Ok(()) } fn config_error(error: impl fmt::Display) -> AiError { AiError::Config(error.to_string()) } #[derive(Clone)] pub struct EndpointConfig { pub base_url: String, pub api_key: String, } impl EndpointConfig { pub fn new(base_url: impl Into, api_key: impl Into) -> AiResult { let config = Self { base_url: base_url.into(), api_key: api_key.into(), }; config.validate("endpoint")?; Ok(config) } fn validate(&self, scope: &str) -> AiResult<()> { validate_required(scope, "base_url", &self.base_url)?; validate_required(scope, "api_key", &self.api_key)?; if !self.base_url.trim().starts_with("http://") && !self.base_url.trim().starts_with("https://") { return Err(AiError::Config(format!( "{scope} base_url must start with http:// or https://" ))); } Ok(()) } pub fn build_client(&self) -> AiResult { openai::Client::builder() .api_key(&self.api_key) .base_url(self.base_url.trim()) .build() .map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}"))) } } impl fmt::Debug for EndpointConfig { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter .debug_struct("EndpointConfig") .field("base_url", &self.base_url) .field("api_key", &"") .finish() } } #[derive(Clone, Debug)] pub struct EmbedConfig { pub endpoint: EndpointConfig, pub model: String, pub dimensions: u64, } impl EmbedConfig { pub fn new( endpoint: EndpointConfig, model: impl Into, dimensions: u64, ) -> AiResult { let config = Self { endpoint, model: model.into(), dimensions, }; config.validate()?; Ok(config) } fn validate(&self) -> AiResult<()> { self.endpoint.validate("embed endpoint")?; validate_required("embed", "model", &self.model)?; if self.dimensions == 0 { return Err(AiError::Config( "embed dimensions must be greater than 0".to_string(), )); } Ok(()) } } #[derive(Clone, Debug)] pub struct AiClientConfig { pub llm: EndpointConfig, pub embed: EmbedConfig, } impl AiClientConfig { pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult { let config = Self { llm, embed }; config.validate()?; Ok(config) } pub fn validate(&self) -> AiResult<()> { self.llm.validate("llm endpoint")?; self.embed.validate()?; Ok(()) } } impl TryFrom<&AppConfig> for AiClientConfig { type Error = AiError; fn try_from(config: &AppConfig) -> Result { let llm = EndpointConfig::new( config.ai_basic_url().map_err(config_error)?, config.ai_api_key().map_err(config_error)?, )?; let embed_endpoint = EndpointConfig::new( config.get_embed_model_base_url().map_err(config_error)?, config.get_embed_model_api_key().map_err(config_error)?, )?; let embed = EmbedConfig::new( embed_endpoint, config.get_embed_model_name().map_err(config_error)?, config.get_embed_model_dimensions().map_err(config_error)?, )?; Self::new(llm, embed) } } #[derive(Clone, Debug)] pub struct AiClient { pub(super) llm_client: openai::Client, pub(super) embed_client: openai::Client, pub(super) config: Arc, } impl AiClient { pub fn new(config: AiClientConfig) -> AiResult { config.validate()?; Ok(Self { llm_client: config.llm.build_client()?, embed_client: config.embed.endpoint.build_client()?, config: Arc::new(config), }) } pub fn from_app_config(config: &AppConfig) -> AiResult { Self::new(AiClientConfig::try_from(config)?) } pub fn llm_client(&self) -> &openai::Client { &self.llm_client } pub fn embed_client(&self) -> &openai::Client { &self.embed_client } pub fn config(&self) -> &AiClientConfig { &self.config } pub fn llm_config(&self) -> &EndpointConfig { &self.config.llm } pub fn embed_config(&self) -> &EmbedConfig { &self.config.embed } pub fn embed_model(&self) -> &str { self.config.embed.model.as_str() } pub fn embed_dimensions(&self) -> u64 { self.config.embed.dimensions } pub fn embed_dimensions_u32(&self) -> u32 { u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX) } } pub fn build_http_client() -> Result { let mut builder = reqwest::Client::builder(); if let Ok(proxy_url) = std::env::var("HTTPS_PROXY") .or_else(|_| std::env::var("https_proxy")) .or_else(|_| std::env::var("HTTP_PROXY")) .or_else(|_| std::env::var("http_proxy")) { let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\''); let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| { AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e)) })?; builder = builder.proxy(proxy); } builder.build().map_err(|e| { AiError::Config(format!("Failed to build HTTP client: {}", e)) }) }