//! ReAct (Reasoning + Acting) agent core. use uuid::Uuid; use std::sync::Arc; use crate::call_with_params; use crate::client::types::ChatRequestMessage; use crate::error::{AgentError, Result}; use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction}; use crate::react::types::{Action, ReactConfig, ReactStep}; pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep}; /// A ReAct agent that performs multi-step tool-augmented reasoning. #[derive(Clone)] pub struct ReactAgent { messages: Vec, #[allow(dead_code)] tool_definitions: Vec, config: ReactConfig, step_count: usize, hook: Arc, } impl ReactAgent { /// Create a new agent with a system prompt and tool definitions (as JSON values). pub fn new( system_prompt: &str, tools: Vec, config: ReactConfig, ) -> Self { let messages = vec![ChatRequestMessage::system(system_prompt)]; Self { messages, tool_definitions: tools, config, step_count: 0, hook: Arc::new(NoopHook), } } /// Add an initial user message to the conversation. pub fn add_user_message(&mut self, content: &str) { self.messages.push(ChatRequestMessage::user(content)); } /// Attach a hook to observe and control the agent loop. /// /// Hooks can log steps, filter content, inject custom tool results, /// or terminate the loop early. Multiple `.with_hook()` calls replace /// the previous hook. pub fn with_hook(mut self, hook: H) -> Self { self.hook = Arc::new(hook); self } /// Run the ReAct loop until a final answer is produced or `max_steps` is reached. pub async fn run( &mut self, model_name: &str, client_config: &crate::client::AiClientConfig, mut on_chunk: C, ) -> Result where C: FnMut(ReactStep) + Send, { loop { if self.step_count >= self.config.max_steps { let msg = format!( "Agent reached maximum reasoning steps ({}) without producing a final answer.", self.config.max_steps ); on_chunk(ReactStep::Answer { step: self.step_count, answer: msg.clone(), }); return Ok(msg); } self.step_count += 1; let step = self.step_count; // For ReAct we force text-only responses so the model follows our JSON-in-text format. let tool_choice_str = if self.tool_definitions.is_empty() { None } else { Some("none") }; let response = call_with_params( &self.messages, model_name, client_config, 0.2, // temperature 4096, // max output tokens None, if self.tool_definitions.is_empty() { None } else { Some(&self.tool_definitions) }, tool_choice_str, ) .await?; let parsed = parse_react_response(&response.content); let answer = parsed.answer.clone(); let action = parsed.action.clone(); on_chunk(ReactStep::Thought { step, thought: parsed.thought.clone(), }); match self.hook.on_thought(step, &parsed.thought).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at thought step: {}", reason ))); } HookAction::Skip => {} HookAction::Continue => {} } // Final answer — emit and return. if let Some(ans) = answer { on_chunk(ReactStep::Answer { step, answer: ans.clone(), }); match self.hook.on_answer(step, &ans).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at answer step: {}", reason ))); } _ => {} } return Ok(ans); } // No answer — either do a tool call or fall back. let Some(act) = action else { let content = response.content.clone(); on_chunk(ReactStep::Answer { step, answer: content.clone(), }); match self.hook.on_answer(step, &content).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at fallback answer: {}", reason ))); } _ => {} } return Ok(content); }; on_chunk(ReactStep::Action { step, action: act.clone(), }); let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string()); match self.hook.on_tool_call(step, &act.name, &args_json).await { ToolCallAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at tool call: {}", reason ))); } ToolCallAction::Skip(injected_result) => { let observation = injected_result; on_chunk(ReactStep::Observation { step, observation: observation.clone(), }); match self.hook.on_observation(step, &observation).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at observation (injected): {}", reason ))); } _ => {} } // Append assistant message with tool_calls. let assistant_msg = build_tool_call_message(&act); self.messages.push(assistant_msg); // Append observation as a tool message. self.messages.push(ChatRequestMessage::tool(&act.id, observation)); continue; } ToolCallAction::Continue => {} } // Append the assistant message with tool_calls. let assistant_msg = build_tool_call_message(&act); self.messages.push(assistant_msg); // Execute the tool. let observation = match &self.config.tool_executor { Some(exec) => { let result = exec(act.name.clone(), act.args.clone()).await; match result { Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()), Err(e) => serde_json::json!({ "error": e }).to_string(), } } None => serde_json::json!({ "error": format!("no tool executor registered for '{}'", act.name) }) .to_string(), }; on_chunk(ReactStep::Observation { step, observation: observation.clone(), }); match self.hook.on_observation(step, &observation).await { HookAction::Terminate(reason) => { return Err(AgentError::Internal(format!( "hook terminated at observation step: {}", reason ))); } _ => {} } // Append observation as a tool message. self.messages.push(ChatRequestMessage::tool(&act.id, observation)); } } /// Returns the number of steps executed so far. pub fn steps(&self) -> usize { self.step_count } } // --------------------------------------------------------------------------- // Response parsing // --------------------------------------------------------------------------- struct ParsedReActResponse { thought: String, action: Option, answer: Option, } fn parse_react_response(content: &str) -> ParsedReActResponse { let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string()); #[derive(serde::Deserialize)] struct RawStep { #[serde(default)] thought: Option, #[serde(default)] action: Option, #[serde(default)] answer: Option, #[serde(default)] name: Option, #[serde(default, rename = "arguments")] args: Option, } #[derive(serde::Deserialize)] struct RawAction { #[serde(default)] name: Option, #[serde(default, rename = "arguments")] args: Option, } match serde_json::from_str::(&json_str) { Ok(raw) => { let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string()); let answer = raw.answer; let action = raw.action.map(|a| Action { id: Uuid::new_v4().to_string(), name: a.name.unwrap_or_default(), args: a.args.unwrap_or(serde_json::Value::Null), }); let action = action.or_else(|| { if raw.name.is_some() || raw.args.is_some() { Some(Action { id: Uuid::new_v4().to_string(), name: raw.name.unwrap_or_default(), args: raw.args.unwrap_or(serde_json::Value::Null), }) } else { None } }); ParsedReActResponse { thought, action, answer, } } Err(_) => ParsedReActResponse { thought: content.to_string(), action: None, answer: None, }, } } fn extract_json(s: &str) -> Option { let trimmed = s.trim(); if trimmed.starts_with('{') || trimmed.starts_with('[') { return Some(trimmed.to_string()); } for line in trimmed.lines() { let line = line.trim(); if line.starts_with("```json") || line == "```" { let mut buf = String::new(); let mut found_start = false; for l in trimmed.lines() { let l = l.trim(); if !found_start && (l == "```json" || l == "```") { found_start = true; continue; } if found_start && l == "```" { break; } if found_start { buf.push_str(l); buf.push('\n'); } } let result = buf.trim().to_string(); if !result.is_empty() { return Some(result); } } } let chars: Vec = trimmed.chars().collect(); for i in 0..chars.len() { let c = chars[i]; if (c == '{' || c == '[') && i > 0 { let prev = chars[i - 1]; if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' { continue; } let candidate: String = chars[i..].iter().collect(); if serde_json::from_str::(&candidate).is_ok() { return Some(candidate.trim_end().to_string()); } let mut depth = 0isize; let mut in_string = false; let mut escaped = false; for (j, c) in candidate.char_indices() { if escaped { escaped = false; continue; } if c == '\\' { escaped = true; continue; } if c == '"' { in_string = !in_string; continue; } if in_string { continue; } if c == '{' || c == '[' { depth += 1; } if c == '}' || c == ']' { depth -= 1; } if depth == 0 { let json_end = j + c.len_utf8(); let trimmed_candidate = &candidate[..json_end]; if serde_json::from_str::(trimmed_candidate).is_ok() { return Some(trimmed_candidate.to_string()); } } } } } None } /// Build an assistant message with tool_calls from an Action. fn build_tool_call_message(action: &Action) -> ChatRequestMessage { let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string()); ChatRequestMessage { role: "assistant".into(), content: Some(format!("Action: {}", action.name)), name: None, tool_call_id: None, tool_calls: Some(vec![crate::client::types::ToolCall { id: action.id.clone(), type_: "function".into(), function: crate::client::types::ToolCallFunction { name: action.name.clone(), arguments: fn_arg_str, }, }]), } }