use futures::StreamExt; use rig::agent::AgentBuilder; use rig::client::CompletionClient; use rig::streaming::StreamingPrompt; use rig::tool::ToolDyn; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use super::RigStreamChunk; use super::config::AgentConfig; use super::helpers::{build_input_string, check_token_budget, estimate_tokens}; use super::hooks::{ HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision, }; use super::persistence::ActiveAgentRun; use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord}; use super::subagent::run_experts; use crate::client::AiClient; use crate::error::{AiError, AiResult}; pub struct RigAgent { pub client: AiClient, pub config: AgentConfig, pub hooks: HookChain, } impl RigAgent { pub fn new(client: AiClient, config: AgentConfig) -> AiResult { config.validate()?; Ok(Self { client, config, hooks: HookChain::empty(), }) } pub fn with_hooks(mut self, hooks: HookChain) -> Self { self.hooks = hooks; self } pub fn config(&self) -> &AgentConfig { &self.config } #[tracing::instrument(skip(self, tools), fields(model = %self.config.model))] pub async fn chat( &self, request: AgentRequest, tools: Vec>, ) -> AiResult { let (mut rx, handle) = self.run(request, tools); tokio::spawn(async move { while rx.recv().await.is_some() {} }); let result = handle.await.map_err(|_| { AiError::Response("agent task panicked".to_string()) })?; result.map(|r| r.output) } #[allow(clippy::too_many_lines)] pub fn run( &self, request: AgentRequest, tools: Vec>, ) -> ( tokio::sync::mpsc::Receiver, tokio::task::JoinHandle>, ) { let (tx, rx) = mpsc::channel::(256); let model_name = self.config.model.clone(); let max_iterations = self.config.max_iterations; let client = self.client.llm_client().clone(); let ai_client = self.client.clone(); let agent_config = self.config.clone(); let system_prompt = self.config.system_prompt.clone(); let temperature = self.config.temperature; let max_completion_tokens = self.config.max_completion_tokens; let max_total_tokens = self.config.max_total_tokens_per_run; let cancellation = request.cancellation_token.clone(); let timeout = request.timeout; let hooks = self.hooks.clone(); let filtered_tools: Vec> = tools .into_iter() .filter(|tool| self.config.is_tool_exposed(&tool.name())) .collect(); let handle = tokio::spawn(async move { execute_agent_run( client, model_name, system_prompt, request, filtered_tools, max_iterations, ai_client, agent_config, temperature, max_completion_tokens, max_total_tokens, cancellation, timeout, hooks, tx, ) .await }); (rx, handle) } } #[allow(clippy::too_many_lines, clippy::too_many_arguments)] async fn execute_agent_run( client: rig::providers::openai::Client, model_name: String, system_prompt: String, request: AgentRequest, tools: Vec>, max_iterations: usize, ai_client: AiClient, agent_config: AgentConfig, temperature: Option, max_completion_tokens: Option, max_total_tokens: Option, cancellation: Option, timeout: Option, hooks: HookChain, tx: mpsc::Sender, ) -> AiResult { if let Some(ref ctx) = request.run_context { let _ = hooks.run_session_start(ctx).await; } let model = client.completion_model(&model_name); let mut agent_builder = AgentBuilder::new(model) .preamble(&system_prompt) .tools(tools) .default_max_turns(max_iterations); if let Some(temp) = temperature { agent_builder = agent_builder.temperature(temp); } if let Some(mt) = max_completion_tokens { agent_builder = agent_builder.max_tokens(mt); } let agent = agent_builder.build(); let mut input = build_input_string(&request); // ---- SubAgent execution ---- let expert_outputs = if !request.experts.is_empty() { let run = ActiveAgentRun { conversation_id: request .run_context .as_ref() .and_then(|c| c.conversation_id), message_id: None, invocation_id: request .run_context .as_ref() .and_then(|c| c.invocation_id), session_id: request.run_context.as_ref().and_then(|c| c.session_id), user_id: request.run_context.as_ref().and_then(|c| c.user_id), started_at: std::time::Instant::now(), current_step: 0, }; let realtime = request .run_context .as_ref() .and_then(|c| c.realtime.as_ref()); // Notify frontend that subagents are starting. for expert in &request.experts { let _ = tx .send(RigStreamChunk::SubagentStarted { subagent_id: expert.id.clone(), role: expert.role.clone(), task: expert.task.clone(), }) .await; } match run_experts( &ai_client, &agent_config, &request.experts, realtime, &run, ) .await { Ok(outputs) => { for out in &outputs { let _ = tx .send(RigStreamChunk::SubagentCompleted { subagent_id: out.id.clone(), role: out.role.clone(), task: out.task.clone(), output: out.output.clone(), }) .await; input.push_str(&format!( "\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n", out.id, out.role, out.task, out.output )); } outputs } Err(e) => { warn!(error = %e, "subagent execution failed, continuing without expert inputs"); let _ = tx .send(RigStreamChunk::SubagentFailed { error: e.to_string(), }) .await; Vec::new() } } } else { Vec::new() }; let estimated_input_tokens = estimate_tokens(&input); if let Some(limit) = max_total_tokens && estimated_input_tokens > limit as u64 { return Err(AiError::TokenBudgetExceeded { estimated: estimated_input_tokens, limit, }); } if !hooks.is_empty() { let hook_messages: Vec = request .messages .iter() .map(|m| HookMessage { role: match m { super::request::AgentMessage::User(_) => "user".to_string(), super::request::AgentMessage::Assistant(_) => { "assistant".to_string() } }, content: match m { super::request::AgentMessage::User(c) => Some(c.clone()), super::request::AgentMessage::Assistant(c) => { Some(c.clone()) } }, tool_calls: None, tool_call_id: None, }) .collect(); let hook_tools: Vec = Vec::new(); let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await; } let stream_future = agent .stream_prompt(&input) .with_history(Vec::::new()) .multi_turn(max_iterations); let stream = if let Some(dur) = timeout { match tokio::time::timeout(dur, stream_future).await { Ok(stream) => stream, Err(_elapsed) => { let _ = tx .send(RigStreamChunk::Failed { error: format!( "agent timed out after {}s", dur.as_secs() ), }) .await; return Err(AiError::Timeout { seconds: dur.as_secs(), }); } } } else { stream_future.await }; tokio::pin!(stream); let mut steps = Vec::new(); let mut delta_index = 0usize; let mut current_step_tool_calls: Vec = Vec::new(); let mut current_step_assistant = String::new(); let mut current_step_reasoning = String::new(); let mut accumulated_output_chars: usize = 0; while let Some(item) = stream.next().await { if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) { let _ = tx .send(RigStreamChunk::Failed { error: "cancelled".to_string(), }) .await; return Err(AiError::Response("agent run cancelled".to_string())); } if let Some(limit) = max_total_tokens && check_token_budget( estimated_input_tokens, accumulated_output_chars, limit, ) { let _ = tx .send(RigStreamChunk::Failed { error: format!("token budget exceeded: limit {limit}"), }) .await; return Err(AiError::TokenBudgetExceeded { estimated: estimated_input_tokens + (accumulated_output_chars as f64 / 2.5).ceil() as u64, limit, }); } match item { Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::Text(text), )) => { accumulated_output_chars += text.text.chars().count(); current_step_assistant.push_str(&text.text); let _ = tx .send(RigStreamChunk::TextDelta { index: delta_index, content: text.text.clone(), }) .await; delta_index += 1; } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::Reasoning(reasoning), )) => { for part in &reasoning.content { if let rig::completion::message::ReasoningContent::Text { text, .. } = part { accumulated_output_chars += text.chars().count(); current_step_reasoning.push_str(text); let _ = tx .send(RigStreamChunk::Thinking { index: delta_index, content: text.clone(), }) .await; delta_index += 1; } } } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. }, )) => { accumulated_output_chars += reasoning.chars().count(); current_step_reasoning.push_str(&reasoning); let _ = tx .send(RigStreamChunk::Thinking { index: delta_index, content: reasoning.clone(), }) .await; delta_index += 1; } Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem( rig::streaming::StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _, }, )) => { let args = match &tool_call.function.arguments { serde_json::Value::String(s) => s.clone(), v => serde_json::to_string(v).unwrap_or_default(), }; accumulated_output_chars += args.chars().count(); let tool_name = tool_call.function.name.clone(); let tool_args: serde_json::Value = serde_json::from_str(&args).unwrap_or_default(); if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await { match decision { ToolGuardrailDecision::Allow => {} ToolGuardrailDecision::Block { reason } => { let _ = tx .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), output: format!("blocked: {reason}"), error: Some(reason), }) .await; current_step_tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name.clone(), arguments: tool_args.clone(), output: None, error: Some("blocked by guardrail".to_string()), elapsed_ms: None, }); continue; } ToolGuardrailDecision::RequireApproval { message } => { let _ = tx .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), output: format!( "awaiting approval: {message}" ), error: None, }) .await; current_step_tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name.clone(), arguments: tool_args.clone(), output: None, error: Some(format!( "requires approval: {message}" )), elapsed_ms: None, }); continue; } } } let _ = tx .send(RigStreamChunk::ToolCallStarted { tool_call_id: tool_call.id.clone(), tool_name: tool_name.clone(), arguments: args.clone(), }) .await; current_step_tool_calls.push(ToolCallRecord { id: tool_call.id.clone(), name: tool_name.clone(), arguments: tool_args.clone(), output: None, error: None, elapsed_ms: None, }); } Ok(rig::agent::MultiTurnStreamItem::StreamUserItem( rig::streaming::StreamedUserContent::ToolResult { tool_result, .. }, )) => { let content = super::helpers::tool_result_content_to_string( &tool_result.content, ); accumulated_output_chars += content.chars().count(); if let Some(last) = current_step_tool_calls.last_mut() && last.id == tool_result.id { last.output = Some( serde_json::from_str(&content).unwrap_or_default(), ); } let tool_name = current_step_tool_calls .last() .map(|tc| tc.name.clone()) .unwrap_or_default(); let _ = tx .send(RigStreamChunk::ToolCallFinished { tool_call_id: tool_result.id.clone(), tool_name, output: content.clone(), error: None, }) .await; if !hooks.is_empty() { let outcome = ToolCallOutcome { name: tool_result.id.clone(), arguments: serde_json::Value::Null, output: Some(serde_json::Value::String(content)), error: None, elapsed_ms: 0, }; let _ = hooks.run_post_tool_call(&outcome).await; } } Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => { let usage = resp.usage(); if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() { let reasoning = (!current_step_reasoning.is_empty()) .then_some(std::mem::take(&mut current_step_reasoning)); steps.push(AgentStep { index: steps.len(), assistant: (!current_step_assistant.is_empty()) .then_some(std::mem::take( &mut current_step_assistant, )), reasoning_content: reasoning, tool_calls: std::mem::take( &mut current_step_tool_calls, ), reflection: None, }); } let output = steps .last() .and_then(|s| s.assistant.clone()) .unwrap_or_default(); if !hooks.is_empty() { let hook_response = HookLlmResponse { content: Some(output.clone()), tool_calls: None, input_tokens: usage.input_tokens, output_tokens: usage.output_tokens, finish_reason: None, }; let _ = hooks.run_post_llm_call(&hook_response).await; } info!( steps = steps.len(), input_tokens = usage.input_tokens, output_tokens = usage.output_tokens, "agent run completed" ); let _ = tx .send(RigStreamChunk::Final { content: output.clone(), input_tokens: usage.input_tokens, output_tokens: usage.output_tokens, }) .await; if let Some(ref ctx) = request.run_context { let _ = hooks.run_session_end(ctx, true).await; } return Ok(AgentResult { output, steps, expert_outputs, input_tokens: usage.input_tokens as i64, output_tokens: usage.output_tokens as i64, }); } Err(e) => { let err = format!("{e}"); warn!(error = %err, "agent stream error"); let _ = tx.send(RigStreamChunk::Failed { error: err }).await; if let Some(ref ctx) = request.run_context { let _ = hooks.run_session_end(ctx, false).await; } return Err(AiError::Api(format!("{e}"))); } _ => {} } } Err(AiError::Response( "agent stream ended without final response".to_string(), )) } impl Clone for HookChain { fn clone(&self) -> Self { HookChain::empty() } }