gitdataai/lib/ai/client.rs

227 lines
5.9 KiB
Rust

use std::fmt;
use std::sync::Arc;
use config::AppConfig;
use rig::providers::openai;
use crate::error::{AiError, AiResult};
fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> {
if value.trim().is_empty() {
return Err(AiError::Config(format!("{scope} {field} is required")));
}
Ok(())
}
fn config_error(error: impl fmt::Display) -> AiError {
AiError::Config(error.to_string())
}
#[derive(Clone)]
pub struct EndpointConfig {
pub base_url: String,
pub api_key: String,
}
impl EndpointConfig {
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
) -> AiResult<Self> {
let config = Self {
base_url: base_url.into(),
api_key: api_key.into(),
};
config.validate("endpoint")?;
Ok(config)
}
fn validate(&self, scope: &str) -> AiResult<()> {
validate_required(scope, "base_url", &self.base_url)?;
validate_required(scope, "api_key", &self.api_key)?;
if !self.base_url.trim().starts_with("http://")
&& !self.base_url.trim().starts_with("https://")
{
return Err(AiError::Config(format!(
"{scope} base_url must start with http:// or https://"
)));
}
Ok(())
}
pub fn build_client(&self) -> AiResult<openai::Client> {
openai::Client::builder()
.api_key(&self.api_key)
.base_url(self.base_url.trim())
.build()
.map_err(|e| {
AiError::Config(format!(
"failed to build rig OpenAI client: {e}"
))
})
}
}
impl fmt::Debug for EndpointConfig {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("EndpointConfig")
.field("base_url", &self.base_url)
.field("api_key", &"<redacted>")
.finish()
}
}
#[derive(Clone, Debug)]
pub struct EmbedConfig {
pub endpoint: EndpointConfig,
pub model: String,
pub dimensions: u64,
}
impl EmbedConfig {
pub fn new(
endpoint: EndpointConfig,
model: impl Into<String>,
dimensions: u64,
) -> AiResult<Self> {
let config = Self {
endpoint,
model: model.into(),
dimensions,
};
config.validate()?;
Ok(config)
}
fn validate(&self) -> AiResult<()> {
self.endpoint.validate("embed endpoint")?;
validate_required("embed", "model", &self.model)?;
if self.dimensions == 0 {
return Err(AiError::Config(
"embed dimensions must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct AiClientConfig {
pub llm: EndpointConfig,
pub embed: EmbedConfig,
}
impl AiClientConfig {
pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult<Self> {
let config = Self { llm, embed };
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> AiResult<()> {
self.llm.validate("llm endpoint")?;
self.embed.validate()?;
Ok(())
}
}
impl TryFrom<&AppConfig> for AiClientConfig {
type Error = AiError;
fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
let llm = EndpointConfig::new(
config.ai_basic_url().map_err(config_error)?,
config.ai_api_key().map_err(config_error)?,
)?;
let embed_endpoint = EndpointConfig::new(
config.get_embed_model_base_url().map_err(config_error)?,
config.get_embed_model_api_key().map_err(config_error)?,
)?;
let embed = EmbedConfig::new(
embed_endpoint,
config.get_embed_model_name().map_err(config_error)?,
config.get_embed_model_dimensions().map_err(config_error)?,
)?;
Self::new(llm, embed)
}
}
#[derive(Clone, Debug)]
pub struct AiClient {
pub(super) llm_client: openai::Client,
pub(super) embed_client: openai::Client,
pub(super) config: Arc<AiClientConfig>,
}
impl AiClient {
pub fn new(config: AiClientConfig) -> AiResult<Self> {
config.validate()?;
Ok(Self {
llm_client: config.llm.build_client()?,
embed_client: config.embed.endpoint.build_client()?,
config: Arc::new(config),
})
}
pub fn from_app_config(config: &AppConfig) -> AiResult<Self> {
Self::new(AiClientConfig::try_from(config)?)
}
pub fn llm_client(&self) -> &openai::Client {
&self.llm_client
}
pub fn embed_client(&self) -> &openai::Client {
&self.embed_client
}
pub fn config(&self) -> &AiClientConfig {
&self.config
}
pub fn llm_config(&self) -> &EndpointConfig {
&self.config.llm
}
pub fn embed_config(&self) -> &EmbedConfig {
&self.config.embed
}
pub fn embed_model(&self) -> &str {
self.config.embed.model.as_str()
}
pub fn embed_dimensions(&self) -> u64 {
self.config.embed.dimensions
}
pub fn embed_dimensions_u32(&self) -> u32 {
u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX)
}
}
pub fn build_http_client() -> Result<reqwest::Client, AiError> {
let mut builder = reqwest::Client::builder();
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY")
.or_else(|_| std::env::var("https_proxy"))
.or_else(|_| std::env::var("HTTP_PROXY"))
.or_else(|_| std::env::var("http_proxy"))
{
let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\'');
let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| {
AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e))
})?;
builder = builder.proxy(proxy);
}
builder.build().map_err(|e| {
AiError::Config(format!("Failed to build HTTP client: {}", e))
})
}