//! Executes tool calls and converts results to OpenAI `tool` messages. use futures::StreamExt; use futures::stream; use crate::client::ChatRequestMessage; use super::call::{ToolCall, ToolCallResult, ToolError, ToolResult}; use super::context::ToolContext; pub struct ToolExecutor { max_tool_calls: usize, max_depth: u32, max_concurrency: usize, } impl Default for ToolExecutor { fn default() -> Self { Self { max_tool_calls: 128, max_depth: 5, max_concurrency: 8, } } } impl ToolExecutor { pub fn new() -> Self { Self::default() } pub fn with_max_tool_calls(mut self, max: usize) -> Self { self.max_tool_calls = max; self } pub fn with_max_depth(mut self, depth: u32) -> Self { self.max_depth = depth; self } /// Set the maximum number of tool calls executed concurrently. /// Defaults to 8. Set to 1 for strictly sequential execution. pub fn with_max_concurrency(mut self, n: usize) -> Self { self.max_concurrency = n; self } /// # Errors /// /// Returns `ToolError::MaxToolCallsExceeded` if the total number of tool calls /// exceeds `max_tool_calls`. pub async fn execute_batch( &self, calls: Vec, ctx: &mut ToolContext, ) -> Result, ToolError> { if ctx.tool_calls_exceeded() { return Err(ToolError::MaxToolCallsExceeded(ctx.tool_call_count())); } if ctx.recursion_exceeded() { return Err(ToolError::RecursionLimitExceeded { max_depth: ctx.depth(), }); } ctx.increment_tool_calls(); let concurrency = self.max_concurrency; let calls_clone: Vec = calls.clone(); // Execute tool calls concurrently but preserve input order for ID matching. // buffer_unordered returns results in *completion* order, which mispairs IDs // on concurrent errors. Instead, track each result with its original index. let indexed_results: Vec<(usize, Result)> = stream::iter(calls.into_iter().enumerate().map(|(i, call)| { let child_ctx = ctx.child_context(); async move { (i, self.execute_one(call, child_ctx).await) } })) .buffer_unordered(concurrency) .collect() .await; // Re-sort by original index to restore input order, then pair with original calls. let mut result_map: std::collections::HashMap> = indexed_results.into_iter().collect(); let results: Vec = calls_clone .into_iter() .enumerate() .map(|(i, call)| { let r = result_map.remove(&i).unwrap_or_else(|| { Err(ToolError::ExecutionError( "missing result for tool call".into(), )) }); r.unwrap_or_else(|e: ToolError| ToolCallResult::error(call, e.to_string())) }) .collect(); Ok(results) } async fn execute_one( &self, call: ToolCall, ctx: ToolContext, ) -> Result { let handler = ctx .registry() .get(&call.name) .ok_or_else(|| ToolError::NotFound(call.name.clone()))? .clone(); let args = call.arguments_json()?; match handler.execute(ctx, args).await { Ok(value) => Ok(ToolCallResult::ok(call, value)), Err(e) => Ok(ToolCallResult::error(call, e.to_string())), } } pub fn to_tool_messages(results: &[ToolCallResult]) -> Vec { results .iter() .map(|r| { let content = match &r.result { ToolResult::Ok(v) => { serde_json::to_string(v).unwrap_or_else(|_| "null".to_string()) } ToolResult::Error(msg) => serde_json::to_string(&serde_json::json!({ "error": msg })) .unwrap_or_else(|_| r#"{"error":"unknown error"}"#.to_string()), }; ChatRequestMessage::tool(&r.call.id, &content) }) .collect() } }