//! Executes tool calls and converts results to OpenAI `tool` messages. use futures::StreamExt; use futures::stream; use async_openai::types::chat::{ ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, }; use super::call::{ToolCall, ToolCallResult, ToolError, ToolResult}; use super::context::ToolContext; pub struct ToolExecutor { max_tool_calls: usize, max_depth: u32, max_concurrency: usize, } impl Default for ToolExecutor { fn default() -> Self { Self { max_tool_calls: 128, max_depth: 5, max_concurrency: 8, } } } impl ToolExecutor { pub fn new() -> Self { Self::default() } pub fn with_max_tool_calls(mut self, max: usize) -> Self { self.max_tool_calls = max; self } pub fn with_max_depth(mut self, depth: u32) -> Self { self.max_depth = depth; self } /// Set the maximum number of tool calls executed concurrently. /// Defaults to 8. Set to 1 for strictly sequential execution. pub fn with_max_concurrency(mut self, n: usize) -> Self { self.max_concurrency = n; self } /// # Errors /// /// Returns `ToolError::MaxToolCallsExceeded` if the total number of tool calls /// exceeds `max_tool_calls`. pub async fn execute_batch( &self, calls: Vec, ctx: &mut ToolContext, ) -> Result, ToolError> { if ctx.tool_calls_exceeded() { return Err(ToolError::MaxToolCallsExceeded(ctx.tool_call_count())); } if ctx.recursion_exceeded() { return Err(ToolError::RecursionLimitExceeded { max_depth: ctx.depth(), }); } ctx.increment_tool_calls(); let concurrency = self.max_concurrency; use std::sync::Mutex; let results: Mutex> = Mutex::new(Vec::with_capacity(calls.len())); stream::iter(calls.into_iter().map(|call| { let child_ctx = ctx.child_context(); async move { self.execute_one(call, child_ctx).await } })) .buffer_unordered(concurrency) .for_each_concurrent( concurrency, |result: Result| async { let r = result.unwrap_or_else(|e| { ToolCallResult::error( ToolCall { id: String::new(), name: String::new(), arguments: String::new(), }, e.to_string(), ) }); results.lock().unwrap().push(r); }, ) .await; Ok(results.into_inner().unwrap()) } async fn execute_one( &self, call: ToolCall, ctx: ToolContext, ) -> Result { let handler = ctx .registry() .get(&call.name) .ok_or_else(|| ToolError::NotFound(call.name.clone()))? .clone(); let args = call.arguments_json()?; match handler.execute(ctx, args).await { Ok(value) => Ok(ToolCallResult::ok(call, value)), Err(e) => Ok(ToolCallResult::error(call, e.to_string())), } } pub fn to_tool_messages(results: &[ToolCallResult]) -> Vec { results .iter() .map(|r| { let content = match &r.result { ToolResult::Ok(v) => { serde_json::to_string(v).unwrap_or_else(|_| "null".to_string()) } ToolResult::Error(msg) => serde_json::to_string(&serde_json::json!({ "error": msg })) .unwrap_or_else(|_| r#"{"error":"unknown error"}"#.to_string()), }; ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { tool_call_id: r.call.id.clone(), content: ChatCompletionRequestToolMessageContent::Text(content), }) }) .collect() } }