440 lines
15 KiB
Rust
440 lines
15 KiB
Rust
//! ReAct (Reasoning + Acting) agent core.
|
|
|
|
use async_openai::types::chat::FunctionCall;
|
|
use async_openai::types::chat::{
|
|
ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
|
|
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
|
|
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
|
|
ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage,
|
|
ChatCompletionRequestUserMessageContent,
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use crate::call_with_params;
|
|
use crate::error::{AgentError, Result};
|
|
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
|
|
use crate::react::types::{Action, ReactConfig, ReactStep};
|
|
|
|
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
|
|
|
|
/// A ReAct agent that performs multi-step tool-augmented reasoning.
|
|
#[derive(Clone)]
|
|
pub struct ReactAgent {
|
|
messages: Vec<ChatCompletionRequestMessage>,
|
|
#[allow(dead_code)]
|
|
tool_definitions: Vec<async_openai::types::chat::ChatCompletionTool>,
|
|
config: ReactConfig,
|
|
step_count: usize,
|
|
hook: Arc<dyn Hook>,
|
|
}
|
|
|
|
impl ReactAgent {
|
|
/// Create a new agent with a system prompt and OpenAI tool definitions.
|
|
pub fn new(
|
|
system_prompt: &str,
|
|
tools: Vec<async_openai::types::chat::ChatCompletionTool>,
|
|
config: ReactConfig,
|
|
) -> Self {
|
|
let messages = vec![ChatCompletionRequestMessage::User(
|
|
ChatCompletionRequestUserMessage {
|
|
content: ChatCompletionRequestUserMessageContent::Text(system_prompt.to_string()),
|
|
..Default::default()
|
|
},
|
|
)];
|
|
Self {
|
|
messages,
|
|
tool_definitions: tools,
|
|
config,
|
|
step_count: 0,
|
|
hook: Arc::new(NoopHook),
|
|
}
|
|
}
|
|
|
|
/// Add an initial user message to the conversation.
|
|
pub fn add_user_message(&mut self, content: &str) {
|
|
self.messages.push(ChatCompletionRequestMessage::User(
|
|
ChatCompletionRequestUserMessage {
|
|
content: ChatCompletionRequestUserMessageContent::Text(content.to_string()),
|
|
..Default::default()
|
|
},
|
|
));
|
|
}
|
|
|
|
/// Attach a hook to observe and control the agent loop.
|
|
///
|
|
/// Hooks can log steps, filter content, inject custom tool results,
|
|
/// or terminate the loop early. Multiple `.with_hook()` calls replace
|
|
/// the previous hook.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```ignore
|
|
/// #[derive(Clone)]
|
|
/// struct MyLogger;
|
|
///
|
|
/// impl Hook for MyLogger {
|
|
/// async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
|
|
/// eprintln!("[step {}] thought: {}", step, thought);
|
|
/// HookAction::Continue
|
|
/// }
|
|
/// }
|
|
///
|
|
/// let agent = ReactAgent::new(prompt, tools, config).with_hook(MyLogger);
|
|
/// ```
|
|
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
|
|
self.hook = Arc::new(hook);
|
|
self
|
|
}
|
|
|
|
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
|
|
///
|
|
/// Yields streaming chunks via `on_chunk`. Each step produces:
|
|
/// - A `ReactStep::Thought` chunk when the AI emits reasoning
|
|
/// - A `ReactStep::Action` chunk when the AI emits a tool call
|
|
/// - A `ReactStep::Observation` chunk after the tool executes
|
|
/// - A `ReactStep::Answer` chunk when the loop terminates with a final answer
|
|
///
|
|
/// Hooks are called at each phase (see [Hook]). Return [HookAction::Terminate]
|
|
/// from any hook to stop the loop early.
|
|
pub async fn run<C>(
|
|
&mut self,
|
|
model_name: &str,
|
|
client_config: &crate::client::AiClientConfig,
|
|
mut on_chunk: C,
|
|
) -> Result<String>
|
|
where
|
|
C: FnMut(ReactStep) + Send,
|
|
{
|
|
loop {
|
|
if self.step_count >= self.config.max_steps {
|
|
return Err(AgentError::Internal(format!(
|
|
"ReAct agent reached max steps ({})",
|
|
self.config.max_steps
|
|
)));
|
|
}
|
|
|
|
self.step_count += 1;
|
|
let step = self.step_count;
|
|
|
|
let response = call_with_params(
|
|
&self.messages,
|
|
model_name,
|
|
client_config,
|
|
0.2, // temperature
|
|
4096, // max output tokens
|
|
None,
|
|
if self.tool_definitions.is_empty() {
|
|
None
|
|
} else {
|
|
Some(self.tool_definitions.as_slice())
|
|
},
|
|
)
|
|
.await?;
|
|
|
|
let parsed = parse_react_response(&response.content);
|
|
let answer = parsed.answer.clone();
|
|
let action = parsed.action.clone();
|
|
|
|
// Emit thought step.
|
|
on_chunk(ReactStep::Thought {
|
|
step,
|
|
thought: parsed.thought.clone(),
|
|
});
|
|
|
|
// Hook: thought
|
|
match self.hook.on_thought(step, &parsed.thought).await {
|
|
HookAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at thought step: {}",
|
|
reason
|
|
)));
|
|
}
|
|
HookAction::Skip => {
|
|
// Skip this step, go directly to answer if present
|
|
}
|
|
HookAction::Continue => {}
|
|
}
|
|
|
|
// Final answer — emit and return.
|
|
if let Some(ans) = answer {
|
|
on_chunk(ReactStep::Answer {
|
|
step,
|
|
answer: ans.clone(),
|
|
});
|
|
|
|
// Hook: answer
|
|
match self.hook.on_answer(step, &ans).await {
|
|
HookAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at answer step: {}",
|
|
reason
|
|
)));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
return Ok(ans);
|
|
}
|
|
|
|
// No answer — either do a tool call or fall back.
|
|
let Some(act) = action else {
|
|
let content = response.content.clone();
|
|
on_chunk(ReactStep::Answer {
|
|
step,
|
|
answer: content.clone(),
|
|
});
|
|
|
|
// Hook: answer (fallback)
|
|
match self.hook.on_answer(step, &content).await {
|
|
HookAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at fallback answer: {}",
|
|
reason
|
|
)));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
return Ok(content);
|
|
};
|
|
|
|
on_chunk(ReactStep::Action {
|
|
step,
|
|
action: act.clone(),
|
|
});
|
|
|
|
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
|
|
|
|
// Hook: tool call — can skip or terminate
|
|
match self.hook.on_tool_call(step, &act.name, &args_json).await {
|
|
ToolCallAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at tool call: {}",
|
|
reason
|
|
)));
|
|
}
|
|
ToolCallAction::Skip(injected_result) => {
|
|
// Skip actual execution, inject the provided result
|
|
let observation = injected_result;
|
|
on_chunk(ReactStep::Observation {
|
|
step,
|
|
observation: observation.clone(),
|
|
});
|
|
|
|
// Hook: observation (injected)
|
|
match self.hook.on_observation(step, &observation).await {
|
|
HookAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at observation (injected): {}",
|
|
reason
|
|
)));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Append observation as a tool message so the model sees it in context.
|
|
self.messages.push(ChatCompletionRequestMessage::Tool(
|
|
ChatCompletionRequestToolMessage {
|
|
tool_call_id: act.id.clone(),
|
|
content: ChatCompletionRequestToolMessageContent::Text(observation),
|
|
},
|
|
));
|
|
|
|
continue;
|
|
}
|
|
ToolCallAction::Continue => {}
|
|
}
|
|
|
|
// Append the assistant message with tool_calls to history.
|
|
let assistant_msg = build_tool_call_message(&act);
|
|
self.messages.push(assistant_msg);
|
|
|
|
// Execute the tool.
|
|
let observation = match &self.config.tool_executor {
|
|
Some(exec) => {
|
|
let result = exec(act.name.clone(), act.args.clone()).await;
|
|
match result {
|
|
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
|
|
Err(e) => serde_json::json!({ "error": e }).to_string(),
|
|
}
|
|
}
|
|
None => serde_json::json!({
|
|
"error": format!("no tool executor registered for '{}'", act.name)
|
|
})
|
|
.to_string(),
|
|
};
|
|
|
|
on_chunk(ReactStep::Observation {
|
|
step,
|
|
observation: observation.clone(),
|
|
});
|
|
|
|
// Hook: observation
|
|
match self.hook.on_observation(step, &observation).await {
|
|
HookAction::Terminate(reason) => {
|
|
return Err(AgentError::Internal(format!(
|
|
"hook terminated at observation step: {}",
|
|
reason
|
|
)));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Append observation as a tool message so the model sees it in context.
|
|
self.messages.push(ChatCompletionRequestMessage::Tool(
|
|
ChatCompletionRequestToolMessage {
|
|
tool_call_id: act.id.clone(),
|
|
content: ChatCompletionRequestToolMessageContent::Text(observation),
|
|
},
|
|
));
|
|
}
|
|
}
|
|
|
|
/// Returns the number of steps executed so far.
|
|
pub fn steps(&self) -> usize {
|
|
self.step_count
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Response parsing
|
|
// ---------------------------------------------------------------------------
|
|
|
|
struct ParsedReActResponse {
|
|
thought: String,
|
|
action: Option<Action>,
|
|
answer: Option<String>,
|
|
}
|
|
|
|
/// Parse the AI's text response into a ReAct step.
|
|
///
|
|
/// The AI is prompted (via system prompt in `ReactAgent::new`) to respond with
|
|
/// JSON in one of these forms:
|
|
///
|
|
/// ```json
|
|
/// { "thought": "...", "action": { "name": "tool_name", "arguments": {...} } }
|
|
/// { "thought": "...", "answer": "final answer text" }
|
|
/// ```
|
|
fn parse_react_response(content: &str) -> ParsedReActResponse {
|
|
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct RawStep {
|
|
#[serde(default)]
|
|
thought: Option<String>,
|
|
#[serde(default)]
|
|
action: Option<RawAction>,
|
|
#[serde(default)]
|
|
answer: Option<String>,
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default, rename = "arguments")]
|
|
args: Option<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(serde::Deserialize)]
|
|
struct RawAction {
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default, rename = "arguments")]
|
|
args: Option<serde_json::Value>,
|
|
}
|
|
|
|
match serde_json::from_str::<RawStep>(&json_str) {
|
|
Ok(raw) => {
|
|
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
|
|
let answer = raw.answer;
|
|
let action = raw.action.map(|a| Action {
|
|
id: Uuid::new_v4().to_string(),
|
|
name: a.name.unwrap_or_default(),
|
|
args: a.args.unwrap_or(serde_json::Value::Null),
|
|
});
|
|
// Handle flat format: { "name": "...", "arguments": {...} }
|
|
let action = action.or_else(|| {
|
|
if raw.name.is_some() || raw.args.is_some() {
|
|
Some(Action {
|
|
id: Uuid::new_v4().to_string(),
|
|
name: raw.name.unwrap_or_default(),
|
|
args: raw.args.unwrap_or(serde_json::Value::Null),
|
|
})
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
ParsedReActResponse {
|
|
thought,
|
|
action,
|
|
answer,
|
|
}
|
|
}
|
|
Err(_) => ParsedReActResponse {
|
|
thought: content.to_string(),
|
|
action: None,
|
|
answer: None,
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Extract the first JSON object or array from a string, handling markdown fences.
|
|
fn extract_json(s: &str) -> Option<String> {
|
|
let trimmed = s.trim();
|
|
if trimmed.starts_with('{') || trimmed.starts_with('[') {
|
|
return Some(trimmed.to_string());
|
|
}
|
|
for line in trimmed.lines() {
|
|
let line = line.trim();
|
|
if line.starts_with("```json") || line.starts_with("```") {
|
|
let mut buf = String::new();
|
|
let mut found_start = false;
|
|
for l in trimmed.lines() {
|
|
let l = l.trim();
|
|
if !found_start && (l == "```json" || l == "```") {
|
|
found_start = true;
|
|
continue;
|
|
}
|
|
if found_start && l == "```" {
|
|
break;
|
|
}
|
|
if found_start {
|
|
buf.push_str(l);
|
|
buf.push('\n');
|
|
}
|
|
}
|
|
let result = buf.trim().to_string();
|
|
if !result.is_empty() {
|
|
return Some(result);
|
|
}
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Build an assistant message with tool_calls from an Action.
|
|
#[allow(deprecated)]
|
|
fn build_tool_call_message(action: &Action) -> ChatCompletionRequestMessage {
|
|
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
|
|
|
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(format!(
|
|
"Action: {}",
|
|
action.name
|
|
))),
|
|
name: None,
|
|
refusal: None,
|
|
audio: None,
|
|
tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
|
|
ChatCompletionMessageToolCall {
|
|
id: action.id.clone(),
|
|
function: FunctionCall {
|
|
name: action.name.clone(),
|
|
arguments: fn_arg_str,
|
|
},
|
|
},
|
|
)]),
|
|
function_call: None,
|
|
})
|
|
}
|