550 lines
15 KiB
Rust
550 lines
15 KiB
Rust
use std::time::SystemTime;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use uuid::Uuid;
|
|
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
/// Current session file format version.
|
|
pub const SESSION_VERSION: u32 = 2;
|
|
|
|
/// Session metadata header.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SessionHeader {
|
|
pub version: u32,
|
|
pub id: Uuid,
|
|
pub created_at: String,
|
|
pub parent_session: Option<Uuid>,
|
|
pub name: Option<String>,
|
|
}
|
|
|
|
impl SessionHeader {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
version: SESSION_VERSION,
|
|
id: Uuid::new_v4(),
|
|
created_at: iso_now(),
|
|
parent_session: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_parent(mut self, parent: Uuid) -> Self {
|
|
self.parent_session = Some(parent);
|
|
self
|
|
}
|
|
|
|
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
|
self.name = Some(name.into());
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Typed session entry — each entry in a session transcript is one of these variants.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum SessionEntry {
|
|
/// A user or assistant message.
|
|
Message {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
role: SessionMessageRole,
|
|
content: String,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
tool_calls: Option<Vec<SessionToolCall>>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
tool_result: Option<SessionToolResult>,
|
|
},
|
|
|
|
/// A context compaction event (older messages summarized).
|
|
Compaction {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
summary: String,
|
|
first_kept_entry_id: Uuid,
|
|
messages_compacted: usize,
|
|
tokens_saved: i64,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
details: Option<Value>,
|
|
},
|
|
|
|
/// A branch summary (created when forking from a different point in the tree).
|
|
BranchSummary {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
from_entry_id: Uuid,
|
|
summary: String,
|
|
entries_summarized: usize,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
label: Option<String>,
|
|
},
|
|
|
|
/// Model change during a session.
|
|
ModelChange {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
provider: String,
|
|
model_id: String,
|
|
},
|
|
|
|
/// Thinking level change during a session.
|
|
ThinkingLevelChange {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
level: String,
|
|
},
|
|
|
|
/// Custom extension data (not sent to LLM).
|
|
Custom {
|
|
id: Uuid,
|
|
parent_id: Option<Uuid>,
|
|
timestamp: String,
|
|
custom_type: String,
|
|
data: Option<Value>,
|
|
},
|
|
}
|
|
|
|
impl SessionEntry {
|
|
pub fn id(&self) -> Uuid {
|
|
match self {
|
|
Self::Message { id, .. }
|
|
| Self::Compaction { id, .. }
|
|
| Self::BranchSummary { id, .. }
|
|
| Self::ModelChange { id, .. }
|
|
| Self::ThinkingLevelChange { id, .. }
|
|
| Self::Custom { id, .. } => *id,
|
|
}
|
|
}
|
|
|
|
pub fn parent_id(&self) -> Option<Uuid> {
|
|
match self {
|
|
Self::Message { parent_id, .. }
|
|
| Self::Compaction { parent_id, .. }
|
|
| Self::BranchSummary { parent_id, .. }
|
|
| Self::ModelChange { parent_id, .. }
|
|
| Self::ThinkingLevelChange { parent_id, .. }
|
|
| Self::Custom { parent_id, .. } => *parent_id,
|
|
}
|
|
}
|
|
|
|
pub fn timestamp(&self) -> &str {
|
|
match self {
|
|
Self::Message { timestamp, .. }
|
|
| Self::Compaction { timestamp, .. }
|
|
| Self::BranchSummary { timestamp, .. }
|
|
| Self::ModelChange { timestamp, .. }
|
|
| Self::ThinkingLevelChange { timestamp, .. }
|
|
| Self::Custom { timestamp, .. } => timestamp,
|
|
}
|
|
}
|
|
|
|
/// Create a user message entry.
|
|
pub fn user_message(
|
|
parent_id: Option<Uuid>,
|
|
content: impl Into<String>,
|
|
) -> Self {
|
|
Self::Message {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
role: SessionMessageRole::User,
|
|
content: content.into(),
|
|
tool_calls: None,
|
|
tool_result: None,
|
|
}
|
|
}
|
|
|
|
/// Create an assistant message entry.
|
|
pub fn assistant_message(
|
|
parent_id: Option<Uuid>,
|
|
content: impl Into<String>,
|
|
tool_calls: Option<Vec<SessionToolCall>>,
|
|
) -> Self {
|
|
Self::Message {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
role: SessionMessageRole::Assistant,
|
|
content: content.into(),
|
|
tool_calls,
|
|
tool_result: None,
|
|
}
|
|
}
|
|
|
|
/// Create a compaction entry.
|
|
pub fn compaction(
|
|
parent_id: Option<Uuid>,
|
|
summary: impl Into<String>,
|
|
first_kept_entry_id: Uuid,
|
|
messages_compacted: usize,
|
|
tokens_saved: i64,
|
|
) -> Self {
|
|
Self::Compaction {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
summary: summary.into(),
|
|
first_kept_entry_id,
|
|
messages_compacted,
|
|
tokens_saved,
|
|
details: None,
|
|
}
|
|
}
|
|
|
|
/// Create a branch summary entry.
|
|
pub fn branch_summary(
|
|
parent_id: Option<Uuid>,
|
|
from_entry_id: Uuid,
|
|
summary: impl Into<String>,
|
|
entries_summarized: usize,
|
|
label: Option<String>,
|
|
) -> Self {
|
|
Self::BranchSummary {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
from_entry_id,
|
|
summary: summary.into(),
|
|
entries_summarized,
|
|
label,
|
|
}
|
|
}
|
|
|
|
/// Create a model change entry.
|
|
pub fn model_change(
|
|
parent_id: Option<Uuid>,
|
|
provider: impl Into<String>,
|
|
model_id: impl Into<String>,
|
|
) -> Self {
|
|
Self::ModelChange {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
provider: provider.into(),
|
|
model_id: model_id.into(),
|
|
}
|
|
}
|
|
|
|
/// Create a custom extension entry.
|
|
pub fn custom(
|
|
parent_id: Option<Uuid>,
|
|
custom_type: impl Into<String>,
|
|
data: Option<Value>,
|
|
) -> Self {
|
|
Self::Custom {
|
|
id: Uuid::new_v4(),
|
|
parent_id,
|
|
timestamp: iso_now(),
|
|
custom_type: custom_type.into(),
|
|
data,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum SessionMessageRole {
|
|
User,
|
|
Assistant,
|
|
ToolResult,
|
|
System,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SessionToolCall {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub arguments: Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SessionToolResult {
|
|
pub tool_call_id: String,
|
|
pub tool_name: String,
|
|
pub content: String,
|
|
pub is_error: bool,
|
|
}
|
|
|
|
/// A full session: header + ordered list of entries forming a tree.
|
|
///
|
|
/// The tree structure supports forking: entries share `parent_id` links,
|
|
/// and the "active branch" is determined by following from the leaf
|
|
/// back to the root.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Session {
|
|
pub header: SessionHeader,
|
|
pub entries: Vec<SessionEntry>,
|
|
}
|
|
|
|
impl Session {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
header: SessionHeader::new(),
|
|
entries: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
|
self.header = self.header.with_name(name);
|
|
self
|
|
}
|
|
|
|
/// Append an entry to the session.
|
|
pub fn push(&mut self, entry: SessionEntry) {
|
|
self.entries.push(entry);
|
|
}
|
|
|
|
/// Get the last entry's id (used as parent_id for the next entry).
|
|
pub fn last_entry_id(&self) -> Option<Uuid> {
|
|
self.entries.last().map(|e| e.id())
|
|
}
|
|
|
|
/// Get all entries on the active branch (from root to leaf).
|
|
pub fn active_branch(&self) -> Vec<&SessionEntry> {
|
|
if self.entries.is_empty() {
|
|
return Vec::new();
|
|
}
|
|
|
|
let mut branch = Vec::new();
|
|
let mut current_id = Some(self.entries.last().unwrap().id());
|
|
|
|
while let Some(id) = current_id {
|
|
if let Some(entry) = self.entries.iter().find(|e| e.id() == id) {
|
|
branch.push(entry);
|
|
current_id = entry.parent_id();
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
branch.reverse();
|
|
branch
|
|
}
|
|
|
|
/// Get all message entries on the active branch (for LLM context).
|
|
pub fn active_messages(&self) -> Vec<&SessionEntry> {
|
|
self.active_branch()
|
|
.into_iter()
|
|
.filter(|e| {
|
|
matches!(
|
|
e,
|
|
SessionEntry::Message { .. }
|
|
| SessionEntry::Compaction { .. }
|
|
)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Find all children of a given entry (for tree navigation).
|
|
pub fn children_of(&self, parent_id: Uuid) -> Vec<&SessionEntry> {
|
|
self.entries
|
|
.iter()
|
|
.filter(|e| e.parent_id() == Some(parent_id))
|
|
.collect()
|
|
}
|
|
|
|
/// Get all leaf entries (entries with no children).
|
|
pub fn leaves(&self) -> Vec<&SessionEntry> {
|
|
let parent_ids: std::collections::HashSet<Uuid> =
|
|
self.entries.iter().filter_map(|e| e.parent_id()).collect();
|
|
|
|
self.entries
|
|
.iter()
|
|
.filter(|e| !parent_ids.contains(&e.id()))
|
|
.collect()
|
|
}
|
|
|
|
/// Count total entries.
|
|
pub fn entry_count(&self) -> usize {
|
|
self.entries.len()
|
|
}
|
|
|
|
/// Fork from a specific entry, creating entries that belong to a new branch.
|
|
/// Returns the entries that should be in the new branch (from root to fork point).
|
|
pub fn fork_from(&self, fork_entry_id: Uuid) -> AiResult<Session> {
|
|
let fork_idx = self
|
|
.entries
|
|
.iter()
|
|
.position(|e| e.id() == fork_entry_id)
|
|
.ok_or_else(|| {
|
|
AiError::Config(format!(
|
|
"fork entry {fork_entry_id} not found in session"
|
|
))
|
|
})?;
|
|
|
|
let mut new_session = Session::new();
|
|
new_session.header = new_session.header.with_parent(self.header.id);
|
|
|
|
// Copy entries up to and including the fork point
|
|
for entry in &self.entries[..=fork_idx] {
|
|
new_session.entries.push(entry.clone());
|
|
}
|
|
|
|
Ok(new_session)
|
|
}
|
|
|
|
/// Find the common ancestor of two entries.
|
|
pub fn common_ancestor(&self, id_a: Uuid, id_b: Uuid) -> Option<Uuid> {
|
|
let ancestors_a = self.ancestor_chain(id_a);
|
|
let ancestors_b: std::collections::HashSet<Uuid> =
|
|
self.ancestor_chain(id_b).into_iter().collect();
|
|
|
|
for ancestor in ancestors_a {
|
|
if ancestors_b.contains(&ancestor) {
|
|
return Some(ancestor);
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Get the chain of ancestor IDs from an entry back to the root.
|
|
fn ancestor_chain(&self, entry_id: Uuid) -> Vec<Uuid> {
|
|
let mut chain = Vec::new();
|
|
let mut current_id = Some(entry_id);
|
|
|
|
while let Some(id) = current_id {
|
|
chain.push(id);
|
|
current_id = self
|
|
.entries
|
|
.iter()
|
|
.find(|e| e.id() == id)
|
|
.and_then(|e| e.parent_id());
|
|
}
|
|
|
|
chain
|
|
}
|
|
}
|
|
|
|
/// Options for session compaction.
|
|
#[derive(Debug, Clone)]
|
|
pub struct CompactionOptions {
|
|
/// Custom instructions for the compaction LLM call.
|
|
pub custom_instructions: Option<String>,
|
|
/// Reserve this many tokens for the prompt + LLM response.
|
|
pub reserve_tokens: i64,
|
|
/// Keep this many recent message pairs untouched.
|
|
pub keep_recent_pairs: usize,
|
|
/// Whether to generate branch summaries for forked branches.
|
|
pub branch_summarization: bool,
|
|
}
|
|
|
|
impl Default for CompactionOptions {
|
|
fn default() -> Self {
|
|
Self {
|
|
custom_instructions: None,
|
|
reserve_tokens: 16_384,
|
|
keep_recent_pairs: 4,
|
|
branch_summarization: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn iso_now() -> String {
|
|
SystemTime::now()
|
|
.duration_since(SystemTime::UNIX_EPOCH)
|
|
.map(|d| {
|
|
let secs = d.as_secs();
|
|
// Simple ISO 8601 format (UTC)
|
|
let days = secs / 86400;
|
|
let years = (days * 400) / 146097;
|
|
let remaining_days =
|
|
days - (years * 365 + years / 4 - years / 100 + years / 400);
|
|
let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
|
|
let is_leap =
|
|
(years % 4 == 0 && years % 100 != 0) || years % 400 == 0;
|
|
let mut month = 0usize;
|
|
let mut day_acc = remaining_days as i64;
|
|
for (i, &md) in month_days.iter().enumerate() {
|
|
let md = if i == 1 && is_leap { md + 1 } else { md };
|
|
if day_acc < md as i64 {
|
|
month = i;
|
|
break;
|
|
}
|
|
day_acc -= md as i64;
|
|
}
|
|
let day = day_acc + 1;
|
|
let hour = (secs % 86400) / 3600;
|
|
let minute = (secs % 3600) / 60;
|
|
let second = secs % 60;
|
|
format!(
|
|
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
|
|
1970 + years,
|
|
month + 1,
|
|
day,
|
|
hour,
|
|
minute,
|
|
second,
|
|
)
|
|
})
|
|
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_session_basic() {
|
|
let mut session = Session::new();
|
|
|
|
let msg1 = SessionEntry::user_message(None, "Hello");
|
|
let msg1_id = msg1.id();
|
|
session.push(msg1);
|
|
|
|
let msg2 =
|
|
SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None);
|
|
session.push(msg2);
|
|
|
|
assert_eq!(session.entry_count(), 2);
|
|
assert_eq!(session.active_branch().len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_fork() {
|
|
let mut session = Session::new();
|
|
|
|
let msg1 = SessionEntry::user_message(None, "First");
|
|
let msg1_id = msg1.id();
|
|
session.push(msg1);
|
|
|
|
let msg2 =
|
|
SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None);
|
|
let msg2_id = msg2.id();
|
|
session.push(msg2);
|
|
|
|
let msg3 = SessionEntry::user_message(Some(msg2_id), "Second");
|
|
session.push(msg3);
|
|
|
|
// Fork from msg2
|
|
let forked = session.fork_from(msg2_id).unwrap();
|
|
assert_eq!(forked.entry_count(), 2);
|
|
assert_eq!(forked.header.parent_session, Some(session.header.id));
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_leaves() {
|
|
let mut session = Session::new();
|
|
|
|
let msg1 = SessionEntry::user_message(None, "Root");
|
|
let msg1_id = msg1.id();
|
|
session.push(msg1);
|
|
|
|
// Two children branching from root
|
|
let msg2a =
|
|
SessionEntry::assistant_message(Some(msg1_id), "Branch A", None);
|
|
let msg2b =
|
|
SessionEntry::assistant_message(Some(msg1_id), "Branch B", None);
|
|
session.push(msg2a);
|
|
session.push(msg2b);
|
|
|
|
let leaves = session.leaves();
|
|
assert_eq!(leaves.len(), 2);
|
|
}
|
|
}
|