gitdataai/lib/ai/agent/agent.rs

591 lines
21 KiB
Rust

use futures::StreamExt;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::streaming::StreamingPrompt;
use rig::tool::ToolDyn;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use super::RigStreamChunk;
use super::config::AgentConfig;
use super::helpers::{build_input_string, check_token_budget, estimate_tokens};
use super::hooks::{
HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome,
ToolGuardrailDecision,
};
use super::persistence::ActiveAgentRun;
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
use super::subagent::run_experts;
use crate::client::AiClient;
use crate::error::{AiError, AiResult};
pub struct RigAgent {
pub client: AiClient,
pub config: AgentConfig,
pub hooks: HookChain,
}
impl RigAgent {
pub fn new(client: AiClient, config: AgentConfig) -> AiResult<Self> {
config.validate()?;
Ok(Self {
client,
config,
hooks: HookChain::empty(),
})
}
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
self.hooks = hooks;
self
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
pub async fn chat(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> AiResult<String> {
let (mut rx, handle) = self.run(request, tools);
tokio::spawn(async move { while rx.recv().await.is_some() {} });
let result = handle.await.map_err(|_| {
AiError::Response("agent task panicked".to_string())
})?;
result.map(|r| r.output)
}
#[allow(clippy::too_many_lines)]
pub fn run(
&self,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
) -> (
tokio::sync::mpsc::Receiver<RigStreamChunk>,
tokio::task::JoinHandle<AiResult<AgentResult>>,
) {
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
let model_name = self.config.model.clone();
let max_iterations = self.config.max_iterations;
let client = self.client.llm_client().clone();
let ai_client = self.client.clone();
let agent_config = self.config.clone();
let system_prompt = self.config.system_prompt.clone();
let temperature = self.config.temperature;
let max_completion_tokens = self.config.max_completion_tokens;
let max_total_tokens = self.config.max_total_tokens_per_run;
let cancellation = request.cancellation_token.clone();
let timeout = request.timeout;
let hooks = self.hooks.clone();
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
.into_iter()
.filter(|tool| self.config.is_tool_exposed(&tool.name()))
.collect();
let handle = tokio::spawn(async move {
execute_agent_run(
client,
model_name,
system_prompt,
request,
filtered_tools,
max_iterations,
ai_client,
agent_config,
temperature,
max_completion_tokens,
max_total_tokens,
cancellation,
timeout,
hooks,
tx,
)
.await
});
(rx, handle)
}
}
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn execute_agent_run(
client: rig::providers::openai::Client,
model_name: String,
system_prompt: String,
request: AgentRequest,
tools: Vec<Box<dyn ToolDyn>>,
max_iterations: usize,
ai_client: AiClient,
agent_config: AgentConfig,
temperature: Option<f64>,
max_completion_tokens: Option<u64>,
max_total_tokens: Option<i64>,
cancellation: Option<CancellationToken>,
timeout: Option<std::time::Duration>,
hooks: HookChain,
tx: mpsc::Sender<RigStreamChunk>,
) -> AiResult<AgentResult> {
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_start(ctx).await;
}
let model = client.completion_model(&model_name);
let mut agent_builder = AgentBuilder::new(model)
.preamble(&system_prompt)
.tools(tools)
.default_max_turns(max_iterations);
if let Some(temp) = temperature {
agent_builder = agent_builder.temperature(temp);
}
if let Some(mt) = max_completion_tokens {
agent_builder = agent_builder.max_tokens(mt);
}
let agent = agent_builder.build();
let mut input = build_input_string(&request);
// ---- SubAgent execution ----
let expert_outputs = if !request.experts.is_empty() {
let run = ActiveAgentRun {
conversation_id: request
.run_context
.as_ref()
.and_then(|c| c.conversation_id),
message_id: None,
invocation_id: request
.run_context
.as_ref()
.and_then(|c| c.invocation_id),
session_id: request.run_context.as_ref().and_then(|c| c.session_id),
user_id: request.run_context.as_ref().and_then(|c| c.user_id),
started_at: std::time::Instant::now(),
current_step: 0,
};
let realtime = request
.run_context
.as_ref()
.and_then(|c| c.realtime.as_ref());
// Notify frontend that subagents are starting.
for expert in &request.experts {
let _ = tx
.send(RigStreamChunk::SubagentStarted {
subagent_id: expert.id.clone(),
role: expert.role.clone(),
task: expert.task.clone(),
})
.await;
}
match run_experts(
&ai_client,
&agent_config,
&request.experts,
realtime,
&run,
)
.await
{
Ok(outputs) => {
for out in &outputs {
let _ = tx
.send(RigStreamChunk::SubagentCompleted {
subagent_id: out.id.clone(),
role: out.role.clone(),
task: out.task.clone(),
output: out.output.clone(),
})
.await;
input.push_str(&format!(
"\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n",
out.id, out.role, out.task, out.output
));
}
outputs
}
Err(e) => {
warn!(error = %e, "subagent execution failed, continuing without expert inputs");
let _ = tx
.send(RigStreamChunk::SubagentFailed {
error: e.to_string(),
})
.await;
Vec::new()
}
}
} else {
Vec::new()
};
let estimated_input_tokens = estimate_tokens(&input);
if let Some(limit) = max_total_tokens
&& estimated_input_tokens > limit as u64
{
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens,
limit,
});
}
if !hooks.is_empty() {
let hook_messages: Vec<HookMessage> = request
.messages
.iter()
.map(|m| HookMessage {
role: match m {
super::request::AgentMessage::User(_) => "user".to_string(),
super::request::AgentMessage::Assistant(_) => {
"assistant".to_string()
}
},
content: match m {
super::request::AgentMessage::User(c) => Some(c.clone()),
super::request::AgentMessage::Assistant(c) => {
Some(c.clone())
}
},
tool_calls: None,
tool_call_id: None,
})
.collect();
let hook_tools: Vec<HookToolDef> = Vec::new();
let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await;
}
let stream_future = agent
.stream_prompt(&input)
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(max_iterations);
let stream = if let Some(dur) = timeout {
match tokio::time::timeout(dur, stream_future).await {
Ok(stream) => stream,
Err(_elapsed) => {
let _ = tx
.send(RigStreamChunk::Failed {
error: format!(
"agent timed out after {}s",
dur.as_secs()
),
})
.await;
return Err(AiError::Timeout {
seconds: dur.as_secs(),
});
}
}
} else {
stream_future.await
};
tokio::pin!(stream);
let mut steps = Vec::new();
let mut delta_index = 0usize;
let mut current_step_tool_calls: Vec<ToolCallRecord> = Vec::new();
let mut current_step_assistant = String::new();
let mut current_step_reasoning = String::new();
let mut accumulated_output_chars: usize = 0;
while let Some(item) = stream.next().await {
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
let _ = tx
.send(RigStreamChunk::Failed {
error: "cancelled".to_string(),
})
.await;
return Err(AiError::Response("agent run cancelled".to_string()));
}
if let Some(limit) = max_total_tokens
&& check_token_budget(
estimated_input_tokens,
accumulated_output_chars,
limit,
)
{
let _ = tx
.send(RigStreamChunk::Failed {
error: format!("token budget exceeded: limit {limit}"),
})
.await;
return Err(AiError::TokenBudgetExceeded {
estimated: estimated_input_tokens
+ (accumulated_output_chars as f64 / 2.5).ceil() as u64,
limit,
});
}
match item {
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Text(text),
)) => {
accumulated_output_chars += text.text.chars().count();
current_step_assistant.push_str(&text.text);
let _ = tx
.send(RigStreamChunk::TextDelta {
index: delta_index,
content: text.text.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
)) => {
for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text {
text,
..
} = part
{
accumulated_output_chars += text.chars().count();
current_step_reasoning.push_str(text);
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: text.clone(),
})
.await;
delta_index += 1;
}
}
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ReasoningDelta {
reasoning,
..
},
)) => {
accumulated_output_chars += reasoning.chars().count();
current_step_reasoning.push_str(&reasoning);
let _ = tx
.send(RigStreamChunk::Thinking {
index: delta_index,
content: reasoning.clone(),
})
.await;
delta_index += 1;
}
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
rig::streaming::StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id: _,
},
)) => {
let args = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
v => serde_json::to_string(v).unwrap_or_default(),
};
accumulated_output_chars += args.chars().count();
let tool_name = tool_call.function.name.clone();
let tool_args: serde_json::Value =
serde_json::from_str(&args).unwrap_or_default();
if let Ok(Some(decision)) =
hooks.run_pre_tool_call(&tool_name, &tool_args).await
{
match decision {
ToolGuardrailDecision::Allow => {}
ToolGuardrailDecision::Block { reason } => {
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!("blocked: {reason}"),
error: Some(reason),
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: Some("blocked by guardrail".to_string()),
elapsed_ms: None,
});
continue;
}
ToolGuardrailDecision::RequireApproval { message } => {
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
output: format!(
"awaiting approval: {message}"
),
error: None,
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: Some(format!(
"requires approval: {message}"
)),
elapsed_ms: None,
});
continue;
}
}
}
let _ = tx
.send(RigStreamChunk::ToolCallStarted {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
arguments: args.clone(),
})
.await;
current_step_tool_calls.push(ToolCallRecord {
id: tool_call.id.clone(),
name: tool_name.clone(),
arguments: tool_args.clone(),
output: None,
error: None,
elapsed_ms: None,
});
}
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult {
tool_result,
..
},
)) => {
let content = super::helpers::tool_result_content_to_string(
&tool_result.content,
);
accumulated_output_chars += content.chars().count();
if let Some(last) = current_step_tool_calls.last_mut()
&& last.id == tool_result.id
{
last.output = Some(
serde_json::from_str(&content).unwrap_or_default(),
);
}
let tool_name = current_step_tool_calls
.last()
.map(|tc| tc.name.clone())
.unwrap_or_default();
let _ = tx
.send(RigStreamChunk::ToolCallFinished {
tool_call_id: tool_result.id.clone(),
tool_name,
output: content.clone(),
error: None,
})
.await;
if !hooks.is_empty() {
let outcome = ToolCallOutcome {
name: tool_result.id.clone(),
arguments: serde_json::Value::Null,
output: Some(serde_json::Value::String(content)),
error: None,
elapsed_ms: 0,
};
let _ = hooks.run_post_tool_call(&outcome).await;
}
}
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
if !current_step_tool_calls.is_empty()
|| !current_step_assistant.is_empty()
{
let reasoning = (!current_step_reasoning.is_empty())
.then_some(std::mem::take(&mut current_step_reasoning));
steps.push(AgentStep {
index: steps.len(),
assistant: (!current_step_assistant.is_empty())
.then_some(std::mem::take(
&mut current_step_assistant,
)),
reasoning_content: reasoning,
tool_calls: std::mem::take(
&mut current_step_tool_calls,
),
reflection: None,
});
}
let output = steps
.last()
.and_then(|s| s.assistant.clone())
.unwrap_or_default();
if !hooks.is_empty() {
let hook_response = HookLlmResponse {
content: Some(output.clone()),
tool_calls: None,
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
finish_reason: None,
};
let _ = hooks.run_post_llm_call(&hook_response).await;
}
info!(
steps = steps.len(),
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
"agent run completed"
);
let _ = tx
.send(RigStreamChunk::Final {
content: output.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})
.await;
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_end(ctx, true).await;
}
return Ok(AgentResult {
output,
steps,
expert_outputs,
input_tokens: usage.input_tokens as i64,
output_tokens: usage.output_tokens as i64,
});
}
Err(e) => {
let err = format!("{e}");
warn!(error = %err, "agent stream error");
let _ = tx.send(RigStreamChunk::Failed { error: err }).await;
if let Some(ref ctx) = request.run_context {
let _ = hooks.run_session_end(ctx, false).await;
}
return Err(AiError::Api(format!("{e}")));
}
_ => {}
}
}
Err(AiError::Response(
"agent stream ended without final response".to_string(),
))
}
impl Clone for HookChain {
fn clone(&self) -> Self {
HookChain::empty()
}
}