gitdataai/libs/agent/react/loop_core.rs
2026-04-14 19:02:01 +08:00

440 lines
15 KiB
Rust

//! ReAct (Reasoning + Acting) agent core.
use async_openai::types::chat::FunctionCall;
use async_openai::types::chat::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent,
};
use uuid::Uuid;
use std::sync::Arc;
use crate::call_with_params;
use crate::error::{AgentError, Result};
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
use crate::react::types::{Action, ReactConfig, ReactStep};
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
/// A ReAct agent that performs multi-step tool-augmented reasoning.
#[derive(Clone)]
pub struct ReactAgent {
messages: Vec<ChatCompletionRequestMessage>,
#[allow(dead_code)]
tool_definitions: Vec<async_openai::types::chat::ChatCompletionTool>,
config: ReactConfig,
step_count: usize,
hook: Arc<dyn Hook>,
}
impl ReactAgent {
/// Create a new agent with a system prompt and OpenAI tool definitions.
pub fn new(
system_prompt: &str,
tools: Vec<async_openai::types::chat::ChatCompletionTool>,
config: ReactConfig,
) -> Self {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(system_prompt.to_string()),
..Default::default()
},
)];
Self {
messages,
tool_definitions: tools,
config,
step_count: 0,
hook: Arc::new(NoopHook),
}
}
/// Add an initial user message to the conversation.
pub fn add_user_message(&mut self, content: &str) {
self.messages.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(content.to_string()),
..Default::default()
},
));
}
/// Attach a hook to observe and control the agent loop.
///
/// Hooks can log steps, filter content, inject custom tool results,
/// or terminate the loop early. Multiple `.with_hook()` calls replace
/// the previous hook.
///
/// # Example
///
/// ```ignore
/// #[derive(Clone)]
/// struct MyLogger;
///
/// impl Hook for MyLogger {
/// async fn on_thought(&self, step: usize, thought: &str) -> HookAction {
/// eprintln!("[step {}] thought: {}", step, thought);
/// HookAction::Continue
/// }
/// }
///
/// let agent = ReactAgent::new(prompt, tools, config).with_hook(MyLogger);
/// ```
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
self.hook = Arc::new(hook);
self
}
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
///
/// Yields streaming chunks via `on_chunk`. Each step produces:
/// - A `ReactStep::Thought` chunk when the AI emits reasoning
/// - A `ReactStep::Action` chunk when the AI emits a tool call
/// - A `ReactStep::Observation` chunk after the tool executes
/// - A `ReactStep::Answer` chunk when the loop terminates with a final answer
///
/// Hooks are called at each phase (see [Hook]). Return [HookAction::Terminate]
/// from any hook to stop the loop early.
pub async fn run<C>(
&mut self,
model_name: &str,
client_config: &crate::client::AiClientConfig,
mut on_chunk: C,
) -> Result<String>
where
C: FnMut(ReactStep) + Send,
{
loop {
if self.step_count >= self.config.max_steps {
return Err(AgentError::Internal(format!(
"ReAct agent reached max steps ({})",
self.config.max_steps
)));
}
self.step_count += 1;
let step = self.step_count;
let response = call_with_params(
&self.messages,
model_name,
client_config,
0.2, // temperature
4096, // max output tokens
None,
if self.tool_definitions.is_empty() {
None
} else {
Some(self.tool_definitions.as_slice())
},
)
.await?;
let parsed = parse_react_response(&response.content);
let answer = parsed.answer.clone();
let action = parsed.action.clone();
// Emit thought step.
on_chunk(ReactStep::Thought {
step,
thought: parsed.thought.clone(),
});
// Hook: thought
match self.hook.on_thought(step, &parsed.thought).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at thought step: {}",
reason
)));
}
HookAction::Skip => {
// Skip this step, go directly to answer if present
}
HookAction::Continue => {}
}
// Final answer — emit and return.
if let Some(ans) = answer {
on_chunk(ReactStep::Answer {
step,
answer: ans.clone(),
});
// Hook: answer
match self.hook.on_answer(step, &ans).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at answer step: {}",
reason
)));
}
_ => {}
}
return Ok(ans);
}
// No answer — either do a tool call or fall back.
let Some(act) = action else {
let content = response.content.clone();
on_chunk(ReactStep::Answer {
step,
answer: content.clone(),
});
// Hook: answer (fallback)
match self.hook.on_answer(step, &content).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at fallback answer: {}",
reason
)));
}
_ => {}
}
return Ok(content);
};
on_chunk(ReactStep::Action {
step,
action: act.clone(),
});
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
// Hook: tool call — can skip or terminate
match self.hook.on_tool_call(step, &act.name, &args_json).await {
ToolCallAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at tool call: {}",
reason
)));
}
ToolCallAction::Skip(injected_result) => {
// Skip actual execution, inject the provided result
let observation = injected_result;
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
// Hook: observation (injected)
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation (injected): {}",
reason
)));
}
_ => {}
}
// Append observation as a tool message so the model sees it in context.
self.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: act.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(observation),
},
));
continue;
}
ToolCallAction::Continue => {}
}
// Append the assistant message with tool_calls to history.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Execute the tool.
let observation = match &self.config.tool_executor {
Some(exec) => {
let result = exec(act.name.clone(), act.args.clone()).await;
match result {
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
Err(e) => serde_json::json!({ "error": e }).to_string(),
}
}
None => serde_json::json!({
"error": format!("no tool executor registered for '{}'", act.name)
})
.to_string(),
};
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
// Hook: observation
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation step: {}",
reason
)));
}
_ => {}
}
// Append observation as a tool message so the model sees it in context.
self.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: act.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(observation),
},
));
}
}
/// Returns the number of steps executed so far.
pub fn steps(&self) -> usize {
self.step_count
}
}
// ---------------------------------------------------------------------------
// Response parsing
// ---------------------------------------------------------------------------
struct ParsedReActResponse {
thought: String,
action: Option<Action>,
answer: Option<String>,
}
/// Parse the AI's text response into a ReAct step.
///
/// The AI is prompted (via system prompt in `ReactAgent::new`) to respond with
/// JSON in one of these forms:
///
/// ```json
/// { "thought": "...", "action": { "name": "tool_name", "arguments": {...} } }
/// { "thought": "...", "answer": "final answer text" }
/// ```
fn parse_react_response(content: &str) -> ParsedReActResponse {
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
#[derive(serde::Deserialize)]
struct RawStep {
#[serde(default)]
thought: Option<String>,
#[serde(default)]
action: Option<RawAction>,
#[serde(default)]
answer: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
#[derive(serde::Deserialize)]
struct RawAction {
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
match serde_json::from_str::<RawStep>(&json_str) {
Ok(raw) => {
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
let answer = raw.answer;
let action = raw.action.map(|a| Action {
id: Uuid::new_v4().to_string(),
name: a.name.unwrap_or_default(),
args: a.args.unwrap_or(serde_json::Value::Null),
});
// Handle flat format: { "name": "...", "arguments": {...} }
let action = action.or_else(|| {
if raw.name.is_some() || raw.args.is_some() {
Some(Action {
id: Uuid::new_v4().to_string(),
name: raw.name.unwrap_or_default(),
args: raw.args.unwrap_or(serde_json::Value::Null),
})
} else {
None
}
});
ParsedReActResponse {
thought,
action,
answer,
}
}
Err(_) => ParsedReActResponse {
thought: content.to_string(),
action: None,
answer: None,
},
}
}
/// Extract the first JSON object or array from a string, handling markdown fences.
fn extract_json(s: &str) -> Option<String> {
let trimmed = s.trim();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Some(trimmed.to_string());
}
for line in trimmed.lines() {
let line = line.trim();
if line.starts_with("```json") || line.starts_with("```") {
let mut buf = String::new();
let mut found_start = false;
for l in trimmed.lines() {
let l = l.trim();
if !found_start && (l == "```json" || l == "```") {
found_start = true;
continue;
}
if found_start && l == "```" {
break;
}
if found_start {
buf.push_str(l);
buf.push('\n');
}
}
let result = buf.trim().to_string();
if !result.is_empty() {
return Some(result);
}
}
}
None
}
/// Build an assistant message with tool_calls from an Action.
#[allow(deprecated)]
fn build_tool_call_message(action: &Action) -> ChatCompletionRequestMessage {
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(format!(
"Action: {}",
action.name
))),
name: None,
refusal: None,
audio: None,
tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
ChatCompletionMessageToolCall {
id: action.id.clone(),
function: FunctionCall {
name: action.name.clone(),
arguments: fn_arg_str,
},
},
)]),
function_call: None,
})
}