gitdataai/libs/agent/react/loop_core.rs
ZhenYi 10c0cc007b refactor(agent): split into submodules and add Qdrant embedding
- Split agent crate into client/, model/, agent/ subdirs
- Add billing.rs for token usage recording
- Add sync.rs for upstream model sync
- EmbedService: Qdrant-backed vector memory for semantic search
- ChatService: wire EmbedService for memory lookup, passive skill awareness
- ReAct loop: streamline with tokio::select! and proper error handling
2026-04-25 20:09:33 +08:00

414 lines
14 KiB
Rust

//! ReAct (Reasoning + Acting) agent core.
use uuid::Uuid;
use std::sync::Arc;
use crate::call_with_params;
use crate::client::types::ChatRequestMessage;
use crate::error::{AgentError, Result};
use crate::react::hooks::{Hook, HookAction, NoopHook, ToolCallAction};
use crate::react::types::{Action, ReactConfig, ReactStep};
pub use crate::react::types::{ReactConfig as ReActConfig, ReactStep as ReActStep};
/// A ReAct agent that performs multi-step tool-augmented reasoning.
#[derive(Clone)]
pub struct ReactAgent {
messages: Vec<ChatRequestMessage>,
#[allow(dead_code)]
tool_definitions: Vec<serde_json::Value>,
config: ReactConfig,
step_count: usize,
hook: Arc<dyn Hook>,
}
impl ReactAgent {
/// Create a new agent with a system prompt and tool definitions (as JSON values).
pub fn new(
system_prompt: &str,
tools: Vec<serde_json::Value>,
config: ReactConfig,
) -> Self {
let messages = vec![ChatRequestMessage::system(system_prompt)];
Self {
messages,
tool_definitions: tools,
config,
step_count: 0,
hook: Arc::new(NoopHook),
}
}
/// Add an initial user message to the conversation.
pub fn add_user_message(&mut self, content: &str) {
self.messages.push(ChatRequestMessage::user(content));
}
/// Attach a hook to observe and control the agent loop.
///
/// Hooks can log steps, filter content, inject custom tool results,
/// or terminate the loop early. Multiple `.with_hook()` calls replace
/// the previous hook.
pub fn with_hook<H: Hook + 'static>(mut self, hook: H) -> Self {
self.hook = Arc::new(hook);
self
}
/// Run the ReAct loop until a final answer is produced or `max_steps` is reached.
pub async fn run<C>(
&mut self,
model_name: &str,
client_config: &crate::client::AiClientConfig,
mut on_chunk: C,
) -> Result<String>
where
C: FnMut(ReactStep) + Send,
{
loop {
if self.step_count >= self.config.max_steps {
let msg = format!(
"Agent reached maximum reasoning steps ({}) without producing a final answer.",
self.config.max_steps
);
on_chunk(ReactStep::Answer {
step: self.step_count,
answer: msg.clone(),
});
return Ok(msg);
}
self.step_count += 1;
let step = self.step_count;
// For ReAct we force text-only responses so the model follows our JSON-in-text format.
let tool_choice_str = if self.tool_definitions.is_empty() {
None
} else {
Some("none")
};
let response = call_with_params(
&self.messages,
model_name,
client_config,
0.2, // temperature
4096, // max output tokens
None,
if self.tool_definitions.is_empty() {
None
} else {
Some(&self.tool_definitions)
},
tool_choice_str,
)
.await?;
let parsed = parse_react_response(&response.content);
let answer = parsed.answer.clone();
let action = parsed.action.clone();
on_chunk(ReactStep::Thought {
step,
thought: parsed.thought.clone(),
});
match self.hook.on_thought(step, &parsed.thought).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at thought step: {}",
reason
)));
}
HookAction::Skip => {}
HookAction::Continue => {}
}
// Final answer — emit and return.
if let Some(ans) = answer {
on_chunk(ReactStep::Answer {
step,
answer: ans.clone(),
});
match self.hook.on_answer(step, &ans).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at answer step: {}",
reason
)));
}
_ => {}
}
return Ok(ans);
}
// No answer — either do a tool call or fall back.
let Some(act) = action else {
let content = response.content.clone();
on_chunk(ReactStep::Answer {
step,
answer: content.clone(),
});
match self.hook.on_answer(step, &content).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at fallback answer: {}",
reason
)));
}
_ => {}
}
return Ok(content);
};
on_chunk(ReactStep::Action {
step,
action: act.clone(),
});
let args_json = serde_json::to_string(&act.args).unwrap_or_else(|_| "{}".to_string());
match self.hook.on_tool_call(step, &act.name, &args_json).await {
ToolCallAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at tool call: {}",
reason
)));
}
ToolCallAction::Skip(injected_result) => {
let observation = injected_result;
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation (injected): {}",
reason
)));
}
_ => {}
}
// Append assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
continue;
}
ToolCallAction::Continue => {}
}
// Append the assistant message with tool_calls.
let assistant_msg = build_tool_call_message(&act);
self.messages.push(assistant_msg);
// Execute the tool.
let observation = match &self.config.tool_executor {
Some(exec) => {
let result = exec(act.name.clone(), act.args.clone()).await;
match result {
Ok(v) => serde_json::to_string(&v).unwrap_or_else(|_| "null".to_string()),
Err(e) => serde_json::json!({ "error": e }).to_string(),
}
}
None => serde_json::json!({
"error": format!("no tool executor registered for '{}'", act.name)
})
.to_string(),
};
on_chunk(ReactStep::Observation {
step,
observation: observation.clone(),
});
match self.hook.on_observation(step, &observation).await {
HookAction::Terminate(reason) => {
return Err(AgentError::Internal(format!(
"hook terminated at observation step: {}",
reason
)));
}
_ => {}
}
// Append observation as a tool message.
self.messages.push(ChatRequestMessage::tool(&act.id, observation));
}
}
/// Returns the number of steps executed so far.
pub fn steps(&self) -> usize {
self.step_count
}
}
// ---------------------------------------------------------------------------
// Response parsing
// ---------------------------------------------------------------------------
struct ParsedReActResponse {
thought: String,
action: Option<Action>,
answer: Option<String>,
}
fn parse_react_response(content: &str) -> ParsedReActResponse {
let json_str = extract_json(content).unwrap_or_else(|| content.trim().to_string());
#[derive(serde::Deserialize)]
struct RawStep {
#[serde(default)]
thought: Option<String>,
#[serde(default)]
action: Option<RawAction>,
#[serde(default)]
answer: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
#[derive(serde::Deserialize)]
struct RawAction {
#[serde(default)]
name: Option<String>,
#[serde(default, rename = "arguments")]
args: Option<serde_json::Value>,
}
match serde_json::from_str::<RawStep>(&json_str) {
Ok(raw) => {
let thought = raw.thought.unwrap_or_else(|| "Thinking...".to_string());
let answer = raw.answer;
let action = raw.action.map(|a| Action {
id: Uuid::new_v4().to_string(),
name: a.name.unwrap_or_default(),
args: a.args.unwrap_or(serde_json::Value::Null),
});
let action = action.or_else(|| {
if raw.name.is_some() || raw.args.is_some() {
Some(Action {
id: Uuid::new_v4().to_string(),
name: raw.name.unwrap_or_default(),
args: raw.args.unwrap_or(serde_json::Value::Null),
})
} else {
None
}
});
ParsedReActResponse {
thought,
action,
answer,
}
}
Err(_) => ParsedReActResponse {
thought: content.to_string(),
action: None,
answer: None,
},
}
}
fn extract_json(s: &str) -> Option<String> {
let trimmed = s.trim();
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return Some(trimmed.to_string());
}
for line in trimmed.lines() {
let line = line.trim();
if line.starts_with("```json") || line == "```" {
let mut buf = String::new();
let mut found_start = false;
for l in trimmed.lines() {
let l = l.trim();
if !found_start && (l == "```json" || l == "```") {
found_start = true;
continue;
}
if found_start && l == "```" {
break;
}
if found_start {
buf.push_str(l);
buf.push('\n');
}
}
let result = buf.trim().to_string();
if !result.is_empty() {
return Some(result);
}
}
}
let chars: Vec<char> = trimmed.chars().collect();
for i in 0..chars.len() {
let c = chars[i];
if (c == '{' || c == '[') && i > 0 {
let prev = chars[i - 1];
if prev.is_alphanumeric() || prev == '_' || prev == '"' || prev == '\'' {
continue;
}
let candidate: String = chars[i..].iter().collect();
if serde_json::from_str::<serde_json::Value>(&candidate).is_ok() {
return Some(candidate.trim_end().to_string());
}
let mut depth = 0isize;
let mut in_string = false;
let mut escaped = false;
for (j, c) in candidate.char_indices() {
if escaped { escaped = false; continue; }
if c == '\\' { escaped = true; continue; }
if c == '"' { in_string = !in_string; continue; }
if in_string { continue; }
if c == '{' || c == '[' { depth += 1; }
if c == '}' || c == ']' { depth -= 1; }
if depth == 0 {
let json_end = j + c.len_utf8();
let trimmed_candidate = &candidate[..json_end];
if serde_json::from_str::<serde_json::Value>(trimmed_candidate).is_ok() {
return Some(trimmed_candidate.to_string());
}
}
}
}
}
None
}
/// Build an assistant message with tool_calls from an Action.
fn build_tool_call_message(action: &Action) -> ChatRequestMessage {
let fn_arg_str = serde_json::to_string(&action.args).unwrap_or_else(|_| "{}".to_string());
ChatRequestMessage {
role: "assistant".into(),
content: Some(format!("Action: {}", action.name)),
name: None,
tool_call_id: None,
tool_calls: Some(vec![crate::client::types::ToolCall {
id: action.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: action.name.clone(),
arguments: fn_arg_str,
},
}]),
}
}