gitdataai/lib/ai/rag/client.rs
2026-05-30 01:38:40 +08:00

264 lines
7.1 KiB
Rust

use config::AppConfig;
use qdrant_client::qdrant::{
CreateCollectionBuilder, CreateFieldIndexCollectionBuilder,
DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder,
SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
};
use qdrant_client::{Qdrant, QdrantError};
use super::{
config::RagConfig,
document::{RagDocument, RagSearchHit},
payload::{
SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id,
},
search::RagSearchOptions,
session::{session_filter, validate_session_id},
};
use crate::{
client::AiClient,
embed::{AiClientEmbedExt, EmbedClient},
error::{AiError, AiResult},
};
#[derive(Clone)]
pub struct RagClient {
qdrant: Qdrant,
embedder: EmbedClient,
config: RagConfig,
}
impl RagClient {
pub fn new(
qdrant: Qdrant,
embedder: EmbedClient,
config: RagConfig,
) -> AiResult<Self> {
config.validate()?;
Ok(Self {
qdrant,
embedder,
config,
})
}
pub fn connect(
ai_client: &AiClient,
config: RagConfig,
) -> AiResult<Self> {
config.validate()?;
let mut builder =
Qdrant::from_url(config.url.trim()).timeout(config.timeout);
if let Some(api_key) = config
.api_key
.as_deref()
.filter(|api_key| !api_key.trim().is_empty())
{
builder = builder.api_key(api_key);
}
Self::new(builder.build()?, ai_client.embedder()?, config)
}
pub fn from_app_config(
ai_client: &AiClient,
config: &AppConfig,
collection_name: impl Into<String>,
) -> AiResult<Self> {
Self::connect(
ai_client,
RagConfig::from_app_config(config, collection_name)?,
)
}
pub fn qdrant(&self) -> &Qdrant {
&self.qdrant
}
pub fn embedder(&self) -> &EmbedClient {
&self.embedder
}
pub fn config(&self) -> &RagConfig {
&self.config
}
pub async fn ensure_collection(&self) -> AiResult<()> {
if !self
.qdrant
.collection_exists(&self.config.collection_name)
.await?
{
self.qdrant
.create_collection(
CreateCollectionBuilder::new(&self.config.collection_name)
.vectors_config(VectorParamsBuilder::new(
self.config.vector_size,
self.config.distance,
)),
)
.await?;
}
match self
.qdrant
.create_field_index(CreateFieldIndexCollectionBuilder::new(
&self.config.collection_name,
SESSION_ID_KEY,
FieldType::Keyword,
))
.await
{
Ok(_) => Ok(()),
Err(QdrantError::ResponseError { .. }) => Ok(()),
Err(error) => Err(error.into()),
}
}
pub async fn upsert_document(
&self,
session_id: impl AsRef<str>,
document: RagDocument,
) -> AiResult<()> {
self.upsert_documents(session_id, vec![document]).await
}
pub async fn upsert_documents(
&self,
session_id: impl AsRef<str>,
documents: Vec<RagDocument>,
) -> AiResult<()> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
validate_documents(&documents)?;
let texts: Vec<String> = documents
.iter()
.map(|d| d.content.clone())
.collect();
let vectors = self
.embedder
.embed_texts_chunked(texts, self.config.upsert_batch_size)
.await?;
let points = documents
.iter()
.zip(vectors)
.map(|(document, vector)| {
Ok(PointStruct::new(
point_id(session_id, &document.id),
vector,
document_payload(session_id, document)?,
))
})
.collect::<AiResult<Vec<_>>>()?;
self.qdrant
.upsert_points(
UpsertPointsBuilder::new(&self.config.collection_name, points)
.wait(true),
)
.await?;
Ok(())
}
pub async fn search_session(
&self,
session_id: impl AsRef<str>,
query: impl Into<String>,
) -> AiResult<Vec<RagSearchHit>> {
let options = RagSearchOptions {
limit: self.config.default_search_limit,
exact: self.config.exact_session_search,
};
self.search_session_with_options(session_id, query, options)
.await
}
pub async fn search_session_with_options(
&self,
session_id: impl AsRef<str>,
query: impl Into<String>,
options: RagSearchOptions,
) -> AiResult<Vec<RagSearchHit>> {
let vector = self.embedder.embed_text(query.into()).await?;
self.search_session_by_vector(session_id, vector, options)
.await
}
pub async fn search_session_by_vector(
&self,
session_id: impl AsRef<str>,
vector: Vec<f32>,
options: RagSearchOptions,
) -> AiResult<Vec<RagSearchHit>> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
if options.limit == 0 {
return Err(AiError::Config(
"rag search limit must be greater than 0".to_string(),
));
}
let response = self
.qdrant
.query(
QueryPointsBuilder::new(&self.config.collection_name)
.query(vector)
.limit(options.limit)
.filter(session_filter(session_id))
.with_payload(true)
.params(
SearchParamsBuilder::default().exact(options.exact),
),
)
.await?;
Ok(response
.result
.into_iter()
.map(hit_from_scored_point)
.collect())
}
pub async fn clear_session(
&self,
session_id: impl AsRef<str>,
) -> AiResult<()> {
let session_id = session_id.as_ref();
validate_session_id(session_id)?;
self.qdrant
.delete_points(
DeletePointsBuilder::new(&self.config.collection_name)
.points(session_filter(session_id))
.wait(true),
)
.await?;
Ok(())
}
}
fn validate_documents(documents: &[RagDocument]) -> AiResult<()> {
if documents.is_empty() {
return Err(AiError::Config("rag documents are required".to_string()));
}
for document in documents {
if document.id.trim().is_empty() {
return Err(AiError::Config(
"rag document id is required".to_string(),
));
}
if document.content.trim().is_empty() {
return Err(AiError::Config(
"rag document content is required".to_string(),
));
}
}
Ok(())
}