320 lines
12 KiB
Rust
320 lines
12 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use super::agent_profile::{
|
|
analyst_profile, researcher_profile, reviewer_profile, should_enable_delegation,
|
|
supervisor_profile,
|
|
};
|
|
use super::message_builder::MessageBuilder;
|
|
use super::nonstreaming_execution::execute_process;
|
|
use super::service::{ProcessResult, StreamResult};
|
|
use super::{AiRequest, StreamCallback};
|
|
use crate::error::Result;
|
|
use crate::tool::call::ToolError;
|
|
use crate::tool::registry::ToolRegistry;
|
|
use crate::tool::{ToolDefinition, ToolHandler, ToolParam, ToolSchema};
|
|
|
|
pub async fn execute_orchestrated_process(
|
|
request: AiRequest,
|
|
message_builder: &MessageBuilder,
|
|
tool_registry: &Option<ToolRegistry>,
|
|
ai_base_url: Option<String>,
|
|
ai_api_key: Option<String>,
|
|
) -> Result<ProcessResult> {
|
|
if request
|
|
.execution_profile
|
|
.as_ref()
|
|
.is_some_and(|p| p.disable_orchestration)
|
|
{
|
|
return execute_process(
|
|
request,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let tools = request.tools.clone().unwrap_or_default();
|
|
if !should_enable_delegation(&request.input, !tools.is_empty()) {
|
|
return execute_process(
|
|
request,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let mut enhanced_registry = tool_registry.clone().unwrap_or_default();
|
|
register_call_sub_agent_tool(
|
|
&mut enhanced_registry,
|
|
&request,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url.clone(),
|
|
ai_api_key.clone(),
|
|
);
|
|
|
|
let mut supervisor_request = request.clone();
|
|
let profile = supervisor_profile();
|
|
supervisor_request.execution_profile = Some(profile.clone());
|
|
supervisor_request.tools = Some(enhanced_registry.to_openai_tools());
|
|
supervisor_request.temperature = profile.temperature.unwrap_or(request.temperature);
|
|
supervisor_request.max_tokens = profile.max_tokens.unwrap_or(request.max_tokens);
|
|
supervisor_request.top_p = profile.top_p.unwrap_or(request.top_p);
|
|
supervisor_request.frequency_penalty = profile
|
|
.frequency_penalty
|
|
.unwrap_or(request.frequency_penalty);
|
|
supervisor_request.presence_penalty = profile
|
|
.presence_penalty
|
|
.unwrap_or(request.presence_penalty);
|
|
|
|
execute_process(
|
|
supervisor_request,
|
|
message_builder,
|
|
&Some(enhanced_registry),
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn execute_orchestrated_stream(
|
|
request: AiRequest,
|
|
on_chunk: StreamCallback,
|
|
message_builder: &MessageBuilder,
|
|
tool_registry: &Option<ToolRegistry>,
|
|
ai_base_url: Option<String>,
|
|
ai_api_key: Option<String>,
|
|
) -> Result<StreamResult> {
|
|
if request
|
|
.execution_profile
|
|
.as_ref()
|
|
.is_some_and(|p| p.disable_orchestration)
|
|
{
|
|
return super::streaming_execution::execute_process_stream(
|
|
request,
|
|
on_chunk,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let tools = request.tools.clone().unwrap_or_default();
|
|
if !should_enable_delegation(&request.input, !tools.is_empty()) {
|
|
return super::streaming_execution::execute_process_stream(
|
|
request,
|
|
on_chunk,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
let mut enhanced_registry = tool_registry.clone().unwrap_or_default();
|
|
register_call_sub_agent_tool(
|
|
&mut enhanced_registry,
|
|
&request,
|
|
message_builder,
|
|
tool_registry,
|
|
ai_base_url.clone(),
|
|
ai_api_key.clone(),
|
|
);
|
|
|
|
let mut supervisor_request = request.clone();
|
|
let profile = supervisor_profile();
|
|
supervisor_request.execution_profile = Some(profile.clone());
|
|
supervisor_request.tools = Some(enhanced_registry.to_openai_tools());
|
|
supervisor_request.temperature = profile.temperature.unwrap_or(request.temperature);
|
|
supervisor_request.max_tokens = profile.max_tokens.unwrap_or(request.max_tokens);
|
|
supervisor_request.top_p = profile.top_p.unwrap_or(request.top_p);
|
|
supervisor_request.frequency_penalty = profile
|
|
.frequency_penalty
|
|
.unwrap_or(request.frequency_penalty);
|
|
supervisor_request.presence_penalty = profile
|
|
.presence_penalty
|
|
.unwrap_or(request.presence_penalty);
|
|
|
|
super::streaming_execution::execute_process_stream(
|
|
supervisor_request,
|
|
on_chunk,
|
|
message_builder,
|
|
&Some(enhanced_registry),
|
|
ai_base_url,
|
|
ai_api_key,
|
|
)
|
|
.await
|
|
}
|
|
|
|
fn register_call_sub_agent_tool(
|
|
registry: &mut ToolRegistry,
|
|
request: &AiRequest,
|
|
message_builder: &MessageBuilder,
|
|
original_registry: &Option<ToolRegistry>,
|
|
ai_base_url: Option<String>,
|
|
ai_api_key: Option<String>,
|
|
) {
|
|
let captured_request = request.clone();
|
|
let captured_message_builder = message_builder.clone();
|
|
let captured_original_registry = original_registry.clone();
|
|
let captured_base_url = ai_base_url;
|
|
let captured_api_key = ai_api_key;
|
|
|
|
registry.register(
|
|
ToolDefinition::new("call_sub_agent")
|
|
.description(
|
|
"Delegate a task to a specialist sub-agent and receive its output.\n\
|
|
Available roles:\n\
|
|
- researcher: Gathers facts, evidence, and data. Best for finding information and searching code.\n\
|
|
- analyst: Builds explanations, highlights causal links and tradeoffs. Best for reasoning about implications.\n\
|
|
- reviewer: Stress-tests proposals, identifies risks and contradictions. Best for quality checks.\n\
|
|
Provide a clear, focused task description so the sub-agent knows exactly what to investigate.",
|
|
)
|
|
.parameters(ToolSchema {
|
|
schema_type: "object".into(),
|
|
properties: Some({
|
|
let mut p = HashMap::new();
|
|
p.insert(
|
|
"role".into(),
|
|
ToolParam {
|
|
name: "role".into(),
|
|
param_type: "string".into(),
|
|
description: Some(
|
|
"The sub-agent role to delegate to: researcher, analyst, or reviewer.".into(),
|
|
),
|
|
required: true,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p.insert(
|
|
"task".into(),
|
|
ToolParam {
|
|
name: "task".into(),
|
|
param_type: "string".into(),
|
|
description: Some(
|
|
"The specific task or question for the sub-agent. Be precise and focused.".into(),
|
|
),
|
|
required: true,
|
|
properties: None,
|
|
items: None,
|
|
},
|
|
);
|
|
p
|
|
}),
|
|
required: Some(vec!["role".into(), "task".into()]),
|
|
}),
|
|
ToolHandler::new(move |_ctx, args| {
|
|
// Extract owned values from args before async move (avoid borrowing across boundary)
|
|
let role = args
|
|
.get("role")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("researcher")
|
|
.to_owned();
|
|
let task = args
|
|
.get("task")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.to_owned();
|
|
|
|
let profile = match role.as_str() {
|
|
"researcher" => researcher_profile(),
|
|
"analyst" => analyst_profile(),
|
|
"reviewer" => reviewer_profile(),
|
|
_ => researcher_profile(),
|
|
};
|
|
|
|
let mut sub_request = captured_request.clone();
|
|
sub_request.input = format!(
|
|
"Sub-agent role: {role}\n\nTask:\n{task}\n\nOriginal user request:\n{}\n\nInstructions:\nFocus only on your assigned task. Return concise, evidence-backed findings.",
|
|
captured_request.input
|
|
);
|
|
sub_request.execution_profile = Some(profile.clone());
|
|
sub_request.tools = Some(filter_tools_for_sub_agent(
|
|
&captured_request.tools,
|
|
&profile.allowed_tools,
|
|
));
|
|
sub_request.max_tool_depth = profile
|
|
.max_tool_depth
|
|
.unwrap_or(captured_request.max_tool_depth);
|
|
sub_request.temperature = profile.temperature.unwrap_or(captured_request.temperature);
|
|
sub_request.max_tokens = profile.max_tokens.unwrap_or(captured_request.max_tokens);
|
|
sub_request.top_p = profile.top_p.unwrap_or(captured_request.top_p);
|
|
sub_request.frequency_penalty = profile
|
|
.frequency_penalty
|
|
.unwrap_or(captured_request.frequency_penalty);
|
|
sub_request.presence_penalty = profile
|
|
.presence_penalty
|
|
.unwrap_or(captured_request.presence_penalty);
|
|
|
|
// Clone captured values for this invocation so the Fn closure retains them
|
|
let mb = captured_message_builder.clone();
|
|
let sub_registry = captured_original_registry.clone();
|
|
let base = captured_base_url.clone();
|
|
let key = captured_api_key.clone();
|
|
|
|
Box::pin(async move {
|
|
let result = execute_process(sub_request, &mb, &sub_registry, base, key).await;
|
|
match result {
|
|
Ok(r) => Ok(serde_json::json!({
|
|
"role": role,
|
|
"output": r.content,
|
|
"input_tokens": r.input_tokens,
|
|
"output_tokens": r.output_tokens,
|
|
})),
|
|
Err(e) => Err(ToolError::ExecutionError(format!(
|
|
"Sub-agent '{}' execution failed: {}",
|
|
role, e
|
|
))),
|
|
}
|
|
})
|
|
}),
|
|
);
|
|
}
|
|
|
|
/// Filter the original tool definitions by the sub-agent's allowed list,
|
|
/// always excluding `call_sub_agent` to prevent recursive delegation.
|
|
fn filter_tools_for_sub_agent(
|
|
original_tools: &Option<Vec<serde_json::Value>>,
|
|
allowed_tools: &Option<Vec<String>>,
|
|
) -> Vec<serde_json::Value> {
|
|
let Some(tools) = original_tools else {
|
|
return Vec::new();
|
|
};
|
|
let allowed = allowed_tools
|
|
.as_ref()
|
|
.map(|list| list.iter().filter(|n| *n != "call_sub_agent").cloned().collect::<Vec<String>>());
|
|
|
|
match allowed {
|
|
Some(allowed_list) if !allowed_list.is_empty() => tools
|
|
.iter()
|
|
.filter(|tool| {
|
|
let name = tool
|
|
.get("function")
|
|
.and_then(|f| f.get("name"))
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("");
|
|
allowed_list.iter().any(|allowed| allowed == name)
|
|
})
|
|
.cloned()
|
|
.collect(),
|
|
_ => tools
|
|
.iter()
|
|
.filter(|tool| {
|
|
tool
|
|
.get("function")
|
|
.and_then(|f| f.get("name"))
|
|
.and_then(|v| v.as_str())
|
|
.is_some_and(|name| name != "call_sub_agent")
|
|
})
|
|
.cloned()
|
|
.collect(),
|
|
}
|
|
} |