use config::AppConfig; use qdrant_client::qdrant::{ CreateCollectionBuilder, CreateFieldIndexCollectionBuilder, DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder, SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder, }; use qdrant_client::{Qdrant, QdrantError}; use super::{ config::RagConfig, document::{RagDocument, RagSearchHit}, payload::{ SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id, }, search::RagSearchOptions, session::{session_filter, validate_session_id}, }; use crate::{ client::AiClient, embed::{AiClientEmbedExt, EmbedClient}, error::{AiError, AiResult}, }; #[derive(Clone)] pub struct RagClient { qdrant: Qdrant, embedder: EmbedClient, config: RagConfig, } impl RagClient { pub fn new( qdrant: Qdrant, embedder: EmbedClient, config: RagConfig, ) -> AiResult { config.validate()?; Ok(Self { qdrant, embedder, config, }) } pub fn connect( ai_client: &AiClient, config: RagConfig, ) -> AiResult { config.validate()?; let mut builder = Qdrant::from_url(config.url.trim()).timeout(config.timeout); if let Some(api_key) = config .api_key .as_deref() .filter(|api_key| !api_key.trim().is_empty()) { builder = builder.api_key(api_key); } Self::new(builder.build()?, ai_client.embedder()?, config) } pub fn from_app_config( ai_client: &AiClient, config: &AppConfig, collection_name: impl Into, ) -> AiResult { Self::connect( ai_client, RagConfig::from_app_config(config, collection_name)?, ) } pub fn qdrant(&self) -> &Qdrant { &self.qdrant } pub fn embedder(&self) -> &EmbedClient { &self.embedder } pub fn config(&self) -> &RagConfig { &self.config } pub async fn ensure_collection(&self) -> AiResult<()> { if !self .qdrant .collection_exists(&self.config.collection_name) .await? { self.qdrant .create_collection( CreateCollectionBuilder::new(&self.config.collection_name) .vectors_config(VectorParamsBuilder::new( self.config.vector_size, self.config.distance, )), ) .await?; } match self .qdrant .create_field_index(CreateFieldIndexCollectionBuilder::new( &self.config.collection_name, SESSION_ID_KEY, FieldType::Keyword, )) .await { Ok(_) => Ok(()), Err(QdrantError::ResponseError { .. }) => Ok(()), Err(error) => Err(error.into()), } } pub async fn upsert_document( &self, session_id: impl AsRef, document: RagDocument, ) -> AiResult<()> { self.upsert_documents(session_id, vec![document]).await } pub async fn upsert_documents( &self, session_id: impl AsRef, documents: Vec, ) -> AiResult<()> { let session_id = session_id.as_ref(); validate_session_id(session_id)?; validate_documents(&documents)?; let texts: Vec = documents .iter() .map(|d| d.content.clone()) .collect(); let vectors = self .embedder .embed_texts_chunked(texts, self.config.upsert_batch_size) .await?; let points = documents .iter() .zip(vectors) .map(|(document, vector)| { Ok(PointStruct::new( point_id(session_id, &document.id), vector, document_payload(session_id, document)?, )) }) .collect::>>()?; self.qdrant .upsert_points( UpsertPointsBuilder::new(&self.config.collection_name, points) .wait(true), ) .await?; Ok(()) } pub async fn search_session( &self, session_id: impl AsRef, query: impl Into, ) -> AiResult> { let options = RagSearchOptions { limit: self.config.default_search_limit, exact: self.config.exact_session_search, }; self.search_session_with_options(session_id, query, options) .await } pub async fn search_session_with_options( &self, session_id: impl AsRef, query: impl Into, options: RagSearchOptions, ) -> AiResult> { let vector = self.embedder.embed_text(query.into()).await?; self.search_session_by_vector(session_id, vector, options) .await } pub async fn search_session_by_vector( &self, session_id: impl AsRef, vector: Vec, options: RagSearchOptions, ) -> AiResult> { let session_id = session_id.as_ref(); validate_session_id(session_id)?; if options.limit == 0 { return Err(AiError::Config( "rag search limit must be greater than 0".to_string(), )); } let response = self .qdrant .query( QueryPointsBuilder::new(&self.config.collection_name) .query(vector) .limit(options.limit) .filter(session_filter(session_id)) .with_payload(true) .params( SearchParamsBuilder::default().exact(options.exact), ), ) .await?; Ok(response .result .into_iter() .map(hit_from_scored_point) .collect()) } pub async fn clear_session( &self, session_id: impl AsRef, ) -> AiResult<()> { let session_id = session_id.as_ref(); validate_session_id(session_id)?; self.qdrant .delete_points( DeletePointsBuilder::new(&self.config.collection_name) .points(session_filter(session_id)) .wait(true), ) .await?; Ok(()) } } fn validate_documents(documents: &[RagDocument]) -> AiResult<()> { if documents.is_empty() { return Err(AiError::Config("rag documents are required".to_string())); } for document in documents { if document.id.trim().is_empty() { return Err(AiError::Config( "rag document id is required".to_string(), )); } if document.content.trim().is_empty() { return Err(AiError::Config( "rag document content is required".to_string(), )); } } Ok(()) }