135 lines
3.7 KiB
Rust
135 lines
3.7 KiB
Rust
use std::time::Duration;
|
|
|
|
use config::AppConfig;
|
|
use qdrant_client::qdrant::Distance;
|
|
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct RagConfig {
|
|
pub url: String,
|
|
pub api_key: Option<String>,
|
|
pub collection_name: String,
|
|
pub vector_size: u64,
|
|
pub distance: Distance,
|
|
pub timeout: Duration,
|
|
pub upsert_batch_size: usize,
|
|
pub default_search_limit: u64,
|
|
pub exact_session_search: bool,
|
|
}
|
|
|
|
impl RagConfig {
|
|
pub fn new(
|
|
url: impl Into<String>,
|
|
collection_name: impl Into<String>,
|
|
vector_size: u64,
|
|
) -> AiResult<Self> {
|
|
let config = Self {
|
|
url: url.into(),
|
|
api_key: None,
|
|
collection_name: collection_name.into(),
|
|
vector_size,
|
|
distance: Distance::Cosine,
|
|
timeout: Duration::from_secs(10),
|
|
upsert_batch_size: 64,
|
|
default_search_limit: 8,
|
|
exact_session_search: true,
|
|
};
|
|
config.validate()?;
|
|
Ok(config)
|
|
}
|
|
|
|
pub fn with_api_key(mut self, api_key: Option<String>) -> Self {
|
|
self.api_key = api_key;
|
|
self
|
|
}
|
|
|
|
pub fn with_distance(mut self, distance: Distance) -> Self {
|
|
self.distance = distance;
|
|
self
|
|
}
|
|
|
|
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
|
self.timeout = timeout;
|
|
self
|
|
}
|
|
|
|
pub fn with_upsert_batch_size(mut self, upsert_batch_size: usize) -> Self {
|
|
self.upsert_batch_size = upsert_batch_size;
|
|
self
|
|
}
|
|
|
|
pub fn with_default_search_limit(
|
|
mut self,
|
|
default_search_limit: u64,
|
|
) -> Self {
|
|
self.default_search_limit = default_search_limit;
|
|
self
|
|
}
|
|
|
|
pub fn with_exact_session_search(
|
|
mut self,
|
|
exact_session_search: bool,
|
|
) -> Self {
|
|
self.exact_session_search = exact_session_search;
|
|
self
|
|
}
|
|
|
|
pub fn validate(&self) -> AiResult<()> {
|
|
if self.url.trim().is_empty() {
|
|
return Err(AiError::Config("qdrant url is required".to_string()));
|
|
}
|
|
if !self.url.trim().starts_with("http://")
|
|
&& !self.url.trim().starts_with("https://")
|
|
{
|
|
return Err(AiError::Config(
|
|
"qdrant url must start with http:// or https://".to_string(),
|
|
));
|
|
}
|
|
if self.collection_name.trim().is_empty() {
|
|
return Err(AiError::Config(
|
|
"qdrant collection_name is required".to_string(),
|
|
));
|
|
}
|
|
if self.vector_size == 0 {
|
|
return Err(AiError::Config(
|
|
"qdrant vector_size must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
if self.upsert_batch_size == 0 {
|
|
return Err(AiError::Config(
|
|
"qdrant upsert_batch_size must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
if self.default_search_limit == 0 {
|
|
return Err(AiError::Config(
|
|
"qdrant default_search_limit must be greater than 0"
|
|
.to_string(),
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl RagConfig {
|
|
pub fn from_app_config(
|
|
config: &AppConfig,
|
|
collection_name: impl Into<String>,
|
|
) -> AiResult<Self> {
|
|
Ok(Self::new(
|
|
config
|
|
.qdrant_url()
|
|
.map_err(|error| AiError::Config(error.to_string()))?,
|
|
collection_name,
|
|
config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|error| AiError::Config(error.to_string()))?,
|
|
)?
|
|
.with_api_key(
|
|
config
|
|
.qdrant_api_key()
|
|
.map_err(|error| AiError::Config(error.to_string()))?,
|
|
))
|
|
}
|
|
}
|