gitdataai/libs/agent/tool/executor.rs
2026-04-14 19:02:01 +08:00

143 lines
4.2 KiB
Rust

//! Executes tool calls and converts results to OpenAI `tool` messages.
use futures::StreamExt;
use futures::stream;
use async_openai::types::chat::{
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
ChatCompletionRequestToolMessageContent,
};
use super::call::{ToolCall, ToolCallResult, ToolError, ToolResult};
use super::context::ToolContext;
pub struct ToolExecutor {
max_tool_calls: usize,
max_depth: u32,
max_concurrency: usize,
}
impl Default for ToolExecutor {
fn default() -> Self {
Self {
max_tool_calls: 128,
max_depth: 5,
max_concurrency: 8,
}
}
}
impl ToolExecutor {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_tool_calls(mut self, max: usize) -> Self {
self.max_tool_calls = max;
self
}
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
/// Set the maximum number of tool calls executed concurrently.
/// Defaults to 8. Set to 1 for strictly sequential execution.
pub fn with_max_concurrency(mut self, n: usize) -> Self {
self.max_concurrency = n;
self
}
/// # Errors
///
/// Returns `ToolError::MaxToolCallsExceeded` if the total number of tool calls
/// exceeds `max_tool_calls`.
pub async fn execute_batch(
&self,
calls: Vec<ToolCall>,
ctx: &mut ToolContext,
) -> Result<Vec<ToolCallResult>, ToolError> {
if ctx.tool_calls_exceeded() {
return Err(ToolError::MaxToolCallsExceeded(ctx.tool_call_count()));
}
if ctx.recursion_exceeded() {
return Err(ToolError::RecursionLimitExceeded {
max_depth: ctx.depth(),
});
}
ctx.increment_tool_calls();
let concurrency = self.max_concurrency;
use std::sync::Mutex;
let results: Mutex<Vec<ToolCallResult>> = Mutex::new(Vec::with_capacity(calls.len()));
stream::iter(calls.into_iter().map(|call| {
let child_ctx = ctx.child_context();
async move { self.execute_one(call, child_ctx).await }
}))
.buffer_unordered(concurrency)
.for_each_concurrent(
concurrency,
|result: Result<ToolCallResult, ToolError>| async {
let r = result.unwrap_or_else(|e| {
ToolCallResult::error(
ToolCall {
id: String::new(),
name: String::new(),
arguments: String::new(),
},
e.to_string(),
)
});
results.lock().unwrap().push(r);
},
)
.await;
Ok(results.into_inner().unwrap())
}
async fn execute_one(
&self,
call: ToolCall,
ctx: ToolContext,
) -> Result<ToolCallResult, ToolError> {
let handler = ctx
.registry()
.get(&call.name)
.ok_or_else(|| ToolError::NotFound(call.name.clone()))?
.clone();
let args = call.arguments_json()?;
match handler.execute(ctx, args).await {
Ok(value) => Ok(ToolCallResult::ok(call, value)),
Err(e) => Ok(ToolCallResult::error(call, e.to_string())),
}
}
pub fn to_tool_messages(results: &[ToolCallResult]) -> Vec<ChatCompletionRequestMessage> {
results
.iter()
.map(|r| {
let content = match &r.result {
ToolResult::Ok(v) => {
serde_json::to_string(v).unwrap_or_else(|_| "null".to_string())
}
ToolResult::Error(msg) => serde_json::to_string(&serde_json::json!({
"error": msg
}))
.unwrap_or_else(|_| r#"{"error":"unknown error"}"#.to_string()),
};
ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: r.call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(content),
})
})
.collect()
}
}