gitdataai/libs/agent/compact/service.rs
2026-04-14 19:02:01 +08:00

468 lines
16 KiB
Rust

use async_openai::Client;
use async_openai::config::OpenAIConfig;
use async_openai::types::chat::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
CreateChatCompletionResponse,
};
use chrono::Utc;
use models::ColumnTrait;
use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder};
use serde_json::Value;
use uuid::Uuid;
use crate::AgentError;
use crate::compact::helpers::summary_content;
use crate::compact::types::{
CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult,
};
use crate::tokent::{TokenUsage, resolve_usage};
#[derive(Clone)]
pub struct CompactService {
db: DatabaseConnection,
openai: Client<OpenAIConfig>,
model: String,
}
impl CompactService {
pub fn new(db: DatabaseConnection, openai: Client<OpenAIConfig>, model: String) -> Self {
Self { db, openai, model }
}
pub async fn compact_room(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
pub async fn compact_session(
&self,
session_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
if messages.is_empty() {
return Err(AgentError::Internal("session has no messages".into()));
}
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
// Summarize the earlier messages
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
pub fn summary_as_system_message(summary: &CompactSummary) -> ChatCompletionRequestMessage {
let content = summary_content(summary);
ChatCompletionRequestMessage::System(
async_openai::types::chat::ChatCompletionRequestSystemMessage {
content: async_openai::types::chat::ChatCompletionRequestSystemMessageContent::Text(
content,
),
..Default::default()
},
)
}
/// Check if the message history for a room exceeds the token threshold.
/// Returns `ThresholdResult::Skip` if below threshold, `Compact` if above.
///
/// This method fetches messages and estimates their token count using tiktoken.
/// Call this before deciding whether to run full compaction.
pub async fn check_threshold(
&self,
room_id: Uuid,
config: CompactConfig,
) -> Result<ThresholdResult, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let tokens = self.estimate_message_tokens(&messages);
if tokens < config.token_threshold {
return Ok(ThresholdResult::Skip {
estimated_tokens: tokens,
});
}
let level = if config.auto_level {
CompactLevel::auto_select(tokens, config.token_threshold)
} else {
config.default_level
};
Ok(ThresholdResult::Compact {
estimated_tokens: tokens,
level,
})
}
/// Auto-compact a room: estimates token count, only compresses if over threshold.
///
/// This is the recommended entry point for automatic compaction.
/// - If tokens < threshold → returns a no-op summary (empty summary, no compression)
/// - If tokens >= threshold → compresses with auto-selected level
pub async fn compact_room_auto(
&self,
room_id: Uuid,
user_names: Option<std::collections::HashMap<Uuid, String>>,
config: CompactConfig,
) -> Result<CompactSummary, AgentError> {
let threshold_result = self.check_threshold(room_id, config).await?;
match threshold_result {
ThresholdResult::Skip { .. } => {
// Below threshold — no compaction needed, return empty summary
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
ThresholdResult::Compact { level, .. } => {
// Above threshold — compress with selected level
return self
.compact_room_with_level(room_id, level, user_names)
.await;
}
}
}
/// Compact a room with a specific level (bypassing threshold check).
/// Use this when the caller has already decided compaction is needed.
async fn compact_room_with_level(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
/// Estimate total token count of a message list using tiktoken.
fn estimate_message_tokens(&self, messages: &[RoomMessageModel]) -> usize {
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
// Rough estimate: ~4 chars per token (safe upper bound)
total_chars / 4
}
fn message_to_summary(
m: &RoomMessageModel,
user_name_map: &std::collections::HashMap<Uuid, String>,
) -> MessageSummary {
let sender_name = m
.sender_id
.and_then(|id| user_name_map.get(&id).cloned())
.unwrap_or_else(|| m.sender_type.to_string());
MessageSummary {
id: m.id,
sender_type: m.sender_type.clone(),
sender_id: m.sender_id,
sender_name,
content: m.content.clone(),
content_type: m.content_type.clone(),
tool_call_id: Self::extract_tool_call_id(&m.content),
send_at: m.send_at,
}
}
fn extract_tool_call_id(content: &str) -> Option<String> {
let content = content.trim();
if let Ok(v) = serde_json::from_str::<Value>(content) {
v.get("tool_call_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
}
}
async fn fetch_room_messages(
&self,
room_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.order_by_asc(RmCol::Seq)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(messages)
}
async fn get_user_name_map(
&self,
user_ids: &[Uuid],
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
use std::collections::HashMap;
let mut map = HashMap::new();
if !user_ids.is_empty() {
let users = User::find()
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
for user in users {
map.insert(user.uid, user.username);
}
}
Ok(map)
}
async fn summarize_messages(
&self,
messages: &[RoomMessageModel],
) -> Result<(String, Option<TokenUsage>), AgentError> {
// Collect distinct user IDs
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
// Query usernames
let user_name_map = self.get_user_name_map(&user_ids).await?;
// Define sender mapper
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let user_msg = ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: async_openai::types::chat::ChatCompletionRequestUserMessageContent::Text(
format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
body
),
),
..Default::default()
});
let request = CreateChatCompletionRequest {
model: self.model.clone(),
messages: vec![user_msg],
stream: Some(false),
..Default::default()
};
let response: CreateChatCompletionResponse = self
.openai
.chat()
.create(request)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let text = response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default();
// Prefer remote usage; fall back to None (caller will use tiktoken via resolve_usage)
let remote_usage = response
.usage
.as_ref()
.and_then(|u| TokenUsage::from_remote(u.prompt_tokens, u.completion_tokens));
Ok((text, remote_usage))
}
}