gitdataai/libs/agent/chat/chat_execution.rs
ZhenYi 8d144ac139 feat(agent): add architect, debugger, implementer, tester, security sub-agent roles
Extend delegation system with 5 new specialized roles alongside
researcher/analyst/reviewer. Each role has curated tool access.
Refactor profile lookup to use profile_for_role_name and update
compact/summarizer and tool context accordingly.
2026-05-18 20:42:57 +08:00

1257 lines
48 KiB
Rust

use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::agent_profile::profile_for_role_name;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::embed::EmbedService;
use crate::error::Result;
use crate::tool::registry::ToolRegistry;
use crate::tool::{
ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam,
};
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
use super::service::StreamResult;
use super::{AiChunkType, AiStreamChunk, StreamCallback};
struct SubAgentRunResult {
output: String,
input_tokens: i64,
output_tokens: i64,
cancelled: bool,
error: Option<String>,
}
/// Persist a sub-agent session record to the database.
async fn persist_sub_agent_session(
db: &db::database::AppDatabase,
conversation_id: Uuid,
children_id: &str,
role: &str,
task: &str,
output: &str,
input_tokens: i64,
output_tokens: i64,
model_name: &str,
status: &str,
error_message: Option<String>,
) {
use models::ai::ai_subagent_session;
use sea_orm::{ActiveModelTrait, Set};
let record = ai_subagent_session::ActiveModel {
id: Set(Uuid::now_v7()),
conversation_id: Set(conversation_id),
message_id: Set(Uuid::nil()),
children_id: Set(children_id.to_string()),
role: Set(role.to_string()),
task: Set(task.to_string()),
output: Set(output.to_string()),
input_tokens: Set(input_tokens),
output_tokens: Set(output_tokens),
model_name: Set(Some(model_name.to_string())),
status: Set(status.to_string()),
error_message: Set(error_message),
created_at: Set(chrono::Utc::now()),
};
if let Err(e) = record.insert(db.writer()).await {
tracing::warn!(error = %e, children_id = %children_id, "failed to persist sub-agent session");
}
}
/// Execute a sub-agent call with streaming output via NATS.
///
/// The sub-agent output is streamed to NATS JetStream subject
/// `chat.subagent.chunk.{conversation_id}.{children_id}` so the frontend
/// can subscribe via the `/api/ai/subagent/{conversation_id}/{children_id}/stream` endpoint.
///
/// Returns the full or partial output after the sub-agent completes or is cancelled.
async fn call_sub_agent_stream(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tools: Option<&[serde_json::Value]>,
tool_registry: Option<ToolRegistry>,
db: db::database::AppDatabase,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
children_id: &str,
conversation_id: Option<uuid::Uuid>,
cache: db::cache::AppCache,
queue_producer: Option<&queue::MessageProducer>,
) -> Result<SubAgentRunResult> {
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
let conversation_id = conversation_id.unwrap_or_default();
let seq = Arc::new(AtomicU64::new(0));
let children_id_owned = children_id.to_string();
let queue_ref = queue_producer.cloned();
let partial_output = Arc::new(tokio::sync::Mutex::new(String::new()));
let (delta_tx, mut delta_rx) = mpsc::unbounded_channel::<(&'static str, String)>();
cache
.clear_sub_agent_cancelled(conversation_id, &children_id_owned)
.await;
let stream_fut = async {
let mut messages = messages.to_vec();
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut last_content = String::new();
let mut tool_depth = 0usize;
loop {
let response = call_stream(
&messages,
model_name,
config,
temperature,
max_tokens,
tools,
None,
Arc::new({
let partial_output = partial_output.clone();
let delta_tx = delta_tx.clone();
move |delta| {
let content = delta.to_string();
let partial_output = partial_output.clone();
let delta_tx = delta_tx.clone();
Box::pin(async move {
partial_output.lock().await.push_str(&content);
let _ = delta_tx.send(("token", content));
})
as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}
}),
Arc::new({
let delta_tx = delta_tx.clone();
move |delta| {
let content = delta.to_string();
let delta_tx = delta_tx.clone();
Box::pin(async move {
let _ = delta_tx.send(("thinking", content));
})
as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}
}),
Arc::new(move |_tc: &StreamedToolCall| {
Box::pin(async move {}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
if !response.content.is_empty() {
last_content = response.content.clone();
}
if response.tool_calls.is_empty() {
return Ok::<SubAgentRunResult, crate::error::AgentError>(SubAgentRunResult {
output: response.content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled: false,
error: None,
});
}
if tool_depth >= max_tool_depth {
let fallback_output = if last_content.is_empty() {
"Sub-agent reached maximum tool depth before producing a final summary."
.to_string()
} else {
last_content
};
return Ok::<SubAgentRunResult, crate::error::AgentError>(SubAgentRunResult {
output: fallback_output,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled: false,
error: Some(format!(
"sub-agent reached maximum tool depth ({max_tool_depth}) before final summary"
)),
});
}
let assistant_tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(assistant_tool_calls),
));
let agent_tool_calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
let tool_messages = execute_sub_agent_tools(
&agent_tool_calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry.as_ref(),
embed_service.as_ref(),
)
.await;
messages.extend(tool_messages);
messages.push(ChatRequestMessage::user(
"Use the tool results above to produce your final concise findings. Do not call another tool unless it is strictly necessary.",
));
tool_depth += 1;
}
};
let children_id_for_cancel = children_id_owned.clone();
let cancel_fut = async {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
loop {
interval.tick().await;
if cache
.is_sub_agent_cancelled(conversation_id, &children_id_for_cancel)
.await
{
break;
}
}
};
let timeout_fut = tokio::time::sleep(std::time::Duration::from_secs(60));
let flush_queue = queue_ref.clone();
let flush_children_id = children_id_owned.clone();
let flush_seq = seq.clone();
let mut flush_handle = tokio::spawn(async move {
let Some(queue) = flush_queue else {
while delta_rx.recv().await.is_some() {}
return;
};
let mut token_buf = String::new();
let mut thinking_buf = String::new();
let mut interval = tokio::time::interval(std::time::Duration::from_millis(50));
async fn flush(
queue: &queue::MessageProducer,
conversation_id: Uuid,
children_id: &str,
seq: &Arc<AtomicU64>,
chunk_type: &str,
buffer: &mut String,
) {
if buffer.is_empty() {
return;
}
let event = queue::types::SubAgentStreamChunkEvent {
conversation_id,
children_id: children_id.to_string(),
seq: seq.fetch_add(1, Ordering::Relaxed),
content: std::mem::take(buffer),
done: false,
error: None,
chunk_type: Some(chunk_type.to_string()),
role: String::new(),
task: String::new(),
};
queue.publish_sub_agent_chunk_realtime(&event).await;
}
loop {
tokio::select! {
Some((kind, content)) = delta_rx.recv() => {
let target = if kind == "thinking" { &mut thinking_buf } else { &mut token_buf };
target.push_str(&content);
if target.len() >= 240 {
flush(&queue, conversation_id, &flush_children_id, &flush_seq, kind, target).await;
}
}
_ = interval.tick() => {
flush(&queue, conversation_id, &flush_children_id, &flush_seq, "thinking", &mut thinking_buf).await;
flush(&queue, conversation_id, &flush_children_id, &flush_seq, "token", &mut token_buf).await;
}
else => break,
}
}
flush(
&queue,
conversation_id,
&flush_children_id,
&flush_seq,
"thinking",
&mut thinking_buf,
)
.await;
flush(
&queue,
conversation_id,
&flush_children_id,
&flush_seq,
"token",
&mut token_buf,
)
.await;
});
let response = tokio::select! {
result = stream_fut => {
match result {
Ok(response) => Some(response),
Err(e) => Some(SubAgentRunResult {
output: partial_output.lock().await.clone(),
input_tokens: 0,
output_tokens: 0,
cancelled: false,
error: Some(e.to_string()),
}),
}
}
_ = cancel_fut => None,
_ = timeout_fut => Some(SubAgentRunResult {
output: partial_output.lock().await.clone(),
input_tokens: 0,
output_tokens: 0,
cancelled: false,
error: Some("sub-agent timed out after 60 seconds".to_string()),
}),
};
drop(delta_tx);
if tokio::time::timeout(std::time::Duration::from_secs(2), &mut flush_handle)
.await
.is_err()
{
flush_handle.abort();
tracing::warn!(
children_id = %children_id,
"sub-agent stream flush timed out; continuing with terminal event"
);
}
let cancelled = response.is_none();
let (total_content, total_input_tokens, total_output_tokens, terminal_error) = match response {
Some(response) => (
response.output.clone(),
response.input_tokens,
response.output_tokens,
response.error.clone(),
),
None => (partial_output.lock().await.clone(), 0, 0, None),
};
// Send final done/stopped chunk.
let final_seq = seq.load(Ordering::Relaxed);
let event = queue::types::SubAgentStreamChunkEvent {
conversation_id,
children_id: children_id_owned,
seq: final_seq,
content: String::new(),
done: true,
error: terminal_error.clone(),
chunk_type: Some(
if terminal_error.is_some() {
"error"
} else if cancelled {
"stopped"
} else {
"done"
}
.to_string(),
),
role: String::new(),
task: String::new(),
};
if let Some(q) = queue_ref {
if tokio::time::timeout(
std::time::Duration::from_secs(1),
q.publish_sub_agent_chunk(&event),
)
.await
.is_err()
{
tracing::warn!(
children_id = %event.children_id,
"sub-agent terminal event publish timed out"
);
}
}
Ok(SubAgentRunResult {
output: total_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled,
error: terminal_error,
})
}
// Keyword-extraction-based title generator: reads conversation messages, extracts
// meaningful words, and updates the conversation record with a short title.
async fn generate_title_for_conversation(
ctx: &ToolContext,
conversation_id: Uuid,
) -> Result<serde_json::Value> {
use models::ai::{AiMessage, ai_conversation, ai_message};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
let db_reader = ctx.db().reader();
let db_writer = ctx.db().writer();
let conv = ai_conversation::Entity::find_by_id(conversation_id)
.one(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("db error: {}", e),
})?
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
let recent_messages = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_desc(ai_message::Column::CreatedAt)
.limit(3)
.all(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("db error: {}", e),
})?;
if recent_messages.is_empty() {
return Err(crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: "No user messages found".into(),
});
}
let content = recent_messages
.first()
.and_then(|m| m.content.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let words: Vec<&str> = content
.split_whitespace()
.filter(|w| w.len() > 2 && !is_stop_word(w))
.take(5)
.collect();
let title = if words.is_empty() {
"New Chat".to_string()
} else {
words.join(" ")
};
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(title.clone()));
active.updated_at = Set(chrono::Utc::now());
active
.update(db_writer)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("failed to update title: {}", e),
})?;
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
}
fn is_stop_word(w: &str) -> bool {
matches!(
w.to_lowercase().as_str(),
"the"
| "this"
| "that"
| "what"
| "which"
| "when"
| "where"
| "why"
| "how"
| "can"
| "could"
| "would"
| "should"
| "please"
| "help"
| "thanks"
| "thank"
| "you"
| "your"
| "have"
| "has"
| "had"
| "with"
| "for"
| "from"
| "into"
| "about"
| "also"
| "just"
| "now"
| "very"
| "really"
)
}
type SharedCallback = Arc<
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
>;
/// Simplified ReAct execution for Chat API.
///
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
/// this function takes messages and tools directly. It does NOT record AI sessions to
/// the `ai_session` table -the caller is responsible for persisting results.
pub async fn execute_chat_stream(
messages: Vec<ChatRequestMessage>,
tools: Vec<serde_json::Value>,
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tool_registry: Option<&ToolRegistry>,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
on_chunk: StreamCallback,
conversation_id: Option<uuid::Uuid>,
queue_producer: Option<queue::MessageProducer>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools_enabled = !tools.is_empty();
let mut messages = messages;
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
// Conditionally inject chat_generate_title tool if conversation has no title
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
if let Some(registry) = tool_registry {
let db_reader = db.reader();
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
.one(db_reader)
.await
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
.unwrap_or(false);
if !has_title {
let mut reg = registry.clone();
reg.register(
ToolDefinition::new("chat_generate_title")
.description(
"Generate a concise title (5 words or fewer) for the current conversation \
based on its message history, and save it to the conversation record. \
Call this tool at the start of a new conversation if it has no title.",
)
.parameters(crate::tool::ToolSchema {
schema_type: "object".into(),
properties: Some({
let mut p = std::collections::HashMap::new();
p.insert("conversation_id".into(), ToolParam {
name: "conversation_id".into(),
param_type: "string".into(),
description: Some("The UUID of the conversation (required).".into()),
required: true,
properties: None,
items: None,
});
p
}),
required: Some(vec!["conversation_id".into()]),
}),
ToolHandler::new(|ctx, args| {
let conv_id = args.get("conversation_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok());
Box::pin(async move {
match conv_id {
Some(id) => generate_title_for_conversation(&ctx, id).await
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
}
})
}),
);
// Prepend system message instructing the model to generate title first
messages.insert(0, ChatRequestMessage::system(
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
with the conversation_id immediately before answering any user question. \
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
));
(reg.to_openai_tools(), true)
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
};
// Add call_sub_agent tool for chat orchestration when tools are available
let tools = if tools_enabled {
let mut t = tools;
t.push(serde_json::json!({
"type": "function",
"function": {
"name": "call_sub_agent",
"description": "Delegate a task to a specialist sub-agent and receive its output.\nAvailable roles:\n- researcher: Gathers facts, evidence, and data. Best for finding information and searching code.\n- analyst: Builds explanations, highlights causal links and tradeoffs. Best for reasoning about implications.\n- reviewer: Stress-tests proposals, identifies risks and contradictions. Best for quality checks.\n- architect: Maps systems, dependencies, boundaries, and design tradeoffs. Best for architecture decisions.\n- debugger: Finds root causes, suspect changes, and validation paths. Best for bugs and regressions.\n- implementer: Converts requirements into concrete implementation steps. Best for execution planning.\n- tester: Designs validation and regression coverage. Best for test strategy.\n- security: Reviews auth, data exposure, injection, dependency, and abuse risks. Best for sensitive changes.\nProvide a clear, focused task description so the sub-agent knows exactly what to investigate.",
"parameters": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The sub-agent role to delegate to: researcher, analyst, reviewer, architect, debugger, implementer, tester, or security."
},
"task": {
"type": "string",
"description": "The specific task or question for the sub-agent. Be precise and focused."
}
},
"required": ["role", "task"]
}
}
}));
t
} else {
tools
};
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages,
model_name,
config,
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
None,
Arc::new(move |delta| {
let content = delta.to_string();
let fut = on_chunk_cb(AiStreamChunk {
content,
done: false,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
});
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
metadata: None,
children_id: None,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
let final_content = response.content.clone();
// Don't push full content as a chunk -incremental deltas in
// response.chunks (already added above) sum to the same text.
// merge_consecutive_blocks would concatenate delta_sum + full =
// 2x full, causing duplicate content in DB persistence.
return Ok(StreamResult {
content: final_content,
reasoning_content: response.reasoning_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
chunks: all_chunks,
});
}
full_content.push_str(&response.content);
let tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(tool_calls.clone()),
));
// Drain tool call notifications
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc
.arguments
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= 100)
.last()
.unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else {
tc.arguments.clone()
};
let tool_display = format!("[tool] {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: None,
children_id: None,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
let (sub_agent_calls, regular_calls): (Vec<AgentToolCall>, Vec<AgentToolCall>) = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect::<Vec<_>>()
.into_iter()
.partition(|c| c.name == "call_sub_agent");
let mut tool_messages = Vec::new();
let mut sub_agent_tasks = tokio::task::JoinSet::new();
let mut sub_agent_ids: Vec<String> = Vec::new();
// Handle call_sub_agent calls inline -stream sub-agent output via NATS
for sub_call in sub_agent_calls {
let args: serde_json::Value = match serde_json::from_str(&sub_call.arguments) {
Ok(v) => v,
Err(_) => {
tool_messages.push(ChatRequestMessage::tool(
&sub_call.id,
"Failed to parse call_sub_agent arguments",
));
continue;
}
};
let role = args
.get("role")
.and_then(|v| v.as_str())
.unwrap_or("researcher");
let task = args.get("task").and_then(|v| v.as_str()).unwrap_or("");
let profile = profile_for_role_name(role);
// Generate children_id BEFORE starting sub-agent execution
let sub_agent_id = format!("sub-agent-{}", Uuid::now_v7());
sub_agent_ids.push(sub_agent_id.clone());
// Emit tool_call chunk immediately with children_id so frontend can start watching
let call_display =
format!("[tool] call_sub_agent({role}) - delegating to {role} agent...");
on_chunk(AiStreamChunk {
content: call_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: Some(serde_json::json!({
"tool": "call_sub_agent",
"args": { "role": role.to_string(), "task": task.to_string() },
"display": call_display,
})),
children_id: Some(sub_agent_id.clone()),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: call_display,
});
let sub_system = profile.system_prompt.clone().unwrap_or_default();
let sub_messages = vec![
ChatRequestMessage::system(sub_system),
ChatRequestMessage::user(format!(
"Sub-agent role: {role}\n\nTask:\n{task}\n\nFocus only on your assigned task. Return concise, evidence-backed findings."
)),
];
// Filter tools for the sub-agent: only include tools in the profile's allowed list,
// always excluding call_sub_agent and chat_generate_title
let sub_tools: Vec<serde_json::Value> = if let Some(ref allowed) = profile.allowed_tools
{
tools
.iter()
.filter(|t| {
let name = t
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("");
allowed.contains(&name.to_string())
&& name != "call_sub_agent"
&& name != "chat_generate_title"
})
.cloned()
.collect()
} else {
tools
.iter()
.filter(|t| {
let name = t
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("");
name != "call_sub_agent" && name != "chat_generate_title"
})
.cloned()
.collect()
};
let call_id = sub_call.id.clone();
let role_owned = role.to_string();
let task_owned = task.to_string();
let sub_agent_id_owned = sub_agent_id.clone();
let model_name_owned = model_name.to_string();
let config_owned = config.clone();
let cache_owned = cache.clone();
let db_owned = db.clone();
let app_config_owned = app_config.clone();
let embed_service_owned = embed_service.clone();
let tool_registry_owned = tool_registry.cloned();
let queue_owned = queue_producer.clone();
let conversation_id_owned = conversation_id;
let temperature = profile.temperature.unwrap_or(0.7) as f32;
let max_tokens = profile.max_tokens.unwrap_or(4096) as u32;
let sub_max_tool_depth = profile.max_tool_depth.unwrap_or(4) as usize;
sub_agent_tasks.spawn(async move {
let result = call_sub_agent_stream(
&sub_messages,
&model_name_owned,
&config_owned,
temperature,
max_tokens,
sub_max_tool_depth,
Some(&sub_tools),
tool_registry_owned,
db_owned,
app_config_owned,
project_id,
sender_uid,
embed_service_owned,
&sub_agent_id_owned,
conversation_id_owned,
cache_owned,
queue_owned.as_ref(),
)
.await;
(
call_id,
sub_agent_id_owned,
role_owned,
task_owned,
model_name_owned,
result,
)
});
}
let mut cancelled_batch = false;
while let Some(joined) = sub_agent_tasks.join_next().await {
let Ok((call_id, sub_agent_id, role, task, sub_model_name, result)) = joined else {
continue;
};
match result {
Ok(result) => {
if result.cancelled && !cancelled_batch {
cancelled_batch = true;
if let Some(conv_id) = conversation_id {
for id in &sub_agent_ids {
if id != &sub_agent_id {
cache.set_sub_agent_cancelled(conv_id, id).await;
}
}
}
}
let status = if result.error.is_some() {
"error"
} else if result.cancelled {
"stopped"
} else {
"ok"
};
let output = result.output.clone();
persist_sub_agent_session(
&db,
conversation_id.unwrap_or_default(),
&sub_agent_id,
&role,
&task,
&output,
result.input_tokens,
result.output_tokens,
&sub_model_name,
status,
result.error.clone(),
)
.await;
let display = if result.error.is_some() {
format!("Sub-agent failed ({role})")
} else if result.cancelled {
format!("Sub-agent stopped ({role})")
} else {
format!("Sub-agent completed ({role})")
};
on_chunk(AiStreamChunk {
content: display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: Some(serde_json::json!({
"tool": "call_sub_agent",
"role": role.clone(),
"task": task.clone(),
"output": output.clone(),
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens,
"error": result.error.clone(),
"status": status,
"display": display.clone(),
})),
children_id: Some(sub_agent_id.clone()),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolResult,
content: serde_json::json!({
"tool": "call_sub_agent",
"role": role.clone(),
"task": task.clone(),
"output": output.clone(),
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens,
"error": result.error.clone(),
"status": status,
"display": display.clone(),
"children_id": sub_agent_id.clone(),
})
.to_string(),
});
let tool_content = if let Some(err) = &result.error {
format!(
"{}\n\n[sub_agent_status={} input_tokens={} output_tokens={} error={}]",
output, status, result.input_tokens, result.output_tokens, err
)
} else {
format!(
"{}\n\n[sub_agent_status={} input_tokens={} output_tokens={}]",
output, status, result.input_tokens, result.output_tokens
)
};
tool_messages.push(ChatRequestMessage::tool(&call_id, tool_content));
}
Err(e) => {
let err_msg = format!("Sub-agent ({role}) failed: {}", e);
let display = format!("Sub-agent failed ({role})");
let result_json = serde_json::json!({
"tool": "call_sub_agent",
"role": role,
"status": "error",
"error": err_msg,
"display": display,
})
.to_string();
on_chunk(AiStreamChunk {
content: display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: None,
children_id: Some(sub_agent_id),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolResult,
content: result_json,
});
tool_messages.push(ChatRequestMessage::tool(&call_id, &err_msg));
}
}
}
// Handle regular tool calls via ToolExecutor
if !regular_calls.is_empty() {
let regular_tool_messages = execute_tools(
&regular_calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry,
embed_service.as_ref(),
&on_chunk,
&mut all_chunks,
)
.await;
tool_messages.extend(regular_tool_messages);
}
messages.extend(tool_messages);
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) - no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
}
}
}
async fn execute_sub_agent_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(
db.clone(),
cache.clone(),
app_config.clone(),
Uuid::nil(),
Some(sender_uid),
)
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let tool_name = call_clone.name.clone();
let res = match tokio::time::timeout(
std::time::Duration::from_secs(45),
executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone),
)
.await
{
Ok(res) => res,
Err(_) => Err(crate::tool::ToolError::ExecutionError(format!(
"tool '{}' timed out after 45 seconds",
tool_name
))),
};
(call_clone, res)
});
}
while let Some(res) = join_set.join_next().await {
let Ok((call, results)) = res else {
continue;
};
match results {
Ok(results) => tool_messages.extend(ToolExecutor::to_tool_messages(&results)),
Err(e) => {
let err_text = format!("[Sub-agent tool call failed: {}]", e);
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
tool_messages
}
async fn execute_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(
db.clone(),
cache.clone(),
app_config.clone(),
Uuid::nil(),
Some(sender_uid),
)
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let tool_name = call_clone.name.clone();
let res = match tokio::time::timeout(
std::time::Duration::from_secs(45),
executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone),
)
.await
{
Ok(res) => res,
Err(_) => Err(crate::tool::ToolError::ExecutionError(format!(
"tool '{}' timed out after 45 seconds",
tool_name
))),
};
(call_clone, res)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results)) = res {
match results {
Ok(results) => {
for result in &results {
let preview = match &result.result {
crate::tool::ToolResult::Ok(v) => {
let t = v.to_string();
if t.len() > 300 {
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &t[..end])
} else { t.clone() }
}
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
tracing::debug!("tool_result: {} -{}", call.name, preview);
}
let success_display = format!("OK {}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None, children_id: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolResult, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_text = format!("[Tool call failed: {}]", e);
let err_display = format!("ERR {} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None, children_id: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolResult, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None, children_id: None }).await;
}
}
}
tool_messages
}