gitdataai/libs/agent/tool/executor.rs

146 lines
4.5 KiB
Rust

//! Executes tool calls and converts results to OpenAI `tool` messages.
use futures::StreamExt;
use futures::stream;
use crate::client::ChatRequestMessage;
use super::call::{ToolCall, ToolCallResult, ToolError, ToolResult};
use super::context::ToolContext;
pub struct ToolExecutor {
max_tool_calls: usize,
max_depth: u32,
max_concurrency: usize,
}
impl Default for ToolExecutor {
fn default() -> Self {
Self {
max_tool_calls: 128,
max_depth: 5,
max_concurrency: 8,
}
}
}
impl ToolExecutor {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_tool_calls(mut self, max: usize) -> Self {
self.max_tool_calls = max;
self
}
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
/// Set the maximum number of tool calls executed concurrently.
/// Defaults to 8. Set to 1 for strictly sequential execution.
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = n;
self
}
/// # Errors
///
/// Returns `ToolError::MaxToolCallsExceeded` if the total number of tool calls
/// exceeds `max_tool_calls`.
pub async fn execute_batch(
&self,
calls: Vec<ToolCall>,
ctx: &mut ToolContext,
) -> Result<Vec<ToolCallResult>, ToolError> {
let ctx = ctx
.clone()
.with_max_tool_calls(self.max_tool_calls)
.with_max_depth(self.max_depth);
if ctx.recursion_exceeded() {
return Err(ToolError::RecursionLimitExceeded {
max_depth: self.max_depth,
});
}
if let Err(current) = ctx.reserve_tool_calls(calls.len()) {
return Err(ToolError::MaxToolCallsExceeded(current));
}
let concurrency = self.max_concurrency;
let calls_clone: Vec<ToolCall> = calls.clone();
// Execute tool calls concurrently but preserve input order for ID matching.
// buffer_unordered returns results in *completion* order, which mispairs IDs
// on concurrent errors. Instead, track each result with its original index.
let indexed_results: Vec<(usize, Result<ToolCallResult, ToolError>)> =
stream::iter(calls.into_iter().enumerate().map(|(i, call)| {
let child_ctx = ctx.child_context();
async move { (i, self.execute_one(call, child_ctx).await) }
}))
.buffer_unordered(concurrency)
.collect()
.await;
// Re-sort by original index to restore input order, then pair with original calls.
let mut result_map: std::collections::HashMap<usize, Result<ToolCallResult, ToolError>> =
indexed_results.into_iter().collect();
let results: Vec<ToolCallResult> = calls_clone
.into_iter()
.enumerate()
.map(|(i, call)| {
let r = result_map.remove(&i).unwrap_or_else(|| {
Err(ToolError::ExecutionError(
"missing result for tool call".into(),
))
});
r.unwrap_or_else(|e: ToolError| ToolCallResult::error(call, e.to_string()))
})
.collect();
Ok(results)
}
async fn execute_one(
&self,
call: ToolCall,
ctx: ToolContext,
) -> Result<ToolCallResult, ToolError> {
let handler = ctx
.registry()
.get(&call.name)
.ok_or_else(|| ToolError::NotFound(call.name.clone()))?
.clone();
let args = call.arguments_json()?;
match handler.execute(ctx, args).await {
Ok(value) => Ok(ToolCallResult::ok(call, value)),
Err(e) => Ok(ToolCallResult::error(call, e.to_string())),
}
}
pub fn to_tool_messages(results: &[ToolCallResult]) -> Vec<ChatRequestMessage> {
results
.iter()
.map(|r| {
let content = match &r.result {
ToolResult::Ok(v) => {
serde_json::to_string(v).unwrap_or_else(|_| "null".to_string())
}
ToolResult::Error(msg) => serde_json::to_string(&serde_json::json!({
"error": msg
}))
.unwrap_or_else(|_| r#"{"error":"unknown error"}"#.to_string()),
};
ChatRequestMessage::tool(&r.call.id, &content)
})
.collect()
}
}