gitdataai/libs/agent/embed/qdrant.rs
2026-04-14 19:02:01 +08:00

313 lines
9.8 KiB
Rust

use qdrant_client::Qdrant;
use qdrant_client::qdrant::{
Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, FieldCondition, Filter,
Match, PointStruct, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder, Vectors,
condition::ConditionOneOf, r#match::MatchValue, point_id::PointIdOptions, value,
};
use std::collections::HashMap;
use std::sync::Arc;
use super::client::{EmbedPayload, SearchResult};
use crate::embed::client::EmbedVector;
pub struct QdrantClient {
inner: Arc<Qdrant>,
}
impl Clone for QdrantClient {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl QdrantClient {
pub async fn new(url: &str, api_key: Option<&str>) -> crate::Result<Self> {
let mut builder = Qdrant::from_url(url);
if let Some(key) = api_key {
builder = builder.api_key(key);
}
let inner = builder
.build()
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(Self {
inner: Arc::new(inner),
})
}
fn collection_name(entity_type: &str) -> String {
format!("embed_{}", entity_type)
}
pub async fn ensure_collection(&self, entity_type: &str, dimensions: u64) -> crate::Result<()> {
let name = Self::collection_name(entity_type);
let exists = self
.inner
.collection_exists(&name)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
if exists {
return Ok(());
}
let create_collection = CreateCollectionBuilder::new(name)
.vectors_config(VectorParamsBuilder::new(dimensions, Distance::Cosine))
.build();
self.inner
.create_collection(create_collection)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(())
}
pub async fn upsert_points(&self, points: Vec<EmbedVector>) -> crate::Result<()> {
if points.is_empty() {
return Ok(());
}
let collection_name = Self::collection_name(&points[0].payload.entity_type);
let qdrant_points: Vec<PointStruct> = points
.into_iter()
.map(|p| {
let mut payload: HashMap<String, qdrant_client::qdrant::Value> = HashMap::new();
payload.insert("entity_type".to_string(), p.payload.entity_type.into());
payload.insert("entity_id".to_string(), p.payload.entity_id.into());
payload.insert("text".to_string(), p.payload.text.into());
if let Some(extra) = p.payload.extra {
let extra_str = serde_json::to_string(&extra).unwrap_or_default();
payload.insert(
"extra".to_string(),
qdrant_client::qdrant::Value {
kind: Some(
qdrant_client::qdrant::value::Kind::StringValue(extra_str),
),
},
);
}
PointStruct::new(p.id, Vectors::from(p.vector), payload)
})
.collect();
let upsert = UpsertPointsBuilder::new(collection_name, qdrant_points).build();
self.inner
.upsert_points(upsert)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(())
}
fn extract_string(value: &qdrant_client::qdrant::Value) -> String {
match &value.kind {
Some(value::Kind::StringValue(s)) => s.clone(),
_ => String::new(),
}
}
pub async fn search(
&self,
vector: &[f32],
entity_type: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
let collection_name = Self::collection_name(entity_type);
let search = SearchPointsBuilder::new(collection_name, vector.to_vec(), limit as u64)
.with_payload(true)
.build();
let results = self
.inner
.search_points(search)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(results
.result
.into_iter()
.filter_map(|p| {
let entity_type = p
.payload
.get(&"entity_type".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let entity_id = p
.payload
.get(&"entity_id".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let text = p
.payload
.get(&"text".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let extra = p
.payload
.get(&"extra".to_string())
.and_then(|v| Some(Self::extract_string(v)))
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok());
let id =
p.id.and_then(|id| id.point_id_options)
.map(|opts| match opts {
PointIdOptions::Uuid(s) => s,
PointIdOptions::Num(n) => n.to_string(),
})
.unwrap_or_default();
Some(SearchResult {
id,
score: p.score,
payload: EmbedPayload {
entity_type,
entity_id,
text,
extra,
},
})
})
.collect())
}
pub async fn search_with_filter(
&self,
vector: &[f32],
entity_type: &str,
limit: usize,
filter: Filter,
) -> crate::Result<Vec<SearchResult>> {
let collection_name = Self::collection_name(entity_type);
let search = SearchPointsBuilder::new(collection_name, vector.to_vec(), limit as u64)
.with_payload(true)
.filter(filter)
.build();
let results = self
.inner
.search_points(search)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(results
.result
.into_iter()
.filter_map(|p| {
let entity_type = p
.payload
.get(&"entity_type".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let entity_id = p
.payload
.get(&"entity_id".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let text = p
.payload
.get(&"text".to_string())
.map(Self::extract_string)
.unwrap_or_default();
let extra = p
.payload
.get(&"extra".to_string())
.and_then(|v| Some(Self::extract_string(v)))
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok());
let id =
p.id.and_then(|id| id.point_id_options)
.map(|opts| match opts {
PointIdOptions::Uuid(s) => s,
PointIdOptions::Num(n) => n.to_string(),
})
.unwrap_or_default();
Some(SearchResult {
id,
score: p.score,
payload: EmbedPayload {
entity_type,
entity_id,
text,
extra,
},
})
})
.collect())
}
pub async fn delete_by_filter(&self, entity_type: &str, entity_id: &str) -> crate::Result<()> {
let collection_name = Self::collection_name(entity_type);
let filter = Filter {
must: vec![Condition {
condition_one_of: Some(ConditionOneOf::Field(FieldCondition {
key: "entity_id".to_string(),
r#match: Some(Match {
match_value: Some(MatchValue::Keyword(entity_id.to_string())),
}),
..Default::default()
})),
}],
..Default::default()
};
let delete = DeletePointsBuilder::new(collection_name)
.points(filter)
.build();
self.inner
.delete_points(delete)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(())
}
pub async fn delete_collection(&self, entity_type: &str) -> crate::Result<()> {
let name = Self::collection_name(entity_type);
self.inner
.delete_collection(name)
.await
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
Ok(())
}
pub async fn ensure_memory_collection(&self, dimensions: u64) -> crate::Result<()> {
self.ensure_collection("memory", dimensions).await
}
pub async fn ensure_skill_collection(&self, dimensions: u64) -> crate::Result<()> {
self.ensure_collection("skill", dimensions).await
}
pub async fn search_memory(
&self,
vector: &[f32],
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.search(vector, "memory", limit).await
}
pub async fn search_skill(
&self,
vector: &[f32],
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.search(vector, "skill", limit).await
}
}