gitdataai/lib/service/agent/context.rs

459 lines
15 KiB
Rust

use std::time::Duration;
use ai::{
agent::request::{
AgentContextChunk, AgentMessage, AgentRequest,
},
client::AiClient,
rag::{
RagClient, RagConfig, RagDocument,
},
};
use db::sqlx;
use model::repos::RepoModel;
use uuid::Uuid;
use super::types::SessionContext;
use crate::error::AppError;
use crate::AppService;
const MAX_HISTORY_MESSAGES: u32 = 50;
const MAX_HISTORY_CHARS: usize = 500_000;
const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000;
impl AppService {
pub(crate) async fn agent_build_request(
&self,
ai_client: &AiClient,
ctx: &SessionContext,
conversation_id: Option<Uuid>,
input: String,
timeout_secs: Option<u64>,
) -> Result<AgentRequest, AppError> {
let mut request = AgentRequest::new(input.clone());
if let Some(secs) = timeout_secs {
request = request.with_timeout(Duration::from_secs(secs));
}
if let Some(conv_id) = conversation_id {
let mut all_messages = Vec::new();
let compacted: Option<String> = sqlx::query_scalar(
"SELECT compacted_summary FROM agent_conversation WHERE id = $1",
)
.bind(conv_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.flatten();
if let Some(summary) = compacted {
all_messages.push(AgentMessage::User(format!(
"[Previous conversation summary]\n{}\n[End of summary — messages below are the most recent verbatim exchanges]",
summary
)));
}
let messages = self
.agent_load_conversation_messages(conv_id)
.await?;
all_messages.extend(messages);
request = request.with_messages(all_messages);
}
let kb_context = self
.agent_load_knowledge_base(ai_client, ctx, &input)
.await?;
let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?;
let mut all_context = kb_context;
if !memories_text.is_empty() {
all_context.push(AgentContextChunk::new(
"long_term_memory",
memories_text,
));
}
// Inject repo context from @[repo:...] mentions in the input.
if let Some(workspace_id) = ctx.workspace_id {
let repo_context = self
.agent_resolve_mentioned_repos(workspace_id, &input)
.await
.unwrap_or_else(|e| {
tracing::warn!(error = %e, "failed to resolve mentioned repos");
Vec::new()
});
all_context.extend(repo_context);
}
if !all_context.is_empty() {
request = request.with_context(all_context);
}
Ok(request)
}
pub(crate) async fn agent_load_conversation_messages(
&self,
conversation_id: Uuid,
) -> Result<Vec<AgentMessage>, AppError> {
let rows: Vec<(String, String)> = sqlx::query_as(
"SELECT m.role, m.content \
FROM agent_message m \
INNER JOIN agent_conversation c ON c.id = m.conversation \
WHERE m.conversation = $1 \
AND m.deleted_at IS NULL \
AND m.status = 'completed' \
AND c.deleted_at IS NULL \
ORDER BY m.created_at ASC \
LIMIT $2",
)
.bind(conversation_id)
.bind(MAX_HISTORY_MESSAGES as i64)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let messages: Vec<AgentMessage> = rows
.into_iter()
.map(|(role, content)| match role.as_str() {
"assistant" => AgentMessage::Assistant(content),
_ => AgentMessage::User(content),
})
.collect();
let mut total_chars: usize = messages
.iter()
.map(|m| match m {
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
})
.sum();
let mut result = messages;
while total_chars > MAX_HISTORY_CHARS && !result.is_empty() {
let removed = result.remove(0);
let removed_len = match &removed {
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
};
total_chars = total_chars.saturating_sub(removed_len);
tracing::debug!(
removed_chars = removed_len,
remaining_chars = total_chars,
"trimmed oldest message to fit context window"
);
}
let mut estimated_tokens: u64 = result
.iter()
.map(|m| match m {
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
ai::agent::helpers::estimate_tokens(c)
}
})
.sum();
let mut trimmed_for_tokens = 0usize;
while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() {
let removed = result.remove(0);
estimated_tokens -= match &removed {
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
ai::agent::helpers::estimate_tokens(c)
}
};
trimmed_for_tokens += 1;
}
if trimmed_for_tokens > 0 {
tracing::info!(
trimmed = trimmed_for_tokens,
estimated_tokens = estimated_tokens,
"trimmed oldest messages to stay within token budget"
);
}
Ok(result)
}
pub(crate) async fn agent_load_knowledge_base(
&self,
ai_client: &AiClient,
ctx: &SessionContext,
query: &str,
) -> Result<Vec<AgentContextChunk>, AppError> {
let knowledge_base_ids: Option<String> = sqlx::query_scalar(
"SELECT knowledge_base_ids FROM agent_session WHERE id = $1",
)
.bind(ctx.session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.flatten();
let kb_id_str = match knowledge_base_ids {
Some(ref s) if !s.trim().is_empty() => s.clone(),
_ => return Ok(Vec::new()),
};
let kb_ids: Vec<Uuid> = kb_id_str
.split(',')
.filter_map(|s| Uuid::parse_str(s.trim()).ok())
.collect();
if kb_ids.is_empty() {
return Ok(Vec::new());
}
let qdrant_url = self
.config
.qdrant_url()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let vector_size = self
.config
.get_embed_model_dimensions()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
.map_err(|e| AppError::InternalServerError(e.to_string()))?
.with_api_key(
self.config
.qdrant_api_key()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
);
let rag = RagClient::connect(ai_client, rag_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let mut all_hits: Vec<AgentContextChunk> = Vec::new();
for kb_id in &kb_ids {
let session_key = format!("kb:{kb_id}");
match rag.search_session(&session_key, query).await {
Ok(hits) => {
for hit in hits {
all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit {
id: hit.id,
session_id: hit.session_id,
score: hit.score,
content: hit.content,
metadata: hit.metadata,
}));
}
}
Err(e) => {
tracing::warn!(
kb_id = %kb_id,
error = %e,
"agent: RAG search failed for knowledge base, skipping"
);
}
}
}
Ok(all_hits)
}
/// Parse `@[repo:name:label]` mentions from the input and resolve them to
/// AgentContextChunks with repo metadata from the database.
pub(crate) async fn agent_resolve_mentioned_repos(
&self,
workspace_id: Uuid,
input: &str,
) -> Result<Vec<AgentContextChunk>, AppError> {
let repo_names = extract_repo_mentions(input);
if repo_names.is_empty() {
return Ok(Vec::new());
}
let mut chunks = Vec::with_capacity(repo_names.len());
for name in &repo_names {
match self.resolve_repo_by_name(workspace_id, name).await {
Ok(Some(repo)) => {
let content = format_repo_context(&repo);
chunks.push(AgentContextChunk::new(
format!("repo:{}", repo.name),
content,
));
tracing::info!(
repo = %repo.name,
"injected repo context from @mention"
);
}
Ok(None) => {
tracing::debug!(
repo_name = %name,
"mentioned repo not found, skipping"
);
}
Err(e) => {
tracing::warn!(
repo_name = %name,
error = %e,
"failed to look up mentioned repo"
);
}
}
}
Ok(chunks)
}
/// Look up a single repo by workspace + name.
async fn resolve_repo_by_name(
&self,
workspace_id: Uuid,
name: &str,
) -> Result<Option<RepoModel>, AppError> {
let repo = sqlx::query_as::<_, RepoModel>(
"SELECT id, wk, name, description, default_branch, visibility, \
size_bytes, is_archived, is_template, is_mirror, created_by, \
storage_path, created_at, updated_at, deleted_at \
FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL",
)
.bind(workspace_id)
.bind(name)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(repo)
}
#[allow(dead_code)]
pub(crate) async fn agent_upsert_knowledge(
&self,
ai_client: &AiClient,
kb_id: Uuid,
documents: Vec<RagDocument>,
) -> Result<(), AppError> {
let qdrant_url = self
.config
.qdrant_url()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let vector_size = self
.config
.get_embed_model_dimensions()
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
.map_err(|e| AppError::InternalServerError(e.to_string()))?
.with_api_key(
self.config
.qdrant_api_key()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
);
let rag = RagClient::connect(ai_client, rag_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
rag.ensure_collection()
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let session_key = format!("kb:{kb_id}");
rag.upsert_documents(&session_key, documents)
.await
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
Ok(())
}
}
// ---------------------------------------------------------------------------
// Repo mention helpers
// ---------------------------------------------------------------------------
/// Extract unique repo names from `@[repo:name:label]` mentions in the input.
fn extract_repo_mentions(input: &str) -> Vec<String> {
let mut names = Vec::new();
let mut seen = std::collections::HashSet::new();
// Simple manual parser for @[repo:name:label]
let bytes = input.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
// Look for "@[repo:"
if i + 7 < len
&& bytes[i] == b'@'
&& bytes[i + 1] == b'['
&& bytes[i + 2] == b'r'
&& bytes[i + 3] == b'e'
&& bytes[i + 4] == b'p'
&& bytes[i + 5] == b'o'
&& bytes[i + 6] == b':'
{
let start = i + 7; // after "@[repo:"
// Find the closing ']' — but skip ':' and the label part.
// Format is @[repo:name:label], we want "name".
if let Some(name_end) = input[start..].find(':') {
let name = &input[start..start + name_end];
if !name.is_empty() && seen.insert(name.to_string()) {
names.push(name.to_string());
}
// Skip past the closing ']'
if let Some(closing) = input[start + name_end..].find(']') {
i = start + name_end + closing + 1;
continue;
}
}
}
i += 1;
}
names
}
/// Format a RepoModel into a concise context string for the AI.
fn format_repo_context(repo: &RepoModel) -> String {
let mut s = format!(
"Repository: {} (id: {})\n",
repo.name, repo.id
);
if let Some(ref desc) = repo.description {
if !desc.trim().is_empty() {
s.push_str(&format!("Description: {}\n", desc.trim()));
}
}
s.push_str(&format!("Default branch: {}\n", repo.default_branch));
s.push_str(&format!("Visibility: {}\n", repo.visibility));
if repo.is_archived {
s.push_str("Status: archived\n");
}
if repo.is_template {
s.push_str("Kind: template repository\n");
}
s.push_str(&format!(
"Created: {}\n",
repo.created_at.format("%Y-%m-%d")
));
s
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_repo_mentions_single() {
let input = "@[repo:my-repo:my-repo] what's the latest commit?";
let names = extract_repo_mentions(input);
assert_eq!(names, vec!["my-repo"]);
}
#[test]
fn test_extract_repo_mentions_multiple() {
let input = "compare @[repo:backend:backend] with @[repo:frontend:frontend]";
let names = extract_repo_mentions(input);
assert_eq!(names, vec!["backend", "frontend"]);
}
#[test]
fn test_extract_repo_mentions_dedupe() {
let input = "look at @[repo:a:a] and also @[repo:a:a] please";
let names = extract_repo_mentions(input);
assert_eq!(names, vec!["a"]);
}
#[test]
fn test_extract_repo_mentions_none() {
let input = "hello world no mentions here";
let names = extract_repo_mentions(input);
assert!(names.is_empty());
}
#[test]
fn test_extract_ignores_other_mention_types() {
let input = "@[user:abc:John] and @[repo:myrepo:myrepo]";
let names = extract_repo_mentions(input);
assert_eq!(names, vec!["myrepo"]);
}
}