259 lines
7.1 KiB
Rust
259 lines
7.1 KiB
Rust
use config::AppConfig;
|
|
use qdrant_client::qdrant::{
|
|
CreateCollectionBuilder, CreateFieldIndexCollectionBuilder,
|
|
DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder,
|
|
SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
|
|
};
|
|
use qdrant_client::{Qdrant, QdrantError};
|
|
|
|
use super::{
|
|
config::RagConfig,
|
|
document::{RagDocument, RagSearchHit},
|
|
payload::{
|
|
SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id,
|
|
},
|
|
search::RagSearchOptions,
|
|
session::{session_filter, validate_session_id},
|
|
};
|
|
use crate::{
|
|
client::AiClient,
|
|
embed::{AiClientEmbedExt, EmbedClient},
|
|
error::{AiError, AiResult},
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct RagClient {
|
|
qdrant: Qdrant,
|
|
embedder: EmbedClient,
|
|
config: RagConfig,
|
|
}
|
|
|
|
impl RagClient {
|
|
pub fn new(
|
|
qdrant: Qdrant,
|
|
embedder: EmbedClient,
|
|
config: RagConfig,
|
|
) -> AiResult<Self> {
|
|
config.validate()?;
|
|
Ok(Self {
|
|
qdrant,
|
|
embedder,
|
|
config,
|
|
})
|
|
}
|
|
|
|
pub fn connect(ai_client: &AiClient, config: RagConfig) -> AiResult<Self> {
|
|
config.validate()?;
|
|
let mut builder =
|
|
Qdrant::from_url(config.url.trim()).timeout(config.timeout);
|
|
if let Some(api_key) = config
|
|
.api_key
|
|
.as_deref()
|
|
.filter(|api_key| !api_key.trim().is_empty())
|
|
{
|
|
builder = builder.api_key(api_key);
|
|
}
|
|
|
|
Self::new(builder.build()?, ai_client.embedder()?, config)
|
|
}
|
|
|
|
pub fn from_app_config(
|
|
ai_client: &AiClient,
|
|
config: &AppConfig,
|
|
collection_name: impl Into<String>,
|
|
) -> AiResult<Self> {
|
|
Self::connect(
|
|
ai_client,
|
|
RagConfig::from_app_config(config, collection_name)?,
|
|
)
|
|
}
|
|
|
|
pub fn qdrant(&self) -> &Qdrant {
|
|
&self.qdrant
|
|
}
|
|
|
|
pub fn embedder(&self) -> &EmbedClient {
|
|
&self.embedder
|
|
}
|
|
|
|
pub fn config(&self) -> &RagConfig {
|
|
&self.config
|
|
}
|
|
|
|
pub async fn ensure_collection(&self) -> AiResult<()> {
|
|
if !self
|
|
.qdrant
|
|
.collection_exists(&self.config.collection_name)
|
|
.await?
|
|
{
|
|
self.qdrant
|
|
.create_collection(
|
|
CreateCollectionBuilder::new(&self.config.collection_name)
|
|
.vectors_config(VectorParamsBuilder::new(
|
|
self.config.vector_size,
|
|
self.config.distance,
|
|
)),
|
|
)
|
|
.await?;
|
|
}
|
|
|
|
match self
|
|
.qdrant
|
|
.create_field_index(CreateFieldIndexCollectionBuilder::new(
|
|
&self.config.collection_name,
|
|
SESSION_ID_KEY,
|
|
FieldType::Keyword,
|
|
))
|
|
.await
|
|
{
|
|
Ok(_) => Ok(()),
|
|
Err(QdrantError::ResponseError { .. }) => Ok(()),
|
|
Err(error) => Err(error.into()),
|
|
}
|
|
}
|
|
|
|
pub async fn upsert_document(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
document: RagDocument,
|
|
) -> AiResult<()> {
|
|
self.upsert_documents(session_id, vec![document]).await
|
|
}
|
|
|
|
pub async fn upsert_documents(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
documents: Vec<RagDocument>,
|
|
) -> AiResult<()> {
|
|
let session_id = session_id.as_ref();
|
|
validate_session_id(session_id)?;
|
|
validate_documents(&documents)?;
|
|
|
|
let texts: Vec<String> =
|
|
documents.iter().map(|d| d.content.clone()).collect();
|
|
let vectors = self
|
|
.embedder
|
|
.embed_texts_chunked(texts, self.config.upsert_batch_size)
|
|
.await?;
|
|
|
|
let points = documents
|
|
.iter()
|
|
.zip(vectors)
|
|
.map(|(document, vector)| {
|
|
Ok(PointStruct::new(
|
|
point_id(session_id, &document.id),
|
|
vector,
|
|
document_payload(session_id, document)?,
|
|
))
|
|
})
|
|
.collect::<AiResult<Vec<_>>>()?;
|
|
|
|
self.qdrant
|
|
.upsert_points(
|
|
UpsertPointsBuilder::new(&self.config.collection_name, points)
|
|
.wait(true),
|
|
)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn search_session(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
query: impl Into<String>,
|
|
) -> AiResult<Vec<RagSearchHit>> {
|
|
let options = RagSearchOptions {
|
|
limit: self.config.default_search_limit,
|
|
exact: self.config.exact_session_search,
|
|
};
|
|
self.search_session_with_options(session_id, query, options)
|
|
.await
|
|
}
|
|
|
|
pub async fn search_session_with_options(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
query: impl Into<String>,
|
|
options: RagSearchOptions,
|
|
) -> AiResult<Vec<RagSearchHit>> {
|
|
let vector = self.embedder.embed_text(query.into()).await?;
|
|
self.search_session_by_vector(session_id, vector, options)
|
|
.await
|
|
}
|
|
|
|
pub async fn search_session_by_vector(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
vector: Vec<f32>,
|
|
options: RagSearchOptions,
|
|
) -> AiResult<Vec<RagSearchHit>> {
|
|
let session_id = session_id.as_ref();
|
|
validate_session_id(session_id)?;
|
|
if options.limit == 0 {
|
|
return Err(AiError::Config(
|
|
"rag search limit must be greater than 0".to_string(),
|
|
));
|
|
}
|
|
|
|
let response = self
|
|
.qdrant
|
|
.query(
|
|
QueryPointsBuilder::new(&self.config.collection_name)
|
|
.query(vector)
|
|
.limit(options.limit)
|
|
.filter(session_filter(session_id))
|
|
.with_payload(true)
|
|
.params(
|
|
SearchParamsBuilder::default().exact(options.exact),
|
|
),
|
|
)
|
|
.await?;
|
|
|
|
Ok(response
|
|
.result
|
|
.into_iter()
|
|
.map(hit_from_scored_point)
|
|
.collect())
|
|
}
|
|
|
|
pub async fn clear_session(
|
|
&self,
|
|
session_id: impl AsRef<str>,
|
|
) -> AiResult<()> {
|
|
let session_id = session_id.as_ref();
|
|
validate_session_id(session_id)?;
|
|
|
|
self.qdrant
|
|
.delete_points(
|
|
DeletePointsBuilder::new(&self.config.collection_name)
|
|
.points(session_filter(session_id))
|
|
.wait(true),
|
|
)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn validate_documents(documents: &[RagDocument]) -> AiResult<()> {
|
|
if documents.is_empty() {
|
|
return Err(AiError::Config("rag documents are required".to_string()));
|
|
}
|
|
|
|
for document in documents {
|
|
if document.id.trim().is_empty() {
|
|
return Err(AiError::Config(
|
|
"rag document id is required".to_string(),
|
|
));
|
|
}
|
|
if document.content.trim().is_empty() {
|
|
return Err(AiError::Config(
|
|
"rag document content is required".to_string(),
|
|
));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|