gitdataai/lib/ai/rag/config.rs
2026-05-30 01:38:40 +08:00

135 lines
3.7 KiB
Rust

use std::time::Duration;
use config::AppConfig;
use qdrant_client::qdrant::Distance;
use crate::error::{AiError, AiResult};
#[derive(Clone, Debug)]
pub struct RagConfig {
pub url: String,
pub api_key: Option<String>,
pub collection_name: String,
pub vector_size: u64,
pub distance: Distance,
pub timeout: Duration,
pub upsert_batch_size: usize,
pub default_search_limit: u64,
pub exact_session_search: bool,
}
impl RagConfig {
pub fn new(
url: impl Into<String>,
collection_name: impl Into<String>,
vector_size: u64,
) -> AiResult<Self> {
let config = Self {
url: url.into(),
api_key: None,
collection_name: collection_name.into(),
vector_size,
distance: Distance::Cosine,
timeout: Duration::from_secs(10),
upsert_batch_size: 64,
default_search_limit: 8,
exact_session_search: true,
};
config.validate()?;
Ok(config)
}
pub fn with_api_key(mut self, api_key: Option<String>) -> Self {
self.api_key = api_key;
self
}
pub fn with_distance(mut self, distance: Distance) -> Self {
self.distance = distance;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_upsert_batch_size(mut self, upsert_batch_size: usize) -> Self {
self.upsert_batch_size = upsert_batch_size;
self
}
pub fn with_default_search_limit(
mut self,
default_search_limit: u64,
) -> Self {
self.default_search_limit = default_search_limit;
self
}
pub fn with_exact_session_search(
mut self,
exact_session_search: bool,
) -> Self {
self.exact_session_search = exact_session_search;
self
}
pub fn validate(&self) -> AiResult<()> {
if self.url.trim().is_empty() {
return Err(AiError::Config("qdrant url is required".to_string()));
}
if !self.url.trim().starts_with("http://")
&& !self.url.trim().starts_with("https://")
{
return Err(AiError::Config(
"qdrant url must start with http:// or https://".to_string(),
));
}
if self.collection_name.trim().is_empty() {
return Err(AiError::Config(
"qdrant collection_name is required".to_string(),
));
}
if self.vector_size == 0 {
return Err(AiError::Config(
"qdrant vector_size must be greater than 0".to_string(),
));
}
if self.upsert_batch_size == 0 {
return Err(AiError::Config(
"qdrant upsert_batch_size must be greater than 0".to_string(),
));
}
if self.default_search_limit == 0 {
return Err(AiError::Config(
"qdrant default_search_limit must be greater than 0"
.to_string(),
));
}
Ok(())
}
}
impl RagConfig {
pub fn from_app_config(
config: &AppConfig,
collection_name: impl Into<String>,
) -> AiResult<Self> {
Ok(Self::new(
config
.qdrant_url()
.map_err(|error| AiError::Config(error.to_string()))?,
collection_name,
config
.get_embed_model_dimensions()
.map_err(|error| AiError::Config(error.to_string()))?,
)?
.with_api_key(
config
.qdrant_api_key()
.map_err(|error| AiError::Config(error.to_string()))?,
))
}
}