459 lines
15 KiB
Rust
459 lines
15 KiB
Rust
use std::time::Duration;
|
|
|
|
use ai::{
|
|
agent::request::{
|
|
AgentContextChunk, AgentMessage, AgentRequest,
|
|
},
|
|
client::AiClient,
|
|
rag::{
|
|
RagClient, RagConfig, RagDocument,
|
|
},
|
|
};
|
|
use db::sqlx;
|
|
use model::repos::RepoModel;
|
|
use uuid::Uuid;
|
|
|
|
use super::types::SessionContext;
|
|
use crate::error::AppError;
|
|
use crate::AppService;
|
|
const MAX_HISTORY_MESSAGES: u32 = 50;
|
|
const MAX_HISTORY_CHARS: usize = 500_000;
|
|
const MAX_HISTORY_ESTIMATED_TOKENS: u64 = 64_000;
|
|
|
|
impl AppService {
|
|
pub(crate) async fn agent_build_request(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
ctx: &SessionContext,
|
|
conversation_id: Option<Uuid>,
|
|
input: String,
|
|
timeout_secs: Option<u64>,
|
|
) -> Result<AgentRequest, AppError> {
|
|
let mut request = AgentRequest::new(input.clone());
|
|
if let Some(secs) = timeout_secs {
|
|
request = request.with_timeout(Duration::from_secs(secs));
|
|
}
|
|
if let Some(conv_id) = conversation_id {
|
|
let mut all_messages = Vec::new();
|
|
let compacted: Option<String> = sqlx::query_scalar(
|
|
"SELECT compacted_summary FROM agent_conversation WHERE id = $1",
|
|
)
|
|
.bind(conv_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.flatten();
|
|
|
|
if let Some(summary) = compacted {
|
|
all_messages.push(AgentMessage::User(format!(
|
|
"[Previous conversation summary]\n{}\n[End of summary — messages below are the most recent verbatim exchanges]",
|
|
summary
|
|
)));
|
|
}
|
|
|
|
let messages = self
|
|
.agent_load_conversation_messages(conv_id)
|
|
.await?;
|
|
all_messages.extend(messages);
|
|
|
|
request = request.with_messages(all_messages);
|
|
}
|
|
let kb_context = self
|
|
.agent_load_knowledge_base(ai_client, ctx, &input)
|
|
.await?;
|
|
let (memories_text, _memory_rows) = self.agent_load_memories(ctx.session_id).await?;
|
|
|
|
let mut all_context = kb_context;
|
|
if !memories_text.is_empty() {
|
|
all_context.push(AgentContextChunk::new(
|
|
"long_term_memory",
|
|
memories_text,
|
|
));
|
|
}
|
|
|
|
// Inject repo context from @[repo:...] mentions in the input.
|
|
if let Some(workspace_id) = ctx.workspace_id {
|
|
let repo_context = self
|
|
.agent_resolve_mentioned_repos(workspace_id, &input)
|
|
.await
|
|
.unwrap_or_else(|e| {
|
|
tracing::warn!(error = %e, "failed to resolve mentioned repos");
|
|
Vec::new()
|
|
});
|
|
all_context.extend(repo_context);
|
|
}
|
|
|
|
if !all_context.is_empty() {
|
|
request = request.with_context(all_context);
|
|
}
|
|
|
|
Ok(request)
|
|
}
|
|
pub(crate) async fn agent_load_conversation_messages(
|
|
&self,
|
|
conversation_id: Uuid,
|
|
) -> Result<Vec<AgentMessage>, AppError> {
|
|
let rows: Vec<(String, String)> = sqlx::query_as(
|
|
"SELECT m.role, m.content \
|
|
FROM agent_message m \
|
|
INNER JOIN agent_conversation c ON c.id = m.conversation \
|
|
WHERE m.conversation = $1 \
|
|
AND m.deleted_at IS NULL \
|
|
AND m.status = 'completed' \
|
|
AND c.deleted_at IS NULL \
|
|
ORDER BY m.created_at ASC \
|
|
LIMIT $2",
|
|
)
|
|
.bind(conversation_id)
|
|
.bind(MAX_HISTORY_MESSAGES as i64)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
let messages: Vec<AgentMessage> = rows
|
|
.into_iter()
|
|
.map(|(role, content)| match role.as_str() {
|
|
"assistant" => AgentMessage::Assistant(content),
|
|
_ => AgentMessage::User(content),
|
|
})
|
|
.collect();
|
|
let mut total_chars: usize = messages
|
|
.iter()
|
|
.map(|m| match m {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
|
|
})
|
|
.sum();
|
|
|
|
let mut result = messages;
|
|
while total_chars > MAX_HISTORY_CHARS && !result.is_empty() {
|
|
let removed = result.remove(0);
|
|
let removed_len = match &removed {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => c.len(),
|
|
};
|
|
total_chars = total_chars.saturating_sub(removed_len);
|
|
tracing::debug!(
|
|
removed_chars = removed_len,
|
|
remaining_chars = total_chars,
|
|
"trimmed oldest message to fit context window"
|
|
);
|
|
}
|
|
let mut estimated_tokens: u64 = result
|
|
.iter()
|
|
.map(|m| match m {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
|
|
ai::agent::helpers::estimate_tokens(c)
|
|
}
|
|
})
|
|
.sum();
|
|
let mut trimmed_for_tokens = 0usize;
|
|
while estimated_tokens > MAX_HISTORY_ESTIMATED_TOKENS && !result.is_empty() {
|
|
let removed = result.remove(0);
|
|
estimated_tokens -= match &removed {
|
|
AgentMessage::User(c) | AgentMessage::Assistant(c) => {
|
|
ai::agent::helpers::estimate_tokens(c)
|
|
}
|
|
};
|
|
trimmed_for_tokens += 1;
|
|
}
|
|
if trimmed_for_tokens > 0 {
|
|
tracing::info!(
|
|
trimmed = trimmed_for_tokens,
|
|
estimated_tokens = estimated_tokens,
|
|
"trimmed oldest messages to stay within token budget"
|
|
);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
pub(crate) async fn agent_load_knowledge_base(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
ctx: &SessionContext,
|
|
query: &str,
|
|
) -> Result<Vec<AgentContextChunk>, AppError> {
|
|
let knowledge_base_ids: Option<String> = sqlx::query_scalar(
|
|
"SELECT knowledge_base_ids FROM agent_session WHERE id = $1",
|
|
)
|
|
.bind(ctx.session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.flatten();
|
|
|
|
let kb_id_str = match knowledge_base_ids {
|
|
Some(ref s) if !s.trim().is_empty() => s.clone(),
|
|
_ => return Ok(Vec::new()),
|
|
};
|
|
|
|
let kb_ids: Vec<Uuid> = kb_id_str
|
|
.split(',')
|
|
.filter_map(|s| Uuid::parse_str(s.trim()).ok())
|
|
.collect();
|
|
|
|
if kb_ids.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let qdrant_url = self
|
|
.config
|
|
.qdrant_url()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let vector_size = self
|
|
.config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?
|
|
.with_api_key(
|
|
self.config
|
|
.qdrant_api_key()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
);
|
|
|
|
let rag = RagClient::connect(ai_client, rag_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let mut all_hits: Vec<AgentContextChunk> = Vec::new();
|
|
for kb_id in &kb_ids {
|
|
let session_key = format!("kb:{kb_id}");
|
|
match rag.search_session(&session_key, query).await {
|
|
Ok(hits) => {
|
|
for hit in hits {
|
|
all_hits.push(AgentContextChunk::from(ai::rag::RagSearchHit {
|
|
id: hit.id,
|
|
session_id: hit.session_id,
|
|
score: hit.score,
|
|
content: hit.content,
|
|
metadata: hit.metadata,
|
|
}));
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
kb_id = %kb_id,
|
|
error = %e,
|
|
"agent: RAG search failed for knowledge base, skipping"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(all_hits)
|
|
}
|
|
/// Parse `@[repo:name:label]` mentions from the input and resolve them to
|
|
/// AgentContextChunks with repo metadata from the database.
|
|
pub(crate) async fn agent_resolve_mentioned_repos(
|
|
&self,
|
|
workspace_id: Uuid,
|
|
input: &str,
|
|
) -> Result<Vec<AgentContextChunk>, AppError> {
|
|
let repo_names = extract_repo_mentions(input);
|
|
if repo_names.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let mut chunks = Vec::with_capacity(repo_names.len());
|
|
for name in &repo_names {
|
|
match self.resolve_repo_by_name(workspace_id, name).await {
|
|
Ok(Some(repo)) => {
|
|
let content = format_repo_context(&repo);
|
|
chunks.push(AgentContextChunk::new(
|
|
format!("repo:{}", repo.name),
|
|
content,
|
|
));
|
|
tracing::info!(
|
|
repo = %repo.name,
|
|
"injected repo context from @mention"
|
|
);
|
|
}
|
|
Ok(None) => {
|
|
tracing::debug!(
|
|
repo_name = %name,
|
|
"mentioned repo not found, skipping"
|
|
);
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
repo_name = %name,
|
|
error = %e,
|
|
"failed to look up mentioned repo"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(chunks)
|
|
}
|
|
|
|
/// Look up a single repo by workspace + name.
|
|
async fn resolve_repo_by_name(
|
|
&self,
|
|
workspace_id: Uuid,
|
|
name: &str,
|
|
) -> Result<Option<RepoModel>, AppError> {
|
|
let repo = sqlx::query_as::<_, RepoModel>(
|
|
"SELECT id, wk, name, description, default_branch, visibility, \
|
|
size_bytes, is_archived, is_template, is_mirror, created_by, \
|
|
storage_path, created_at, updated_at, deleted_at \
|
|
FROM repo WHERE wk = $1 AND name = $2 AND deleted_at IS NULL",
|
|
)
|
|
.bind(workspace_id)
|
|
.bind(name)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(repo)
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub(crate) async fn agent_upsert_knowledge(
|
|
&self,
|
|
ai_client: &AiClient,
|
|
kb_id: Uuid,
|
|
documents: Vec<RagDocument>,
|
|
) -> Result<(), AppError> {
|
|
let qdrant_url = self
|
|
.config
|
|
.qdrant_url()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let vector_size = self
|
|
.config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let rag_config = RagConfig::new(qdrant_url, "agent_knowledge", vector_size)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?
|
|
.with_api_key(
|
|
self.config
|
|
.qdrant_api_key()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
);
|
|
|
|
let rag = RagClient::connect(ai_client, rag_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
rag.ensure_collection()
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let session_key = format!("kb:{kb_id}");
|
|
rag.upsert_documents(&session_key, documents)
|
|
.await
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Repo mention helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Extract unique repo names from `@[repo:name:label]` mentions in the input.
|
|
fn extract_repo_mentions(input: &str) -> Vec<String> {
|
|
let mut names = Vec::new();
|
|
let mut seen = std::collections::HashSet::new();
|
|
|
|
// Simple manual parser for @[repo:name:label]
|
|
let bytes = input.as_bytes();
|
|
let len = bytes.len();
|
|
let mut i = 0;
|
|
|
|
while i < len {
|
|
// Look for "@[repo:"
|
|
if i + 7 < len
|
|
&& bytes[i] == b'@'
|
|
&& bytes[i + 1] == b'['
|
|
&& bytes[i + 2] == b'r'
|
|
&& bytes[i + 3] == b'e'
|
|
&& bytes[i + 4] == b'p'
|
|
&& bytes[i + 5] == b'o'
|
|
&& bytes[i + 6] == b':'
|
|
{
|
|
let start = i + 7; // after "@[repo:"
|
|
// Find the closing ']' — but skip ':' and the label part.
|
|
// Format is @[repo:name:label], we want "name".
|
|
if let Some(name_end) = input[start..].find(':') {
|
|
let name = &input[start..start + name_end];
|
|
if !name.is_empty() && seen.insert(name.to_string()) {
|
|
names.push(name.to_string());
|
|
}
|
|
// Skip past the closing ']'
|
|
if let Some(closing) = input[start + name_end..].find(']') {
|
|
i = start + name_end + closing + 1;
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
i += 1;
|
|
}
|
|
|
|
names
|
|
}
|
|
|
|
/// Format a RepoModel into a concise context string for the AI.
|
|
fn format_repo_context(repo: &RepoModel) -> String {
|
|
let mut s = format!(
|
|
"Repository: {} (id: {})\n",
|
|
repo.name, repo.id
|
|
);
|
|
if let Some(ref desc) = repo.description {
|
|
if !desc.trim().is_empty() {
|
|
s.push_str(&format!("Description: {}\n", desc.trim()));
|
|
}
|
|
}
|
|
s.push_str(&format!("Default branch: {}\n", repo.default_branch));
|
|
s.push_str(&format!("Visibility: {}\n", repo.visibility));
|
|
if repo.is_archived {
|
|
s.push_str("Status: archived\n");
|
|
}
|
|
if repo.is_template {
|
|
s.push_str("Kind: template repository\n");
|
|
}
|
|
s.push_str(&format!(
|
|
"Created: {}\n",
|
|
repo.created_at.format("%Y-%m-%d")
|
|
));
|
|
s
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_extract_repo_mentions_single() {
|
|
let input = "@[repo:my-repo:my-repo] what's the latest commit?";
|
|
let names = extract_repo_mentions(input);
|
|
assert_eq!(names, vec!["my-repo"]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_repo_mentions_multiple() {
|
|
let input = "compare @[repo:backend:backend] with @[repo:frontend:frontend]";
|
|
let names = extract_repo_mentions(input);
|
|
assert_eq!(names, vec!["backend", "frontend"]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_repo_mentions_dedupe() {
|
|
let input = "look at @[repo:a:a] and also @[repo:a:a] please";
|
|
let names = extract_repo_mentions(input);
|
|
assert_eq!(names, vec!["a"]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_repo_mentions_none() {
|
|
let input = "hello world no mentions here";
|
|
let names = extract_repo_mentions(input);
|
|
assert!(names.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_ignores_other_mention_types() {
|
|
let input = "@[user:abc:John] and @[repo:myrepo:myrepo]";
|
|
let names = extract_repo_mentions(input);
|
|
assert_eq!(names, vec!["myrepo"]);
|
|
}
|
|
}
|