143 lines
4.2 KiB
Rust
143 lines
4.2 KiB
Rust
//! Executes tool calls and converts results to OpenAI `tool` messages.
|
|
|
|
use futures::StreamExt;
|
|
use futures::stream;
|
|
|
|
use async_openai::types::chat::{
|
|
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
|
|
ChatCompletionRequestToolMessageContent,
|
|
};
|
|
|
|
use super::call::{ToolCall, ToolCallResult, ToolError, ToolResult};
|
|
use super::context::ToolContext;
|
|
|
|
pub struct ToolExecutor {
|
|
max_tool_calls: usize,
|
|
max_depth: u32,
|
|
max_concurrency: usize,
|
|
}
|
|
|
|
impl Default for ToolExecutor {
|
|
fn default() -> Self {
|
|
Self {
|
|
max_tool_calls: 128,
|
|
max_depth: 5,
|
|
max_concurrency: 8,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ToolExecutor {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn with_max_tool_calls(mut self, max: usize) -> Self {
|
|
self.max_tool_calls = max;
|
|
self
|
|
}
|
|
|
|
pub fn with_max_depth(mut self, depth: u32) -> Self {
|
|
self.max_depth = depth;
|
|
self
|
|
}
|
|
|
|
/// Set the maximum number of tool calls executed concurrently.
|
|
/// Defaults to 8. Set to 1 for strictly sequential execution.
|
|
pub fn with_max_concurrency(mut self, n: usize) -> Self {
|
|
self.max_concurrency = n;
|
|
self
|
|
}
|
|
|
|
/// # Errors
|
|
///
|
|
/// Returns `ToolError::MaxToolCallsExceeded` if the total number of tool calls
|
|
/// exceeds `max_tool_calls`.
|
|
pub async fn execute_batch(
|
|
&self,
|
|
calls: Vec<ToolCall>,
|
|
ctx: &mut ToolContext,
|
|
) -> Result<Vec<ToolCallResult>, ToolError> {
|
|
if ctx.tool_calls_exceeded() {
|
|
return Err(ToolError::MaxToolCallsExceeded(ctx.tool_call_count()));
|
|
}
|
|
if ctx.recursion_exceeded() {
|
|
return Err(ToolError::RecursionLimitExceeded {
|
|
max_depth: ctx.depth(),
|
|
});
|
|
}
|
|
|
|
ctx.increment_tool_calls();
|
|
|
|
let concurrency = self.max_concurrency;
|
|
use std::sync::Mutex;
|
|
let results: Mutex<Vec<ToolCallResult>> = Mutex::new(Vec::with_capacity(calls.len()));
|
|
|
|
stream::iter(calls.into_iter().map(|call| {
|
|
let child_ctx = ctx.child_context();
|
|
async move { self.execute_one(call, child_ctx).await }
|
|
}))
|
|
.buffer_unordered(concurrency)
|
|
.for_each_concurrent(
|
|
concurrency,
|
|
|result: Result<ToolCallResult, ToolError>| async {
|
|
let r = result.unwrap_or_else(|e| {
|
|
ToolCallResult::error(
|
|
ToolCall {
|
|
id: String::new(),
|
|
name: String::new(),
|
|
arguments: String::new(),
|
|
},
|
|
e.to_string(),
|
|
)
|
|
});
|
|
results.lock().unwrap().push(r);
|
|
},
|
|
)
|
|
.await;
|
|
|
|
Ok(results.into_inner().unwrap())
|
|
}
|
|
|
|
async fn execute_one(
|
|
&self,
|
|
call: ToolCall,
|
|
ctx: ToolContext,
|
|
) -> Result<ToolCallResult, ToolError> {
|
|
let handler = ctx
|
|
.registry()
|
|
.get(&call.name)
|
|
.ok_or_else(|| ToolError::NotFound(call.name.clone()))?
|
|
.clone();
|
|
|
|
let args = call.arguments_json()?;
|
|
|
|
match handler.execute(ctx, args).await {
|
|
Ok(value) => Ok(ToolCallResult::ok(call, value)),
|
|
Err(e) => Ok(ToolCallResult::error(call, e.to_string())),
|
|
}
|
|
}
|
|
|
|
pub fn to_tool_messages(results: &[ToolCallResult]) -> Vec<ChatCompletionRequestMessage> {
|
|
results
|
|
.iter()
|
|
.map(|r| {
|
|
let content = match &r.result {
|
|
ToolResult::Ok(v) => {
|
|
serde_json::to_string(v).unwrap_or_else(|_| "null".to_string())
|
|
}
|
|
ToolResult::Error(msg) => serde_json::to_string(&serde_json::json!({
|
|
"error": msg
|
|
}))
|
|
.unwrap_or_else(|_| r#"{"error":"unknown error"}"#.to_string()),
|
|
};
|
|
|
|
ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
|
tool_call_id: r.call.id.clone(),
|
|
content: ChatCompletionRequestToolMessageContent::Text(content),
|
|
})
|
|
})
|
|
.collect()
|
|
}
|
|
}
|