gitdataai/lib/ai/agent/session.rs
2026-05-30 01:38:40 +08:00

536 lines
15 KiB
Rust

use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use crate::error::{AiError, AiResult};
/// Current session file format version.
pub const SESSION_VERSION: u32 = 2;
/// Session metadata header.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionHeader {
pub version: u32,
pub id: Uuid,
pub created_at: String,
pub parent_session: Option<Uuid>,
pub name: Option<String>,
}
impl SessionHeader {
pub fn new() -> Self {
Self {
version: SESSION_VERSION,
id: Uuid::new_v4(),
created_at: iso_now(),
parent_session: None,
name: None,
}
}
pub fn with_parent(mut self, parent: Uuid) -> Self {
self.parent_session = Some(parent);
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
/// Typed session entry — each entry in a session transcript is one of these variants.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SessionEntry {
/// A user or assistant message.
Message {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
role: SessionMessageRole,
content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<SessionToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
tool_result: Option<SessionToolResult>,
},
/// A context compaction event (older messages summarized).
Compaction {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
summary: String,
first_kept_entry_id: Uuid,
messages_compacted: usize,
tokens_saved: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
details: Option<Value>,
},
/// A branch summary (created when forking from a different point in the tree).
BranchSummary {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
from_entry_id: Uuid,
summary: String,
entries_summarized: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
label: Option<String>,
},
/// Model change during a session.
ModelChange {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
provider: String,
model_id: String,
},
/// Thinking level change during a session.
ThinkingLevelChange {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
level: String,
},
/// Custom extension data (not sent to LLM).
Custom {
id: Uuid,
parent_id: Option<Uuid>,
timestamp: String,
custom_type: String,
data: Option<Value>,
},
}
impl SessionEntry {
pub fn id(&self) -> Uuid {
match self {
Self::Message { id, .. }
| Self::Compaction { id, .. }
| Self::BranchSummary { id, .. }
| Self::ModelChange { id, .. }
| Self::ThinkingLevelChange { id, .. }
| Self::Custom { id, .. } => *id,
}
}
pub fn parent_id(&self) -> Option<Uuid> {
match self {
Self::Message { parent_id, .. }
| Self::Compaction { parent_id, .. }
| Self::BranchSummary { parent_id, .. }
| Self::ModelChange { parent_id, .. }
| Self::ThinkingLevelChange { parent_id, .. }
| Self::Custom { parent_id, .. } => *parent_id,
}
}
pub fn timestamp(&self) -> &str {
match self {
Self::Message { timestamp, .. }
| Self::Compaction { timestamp, .. }
| Self::BranchSummary { timestamp, .. }
| Self::ModelChange { timestamp, .. }
| Self::ThinkingLevelChange { timestamp, .. }
| Self::Custom { timestamp, .. } => timestamp,
}
}
/// Create a user message entry.
pub fn user_message(parent_id: Option<Uuid>, content: impl Into<String>) -> Self {
Self::Message {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
role: SessionMessageRole::User,
content: content.into(),
tool_calls: None,
tool_result: None,
}
}
/// Create an assistant message entry.
pub fn assistant_message(
parent_id: Option<Uuid>,
content: impl Into<String>,
tool_calls: Option<Vec<SessionToolCall>>,
) -> Self {
Self::Message {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
role: SessionMessageRole::Assistant,
content: content.into(),
tool_calls,
tool_result: None,
}
}
/// Create a compaction entry.
pub fn compaction(
parent_id: Option<Uuid>,
summary: impl Into<String>,
first_kept_entry_id: Uuid,
messages_compacted: usize,
tokens_saved: i64,
) -> Self {
Self::Compaction {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
summary: summary.into(),
first_kept_entry_id,
messages_compacted,
tokens_saved,
details: None,
}
}
/// Create a branch summary entry.
pub fn branch_summary(
parent_id: Option<Uuid>,
from_entry_id: Uuid,
summary: impl Into<String>,
entries_summarized: usize,
label: Option<String>,
) -> Self {
Self::BranchSummary {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
from_entry_id,
summary: summary.into(),
entries_summarized,
label,
}
}
/// Create a model change entry.
pub fn model_change(
parent_id: Option<Uuid>,
provider: impl Into<String>,
model_id: impl Into<String>,
) -> Self {
Self::ModelChange {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
provider: provider.into(),
model_id: model_id.into(),
}
}
/// Create a custom extension entry.
pub fn custom(
parent_id: Option<Uuid>,
custom_type: impl Into<String>,
data: Option<Value>,
) -> Self {
Self::Custom {
id: Uuid::new_v4(),
parent_id,
timestamp: iso_now(),
custom_type: custom_type.into(),
data,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SessionMessageRole {
User,
Assistant,
ToolResult,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionToolResult {
pub tool_call_id: String,
pub tool_name: String,
pub content: String,
pub is_error: bool,
}
/// A full session: header + ordered list of entries forming a tree.
///
/// The tree structure supports forking: entries share `parent_id` links,
/// and the "active branch" is determined by following from the leaf
/// back to the root.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub header: SessionHeader,
pub entries: Vec<SessionEntry>,
}
impl Session {
pub fn new() -> Self {
Self {
header: SessionHeader::new(),
entries: Vec::new(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.header = self.header.with_name(name);
self
}
/// Append an entry to the session.
pub fn push(&mut self, entry: SessionEntry) {
self.entries.push(entry);
}
/// Get the last entry's id (used as parent_id for the next entry).
pub fn last_entry_id(&self) -> Option<Uuid> {
self.entries.last().map(|e| e.id())
}
/// Get all entries on the active branch (from root to leaf).
pub fn active_branch(&self) -> Vec<&SessionEntry> {
if self.entries.is_empty() {
return Vec::new();
}
let mut branch = Vec::new();
let mut current_id = Some(self.entries.last().unwrap().id());
while let Some(id) = current_id {
if let Some(entry) = self.entries.iter().find(|e| e.id() == id) {
branch.push(entry);
current_id = entry.parent_id();
} else {
break;
}
}
branch.reverse();
branch
}
/// Get all message entries on the active branch (for LLM context).
pub fn active_messages(&self) -> Vec<&SessionEntry> {
self.active_branch()
.into_iter()
.filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. }))
.collect()
}
/// Find all children of a given entry (for tree navigation).
pub fn children_of(&self, parent_id: Uuid) -> Vec<&SessionEntry> {
self.entries
.iter()
.filter(|e| e.parent_id() == Some(parent_id))
.collect()
}
/// Get all leaf entries (entries with no children).
pub fn leaves(&self) -> Vec<&SessionEntry> {
let parent_ids: std::collections::HashSet<Uuid> = self
.entries
.iter()
.filter_map(|e| e.parent_id())
.collect();
self.entries
.iter()
.filter(|e| !parent_ids.contains(&e.id()))
.collect()
}
/// Count total entries.
pub fn entry_count(&self) -> usize {
self.entries.len()
}
/// Fork from a specific entry, creating entries that belong to a new branch.
/// Returns the entries that should be in the new branch (from root to fork point).
pub fn fork_from(&self, fork_entry_id: Uuid) -> AiResult<Session> {
let fork_idx = self
.entries
.iter()
.position(|e| e.id() == fork_entry_id)
.ok_or_else(|| {
AiError::Config(format!("fork entry {fork_entry_id} not found in session"))
})?;
let mut new_session = Session::new();
new_session.header = new_session.header.with_parent(self.header.id);
// Copy entries up to and including the fork point
for entry in &self.entries[..=fork_idx] {
new_session.entries.push(entry.clone());
}
Ok(new_session)
}
/// Find the common ancestor of two entries.
pub fn common_ancestor(&self, id_a: Uuid, id_b: Uuid) -> Option<Uuid> {
let ancestors_a = self.ancestor_chain(id_a);
let ancestors_b: std::collections::HashSet<Uuid> =
self.ancestor_chain(id_b).into_iter().collect();
for ancestor in ancestors_a {
if ancestors_b.contains(&ancestor) {
return Some(ancestor);
}
}
None
}
/// Get the chain of ancestor IDs from an entry back to the root.
fn ancestor_chain(&self, entry_id: Uuid) -> Vec<Uuid> {
let mut chain = Vec::new();
let mut current_id = Some(entry_id);
while let Some(id) = current_id {
chain.push(id);
current_id = self
.entries
.iter()
.find(|e| e.id() == id)
.and_then(|e| e.parent_id());
}
chain
}
}
/// Options for session compaction.
#[derive(Debug, Clone)]
pub struct CompactionOptions {
/// Custom instructions for the compaction LLM call.
pub custom_instructions: Option<String>,
/// Reserve this many tokens for the prompt + LLM response.
pub reserve_tokens: i64,
/// Keep this many recent message pairs untouched.
pub keep_recent_pairs: usize,
/// Whether to generate branch summaries for forked branches.
pub branch_summarization: bool,
}
impl Default for CompactionOptions {
fn default() -> Self {
Self {
custom_instructions: None,
reserve_tokens: 16_384,
keep_recent_pairs: 4,
branch_summarization: true,
}
}
}
fn iso_now() -> String {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| {
let secs = d.as_secs();
// Simple ISO 8601 format (UTC)
let days = secs / 86400;
let years = (days * 400) / 146097;
let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400);
let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0;
let mut month = 0usize;
let mut day_acc = remaining_days as i64;
for (i, &md) in month_days.iter().enumerate() {
let md = if i == 1 && is_leap { md + 1 } else { md };
if day_acc < md as i64 {
month = i;
break;
}
day_acc -= md as i64;
}
let day = day_acc + 1;
let hour = (secs % 86400) / 3600;
let minute = (secs % 3600) / 60;
let second = secs % 60;
format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
1970 + years,
month + 1,
day,
hour,
minute,
second,
)
})
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_basic() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "Hello");
let msg1_id = msg1.id();
session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None);
session.push(msg2);
assert_eq!(session.entry_count(), 2);
assert_eq!(session.active_branch().len(), 2);
}
#[test]
fn test_session_fork() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "First");
let msg1_id = msg1.id();
session.push(msg1);
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None);
let msg2_id = msg2.id();
session.push(msg2);
let msg3 = SessionEntry::user_message(Some(msg2_id), "Second");
session.push(msg3);
// Fork from msg2
let forked = session.fork_from(msg2_id).unwrap();
assert_eq!(forked.entry_count(), 2);
assert_eq!(forked.header.parent_session, Some(session.header.id));
}
#[test]
fn test_session_leaves() {
let mut session = Session::new();
let msg1 = SessionEntry::user_message(None, "Root");
let msg1_id = msg1.id();
session.push(msg1);
// Two children branching from root
let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None);
let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None);
session.push(msg2a);
session.push(msg2b);
let leaves = session.leaves();
assert_eq!(leaves.len(), 2);
}
}