//! ReAct (Reasoning + Acting) agent core. use async_openai::types::chat::FunctionCall; use async_openai::types::chat::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, }; use uuid::Uuid; use std::sync::Arc; use crate::call_with_params; use crate::error::{AgentError, Result}; use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction}; use crate::react::types::{Action, ReactConfig, ReactStep}; pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep}; /// A ReAct agent that performs multi-step tool-augmented reasoning. #[derive(Clone)] pub struct ReactAgent { messages: Vec, #[allow(dead_code)] tool_definitions: Vec, config: ReactConfig, step_count: usize, hook: Arc, } impl ReactAgent { /// Create a new agent with a system prompt and OpenAI tool definitions. pub fn new( system_prompt: &str, tools: Vec, config: ReactConfig, ) -> Self { let messages = vec![ChatCompletionRequestMessage::User( ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Text(system_prompt.to_string()), ..Default::default() }, )]; Self { messages, tool_definitions: tools, config, step_count: 0, hook: Arc::new(NoopHook), } } /// Add an initial user message to the conversation. pub fn add_user_message(&mut self, content: &str) { self.messages.push(ChatCompletionRequestMessage::User( ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Text(content.to_string()), ..Default::default() }, )); } /// Attach a hook to observe and control the agent loop. /// /// Hooks can log steps, filter content, inject custom tool results, /// or terminate the loop early. Multiple `.with_hook()` calls replace /// the previous hook. /// /// # Example /// /// ```ignore /// #[derive(Clone)] /// struct MyLogger; /// /// impl Hook for MyLogger { /// async fn on_thought(&self, step: usize, thought: &str) -> HookAction { /// eprintln!("[step {}] thought: {}", step, thought); /// HookAction::Continue /// } /// } /// /// let agent = ReactAgent::new(prompt, tools, config).with_hook(MyLogger); /// ``` pub fn with_hook(mut self, hook: H) -> Self { self.hook = Arc::new(hook); self } /// Run the ReAct loop until a final answer is produced or `max_steps` is reached. /// /// Yields streaming chunks via `on_chunk`. Each step produces: /// - A `ReactStep::Thought` chunk when the AI emits reasoning /// - A `ReactStep::Action` chunk when the AI emits a tool call /// - A `ReactStep::Observation` chunk after the tool executes /// - A `ReactStep::Answer` chunk when the loop terminates with a final answer /// /// Hooks are called at each phase (see [Hook]). Return [HookAction::Terminate] /// from any hook to stop the loop early. pub async fn run( &mut self, model_name: &str, client_config: &crate::client::AiClientConfig, mut on_chunk: C, ) -> Result where C: FnMut(ReactStep) + Send, { loop { if self.step_count >= self.config.max_steps { return Err(AgentError::Internal(format!( "ReAct agent reached max steps ({})", self.config.max_steps ))); } self.step_count += 1; let step = self.step_count; let response = call_with_params( &self.messages, model_name, client_config, 0.2, // temperature 4096, // max output tokens None, if self.tool_definitions.is_empty() { None } else { Some(self.tool_definitions.as_slice()) }, ) .await?; let parsed = parse_react_response(&response.content); let answer = parsed.answer.clone(); let action = parsed.action.clone(); // Emit thought step. on_chunk(ReactStep::Thought { step, thought: parsed.thought.clone(), }); // Hook: thought match self.hook.on_thought(step, &parsed.thought).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at thought step: {}", reason ))); } HookAction::Skip => { // Skip this step, go directly to answer if present } HookAction::Continue => {} } // Final answer — emit and return. if let Some(ans) = answer { on_chunk(ReactStep::Answer { step, answer: ans.clone(), }); // Hook: answer match self.hook.on_answer(step, &ans).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at answer step: {}", reason ))); } _ => {} } return Ok(ans); } // No answer — either do a tool call or fall back. let Some(act) = action else { let content = response.content.clone(); on_chunk(ReactStep::Answer { step, answer: content.clone(), }); // Hook: answer (fallback) match self.hook.on_answer(step, &content).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at fallback answer: {}", reason ))); } _ => {} } return Ok(content); }; on_chunk(ReactStep::Action { step, action: act.clone(), }); let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string()); // Hook: tool call — can skip or terminate match self.hook.on_tool_call(step, &act.name, &args_json).await { ToolCallAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at tool call: {}", reason ))); } ToolCallAction::Skip(injected_result) => { // Skip actual execution, inject the provided result let observation = injected_result; on_chunk(ReactStep::Observation { step, observation: observation.clone(), }); // Hook: observation (injected) match self.hook.on_observation(step, &observation).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at observation (injected): {}", reason ))); } _ => {} } // Append observation as a tool message so the model sees it in context. self.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { tool_call_id: act.id.clone(), content: ChatCompletionRequestToolMessageContent::Text(observation), }, )); continue; } ToolCallAction::Continue => {} } // Append the assistant message with tool_calls to history. let assistant_msg = build_tool_call_message(&act); self.messages.push(assistant_msg); // Execute the tool. let observation = match &self.config.tool_executor { Some(exec) => { let result = exec(act.name.clone(), act.args.clone()).await; match result { Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()), Err(e) => serde_json::json!({ "error": e }).to_string(), } } None => serde_json::json!({ "error": format!("no tool executor registered for '{}'", act.name) }) .to_string(), }; on_chunk(ReactStep::Observation { step, observation: observation.clone(), }); // Hook: observation match self.hook.on_observation(step, &observation).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at observation step: {}", reason ))); } _ => {} } // Append observation as a tool message so the model sees it in context. self.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { tool_call_id: act.id.clone(), content: ChatCompletionRequestToolMessageContent::Text(observation), }, )); } } /// Returns the number of steps executed so far. pub fn steps(&self) -> usize { self.step_count } } // --------------------------------------------------------------------------- // Response parsing // --------------------------------------------------------------------------- struct ParsedReActResponse { thought: String, action: Option, answer: Option, } /// Parse the AI's text response into a ReAct step. /// /// The AI is prompted (via system prompt in `ReactAgent::new`) to respond with /// JSON in one of these forms: /// /// ```json /// { "thought": "...", "action": { "name": "tool_name", "arguments": {...} } } /// { "thought": "...", "answer": "final answer text" } /// ``` fn parse_react_response(content: &str) -> ParsedReActResponse { let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string()); #[derive(serde::Deserialize)] struct RawStep { #[serde(default)] thought: Option, #[serde(default)] action: Option, #[serde(default)] answer: Option, #[serde(default)] name: Option, #[serde(default, rename = "arguments")] args: Option, } #[derive(serde::Deserialize)] struct RawAction { #[serde(default)] name: Option, #[serde(default, rename = "arguments")] args: Option, } match serde_json::from_str::(&json_str) { Ok(raw) => { let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string()); let answer = raw.answer; let action = raw.action.map(|a| Action { id: Uuid::new_v4().to_string(), name: a.name.unwrap_or_default(), args: a.args.unwrap_or(serde_json::Value::Null), }); // Handle flat format: { "name": "...", "arguments": {...} } let action = action.or_else(|| { if raw.name.is_some() || raw.args.is_some() { Some(Action { id: Uuid::new_v4().to_string(), name: raw.name.unwrap_or_default(), args: raw.args.unwrap_or(serde_json::Value::Null), }) } else { None } }); ParsedReActResponse { thought, action, answer, } } Err(_) => ParsedReActResponse { thought: content.to_string(), action: None, answer: None, }, } } /// Extract the first JSON object or array from a string, handling markdown fences. fn extract_json(s: &str) -> Option { let trimmed = s.trim(); if trimmed.starts_with('{') || trimmed.starts_with('[') { return Some(trimmed.to_string()); } for line in trimmed.lines() { let line = line.trim(); if line.starts_with("```json") || line.starts_with("```") { let mut buf = String::new(); let mut found_start = false; for l in trimmed.lines() { let l = l.trim(); if !found_start && (l == "```json" || l == "```") { found_start = true; continue; } if found_start && l == "```" { break; } if found_start { buf.push_str(l); buf.push('\n'); } } let result = buf.trim().to_string(); if !result.is_empty() { return Some(result); } } } None } /// Build an assistant message with tool_calls from an Action. #[allow(deprecated)] fn build_tool_call_message(action: &Action) -> ChatCompletionRequestMessage { let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string()); ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text(format!( "Action: {}", action.name ))), name: None, refusal: None, audio: None, tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function( ChatCompletionMessageToolCall { id: action.id.clone(), function: FunctionCall { name: action.name.clone(), arguments: fn_arg_str, }, }, )]), function_call: None, }) }