328 lines
11 KiB
Rust
328 lines
11 KiB
Rust
use chrono::Utc;
|
|
use models::ColumnTrait;
|
|
use models::rooms::room_message::{
|
|
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
|
};
|
|
use models::users::user::{Column as UserCol, Entity as User};
|
|
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
|
use uuid::Uuid;
|
|
|
|
use crate::client::types::ChatRequestMessage;
|
|
use crate::client::AiClientConfig;
|
|
use crate::client::call_with_params;
|
|
use crate::AgentError;
|
|
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
|
|
use crate::tokent::{TokenUsage, resolve_usage};
|
|
|
|
#[derive(Clone)]
|
|
pub struct CompactService {
|
|
db: DatabaseConnection,
|
|
ai_client_config: AiClientConfig,
|
|
model: String,
|
|
}
|
|
|
|
impl CompactService {
|
|
pub fn new(db: DatabaseConnection, ai_client_config: AiClientConfig, model: String) -> Self {
|
|
Self { db, ai_client_config, model }
|
|
}
|
|
|
|
pub async fn compact_room(
|
|
&self,
|
|
room_id: Uuid,
|
|
level: CompactLevel,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
requester_id: Uuid,
|
|
context_window_tokens: i32,
|
|
compaction_max_summary_ratio: f32,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
// Verify room access at the database level to ensure auth context is enforced.
|
|
// Public rooms are accessible to project members.
|
|
// For simplicity in this audit fix, we'll fetch only if access exists.
|
|
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
|
|
|
|
if messages.is_empty() {
|
|
// Check if room actually exists or if it's just empty/inaccessible
|
|
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
|
|
.one(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?
|
|
.is_some();
|
|
|
|
if room_exists {
|
|
return Err(AgentError::Internal("Access denied or room empty".into()));
|
|
} else {
|
|
return Err(AgentError::Internal("Room not found".into()));
|
|
}
|
|
}
|
|
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
|
|
if messages.len() <= level.retain_count() {
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
return Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
|
|
let retain_count = level.retain_count();
|
|
let split_index = messages.len().saturating_sub(retain_count);
|
|
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
|
|
|
let retained: Vec<MessageSummary> = retained_messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
|
|
|
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
|
|
|
// Build text of what was summarized (for tiktoken fallback)
|
|
let summarized_text = to_summarize
|
|
.iter()
|
|
.map(|m| m.content.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
|
|
|
Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary,
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: to_summarize.len(),
|
|
usage: Some(usage),
|
|
})
|
|
}
|
|
|
|
pub async fn compact_session(
|
|
&self,
|
|
session_id: Uuid,
|
|
level: CompactLevel,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
context_window_tokens: i32,
|
|
compaction_max_summary_ratio: f32,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
|
.filter(RmCol::Room.eq(session_id))
|
|
.order_by_asc(RmCol::Seq)
|
|
.limit(10000)
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
|
|
|
if messages.is_empty() {
|
|
return Err(AgentError::Internal("session has no messages".into()));
|
|
}
|
|
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
|
|
if messages.len() <= level.retain_count() {
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
return Ok(CompactSummary {
|
|
session_id,
|
|
room_id: Uuid::nil(),
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
|
|
let retain_count = level.retain_count();
|
|
let split_index = messages.len().saturating_sub(retain_count);
|
|
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
|
|
|
let retained: Vec<MessageSummary> = retained_messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
|
|
|
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
|
|
|
let summarized_text = to_summarize
|
|
.iter()
|
|
.map(|m| m.content.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
|
|
|
Ok(CompactSummary {
|
|
session_id,
|
|
room_id: Uuid::nil(),
|
|
retained,
|
|
summary,
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: to_summarize.len(),
|
|
usage: Some(usage),
|
|
})
|
|
}
|
|
|
|
async fn fetch_room_messages_secure(
|
|
&self,
|
|
room_id: Uuid,
|
|
requester_id: Uuid,
|
|
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
|
use models::rooms::{RoomUserState, RoomAccess};
|
|
use sea_orm::QueryTrait;
|
|
use sea_orm::sea_query::Expr;
|
|
|
|
// Find messages for the room where the requester has access.
|
|
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
|
|
RoomMessage::find()
|
|
.filter(RmCol::Room.eq(room_id))
|
|
.filter(
|
|
sea_orm::Condition::any()
|
|
.add(
|
|
Expr::exists(
|
|
RoomUserState::find()
|
|
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
|
|
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
|
|
.into_query()
|
|
)
|
|
)
|
|
.add(
|
|
Expr::exists(
|
|
RoomAccess::find()
|
|
.filter(models::rooms::room_access::Column::Room.eq(room_id))
|
|
.filter(models::rooms::room_access::Column::User.eq(requester_id))
|
|
.into_query()
|
|
)
|
|
)
|
|
)
|
|
.order_by_asc(RmCol::Seq)
|
|
.limit(10000)
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))
|
|
}
|
|
|
|
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
|
|
let sender_name = if let Some(user_id) = m.sender_id {
|
|
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
|
|
} else {
|
|
m.sender_type.to_string()
|
|
};
|
|
MessageSummary {
|
|
id: m.id,
|
|
sender_type: m.sender_type.clone(),
|
|
sender_id: m.sender_id,
|
|
sender_name,
|
|
content: m.content.clone(),
|
|
content_type: m.content_type.clone(),
|
|
tool_call_id: None,
|
|
send_at: m.send_at,
|
|
}
|
|
}
|
|
|
|
async fn get_user_name_map(
|
|
&self,
|
|
user_ids: &[Uuid],
|
|
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
|
|
use std::collections::HashMap;
|
|
let mut map = HashMap::new();
|
|
if !user_ids.is_empty() {
|
|
let users = User::find()
|
|
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
|
for user in users {
|
|
map.insert(user.uid, user.username);
|
|
}
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
async fn summarize_messages(
|
|
&self,
|
|
messages: &[RoomMessageModel],
|
|
max_summary_tokens: usize,
|
|
) -> Result<(String, Option<TokenUsage>), AgentError> {
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
|
|
let user_name_map = self.get_user_name_map(&user_ids).await?;
|
|
|
|
let sender_mapper = |m: &RoomMessageModel| {
|
|
if let Some(user_id) = m.sender_id {
|
|
if let Some(username) = user_name_map.get(&user_id) {
|
|
return username.clone();
|
|
}
|
|
}
|
|
m.sender_type.to_string()
|
|
};
|
|
|
|
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
|
|
|
|
let user_msg = ChatRequestMessage::user(format!(
|
|
"Summarise the following conversation concisely, preserving all key facts, \
|
|
decisions, and any pending or in-progress work. \
|
|
The summary MUST NOT exceed {} tokens. \
|
|
Use this format:\n\n\
|
|
**Summary:** <one-paragraph overview>\n\
|
|
**Key decisions:** <bullet list or 'none'>\n\
|
|
**Open items:** <bullet list or 'none'>\n\n\
|
|
Conversation:\n\n{}",
|
|
max_summary_tokens,
|
|
body
|
|
));
|
|
|
|
let response = call_with_params(
|
|
&[user_msg],
|
|
&self.model,
|
|
&self.ai_client_config,
|
|
0.3,
|
|
2048,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let remote_usage =
|
|
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
|
|
|
|
Ok((response.content, remote_usage))
|
|
}
|
|
}
|