Add RoomAiService as the central dispatcher that selects execution path based on mode (react/chat/cot/reflexion/rewoo) and streams vs nonstreaming preference. Replace monolithic ai_streaming with mode-aware dispatch and dedicated streaming implementation.
218 lines
6.3 KiB
Rust
218 lines
6.3 KiB
Rust
//! Agent state machine — tracks lifecycle of a single AI agent invocation.
|
|
//!
|
|
//! States: Idle → Thinking → ToolCall → Thinking → ... → Answering | Error
|
|
//! The Thinking ↔ ToolCall cycle repeats until max tool depth or final answer.
|
|
|
|
use chrono::{DateTime, Utc};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
/// Current phase of an agent's execution lifecycle.
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub enum AgentState {
|
|
/// Agent is idle, waiting for input
|
|
Idle,
|
|
/// Agent is reasoning/thinking (may produce thinking chunks)
|
|
Thinking {
|
|
started_at: DateTime<Utc>,
|
|
tool_depth: u32,
|
|
},
|
|
/// Agent is executing a tool call
|
|
ToolCall {
|
|
tool_name: String,
|
|
started_at: DateTime<Utc>,
|
|
},
|
|
/// Agent is returning the final answer
|
|
Answering {
|
|
/// Accumulated answer content so far
|
|
content_chars: u64,
|
|
started_at: DateTime<Utc>,
|
|
},
|
|
/// Agent encountered a non-recoverable error
|
|
Error {
|
|
message: String,
|
|
tool_depth: u32,
|
|
},
|
|
}
|
|
|
|
impl AgentState {
|
|
pub fn is_terminal(&self) -> bool {
|
|
matches!(self, AgentState::Answering { .. } | AgentState::Error { .. })
|
|
}
|
|
|
|
pub fn is_idle(&self) -> bool {
|
|
matches!(self, AgentState::Idle)
|
|
}
|
|
|
|
pub fn current_phase(&self) -> &'static str {
|
|
match self {
|
|
AgentState::Idle => "idle",
|
|
AgentState::Thinking { .. } => "thinking",
|
|
AgentState::ToolCall { .. } => "tool_call",
|
|
AgentState::Answering { .. } => "answering",
|
|
AgentState::Error { .. } => "error",
|
|
}
|
|
}
|
|
}
|
|
|
|
/// State machine for agent lifecycle transitions.
|
|
pub struct AgentRuntime {
|
|
state: AgentState,
|
|
max_tool_depth: u32,
|
|
current_depth: u32,
|
|
}
|
|
|
|
impl AgentRuntime {
|
|
pub fn new(max_tool_depth: u32) -> Self {
|
|
Self {
|
|
state: AgentState::Idle,
|
|
max_tool_depth,
|
|
current_depth: 0,
|
|
}
|
|
}
|
|
|
|
pub fn state(&self) -> &AgentState {
|
|
&self.state
|
|
}
|
|
|
|
/// Transition from Idle → Thinking
|
|
pub fn start_thinking(&mut self) {
|
|
debug_assert!(self.state.is_idle(), "must be Idle to start thinking");
|
|
self.current_depth = 0;
|
|
self.state = AgentState::Thinking {
|
|
started_at: Utc::now(),
|
|
tool_depth: 0,
|
|
};
|
|
}
|
|
|
|
/// Transition from Thinking → ToolCall (increments tool depth)
|
|
pub fn start_tool_call(&mut self, tool_name: String) -> Result<(), &'static str> {
|
|
if !matches!(self.state, AgentState::Thinking { .. }) {
|
|
return Err("must be Thinking to start tool call");
|
|
}
|
|
if self.current_depth >= self.max_tool_depth {
|
|
return Err("max tool depth reached");
|
|
}
|
|
self.state = AgentState::ToolCall {
|
|
tool_name,
|
|
started_at: Utc::now(),
|
|
};
|
|
Ok(())
|
|
}
|
|
|
|
/// Transition from ToolCall → Thinking (back after tool result)
|
|
pub fn complete_tool_call(&mut self) -> Result<(), &'static str> {
|
|
if !matches!(self.state, AgentState::ToolCall { .. }) {
|
|
return Err("must be ToolCall to complete");
|
|
}
|
|
self.current_depth += 1;
|
|
self.state = AgentState::Thinking {
|
|
started_at: Utc::now(),
|
|
tool_depth: self.current_depth,
|
|
};
|
|
Ok(())
|
|
}
|
|
|
|
/// Transition to Answering (terminal)
|
|
pub fn start_answer(&mut self) {
|
|
self.state = AgentState::Answering {
|
|
content_chars: 0,
|
|
started_at: Utc::now(),
|
|
};
|
|
}
|
|
|
|
pub fn append_answer(&mut self, content: &str) {
|
|
if let AgentState::Answering { content_chars, .. } = &mut self.state {
|
|
*content_chars += content.len() as u64;
|
|
}
|
|
}
|
|
|
|
/// Transition to Error (terminal)
|
|
pub fn fail(&mut self, message: String) {
|
|
self.state = AgentState::Error {
|
|
message,
|
|
tool_depth: self.current_depth,
|
|
};
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_starts_idle() {
|
|
let rt = AgentRuntime::new(10);
|
|
assert!(rt.state().is_idle());
|
|
assert_eq!(rt.state().current_phase(), "idle");
|
|
}
|
|
|
|
#[test]
|
|
fn test_idle_to_thinking() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
rt.start_thinking();
|
|
assert_eq!(rt.state().current_phase(), "thinking");
|
|
assert!(!rt.state().is_terminal());
|
|
}
|
|
|
|
#[test]
|
|
fn test_thinking_to_tool_call_and_back() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
rt.start_thinking();
|
|
rt.start_tool_call("search".into()).unwrap();
|
|
assert_eq!(rt.state().current_phase(), "tool_call");
|
|
rt.complete_tool_call().unwrap();
|
|
assert_eq!(rt.state().current_phase(), "thinking");
|
|
}
|
|
|
|
#[test]
|
|
fn test_thinking_to_answer() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
rt.start_thinking();
|
|
rt.start_answer();
|
|
assert_eq!(rt.state().current_phase(), "answering");
|
|
assert!(rt.state().is_terminal());
|
|
}
|
|
|
|
#[test]
|
|
fn test_append_answer_tracks_chars() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
rt.start_thinking();
|
|
rt.start_answer();
|
|
rt.append_answer("hello");
|
|
if let AgentState::Answering { content_chars, .. } = rt.state() {
|
|
assert_eq!(*content_chars, 5);
|
|
} else {
|
|
panic!("expected Answering state");
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_error_is_terminal() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
rt.start_thinking();
|
|
rt.fail("something broke".into());
|
|
assert_eq!(rt.state().current_phase(), "error");
|
|
assert!(rt.state().is_terminal());
|
|
}
|
|
|
|
#[test]
|
|
fn test_transition_from_wrong_state() {
|
|
let mut rt = AgentRuntime::new(10);
|
|
// Can't start tool call from Idle
|
|
assert!(rt.start_tool_call("tool".into()).is_err());
|
|
// Can't complete tool call from Idle
|
|
assert!(rt.complete_tool_call().is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_max_depth_rejected() {
|
|
let mut rt = AgentRuntime::new(2);
|
|
rt.start_thinking();
|
|
rt.start_tool_call("tool1".into()).unwrap();
|
|
rt.complete_tool_call().unwrap();
|
|
rt.start_tool_call("tool2".into()).unwrap();
|
|
rt.complete_tool_call().unwrap();
|
|
assert!(rt.start_tool_call("tool3".into()).is_err());
|
|
}
|
|
}
|