292 lines
9.1 KiB
Rust
292 lines
9.1 KiB
Rust
use rig::client::EmbeddingsClient;
|
|
use rig::embeddings::EmbeddingModel;
|
|
use rig::providers::openai::Client as OpenAiClient;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::embed::qdrant::QdrantClient;
|
|
|
|
pub struct EmbedClient {
|
|
openai: OpenAiClient,
|
|
qdrant: QdrantClient,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct EmbedVector {
|
|
pub id: String,
|
|
pub vector: Vec<f32>,
|
|
pub payload: EmbedPayload,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct EmbedPayload {
|
|
pub entity_type: String,
|
|
pub entity_id: String,
|
|
pub text: String,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub extra: Option<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchResult {
|
|
pub id: String,
|
|
pub score: f32,
|
|
pub payload: EmbedPayload,
|
|
}
|
|
|
|
impl EmbedClient {
|
|
pub fn new(openai: OpenAiClient, qdrant: QdrantClient) -> Self {
|
|
Self { openai, qdrant }
|
|
}
|
|
|
|
pub async fn embed_text(&self, text: &str, model: &str) -> crate::Result<Vec<f32>> {
|
|
let model = self.openai.embedding_model(model);
|
|
let embeddings = model
|
|
.embed_texts(vec![text.to_string()])
|
|
.await
|
|
.map_err(|e| crate::AgentError::OpenAi(format!("embedding failed: {}", e)))?;
|
|
|
|
embeddings
|
|
.first()
|
|
.map(|e| e.vec.iter().map(|v| *v as f32).collect())
|
|
.ok_or_else(|| crate::AgentError::OpenAi("no embedding returned".into()))
|
|
}
|
|
|
|
pub async fn embed_batch(&self, texts: &[String], model: &str) -> crate::Result<Vec<Vec<f32>>> {
|
|
let model = self.openai.embedding_model(model);
|
|
let embeddings = model
|
|
.embed_texts(texts.to_vec())
|
|
.await
|
|
.map_err(|e| crate::AgentError::OpenAi(format!("embedding batch failed: {}", e)))?;
|
|
|
|
tracing::debug!(
|
|
input_count = texts.len(),
|
|
returned_count = embeddings.len(),
|
|
"embed_batch: API returned"
|
|
);
|
|
|
|
let mut result = vec![Vec::new(); texts.len()];
|
|
for (idx, embedding) in embeddings.into_iter().enumerate() {
|
|
if idx < result.len() {
|
|
result[idx] = embedding.vec.iter().map(|v| *v as f32).collect();
|
|
continue;
|
|
}
|
|
tracing::warn!(
|
|
idx,
|
|
"embed_batch: provider returned more embeddings than requested"
|
|
);
|
|
break;
|
|
}
|
|
|
|
// Check for empty results
|
|
let empty_count = result.iter().filter(|v| v.is_empty()).count();
|
|
if empty_count > 0 {
|
|
tracing::warn!(
|
|
empty_count = empty_count,
|
|
total = texts.len(),
|
|
"embed_batch: some embeddings returned empty vectors"
|
|
);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn upsert(&self, points: Vec<EmbedVector>) -> crate::Result<()> {
|
|
self.qdrant.upsert_points(points).await
|
|
}
|
|
|
|
/// Upsert points into a named collection (bypasses entity_type routing).
|
|
pub async fn upsert_to_collection(
|
|
&self,
|
|
collection_name: &str,
|
|
points: Vec<EmbedVector>,
|
|
) -> crate::Result<()> {
|
|
self.qdrant
|
|
.upsert_to_collection(collection_name, points)
|
|
.await
|
|
}
|
|
|
|
pub async fn search(
|
|
&self,
|
|
query: &str,
|
|
entity_type: &str,
|
|
model: &str,
|
|
limit: usize,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let vector = self.embed_text(query, model).await?;
|
|
self.qdrant.search(&vector, entity_type, limit).await
|
|
}
|
|
|
|
pub async fn search_with_filter(
|
|
&self,
|
|
query: &str,
|
|
entity_type: &str,
|
|
model: &str,
|
|
limit: usize,
|
|
filter: qdrant_client::qdrant::Filter,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let vector = self.embed_text(query, model).await?;
|
|
self.qdrant
|
|
.search_with_filter(&vector, entity_type, limit, filter)
|
|
.await
|
|
}
|
|
|
|
pub async fn delete_by_entity_id(
|
|
&self,
|
|
entity_type: &str,
|
|
entity_id: &str,
|
|
) -> crate::Result<()> {
|
|
self.qdrant.delete_by_filter(entity_type, entity_id).await
|
|
}
|
|
|
|
pub async fn ensure_collection(&self, entity_type: &str, dimensions: u64) -> crate::Result<()> {
|
|
self.qdrant.ensure_collection(entity_type, dimensions).await
|
|
}
|
|
|
|
pub async fn ensure_skill_collection(&self, dimensions: u64) -> crate::Result<()> {
|
|
self.qdrant.ensure_skill_collection(dimensions).await
|
|
}
|
|
|
|
/// Ensure a room-specific memory collection exists.
|
|
pub async fn ensure_room_memory_collection(
|
|
&self,
|
|
project_name: &str,
|
|
room_id: &str,
|
|
dimensions: u64,
|
|
) -> crate::Result<()> {
|
|
self.qdrant
|
|
.ensure_room_memory_collection(project_name, room_id, dimensions)
|
|
.await
|
|
}
|
|
|
|
/// Embed and store a conversation memory (message) in Qdrant.
|
|
/// Uses per-room collection: `room:{project_name}:{room_id}`.
|
|
pub async fn embed_memory(
|
|
&self,
|
|
id: &str,
|
|
text: &str,
|
|
project_name: &str,
|
|
room_id: &str,
|
|
user_id: Option<&str>,
|
|
model: &str,
|
|
) -> crate::Result<()> {
|
|
// Compute embedding first to know dimensions
|
|
let vector = self.embed_text(text, model).await?;
|
|
let collection =
|
|
crate::embed::qdrant::QdrantClient::room_memory_collection_name(project_name, room_id);
|
|
// Auto-create the room collection with correct dimensions
|
|
self.qdrant
|
|
.ensure_room_memory_collection(project_name, room_id, vector.len() as u64)
|
|
.await?;
|
|
let point = EmbedVector {
|
|
id: id.to_string(),
|
|
vector,
|
|
payload: EmbedPayload {
|
|
entity_type: "memory".to_string(),
|
|
entity_id: room_id.to_string(),
|
|
text: text.to_string(),
|
|
extra: serde_json::json!({ "user_id": user_id }).into(),
|
|
},
|
|
};
|
|
self.qdrant
|
|
.upsert_to_collection(&collection, vec![point])
|
|
.await
|
|
}
|
|
|
|
/// Search memory embeddings by semantic similarity within a room.
|
|
/// Searches the per-room collection directly — no post-filtering needed.
|
|
pub async fn search_memories(
|
|
&self,
|
|
query: &str,
|
|
model: &str,
|
|
project_name: &str,
|
|
room_id: &str,
|
|
limit: usize,
|
|
dimensions: u64,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let vector = self.embed_text(query, model).await?;
|
|
let collection =
|
|
crate::embed::qdrant::QdrantClient::room_memory_collection_name(project_name, room_id);
|
|
// Ensure collection exists (will be no-op if already created)
|
|
self.qdrant
|
|
.ensure_room_memory_collection(project_name, room_id, dimensions)
|
|
.await?;
|
|
self.qdrant
|
|
.search_collection(&collection, &vector, limit)
|
|
.await
|
|
}
|
|
|
|
pub async fn search_memories_after_seq(
|
|
&self,
|
|
query: &str,
|
|
model: &str,
|
|
project_name: &str,
|
|
room_id: &str,
|
|
limit: usize,
|
|
dimensions: u64,
|
|
after_seq: Option<i64>,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let fetch_limit = if after_seq.is_some() {
|
|
limit.saturating_mul(4).max(limit)
|
|
} else {
|
|
limit
|
|
};
|
|
let mut results = self
|
|
.search_memories(query, model, project_name, room_id, fetch_limit, dimensions)
|
|
.await?;
|
|
|
|
if let Some(cutoff) = after_seq {
|
|
results.retain(|r| {
|
|
r.payload
|
|
.extra
|
|
.as_ref()
|
|
.and_then(|v| v.get("seq"))
|
|
.and_then(|v| v.as_i64())
|
|
.map(|seq| seq > cutoff)
|
|
.unwrap_or(false)
|
|
});
|
|
}
|
|
results.truncate(limit);
|
|
Ok(results)
|
|
}
|
|
|
|
/// Embed and store a skill in Qdrant.
|
|
pub async fn embed_skill(
|
|
&self,
|
|
id: &str,
|
|
name: &str,
|
|
description: &str,
|
|
content: &str,
|
|
project_uuid: &str,
|
|
model: &str,
|
|
) -> crate::Result<()> {
|
|
let text = format!("{}: {} {}", name, description, content);
|
|
let vector = self.embed_text(&text, model).await?;
|
|
let point = EmbedVector {
|
|
id: id.to_string(),
|
|
vector,
|
|
payload: EmbedPayload {
|
|
entity_type: "skill".to_string(),
|
|
entity_id: project_uuid.to_string(),
|
|
text,
|
|
extra: serde_json::json!({ "name": name, "description": description }).into(),
|
|
},
|
|
};
|
|
self.qdrant.upsert_points(vec![point]).await
|
|
}
|
|
|
|
/// Search skill embeddings by semantic similarity within a project.
|
|
pub async fn search_skills(
|
|
&self,
|
|
query: &str,
|
|
model: &str,
|
|
project_uuid: &str,
|
|
limit: usize,
|
|
) -> crate::Result<Vec<SearchResult>> {
|
|
let vector = self.embed_text(query, model).await?;
|
|
let mut results = self.qdrant.search_skill(&vector, limit + 1).await?;
|
|
results.retain(|r| r.payload.entity_id == project_uuid);
|
|
results.truncate(limit);
|
|
Ok(results)
|
|
}
|
|
}
|