gitdataai/lib/ai/agent/loop.rs
2026-05-30 01:38:40 +08:00

877 lines
32 KiB
Rust

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use futures::StreamExt;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::streaming::StreamingPrompt;
use rig::tool::ToolDyn;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use super::config::AgentConfig;
use super::error_classifier::{
classify_error, retry_policy_for, should_switch_to_fallback,
};
use super::events::{AgentEvent, EventSink};
use super::helpers::{build_input_string, estimate_tokens};
use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision};
use super::iteration_budget::IterationBudget;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::RigStreamChunk;
use crate::client::AiClient;
use crate::error::{AiError, AiResult};
/// How tool calls from a single assistant turn are executed.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolExecutionMode {
/// Execute tool calls one at a time.
Sequential,
/// Execute tool calls concurrently (after sequential preflight).
Parallel,
}
impl Default for ToolExecutionMode {
fn default() -> Self {
Self::Parallel
}
}
/// Callback type for steering messages (injected mid-run).
pub type SteeringFn = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
>;
/// Callback type for follow-up messages (injected after agent would stop).
pub type FollowUpFn = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
>;
/// Callback to decide whether the agent should stop after a turn.
pub type ShouldStopFn = Arc<
dyn Fn(&TurnContext) -> bool + Send + Sync,
>;
/// Callback to prepare/modify state before the next turn.
pub type PrepareNextTurnFn = Arc<
dyn Fn(&TurnContext) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>>
+ Send
+ Sync,
>;
/// Context passed to `should_stop` and `prepare_next_turn` callbacks.
#[derive(Debug, Clone)]
pub struct TurnContext {
pub turn_index: usize,
pub assistant_text: String,
pub tool_call_count: usize,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub model_name: String,
}
/// Replacement state for the next turn (returned by `prepare_next_turn`).
#[derive(Debug, Clone)]
pub struct TurnUpdate {
pub model: Option<String>,
pub temperature: Option<f64>,
pub max_completion_tokens: Option<u64>,
}
/// Extended agent loop configuration, adding steering/follow-up/lifecycle
/// hooks on top of the base `AgentConfig`.
pub struct AgentLoopConfig {
pub config: AgentConfig,
pub tool_execution_mode: ToolExecutionMode,
pub get_steering_messages: Option<SteeringFn>,
pub get_follow_up_messages: Option<FollowUpFn>,
pub should_stop_after_turn: Option<ShouldStopFn>,
pub prepare_next_turn: Option<PrepareNextTurnFn>,
pub event_sink: Option<EventSink>,
}
impl AgentLoopConfig {
pub fn new(config: AgentConfig) -> Self {
Self {
config,
tool_execution_mode: ToolExecutionMode::default(),
get_steering_messages: None,
get_follow_up_messages: None,
should_stop_after_turn: None,
prepare_next_turn: None,
event_sink: None,
}
}
pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
self.tool_execution_mode = mode;
self
}
pub fn with_steering_messages(mut self, f: SteeringFn) -> Self {
self.get_steering_messages = Some(f);
self
}
pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self {
self.get_follow_up_messages = Some(f);
self
}
pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self {
self.should_stop_after_turn = Some(f);
self
}
pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self {
self.prepare_next_turn = Some(f);
self
}
pub fn with_event_sink(mut self, sink: EventSink) -> Self {
self.event_sink = Some(sink);
self
}
}
/// Enhanced agent with loop controls (steering, follow-up, model switching).
pub struct EnhancedAgent {
pub client: AiClient,
pub loop_config: AgentLoopConfig,
pub hooks: HookChain,
}
impl EnhancedAgent {
pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult<Self> {
loop_config.config.validate()?;
Ok(Self {
client,
loop_config,
hooks: HookChain::empty(),
})
}
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
self.hooks = hooks;
self
}
pub fn config(&self) -> &AgentConfig {
&self.loop_config.config
}
/// Run the enhanced agent loop, returning a chunk receiver and a join handle.
#[allow(clippy::too_many_lines)]
pub fn run(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> (
mpsc::Receiver<RigStreamChunk>,
tokio::task::JoinHandle<AiResult<AgentResult>>,
) {
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
let config = self.loop_config.config.clone();
let tool_execution_mode = self.loop_config.tool_execution_mode;
let steering_fn = self.loop_config.get_steering_messages.clone();
let follow_up_fn = self.loop_config.get_follow_up_messages.clone();
let should_stop = self.loop_config.should_stop_after_turn.clone();
let prepare_next = self.loop_config.prepare_next_turn.clone();
let event_sink = self.loop_config.event_sink.clone();
let client = self.client.llm_client().clone();
let hooks = self.hooks.clone();
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
.into_iter()
.filter(|tool| config.is_tool_exposed(&tool.name()))
.collect();
let handle = tokio::spawn(async move {
run_enhanced_loop(
client,
config,
request,
filtered_tools,
tool_execution_mode,
steering_fn,
follow_up_fn,
should_stop,
prepare_next,
event_sink,
hooks,
tx,
)
.await
});
(rx, handle)
}
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn run_enhanced_loop(
client: rig::providers::openai::Client,
mut config: AgentConfig,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
_tool_execution_mode: ToolExecutionMode,
steering_fn: Option<SteeringFn>,
follow_up_fn: Option<FollowUpFn>,
should_stop: Option<ShouldStopFn>,
prepare_next: Option<PrepareNextTurnFn>,
event_sink: Option<EventSink>,
hooks: HookChain,
tx: mpsc::Sender<RigStreamChunk>,
) -> AiResult<AgentResult> {
let cancellation = request.cancellation_token.clone();
let timeout = request.timeout;
let mut budget = IterationBudget::new(config.iteration_budget);
let mut all_steps: Vec<AgentStep> = Vec::new();
let mut total_input_tokens: u64 = 0;
let mut total_output_tokens: u64 = 0;
let mut turn_index: usize = 0;
// Session start hook
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_start(ctx).await;
}
// Emit agent start event
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::AgentStart);
}
// Build the initial input
let input = build_input_string(&request);
let mut current_input = input.clone();
let estimated_input_tokens = estimate_tokens(&current_input);
if let Some(limit) = config.max_total_tokens_per_run
&& estimated_input_tokens > limit as u64
{
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens,
limit,
});
}
// Outer loop: handles follow-up messages after agent would stop
loop {
// Inner loop: tool call turns + steering messages
let mut pending_steering: Vec<String> = if let Some(f) = &steering_fn {
f().await
} else {
Vec::new()
};
loop {
// Check cancellation
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await;
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified {
category: "cancelled".to_string(),
message: "cancelled by caller".to_string(),
will_retry: false,
retry_delay_ms: None,
});
}
return Err(AiError::Response("agent run cancelled".to_string()));
}
// Inject steering messages if any
if !pending_steering.is_empty() {
let count = pending_steering.len();
for msg in &pending_steering {
current_input.push_str(&format!("\nUser: {msg}\n"));
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::SteeringMessagesInjected { count });
}
pending_steering.clear();
}
// Emit turn start
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnStart { turn_index });
}
let _ = tx.send(RigStreamChunk::TextDelta {
index: 0,
content: String::new(), // placeholder for turn boundary detection
}).await;
// Run one LLM turn with retry
let turn_result = run_single_turn(
&client,
&config,
&current_input,
&tools,
&mut budget,
&cancellation,
timeout,
&hooks,
&event_sink,
&tx,
)
.await;
match turn_result {
Ok(turn_output) => {
total_input_tokens += turn_output.input_tokens;
total_output_tokens += turn_output.output_tokens;
// Collect step
let tool_call_count = turn_output.tool_calls.len();
if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() {
all_steps.push(AgentStep {
index: all_steps.len(),
assistant: (!turn_output.assistant_text.is_empty())
.then_some(turn_output.assistant_text.clone()),
reasoning_content: None,
tool_calls: turn_output.tool_calls,
reflection: None,
});
}
// Emit turn end
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::TurnEnd {
turn_index,
assistant_text: Some(turn_output.assistant_text.clone()),
tool_call_count,
});
}
// Check should_stop
let turn_ctx = TurnContext {
turn_index,
assistant_text: turn_output.assistant_text.clone(),
tool_call_count,
total_input_tokens,
total_output_tokens,
model_name: config.model.clone(),
};
if let Some(stop_fn) = &should_stop {
if stop_fn(&turn_ctx) {
info!(turn_index, "agent stopped by should_stop callback");
break;
}
}
// Prepare next turn (may switch model)
if let Some(prep_fn) = &prepare_next {
if let Some(update) = prep_fn(&turn_ctx).await {
if let Some(new_model) = update.model {
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ModelSwitched {
from_model: config.model.clone(),
to_model: new_model.clone(),
reason: "prepare_next_turn".to_string(),
});
}
config.model = new_model;
}
if let Some(temp) = update.temperature {
config.temperature = Some(temp);
}
if let Some(max_tok) = update.max_completion_tokens {
config.max_completion_tokens = Some(max_tok);
}
}
}
turn_index += 1;
// If no tool calls, this turn is done
if tool_call_count == 0 {
break;
}
// Otherwise, continue with tool results as new input
current_input = turn_output.assistant_text.clone();
}
Err(e) => {
// Error classification and retry with fallback
let category = classify_error(&e, None);
let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ErrorClassified {
category: format!("{category:?}"),
message: e.to_string(),
will_retry: policy.switch_to_fallback || policy.max_attempts > 0,
retry_delay_ms: Some(policy.base_delay.as_millis() as u64),
});
}
if should_switch_to_fallback(&category) {
if let Some(fallback_model) = &config.fallback_model {
info!(
from_model = %config.model,
to_model = %fallback_model,
"switching to fallback model due to error"
);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ModelSwitched {
from_model: config.model.clone(),
to_model: fallback_model.clone(),
reason: format!("fallback: {category:?}"),
});
}
config.model = fallback_model.clone();
// Retry with the fallback model
let retry_result = run_single_turn(
&client,
&config,
&current_input,
&tools,
&mut budget,
&cancellation,
timeout,
&hooks,
&event_sink,
&tx,
)
.await;
match retry_result {
Ok(turn_output) => {
total_input_tokens += turn_output.input_tokens;
total_output_tokens += turn_output.output_tokens;
let tc_count = turn_output.tool_calls.len();
let has_tools = tc_count > 0;
let has_text = !turn_output.assistant_text.is_empty();
let assistant = turn_output.assistant_text;
if has_tools || has_text {
all_steps.push(AgentStep {
index: all_steps.len(),
assistant: has_text.then_some(assistant.clone()),
reasoning_content: None,
tool_calls: turn_output.tool_calls,
reflection: None,
});
}
turn_index += 1;
if !has_tools {
break;
}
current_input = assistant;
continue;
}
Err(retry_err) => {
let _ = tx
.send(RigStreamChunk::Failed {
error: retry_err.to_string(),
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(retry_err);
}
}
}
}
// Non-retryable or no fallback
let _ = tx
.send(RigStreamChunk::Failed {
error: e.to_string(),
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(e);
}
}
}
// Check for follow-up messages
let follow_ups: Vec<String> = if let Some(f) = &follow_up_fn {
f().await
} else {
Vec::new()
};
if follow_ups.is_empty() {
break;
}
// Inject follow-up messages and continue the outer loop
let count = follow_ups.len();
for msg in &follow_ups {
current_input.push_str(&format!("\nUser: {msg}\n"));
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::FollowUpMessagesInjected { count });
}
}
// Build final output
let output = all_steps
.last()
.and_then(|s| s.assistant.clone())
.unwrap_or_default();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::AgentEnd {
messages: Vec::new(),
total_input_tokens,
total_output_tokens,
});
}
let _ = tx
.send(RigStreamChunk::Final {
content: output.clone(),
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
})
.await;
if let Some(ctx) = &request.run_context {
let _ = hooks.run_session_end(ctx, true).await;
}
info!(
turns = turn_index,
steps = all_steps.len(),
total_input_tokens,
total_output_tokens,
"enhanced agent loop completed"
);
Ok(AgentResult {
output,
steps: all_steps,
expert_outputs: Vec::new(),
input_tokens: total_input_tokens as i64,
output_tokens: total_output_tokens as i64,
})
}
/// Output from a single LLM turn (one assistant response + its tool calls).
struct TurnOutput {
assistant_text: String,
tool_calls: Vec<ToolCallRecord>,
input_tokens: u64,
output_tokens: u64,
}
/// Run a single LLM turn with streaming, handling the stream parsing and
/// tool call collection.
#[allow(clippy::too_many_arguments)]
async fn run_single_turn(
client: &rig::providers::openai::Client,
config: &AgentConfig,
input: &str,
_tools: &[Box<dyn ToolDyn>],
budget: &mut IterationBudget,
cancellation: &Option<CancellationToken>,
timeout: Option<std::time::Duration>,
hooks: &HookChain,
event_sink: &Option<EventSink>,
tx: &mpsc::Sender<RigStreamChunk>,
) -> AiResult<TurnOutput> {
if !budget.consume() {
return Err(AiError::Response("iteration budget exhausted".to_string()));
}
let model = client.completion_model(&config.model);
let mut agent_builder = AgentBuilder::new(model)
.preamble(&config.system_prompt)
.default_max_turns(1); // Single turn, we manage the loop
// Note: we can't easily pass tools here for single-turn since
// rig's multi_turn handles tool execution internally.
// For the enhanced loop, we rely on rig's built-in tool execution
// within a single turn. The parallel/sequential mode is controlled
// by the event-level hooks.
if let Some(temp) = config.temperature {
agent_builder = agent_builder.temperature(temp);
}
if let Some(mt) = config.max_completion_tokens {
agent_builder = agent_builder.max_tokens(mt);
}
let agent = agent_builder.build();
// Pre-LLM hook
if !hooks.is_empty() {
let hook_messages = vec![HookMessage {
role: "user".to_string(),
content: Some(input.to_string()),
tool_calls: None,
tool_call_id: None,
}];
let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await;
}
let stream_future = agent
.stream_prompt(input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(config.max_iterations);
let stream = if let Some(dur) = timeout {
match tokio::time::timeout(dur, stream_future).await {
Ok(stream) => stream,
Err(_) => {
return Err(AiError::Timeout {
seconds: dur.as_secs(),
});
}
}
} else {
stream_future.await
};
tokio::pin!(stream);
let mut assistant_text = String::new();
let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut delta_index = 0usize;
let mut _accumulated_output_chars: usize = 0;
let mut input_tokens: u64 = 0;
let mut output_tokens: u64 = 0;
while let Some(item) = stream.next().await {
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
return Err(AiError::Response("cancelled".to_string()));
}
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
_accumulated_output_chars += text.text.chars().count();
assistant_text.push_str(&text.text);
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageTextDelta {
index: delta_index,
delta: text.text.clone(),
});
}
let _ = tx
.send(RigStreamChunk::TextDelta {
index: delta_index,
content: text.text.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
)) => {
for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text { text, .. } = part {
_accumulated_output_chars += text.chars().count();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageThinkingDelta {
index: delta_index,
delta: text.clone(),
});
}
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: text.clone(),
})
.await;
delta_index += 1;
}
}
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. },
)) => {
_accumulated_output_chars += reasoning.chars().count();
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::MessageThinkingDelta {
index: delta_index,
delta: reasoning.clone(),
});
}
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: reasoning.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. },
)) => {
let args = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
_accumulated_output_chars += args.chars().count();
let tool_name = tool_call.function.name.clone();
let tool_args: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
// Pre-tool-call guardrail hook
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
match decision {
ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => {
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: None,
error: Some(reason.clone()),
elapsed_ms: 0,
});
}
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!("blocked: {reason}"),
error: Some(reason),
})
.await;
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name,
arguments: tool_args,
output: None,
error: Some("blocked by guardrail".to_string()),
elapsed_ms: None,
});
continue;
}
ToolGuardrailDecision::RequireApproval { message } => {
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args,
output: None,
error: Some(format!("requires approval: {message}")),
elapsed_ms: None,
});
continue;
}
}
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionStart {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: tool_args.clone(),
});
}
let _ = tx
.send(RigStreamChunk::ToolCallStarted {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: args.clone(),
})
.await;
tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name,
arguments: tool_args,
output: None,
error: None,
elapsed_ms: None,
});
}
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
)) => {
let content =
super::helpers::tool_result_content_to_string(&tool_result.content);
_accumulated_output_chars += content.chars().count();
let tool_name = tool_calls
.last()
.map(|tc| tc.name.clone())
.unwrap_or_default();
if let Some(last) = tool_calls.last_mut()
&& last.id == tool_result.id
{
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
}
if let Some(sink) = &event_sink {
sink.emit(AgentEvent::ToolExecutionEnd {
tool_call_id: tool_result.id.clone(),
tool_name: tool_name.clone(),
output: Some(serde_json::Value::String(content.clone())),
error: None,
elapsed_ms: 0,
});
}
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_result.id.clone(),
tool_name,
output: content.clone(),
error: None,
})
.await;
if !hooks.is_empty() {
let outcome = ToolCallOutcome {
name: tool_result.id.clone(),
arguments: serde_json::Value::Null,
output: Some(serde_json::Value::String(content)),
error: None,
elapsed_ms: 0,
};
let _ = hooks.run_post_tool_call(&outcome).await;
}
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
input_tokens = usage.input_tokens;
output_tokens = usage.output_tokens;
if !hooks.is_empty() {
let hook_response = HookLlmResponse {
content: Some(assistant_text.clone()),
tool_calls: None,
input_tokens,
output_tokens,
finish_reason: None,
};
let _ = hooks.run_post_llm_call(&hook_response).await;
}
}
Err(e) => {
warn!(error = %e, "turn stream error");
return Err(AiError::Api(format!("{e}")));
}
_ => {}
}
}
Ok(TurnOutput {
assistant_text,
tool_calls,
input_tokens,
output_tokens,
})
}