220 lines
5.8 KiB
Rust
220 lines
5.8 KiB
Rust
use std::fmt;
|
|
use std::sync::Arc;
|
|
|
|
use config::AppConfig;
|
|
use rig::providers::openai;
|
|
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> {
|
|
if value.trim().is_empty() {
|
|
return Err(AiError::Config(format!("{scope} {field} is required")));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn config_error(error: impl fmt::Display) -> AiError {
|
|
AiError::Config(error.to_string())
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct EndpointConfig {
|
|
pub base_url: String,
|
|
pub api_key: String,
|
|
}
|
|
|
|
impl EndpointConfig {
|
|
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> AiResult<Self> {
|
|
let config = Self {
|
|
base_url: base_url.into(),
|
|
api_key: api_key.into(),
|
|
};
|
|
config.validate("endpoint")?;
|
|
Ok(config)
|
|
}
|
|
|
|
fn validate(&self, scope: &str) -> AiResult<()> {
|
|
validate_required(scope, "base_url", &self.base_url)?;
|
|
validate_required(scope, "api_key", &self.api_key)?;
|
|
if !self.base_url.trim().starts_with("http://")
|
|
&& !self.base_url.trim().starts_with("https://")
|
|
{
|
|
return Err(AiError::Config(format!(
|
|
"{scope} base_url must start with http:// or https://"
|
|
)));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub fn build_client(&self) -> AiResult<openai::Client> {
|
|
openai::Client::builder()
|
|
.api_key(&self.api_key)
|
|
.base_url(self.base_url.trim())
|
|
.build()
|
|
.map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}")))
|
|
}
|
|
}
|
|
|
|
impl fmt::Debug for EndpointConfig {
|
|
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
formatter
|
|
.debug_struct("EndpointConfig")
|
|
.field("base_url", &self.base_url)
|
|
.field("api_key", &"<redacted>")
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct EmbedConfig {
|
|
pub endpoint: EndpointConfig,
|
|
pub model: String,
|
|
pub dimensions: u64,
|
|
}
|
|
|
|
impl EmbedConfig {
|
|
pub fn new(
|
|
endpoint: EndpointConfig,
|
|
model: impl Into<String>,
|
|
dimensions: u64,
|
|
) -> AiResult<Self> {
|
|
let config = Self {
|
|
endpoint,
|
|
model: model.into(),
|
|
dimensions,
|
|
};
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
|
|
fn validate(&self) -> AiResult<()> {
|
|
self.endpoint.validate("embed endpoint")?;
|
|
validate_required("embed", "model", &self.model)?;
|
|
if self.dimensions == 0 {
|
|
return Err(AiError::Config(
|
|
"embed dimensions must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AiClientConfig {
|
|
pub llm: EndpointConfig,
|
|
pub embed: EmbedConfig,
|
|
}
|
|
|
|
impl AiClientConfig {
|
|
pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult<Self> {
|
|
let config = Self { llm, embed };
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn validate(&self) -> AiResult<()> {
|
|
self.llm.validate("llm endpoint")?;
|
|
self.embed.validate()?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl TryFrom<&AppConfig> for AiClientConfig {
|
|
type Error = AiError;
|
|
|
|
fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
|
|
let llm = EndpointConfig::new(
|
|
config.ai_basic_url().map_err(config_error)?,
|
|
config.ai_api_key().map_err(config_error)?,
|
|
)?;
|
|
|
|
let embed_endpoint = EndpointConfig::new(
|
|
config.get_embed_model_base_url().map_err(config_error)?,
|
|
config.get_embed_model_api_key().map_err(config_error)?,
|
|
)?;
|
|
|
|
let embed = EmbedConfig::new(
|
|
embed_endpoint,
|
|
config.get_embed_model_name().map_err(config_error)?,
|
|
config.get_embed_model_dimensions().map_err(config_error)?,
|
|
)?;
|
|
|
|
Self::new(llm, embed)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AiClient {
|
|
pub(super) llm_client: openai::Client,
|
|
pub(super) embed_client: openai::Client,
|
|
pub(super) config: Arc<AiClientConfig>,
|
|
}
|
|
|
|
impl AiClient {
|
|
pub fn new(config: AiClientConfig) -> AiResult<Self> {
|
|
config.validate()?;
|
|
|
|
Ok(Self {
|
|
llm_client: config.llm.build_client()?,
|
|
embed_client: config.embed.endpoint.build_client()?,
|
|
config: Arc::new(config),
|
|
})
|
|
}
|
|
|
|
pub fn from_app_config(config: &AppConfig) -> AiResult<Self> {
|
|
Self::new(AiClientConfig::try_from(config)?)
|
|
}
|
|
|
|
pub fn llm_client(&self) -> &openai::Client {
|
|
&self.llm_client
|
|
}
|
|
|
|
pub fn embed_client(&self) -> &openai::Client {
|
|
&self.embed_client
|
|
}
|
|
|
|
pub fn config(&self) -> &AiClientConfig {
|
|
&self.config
|
|
}
|
|
|
|
pub fn llm_config(&self) -> &EndpointConfig {
|
|
&self.config.llm
|
|
}
|
|
|
|
pub fn embed_config(&self) -> &EmbedConfig {
|
|
&self.config.embed
|
|
}
|
|
|
|
pub fn embed_model(&self) -> &str {
|
|
self.config.embed.model.as_str()
|
|
}
|
|
|
|
pub fn embed_dimensions(&self) -> u64 {
|
|
self.config.embed.dimensions
|
|
}
|
|
|
|
pub fn embed_dimensions_u32(&self) -> u32 {
|
|
u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX)
|
|
}
|
|
}
|
|
|
|
pub fn build_http_client() -> Result<reqwest::Client, AiError> {
|
|
let mut builder = reqwest::Client::builder();
|
|
|
|
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY")
|
|
.or_else(|_| std::env::var("https_proxy"))
|
|
.or_else(|_| std::env::var("HTTP_PROXY"))
|
|
.or_else(|_| std::env::var("http_proxy"))
|
|
{
|
|
let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\'');
|
|
let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| {
|
|
AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e))
|
|
})?;
|
|
builder = builder.proxy(proxy);
|
|
}
|
|
|
|
builder.build().map_err(|e| {
|
|
AiError::Config(format!("Failed to build HTTP client: {}", e))
|
|
})
|
|
}
|