gitdataai/libs/agent/compact/service.rs

328 lines
11 KiB
Rust

use chrono::Utc;
use models::ColumnTrait;
use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use uuid::Uuid;
use crate::client::types::ChatRequestMessage;
use crate::client::AiClientConfig;
use crate::client::call_with_params;
use crate::AgentError;
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
use crate::tokent::{TokenUsage, resolve_usage};
#[derive(Clone)]
pub struct CompactService {
db: DatabaseConnection,
ai_client_config: AiClientConfig,
model: String,
}
impl CompactService {
pub fn new(db: DatabaseConnection, ai_client_config: AiClientConfig, model: String) -> Self {
Self { db, ai_client_config, model }
}
pub async fn compact_room(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
requester_id: Uuid,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
// Verify room access at the database level to ensure auth context is enforced.
// Public rooms are accessible to project members.
// For simplicity in this audit fix, we'll fetch only if access exists.
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
if messages.is_empty() {
// Check if room actually exists or if it's just empty/inaccessible
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
.one(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?
.is_some();
if room_exists {
return Err(AgentError::Internal("Access denied or room empty".into()));
} else {
return Err(AgentError::Internal("Room not found".into()));
}
}
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
pub async fn compact_session(
&self,
session_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
if messages.is_empty() {
return Err(AgentError::Internal("session has no messages".into()));
}
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
async fn fetch_room_messages_secure(
&self,
room_id: Uuid,
requester_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
use models::rooms::{RoomUserState, RoomAccess};
use sea_orm::QueryTrait;
use sea_orm::sea_query::Expr;
// Find messages for the room where the requester has access.
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.filter(
sea_orm::Condition::any()
.add(
Expr::exists(
RoomUserState::find()
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
.into_query()
)
)
.add(
Expr::exists(
RoomAccess::find()
.filter(models::rooms::room_access::Column::Room.eq(room_id))
.filter(models::rooms::room_access::Column::User.eq(requester_id))
.into_query()
)
)
)
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))
}
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
let sender_name = if let Some(user_id) = m.sender_id {
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
} else {
m.sender_type.to_string()
};
MessageSummary {
id: m.id,
sender_type: m.sender_type.clone(),
sender_id: m.sender_id,
sender_name,
content: m.content.clone(),
content_type: m.content_type.clone(),
tool_call_id: None,
send_at: m.send_at,
}
}
async fn get_user_name_map(
&self,
user_ids: &[Uuid],
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
use std::collections::HashMap;
let mut map = HashMap::new();
if !user_ids.is_empty() {
let users = User::find()
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
for user in users {
map.insert(user.uid, user.username);
}
}
Ok(map)
}
async fn summarize_messages(
&self,
messages: &[RoomMessageModel],
max_summary_tokens: usize,
) -> Result<(String, Option<TokenUsage>), AgentError> {
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = self.get_user_name_map(&user_ids).await?;
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens,
body
));
let response = call_with_params(
&[user_msg],
&self.model,
&self.ai_client_config,
0.3,
2048,
None,
None,
None,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, remote_usage))
}
}