gitdataai/libs/agent/chat/state.rs

218 lines
6.3 KiB
Rust

//! Agent state machine — tracks lifecycle of a single AI agent invocation.
//!
//! States: Idle → Thinking → ToolCall → Thinking → ... → Answering | Error
//! The Thinking ↔ ToolCall cycle repeats until max tool depth or final answer.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Current phase of an agent's execution lifecycle.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum AgentState {
/// Agent is idle, waiting for input
Idle,
/// Agent is reasoning/thinking (may produce thinking chunks)
Thinking {
started_at: DateTime<Utc>,
tool_depth: u32,
},
/// Agent is executing a tool call
ToolCall {
tool_name: String,
started_at: DateTime<Utc>,
},
/// Agent is returning the final answer
Answering {
/// Accumulated answer content so far
content_chars: u64,
started_at: DateTime<Utc>,
},
/// Agent encountered a non-recoverable error
Error { message: String, tool_depth: u32 },
}
impl AgentState {
pub fn is_terminal(&self) -> bool {
matches!(
self,
AgentState::Answering { .. } | AgentState::Error { .. }
)
}
pub fn is_idle(&self) -> bool {
matches!(self, AgentState::Idle)
}
pub fn current_phase(&self) -> &'static str {
match self {
AgentState::Idle => "idle",
AgentState::Thinking { .. } => "thinking",
AgentState::ToolCall { .. } => "tool_call",
AgentState::Answering { .. } => "answering",
AgentState::Error { .. } => "error",
}
}
}
/// State machine for agent lifecycle transitions.
pub struct AgentRuntime {
state: AgentState,
max_tool_depth: u32,
current_depth: u32,
}
impl AgentRuntime {
pub fn new(max_tool_depth: u32) -> Self {
Self {
state: AgentState::Idle,
max_tool_depth,
current_depth: 0,
}
}
pub fn state(&self) -> &AgentState {
&self.state
}
/// Transition from Idle → Thinking
pub fn start_thinking(&mut self) {
debug_assert!(self.state.is_idle(), "must be Idle to start thinking");
self.current_depth = 0;
self.state = AgentState::Thinking {
started_at: Utc::now(),
tool_depth: 0,
};
}
/// Transition from Thinking → ToolCall (increments tool depth)
pub fn start_tool_call(&mut self, tool_name: String) -> Result<(), &'static str> {
if !matches!(self.state, AgentState::Thinking { .. }) {
return Err("must be Thinking to start tool call");
}
if self.current_depth >= self.max_tool_depth {
return Err("max tool depth reached");
}
self.state = AgentState::ToolCall {
tool_name,
started_at: Utc::now(),
};
Ok(())
}
/// Transition from ToolCall → Thinking (back after tool result)
pub fn complete_tool_call(&mut self) -> Result<(), &'static str> {
if !matches!(self.state, AgentState::ToolCall { .. }) {
return Err("must be ToolCall to complete");
}
self.current_depth += 1;
self.state = AgentState::Thinking {
started_at: Utc::now(),
tool_depth: self.current_depth,
};
Ok(())
}
/// Transition to Answering (terminal)
pub fn start_answer(&mut self) {
self.state = AgentState::Answering {
content_chars: 0,
started_at: Utc::now(),
};
}
pub fn append_answer(&mut self, content: &str) {
if let AgentState::Answering { content_chars, .. } = &mut self.state {
*content_chars += content.len() as u64;
}
}
/// Transition to Error (terminal)
pub fn fail(&mut self, message: String) {
self.state = AgentState::Error {
message,
tool_depth: self.current_depth,
};
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_starts_idle() {
let rt = AgentRuntime::new(10);
assert!(rt.state().is_idle());
assert_eq!(rt.state().current_phase(), "idle");
}
#[test]
fn test_idle_to_thinking() {
let mut rt = AgentRuntime::new(10);
rt.start_thinking();
assert_eq!(rt.state().current_phase(), "thinking");
assert!(!rt.state().is_terminal());
}
#[test]
fn test_thinking_to_tool_call_and_back() {
let mut rt = AgentRuntime::new(10);
rt.start_thinking();
rt.start_tool_call("search".into()).unwrap();
assert_eq!(rt.state().current_phase(), "tool_call");
rt.complete_tool_call().unwrap();
assert_eq!(rt.state().current_phase(), "thinking");
}
#[test]
fn test_thinking_to_answer() {
let mut rt = AgentRuntime::new(10);
rt.start_thinking();
rt.start_answer();
assert_eq!(rt.state().current_phase(), "answering");
assert!(rt.state().is_terminal());
}
#[test]
fn test_append_answer_tracks_chars() {
let mut rt = AgentRuntime::new(10);
rt.start_thinking();
rt.start_answer();
rt.append_answer("hello");
if let AgentState::Answering { content_chars, .. } = rt.state() {
assert_eq!(*content_chars, 5);
} else {
panic!("expected Answering state");
}
}
#[test]
fn test_error_is_terminal() {
let mut rt = AgentRuntime::new(10);
rt.start_thinking();
rt.fail("something broke".into());
assert_eq!(rt.state().current_phase(), "error");
assert!(rt.state().is_terminal());
}
#[test]
fn test_transition_from_wrong_state() {
let mut rt = AgentRuntime::new(10);
// Can't start tool call from Idle
assert!(rt.start_tool_call("tool".into()).is_err());
// Can't complete tool call from Idle
assert!(rt.complete_tool_call().is_err());
}
#[test]
fn test_max_depth_rejected() {
let mut rt = AgentRuntime::new(2);
rt.start_thinking();
rt.start_tool_call("tool1".into()).unwrap();
rt.complete_tool_call().unwrap();
rt.start_tool_call("tool2".into()).unwrap();
rt.complete_tool_call().unwrap();
assert!(rt.start_tool_call("tool3".into()).is_err());
}
}