313 lines
9.8 KiB
Rust
313 lines
9.8 KiB
Rust
use qdrant_client::Qdrant;
|
|
use qdrant_client::qdrant::{
|
|
Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, FieldCondition, Filter,
|
|
Match, PointStruct, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder, Vectors,
|
|
condition::ConditionOneOf, r#match::MatchValue, point_id::PointIdOptions, value,
|
|
};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use super::client::{EmbedPayload, SearchResult};
|
|
use crate::embed::client::EmbedVector;
|
|
|
|
pub struct QdrantClient {
|
|
inner: Arc<Qdrant>,
|
|
}
|
|
|
|
impl Clone for QdrantClient {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
inner: self.inner.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl QdrantClient {
|
|
pub async fn new(url: &str, api_key: Option<&str>) -> crate::Result<Self> {
|
|
let mut builder = Qdrant::from_url(url);
|
|
if let Some(key) = api_key {
|
|
builder = builder.api_key(key);
|
|
}
|
|
let inner = builder
|
|
.build()
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
Ok(Self {
|
|
inner: Arc::new(inner),
|
|
})
|
|
}
|
|
|
|
fn collection_name(entity_type: &str) -> String {
|
|
format!("embed_{}", entity_type)
|
|
}
|
|
|
|
pub async fn ensure_collection(&self, entity_type: &str, dimensions: u64) -> crate::Result<()> {
|
|
let name = Self::collection_name(entity_type);
|
|
let exists = self
|
|
.inner
|
|
.collection_exists(&name)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
if exists {
|
|
return Ok(());
|
|
}
|
|
|
|
let create_collection = CreateCollectionBuilder::new(name)
|
|
.vectors_config(VectorParamsBuilder::new(dimensions, Distance::Cosine))
|
|
.build();
|
|
|
|
self.inner
|
|
.create_collection(create_collection)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn upsert_points(&self, points: Vec<EmbedVector>) -> crate::Result<()> {
|
|
if points.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let collection_name = Self::collection_name(&points[0].payload.entity_type);
|
|
|
|
let qdrant_points: Vec<PointStruct> = points
|
|
.into_iter()
|
|
.map(|p| {
|
|
let mut payload: HashMap<String, qdrant_client::qdrant::Value> = HashMap::new();
|
|
payload.insert("entity_type".to_string(), p.payload.entity_type.into());
|
|
payload.insert("entity_id".to_string(), p.payload.entity_id.into());
|
|
payload.insert("text".to_string(), p.payload.text.into());
|
|
if let Some(extra) = p.payload.extra {
|
|
let extra_str = serde_json::to_string(&extra).unwrap_or_default();
|
|
payload.insert(
|
|
"extra".to_string(),
|
|
qdrant_client::qdrant::Value {
|
|
kind: Some(
|
|
qdrant_client::qdrant::value::Kind::StringValue(extra_str),
|
|
),
|
|
},
|
|
);
|
|
}
|
|
|
|
PointStruct::new(p.id, Vectors::from(p.vector), payload)
|
|
})
|
|
.collect();
|
|
|
|
let upsert = UpsertPointsBuilder::new(collection_name, qdrant_points).build();
|
|
|
|
self.inner
|
|
.upsert_points(upsert)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn extract_string(value: &qdrant_client::qdrant::Value) -> String {
|
|
match &value.kind {
|
|
Some(value::Kind::StringValue(s)) => s.clone(),
|
|
_ => String::new(),
|
|
}
|
|
}
|
|
|
|
pub async fn search(
|
|
&self,
|
|
vector: &[f32],
|
|
entity_type: &str,
|
|
limit: usize,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let collection_name = Self::collection_name(entity_type);
|
|
|
|
let search = SearchPointsBuilder::new(collection_name, vector.to_vec(), limit as u64)
|
|
.with_payload(true)
|
|
.build();
|
|
|
|
let results = self
|
|
.inner
|
|
.search_points(search)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
Ok(results
|
|
.result
|
|
.into_iter()
|
|
.filter_map(|p| {
|
|
let entity_type = p
|
|
.payload
|
|
.get(&"entity_type".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let entity_id = p
|
|
.payload
|
|
.get(&"entity_id".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let text = p
|
|
.payload
|
|
.get(&"text".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let extra = p
|
|
.payload
|
|
.get(&"extra".to_string())
|
|
.and_then(|v| Some(Self::extract_string(v)))
|
|
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok());
|
|
|
|
let id =
|
|
p.id.and_then(|id| id.point_id_options)
|
|
.map(|opts| match opts {
|
|
PointIdOptions::Uuid(s) => s,
|
|
PointIdOptions::Num(n) => n.to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
Some(SearchResult {
|
|
id,
|
|
score: p.score,
|
|
payload: EmbedPayload {
|
|
entity_type,
|
|
entity_id,
|
|
text,
|
|
extra,
|
|
},
|
|
})
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
pub async fn search_with_filter(
|
|
&self,
|
|
vector: &[f32],
|
|
entity_type: &str,
|
|
limit: usize,
|
|
filter: Filter,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let collection_name = Self::collection_name(entity_type);
|
|
|
|
let search = SearchPointsBuilder::new(collection_name, vector.to_vec(), limit as u64)
|
|
.with_payload(true)
|
|
.filter(filter)
|
|
.build();
|
|
|
|
let results = self
|
|
.inner
|
|
.search_points(search)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
Ok(results
|
|
.result
|
|
.into_iter()
|
|
.filter_map(|p| {
|
|
let entity_type = p
|
|
.payload
|
|
.get(&"entity_type".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let entity_id = p
|
|
.payload
|
|
.get(&"entity_id".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let text = p
|
|
.payload
|
|
.get(&"text".to_string())
|
|
.map(Self::extract_string)
|
|
.unwrap_or_default();
|
|
|
|
let extra = p
|
|
.payload
|
|
.get(&"extra".to_string())
|
|
.and_then(|v| Some(Self::extract_string(v)))
|
|
.and_then(|s| serde_json::from_str::<serde_json::Value>(&s).ok());
|
|
|
|
let id =
|
|
p.id.and_then(|id| id.point_id_options)
|
|
.map(|opts| match opts {
|
|
PointIdOptions::Uuid(s) => s,
|
|
PointIdOptions::Num(n) => n.to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
Some(SearchResult {
|
|
id,
|
|
score: p.score,
|
|
payload: EmbedPayload {
|
|
entity_type,
|
|
entity_id,
|
|
text,
|
|
extra,
|
|
},
|
|
})
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
pub async fn delete_by_filter(&self, entity_type: &str, entity_id: &str) -> crate::Result<()> {
|
|
let collection_name = Self::collection_name(entity_type);
|
|
|
|
let filter = Filter {
|
|
must: vec![Condition {
|
|
condition_one_of: Some(ConditionOneOf::Field(FieldCondition {
|
|
key: "entity_id".to_string(),
|
|
r#match: Some(Match {
|
|
match_value: Some(MatchValue::Keyword(entity_id.to_string())),
|
|
}),
|
|
..Default::default()
|
|
})),
|
|
}],
|
|
..Default::default()
|
|
};
|
|
|
|
let delete = DeletePointsBuilder::new(collection_name)
|
|
.points(filter)
|
|
.build();
|
|
|
|
self.inner
|
|
.delete_points(delete)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn delete_collection(&self, entity_type: &str) -> crate::Result<()> {
|
|
let name = Self::collection_name(entity_type);
|
|
self.inner
|
|
.delete_collection(name)
|
|
.await
|
|
.map_err(|e| crate::AgentError::Qdrant(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn ensure_memory_collection(&self, dimensions: u64) -> crate::Result<()> {
|
|
self.ensure_collection("memory", dimensions).await
|
|
}
|
|
|
|
pub async fn ensure_skill_collection(&self, dimensions: u64) -> crate::Result<()> {
|
|
self.ensure_collection("skill", dimensions).await
|
|
}
|
|
|
|
pub async fn search_memory(
|
|
&self,
|
|
vector: &[f32],
|
|
limit: usize,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
self.search(vector, "memory", limit).await
|
|
}
|
|
|
|
pub async fn search_skill(
|
|
&self,
|
|
vector: &[f32],
|
|
limit: usize,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
self.search(vector, "skill", limit).await
|
|
}
|
|
}
|