gitdataai/lib/service/agent/context.rs
2026-05-30 01:38:40 +08:00

267 lines
9.1 KiB
Rust

use std::time::Duration;
use ai::{
agent::request::{
AgentContextChunk, AgentMessage, AgentRequest,
},
client::AiClient,
rag::{
RagClient, RagConfig, RagDocument,
},
};
use db::sqlx;
use uuid::Uuid;
use super::types::SessionContext;
use crate::error::AppError;
use crate::AppService;
const MAX_HISTORY_MESSAGES: u32 = 50;
const MAX_HISTORY_CHARS: usize = 500_000;
const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000;
impl AppService {
pub(crate) async fn agent_build_request(
&self,
ai_client: &AiClient,
ctx: &SessionContext,
conversation_id: Option<Uuid>,
input: String,
timeout_secs: Option<u64>,
) -> Result<AgentRequest, AppError> {
let mut request = AgentRequest::new(input.clone());
if let Some(secs) = timeout_secs {
request = request.with_timeout(Duration::from_secs(secs));
}
if let Some(conv_id) = conversation_id {
let mut all_messages = Vec::new();
let compacted: Option<String> = sqlx::query_scalar(
"SELECT compacted_summary FROM agent_conversation WHERE id = $1",
)
.bind(conv_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.flatten();
if let Some(summary) = compacted {
all_messages.push(AgentMessage::User(format!(
"[Previous conversation summary]\n{}\n[End of summary — messages below are the most recent verbatim exchanges]",
summary
)));
}
let messages = self
.agent_load_conversation_messages(conv_id)
.await?;
all_messages.extend(messages);
request = request.with_messages(all_messages);
}
let kb_context = self
.agent_load_knowledge_base(ai_client, ctx, &input)
.await?;
let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?;
let mut all_context = kb_context;
if !memories_text.is_empty() {
all_context.push(AgentContextChunk::new(
"long_term_memory",
memories_text,
));
}
if !all_context.is_empty() {
request = request.with_context(all_context);
}
Ok(request)
}
pub(crate) async fn agent_load_conversation_messages(
&self,
conversation_id: Uuid,
) -> Result<Vec<AgentMessage>, AppError> {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT m.role, m.content \
FROM agent_message m \
INNER JOIN agent_conversation c ON c.id = m.conversation \
WHERE m.conversation = $1 \
AND m.deleted_at IS NULL \
AND m.status = 'completed' \
AND c.deleted_at IS NULL \
ORDER BY m.created_at ASC \
LIMIT $2",
)
.bind(conversation_id)
.bind(MAX_HISTORY_MESSAGES as i64)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let messages: Vec<AgentMessage> = rows
.into_iter()
.map(|(role, content)| match role.as_str() {
"assistant" => AgentMessage::Assistant(content),
_ => AgentMessage::User(content),
})
.collect();
let mut total_chars: usize = messages
.iter()
.map(|m| match m {
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
})
.sum();
let mut result = messages;
while total_chars > MAX_HISTORY_CHARS && !result.is_empty() {
let removed = result.remove(0);
let removed_len = match &removed {
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
};
total_chars = total_chars.saturating_sub(removed_len);
tracing::debug!(
removed_chars = removed_len,
remaining_chars = total_chars,
"trimmed oldest message to fit context window"
);
}
let mut estimated_tokens: u64 = result
.iter()
.map(|m| match m {
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
ai::agent::helpers::estimate_tokens(c)
}
})
.sum();
let mut trimmed_for_tokens = 0usize;
while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() {
let removed = result.remove(0);
estimated_tokens -= match &removed {
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
ai::agent::helpers::estimate_tokens(c)
}
};
trimmed_for_tokens += 1;
}
if trimmed_for_tokens > 0 {
tracing::info!(
trimmed = trimmed_for_tokens,
estimated_tokens = estimated_tokens,
"trimmed oldest messages to stay within token budget"
);
}
Ok(result)
}
pub(crate) async fn agent_load_knowledge_base(
&self,
ai_client: &AiClient,
ctx: &SessionContext,
query: &str,
) -> Result<Vec<AgentContextChunk>, AppError> {
let knowledge_base_ids: Option<String> = sqlx::query_scalar(
"SELECT knowledge_base_ids FROM agent_session WHERE id = $1",
)
.bind(ctx.session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.flatten();
let kb_id_str = match knowledge_base_ids {
Some(ref s) if !s.trim().is_empty() => s.clone(),
_ => return Ok(Vec::new()),
};
let kb_ids: Vec<Uuid> = kb_id_str
.split(',')
.filter_map(|s| Uuid::parse_str(s.trim()).ok())
.collect();
if kb_ids.is_empty() {
return Ok(Vec::new());
}
let qdrant_url = self
.config
.qdrant_url()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let vector_size = self
.config
.get_embed_model_dimensions()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
.map_err(|e| AppError::InternalServerError(e.to_string()))?
.with_api_key(
self.config
.qdrant_api_key()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
);
let rag = RagClient::connect(ai_client, rag_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let mut all_hits: Vec<AgentContextChunk> = Vec::new();
for kb_id in &kb_ids {
let session_key = format!("kb:{kb_id}");
match rag.search_session(&session_key, query).await {
Ok(hits) => {
for hit in hits {
all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit {
id: hit.id,
session_id: hit.session_id,
score: hit.score,
content: hit.content,
metadata: hit.metadata,
}));
}
}
Err(e) => {
tracing::warn!(
kb_id = %kb_id,
error = %e,
"agent: RAG search failed for knowledge base, skipping"
);
}
}
}
Ok(all_hits)
}
#[allow(dead_code)]
pub(crate) async fn agent_upsert_knowledge(
&self,
ai_client: &AiClient,
kb_id: Uuid,
documents: Vec<RagDocument>,
) -> Result<(), AppError> {
let qdrant_url = self
.config
.qdrant_url()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let vector_size = self
.config
.get_embed_model_dimensions()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
.map_err(|e| AppError::InternalServerError(e.to_string()))?
.with_api_key(
self.config
.qdrant_api_key()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
);
let rag = RagClient::connect(ai_client, rag_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
rag.ensure_collection()
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let session_key = format!("kb:{kb_id}");
rag.upsert_documents(&session_key, documents)
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
Ok(())
}
}