544 lines
20 KiB
Rust
544 lines
20 KiB
Rust
use futures::StreamExt;
|
|
use rig::agent::AgentBuilder;
|
|
use rig::client::CompletionClient;
|
|
use rig::streaming::StreamingPrompt;
|
|
use rig::tool::ToolDyn;
|
|
use tokio::sync::mpsc;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{info, warn};
|
|
|
|
use super::config::AgentConfig;
|
|
use super::helpers::{build_input_string, check_token_budget, estimate_tokens};
|
|
use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision};
|
|
use super::persistence::ActiveAgentRun;
|
|
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
|
|
use super::subagent::run_experts;
|
|
use super::RigStreamChunk;
|
|
use crate::client::AiClient;
|
|
use crate::error::{AiError, AiResult};
|
|
|
|
pub struct RigAgent {
|
|
pub client: AiClient,
|
|
pub config: AgentConfig,
|
|
pub hooks: HookChain,
|
|
}
|
|
|
|
impl RigAgent {
|
|
pub fn new(client: AiClient, config: AgentConfig) -> AiResult<Self> {
|
|
config.validate()?;
|
|
Ok(Self {
|
|
client,
|
|
config,
|
|
hooks: HookChain::empty(),
|
|
})
|
|
}
|
|
|
|
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
|
|
self.hooks = hooks;
|
|
self
|
|
}
|
|
|
|
pub fn config(&self) -> &AgentConfig {
|
|
&self.config
|
|
}
|
|
|
|
pub async fn chat(
|
|
&self,
|
|
request: AgentRequest,
|
|
tools: Vec<Box<dyn ToolDyn>>,
|
|
) -> AiResult<String> {
|
|
let (mut rx, handle) = self.run(request, tools);
|
|
tokio::spawn(async move {
|
|
while rx.recv().await.is_some() {}
|
|
});
|
|
let result = handle.await.map_err(|_| {
|
|
AiError::Response("agent task panicked".to_string())
|
|
})?;
|
|
result.map(|r| r.output)
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub fn run(
|
|
&self,
|
|
request: AgentRequest,
|
|
tools: Vec<Box<dyn ToolDyn>>,
|
|
) -> (
|
|
tokio::sync::mpsc::Receiver<RigStreamChunk>,
|
|
tokio::task::JoinHandle<AiResult<AgentResult>>,
|
|
) {
|
|
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
|
|
|
|
let model_name = self.config.model.clone();
|
|
let max_iterations = self.config.max_iterations;
|
|
let client = self.client.llm_client().clone();
|
|
let ai_client = self.client.clone();
|
|
let agent_config = self.config.clone();
|
|
let system_prompt = self.config.system_prompt.clone();
|
|
let temperature = self.config.temperature;
|
|
let max_completion_tokens = self.config.max_completion_tokens;
|
|
let max_total_tokens = self.config.max_total_tokens_per_run;
|
|
let cancellation = request.cancellation_token.clone();
|
|
let timeout = request.timeout;
|
|
let hooks = self.hooks.clone();
|
|
|
|
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
|
|
.into_iter()
|
|
.filter(|tool| self.config.is_tool_exposed(&tool.name()))
|
|
.collect();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
execute_agent_run(
|
|
client,
|
|
model_name,
|
|
system_prompt,
|
|
request,
|
|
filtered_tools,
|
|
max_iterations,
|
|
ai_client,
|
|
agent_config,
|
|
temperature,
|
|
max_completion_tokens,
|
|
max_total_tokens,
|
|
cancellation,
|
|
timeout,
|
|
hooks,
|
|
tx,
|
|
)
|
|
.await
|
|
});
|
|
|
|
(rx, handle)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
|
async fn execute_agent_run(
|
|
client: rig::providers::openai::Client,
|
|
model_name: String,
|
|
system_prompt: String,
|
|
request: AgentRequest,
|
|
tools: Vec<Box<dyn ToolDyn>>,
|
|
max_iterations: usize,
|
|
ai_client: AiClient,
|
|
agent_config: AgentConfig,
|
|
temperature: Option<f64>,
|
|
max_completion_tokens: Option<u64>,
|
|
max_total_tokens: Option<i64>,
|
|
cancellation: Option<CancellationToken>,
|
|
timeout: Option<std::time::Duration>,
|
|
hooks: HookChain,
|
|
tx: mpsc::Sender<RigStreamChunk>,
|
|
) -> AiResult<AgentResult> {
|
|
if let Some(ref ctx) = request.run_context {
|
|
let _ = hooks.run_session_start(ctx).await;
|
|
}
|
|
|
|
let model = client.completion_model(&model_name);
|
|
let mut agent_builder = AgentBuilder::new(model)
|
|
.preamble(&system_prompt)
|
|
.tools(tools)
|
|
.default_max_turns(max_iterations);
|
|
|
|
if let Some(temp) = temperature {
|
|
agent_builder = agent_builder.temperature(temp);
|
|
}
|
|
if let Some(mt) = max_completion_tokens {
|
|
agent_builder = agent_builder.max_tokens(mt);
|
|
}
|
|
|
|
let agent = agent_builder.build();
|
|
let mut input = build_input_string(&request);
|
|
|
|
// ---- SubAgent execution ----
|
|
let expert_outputs = if !request.experts.is_empty() {
|
|
let run = ActiveAgentRun {
|
|
conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id),
|
|
message_id: None,
|
|
invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id),
|
|
session_id: request.run_context.as_ref().and_then(|c| c.session_id),
|
|
user_id: request.run_context.as_ref().and_then(|c| c.user_id),
|
|
started_at: std::time::Instant::now(),
|
|
current_step: 0,
|
|
};
|
|
let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref());
|
|
|
|
// Notify frontend that subagents are starting.
|
|
for expert in &request.experts {
|
|
let _ = tx
|
|
.send(RigStreamChunk::SubagentStarted {
|
|
subagent_id: expert.id.clone(),
|
|
role: expert.role.clone(),
|
|
task: expert.task.clone(),
|
|
})
|
|
.await;
|
|
}
|
|
|
|
match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await {
|
|
Ok(outputs) => {
|
|
for out in &outputs {
|
|
let _ = tx
|
|
.send(RigStreamChunk::SubagentCompleted {
|
|
subagent_id: out.id.clone(),
|
|
role: out.role.clone(),
|
|
task: out.task.clone(),
|
|
output: out.output.clone(),
|
|
})
|
|
.await;
|
|
input.push_str(&format!(
|
|
"\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n",
|
|
out.id, out.role, out.task, out.output
|
|
));
|
|
}
|
|
outputs
|
|
}
|
|
Err(e) => {
|
|
warn!(error = %e, "subagent execution failed, continuing without expert inputs");
|
|
let _ = tx
|
|
.send(RigStreamChunk::SubagentFailed {
|
|
error: e.to_string(),
|
|
})
|
|
.await;
|
|
Vec::new()
|
|
}
|
|
}
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
let estimated_input_tokens = estimate_tokens(&input);
|
|
|
|
if let Some(limit) = max_total_tokens
|
|
&& estimated_input_tokens > limit as u64
|
|
{
|
|
return Err(AiError::TokenBudgetExceeded {
|
|
estimated: estimated_input_tokens,
|
|
limit,
|
|
});
|
|
}
|
|
|
|
if !hooks.is_empty() {
|
|
let hook_messages: Vec<HookMessage> = request
|
|
.messages
|
|
.iter()
|
|
.map(|m| HookMessage {
|
|
role: match m {
|
|
super::request::AgentMessage::User(_) => "user".to_string(),
|
|
super::request::AgentMessage::Assistant(_) => {
|
|
"assistant".to_string()
|
|
}
|
|
},
|
|
content: match m {
|
|
super::request::AgentMessage::User(c) => Some(c.clone()),
|
|
super::request::AgentMessage::Assistant(c) => {
|
|
Some(c.clone())
|
|
}
|
|
},
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
})
|
|
.collect();
|
|
let hook_tools: Vec<HookToolDef> = Vec::new();
|
|
let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await;
|
|
}
|
|
|
|
let stream_future = agent
|
|
.stream_prompt(&input)
|
|
.with_history(Vec::<rig::completion::Message>::new())
|
|
.multi_turn(max_iterations);
|
|
|
|
let stream = if let Some(dur) = timeout {
|
|
match tokio::time::timeout(dur, stream_future).await {
|
|
Ok(stream) => stream,
|
|
Err(_elapsed) => {
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: format!("agent timed out after {}s", dur.as_secs()),
|
|
})
|
|
.await;
|
|
return Err(AiError::Timeout {
|
|
seconds: dur.as_secs(),
|
|
});
|
|
}
|
|
}
|
|
} else {
|
|
stream_future.await
|
|
};
|
|
|
|
tokio::pin!(stream);
|
|
|
|
let mut steps = Vec::new();
|
|
let mut delta_index = 0usize;
|
|
let mut current_step_tool_calls: Vec<ToolCallRecord> = Vec::new();
|
|
let mut current_step_assistant = String::new();
|
|
let mut current_step_reasoning = String::new();
|
|
let mut accumulated_output_chars: usize = 0;
|
|
|
|
while let Some(item) = stream.next().await {
|
|
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: "cancelled".to_string(),
|
|
})
|
|
.await;
|
|
return Err(AiError::Response("agent run cancelled".to_string()));
|
|
}
|
|
|
|
if let Some(limit) = max_total_tokens
|
|
&& check_token_budget(estimated_input_tokens, accumulated_output_chars, limit)
|
|
{
|
|
let _ = tx
|
|
.send(RigStreamChunk::Failed {
|
|
error: format!("token budget exceeded: limit {limit}"),
|
|
})
|
|
.await;
|
|
return Err(AiError::TokenBudgetExceeded {
|
|
estimated: estimated_input_tokens
|
|
+ (accumulated_output_chars as f64 / 2.5).ceil() as u64,
|
|
limit,
|
|
});
|
|
}
|
|
|
|
match item {
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::Text(text),
|
|
)) => {
|
|
accumulated_output_chars += text.text.chars().count();
|
|
current_step_assistant.push_str(&text.text);
|
|
let _ = tx
|
|
.send(RigStreamChunk::TextDelta {
|
|
index: delta_index,
|
|
content: text.text.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
|
|
)) => {
|
|
for part in &reasoning.content {
|
|
if let rig::completion::message::ReasoningContent::Text {
|
|
text, ..
|
|
} = part
|
|
{
|
|
accumulated_output_chars += text.chars().count();
|
|
current_step_reasoning.push_str(text);
|
|
let _ = tx
|
|
.send(RigStreamChunk::Thinking {
|
|
index: delta_index,
|
|
content: text.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
}
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::ReasoningDelta {
|
|
reasoning, ..
|
|
},
|
|
)) => {
|
|
accumulated_output_chars += reasoning.chars().count();
|
|
current_step_reasoning.push_str(&reasoning);
|
|
let _ = tx
|
|
.send(RigStreamChunk::Thinking {
|
|
index: delta_index,
|
|
content: reasoning.clone(),
|
|
})
|
|
.await;
|
|
delta_index += 1;
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
|
rig::streaming::StreamedAssistantContent::ToolCall {
|
|
tool_call,
|
|
internal_call_id: _,
|
|
},
|
|
)) => {
|
|
let args = match &tool_call.function.arguments {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
v => serde_json::to_string(v).unwrap_or_default(),
|
|
};
|
|
accumulated_output_chars += args.chars().count();
|
|
|
|
let tool_name = tool_call.function.name.clone();
|
|
let tool_args: serde_json::Value =
|
|
serde_json::from_str(&args).unwrap_or_default();
|
|
|
|
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
|
|
match decision {
|
|
ToolGuardrailDecision::Allow => {}
|
|
ToolGuardrailDecision::Block { reason } => {
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallFinished {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
output: format!("blocked: {reason}"),
|
|
error: Some(reason),
|
|
})
|
|
.await;
|
|
current_step_tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name.clone(),
|
|
arguments: tool_args.clone(),
|
|
output: None,
|
|
error: Some("blocked by guardrail".to_string()),
|
|
elapsed_ms: None,
|
|
});
|
|
continue;
|
|
}
|
|
ToolGuardrailDecision::RequireApproval { message } => {
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallFinished {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
output: format!("awaiting approval: {message}"),
|
|
error: None,
|
|
})
|
|
.await;
|
|
current_step_tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name.clone(),
|
|
arguments: tool_args.clone(),
|
|
output: None,
|
|
error: Some(format!("requires approval: {message}")),
|
|
elapsed_ms: None,
|
|
});
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallStarted {
|
|
tool_call_id: tool_call.id.clone(),
|
|
tool_name: tool_name.clone(),
|
|
arguments: args.clone(),
|
|
})
|
|
.await;
|
|
current_step_tool_calls.push(ToolCallRecord {
|
|
id: tool_call.id.clone(),
|
|
name: tool_name.clone(),
|
|
arguments: tool_args.clone(),
|
|
output: None,
|
|
error: None,
|
|
elapsed_ms: None,
|
|
});
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
|
|
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
|
)) => {
|
|
let content =
|
|
super::helpers::tool_result_content_to_string(&tool_result.content);
|
|
accumulated_output_chars += content.chars().count();
|
|
|
|
if let Some(last) = current_step_tool_calls.last_mut()
|
|
&& last.id == tool_result.id
|
|
{
|
|
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
|
|
}
|
|
|
|
let tool_name = current_step_tool_calls
|
|
.last()
|
|
.map(|tc| tc.name.clone())
|
|
.unwrap_or_default();
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::ToolCallFinished {
|
|
tool_call_id: tool_result.id.clone(),
|
|
tool_name,
|
|
output: content.clone(),
|
|
error: None,
|
|
})
|
|
.await;
|
|
|
|
if !hooks.is_empty() {
|
|
let outcome = ToolCallOutcome {
|
|
name: tool_result.id.clone(),
|
|
arguments: serde_json::Value::Null,
|
|
output: Some(serde_json::Value::String(content)),
|
|
error: None,
|
|
elapsed_ms: 0,
|
|
};
|
|
let _ = hooks.run_post_tool_call(&outcome).await;
|
|
}
|
|
}
|
|
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
|
let usage = resp.usage();
|
|
|
|
if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() {
|
|
let reasoning = (!current_step_reasoning.is_empty())
|
|
.then_some(std::mem::take(&mut current_step_reasoning));
|
|
steps.push(AgentStep {
|
|
index: steps.len(),
|
|
assistant: (!current_step_assistant.is_empty())
|
|
.then_some(std::mem::take(&mut current_step_assistant)),
|
|
reasoning_content: reasoning,
|
|
tool_calls: std::mem::take(&mut current_step_tool_calls),
|
|
reflection: None,
|
|
});
|
|
}
|
|
let output = steps
|
|
.last()
|
|
.and_then(|s| s.assistant.clone())
|
|
.unwrap_or_default();
|
|
|
|
if !hooks.is_empty() {
|
|
let hook_response = HookLlmResponse {
|
|
content: Some(output.clone()),
|
|
tool_calls: None,
|
|
input_tokens: usage.input_tokens,
|
|
output_tokens: usage.output_tokens,
|
|
finish_reason: None,
|
|
};
|
|
let _ = hooks.run_post_llm_call(&hook_response).await;
|
|
}
|
|
|
|
info!(
|
|
steps = steps.len(),
|
|
input_tokens = usage.input_tokens,
|
|
output_tokens = usage.output_tokens,
|
|
"agent run completed"
|
|
);
|
|
|
|
let _ = tx
|
|
.send(RigStreamChunk::Final {
|
|
content: output.clone(),
|
|
input_tokens: usage.input_tokens,
|
|
output_tokens: usage.output_tokens,
|
|
})
|
|
.await;
|
|
|
|
if let Some(ref ctx) = request.run_context {
|
|
let _ = hooks.run_session_end(ctx, true).await;
|
|
}
|
|
|
|
return Ok(AgentResult {
|
|
output,
|
|
steps,
|
|
expert_outputs,
|
|
input_tokens: usage.input_tokens as i64,
|
|
output_tokens: usage.output_tokens as i64,
|
|
});
|
|
}
|
|
Err(e) => {
|
|
let err = format!("{e}");
|
|
warn!(error = %err, "agent stream error");
|
|
let _ = tx.send(RigStreamChunk::Failed { error: err }).await;
|
|
|
|
if let Some(ref ctx) = request.run_context {
|
|
let _ = hooks.run_session_end(ctx, false).await;
|
|
}
|
|
return Err(AiError::Api(format!("{e}")));
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
Err(AiError::Response("agent stream ended without final response".to_string()))
|
|
}
|
|
|
|
impl Clone for HookChain {
|
|
fn clone(&self) -> Self {
|
|
HookChain::empty()
|
|
}
|
|
}
|