use std::future::Future; use std::pin::Pin; use std::sync::Arc; use futures::StreamExt; use rig::agent::AgentBuilder; use rig::client::CompletionClient; use rig::streaming::StreamingPrompt; use rig::tool::ToolDyn; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use super::config::AgentConfig; use super::error_classifier::{ classify_error, retry_policy_for, should_switch_to_fallback, }; use super::events::{AgentEvent, EventSink}; use super::helpers::{build_input_string, estimate_tokens}; use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision}; use super::iteration_budget::IterationBudget; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; use super::RigStreamChunk; use crate::client::AiClient; use crate::error::{AiError, AiResult}; /// How tool calls from a single assistant turn are executed. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ToolExecutionMode { /// Execute tool calls one at a time. Sequential, /// Execute tool calls concurrently (after sequential preflight). Parallel, } impl Default for ToolExecutionMode { fn default() -> Self { Self::Parallel } } /// Callback type for steering messages (injected mid-run). pub type SteeringFn = Arc< dyn Fn() -> Pin> + Send>> + Send + Sync, >; /// Callback type for follow-up messages (injected after agent would stop). pub type FollowUpFn = Arc< dyn Fn() -> Pin> + Send>> + Send + Sync, >; /// Callback to decide whether the agent should stop after a turn. pub type ShouldStopFn = Arc< dyn Fn(&TurnContext) -> bool + Send + Sync, >; /// Callback to prepare/modify state before the next turn. pub type PrepareNextTurnFn = Arc< dyn Fn(&TurnContext) -> Pin> + Send>> + Send + Sync, >; /// Context passed to `should_stop` and `prepare_next_turn` callbacks. #[derive(Debug, Clone)] pub struct TurnContext { pub turn_index: usize, pub assistant_text: String, pub tool_call_count: usize, pub total_input_tokens: u64, pub total_output_tokens: u64, pub model_name: String, } /// Replacement state for the next turn (returned by `prepare_next_turn`). #[derive(Debug, Clone)] pub struct TurnUpdate { pub model: Option, pub temperature: Option, pub max_completion_tokens: Option, } /// Extended agent loop configuration, adding steering/follow-up/lifecycle /// hooks on top of the base `AgentConfig`. pub struct AgentLoopConfig { pub config: AgentConfig, pub tool_execution_mode: ToolExecutionMode, pub get_steering_messages: Option, pub get_follow_up_messages: Option, pub should_stop_after_turn: Option, pub prepare_next_turn: Option, pub event_sink: Option, } impl AgentLoopConfig { pub fn new(config: AgentConfig) -> Self { Self { config, tool_execution_mode: ToolExecutionMode::default(), get_steering_messages: None, get_follow_up_messages: None, should_stop_after_turn: None, prepare_next_turn: None, event_sink: None, } } pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self { self.tool_execution_mode = mode; self } pub fn with_steering_messages(mut self, f: SteeringFn) -> Self { self.get_steering_messages = Some(f); self } pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self { self.get_follow_up_messages = Some(f); self } pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self { self.should_stop_after_turn = Some(f); self } pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self { self.prepare_next_turn = Some(f); self } pub fn with_event_sink(mut self, sink: EventSink) -> Self { self.event_sink = Some(sink); self } } /// Enhanced agent with loop controls (steering, follow-up, model switching). pub struct EnhancedAgent { pub client: AiClient, pub loop_config: AgentLoopConfig, pub hooks: HookChain, } impl EnhancedAgent { pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult { loop_config.config.validate()?; Ok(Self { client, loop_config, hooks: HookChain::empty(), }) } pub fn with_hooks(mut self, hooks: HookChain) -> Self { self.hooks = hooks; self } pub fn config(&self) -> &AgentConfig { &self.loop_config.config } /// Run the enhanced agent loop, returning a chunk receiver and a join handle. #[allow(clippy::too_many_lines)] pub fn run( &self, request: AgentRequest, tools: Vec>, ) -> ( mpsc::Receiver, tokio::task::JoinHandle>, ) { let (tx, rx) = mpsc::channel::(256); let config = self.loop_config.config.clone(); let tool_execution_mode = self.loop_config.tool_execution_mode; let steering_fn = self.loop_config.get_steering_messages.clone(); let follow_up_fn = self.loop_config.get_follow_up_messages.clone(); let should_stop = self.loop_config.should_stop_after_turn.clone(); let prepare_next = self.loop_config.prepare_next_turn.clone(); let event_sink = self.loop_config.event_sink.clone(); let client = self.client.llm_client().clone(); let hooks = self.hooks.clone(); let filtered_tools: Vec> = tools .into_iter() .filter(|tool| config.is_tool_exposed(&tool.name())) .collect(); let handle = tokio::spawn(async move { run_enhanced_loop( client, config, request, filtered_tools, tool_execution_mode, steering_fn, follow_up_fn, should_stop, prepare_next, event_sink, hooks, tx, ) .await }); (rx, handle) } } #[allow(clippy::too_many_lines, clippy::too_many_arguments)] async fn run_enhanced_loop( client: rig::providers::openai::Client, mut config: AgentConfig, request: AgentRequest, tools: Vec>, _tool_execution_mode: ToolExecutionMode, steering_fn: Option, follow_up_fn: Option, should_stop: Option, prepare_next: Option, event_sink: Option, hooks: HookChain, tx: mpsc::Sender, ) -> AiResult { let cancellation = request.cancellation_token.clone(); let timeout = request.timeout; let mut budget = IterationBudget::new(config.iteration_budget); let mut all_steps: Vec = Vec::new(); let mut total_input_tokens: u64 = 0; let mut total_output_tokens: u64 = 0; let mut turn_index: usize = 0; // Session start hook if let Some(ctx) = &request.run_context { let _ = hooks.run_session_start(ctx).await; } // Emit agent start event if let Some(sink) = &event_sink { sink.emit(AgentEvent::AgentStart); } // Build the initial input let input = build_input_string(&request); let mut current_input = input.clone(); let estimated_input_tokens = estimate_tokens(¤t_input); if let Some(limit) = config.max_total_tokens_per_run && estimated_input_tokens > limit as u64 { return Err(AiError::TokenBudgetExceeded { estimated: estimated_input_tokens, limit, }); } // Outer loop: handles follow-up messages after agent would stop loop { // Inner loop: tool call turns + steering messages let mut pending_steering: Vec = if let Some(f) = &steering_fn { f().await } else { Vec::new() }; loop { // Check cancellation if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await; if let Some(sink) = &event_sink { sink.emit(AgentEvent::ErrorClassified { category: "cancelled".to_string(), message: "cancelled by caller".to_string(), will_retry: false, retry_delay_ms: None, }); } return Err(AiError::Response("agent run cancelled".to_string())); } // Inject steering messages if any if !pending_steering.is_empty() { let count = pending_steering.len(); for msg in &pending_steering { current_input.push_str(&format!("\nUser: {msg}\n")); } if let Some(sink) = &event_sink { sink.emit(AgentEvent::SteeringMessagesInjected { count }); } pending_steering.clear(); } // Emit turn start if let Some(sink) = &event_sink { sink.emit(AgentEvent::TurnStart { turn_index }); } let _ = tx.send(RigStreamChunk::TextDelta { index: 0, content: String::new(), // placeholder for turn boundary detection }).await; // Run one LLM turn with retry let turn_result = run_single_turn( &client, &config, ¤t_input, &tools, &mut budget, &cancellation, timeout, &hooks, &event_sink, &tx, ) .await; match turn_result { Ok(turn_output) => { total_input_tokens += turn_output.input_tokens; total_output_tokens += turn_output.output_tokens; // Collect step let tool_call_count = turn_output.tool_calls.len(); if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() { all_steps.push(AgentStep { index: all_steps.len(), assistant: (!turn_output.assistant_text.is_empty()) .then_some(turn_output.assistant_text.clone()), reasoning_content: None, tool_calls: turn_output.tool_calls, reflection: None, }); } // Emit turn end if let Some(sink) = &event_sink { sink.emit(AgentEvent::TurnEnd { turn_index, assistant_text: Some(turn_output.assistant_text.clone()), tool_call_count, }); } // Check should_stop let turn_ctx = TurnContext { turn_index, assistant_text: turn_output.assistant_text.clone(), tool_call_count, total_input_tokens, total_output_tokens, model_name: config.model.clone(), }; if let Some(stop_fn) = &should_stop { if stop_fn(&turn_ctx) { info!(turn_index, "agent stopped by should_stop callback"); break; } } // Prepare next turn (may switch model) if let Some(prep_fn) = &prepare_next { if let Some(update) = prep_fn(&turn_ctx).await { if let Some(new_model) = update.model { if let Some(sink) = &event_sink { sink.emit(AgentEvent::ModelSwitched { from_model: config.model.clone(), to_model: new_model.clone(), reason: "prepare_next_turn".to_string(), }); } config.model = new_model; } if let Some(temp) = update.temperature { config.temperature = Some(temp); } if let Some(max_tok) = update.max_completion_tokens { config.max_completion_tokens = Some(max_tok); } } } turn_index += 1; // If no tool calls, this turn is done if tool_call_count == 0 { break; } // Otherwise, continue with tool results as new input current_input = turn_output.assistant_text.clone(); } Err(e) => { // Error classification and retry with fallback let category = classify_error(&e, None); let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms); if let Some(sink) = &event_sink { sink.emit(AgentEvent::ErrorClassified { category: format!("{category:?}"), message: e.to_string(), will_retry: policy.switch_to_fallback || policy.max_attempts > 0, retry_delay_ms: Some(policy.base_delay.as_millis() as u64), }); } if should_switch_to_fallback(&category) { if let Some(fallback_model) = &config.fallback_model { info!( from_model = %config.model, to_model = %fallback_model, "switching to fallback model due to error" ); if let Some(sink) = &event_sink { sink.emit(AgentEvent::ModelSwitched { from_model: config.model.clone(), to_model: fallback_model.clone(), reason: format!("fallback: {category:?}"), }); } config.model = fallback_model.clone(); // Retry with the fallback model let retry_result = run_single_turn( &client, &config, ¤t_input, &tools, &mut budget, &cancellation, timeout, &hooks, &event_sink, &tx, ) .await; match retry_result { Ok(turn_output) => { total_input_tokens += turn_output.input_tokens; total_output_tokens += turn_output.output_tokens; let tc_count = turn_output.tool_calls.len(); let has_tools = tc_count > 0; let has_text = !turn_output.assistant_text.is_empty(); let assistant = turn_output.assistant_text; if has_tools || has_text { all_steps.push(AgentStep { index: all_steps.len(), assistant: has_text.then_some(assistant.clone()), reasoning_content: None, tool_calls: turn_output.tool_calls, reflection: None, }); } turn_index += 1; if !has_tools { break; } current_input = assistant; continue; } Err(retry_err) => { let _ = tx .send(RigStreamChunk::Failed { error: retry_err.to_string(), }) .await; if let Some(ctx) = &request.run_context { let _ = hooks.run_session_end(ctx, false).await; } return Err(retry_err); } } } } // Non-retryable or no fallback let _ = tx .send(RigStreamChunk::Failed { error: e.to_string(), }) .await; if let Some(ctx) = &request.run_context { let _ = hooks.run_session_end(ctx, false).await; } return Err(e); } } } // Check for follow-up messages let follow_ups: Vec = if let Some(f) = &follow_up_fn { f().await } else { Vec::new() }; if follow_ups.is_empty() { break; } // Inject follow-up messages and continue the outer loop let count = follow_ups.len(); for msg in &follow_ups { current_input.push_str(&format!("\nUser: {msg}\n")); } if let Some(sink) = &event_sink { sink.emit(AgentEvent::FollowUpMessagesInjected { count }); } } // Build final output let output = all_steps .last() .and_then(|s| s.assistant.clone()) .unwrap_or_default(); if let Some(sink) = &event_sink { sink.emit(AgentEvent::AgentEnd { messages: Vec::new(), total_input_tokens, total_output_tokens, }); } let _ = tx .send(RigStreamChunk::Final { content: output.clone(), input_tokens: total_input_tokens, output_tokens: total_output_tokens, }) .await; if let Some(ctx) = &request.run_context { let _ = hooks.run_session_end(ctx, true).await; } info!( turns = turn_index, steps = all_steps.len(), total_input_tokens, total_output_tokens, "enhanced agent loop completed" ); Ok(AgentResult { output, steps: all_steps, expert_outputs: Vec::new(), input_tokens: total_input_tokens as i64, output_tokens: total_output_tokens as i64, }) } /// Output from a single LLM turn (one assistant response + its tool calls). struct TurnOutput { assistant_text: String, tool_calls: Vec, input_tokens: u64, output_tokens: u64, } /// Run a single LLM turn with streaming, handling the stream parsing and /// tool call collection. #[allow(clippy::too_many_arguments)] async fn run_single_turn( client: &rig::providers::openai::Client, config: &AgentConfig, input: &str, _tools: &[Box], budget: &mut IterationBudget, cancellation: &Option, timeout: Option, hooks: &HookChain, event_sink: &Option, tx: &mpsc::Sender, ) -> AiResult { if !budget.consume() { return Err(AiError::Response("iteration budget exhausted".to_string())); } let model = client.completion_model(&config.model); let mut agent_builder = AgentBuilder::new(model) .preamble(&config.system_prompt) .default_max_turns(1); // Single turn, we manage the loop // Note: we can't easily pass tools here for single-turn since // rig's multi_turn handles tool execution internally. // For the enhanced loop, we rely on rig's built-in tool execution // within a single turn. The parallel/sequential mode is controlled // by the event-level hooks. if let Some(temp) = config.temperature { agent_builder = agent_builder.temperature(temp); } if let Some(mt) = config.max_completion_tokens { agent_builder = agent_builder.max_tokens(mt); } let agent = agent_builder.build(); // Pre-LLM hook if !hooks.is_empty() { let hook_messages = vec![HookMessage { role: "user".to_string(), content: Some(input.to_string()), tool_calls: None, tool_call_id: None, }]; let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await; } let stream_future = agent .stream_prompt(input) .with_history(Vec::::new()) .multi_turn(config.max_iterations); let stream = if let Some(dur) = timeout { match tokio::time::timeout(dur, stream_future).await { Ok(stream) => stream, Err(_) => { return Err(AiError::Timeout { seconds: dur.as_secs(), }); } } } else { stream_future.await }; tokio::pin!(stream); let mut assistant_text = String::new(); let mut tool_calls: Vec = Vec::new(); let mut delta_index = 0usize; let mut _accumulated_output_chars: usize = 0; let mut input_tokens: u64 = 0; let mut output_tokens: u64 = 0; while let Some(item) = stream.next().await { if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { return Err(AiError::Response("cancelled".to_string())); } match item { Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::Text(text), )) => { _accumulated_output_chars += text.text.chars().count(); assistant_text.push_str(&text.text); if let Some(sink) = &event_sink { sink.emit(AgentEvent::MessageTextDelta { index: delta_index, delta: text.text.clone(), }); } let _ = tx .send(RigStreamChunk::TextDelta { index: delta_index, content: text.text.clone(), }) .await; delta_index += 1; } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::Reasoning(reasoning), )) => { for part in &reasoning.content { if let rig::completion::message::ReasoningContent::Text { text, .. } = part { _accumulated_output_chars += text.chars().count(); if let Some(sink) = &event_sink { sink.emit(AgentEvent::MessageThinkingDelta { index: delta_index, delta: text.clone(), }); } let _ = tx .send(RigStreamChunk::Thinking { index: delta_index, content: text.clone(), }) .await; delta_index += 1; } } } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. }, )) => { _accumulated_output_chars += reasoning.chars().count(); if let Some(sink) = &event_sink { sink.emit(AgentEvent::MessageThinkingDelta { index: delta_index, delta: reasoning.clone(), }); } let _ = tx .send(RigStreamChunk::Thinking { index: delta_index, content: reasoning.clone(), }) .await; delta_index += 1; } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. }, )) => { let args = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), v => serde_json::to_string(v).unwrap_or_default(), }; _accumulated_output_chars += args.chars().count(); let tool_name = tool_call.function.name.clone(); let tool_args: serde_json::Value = serde_json::from_str(&args).unwrap_or_default(); // Pre-tool-call guardrail hook if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { match decision { ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Block { reason } => { if let Some(sink) = &event_sink { sink.emit(AgentEvent::ToolExecutionEnd { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), output: None, error: Some(reason.clone()), elapsed_ms: 0, }); } let _ = tx .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), output: format!("blocked: {reason}"), error: Some(reason), }) .await; tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name, arguments: tool_args, output: None, error: Some("blocked by guardrail".to_string()), elapsed_ms: None, }); continue; } ToolGuardrailDecision::RequireApproval { message } => { tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name.clone(), arguments: tool_args, output: None, error: Some(format!("requires approval: {message}")), elapsed_ms: None, }); continue; } } } if let Some(sink) = &event_sink { sink.emit(AgentEvent::ToolExecutionStart { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), arguments: tool_args.clone(), }); } let _ = tx .send(RigStreamChunk::ToolCallStarted { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), arguments: args.clone(), }) .await; tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name, arguments: tool_args, output: None, error: None, elapsed_ms: None, }); } Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, )) => { let content = super::helpers::tool_result_content_to_string(&tool_result.content); _accumulated_output_chars += content.chars().count(); let tool_name = tool_calls .last() .map(|tc| tc.name.clone()) .unwrap_or_default(); if let Some(last) = tool_calls.last_mut() && last.id == tool_result.id { last.output = Some(serde_json::from_str(&content).unwrap_or_default()); } if let Some(sink) = &event_sink { sink.emit(AgentEvent::ToolExecutionEnd { tool_call_id: tool_result.id.clone(), tool_name: tool_name.clone(), output: Some(serde_json::Value::String(content.clone())), error: None, elapsed_ms: 0, }); } let _ = tx .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_result.id.clone(), tool_name, output: content.clone(), error: None, }) .await; if !hooks.is_empty() { let outcome = ToolCallOutcome { name: tool_result.id.clone(), arguments: serde_json::Value::Null, output: Some(serde_json::Value::String(content)), error: None, elapsed_ms: 0, }; let _ = hooks.run_post_tool_call(&outcome).await; } } Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); input_tokens = usage.input_tokens; output_tokens = usage.output_tokens; if !hooks.is_empty() { let hook_response = HookLlmResponse { content: Some(assistant_text.clone()), tool_calls: None, input_tokens, output_tokens, finish_reason: None, }; let _ = hooks.run_post_llm_call(&hook_response).await; } } Err(e) => { warn!(error = %e, "turn stream error"); return Err(AiError::Api(format!("{e}"))); } _ => {} } } Ok(TurnOutput { assistant_text, tool_calls, input_tokens, output_tokens, }) }