267 lines
9.1 KiB
Rust
267 lines
9.1 KiB
Rust
use std::time::Duration;
|
|
|
|
use ai::{
|
|
agent::request::{
|
|
AgentContextChunk, AgentMessage, AgentRequest,
|
|
},
|
|
client::AiClient,
|
|
rag::{
|
|
RagClient, RagConfig, RagDocument,
|
|
},
|
|
};
|
|
use db::sqlx;
|
|
use uuid::Uuid;
|
|
|
|
use super::types::SessionContext;
|
|
use crate::error::AppError;
|
|
use crate::AppService;
|
|
const MAX_HISTORY_MESSAGES: u32 = 50;
|
|
const MAX_HISTORY_CHARS: usize = 500_000;
|
|
const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000;
|
|
|
|
impl AppService {
|
|
pub(crate) async fn agent_build_request(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
ctx: &SessionContext,
|
|
conversation_id: Option<Uuid>,
|
|
input: String,
|
|
timeout_secs: Option<u64>,
|
|
) -> Result<AgentRequest, AppError> {
|
|
let mut request = AgentRequest::new(input.clone());
|
|
if let Some(secs) = timeout_secs {
|
|
request = request.with_timeout(Duration::from_secs(secs));
|
|
}
|
|
if let Some(conv_id) = conversation_id {
|
|
let mut all_messages = Vec::new();
|
|
let compacted: Option<String> = sqlx::query_scalar(
|
|
"SELECT compacted_summary FROM agent_conversation WHERE id = $1",
|
|
)
|
|
.bind(conv_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.flatten();
|
|
|
|
if let Some(summary) = compacted {
|
|
all_messages.push(AgentMessage::User(format!(
|
|
"[Previous conversation summary]\n{}\n[End of summary — messages below are the most recent verbatim exchanges]",
|
|
summary
|
|
)));
|
|
}
|
|
|
|
let messages = self
|
|
.agent_load_conversation_messages(conv_id)
|
|
.await?;
|
|
all_messages.extend(messages);
|
|
|
|
request = request.with_messages(all_messages);
|
|
}
|
|
let kb_context = self
|
|
.agent_load_knowledge_base(ai_client, ctx, &input)
|
|
.await?;
|
|
let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?;
|
|
|
|
let mut all_context = kb_context;
|
|
if !memories_text.is_empty() {
|
|
all_context.push(AgentContextChunk::new(
|
|
"long_term_memory",
|
|
memories_text,
|
|
));
|
|
}
|
|
if !all_context.is_empty() {
|
|
request = request.with_context(all_context);
|
|
}
|
|
|
|
Ok(request)
|
|
}
|
|
pub(crate) async fn agent_load_conversation_messages(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
) -> Result<Vec<AgentMessage>, AppError> {
|
|
let rows: Vec<(String, String)> = sqlx::query_as(
|
|
"SELECT m.role, m.content \
|
|
FROM agent_message m \
|
|
INNER JOIN agent_conversation c ON c.id = m.conversation \
|
|
WHERE m.conversation = $1 \
|
|
AND m.deleted_at IS NULL \
|
|
AND m.status = 'completed' \
|
|
AND c.deleted_at IS NULL \
|
|
ORDER BY m.created_at ASC \
|
|
LIMIT $2",
|
|
)
|
|
.bind(conversation_id)
|
|
.bind(MAX_HISTORY_MESSAGES as i64)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let messages: Vec<AgentMessage> = rows
|
|
.into_iter()
|
|
.map(|(role, content)| match role.as_str() {
|
|
"assistant" => AgentMessage::Assistant(content),
|
|
_ => AgentMessage::User(content),
|
|
})
|
|
.collect();
|
|
let mut total_chars: usize = messages
|
|
.iter()
|
|
.map(|m| match m {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
|
|
})
|
|
.sum();
|
|
|
|
let mut result = messages;
|
|
while total_chars > MAX_HISTORY_CHARS && !result.is_empty() {
|
|
let removed = result.remove(0);
|
|
let removed_len = match &removed {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
|
|
};
|
|
total_chars = total_chars.saturating_sub(removed_len);
|
|
tracing::debug!(
|
|
removed_chars = removed_len,
|
|
remaining_chars = total_chars,
|
|
"trimmed oldest message to fit context window"
|
|
);
|
|
}
|
|
let mut estimated_tokens: u64 = result
|
|
.iter()
|
|
.map(|m| match m {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
|
|
ai::agent::helpers::estimate_tokens(c)
|
|
}
|
|
})
|
|
.sum();
|
|
let mut trimmed_for_tokens = 0usize;
|
|
while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() {
|
|
let removed = result.remove(0);
|
|
estimated_tokens -= match &removed {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
|
|
ai::agent::helpers::estimate_tokens(c)
|
|
}
|
|
};
|
|
trimmed_for_tokens += 1;
|
|
}
|
|
if trimmed_for_tokens > 0 {
|
|
tracing::info!(
|
|
trimmed = trimmed_for_tokens,
|
|
estimated_tokens = estimated_tokens,
|
|
"trimmed oldest messages to stay within token budget"
|
|
);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
pub(crate) async fn agent_load_knowledge_base(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
ctx: &SessionContext,
|
|
query: &str,
|
|
) -> Result<Vec<AgentContextChunk>, AppError> {
|
|
let knowledge_base_ids: Option<String> = sqlx::query_scalar(
|
|
"SELECT knowledge_base_ids FROM agent_session WHERE id = $1",
|
|
)
|
|
.bind(ctx.session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.flatten();
|
|
|
|
let kb_id_str = match knowledge_base_ids {
|
|
Some(ref s) if !s.trim().is_empty() => s.clone(),
|
|
_ => return Ok(Vec::new()),
|
|
};
|
|
|
|
let kb_ids: Vec<Uuid> = kb_id_str
|
|
.split(',')
|
|
.filter_map(|s| Uuid::parse_str(s.trim()).ok())
|
|
.collect();
|
|
|
|
if kb_ids.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let qdrant_url = self
|
|
.config
|
|
.qdrant_url()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let vector_size = self
|
|
.config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?
|
|
.with_api_key(
|
|
self.config
|
|
.qdrant_api_key()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
);
|
|
|
|
let rag = RagClient::connect(ai_client, rag_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let mut all_hits: Vec<AgentContextChunk> = Vec::new();
|
|
for kb_id in &kb_ids {
|
|
let session_key = format!("kb:{kb_id}");
|
|
match rag.search_session(&session_key, query).await {
|
|
Ok(hits) => {
|
|
for hit in hits {
|
|
all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit {
|
|
id: hit.id,
|
|
session_id: hit.session_id,
|
|
score: hit.score,
|
|
content: hit.content,
|
|
metadata: hit.metadata,
|
|
}));
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
kb_id = %kb_id,
|
|
error = %e,
|
|
"agent: RAG search failed for knowledge base, skipping"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(all_hits)
|
|
}
|
|
#[allow(dead_code)]
|
|
pub(crate) async fn agent_upsert_knowledge(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
kb_id: Uuid,
|
|
documents: Vec<RagDocument>,
|
|
) -> Result<(), AppError> {
|
|
let qdrant_url = self
|
|
.config
|
|
.qdrant_url()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let vector_size = self
|
|
.config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?
|
|
.with_api_key(
|
|
self.config
|
|
.qdrant_api_key()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
);
|
|
|
|
let rag = RagClient::connect(ai_client, rag_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
rag.ensure_collection()
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let session_key = format!("kb:{kb_id}");
|
|
rag.upsert_documents(&session_key, documents)
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|