934 lines
34 KiB
Rust
934 lines
34 KiB
Rust
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
|
|
use futures::StreamExt;
|
|
use rig::agent::AgentBuilder;
|
|
use rig::client::CompletionClient;
|
|
use rig::streaming::StreamingPrompt;
|
|
use rig::tool::ToolDyn;
|
|
use tokio::sync::mpsc;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{info, warn};
|
|
|
|
use super::RigStreamChunk;
|
|
use super::config::AgentConfig;
|
|
use super::error_classifier::{
|
|
classify_error, retry_policy_for, should_switch_to_fallback,
|
|
};
|
|
use super::events::{AgentEvent, EventSink};
|
|
use super::helpers::{build_input_string, estimate_tokens};
|
|
use super::hooks::{
|
|
HookChain, HookLlmResponse, HookMessage, ToolCallOutcome,
|
|
ToolGuardrailDecision,
|
|
};
|
|
use super::iteration_budget::IterationBudget;
|
|
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
|
|
use crate::client::AiClient;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
/// How tool calls from a single assistant turn are executed.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ToolExecutionMode {
|
|
/// Execute tool calls one at a time.
|
|
Sequential,
|
|
/// Execute tool calls concurrently (after sequential preflight).
|
|
Parallel,
|
|
}
|
|
|
|
impl Default for ToolExecutionMode {
|
|
fn default() -> Self {
|
|
Self::Parallel
|
|
}
|
|
}
|
|
|
|
/// Callback type for steering messages (injected mid-run).
|
|
pub type SteeringFn = Arc<
|
|
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
|
|
>;
|
|
|
|
/// Callback type for follow-up messages (injected after agent would stop).
|
|
pub type FollowUpFn = Arc<
|
|
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
|
|
>;
|
|
|
|
/// Callback to decide whether the agent should stop after a turn.
|
|
pub type ShouldStopFn = Arc<dyn Fn(&TurnContext) -> bool + Send + Sync>;
|
|
|
|
/// Callback to prepare/modify state before the next turn.
|
|
pub type PrepareNextTurnFn = Arc<
|
|
dyn Fn(
|
|
&TurnContext,
|
|
) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
/// Context passed to `should_stop` and `prepare_next_turn` callbacks.
|
|
#[derive(Debug, Clone)]
|
|
pub struct TurnContext {
|
|
pub turn_index: usize,
|
|
pub assistant_text: String,
|
|
pub tool_call_count: usize,
|
|
pub total_input_tokens: u64,
|
|
pub total_output_tokens: u64,
|
|
pub model_name: String,
|
|
}
|
|
|
|
/// Replacement state for the next turn (returned by `prepare_next_turn`).
|
|
#[derive(Debug, Clone)]
|
|
pub struct TurnUpdate {
|
|
pub model: Option<String>,
|
|
pub temperature: Option<f64>,
|
|
pub max_completion_tokens: Option<u64>,
|
|
}
|
|
|
|
/// Extended agent loop configuration, adding steering/follow-up/lifecycle
|
|
/// hooks on top of the base `AgentConfig`.
|
|
pub struct AgentLoopConfig {
|
|
pub config: AgentConfig,
|
|
pub tool_execution_mode: ToolExecutionMode,
|
|
pub get_steering_messages: Option<SteeringFn>,
|
|
pub get_follow_up_messages: Option<FollowUpFn>,
|
|
pub should_stop_after_turn: Option<ShouldStopFn>,
|
|
pub prepare_next_turn: Option<PrepareNextTurnFn>,
|
|
pub event_sink: Option<EventSink>,
|
|
}
|
|
|
|
impl AgentLoopConfig {
|
|
pub fn new(config: AgentConfig) -> Self {
|
|
Self {
|
|
config,
|
|
tool_execution_mode: ToolExecutionMode::default(),
|
|
get_steering_messages: None,
|
|
get_follow_up_messages: None,
|
|
should_stop_after_turn: None,
|
|
prepare_next_turn: None,
|
|
event_sink: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
|
|
self.tool_execution_mode = mode;
|
|
self
|
|
}
|
|
|
|
pub fn with_steering_messages(mut self, f: SteeringFn) -> Self {
|
|
self.get_steering_messages = Some(f);
|
|
self
|
|
}
|
|
|
|
pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self {
|
|
self.get_follow_up_messages = Some(f);
|
|
self
|
|
}
|
|
|
|
pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self {
|
|
self.should_stop_after_turn = Some(f);
|
|
self
|
|
}
|
|
|
|
pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self {
|
|
self.prepare_next_turn = Some(f);
|
|
self
|
|
}
|
|
|
|
pub fn with_event_sink(mut self, sink: EventSink) -> Self {
|
|
self.event_sink = Some(sink);
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Enhanced agent with loop controls (steering, follow-up, model switching).
|
|
pub struct EnhancedAgent {
|
|
pub client: AiClient,
|
|
pub loop_config: AgentLoopConfig,
|
|
pub hooks: HookChain,
|
|
}
|
|
|
|
impl EnhancedAgent {
|
|
pub fn new(
|
|
client: AiClient,
|
|
loop_config: AgentLoopConfig,
|
|
) -> AiResult<Self> {
|
|
loop_config.config.validate()?;
|
|
Ok(Self {
|
|
client,
|
|
loop_config,
|
|
hooks: HookChain::empty(),
|
|
})
|
|
}
|
|
|
|
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
|
|
self.hooks = hooks;
|
|
self
|
|
}
|
|
|
|
pub fn config(&self) -> &AgentConfig {
|
|
&self.loop_config.config
|
|
}
|
|
|
|
/// Run the enhanced agent loop, returning a chunk receiver and a join handle.
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn run(
|
|
&self,
|
|
request: AgentRequest,
|
|
tools: Vec<Box<dyn ToolDyn>>,
|
|
) -> (
|
|
mpsc::Receiver<RigStreamChunk>,
|
|
tokio::task::JoinHandle<AiResult<AgentResult>>,
|
|
) {
|
|
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
|
|
|
|
let config = self.loop_config.config.clone();
|
|
let tool_execution_mode = self.loop_config.tool_execution_mode;
|
|
let steering_fn = self.loop_config.get_steering_messages.clone();
|
|
let follow_up_fn = self.loop_config.get_follow_up_messages.clone();
|
|
let should_stop = self.loop_config.should_stop_after_turn.clone();
|
|
let prepare_next = self.loop_config.prepare_next_turn.clone();
|
|
let event_sink = self.loop_config.event_sink.clone();
|
|
let client = self.client.llm_client().clone();
|
|
let hooks = self.hooks.clone();
|
|
|
|
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
|
|
.into_iter()
|
|
.filter(|tool| config.is_tool_exposed(&tool.name()))
|
|
.collect();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
run_enhanced_loop(
|
|
client,
|
|
config,
|
|
request,
|
|
filtered_tools,
|
|
tool_execution_mode,
|
|
steering_fn,
|
|
follow_up_fn,
|
|
should_stop,
|
|
prepare_next,
|
|
event_sink,
|
|
hooks,
|
|
tx,
|
|
)
|
|
.await
|
|
});
|
|
|
|
(rx, handle)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
|
async fn run_enhanced_loop(
|
|
client: rig::providers::openai::Client,
|
|
mut config: AgentConfig,
|
|
request: AgentRequest,
|
|
tools: Vec<Box<dyn ToolDyn>>,
|
|
_tool_execution_mode: ToolExecutionMode,
|
|
steering_fn: Option<SteeringFn>,
|
|
follow_up_fn: Option<FollowUpFn>,
|
|
should_stop: Option<ShouldStopFn>,
|
|
prepare_next: Option<PrepareNextTurnFn>,
|
|
event_sink: Option<EventSink>,
|
|
hooks: HookChain,
|
|
tx: mpsc::Sender<RigStreamChunk>,
|
|
) -> AiResult<AgentResult> {
|
|
let cancellation = request.cancellation_token.clone();
|
|
let timeout = request.timeout;
|
|
let mut budget = IterationBudget::new(config.iteration_budget);
|
|
let mut all_steps: Vec<AgentStep> = Vec::new();
|
|
let mut total_input_tokens: u64 = 0;
|
|
let mut total_output_tokens: u64 = 0;
|
|
let mut turn_index: usize = 0;
|
|
|
|
// Session start hook
|
|
if let Some(ctx) = &request.run_context {
|
|
let _ = hooks.run_session_start(ctx).await;
|
|
}
|
|
|
|
// Emit agent start event
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::AgentStart);
|
|
}
|
|
|
|
// Build the initial input
|
|
let input = build_input_string(&request);
|
|
let mut current_input = input.clone();
|
|
let estimated_input_tokens = estimate_tokens(¤t_input);
|
|
|
|
if let Some(limit) = config.max_total_tokens_per_run
|
|
&& estimated_input_tokens > limit as u64
|
|
{
|
|
return Err(AiError::TokenBudgetExceeded {
|
|
estimated: estimated_input_tokens,
|
|
limit,
|
|
});
|
|
}
|
|
|
|
// Outer loop: handles follow-up messages after agent would stop
|
|
loop {
|
|
// Inner loop: tool call turns + steering messages
|
|
let mut pending_steering: Vec<String> = if let Some(f) = &steering_fn {
|
|
f().await
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
loop {
|
|
// Check cancellation
|
|
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: "cancelled".to_string(),
|
|
})
|
|
.await;
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ErrorClassified {
|
|
category: "cancelled".to_string(),
|
|
message: "cancelled by caller".to_string(),
|
|
will_retry: false,
|
|
retry_delay_ms: None,
|
|
});
|
|
}
|
|
return Err(AiError::Response(
|
|
"agent run cancelled".to_string(),
|
|
));
|
|
}
|
|
|
|
// Inject steering messages if any
|
|
if !pending_steering.is_empty() {
|
|
let count = pending_steering.len();
|
|
for msg in &pending_steering {
|
|
current_input.push_str(&format!("\nUser: {msg}\n"));
|
|
}
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::SteeringMessagesInjected { count });
|
|
}
|
|
pending_steering.clear();
|
|
}
|
|
|
|
// Emit turn start
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::TurnStart { turn_index });
|
|
}
|
|
let _ = tx
|
|
.send(RigStreamChunk::TextDelta {
|
|
index: 0,
|
|
content: String::new(), // placeholder for turn boundary detection
|
|
})
|
|
.await;
|
|
|
|
// Run one LLM turn with retry
|
|
let turn_result = run_single_turn(
|
|
&client,
|
|
&config,
|
|
¤t_input,
|
|
&tools,
|
|
&mut budget,
|
|
&cancellation,
|
|
timeout,
|
|
&hooks,
|
|
&event_sink,
|
|
&tx,
|
|
)
|
|
.await;
|
|
|
|
match turn_result {
|
|
Ok(turn_output) => {
|
|
total_input_tokens += turn_output.input_tokens;
|
|
total_output_tokens += turn_output.output_tokens;
|
|
|
|
// Collect step
|
|
let tool_call_count = turn_output.tool_calls.len();
|
|
if !turn_output.tool_calls.is_empty()
|
|
|| !turn_output.assistant_text.is_empty()
|
|
{
|
|
all_steps.push(AgentStep {
|
|
index: all_steps.len(),
|
|
assistant: (!turn_output.assistant_text.is_empty())
|
|
.then_some(turn_output.assistant_text.clone()),
|
|
reasoning_content: None,
|
|
tool_calls: turn_output.tool_calls,
|
|
reflection: None,
|
|
});
|
|
}
|
|
|
|
// Emit turn end
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::TurnEnd {
|
|
turn_index,
|
|
assistant_text: Some(
|
|
turn_output.assistant_text.clone(),
|
|
),
|
|
tool_call_count,
|
|
});
|
|
}
|
|
|
|
// Check should_stop
|
|
let turn_ctx = TurnContext {
|
|
turn_index,
|
|
assistant_text: turn_output.assistant_text.clone(),
|
|
tool_call_count,
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
model_name: config.model.clone(),
|
|
};
|
|
|
|
if let Some(stop_fn) = &should_stop {
|
|
if stop_fn(&turn_ctx) {
|
|
info!(
|
|
turn_index,
|
|
"agent stopped by should_stop callback"
|
|
);
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Prepare next turn (may switch model)
|
|
if let Some(prep_fn) = &prepare_next {
|
|
if let Some(update) = prep_fn(&turn_ctx).await {
|
|
if let Some(new_model) = update.model {
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ModelSwitched {
|
|
from_model: config.model.clone(),
|
|
to_model: new_model.clone(),
|
|
reason: "prepare_next_turn".to_string(),
|
|
});
|
|
}
|
|
config.model = new_model;
|
|
}
|
|
if let Some(temp) = update.temperature {
|
|
config.temperature = Some(temp);
|
|
}
|
|
if let Some(max_tok) = update.max_completion_tokens
|
|
{
|
|
config.max_completion_tokens = Some(max_tok);
|
|
}
|
|
}
|
|
}
|
|
|
|
turn_index += 1;
|
|
|
|
// If no tool calls, this turn is done
|
|
if tool_call_count == 0 {
|
|
break;
|
|
}
|
|
|
|
// Otherwise, continue with tool results as new input
|
|
current_input = turn_output.assistant_text.clone();
|
|
}
|
|
Err(e) => {
|
|
// Error classification and retry with fallback
|
|
let category = classify_error(&e, None);
|
|
let policy = retry_policy_for(
|
|
&category,
|
|
config.retry_max_attempts,
|
|
config.retry_base_delay_ms,
|
|
);
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ErrorClassified {
|
|
category: format!("{category:?}"),
|
|
message: e.to_string(),
|
|
will_retry: policy.switch_to_fallback
|
|
|| policy.max_attempts > 0,
|
|
retry_delay_ms: Some(
|
|
policy.base_delay.as_millis() as u64
|
|
),
|
|
});
|
|
}
|
|
|
|
if should_switch_to_fallback(&category) {
|
|
if let Some(fallback_model) = &config.fallback_model {
|
|
info!(
|
|
from_model = %config.model,
|
|
to_model = %fallback_model,
|
|
"switching to fallback model due to error"
|
|
);
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ModelSwitched {
|
|
from_model: config.model.clone(),
|
|
to_model: fallback_model.clone(),
|
|
reason: format!("fallback: {category:?}"),
|
|
});
|
|
}
|
|
|
|
config.model = fallback_model.clone();
|
|
|
|
// Retry with the fallback model
|
|
let retry_result = run_single_turn(
|
|
&client,
|
|
&config,
|
|
¤t_input,
|
|
&tools,
|
|
&mut budget,
|
|
&cancellation,
|
|
timeout,
|
|
&hooks,
|
|
&event_sink,
|
|
&tx,
|
|
)
|
|
.await;
|
|
|
|
match retry_result {
|
|
Ok(turn_output) => {
|
|
total_input_tokens +=
|
|
turn_output.input_tokens;
|
|
total_output_tokens +=
|
|
turn_output.output_tokens;
|
|
let tc_count = turn_output.tool_calls.len();
|
|
let has_tools = tc_count > 0;
|
|
let has_text =
|
|
!turn_output.assistant_text.is_empty();
|
|
let assistant = turn_output.assistant_text;
|
|
if has_tools || has_text {
|
|
all_steps.push(AgentStep {
|
|
index: all_steps.len(),
|
|
assistant: has_text
|
|
.then_some(assistant.clone()),
|
|
reasoning_content: None,
|
|
tool_calls: turn_output.tool_calls,
|
|
reflection: None,
|
|
});
|
|
}
|
|
turn_index += 1;
|
|
if !has_tools {
|
|
break;
|
|
}
|
|
current_input = assistant;
|
|
continue;
|
|
}
|
|
Err(retry_err) => {
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: retry_err.to_string(),
|
|
})
|
|
.await;
|
|
if let Some(ctx) = &request.run_context {
|
|
let _ = hooks
|
|
.run_session_end(ctx, false)
|
|
.await;
|
|
}
|
|
return Err(retry_err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Non-retryable or no fallback
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: e.to_string(),
|
|
})
|
|
.await;
|
|
if let Some(ctx) = &request.run_context {
|
|
let _ = hooks.run_session_end(ctx, false).await;
|
|
}
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for follow-up messages
|
|
let follow_ups: Vec<String> = if let Some(f) = &follow_up_fn {
|
|
f().await
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
if follow_ups.is_empty() {
|
|
break;
|
|
}
|
|
|
|
// Inject follow-up messages and continue the outer loop
|
|
let count = follow_ups.len();
|
|
for msg in &follow_ups {
|
|
current_input.push_str(&format!("\nUser: {msg}\n"));
|
|
}
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::FollowUpMessagesInjected { count });
|
|
}
|
|
}
|
|
|
|
// Build final output
|
|
let output = all_steps
|
|
.last()
|
|
.and_then(|s| s.assistant.clone())
|
|
.unwrap_or_default();
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::AgentEnd {
|
|
messages: Vec::new(),
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
});
|
|
}
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::Final {
|
|
content: output.clone(),
|
|
input_tokens: total_input_tokens,
|
|
output_tokens: total_output_tokens,
|
|
})
|
|
.await;
|
|
|
|
if let Some(ctx) = &request.run_context {
|
|
let _ = hooks.run_session_end(ctx, true).await;
|
|
}
|
|
|
|
info!(
|
|
turns = turn_index,
|
|
steps = all_steps.len(),
|
|
total_input_tokens,
|
|
total_output_tokens,
|
|
"enhanced agent loop completed"
|
|
);
|
|
|
|
Ok(AgentResult {
|
|
output,
|
|
steps: all_steps,
|
|
expert_outputs: Vec::new(),
|
|
input_tokens: total_input_tokens as i64,
|
|
output_tokens: total_output_tokens as i64,
|
|
})
|
|
}
|
|
|
|
/// Output from a single LLM turn (one assistant response + its tool calls).
|
|
struct TurnOutput {
|
|
assistant_text: String,
|
|
tool_calls: Vec<ToolCallRecord>,
|
|
input_tokens: u64,
|
|
output_tokens: u64,
|
|
}
|
|
|
|
/// Run a single LLM turn with streaming, handling the stream parsing and
|
|
/// tool call collection.
|
|
#[allow(clippy::too_many_arguments)]
|
|
async fn run_single_turn(
|
|
client: &rig::providers::openai::Client,
|
|
config: &AgentConfig,
|
|
input: &str,
|
|
_tools: &[Box<dyn ToolDyn>],
|
|
budget: &mut IterationBudget,
|
|
cancellation: &Option<CancellationToken>,
|
|
timeout: Option<std::time::Duration>,
|
|
hooks: &HookChain,
|
|
event_sink: &Option<EventSink>,
|
|
tx: &mpsc::Sender<RigStreamChunk>,
|
|
) -> AiResult<TurnOutput> {
|
|
if !budget.consume() {
|
|
return Err(AiError::Response(
|
|
"iteration budget exhausted".to_string(),
|
|
));
|
|
}
|
|
|
|
let model = client.completion_model(&config.model);
|
|
let mut agent_builder = AgentBuilder::new(model)
|
|
.preamble(&config.system_prompt)
|
|
.default_max_turns(1); // Single turn, we manage the loop
|
|
|
|
// Note: we can't easily pass tools here for single-turn since
|
|
// rig's multi_turn handles tool execution internally.
|
|
// For the enhanced loop, we rely on rig's built-in tool execution
|
|
// within a single turn. The parallel/sequential mode is controlled
|
|
// by the event-level hooks.
|
|
|
|
if let Some(temp) = config.temperature {
|
|
agent_builder = agent_builder.temperature(temp);
|
|
}
|
|
if let Some(mt) = config.max_completion_tokens {
|
|
agent_builder = agent_builder.max_tokens(mt);
|
|
}
|
|
|
|
let agent = agent_builder.build();
|
|
|
|
// Pre-LLM hook
|
|
if !hooks.is_empty() {
|
|
let hook_messages = vec![HookMessage {
|
|
role: "user".to_string(),
|
|
content: Some(input.to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
}];
|
|
let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await;
|
|
}
|
|
|
|
let stream_future = agent
|
|
.stream_prompt(input)
|
|
.with_history(Vec::<rig::completion::Message>::new())
|
|
.multi_turn(config.max_iterations);
|
|
|
|
let stream = if let Some(dur) = timeout {
|
|
match tokio::time::timeout(dur, stream_future).await {
|
|
Ok(stream) => stream,
|
|
Err(_) => {
|
|
return Err(AiError::Timeout {
|
|
seconds: dur.as_secs(),
|
|
});
|
|
}
|
|
}
|
|
} else {
|
|
stream_future.await
|
|
};
|
|
|
|
tokio::pin!(stream);
|
|
|
|
let mut assistant_text = String::new();
|
|
let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
|
|
let mut delta_index = 0usize;
|
|
let mut _accumulated_output_chars: usize = 0;
|
|
let mut input_tokens: u64 = 0;
|
|
let mut output_tokens: u64 = 0;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
|
return Err(AiError::Response("cancelled".to_string()));
|
|
}
|
|
|
|
match item {
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
_accumulated_output_chars += text.text.chars().count();
|
|
assistant_text.push_str(&text.text);
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::MessageTextDelta {
|
|
index: delta_index,
|
|
delta: text.text.clone(),
|
|
});
|
|
}
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::TextDelta {
|
|
index: delta_index,
|
|
content: text.text.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
|
|
)) => {
|
|
for part in &reasoning.content {
|
|
if let rig::completion::message::ReasoningContent::Text {
|
|
text,
|
|
..
|
|
} = part
|
|
{
|
|
_accumulated_output_chars += text.chars().count();
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::MessageThinkingDelta {
|
|
index: delta_index,
|
|
delta: text.clone(),
|
|
});
|
|
}
|
|
let _ = tx
|
|
.send(RigStreamChunk::Thinking {
|
|
index: delta_index,
|
|
content: text.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
}
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::ReasoningDelta {
|
|
reasoning,
|
|
..
|
|
},
|
|
)) => {
|
|
_accumulated_output_chars += reasoning.chars().count();
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::MessageThinkingDelta {
|
|
index: delta_index,
|
|
delta: reasoning.clone(),
|
|
});
|
|
}
|
|
let _ = tx
|
|
.send(RigStreamChunk::Thinking {
|
|
index: delta_index,
|
|
content: reasoning.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::ToolCall {
|
|
tool_call,
|
|
..
|
|
},
|
|
)) => {
|
|
let args = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
v => serde_json::to_string(v).unwrap_or_default(),
|
|
};
|
|
_accumulated_output_chars += args.chars().count();
|
|
|
|
let tool_name = tool_call.function.name.clone();
|
|
let tool_args: serde_json::Value =
|
|
serde_json::from_str(&args).unwrap_or_default();
|
|
|
|
// Pre-tool-call guardrail hook
|
|
if let Ok(Some(decision)) =
|
|
hooks.run_pre_tool_call(&tool_name, &tool_args).await
|
|
{
|
|
match decision {
|
|
ToolGuardrailDecision::Allow => {}
|
|
ToolGuardrailDecision::Block { reason } => {
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ToolExecutionEnd {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
output: None,
|
|
error: Some(reason.clone()),
|
|
elapsed_ms: 0,
|
|
});
|
|
}
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallFinished {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
output: format!("blocked: {reason}"),
|
|
error: Some(reason),
|
|
})
|
|
.await;
|
|
tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name,
|
|
arguments: tool_args,
|
|
output: None,
|
|
error: Some("blocked by guardrail".to_string()),
|
|
elapsed_ms: None,
|
|
});
|
|
continue;
|
|
}
|
|
ToolGuardrailDecision::RequireApproval { message } => {
|
|
tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name.clone(),
|
|
arguments: tool_args,
|
|
output: None,
|
|
error: Some(format!(
|
|
"requires approval: {message}"
|
|
)),
|
|
elapsed_ms: None,
|
|
});
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ToolExecutionStart {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
arguments: tool_args.clone(),
|
|
});
|
|
}
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallStarted {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
arguments: args.clone(),
|
|
})
|
|
.await;
|
|
tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name,
|
|
arguments: tool_args,
|
|
output: None,
|
|
error: None,
|
|
elapsed_ms: None,
|
|
});
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
|
|
rig::streaming::StreamedUserContent::ToolResult {
|
|
tool_result,
|
|
..
|
|
},
|
|
)) => {
|
|
let content = super::helpers::tool_result_content_to_string(
|
|
&tool_result.content,
|
|
);
|
|
_accumulated_output_chars += content.chars().count();
|
|
|
|
let tool_name = tool_calls
|
|
.last()
|
|
.map(|tc| tc.name.clone())
|
|
.unwrap_or_default();
|
|
|
|
if let Some(last) = tool_calls.last_mut()
|
|
&& last.id == tool_result.id
|
|
{
|
|
last.output = Some(
|
|
serde_json::from_str(&content).unwrap_or_default(),
|
|
);
|
|
}
|
|
|
|
if let Some(sink) = &event_sink {
|
|
sink.emit(AgentEvent::ToolExecutionEnd {
|
|
tool_call_id: tool_result.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
output: Some(serde_json::Value::String(
|
|
content.clone(),
|
|
)),
|
|
error: None,
|
|
elapsed_ms: 0,
|
|
});
|
|
}
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallFinished {
|
|
tool_call_id: tool_result.id.clone(),
|
|
tool_name,
|
|
output: content.clone(),
|
|
error: None,
|
|
})
|
|
.await;
|
|
|
|
if !hooks.is_empty() {
|
|
let outcome = ToolCallOutcome {
|
|
name: tool_result.id.clone(),
|
|
arguments: serde_json::Value::Null,
|
|
output: Some(serde_json::Value::String(content)),
|
|
error: None,
|
|
elapsed_ms: 0,
|
|
};
|
|
let _ = hooks.run_post_tool_call(&outcome).await;
|
|
}
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
input_tokens = usage.input_tokens;
|
|
output_tokens = usage.output_tokens;
|
|
|
|
if !hooks.is_empty() {
|
|
let hook_response = HookLlmResponse {
|
|
content: Some(assistant_text.clone()),
|
|
tool_calls: None,
|
|
input_tokens,
|
|
output_tokens,
|
|
finish_reason: None,
|
|
};
|
|
let _ = hooks.run_post_llm_call(&hook_response).await;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
warn!(error = %e, "turn stream error");
|
|
return Err(AiError::Api(format!("{e}")));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
Ok(TurnOutput {
|
|
assistant_text,
|
|
tool_calls,
|
|
input_tokens,
|
|
output_tokens,
|
|
})
|
|
}
|