468 lines
16 KiB
Rust
468 lines
16 KiB
Rust
use async_openai::Client;
|
|
use async_openai::config::OpenAIConfig;
|
|
use async_openai::types::chat::{
|
|
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
|
|
CreateChatCompletionResponse,
|
|
};
|
|
use chrono::Utc;
|
|
use models::ColumnTrait;
|
|
use models::rooms::room_message::{
|
|
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
|
};
|
|
use models::users::user::{Column as UserCol, Entity as User};
|
|
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder};
|
|
use serde_json::Value;
|
|
use uuid::Uuid;
|
|
|
|
use crate::AgentError;
|
|
use crate::compact::helpers::summary_content;
|
|
use crate::compact::types::{
|
|
CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult,
|
|
};
|
|
use crate::tokent::{TokenUsage, resolve_usage};
|
|
|
|
#[derive(Clone)]
|
|
pub struct CompactService {
|
|
db: DatabaseConnection,
|
|
openai: Client<OpenAIConfig>,
|
|
model: String,
|
|
}
|
|
|
|
impl CompactService {
|
|
pub fn new(db: DatabaseConnection, openai: Client<OpenAIConfig>, model: String) -> Self {
|
|
Self { db, openai, model }
|
|
}
|
|
|
|
pub async fn compact_room(
|
|
&self,
|
|
room_id: Uuid,
|
|
level: CompactLevel,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
let messages = self.fetch_room_messages(room_id).await?;
|
|
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
|
|
if messages.len() <= level.retain_count() {
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
return Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
|
|
let retain_count = level.retain_count();
|
|
let split_index = messages.len().saturating_sub(retain_count);
|
|
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
|
|
|
let retained: Vec<MessageSummary> = retained_messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
|
|
|
// Build text of what was summarized (for tiktoken fallback)
|
|
let summarized_text = to_summarize
|
|
.iter()
|
|
.map(|m| m.content.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
|
|
|
Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary,
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: to_summarize.len(),
|
|
usage: Some(usage),
|
|
})
|
|
}
|
|
|
|
pub async fn compact_session(
|
|
&self,
|
|
session_id: Uuid,
|
|
level: CompactLevel,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
|
.filter(RmCol::Room.eq(session_id))
|
|
.order_by_asc(RmCol::Seq)
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
|
|
|
if messages.is_empty() {
|
|
return Err(AgentError::Internal("session has no messages".into()));
|
|
}
|
|
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
|
|
if messages.len() <= level.retain_count() {
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
return Ok(CompactSummary {
|
|
session_id,
|
|
room_id: Uuid::nil(),
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
|
|
let retain_count = level.retain_count();
|
|
let split_index = messages.len().saturating_sub(retain_count);
|
|
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
|
|
|
let retained: Vec<MessageSummary> = retained_messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
// Summarize the earlier messages
|
|
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
|
|
|
// Build text of what was summarized (for tiktoken fallback)
|
|
let summarized_text = to_summarize
|
|
.iter()
|
|
.map(|m| m.content.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
|
|
|
Ok(CompactSummary {
|
|
session_id,
|
|
room_id: Uuid::nil(),
|
|
retained,
|
|
summary,
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: to_summarize.len(),
|
|
usage: Some(usage),
|
|
})
|
|
}
|
|
|
|
pub fn summary_as_system_message(summary: &CompactSummary) -> ChatCompletionRequestMessage {
|
|
let content = summary_content(summary);
|
|
ChatCompletionRequestMessage::System(
|
|
async_openai::types::chat::ChatCompletionRequestSystemMessage {
|
|
content: async_openai::types::chat::ChatCompletionRequestSystemMessageContent::Text(
|
|
content,
|
|
),
|
|
..Default::default()
|
|
},
|
|
)
|
|
}
|
|
|
|
/// Check if the message history for a room exceeds the token threshold.
|
|
/// Returns `ThresholdResult::Skip` if below threshold, `Compact` if above.
|
|
///
|
|
/// This method fetches messages and estimates their token count using tiktoken.
|
|
/// Call this before deciding whether to run full compaction.
|
|
pub async fn check_threshold(
|
|
&self,
|
|
room_id: Uuid,
|
|
config: CompactConfig,
|
|
) -> Result<ThresholdResult, AgentError> {
|
|
let messages = self.fetch_room_messages(room_id).await?;
|
|
let tokens = self.estimate_message_tokens(&messages);
|
|
|
|
if tokens < config.token_threshold {
|
|
return Ok(ThresholdResult::Skip {
|
|
estimated_tokens: tokens,
|
|
});
|
|
}
|
|
|
|
let level = if config.auto_level {
|
|
CompactLevel::auto_select(tokens, config.token_threshold)
|
|
} else {
|
|
config.default_level
|
|
};
|
|
|
|
Ok(ThresholdResult::Compact {
|
|
estimated_tokens: tokens,
|
|
level,
|
|
})
|
|
}
|
|
|
|
/// Auto-compact a room: estimates token count, only compresses if over threshold.
|
|
///
|
|
/// This is the recommended entry point for automatic compaction.
|
|
/// - If tokens < threshold → returns a no-op summary (empty summary, no compression)
|
|
/// - If tokens >= threshold → compresses with auto-selected level
|
|
pub async fn compact_room_auto(
|
|
&self,
|
|
room_id: Uuid,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
config: CompactConfig,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
let threshold_result = self.check_threshold(room_id, config).await?;
|
|
|
|
match threshold_result {
|
|
ThresholdResult::Skip { .. } => {
|
|
// Below threshold — no compaction needed, return empty summary
|
|
let messages = self.fetch_room_messages(room_id).await?;
|
|
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
return Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
ThresholdResult::Compact { level, .. } => {
|
|
// Above threshold — compress with selected level
|
|
return self
|
|
.compact_room_with_level(room_id, level, user_names)
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Compact a room with a specific level (bypassing threshold check).
|
|
/// Use this when the caller has already decided compaction is needed.
|
|
async fn compact_room_with_level(
|
|
&self,
|
|
room_id: Uuid,
|
|
level: CompactLevel,
|
|
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
|
) -> Result<CompactSummary, AgentError> {
|
|
let messages = self.fetch_room_messages(room_id).await?;
|
|
|
|
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
|
|
let user_name_map = match user_names {
|
|
Some(map) => map,
|
|
None => self.get_user_name_map(&user_ids).await?,
|
|
};
|
|
|
|
if messages.len() <= level.retain_count() {
|
|
let retained: Vec<MessageSummary> = messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
return Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary: String::new(),
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: 0,
|
|
usage: None,
|
|
});
|
|
}
|
|
|
|
let retain_count = level.retain_count();
|
|
let split_index = messages.len().saturating_sub(retain_count);
|
|
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
|
|
|
let retained: Vec<MessageSummary> = retained_messages
|
|
.iter()
|
|
.map(|m| Self::message_to_summary(m, &user_name_map))
|
|
.collect();
|
|
|
|
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
|
|
|
|
let summarized_text = to_summarize
|
|
.iter()
|
|
.map(|m| m.content.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n");
|
|
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
|
|
|
Ok(CompactSummary {
|
|
session_id: Uuid::new_v4(),
|
|
room_id,
|
|
retained,
|
|
summary,
|
|
compacted_at: Utc::now(),
|
|
messages_compressed: to_summarize.len(),
|
|
usage: Some(usage),
|
|
})
|
|
}
|
|
|
|
/// Estimate total token count of a message list using tiktoken.
|
|
fn estimate_message_tokens(&self, messages: &[RoomMessageModel]) -> usize {
|
|
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
|
|
// Rough estimate: ~4 chars per token (safe upper bound)
|
|
total_chars / 4
|
|
}
|
|
|
|
fn message_to_summary(
|
|
m: &RoomMessageModel,
|
|
user_name_map: &std::collections::HashMap<Uuid, String>,
|
|
) -> MessageSummary {
|
|
let sender_name = m
|
|
.sender_id
|
|
.and_then(|id| user_name_map.get(&id).cloned())
|
|
.unwrap_or_else(|| m.sender_type.to_string());
|
|
MessageSummary {
|
|
id: m.id,
|
|
sender_type: m.sender_type.clone(),
|
|
sender_id: m.sender_id,
|
|
sender_name,
|
|
content: m.content.clone(),
|
|
content_type: m.content_type.clone(),
|
|
tool_call_id: Self::extract_tool_call_id(&m.content),
|
|
send_at: m.send_at,
|
|
}
|
|
}
|
|
|
|
fn extract_tool_call_id(content: &str) -> Option<String> {
|
|
let content = content.trim();
|
|
if let Ok(v) = serde_json::from_str::<Value>(content) {
|
|
v.get("tool_call_id")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string())
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
async fn fetch_room_messages(
|
|
&self,
|
|
room_id: Uuid,
|
|
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
|
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
|
.filter(RmCol::Room.eq(room_id))
|
|
.order_by_asc(RmCol::Seq)
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
|
Ok(messages)
|
|
}
|
|
|
|
async fn get_user_name_map(
|
|
&self,
|
|
user_ids: &[Uuid],
|
|
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
|
|
use std::collections::HashMap;
|
|
let mut map = HashMap::new();
|
|
if !user_ids.is_empty() {
|
|
let users = User::find()
|
|
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
|
|
.all(&self.db)
|
|
.await
|
|
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
|
for user in users {
|
|
map.insert(user.uid, user.username);
|
|
}
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
async fn summarize_messages(
|
|
&self,
|
|
messages: &[RoomMessageModel],
|
|
) -> Result<(String, Option<TokenUsage>), AgentError> {
|
|
// Collect distinct user IDs
|
|
let user_ids: Vec<Uuid> = messages
|
|
.iter()
|
|
.filter_map(|m| m.sender_id)
|
|
.collect::<std::collections::HashSet<_>>()
|
|
.into_iter()
|
|
.collect();
|
|
|
|
// Query usernames
|
|
let user_name_map = self.get_user_name_map(&user_ids).await?;
|
|
|
|
// Define sender mapper
|
|
let sender_mapper = |m: &RoomMessageModel| {
|
|
if let Some(user_id) = m.sender_id {
|
|
if let Some(username) = user_name_map.get(&user_id) {
|
|
return username.clone();
|
|
}
|
|
}
|
|
m.sender_type.to_string()
|
|
};
|
|
|
|
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
|
|
|
|
let user_msg = ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
|
|
content: async_openai::types::chat::ChatCompletionRequestUserMessageContent::Text(
|
|
format!(
|
|
"Summarise the following conversation concisely, preserving all key facts, \
|
|
decisions, and any pending or in-progress work. \
|
|
Use this format:\n\n\
|
|
**Summary:** <one-paragraph overview>\n\
|
|
**Key decisions:** <bullet list or 'none'>\n\
|
|
**Open items:** <bullet list or 'none'>\n\n\
|
|
Conversation:\n\n{}",
|
|
body
|
|
),
|
|
),
|
|
..Default::default()
|
|
});
|
|
|
|
let request = CreateChatCompletionRequest {
|
|
model: self.model.clone(),
|
|
messages: vec![user_msg],
|
|
stream: Some(false),
|
|
..Default::default()
|
|
};
|
|
|
|
let response: CreateChatCompletionResponse = self
|
|
.openai
|
|
.chat()
|
|
.create(request)
|
|
.await
|
|
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
|
|
|
let text = response
|
|
.choices
|
|
.first()
|
|
.and_then(|c| c.message.content.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Prefer remote usage; fall back to None (caller will use tiktoken via resolve_usage)
|
|
let remote_usage = response
|
|
.usage
|
|
.as_ref()
|
|
.and_then(|u| TokenUsage::from_remote(u.prompt_tokens, u.completion_tokens));
|
|
|
|
Ok((text, remote_usage))
|
|
}
|
|
}
|