feat(agent): add orchestrator, agent profile and message favorites

This commit is contained in:
ZhenYi 2026-05-17 16:37:30 +08:00
parent 51a1fb8c6c
commit 131c1cca2f
30 changed files with 3333 additions and 444 deletions

View File

@ -15,15 +15,13 @@ use db::database::AppDatabase;
use models::agents::model_pricing;
use models::ai::billing_error;
use models::projects::{project, project_billing, project_billing_history};
use models::users::user_billing;
use models::users::{user_billing, user_billing_history};
use rust_decimal::Decimal;
use sea_orm::*;
use uuid::Uuid;
use crate::error::AgentError;
// ── Constants ──
fn default_user_balance() -> Decimal {
Decimal::new(100_000, 4)
} // $10.0000
@ -32,8 +30,6 @@ fn first_project_credit() -> Decimal {
} // $20.0000
const SUBSEQUENT_PROJECT_BALANCE: Decimal = Decimal::ZERO;
// ── Types ──
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
pub struct BillingRecord {
pub cost: f64,
@ -49,8 +45,6 @@ pub enum BillingResult {
InsufficientBalance { message: String },
}
// ── Core deduction: AI usage ──
/// Record AI usage: deduct from project balance first, fall through to user balance.
///
/// Returns `InsufficientBalance` if neither account can cover the cost.
@ -167,6 +161,74 @@ pub async fn record_ai_usage(
}
}
/// Record personal AI usage against the user's own billing balance.
pub async fn record_user_ai_usage(
db: &AppDatabase,
user_uid: Uuid,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<BillingResult, AgentError> {
let total_cost = compute_cost(db, model_id, input_tokens, output_tokens).await?;
let currency = get_currency(db, model_id).await?;
match deduct_from_user_personal(
db,
user_uid,
total_cost,
&currency,
model_id,
input_tokens,
output_tokens,
)
.await
{
Ok(()) => {
let cost_f64 = decimal_to_f64(total_cost);
tracing::info!(
user_id = %user_uid,
model_id = %model_id,
input_tokens, output_tokens,
cost = %cost_f64,
currency = %currency,
deducted_from = "user",
scope = "personal",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
deducted_from: "user".to_string(),
}))
}
Err(insufficient_msg) => {
persist_billing_error(
db,
"user",
user_uid,
"insufficient_balance",
&insufficient_msg,
Some(serde_json::json!({
"user_id": user_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": decimal_to_f64(total_cost),
"currency": currency,
"scope": "personal",
})),
)
.await?;
Ok(BillingResult::InsufficientBalance {
message: insufficient_msg,
})
}
}
}
/// Check whether a project + user has sufficient combined balance for a potential AI call.
/// Called before starting AI processing to avoid wasted compute.
pub async fn check_balance(
@ -190,6 +252,26 @@ pub async fn check_balance(
Ok(project_balance + user_balance >= estimated_cost)
}
/// Check whether a user's personal balance can cover a potential AI call.
pub async fn check_user_balance(
db: &AppDatabase,
user_uid: Uuid,
model_id: Uuid,
estimated_input_tokens: i64,
estimated_output_tokens: i64,
) -> Result<bool, AgentError> {
let estimated_cost = compute_cost(
db,
model_id,
estimated_input_tokens,
estimated_output_tokens,
)
.await?;
let user_balance = get_user_balance(db, user_uid).await;
Ok(user_balance >= estimated_cost)
}
// ── Initialization ──
/// Initialize a user billing account with the default $10 balance.
@ -226,6 +308,7 @@ pub async fn initialize_project_billing(
// Check how many projects this user has already created
let existing_count = project::Entity::find()
.filter(project::Column::CreatedBy.eq(creator_uid))
.filter(project::Column::Id.ne(project_uid))
.count(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to count user projects: {}", e)))?;
@ -315,9 +398,17 @@ async fn compute_cost(
.parse()
.map_err(|e| AgentError::Internal(format!("Invalid output price: {}", e)))?;
let thousand = Decimal::from(1000);
Ok((Decimal::from(input_tokens) / thousand) * input_price
+ (Decimal::from(output_tokens) / thousand) * output_price)
if input_price <= Decimal::ZERO && output_price <= Decimal::ZERO {
return Err(AgentError::Internal(
"Model pricing is not configured or is zero. Please configure non-zero AI model pricing first."
.into(),
));
}
// DB stores per-1M-token prices; divide tokens by 1M to compute cost.
let million = Decimal::from(1_000_000);
Ok((Decimal::from(input_tokens) / million) * input_price
+ (Decimal::from(output_tokens) / million) * output_price)
}
async fn get_currency(db: &AppDatabase, model_id: Uuid) -> Result<String, AgentError> {
@ -481,6 +572,71 @@ async fn deduct_from_user(
Ok(())
}
async fn deduct_from_user_personal(
db: &AppDatabase,
user_uid: Uuid,
cost: Decimal,
currency: &str,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<(), String> {
let txn = db
.begin()
.await
.map_err(|e| format!("db txn error: {}", e))?;
let billing = user_billing::Entity::find_by_id(user_uid)
.lock_exclusive()
.one(&txn)
.await
.map_err(|e| format!("db error: {}", e))?
.ok_or_else(|| "User billing account not found".to_string())?;
if billing.balance < cost {
txn.rollback().await.ok();
return Err(format!(
"Insufficient balance. User: {:.4} {}. Required: {:.4} {}",
billing.balance, billing.currency, cost, billing.currency
));
}
let now = chrono::Utc::now();
user_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
user: Set(user_uid),
amount: Set(-cost),
currency: Set(currency.to_string()),
reason: Set("ai_usage_personal".to_string()),
extra: Set(Some(serde_json::json!({
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"deducted_from": "user",
"scope": "personal",
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await
.map_err(|e| format!("failed to insert user history: {}", e))?;
let mut updated: user_billing::ActiveModel = billing.into();
updated.balance = Set(updated.balance.unwrap() - cost);
updated.updated_at = Set(now);
updated
.update(&txn)
.await
.map_err(|e| format!("failed to update user balance: {}", e))?;
txn.commit()
.await
.map_err(|e| format!("commit error: {}", e))?;
Ok(())
}
pub async fn persist_billing_error(
db: &AppDatabase,
scope: &str,

View File

@ -0,0 +1,208 @@
use crate::chat::{AgentExecutionProfile, AgentRole};
/// Tools available to every agent role (shared baseline).
fn shared_tools() -> Vec<String> {
vec![
// Conversation management
"chat_generate_title".into(),
// File parsing / search
"git_grep".into(),
"read_csv".into(),
"read_json".into(),
"read_sql".into(),
"read_markdown".into(),
// File content retrieval
"git_file_content".into(),
"git_blob_get".into(),
"git_blob_content".into(),
]
}
/// Researcher-specific tools (search & discovery).
fn researcher_tools() -> Vec<String> {
vec![
// Search
"git_search_commits".into(),
"repo_search".into(),
"repo_doc_search".into(),
"project_arxiv_search".into(),
"project_curl".into(),
// Index / overview
"repo_overview".into(),
"repo_readme".into(),
"repo_file_tree".into(),
"repo_doc_index".into(),
"repo_doc_read".into(),
// Lists
"project_list_repos".into(),
"project_list_members".into(),
"project_list_issues".into(),
"project_list_labels".into(),
"project_list_boards".into(),
// Log / history
"git_log".into(),
"git_reflog".into(),
"git_graph".into(),
"git_commit_info".into(),
"repo_commit_log".into(),
]
}
/// Analyst-specific tools (deep analysis & explanation).
fn analyst_tools() -> Vec<String> {
vec![
// Deep inspection
"git_show".into(),
"git_diff".into(),
"git_diff_stats".into(),
"git_blame".into(),
// Structural analysis
"repo_languages".into(),
"repo_dependencies".into(),
"repo_diff_summary".into(),
"repo_contributors".into(),
// Branch comparison
"git_branch_list".into(),
"git_branch_info".into(),
"git_branch_diff".into(),
// Project data
"project_list_issues".into(),
"project_list_repos".into(),
]
}
/// Reviewer-specific tools (evaluation & risk detection).
fn reviewer_tools() -> Vec<String> {
vec![
// Change inspection
"git_diff".into(),
"git_diff_stats".into(),
"git_blame".into(),
// Merge status
"git_branches_merged".into(),
"git_branch_info".into(),
// Tracking
"project_list_issues".into(),
"project_update_issue".into(),
// Boards
"project_list_boards".into(),
"project_update_board_card".into(),
]
}
/// Supervisor-specific tools (delegation & synthesis).
fn supervisor_tools() -> Vec<String> {
vec![
// Delegation
"call_sub_agent".into(),
]
}
/// Returns the complete tool set for a given agent role (shared + role-specific).
pub fn tools_for_role(role: &AgentRole) -> Vec<String> {
let mut tools = shared_tools();
match role {
AgentRole::Researcher => tools.extend(researcher_tools()),
AgentRole::Analyst => tools.extend(analyst_tools()),
AgentRole::Reviewer => tools.extend(reviewer_tools()),
AgentRole::Supervisor => tools.extend(supervisor_tools()),
AgentRole::Default => {} // Default role gets only shared tools
}
tools
}
pub fn supervisor_profile() -> AgentExecutionProfile {
AgentExecutionProfile {
role: AgentRole::Supervisor,
system_prompt: Some(
"You are the supervisor agent. You coordinate specialist sub-agents to produce the best answer for the user.\n\
\n\
## Delegation Strategy\n\
- Use the `call_sub_agent` tool to delegate tasks to specialist agents.\n\
- Available roles:\n\
- **researcher**: Gathers concrete facts, evidence, and data from tools and context. Best for finding information, searching code, and discovering evidence.\n\
- **analyst**: Builds coherent explanations, highlights causal links, edge cases, and tradeoffs. Best for explaining findings and reasoning about implications.\n\
- **reviewer**: Stress-tests proposals, identifies contradictions, missing assumptions, regressions, and risks. Best for quality checks and risk assessment.\n\
- Provide a clear, focused task description for each sub-agent.\n\
- You may call multiple sub-agents in sequence (call one, review its output, then decide to call another).\n\
- You may also call the same role twice with different tasks if needed.\n\
\n\
## Decision Guide\n\
- Simple factual questions: call researcher only.\n\
- Questions requiring explanation: call researcher then analyst.\n\
- Design/architecture reviews: call researcher, analyst, then reviewer.\n\
- If a sub-agent's output is insufficient, call another sub-agent for clarification.\n\
\n\
## Output Rules\n\
- After gathering all sub-agent outputs, synthesize them into one final answer.\n\
- Resolve conflicts between sub-agent outputs prefer evidence over speculation.\n\
- Call out any remaining uncertainty explicitly.\n\
- Do not assume facts not present in sub-agent outputs.".to_string(),
),
temperature: Some(0.2),
max_tokens: Some(4000),
top_p: Some(1.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
max_tool_depth: Some(8),
allowed_tools: Some(tools_for_role(&AgentRole::Supervisor)),
disable_orchestration: true,
}
}
pub fn researcher_profile() -> AgentExecutionProfile {
AgentExecutionProfile {
role: AgentRole::Researcher,
system_prompt: Some(
"You are the researcher agent. Your job is to gather concrete facts from available tools and context. Prefer direct evidence over inference. Return structured findings, relevant code or data references, and unresolved gaps.".to_string(),
),
temperature: Some(0.1),
max_tokens: Some(1800),
top_p: Some(1.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
max_tool_depth: Some(6),
allowed_tools: Some(tools_for_role(&AgentRole::Researcher)),
disable_orchestration: true,
}
}
pub fn analyst_profile() -> AgentExecutionProfile {
AgentExecutionProfile {
role: AgentRole::Analyst,
system_prompt: Some(
"You are the analyst agent. Build a coherent explanation from the available evidence. Highlight causal links, edge cases, and tradeoffs. If evidence is weak, say so explicitly.".to_string(),
),
temperature: Some(0.2),
max_tokens: Some(1800),
top_p: Some(1.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
max_tool_depth: Some(4),
allowed_tools: Some(tools_for_role(&AgentRole::Analyst)),
disable_orchestration: true,
}
}
pub fn reviewer_profile() -> AgentExecutionProfile {
AgentExecutionProfile {
role: AgentRole::Reviewer,
system_prompt: Some(
"You are the reviewer agent. Stress-test the proposed answer. Look for contradictions, missing assumptions, regressions, and risks. Output only high-signal critiques and corrections.".to_string(),
),
temperature: Some(0.1),
max_tokens: Some(1600),
top_p: Some(1.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
max_tool_depth: Some(4),
allowed_tools: Some(tools_for_role(&AgentRole::Reviewer)),
disable_orchestration: true,
}
}
/// Whether to enable multi-agent delegation for this request.
/// Simplified from keyword-based gating: delegation is enabled when tools are available.
pub fn should_enable_delegation(_input: &str, tools_available: bool) -> bool {
tools_available
}

View File

@ -2,6 +2,7 @@ use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use super::agent_profile::{analyst_profile, researcher_profile, reviewer_profile};
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
@ -16,6 +17,398 @@ use sea_orm::{ActiveModelTrait, EntityTrait, Set};
use super::service::StreamResult;
use super::{AiChunkType, AiStreamChunk, StreamCallback};
struct SubAgentRunResult {
output: String,
input_tokens: i64,
output_tokens: i64,
cancelled: bool,
error: Option<String>,
}
/// Persist a sub-agent session record to the database.
async fn persist_sub_agent_session(
db: &db::database::AppDatabase,
conversation_id: Uuid,
children_id: &str,
role: &str,
task: &str,
output: &str,
input_tokens: i64,
output_tokens: i64,
model_name: &str,
status: &str,
error_message: Option<String>,
) {
use models::ai::ai_subagent_session;
use sea_orm::{ActiveModelTrait, Set};
let record = ai_subagent_session::ActiveModel {
id: Set(Uuid::now_v7()),
conversation_id: Set(conversation_id),
message_id: Set(Uuid::nil()),
children_id: Set(children_id.to_string()),
role: Set(role.to_string()),
task: Set(task.to_string()),
output: Set(output.to_string()),
input_tokens: Set(input_tokens),
output_tokens: Set(output_tokens),
model_name: Set(Some(model_name.to_string())),
status: Set(status.to_string()),
error_message: Set(error_message),
created_at: Set(chrono::Utc::now()),
};
if let Err(e) = record.insert(db.writer()).await {
tracing::warn!(error = %e, children_id = %children_id, "failed to persist sub-agent session");
}
}
/// Execute a sub-agent call with streaming output via NATS.
///
/// The sub-agent output is streamed to NATS JetStream subject
/// `chat.subagent.chunk.{conversation_id}.{children_id}` so the frontend
/// can subscribe via the `/api/ai/subagent/{conversation_id}/{children_id}/stream` endpoint.
///
/// Returns the full or partial output after the sub-agent completes or is cancelled.
async fn call_sub_agent_stream(
messages: &[ChatRequestMessage],
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tools: Option<&[serde_json::Value]>,
tool_registry: Option<ToolRegistry>,
db: db::database::AppDatabase,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
children_id: &str,
conversation_id: Option<uuid::Uuid>,
cache: db::cache::AppCache,
queue_producer: Option<&queue::MessageProducer>,
) -> Result<SubAgentRunResult> {
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
let conversation_id = conversation_id.unwrap_or_default();
let seq = Arc::new(AtomicU64::new(0));
let children_id_owned = children_id.to_string();
let queue_ref = queue_producer.cloned();
let partial_output = Arc::new(tokio::sync::Mutex::new(String::new()));
let (delta_tx, mut delta_rx) = mpsc::unbounded_channel::<(&'static str, String)>();
cache
.clear_sub_agent_cancelled(conversation_id, &children_id_owned)
.await;
let stream_fut = async {
let mut messages = messages.to_vec();
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut last_content = String::new();
let mut tool_depth = 0usize;
loop {
let response = call_stream(
&messages,
model_name,
config,
temperature,
max_tokens,
tools,
None,
Arc::new({
let partial_output = partial_output.clone();
let delta_tx = delta_tx.clone();
move |delta| {
let content = delta.to_string();
let partial_output = partial_output.clone();
let delta_tx = delta_tx.clone();
Box::pin(async move {
partial_output.lock().await.push_str(&content);
let _ = delta_tx.send(("token", content));
})
as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}
}),
Arc::new({
let delta_tx = delta_tx.clone();
move |delta| {
let content = delta.to_string();
let delta_tx = delta_tx.clone();
Box::pin(async move {
let _ = delta_tx.send(("thinking", content));
})
as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}
}),
Arc::new(move |_tc: &StreamedToolCall| {
Box::pin(async move {}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
if !response.content.is_empty() {
last_content = response.content.clone();
}
if response.tool_calls.is_empty() {
return Ok::<SubAgentRunResult, crate::error::AgentError>(SubAgentRunResult {
output: response.content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled: false,
error: None,
});
}
if tool_depth >= max_tool_depth {
let fallback_output = if last_content.is_empty() {
"Sub-agent reached maximum tool depth before producing a final summary."
.to_string()
} else {
last_content
};
return Ok::<SubAgentRunResult, crate::error::AgentError>(SubAgentRunResult {
output: fallback_output,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled: false,
error: Some(format!(
"sub-agent reached maximum tool depth ({max_tool_depth}) before final summary"
)),
});
}
let assistant_tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(assistant_tool_calls),
));
let agent_tool_calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
let tool_messages = execute_sub_agent_tools(
&agent_tool_calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry.as_ref(),
embed_service.as_ref(),
)
.await;
messages.extend(tool_messages);
messages.push(ChatRequestMessage::user(
"Use the tool results above to produce your final concise findings. Do not call another tool unless it is strictly necessary.",
));
tool_depth += 1;
}
};
let children_id_for_cancel = children_id_owned.clone();
let cancel_fut = async {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
loop {
interval.tick().await;
if cache
.is_sub_agent_cancelled(conversation_id, &children_id_for_cancel)
.await
{
break;
}
}
};
let timeout_fut = tokio::time::sleep(std::time::Duration::from_secs(60));
let flush_queue = queue_ref.clone();
let flush_children_id = children_id_owned.clone();
let flush_seq = seq.clone();
let mut flush_handle = tokio::spawn(async move {
let Some(queue) = flush_queue else {
while delta_rx.recv().await.is_some() {}
return;
};
let mut token_buf = String::new();
let mut thinking_buf = String::new();
let mut interval = tokio::time::interval(std::time::Duration::from_millis(50));
async fn flush(
queue: &queue::MessageProducer,
conversation_id: Uuid,
children_id: &str,
seq: &Arc<AtomicU64>,
chunk_type: &str,
buffer: &mut String,
) {
if buffer.is_empty() {
return;
}
let event = queue::types::SubAgentStreamChunkEvent {
conversation_id,
children_id: children_id.to_string(),
seq: seq.fetch_add(1, Ordering::Relaxed),
content: std::mem::take(buffer),
done: false,
error: None,
chunk_type: Some(chunk_type.to_string()),
role: String::new(),
task: String::new(),
};
queue.publish_sub_agent_chunk_realtime(&event).await;
}
loop {
tokio::select! {
Some((kind, content)) = delta_rx.recv() => {
let target = if kind == "thinking" { &mut thinking_buf } else { &mut token_buf };
target.push_str(&content);
if target.len() >= 240 {
flush(&queue, conversation_id, &flush_children_id, &flush_seq, kind, target).await;
}
}
_ = interval.tick() => {
flush(&queue, conversation_id, &flush_children_id, &flush_seq, "thinking", &mut thinking_buf).await;
flush(&queue, conversation_id, &flush_children_id, &flush_seq, "token", &mut token_buf).await;
}
else => break,
}
}
flush(
&queue,
conversation_id,
&flush_children_id,
&flush_seq,
"thinking",
&mut thinking_buf,
)
.await;
flush(
&queue,
conversation_id,
&flush_children_id,
&flush_seq,
"token",
&mut token_buf,
)
.await;
});
let response = tokio::select! {
result = stream_fut => {
match result {
Ok(response) => Some(response),
Err(e) => Some(SubAgentRunResult {
output: partial_output.lock().await.clone(),
input_tokens: 0,
output_tokens: 0,
cancelled: false,
error: Some(e.to_string()),
}),
}
}
_ = cancel_fut => None,
_ = timeout_fut => Some(SubAgentRunResult {
output: partial_output.lock().await.clone(),
input_tokens: 0,
output_tokens: 0,
cancelled: false,
error: Some("sub-agent timed out after 60 seconds".to_string()),
}),
};
drop(delta_tx);
if tokio::time::timeout(std::time::Duration::from_secs(2), &mut flush_handle)
.await
.is_err()
{
flush_handle.abort();
tracing::warn!(
children_id = %children_id,
"sub-agent stream flush timed out; continuing with terminal event"
);
}
let cancelled = response.is_none();
let (total_content, total_input_tokens, total_output_tokens, terminal_error) = match response {
Some(response) => (
response.output.clone(),
response.input_tokens,
response.output_tokens,
response.error.clone(),
),
None => (partial_output.lock().await.clone(), 0, 0, None),
};
// Send final done/stopped chunk.
let final_seq = seq.load(Ordering::Relaxed);
let event = queue::types::SubAgentStreamChunkEvent {
conversation_id,
children_id: children_id_owned,
seq: final_seq,
content: String::new(),
done: true,
error: terminal_error.clone(),
chunk_type: Some(
if terminal_error.is_some() {
"error"
} else if cancelled {
"stopped"
} else {
"done"
}
.to_string(),
),
role: String::new(),
task: String::new(),
};
if let Some(q) = queue_ref {
if tokio::time::timeout(
std::time::Duration::from_secs(1),
q.publish_sub_agent_chunk(&event),
)
.await
.is_err()
{
tracing::warn!(
children_id = %event.children_id,
"sub-agent terminal event publish timed out"
);
}
}
Ok(SubAgentRunResult {
output: total_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
cancelled,
error: terminal_error,
})
}
// Keyword-extraction-based title generator: reads conversation messages, extracts
// meaningful words, and updates the conversation record with a short title.
async fn generate_title_for_conversation(
@ -135,7 +528,7 @@ type SharedCallback = Arc<
///
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
/// this function takes messages and tools directly. It does NOT record AI sessions to
/// the `ai_session` table the caller is responsible for persisting results.
/// the `ai_session` table -the caller is responsible for persisting results.
pub async fn execute_chat_stream(
messages: Vec<ChatRequestMessage>,
tools: Vec<serde_json::Value>,
@ -153,6 +546,7 @@ pub async fn execute_chat_stream(
embed_service: Option<EmbedService>,
on_chunk: StreamCallback,
conversation_id: Option<uuid::Uuid>,
queue_producer: Option<queue::MessageProducer>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools_enabled = !tools.is_empty();
@ -228,6 +622,35 @@ pub async fn execute_chat_stream(
(tools.clone(), false)
};
// Add call_sub_agent tool for chat orchestration when tools are available
let tools = if tools_enabled {
let mut t = tools;
t.push(serde_json::json!({
"type": "function",
"function": {
"name": "call_sub_agent",
"description": "Delegate a task to a specialist sub-agent and receive its output.\nAvailable roles:\n- researcher: Gathers facts, evidence, and data. Best for finding information and searching code.\n- analyst: Builds explanations, highlights causal links and tradeoffs. Best for reasoning about implications.\n- reviewer: Stress-tests proposals, identifies risks and contradictions. Best for quality checks.\nProvide a clear, focused task description so the sub-agent knows exactly what to investigate.",
"parameters": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The sub-agent role to delegate to: researcher, analyst, or reviewer."
},
"task": {
"type": "string",
"description": "The specific task or question for the sub-agent. Be precise and focused."
}
},
"required": ["role", "task"]
}
}
}));
t
} else {
tools
};
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
@ -249,6 +672,7 @@ pub async fn execute_chat_stream(
done: false,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
});
fut
}),
@ -258,6 +682,7 @@ pub async fn execute_chat_stream(
done: false,
chunk_type: AiChunkType::Thinking,
metadata: None,
children_id: None,
});
fut
}),
@ -278,10 +703,10 @@ pub async fn execute_chat_stream(
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
let final_content = response.content.clone();
// Don't push full content as a chunk incremental deltas in
// Don't push full content as a chunk -incremental deltas in
// response.chunks (already added above) sum to the same text.
// merge_consecutive_blocks would concatenate delta_sum + full =
// 2× full, causing duplicate content in DB persistence.
// 2x full, causing duplicate content in DB persistence.
return Ok(StreamResult {
content: final_content,
reasoning_content: response.reasoning_content,
@ -327,12 +752,13 @@ pub async fn execute_chat_stream(
} else {
tc.arguments.clone()
};
let tool_display = format!("🔧 {}({})", tc.name, args_display);
let tool_display = format!("[tool] {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: None,
children_id: None,
})
.await;
all_chunks.push(StreamChunk {
@ -345,7 +771,7 @@ pub async fn execute_chat_stream(
}
}
let calls: Vec<AgentToolCall> = response
let (sub_agent_calls, regular_calls): (Vec<AgentToolCall>, Vec<AgentToolCall>) = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
@ -353,28 +779,300 @@ pub async fn execute_chat_stream(
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
.collect::<Vec<_>>()
.into_iter()
.partition(|c| c.name == "call_sub_agent");
let tool_messages = execute_tools(
&calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry,
embed_service.as_ref(),
&on_chunk,
&mut all_chunks,
)
.await;
let mut tool_messages = Vec::new();
let mut sub_agent_tasks = tokio::task::JoinSet::new();
let mut sub_agent_ids: Vec<String> = Vec::new();
// Handle call_sub_agent calls inline -stream sub-agent output via NATS
for sub_call in sub_agent_calls {
let args: serde_json::Value = match serde_json::from_str(&sub_call.arguments) {
Ok(v) => v,
Err(_) => {
tool_messages.push(ChatRequestMessage::tool(
&sub_call.id,
"Failed to parse call_sub_agent arguments",
));
continue;
}
};
let role = args
.get("role")
.and_then(|v| v.as_str())
.unwrap_or("researcher");
let task = args.get("task").and_then(|v| v.as_str()).unwrap_or("");
let profile = match role {
"analyst" => analyst_profile(),
"reviewer" => reviewer_profile(),
_ => researcher_profile(),
};
// Generate children_id BEFORE starting sub-agent execution
let sub_agent_id = format!("sub-agent-{}", Uuid::now_v7());
sub_agent_ids.push(sub_agent_id.clone());
// Emit tool_call chunk immediately with children_id so frontend can start watching
let call_display =
format!("[tool] call_sub_agent({role}) - delegating to {role} agent...");
on_chunk(AiStreamChunk {
content: call_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: Some(serde_json::json!({
"tool": "call_sub_agent",
"args": { "role": role.to_string(), "task": task.to_string() },
"display": call_display,
})),
children_id: Some(sub_agent_id.clone()),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: call_display,
});
let sub_system = profile.system_prompt.clone().unwrap_or_default();
let sub_messages = vec![
ChatRequestMessage::system(sub_system),
ChatRequestMessage::user(format!(
"Sub-agent role: {role}\n\nTask:\n{task}\n\nFocus only on your assigned task. Return concise, evidence-backed findings."
)),
];
// Filter tools for the sub-agent: only include tools in the profile's allowed list,
// always excluding call_sub_agent and chat_generate_title
let sub_tools: Vec<serde_json::Value> = if let Some(ref allowed) = profile.allowed_tools
{
tools
.iter()
.filter(|t| {
let name = t
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("");
allowed.contains(&name.to_string())
&& name != "call_sub_agent"
&& name != "chat_generate_title"
})
.cloned()
.collect()
} else {
tools
.iter()
.filter(|t| {
let name = t
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("");
name != "call_sub_agent" && name != "chat_generate_title"
})
.cloned()
.collect()
};
let call_id = sub_call.id.clone();
let role_owned = role.to_string();
let task_owned = task.to_string();
let sub_agent_id_owned = sub_agent_id.clone();
let model_name_owned = model_name.to_string();
let config_owned = config.clone();
let cache_owned = cache.clone();
let db_owned = db.clone();
let app_config_owned = app_config.clone();
let embed_service_owned = embed_service.clone();
let tool_registry_owned = tool_registry.cloned();
let queue_owned = queue_producer.clone();
let conversation_id_owned = conversation_id;
let temperature = profile.temperature.unwrap_or(0.7) as f32;
let max_tokens = profile.max_tokens.unwrap_or(4096) as u32;
let sub_max_tool_depth = profile.max_tool_depth.unwrap_or(4) as usize;
sub_agent_tasks.spawn(async move {
let result = call_sub_agent_stream(
&sub_messages,
&model_name_owned,
&config_owned,
temperature,
max_tokens,
sub_max_tool_depth,
Some(&sub_tools),
tool_registry_owned,
db_owned,
app_config_owned,
project_id,
sender_uid,
embed_service_owned,
&sub_agent_id_owned,
conversation_id_owned,
cache_owned,
queue_owned.as_ref(),
)
.await;
(
call_id,
sub_agent_id_owned,
role_owned,
task_owned,
model_name_owned,
result,
)
});
}
let mut cancelled_batch = false;
while let Some(joined) = sub_agent_tasks.join_next().await {
let Ok((call_id, sub_agent_id, role, task, sub_model_name, result)) = joined else {
continue;
};
match result {
Ok(result) => {
if result.cancelled && !cancelled_batch {
cancelled_batch = true;
if let Some(conv_id) = conversation_id {
for id in &sub_agent_ids {
if id != &sub_agent_id {
cache.set_sub_agent_cancelled(conv_id, id).await;
}
}
}
}
let status = if result.error.is_some() {
"error"
} else if result.cancelled {
"stopped"
} else {
"ok"
};
let output = result.output.clone();
persist_sub_agent_session(
&db,
conversation_id.unwrap_or_default(),
&sub_agent_id,
&role,
&task,
&output,
result.input_tokens,
result.output_tokens,
&sub_model_name,
status,
result.error.clone(),
)
.await;
let display = if result.error.is_some() {
format!("Sub-agent failed ({role})")
} else if result.cancelled {
format!("Sub-agent stopped ({role})")
} else {
format!("Sub-agent completed ({role})")
};
on_chunk(AiStreamChunk {
content: display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: Some(serde_json::json!({
"tool": "call_sub_agent",
"role": role.clone(),
"task": task.clone(),
"output": output.clone(),
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens,
"error": result.error.clone(),
"status": status,
"display": display.clone(),
})),
children_id: Some(sub_agent_id.clone()),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolResult,
content: serde_json::json!({
"tool": "call_sub_agent",
"role": role.clone(),
"task": task.clone(),
"output": output.clone(),
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens,
"error": result.error.clone(),
"status": status,
"display": display.clone(),
"children_id": sub_agent_id.clone(),
})
.to_string(),
});
let tool_content = if let Some(err) = &result.error {
format!(
"{}\n\n[sub_agent_status={} input_tokens={} output_tokens={} error={}]",
output, status, result.input_tokens, result.output_tokens, err
)
} else {
format!(
"{}\n\n[sub_agent_status={} input_tokens={} output_tokens={}]",
output, status, result.input_tokens, result.output_tokens
)
};
tool_messages.push(ChatRequestMessage::tool(&call_id, tool_content));
}
Err(e) => {
let err_msg = format!("Sub-agent ({role}) failed: {}", e);
let display = format!("Sub-agent failed ({role})");
let result_json = serde_json::json!({
"tool": "call_sub_agent",
"role": role,
"status": "error",
"error": err_msg,
"display": display,
})
.to_string();
on_chunk(AiStreamChunk {
content: display.clone(),
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: None,
children_id: Some(sub_agent_id),
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolResult,
content: result_json,
});
tool_messages.push(ChatRequestMessage::tool(&call_id, &err_msg));
}
}
}
// Handle regular tool calls via ToolExecutor
if !regular_calls.is_empty() {
let regular_tool_messages = execute_tools(
&regular_calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry,
embed_service.as_ref(),
&on_chunk,
&mut all_chunks,
)
.await;
tool_messages.extend(regular_tool_messages);
}
messages.extend(tool_messages);
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
"[AI reached maximum tool depth ({}) - no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk {
@ -382,6 +1080,7 @@ pub async fn execute_chat_stream(
done: true,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
})
.await;
all_chunks.push(StreamChunk {
@ -399,6 +1098,71 @@ pub async fn execute_chat_stream(
}
}
async fn execute_sub_agent_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(
db.clone(),
cache.clone(),
app_config.clone(),
Uuid::nil(),
Some(sender_uid),
)
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let tool_name = call_clone.name.clone();
let res = match tokio::time::timeout(
std::time::Duration::from_secs(45),
executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone),
)
.await
{
Ok(res) => res,
Err(_) => Err(crate::tool::ToolError::ExecutionError(format!(
"tool '{}' timed out after 45 seconds",
tool_name
))),
};
(call_clone, res)
});
}
while let Some(res) = join_set.join_next().await {
let Ok((call, results)) = res else {
continue;
};
match results {
Ok(results) => tool_messages.extend(ToolExecutor::to_tool_messages(&results)),
Err(e) => {
let err_text = format!("[Sub-agent tool call failed: {}]", e);
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
tool_messages
}
async fn execute_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
@ -433,9 +1197,19 @@ async fn execute_tools(
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let res = executor
.execute_batch(vec![call_clone.clone()], &mut ctx_clone)
.await;
let tool_name = call_clone.name.clone();
let res = match tokio::time::timeout(
std::time::Duration::from_secs(45),
executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone),
)
.await
{
Ok(res) => res,
Err(_) => Err(crate::tool::ToolError::ExecutionError(format!(
"tool '{}' timed out after 45 seconds",
tool_name
))),
};
(call_clone, res)
});
}
@ -458,27 +1232,27 @@ async fn execute_tools(
}
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
tracing::debug!("tool_result: {} {}", call.name, preview);
tracing::debug!("tool_result: {} -{}", call.name, preview);
}
let success_display = format!(" {}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let success_display = format!("OK {}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None, children_id: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolResult, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_text = format!("[Tool call failed: {}]", e);
let err_display = format!(" {} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
let err_display = format!("ERR {} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None, children_id: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolResult, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await;
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None, children_id: None }).await;
}
}
}

View File

@ -9,6 +9,7 @@ use crate::embed::EmbedService;
use crate::error::Result;
use crate::perception::{PerceptionService, SkillEntry};
#[derive(Clone)]
pub struct MessageBuilder {
pub compact_service: Option<CompactService>,
pub embed_service: Option<EmbedService>,
@ -42,6 +43,10 @@ impl MessageBuilder {
pub async fn build_messages(&self, request: &AiRequest) -> Result<Vec<ChatRequestMessage>> {
let mut messages = Vec::new();
if let Some(ref preamble) = request.room_preamble {
messages.push(ChatRequestMessage::system(preamble.clone()));
}
messages.push(ChatRequestMessage::system(
"When receiving a question or problem, follow this reasoning process:\n\
1. ANALYZE: Break down the question. Identify what is being asked, what context is available, and what information is missing.\n\
@ -51,104 +56,49 @@ impl MessageBuilder {
\n\
Do NOT guess or assume when tools can provide concrete answers. Always verify claims against actual code or documentation.".to_string()
));
if let Some(system_prompt) = request
.execution_profile
.as_ref()
.and_then(|p| p.system_prompt.as_ref())
{
messages.push(ChatRequestMessage::system(system_prompt.clone()));
}
let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service {
let compact_cache_key = format!("ai:compact:{}", request.room.id);
let cached_summary: Option<String> = match request.cache.conn().await {
Ok(mut conn) => redis::cmd("GET")
.arg(&compact_cache_key)
.query_async::<Option<String>>(&mut conn)
.await
.unwrap_or(None),
Err(e) => {
tracing::warn!(error = %e, "compact cache: conn failed");
None
}
};
let compact_config = request
.context_setting
.as_ref()
.map(|s| {
crate::compact::CompactConfig::from_project_setting(
s.context_window_tokens,
s.compaction_threshold,
s.compaction_max_summary_ratio,
)
})
.unwrap_or_default();
if let Some(cached_json) = cached_summary {
if let Ok(summary) =
serde_json::from_str::<crate::compact::CompactSummary>(&cached_json)
{
if !summary.summary.is_empty() {
match compact_service
.for_model_entry(&request.model)
.prepare_room_compact_context(
request.room.id,
request.sender.uid,
Some(request.user_names.clone()),
compact_config,
)
.await
{
Ok(context) => {
if let Some(summary) = context.summary.filter(|s| !s.trim().is_empty()) {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
summary.summary
summary
)));
}
processed_history = summary.retained;
processed_history = context.retained;
}
}
if processed_history.is_empty() {
let compact_config = request
.context_setting
.as_ref()
.map(|s| {
crate::compact::CompactConfig::from_project_setting(
s.context_window_tokens,
s.compaction_threshold,
s.compaction_max_summary_ratio,
)
})
.unwrap_or_default();
let history_text = request
.history
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let estimated_tokens =
crate::tokent::count_message_text(&history_text, &request.model.name)
.unwrap_or_else(|_| history_text.len() / 4);
if estimated_tokens < compact_config.token_threshold {
tracing::debug!(
estimated_tokens,
threshold = compact_config.token_threshold,
"conversation compaction skipped below threshold"
);
} else {
match compact_service
.for_model(&request.model.name)
.compact_room(
request.room.id,
compact_config.default_level,
Some(request.user_names.clone()),
request.sender.uid,
request
.context_setting
.as_ref()
.map(|s| s.context_window_tokens)
.unwrap_or(128000),
request
.context_setting
.as_ref()
.map(|s| s.compaction_max_summary_ratio)
.unwrap_or(0.2),
)
.await
{
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!(
"Conversation summary:\n{}",
compact_summary.summary
)));
}
if let Ok(json) = serde_json::to_string(&compact_summary) {
if let Ok(mut conn) = request.cache.conn().await {
let _ = redis::cmd("SETEX").arg(&compact_cache_key).arg(300u64).arg(&json).query_async::<()>(&mut conn).await
.inspect_err(|e| { tracing::warn!(error = %e, "compact cache: SETEX failed"); });
}
}
processed_history = compact_summary.retained;
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history")
}
}
Err(e) => {
tracing::warn!(error = %e, "conversation compaction failed, using full history");
}
}
}
@ -159,7 +109,7 @@ impl MessageBuilder {
messages.push(ctx.to_message());
}
} else {
for msg in &request.history {
for msg in Self::history_in_chronological_order(&request.history) {
let ctx = RoomMessageContext::from_model_with_names(msg, &request.user_names);
messages.push(ctx.to_message());
}
@ -273,6 +223,14 @@ impl MessageBuilder {
Ok(messages)
}
fn history_in_chronological_order(
history: &[models::rooms::room_message::Model],
) -> Vec<&models::rooms::room_message::Model> {
let mut ordered = history.iter().collect::<Vec<_>>();
ordered.sort_by_key(|m| (m.seq, m.send_at));
ordered
}
pub async fn build_room_optimized_context_text(
&self,
request: &AiRequest,
@ -294,7 +252,7 @@ impl MessageBuilder {
.unwrap_or_default();
let context = match compact_service
.for_model(&request.model.name)
.for_model_entry(&request.model)
.prepare_room_compact_context(
request.room.id,
request.sender.uid,
@ -332,10 +290,11 @@ impl MessageBuilder {
fn recent_history_text(request: &AiRequest, cutoff_seq: Option<i64>) -> String {
let mut lines = Vec::new();
for msg in request
.history
.iter()
for msg in Self::history_in_chronological_order(&request.history)
.into_iter()
.filter(|m| cutoff_seq.map(|seq| m.seq > seq).unwrap_or(true))
.collect::<Vec<_>>()
.into_iter()
.rev()
.take(20)
.collect::<Vec<_>>()
@ -489,3 +448,57 @@ impl MessageBuilder {
}
}
}
#[cfg(test)]
mod tests {
use super::MessageBuilder;
use chrono::{Duration, Utc};
use models::rooms::{MessageContentType, MessageSenderType, room_message};
use uuid::Uuid;
#[test]
fn history_is_sorted_chronologically() {
let now = Utc::now();
let later = room_message::Model {
id: Uuid::new_v4(),
seq: 20,
room: Uuid::new_v4(),
sender_type: MessageSenderType::User,
sender_id: Some(Uuid::new_v4()),
model_id: None,
thread: None,
content: "later".into(),
content_type: MessageContentType::Text,
thinking_content: None,
edited_at: None,
send_at: now,
revoked: None,
revoked_by: None,
in_reply_to: None,
content_tsv: None,
};
let earlier = room_message::Model {
id: Uuid::new_v4(),
seq: 10,
room: later.room,
sender_type: MessageSenderType::User,
sender_id: later.sender_id,
model_id: None,
thread: None,
content: "earlier".into(),
content_type: MessageContentType::Text,
thinking_content: None,
edited_at: None,
send_at: now - Duration::seconds(10),
revoked: None,
revoked_by: None,
in_reply_to: None,
content_tsv: None,
};
let history = [later, earlier];
let ordered = MessageBuilder::history_in_chronological_order(&history);
assert_eq!(ordered[0].content, "earlier");
assert_eq!(ordered[1].content, "later");
}
}

View File

@ -27,6 +27,8 @@ pub struct AiStreamChunk {
/// tool_call: {"tool": "...", "args": {...}}
/// tool_result: {"tool": "...", "status": "ok|error", "result": "..."}
pub metadata: Option<serde_json::Value>,
/// Optional ID of a child process/agent, sent to frontend via SSE.
pub children_id: Option<String>,
}
/// Type of streaming chunk, used by the frontend for rendering.
@ -78,6 +80,36 @@ pub type StreamCallback = Box<
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentRole {
Default,
Supervisor,
Researcher,
Analyst,
Reviewer,
}
#[derive(Debug, Clone, Default)]
pub struct AgentExecutionProfile {
pub role: AgentRole,
pub system_prompt: Option<String>,
pub temperature: Option<f64>,
pub max_tokens: Option<i32>,
pub top_p: Option<f64>,
pub frequency_penalty: Option<f64>,
pub presence_penalty: Option<f64>,
pub max_tool_depth: Option<usize>,
pub allowed_tools: Option<Vec<String>>,
pub disable_orchestration: bool,
}
impl Default for AgentRole {
fn default() -> Self {
Self::Default
}
}
#[derive(Clone)]
pub struct AiRequest {
pub db: AppDatabase,
pub cache: AppCache,
@ -100,17 +132,22 @@ pub struct AiRequest {
pub think: bool,
pub tools: Option<Vec<serde_json::Value>>,
pub max_tool_depth: usize,
pub execution_profile: Option<AgentExecutionProfile>,
pub room_preamble: Option<String>,
}
#[derive(Clone)]
pub enum Mention {
User(user::Model),
Repo(repo::Model),
}
pub mod agent_profile;
pub mod chat_execution;
pub mod context;
pub mod message_builder;
pub mod nonstreaming_execution;
pub mod orchestrator;
pub mod react_execution;
pub mod service;
pub mod session_recording;

View File

@ -8,7 +8,7 @@ use super::message_builder::MessageBuilder;
use super::service::ProcessResult;
use super::session_recording::record_ai_session;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::types::ChatRequestMessage;
use crate::error::Result;
use crate::perception::{SkillEntry, ToolCallEvent};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
@ -33,13 +33,22 @@ pub async fn execute_process(
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai_config
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
let profile = request.execution_profile.as_ref();
let temperature = profile
.and_then(|p| p.temperature.map(|v| v as f32))
.or_else(|| {
room_ai_config
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
})
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai_config
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
let max_tokens = profile
.and_then(|p| p.max_tokens.map(|v| v as u32))
.or_else(|| {
room_ai_config
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
})
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut input_tokens = 0i64;
@ -68,39 +77,27 @@ pub async fn execute_process(
input_tokens += response.input_tokens;
output_tokens += response.output_tokens;
if tools_enabled && !response.tool_calls_finished.is_empty() {
let tool_call_messages: Vec<_> = response
.tool_calls_finished
.iter()
.map(|name| ToolCall {
id: Uuid::new_v4().to_string(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: name.clone(),
arguments: "{}".into(),
},
})
.collect();
if tools_enabled && !response.tool_calls.is_empty() {
messages.push(ChatRequestMessage::assistant(
Some(text.clone()),
Some(tool_call_messages.clone()),
Some(response.tool_calls.clone()),
));
let calls: Vec<AgentToolCall> = tool_call_messages
.into_iter()
let calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
})
.collect();
let tool_names: Vec<String> = calls.iter().map(|call| call.name.clone()).collect();
let tool_messages = execute_tools(
&request,
&calls,
session_id,
&response.tool_calls_finished,
tool_registry,
message_builder,
)
@ -109,7 +106,7 @@ pub async fn execute_process(
inject_passive_skills(
&request,
message_builder,
&response.tool_calls_finished,
&tool_names,
&mut messages,
)
.await;
@ -173,7 +170,6 @@ async fn execute_tools(
request: &AiRequest,
calls: &[AgentToolCall],
session_id: Uuid,
tool_names: &[String],
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
message_builder: &MessageBuilder,
) -> Vec<ChatRequestMessage> {
@ -198,7 +194,7 @@ async fn execute_tools(
let executor = ToolExecutor::new();
match executor.execute_batch(calls.to_vec(), &mut ctx).await {
Ok(results) => {
for (call, result) in tool_names.iter().zip(results.iter()) {
for (call, result) in calls.iter().zip(results.iter()) {
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result {
@ -208,9 +204,11 @@ async fn execute_tools(
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call.clone(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
arguments: call
.arguments_json()
.unwrap_or_else(|_| serde_json::Value::Null),
status: if is_error {
models::ai::ToolCallStatus::Failed
} else {
@ -226,13 +224,15 @@ async fn execute_tools(
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
for call_name in tool_names {
for call in calls {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: Uuid::new_v4().to_string(),
session_id: recorder.session_id(),
tool_name: call_name.clone(),
tool_name: call.name.clone(),
caller: request.sender.uid,
arguments: serde_json::Value::Null,
arguments: call
.arguments_json()
.unwrap_or_else(|_| serde_json::Value::Null),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed),
error_message: Some(e.to_string()),
@ -241,7 +241,7 @@ async fn execute_tools(
});
}
let err_msg = format!("[Tool call failed: {}]", e);
tool_names
calls
.iter()
.map(|_| ChatRequestMessage::tool(Uuid::new_v4().to_string(), &err_msg))
.collect()

View File

@ -0,0 +1,320 @@
use std::collections::HashMap;
use super::agent_profile::{
analyst_profile, researcher_profile, reviewer_profile, should_enable_delegation,
supervisor_profile,
};
use super::message_builder::MessageBuilder;
use super::nonstreaming_execution::execute_process;
use super::service::{ProcessResult, StreamResult};
use super::{AiRequest, StreamCallback};
use crate::error::Result;
use crate::tool::call::ToolError;
use crate::tool::registry::ToolRegistry;
use crate::tool::{ToolDefinition, ToolHandler, ToolParam, ToolSchema};
pub async fn execute_orchestrated_process(
request: AiRequest,
message_builder: &MessageBuilder,
tool_registry: &Option<ToolRegistry>,
ai_base_url: Option<String>,
ai_api_key: Option<String>,
) -> Result<ProcessResult> {
if request
.execution_profile
.as_ref()
.is_some_and(|p| p.disable_orchestration)
{
return execute_process(
request,
message_builder,
tool_registry,
ai_base_url,
ai_api_key,
)
.await;
}
let tools = request.tools.clone().unwrap_or_default();
if !should_enable_delegation(&request.input, !tools.is_empty()) {
return execute_process(
request,
message_builder,
tool_registry,
ai_base_url,
ai_api_key,
)
.await;
}
let mut enhanced_registry = tool_registry.clone().unwrap_or_default();
register_call_sub_agent_tool(
&mut enhanced_registry,
&request,
message_builder,
tool_registry,
ai_base_url.clone(),
ai_api_key.clone(),
);
let mut supervisor_request = request.clone();
let profile = supervisor_profile();
supervisor_request.execution_profile = Some(profile.clone());
supervisor_request.tools = Some(enhanced_registry.to_openai_tools());
supervisor_request.temperature = profile.temperature.unwrap_or(request.temperature);
supervisor_request.max_tokens = profile.max_tokens.unwrap_or(request.max_tokens);
supervisor_request.top_p = profile.top_p.unwrap_or(request.top_p);
supervisor_request.frequency_penalty = profile
.frequency_penalty
.unwrap_or(request.frequency_penalty);
supervisor_request.presence_penalty = profile
.presence_penalty
.unwrap_or(request.presence_penalty);
execute_process(
supervisor_request,
message_builder,
&Some(enhanced_registry),
ai_base_url,
ai_api_key,
)
.await
}
pub async fn execute_orchestrated_stream(
request: AiRequest,
on_chunk: StreamCallback,
message_builder: &MessageBuilder,
tool_registry: &Option<ToolRegistry>,
ai_base_url: Option<String>,
ai_api_key: Option<String>,
) -> Result<StreamResult> {
if request
.execution_profile
.as_ref()
.is_some_and(|p| p.disable_orchestration)
{
return super::streaming_execution::execute_process_stream(
request,
on_chunk,
message_builder,
tool_registry,
ai_base_url,
ai_api_key,
)
.await;
}
let tools = request.tools.clone().unwrap_or_default();
if !should_enable_delegation(&request.input, !tools.is_empty()) {
return super::streaming_execution::execute_process_stream(
request,
on_chunk,
message_builder,
tool_registry,
ai_base_url,
ai_api_key,
)
.await;
}
let mut enhanced_registry = tool_registry.clone().unwrap_or_default();
register_call_sub_agent_tool(
&mut enhanced_registry,
&request,
message_builder,
tool_registry,
ai_base_url.clone(),
ai_api_key.clone(),
);
let mut supervisor_request = request.clone();
let profile = supervisor_profile();
supervisor_request.execution_profile = Some(profile.clone());
supervisor_request.tools = Some(enhanced_registry.to_openai_tools());
supervisor_request.temperature = profile.temperature.unwrap_or(request.temperature);
supervisor_request.max_tokens = profile.max_tokens.unwrap_or(request.max_tokens);
supervisor_request.top_p = profile.top_p.unwrap_or(request.top_p);
supervisor_request.frequency_penalty = profile
.frequency_penalty
.unwrap_or(request.frequency_penalty);
supervisor_request.presence_penalty = profile
.presence_penalty
.unwrap_or(request.presence_penalty);
super::streaming_execution::execute_process_stream(
supervisor_request,
on_chunk,
message_builder,
&Some(enhanced_registry),
ai_base_url,
ai_api_key,
)
.await
}
fn register_call_sub_agent_tool(
registry: &mut ToolRegistry,
request: &AiRequest,
message_builder: &MessageBuilder,
original_registry: &Option<ToolRegistry>,
ai_base_url: Option<String>,
ai_api_key: Option<String>,
) {
let captured_request = request.clone();
let captured_message_builder = message_builder.clone();
let captured_original_registry = original_registry.clone();
let captured_base_url = ai_base_url;
let captured_api_key = ai_api_key;
registry.register(
ToolDefinition::new("call_sub_agent")
.description(
"Delegate a task to a specialist sub-agent and receive its output.\n\
Available roles:\n\
- researcher: Gathers facts, evidence, and data. Best for finding information and searching code.\n\
- analyst: Builds explanations, highlights causal links and tradeoffs. Best for reasoning about implications.\n\
- reviewer: Stress-tests proposals, identifies risks and contradictions. Best for quality checks.\n\
Provide a clear, focused task description so the sub-agent knows exactly what to investigate.",
)
.parameters(ToolSchema {
schema_type: "object".into(),
properties: Some({
let mut p = HashMap::new();
p.insert(
"role".into(),
ToolParam {
name: "role".into(),
param_type: "string".into(),
description: Some(
"The sub-agent role to delegate to: researcher, analyst, or reviewer.".into(),
),
required: true,
properties: None,
items: None,
},
);
p.insert(
"task".into(),
ToolParam {
name: "task".into(),
param_type: "string".into(),
description: Some(
"The specific task or question for the sub-agent. Be precise and focused.".into(),
),
required: true,
properties: None,
items: None,
},
);
p
}),
required: Some(vec!["role".into(), "task".into()]),
}),
ToolHandler::new(move |_ctx, args| {
// Extract owned values from args before async move (avoid borrowing across boundary)
let role = args
.get("role")
.and_then(|v| v.as_str())
.unwrap_or("researcher")
.to_owned();
let task = args
.get("task")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_owned();
let profile = match role.as_str() {
"researcher" => researcher_profile(),
"analyst" => analyst_profile(),
"reviewer" => reviewer_profile(),
_ => researcher_profile(),
};
let mut sub_request = captured_request.clone();
sub_request.input = format!(
"Sub-agent role: {role}\n\nTask:\n{task}\n\nOriginal user request:\n{}\n\nInstructions:\nFocus only on your assigned task. Return concise, evidence-backed findings.",
captured_request.input
);
sub_request.execution_profile = Some(profile.clone());
sub_request.tools = Some(filter_tools_for_sub_agent(
&captured_request.tools,
&profile.allowed_tools,
));
sub_request.max_tool_depth = profile
.max_tool_depth
.unwrap_or(captured_request.max_tool_depth);
sub_request.temperature = profile.temperature.unwrap_or(captured_request.temperature);
sub_request.max_tokens = profile.max_tokens.unwrap_or(captured_request.max_tokens);
sub_request.top_p = profile.top_p.unwrap_or(captured_request.top_p);
sub_request.frequency_penalty = profile
.frequency_penalty
.unwrap_or(captured_request.frequency_penalty);
sub_request.presence_penalty = profile
.presence_penalty
.unwrap_or(captured_request.presence_penalty);
// Clone captured values for this invocation so the Fn closure retains them
let mb = captured_message_builder.clone();
let sub_registry = captured_original_registry.clone();
let base = captured_base_url.clone();
let key = captured_api_key.clone();
Box::pin(async move {
let result = execute_process(sub_request, &mb, &sub_registry, base, key).await;
match result {
Ok(r) => Ok(serde_json::json!({
"role": role,
"output": r.content,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
})),
Err(e) => Err(ToolError::ExecutionError(format!(
"Sub-agent '{}' execution failed: {}",
role, e
))),
}
})
}),
);
}
/// Filter the original tool definitions by the sub-agent's allowed list,
/// always excluding `call_sub_agent` to prevent recursive delegation.
fn filter_tools_for_sub_agent(
original_tools: &Option<Vec<serde_json::Value>>,
allowed_tools: &Option<Vec<String>>,
) -> Vec<serde_json::Value> {
let Some(tools) = original_tools else {
return Vec::new();
};
let allowed = allowed_tools
.as_ref()
.map(|list| list.iter().filter(|n| *n != "call_sub_agent").cloned().collect::<Vec<String>>());
match allowed {
Some(allowed_list) if !allowed_list.is_empty() => tools
.iter()
.filter(|tool| {
let name = tool
.get("function")
.and_then(|f| f.get("name"))
.and_then(|v| v.as_str())
.unwrap_or("");
allowed_list.iter().any(|allowed| allowed == name)
})
.cloned()
.collect(),
_ => tools
.iter()
.filter(|tool| {
tool
.get("function")
.and_then(|f| f.get("name"))
.and_then(|v| v.as_str())
.is_some_and(|name| name != "call_sub_agent")
})
.cloned()
.collect(),
}
}

View File

@ -128,7 +128,7 @@ impl ChatService {
/// Process AI request without streaming (tool-call loop with non-streaming API).
pub async fn process(&self, request: AiRequest) -> Result<ProcessResult> {
super::nonstreaming_execution::execute_process(
super::orchestrator::execute_orchestrated_process(
request,
&self.message_builder,
&self.tool_registry,
@ -144,7 +144,7 @@ impl ChatService {
request: AiRequest,
on_chunk: StreamCallback,
) -> Result<StreamResult> {
super::streaming_execution::execute_process_stream(
super::orchestrator::execute_orchestrated_stream(
request,
on_chunk,
&self.message_builder,
@ -155,6 +155,59 @@ impl ChatService {
.await
}
/// Process AI request for room context — direct execution path (bypasses orchestrator).
///
/// Room AI uses a fast single-agent loop: all tools available, no multi-agent delegation.
/// Merges `room_tools` (send_message, retract_message) into the base registry,
/// then runs `execute_process` / `execute_process_stream` directly.
pub async fn process_room(
&self,
request: AiRequest,
room_tools: ToolRegistry,
) -> Result<ProcessResult> {
let mut merged = self
.tool_registry
.clone()
.unwrap_or_default();
merged.merge(room_tools);
super::nonstreaming_execution::execute_process(
request,
&self.message_builder,
&Some(merged),
self.ai_base_url.clone(),
self.ai_api_key.clone(),
)
.await
}
/// Process AI request for room context with streaming — direct execution path.
///
/// Same as `process_room` but with streaming response. Bypasses orchestrator,
/// gives the room AI all tools (base + room) for fast single-agent execution.
pub async fn process_room_stream(
&self,
request: AiRequest,
on_chunk: StreamCallback,
room_tools: ToolRegistry,
) -> Result<StreamResult> {
let mut merged = self
.tool_registry
.clone()
.unwrap_or_default();
merged.merge(room_tools);
super::streaming_execution::execute_process_stream(
request,
on_chunk,
&self.message_builder,
&Some(merged),
self.ai_base_url.clone(),
self.ai_api_key.clone(),
)
.await
}
/// Process AI request via rig-based ReAct streaming loop.
pub async fn process_react<C, Fut>(
&self,

View File

@ -9,9 +9,9 @@ use super::message_builder::MessageBuilder;
use super::service::StreamResult;
use super::session_recording::record_ai_session;
use super::{AiChunkType, AiRequest, AiStreamChunk, StreamCallback};
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::client::AiClientConfig;
use crate::client::{call_stream, StreamChunk, StreamChunkType, StreamedToolCall};
use crate::error::Result;
use crate::perception::{SkillEntry, ToolCallEvent};
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
@ -42,13 +42,22 @@ pub async fn execute_process_stream(
.await?;
let model_name = request.model.name.clone();
let temperature = room_ai_config
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
let profile = request.execution_profile.as_ref();
let temperature = profile
.and_then(|p| p.temperature.map(|v| v as f32))
.or_else(|| {
room_ai_config
.as_ref()
.and_then(|r| r.temperature.map(|v| v as f32))
})
.unwrap_or(request.temperature as f32);
let max_tokens = room_ai_config
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
let max_tokens = profile
.and_then(|p| p.max_tokens.map(|v| v as u32))
.or_else(|| {
room_ai_config
.as_ref()
.and_then(|r| r.max_tokens.map(|v| v as u32))
})
.unwrap_or(request.max_tokens as u32);
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
@ -84,6 +93,7 @@ pub async fn execute_process_stream(
done: false,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
});
fut
}),
@ -93,6 +103,7 @@ pub async fn execute_process_stream(
done: false,
chunk_type: AiChunkType::Thinking,
metadata: None,
children_id: None,
});
fut
}),
@ -188,6 +199,7 @@ pub async fn execute_process_stream(
done: true,
chunk_type: AiChunkType::Answer,
metadata: None,
children_id: None,
})
.await;
all_chunks.push(StreamChunk {
@ -253,8 +265,8 @@ async fn drain_tool_call_notifications(
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: Some(metadata),
})
.await;
children_id: None,
}).await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
@ -356,6 +368,7 @@ async fn execute_streaming_tools(
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: Some(metadata),
children_id: Some(call.id.clone()),
}).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
@ -388,6 +401,7 @@ async fn execute_streaming_tools(
done: false,
chunk_type: AiChunkType::ToolResult,
metadata: Some(metadata),
children_id: None,
}).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
@ -396,7 +410,7 @@ async fn execute_streaming_tools(
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await;
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None, children_id: None }).await;
}
}
}

View File

@ -89,6 +89,7 @@ pub struct AiCallResponse {
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: i64,
pub tool_calls: Vec<ClientToolCall>,
pub tool_calls_finished: Vec<String>,
}
@ -275,7 +276,7 @@ async fn do_completion<M>(
max_tokens: Option<u32>,
tools: Option<&[serde_json::Value]>,
tool_choice: Option<&str>,
) -> Result<(String, u64, u64, Vec<String>)>
) -> Result<(String, u64, u64, Vec<ClientToolCall>, Vec<String>)>
where
M: CompletionModel<Client = openai::Client>,
{
@ -342,6 +343,7 @@ where
let mut content = String::new();
let mut tool_names: Vec<String> = Vec::new();
let mut tool_calls: Vec<ClientToolCall> = Vec::new();
for item in response.choice {
match item {
AssistantContent::Text(t) => {
@ -349,6 +351,15 @@ where
}
AssistantContent::ToolCall(tc) => {
tool_names.push(tc.function.name.clone());
tool_calls.push(ClientToolCall {
id: tc.id,
type_: "function".into(),
function: types::ToolCallFunction {
name: tc.function.name,
arguments: serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
},
});
}
AssistantContent::Reasoning(_) => {}
AssistantContent::Image(_) => {}
@ -358,7 +369,7 @@ where
let input_tokens = response.usage.input_tokens;
let output_tokens = response.usage.output_tokens;
Ok((content, input_tokens, output_tokens, tool_names))
Ok((content, input_tokens, output_tokens, tool_calls, tool_names))
}
// ── Public API ───────────────────────────────────────────────────────────────
@ -380,7 +391,7 @@ pub async fn call_with_retry(
let result = do_completion(&model, messages, None, None, None, None).await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(
@ -393,6 +404,7 @@ pub async fn call_with_retry(
input_tokens: input_tokens as i64,
output_tokens: output_tokens as i64,
latency_ms,
tool_calls,
tool_calls_finished: tool_names,
});
}
@ -446,7 +458,7 @@ pub async fn call_with_params(
.await;
match result {
Ok((content, input_tokens, output_tokens, tool_names)) => {
Ok((content, input_tokens, output_tokens, tool_calls, tool_names)) => {
let latency_ms = start.elapsed().as_millis() as i64;
let has_function_call = !tool_names.is_empty();
ai_metrics().record_success(
@ -459,6 +471,7 @@ pub async fn call_with_params(
input_tokens: input_tokens as i64,
output_tokens: output_tokens as i64,
latency_ms,
tool_calls,
tool_calls_finished: tool_names,
});
}
@ -500,6 +513,7 @@ pub enum StreamChunkType {
Thinking,
Answer,
ToolCall,
ToolResult,
}
/// A single chunk from the streaming response in arrival order.

View File

@ -18,6 +18,7 @@ pub struct CompactService {
db: DatabaseConnection,
ai_client_config: crate::client::AiClientConfig,
model: String,
model_context_limit: Option<usize>,
}
impl CompactService {
@ -30,6 +31,7 @@ impl CompactService {
db,
ai_client_config,
model,
model_context_limit: None,
}
}
@ -38,6 +40,17 @@ impl CompactService {
db: self.db.clone(),
ai_client_config: self.ai_client_config.clone(),
model: model.into(),
model_context_limit: self.model_context_limit,
}
}
pub fn with_model_context_limit(mut self, model_context_limit: Option<usize>) -> Self {
self.model_context_limit = model_context_limit.filter(|limit| *limit > 0);
self
}
pub fn for_model_entry(&self, model: &models::agents::model::Model) -> Self {
self.for_model(model.name.clone())
.with_model_context_limit(Some(model.context_length.max(0) as usize))
}
}

View File

@ -141,6 +141,15 @@ impl super::CompactService {
.collect()
}
fn resolve_retain_count(config: CompactConfig, estimated_tokens: usize) -> usize {
let level = if config.auto_level {
CompactLevel::auto_select(estimated_tokens, config.token_threshold)
} else {
config.default_level
};
level.retain_count()
}
pub async fn prepare_room_compact_context(
&self,
room_id: uuid::Uuid,
@ -188,10 +197,10 @@ impl super::CompactService {
let estimated_tokens = crate::tokent::count_message_text(&estimate_input, &self.model)
.unwrap_or_else(|_| estimate_input.len() / 4);
let retain_count = Self::resolve_retain_count(config, estimated_tokens);
if estimated_tokens >= config.token_threshold
&& messages.len() > config.default_level.retain_count()
&& messages.len() > retain_count
{
let retain_count = config.default_level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let from_seq = to_summarize
@ -303,8 +312,10 @@ impl super::CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens =
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let max_summary_tokens = CompactConfig::summary_token_budget(
context_window_tokens.max(0) as usize,
compaction_max_summary_ratio,
);
let (summary, remote_usage) = self
.summarize_messages(to_summarize, max_summary_tokens)
@ -384,8 +395,10 @@ impl super::CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens =
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let max_summary_tokens = CompactConfig::summary_token_budget(
context_window_tokens.max(0) as usize,
compaction_max_summary_ratio,
);
let (summary, remote_usage) = self
.summarize_messages(to_summarize, max_summary_tokens)

View File

@ -2,11 +2,22 @@ use models::rooms::room_message::Model as RoomMessageModel;
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use crate::AgentError;
use crate::client::call_with_params;
use crate::client::types::ChatRequestMessage;
use crate::compact::types::MessageSummary;
use crate::tokent::TokenUsage;
use crate::compact::types::{CompactConfig, MessageSummary};
use crate::tokent::{TokenUsage, count_message_text};
use crate::AgentError;
const DEFAULT_MODEL_CONTEXT_LIMIT: usize = 128_000;
const MODEL_INPUT_RATIO_NUMERATOR: usize = 85;
const MODEL_INPUT_RATIO_DENOMINATOR: usize = 100;
const MIN_ROUND_SUMMARY_TOKENS: usize = 64;
#[derive(Clone, Copy)]
enum SummaryKind {
Conversation,
RoomIncrement,
}
impl super::CompactService {
pub async fn summarize_room_increment(
@ -23,52 +34,28 @@ impl super::CompactService {
.collect();
let user_name_map = self.get_user_name_map(&user_ids).await?;
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let blocks = messages
.iter()
.map(|m| {
let sender = if let Some(user_id) = m.sender_id {
user_name_map
.get(&user_id)
.cloned()
.unwrap_or_else(|| m.sender_type.to_string())
} else {
m.sender_type.to_string()
};
format!("[{}] {}: {}", m.send_at, sender, m.content)
})
.collect::<Vec<_>>();
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let previous = previous_summary
.filter(|s| !s.trim().is_empty())
.map(|s| format!("Previous compressed room summary:\n{}\n\n", s.trim()))
.unwrap_or_default();
let user_msg = ChatRequestMessage::user(format!(
"Create an incremental room summary. Start from the previous summary if present, \
then merge the new messages below. Deduplicate repeated messages, clean noise, \
keep chronological order, and preserve decisions, facts, assignments/owners, \
unresolved questions, and concrete next steps. The result MUST NOT exceed {} tokens.\n\n\
Format:\n\
**Summary:** <compact overview>\n\
**Decisions:** <bullets or 'none'>\n\
**Owners:** <bullets with owner -> task or 'none'>\n\
**Open items:** <bullets or 'none'>\n\n\
{}New messages:\n\n{}",
max_summary_tokens, previous, body
));
let response = call_with_params(
&[user_msg],
&self.model,
&self.ai_client_config,
0.2,
max_summary_tokens.min(4096) as u32,
None,
None,
None,
self.summarize_blocks_with_optional_previous(
blocks,
previous_summary,
max_summary_tokens,
SummaryKind::RoomIncrement,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, remote_usage))
}
pub async fn summarize_messages(
@ -84,47 +71,28 @@ impl super::CompactService {
.collect();
let user_name_map = self.get_user_name_map(&user_ids).await?;
let blocks = messages
.iter()
.map(|m| {
let sender = if let Some(user_id) = m.sender_id {
user_name_map
.get(&user_id)
.cloned()
.unwrap_or_else(|| m.sender_type.to_string())
} else {
m.sender_type.to_string()
};
format!("[{}] {}: {}", m.send_at, sender, m.content)
})
.collect::<Vec<_>>();
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens, body
));
let response = call_with_params(
&[user_msg],
&self.model,
&self.ai_client_config,
0.3,
2048,
None,
None,
self.summarize_blocks_with_optional_previous(
blocks,
None,
max_summary_tokens,
SummaryKind::Conversation,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, remote_usage))
}
pub fn message_to_summary(
@ -169,4 +137,377 @@ impl super::CompactService {
}
Ok(map)
}
async fn summarize_blocks_with_optional_previous(
&self,
blocks: Vec<String>,
previous_summary: Option<&str>,
max_summary_tokens: usize,
kind: SummaryKind,
) -> Result<(String, Option<TokenUsage>), AgentError> {
let final_budget = Self::final_summary_budget(max_summary_tokens);
let input_budget = self.safe_model_input_budget();
let round_budget = Self::round_summary_budget(final_budget, input_budget);
let mut total_usage = TokenUsage::default();
let mut has_usage = false;
let fitted_chunks =
self.split_blocks_to_fit(blocks, input_budget, round_budget, kind, false)?;
let mut partial_summaries = Vec::new();
for chunk in fitted_chunks {
let prompt = self.build_prompt(kind, false, &chunk, round_budget);
let (summary, usage) = self
.invoke_summary_prompt(&prompt, round_budget, Self::temperature_for(kind))
.await?;
Self::accumulate_usage(&mut total_usage, &mut has_usage, usage);
partial_summaries.push(summary);
}
if let Some(previous) = previous_summary
.map(str::trim)
.filter(|summary| !summary.is_empty())
{
partial_summaries.insert(0, previous.to_string());
}
if partial_summaries.is_empty() {
return Ok((String::new(), None));
}
if partial_summaries.len() == 1 && previous_summary.is_none() {
return Ok((
partial_summaries.remove(0),
if has_usage { Some(total_usage) } else { None },
));
}
let final_summary = self
.merge_summary_rounds(
partial_summaries,
final_budget,
round_budget,
kind,
&mut total_usage,
&mut has_usage,
)
.await?;
Ok((final_summary, if has_usage { Some(total_usage) } else { None }))
}
async fn merge_summary_rounds(
&self,
mut summaries: Vec<String>,
final_budget: usize,
round_budget: usize,
kind: SummaryKind,
total_usage: &mut TokenUsage,
has_usage: &mut bool,
) -> Result<String, AgentError> {
let input_budget = self.safe_model_input_budget();
while summaries.len() > 1 {
let current_budget = if summaries.len() <= 2 {
final_budget
} else {
round_budget
};
let mut next_round = Vec::new();
let mut idx = 0usize;
while idx < summaries.len() {
if idx + 1 >= summaries.len() {
next_round.push(summaries[idx].clone());
idx += 1;
continue;
}
let pair = vec![summaries[idx].clone(), summaries[idx + 1].clone()];
let fitted_pairs =
self.split_blocks_to_fit(pair, input_budget, current_budget, kind, true)?;
for pair_text in fitted_pairs {
let prompt = self.build_prompt(kind, true, &pair_text, current_budget);
let (summary, usage) = self
.invoke_summary_prompt(
&prompt,
current_budget,
Self::temperature_for(kind),
)
.await?;
Self::accumulate_usage(total_usage, has_usage, usage);
next_round.push(summary);
}
idx += 2;
}
summaries = next_round;
}
summaries
.pop()
.ok_or_else(|| AgentError::Internal("summary merge produced no output".into()))
}
async fn invoke_summary_prompt(
&self,
prompt: &str,
max_summary_tokens: usize,
temperature: f32,
) -> Result<(String, Option<TokenUsage>), AgentError> {
let response = call_with_params(
&[ChatRequestMessage::user(prompt.to_string())],
&self.model,
&self.ai_client_config,
temperature,
max_summary_tokens as u32,
None,
None,
None,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, usage))
}
fn split_blocks_to_fit(
&self,
blocks: Vec<String>,
input_budget: usize,
max_summary_tokens: usize,
kind: SummaryKind,
is_merge: bool,
) -> Result<Vec<String>, AgentError> {
let mut chunks = Vec::new();
self.collect_fitting_chunks(
blocks,
input_budget,
max_summary_tokens,
kind,
is_merge,
&mut chunks,
)?;
Ok(chunks)
}
fn collect_fitting_chunks(
&self,
blocks: Vec<String>,
input_budget: usize,
max_summary_tokens: usize,
kind: SummaryKind,
is_merge: bool,
chunks: &mut Vec<String>,
) -> Result<(), AgentError> {
let body = Self::join_blocks(&blocks, is_merge);
let prompt = self.build_prompt(kind, is_merge, &body, max_summary_tokens);
if self.estimate_tokens(&prompt) <= input_budget {
chunks.push(body);
return Ok(());
}
if blocks.len() > 1 {
let mid = blocks.len() / 2;
self.collect_fitting_chunks(
blocks[..mid].to_vec(),
input_budget,
max_summary_tokens,
kind,
is_merge,
chunks,
)?;
self.collect_fitting_chunks(
blocks[mid..].to_vec(),
input_budget,
max_summary_tokens,
kind,
is_merge,
chunks,
)?;
return Ok(());
}
let single = blocks
.into_iter()
.next()
.ok_or_else(|| AgentError::Internal("cannot split empty summary block".into()))?;
let (left, right) = Self::split_text_in_half(&single)?;
self.collect_fitting_chunks(
vec![left],
input_budget,
max_summary_tokens,
kind,
is_merge,
chunks,
)?;
self.collect_fitting_chunks(
vec![right],
input_budget,
max_summary_tokens,
kind,
is_merge,
chunks,
)?;
Ok(())
}
fn build_prompt(
&self,
kind: SummaryKind,
is_merge: bool,
body: &str,
max_summary_tokens: usize,
) -> String {
match (kind, is_merge) {
(SummaryKind::Conversation, false) => format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens, body
),
(SummaryKind::Conversation, true) => format!(
"Merge the following partial conversation summaries into a single concise summary. \
Deduplicate overlap, preserve chronology, and keep all concrete decisions, \
status updates, and unresolved work. The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Partial summaries:\n\n{}",
max_summary_tokens, body
),
(SummaryKind::RoomIncrement, false) => format!(
"Create an incremental room summary from the new messages below. \
Deduplicate repeated messages, clean noise, keep chronological order, and preserve \
decisions, facts, assignments/owners, unresolved questions, and concrete next steps. \
The result MUST NOT exceed {} tokens.\n\n\
Format:\n\
**Summary:** <compact overview>\n\
**Decisions:** <bullets or 'none'>\n\
**Owners:** <bullets with owner -> task or 'none'>\n\
**Open items:** <bullets or 'none'>\n\n\
New messages:\n\n{}",
max_summary_tokens, body
),
(SummaryKind::RoomIncrement, true) => format!(
"Merge the following partial room summaries into one room summary. Deduplicate overlap, \
keep chronology, preserve decisions, facts, assignments/owners, unresolved questions, \
and concrete next steps. The result MUST NOT exceed {} tokens.\n\n\
Format:\n\
**Summary:** <compact overview>\n\
**Decisions:** <bullets or 'none'>\n\
**Owners:** <bullets with owner -> task or 'none'>\n\
**Open items:** <bullets or 'none'>\n\n\
Partial summaries:\n\n{}",
max_summary_tokens, body
),
}
}
fn join_blocks(blocks: &[String], is_merge: bool) -> String {
if is_merge {
blocks
.iter()
.enumerate()
.map(|(index, block)| format!("### Partial Summary {}\n{}", index + 1, block))
.collect::<Vec<_>>()
.join("\n\n")
} else {
blocks.join("\n")
}
}
fn split_text_in_half(text: &str) -> Result<(String, String), AgentError> {
if text.chars().count() < 2 {
return Err(AgentError::Internal(
"single summary block exceeds input budget and cannot be split".into(),
));
}
let midpoint = text.len() / 2;
let mut split_at = text.floor_char_boundary(midpoint);
if split_at == 0 || split_at >= text.len() {
split_at = text.ceil_char_boundary(midpoint);
}
if split_at == 0 || split_at >= text.len() {
return Err(AgentError::Internal(
"failed to split oversized summary block".into(),
));
}
Ok((text[..split_at].to_string(), text[split_at..].to_string()))
}
fn estimate_tokens(&self, text: &str) -> usize {
count_message_text(text, &self.model).unwrap_or_else(|_| (text.len() / 4).max(1))
}
fn safe_model_input_budget(&self) -> usize {
Self::safe_model_input_budget_from_limit(self.model_context_limit)
}
fn final_summary_budget(max_summary_tokens: usize) -> usize {
max_summary_tokens.clamp(
CompactConfig::MIN_SUMMARY_TOKENS,
CompactConfig::MAX_SUMMARY_TOKENS,
)
}
fn round_summary_budget(final_budget: usize, input_budget: usize) -> usize {
final_budget.min((input_budget / 8).max(MIN_ROUND_SUMMARY_TOKENS))
}
fn temperature_for(kind: SummaryKind) -> f32 {
match kind {
SummaryKind::Conversation => 0.3,
SummaryKind::RoomIncrement => 0.2,
}
}
fn safe_model_input_budget_from_limit(model_context_limit: Option<usize>) -> usize {
let context_limit = model_context_limit.unwrap_or(DEFAULT_MODEL_CONTEXT_LIMIT).max(1);
context_limit
.saturating_mul(MODEL_INPUT_RATIO_NUMERATOR)
.saturating_div(MODEL_INPUT_RATIO_DENOMINATOR)
.max(1)
}
fn accumulate_usage(
total: &mut TokenUsage,
has_usage: &mut bool,
usage: Option<TokenUsage>,
) {
if let Some(usage) = usage {
total.input_tokens += usage.input_tokens;
total.output_tokens += usage.output_tokens;
*has_usage = true;
}
}
}
#[cfg(test)]
mod tests {
use super::super::CompactService;
#[test]
fn room_summary_uses_eighty_five_percent_input_budget() {
assert_eq!(CompactService::safe_model_input_budget_from_limit(Some(1000)), 850);
}
#[test]
fn oversized_text_is_split_in_half() {
let (left, right) = CompactService::split_text_in_half("abcdefgh").unwrap();
assert_eq!(format!("{}{}", left, right), "abcdefgh");
assert!(!left.is_empty());
assert!(!right.is_empty());
}
}

View File

@ -112,19 +112,39 @@ impl Default for CompactConfig {
}
impl CompactConfig {
pub const MIN_SUMMARY_TOKENS: usize = 256;
pub const MAX_SUMMARY_TOKENS: usize = 4096;
/// Build config from project context settings.
pub fn from_project_setting(
context_window_tokens: i32,
compaction_threshold: f32,
compaction_max_summary_ratio: f32,
) -> Self {
let threshold = (context_window_tokens as f32 * compaction_threshold) as usize;
let context_window_tokens = context_window_tokens.max(0) as usize;
let threshold = (context_window_tokens as f32 * compaction_threshold.max(0.0)) as usize;
Self {
token_threshold: threshold,
auto_level: true,
default_level: CompactLevel::Light,
max_summary_tokens: (context_window_tokens as f32 * compaction_max_summary_ratio)
as usize,
max_summary_tokens: Self::summary_token_budget(
context_window_tokens,
compaction_max_summary_ratio,
),
}
}
pub fn summary_token_budget(
context_window_tokens: usize,
compaction_max_summary_ratio: f32,
) -> usize {
let ratio = compaction_max_summary_ratio.max(0.0);
let raw_budget = (context_window_tokens as f32 * ratio) as usize;
if raw_budget == 0 {
Self::MIN_SUMMARY_TOKENS
} else {
raw_budget.clamp(Self::MIN_SUMMARY_TOKENS, Self::MAX_SUMMARY_TOKENS)
}
}
}
@ -170,3 +190,20 @@ impl MessageSummary {
}
}
}
#[cfg(test)]
mod tests {
use super::CompactConfig;
#[test]
fn summary_budget_has_minimum_floor() {
assert_eq!(CompactConfig::summary_token_budget(0, 0.0), 256);
assert_eq!(CompactConfig::summary_token_budget(128_000, 0.0), 256);
assert_eq!(CompactConfig::summary_token_budget(1_000, 0.01), 256);
}
#[test]
fn summary_budget_is_capped() {
assert_eq!(CompactConfig::summary_token_budget(128_000, 0.2), 4096);
}
}

View File

@ -15,12 +15,12 @@ pub mod task;
pub mod tokent;
pub mod tool;
pub use billing::{
BillingRecord, BillingResult, check_balance, initialize_project_billing,
initialize_user_billing, persist_billing_error, record_ai_usage,
BillingRecord, BillingResult, check_balance, check_user_balance, initialize_project_billing,
initialize_user_billing, persist_billing_error, record_ai_usage, record_user_ai_usage,
};
pub use chat::{
AiContextSenderType, AiRequest, AiStreamChunk, ChatService, Mention, RoomMessageContext,
StreamCallback,
AgentExecutionProfile, AgentRole, AiContextSenderType, AiRequest, AiStreamChunk,
ChatService, Mention, RoomMessageContext, StreamCallback,
};
pub use client::types::ChatRequestMessage;
pub use client::{AiCallResponse, AiClientConfig, call_with_params, call_with_retry};

View File

@ -4,6 +4,7 @@
//! request metadata, and the tool registry. Cheap to clone via `Arc`.
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use config::AppConfig;
use db::cache::AppCache;
@ -40,7 +41,7 @@ struct Inner {
pub sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
depth: u32,
max_depth: u32,
tool_call_count: usize,
tool_call_count: Arc<AtomicUsize>,
max_tool_calls: usize,
}
@ -68,7 +69,7 @@ impl ToolContext {
sent_in_turn: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
depth: 0,
max_depth: 5,
tool_call_count: 0,
tool_call_count: Arc::new(AtomicUsize::new(0)),
max_tool_calls: 128,
}),
}
@ -145,7 +146,7 @@ impl ToolContext {
}
pub fn tool_calls_exceeded(&self) -> bool {
self.inner.tool_call_count >= self.inner.max_tool_calls
self.inner.tool_call_count.load(Ordering::Relaxed) >= self.inner.max_tool_calls
}
/// Current recursion depth.
@ -155,12 +156,37 @@ impl ToolContext {
/// Current tool call count.
pub fn tool_call_count(&self) -> usize {
self.inner.tool_call_count
self.inner.tool_call_count.load(Ordering::Relaxed)
}
/// Increments the tool call count.
pub(crate) fn increment_tool_calls(&mut self) {
Arc::make_mut(&mut self.inner).tool_call_count += 1;
/// Reserves a number of tool calls for this shared execution context.
pub(crate) fn reserve_tool_calls(&self, additional: usize) -> Result<(), usize> {
if additional == 0 {
return Ok(());
}
let current = &self.inner.tool_call_count;
let max_tool_calls = self.inner.max_tool_calls;
loop {
let existing = current.load(Ordering::Relaxed);
let Some(next) = existing.checked_add(additional) else {
return Err(existing);
};
if next > max_tool_calls {
return Err(existing);
}
match current.compare_exchange(
existing,
next,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(_) => continue,
}
}
}
/// Returns a child context for a recursive tool call (depth + 1).

View File

@ -55,16 +55,20 @@ impl ToolExecutor {
calls: Vec<ToolCall>,
ctx: &mut ToolContext,
) -> Result<Vec<ToolCallResult>, ToolError> {
if ctx.tool_calls_exceeded() {
return Err(ToolError::MaxToolCallsExceeded(ctx.tool_call_count()));
}
let ctx = ctx
.clone()
.with_max_tool_calls(self.max_tool_calls)
.with_max_depth(self.max_depth);
if ctx.recursion_exceeded() {
return Err(ToolError::RecursionLimitExceeded {
max_depth: ctx.depth(),
max_depth: self.max_depth,
});
}
ctx.increment_tool_calls();
if let Err(current) = ctx.reserve_tool_calls(calls.len()) {
return Err(ToolError::MaxToolCallsExceeded(current));
}
let concurrency = self.max_concurrency;
let calls_clone: Vec<ToolCall> = calls.clone();

View File

@ -226,9 +226,9 @@ pub async fn message_stream(
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
// Verify user owns the conversation
// Streaming triggers AI execution and billing, so view-only access is not enough.
let conv = service
.find_conversation_owned(conversation_id, user_id)
.find_conversation_full_access(conversation_id, user_id)
.await?;
let model = conv.model;

View File

@ -2,9 +2,19 @@ use actix_web::web;
pub mod handlers;
pub mod stream;
pub mod subagent;
pub mod watch;
pub fn init_chat_routes(cfg: &mut web::ServiceConfig) {
cfg.route(
"/ai/subagent/{conversation_id}/{children_id}/stream",
web::get().to(subagent::subagent_stream_watch),
)
.route(
"/ai/subagent/{conversation_id}/{children_id}/stop",
web::post().to(subagent::subagent_stop),
);
cfg.service(
web::scope("/ai/conversations")
.route(
@ -28,6 +38,14 @@ pub fn init_chat_routes(cfg: &mut web::ServiceConfig) {
"/{conversation_id}/watch",
web::get().to(watch::conversation_watch),
)
.route(
"/subagent/{conversation_id}/{children_id}/stream",
web::get().to(subagent::subagent_stream_watch),
)
.route(
"/subagent/{conversation_id}/{children_id}/stop",
web::post().to(subagent::subagent_stop),
)
.route(
"/{conversation_id}/share",
web::post().to(handlers::share::conversation_share),

View File

@ -1,8 +1,8 @@
use agent::chat::chat_execution;
use agent::chat::{AiChunkType, AiStreamChunk, normalize_thinking_content};
use agent::client::AiClientConfig;
use agent::client::StreamChunkType;
use agent::client::types::ChatRequestMessage;
use agent::client::{StreamChunk, StreamChunkType};
use agent::react::PERSONAL_CONTEXT_PROMPT;
use futures::StreamExt;
use models::agents::{model, model_version};
@ -22,7 +22,7 @@ use uuid::Uuid;
///
/// Also publishes chat messages and stream chunks via NATS JetStream for
/// multi-viewer support. The requesting client receives SSE events, while
/// other viewers receive chunks via NATS WebSocket broadcast.
/// other viewers receive chunks via NATS -> WebSocket broadcast.
pub fn create_chat_sse_stream(
service: AppService,
conversation_id: Uuid,
@ -155,45 +155,103 @@ pub fn create_chat_sse_stream(
messages
};
// Pre-flight balance check: verify project + user can afford at least a minimal AI call
if !is_personal {
let balance_ok = agent::billing::check_balance(
let (model_record, billing_version_id) = match model::Entity::find()
.filter(model::Column::Name.eq(&model_name))
.one(service.db.reader())
.await
{
Ok(Some(m)) => {
let version_id = model_version::Entity::find()
.filter(model_version::Column::ModelId.eq(m.id))
.filter(model_version::Column::Status.eq("active"))
.order_by_desc(model_version::Column::IsDefault)
.order_by_desc(model_version::Column::ReleaseDate)
.one(service.db.reader())
.await
.ok()
.flatten()
.map(|v| v.id);
match version_id {
Some(version_id) => (m, version_id),
None => {
let error_msg = "AI model version is not configured. Please configure an active model version before using AI.";
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send(
"data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n"
.to_string(),
)
.await;
return;
}
}
}
_ => {
let error_msg = "AI model is not configured. Please sync or configure the model before using AI.";
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
.await;
return;
}
};
// Pre-flight balance check: verify the selected account can afford a minimal AI call.
let balance_ok = if is_personal {
agent::billing::check_user_balance(&service.db, user_id, billing_version_id, 500, 250)
.await
} else {
agent::billing::check_balance(
&service.db,
project_id,
user_id,
Uuid::nil(),
billing_version_id,
500,
250,
)
.await;
.await
};
match balance_ok {
Ok(true) => {}
Ok(false) => {
tracing::warn!(project_id = %project_id, user_id = %user_id, "Insufficient balance for chat AI call");
match balance_ok {
Ok(true) => {}
Ok(false) => {
tracing::warn!(project_id = %project_id, user_id = %user_id, personal = is_personal, "Insufficient balance for chat AI call");
let _ = agent::billing::persist_billing_error(
&service.db, "user", user_id, "insufficient_balance",
&format!("Insufficient balance. Your account does not have enough funds for this AI request."),
Some(serde_json::json!({
"user_id": user_id.to_string(),
"project_id": project_id.to_string(),
})),
).await;
let (scope, scope_id) = if is_personal {
("user", user_id)
} else {
("project", project_id)
};
let _ = agent::billing::persist_billing_error(
&service.db, scope, scope_id, "insufficient_balance",
"Insufficient balance. Your account does not have enough funds for this AI request.",
Some(serde_json::json!({
"user_id": user_id.to_string(),
"project_id": if is_personal { None } else { Some(project_id.to_string()) },
"model_version_id": billing_version_id.to_string(),
})),
).await;
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send(
"data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string(),
)
.await;
return;
}
Err(e) => {
tracing::warn!(error = %e, "Balance check failed, proceeding without pre-flight check");
}
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
.await;
return;
}
Err(e) => {
tracing::warn!(error = %e, "Balance check failed");
let error_msg = format!("Billing check failed: {}", e);
let payload = serde_json::json!({"event":"billing_error","data":error_msg});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string())
.await;
return;
}
}
@ -224,9 +282,10 @@ pub fn create_chat_sse_stream(
// Clear any stale cancel flag before starting
let _ = cache.clear_chat_stream_cancelled(conversation_id).await;
// Cancellation token checked in on_chunk and by a periodic poller
// Cancellation token checked in on_chunk and by a periodic poller.
let cancelled = Arc::new(std::sync::atomic::AtomicBool::new(false));
let cancelled_for_on_chunk = cancelled.clone();
let recorded_chunks = Arc::new(tokio::sync::Mutex::new(Vec::<StreamChunk>::new()));
let on_chunk_tx = tx.clone();
let on_chunk_queue = queue.clone();
@ -234,6 +293,7 @@ pub fn create_chat_sse_stream(
let on_chunk_conv_id = conversation_id;
let on_chunk_msg_id = user_message_id;
let on_chunk_model = model_name.clone();
let on_chunk_recorded = recorded_chunks.clone();
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
let tx = on_chunk_tx.clone();
@ -243,28 +303,30 @@ pub fn create_chat_sse_stream(
let msg_id = on_chunk_msg_id;
let model = on_chunk_model.clone();
let cancelled = cancelled_for_on_chunk.clone();
let recorded = on_chunk_recorded.clone();
Box::pin(async move {
// Check if stream has been cancelled
if cancelled.load(Ordering::Acquire) {
return;
}
let event = match chunk.chunk_type {
let chunk_type = chunk.chunk_type.clone();
let event = match &chunk_type {
AiChunkType::Thinking => "thinking",
AiChunkType::Answer => "token",
AiChunkType::ToolCall => "tool_call",
AiChunkType::ToolResult => "tool_result",
};
let content = match chunk.chunk_type {
let content = match &chunk_type {
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
_ => chunk.content.clone(),
};
// Build structured data payload based on chunk type
let data_json = match chunk.chunk_type {
let data_json = match &chunk_type {
AiChunkType::ToolCall | AiChunkType::ToolResult => {
// Use structured metadata if available
if let Some(meta) = chunk.metadata {
if let Some(meta) = chunk.metadata.clone() {
meta
} else {
// Fallback: wrap raw content as display text
@ -273,14 +335,38 @@ pub fn create_chat_sse_stream(
}
_ => {
// thinking / answer: send plain text content
serde_json::Value::String(content)
serde_json::Value::String(content.clone())
}
};
let persisted_content = match &chunk_type {
AiChunkType::ToolCall | AiChunkType::ToolResult => data_json.to_string(),
_ => content.clone(),
};
let persisted_type = match &chunk_type {
AiChunkType::Thinking => StreamChunkType::Thinking,
AiChunkType::Answer => StreamChunkType::Answer,
AiChunkType::ToolCall => StreamChunkType::ToolCall,
AiChunkType::ToolResult => StreamChunkType::ToolResult,
};
recorded.lock().await.push(StreamChunk {
chunk_type: persisted_type,
content: persisted_content,
});
let mut sse_json = serde_json::json!({
"event": event,
"data": data_json,
});
if let Some(children_id) = chunk.children_id {
sse_json.as_object_mut().unwrap().insert(
"children_id".to_string(),
serde_json::Value::String(children_id),
);
}
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event,
serde_json::to_string(&data_json).unwrap_or_default()
"data: {}\n\n",
serde_json::to_string(&sse_json).unwrap_or_default()
);
let _ = tx.send(sse).await;
@ -299,36 +385,30 @@ pub fn create_chat_sse_stream(
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
});
let cancelled_for_check = cancelled.clone();
let cache_for_check = cache.clone();
let conv_id_for_check = conversation_id;
let (done_tx, mut done_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
tokio::select! {
_ = interval.tick() => {
if cache_for_check.is_chat_stream_cancelled(conv_id_for_check).await {
cancelled_for_check.store(true, Ordering::Release);
break;
}
let cancel_wait = {
let cache_for_check = cache.clone();
let conv_id_for_check = conversation_id;
async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(250));
loop {
interval.tick().await;
if cache_for_check
.is_chat_stream_cancelled(conv_id_for_check)
.await
{
break;
}
_ = &mut done_rx => break,
}
}
});
// Resolve max_tokens from model config (unlimited if not set)
let max_tokens = match model::Entity::find()
.filter(model::Column::Name.eq(&model_name))
.one(service.db.reader())
.await
{
Ok(Some(m)) => m.max_output_tokens.map(|v| v as u32).unwrap_or(u32::MAX),
_ => u32::MAX,
};
let result = chat_execution::execute_chat_stream(
// Resolve max_tokens from model config (unlimited if not set)
let max_tokens = model_record
.max_output_tokens
.map(|v| v as u32)
.unwrap_or(u32::MAX);
let execution = chat_execution::execute_chat_stream(
messages,
tools,
&model_name,
@ -341,41 +421,70 @@ pub fn create_chat_sse_stream(
service.cache.clone(),
service.config.clone(),
project_id,
Uuid::nil(), // sender_uid unknown in Chat API context
Uuid::nil(), // sender_uid 閳?unknown in Chat API context
embed_service,
on_chunk,
Some(conversation_id),
)
.await;
Some(service.queue_producer.clone()),
);
let result = tokio::select! {
result = execution => Some(result),
_ = cancel_wait => {
cancelled.store(true, Ordering::Release);
None
}
};
// Clear stream active state and cancel flag (streaming finished)
let _ = cache.clear_chat_stream_active(conversation_id).await;
let _ = cache.clear_chat_stream_cancelled(conversation_id).await;
let was_cancelled = cancelled.load(Ordering::Acquire);
let _ = done_tx.send(());
match result {
Ok(stream_result) => {
Some(Ok(stream_result)) => {
if was_cancelled {
let partial_chunks = recorded_chunks.lock().await.clone();
if let Some(msg) = persist_assistant_message_from_chunks(
&service,
conversation_id,
user_message_id,
assistant_msg_id,
&model_name,
&partial_chunks,
&stream_result.content,
stream_result.input_tokens,
stream_result.output_tokens,
"cancelled",
)
.await
{
update_conversation_after_response(&service, conversation_id, &msg).await;
}
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"stopped\"}\n\n".to_string())
.await;
return;
}
// Build ordered content blocks from stream chunks, merging
// consecutive blocks of the same role (thinking/assistant).
// consecutive blocks of the same role (thinking/assistant/tool_call/tool_result).
let raw_blocks: Vec<(String, String)> = stream_result
.chunks
.iter()
.filter(|c| {
matches!(
c.chunk_type,
StreamChunkType::Thinking | StreamChunkType::Answer
StreamChunkType::Thinking
| StreamChunkType::Answer
| StreamChunkType::ToolCall
| StreamChunkType::ToolResult
)
})
.map(|chunk| {
let role = match chunk.chunk_type {
StreamChunkType::Thinking => "thinking",
StreamChunkType::ToolCall => "tool_call",
StreamChunkType::ToolResult => "tool_result",
_ => "assistant",
};
(role.to_string(), chunk.content.clone())
@ -384,7 +493,7 @@ pub fn create_chat_sse_stream(
let merged_blocks = merge_consecutive_blocks(raw_blocks);
// Apply thinking normalization to the fully merged thinking
// blocks per-token normalization is meaningless since each
// blocks 閳?per-token normalization is meaningless since each
// chunk is a single token.
let normalized_blocks: Vec<(String, String)> = merged_blocks
.into_iter()
@ -495,7 +604,7 @@ pub fn create_chat_sse_stream(
}
}
} else if let Some(title) = &existing_title {
// Title already set (e.g. by AI tool) emit it
// Title already set (e.g. by AI tool) 閳?emit it
let title_payload = serde_json::json!({"title": title}).to_string();
let _ = tx
.send(format!(
@ -508,57 +617,56 @@ pub fn create_chat_sse_stream(
}
// Record billing after successful AI response
let billing_version_id = match model::Entity::find()
.filter(model::Column::Name.eq(&model_name))
.one(service.db.reader())
.await
.ok()
.flatten()
{
Some(m) => {
let reader = service.db.reader();
model_version::Entity::find()
.filter(model_version::Column::ModelId.eq(m.id))
.filter(model_version::Column::Status.eq("active"))
.order_by_desc(model_version::Column::IsDefault)
.order_by_desc(model_version::Column::ReleaseDate)
.one(reader)
.await
.ok()
.flatten()
.map(|v| v.id)
}
None => None,
};
if let (Some(version_id), Some(_)) = (billing_version_id, conv_project_id) {
match agent::billing::record_ai_usage(
let billing_result = if is_personal {
agent::billing::record_user_ai_usage(
&service.db,
project_id,
user_id,
version_id,
billing_version_id,
stream_result.input_tokens,
stream_result.output_tokens,
)
.await
{
Ok(agent::billing::BillingResult::Success(record)) => {
tracing::info!(
cost = record.cost,
deducted_from = record.deducted_from.as_str(),
"chat_billing_deducted"
);
}
Ok(agent::billing::BillingResult::InsufficientBalance { .. }) => {
tracing::warn!(
project_id = %project_id,
user_id = %user_id,
"chat_billing_insufficient_balance"
);
}
Err(e) => {
tracing::error!(error = %e, "chat_billing_error");
}
} else {
agent::billing::record_ai_usage(
&service.db,
project_id,
user_id,
billing_version_id,
stream_result.input_tokens,
stream_result.output_tokens,
)
.await
};
let mut billing_failed = false;
match billing_result {
Ok(agent::billing::BillingResult::Success(record)) => {
tracing::info!(
cost = record.cost,
deducted_from = record.deducted_from.as_str(),
personal = is_personal,
"chat_billing_deducted"
);
}
Ok(agent::billing::BillingResult::InsufficientBalance { message }) => {
billing_failed = true;
tracing::warn!(
project_id = %project_id,
user_id = %user_id,
personal = is_personal,
"chat_billing_insufficient_balance"
);
let payload = serde_json::json!({"event":"billing_error","data":message});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
}
Err(e) => {
billing_failed = true;
tracing::error!(error = %e, "chat_billing_error");
let payload = serde_json::json!({
"event":"billing_error",
"data": format!("Billing failed: {}", e),
});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
}
}
@ -578,13 +686,89 @@ pub fn create_chat_sse_stream(
let _ = queue.publish_chat_message(&final_msg).await;
// Send final SSE done event
let done_data = if billing_failed {
"billing_error"
} else {
"ok"
};
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"ok\"}\n\n".to_string())
.send(format!(
"data: {{\"event\":\"done\",\"data\":\"{}\"}}\n\n",
done_data
))
.await;
}
Err(e) => {
None => {
let partial_chunks = recorded_chunks.lock().await.clone();
if let Some(msg) = persist_assistant_message_from_chunks(
&service,
conversation_id,
user_message_id,
assistant_msg_id,
&model_name,
&partial_chunks,
"",
0,
0,
"cancelled",
)
.await
{
update_conversation_after_response(&service, conversation_id, &msg).await;
let final_msg = ChatMessageEvent {
message_id: assistant_msg_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: assistant_plain_text(&msg.content),
model: Some(model_name.clone()),
input_tokens: msg.input_tokens,
output_tokens: msg.output_tokens,
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&final_msg).await;
}
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"stopped\"}\n\n".to_string())
.await;
}
Some(Err(e)) => {
let partial_chunks = recorded_chunks.lock().await.clone();
if let Some(msg) = persist_assistant_message_from_chunks(
&service,
conversation_id,
user_message_id,
assistant_msg_id,
&model_name,
&partial_chunks,
"",
0,
0,
"error",
)
.await
{
update_conversation_after_response(&service, conversation_id, &msg).await;
let final_msg = ChatMessageEvent {
message_id: assistant_msg_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: assistant_plain_text(&msg.content),
model: Some(model_name.clone()),
input_tokens: msg.input_tokens,
output_tokens: msg.output_tokens,
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&final_msg).await;
}
let payload = serde_json::json!({"event":"error","data": e.to_string()});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
let _ = tx
.send("data: {\"event\":\"done\",\"data\":\"error\"}\n\n".to_string())
.await;
}
}
});
@ -592,6 +776,112 @@ pub fn create_chat_sse_stream(
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
}
fn content_value_from_chunks(chunks: &[StreamChunk], fallback: &str) -> Option<serde_json::Value> {
let raw_blocks: Vec<(String, String)> = chunks
.iter()
.filter(|c| {
matches!(
c.chunk_type,
StreamChunkType::Thinking
| StreamChunkType::Answer
| StreamChunkType::ToolCall
| StreamChunkType::ToolResult
)
})
.map(|chunk| {
let role = match chunk.chunk_type {
StreamChunkType::Thinking => "thinking",
StreamChunkType::ToolCall => "tool_call",
StreamChunkType::ToolResult => "tool_result",
_ => "assistant",
};
(role.to_string(), chunk.content.clone())
})
.collect();
let merged_blocks = merge_consecutive_blocks(raw_blocks);
let normalized_blocks: Vec<(String, String)> = merged_blocks
.into_iter()
.map(|(role, content)| {
if role == "thinking" {
(role, normalize_thinking_content(&content))
} else {
(role, content)
}
})
.filter(|(_, content)| !content.is_empty())
.collect();
if normalized_blocks.is_empty() && fallback.is_empty() {
return None;
}
let content_blocks: Vec<serde_json::Value> = normalized_blocks
.iter()
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
.collect();
Some(if content_blocks.is_empty() {
serde_json::json!([{ "role": "assistant", "content": fallback }])
} else {
serde_json::json!(content_blocks)
})
}
fn assistant_plain_text(content: &serde_json::Value) -> String {
match content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
.collect::<Vec<_>>()
.join("\n"),
other => other.to_string(),
}
}
async fn persist_assistant_message_from_chunks(
service: &AppService,
conversation_id: Uuid,
user_message_id: Uuid,
assistant_msg_id: Uuid,
model_name: &str,
chunks: &[StreamChunk],
fallback: &str,
input_tokens: i64,
output_tokens: i64,
stop_reason: &str,
) -> Option<ai_message::Model> {
let content = content_value_from_chunks(chunks, fallback)?;
let assistant_msg = ai_message::ActiveModel {
id: Set(assistant_msg_id),
conversation_id: Set(conversation_id),
parent_message_id: Set(Some(user_message_id)),
role: Set("assistant".to_string()),
content: Set(content),
model: Set(Some(model_name.to_string())),
is_fork_origin: Set(false),
stop_reason: Set(Some(stop_reason.to_string())),
input_tokens: Set(Some(input_tokens as i32)),
output_tokens: Set(Some(output_tokens as i32)),
latency_ms: Set(None),
metadata: Set(None),
room_id: Set(None),
version_group_id: Set(Some(assistant_msg_id)),
version_number: Set(1),
is_latest: Set(true),
created_at: Set(chrono::Utc::now()),
};
match assistant_msg.insert(service.db.writer()).await {
Ok(msg) => Some(msg),
Err(e) => {
tracing::warn!(error = %e, conversation_id = %conversation_id, "failed to persist partial assistant message");
None
}
}
}
/// Update conversation metadata after an AI assistant message is saved.
async fn update_conversation_after_response(
service: &AppService,
@ -629,6 +919,12 @@ async fn build_messages_from_history(
service: &AppService,
conversation_id: Uuid,
) -> Result<Vec<ChatRequestMessage>, String> {
let conversation = service
.find_conversation(conversation_id)
.await
.map_err(|e| format!("conversation lookup error: {}", e))?;
let project_id = conversation.project_id;
let msgs = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::IsLatest.eq(true))
@ -668,6 +964,27 @@ async fn build_messages_from_history(
other => other.to_string(),
};
if role == "user" {
match service
.build_message_context_prompts(project_id, msg.metadata.as_ref())
.await
{
Ok(prompts) => {
for prompt in prompts {
chat_messages.push(ChatRequestMessage::system(prompt));
}
}
Err(error) => {
tracing::warn!(
conversation_id = %conversation_id,
message_id = %msg.id,
error = %error,
"failed to build chat message context prompts"
);
}
}
}
match role {
"user" => chat_messages.push(ChatRequestMessage::user(content)),
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
@ -681,8 +998,8 @@ async fn build_messages_from_history(
/// Merge consecutive content blocks of the same role into single blocks.
/// This transforms many small per-chunk blocks into clean interleaved segments:
/// [thinking, thinking, assistant, assistant] [thinking, assistant]
/// Per-token chunks are concatenated directly the model sends \n inside
/// [thinking, thinking, assistant, assistant] -> [thinking, assistant]
/// Per-token chunks are concatenated directly; the model sends \n inside
/// the token content where needed, not between tokens.
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
let mut merged: Vec<(String, String)> = Vec::new();
@ -691,7 +1008,7 @@ fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, Strin
continue;
}
if let Some(last) = merged.last_mut() {
if last.0 == role {
if last.0 == role && role != "tool_call" && role != "tool_result" {
last.1.push_str(&content);
continue;
}

376
libs/api/chat/subagent.rs Normal file
View File

@ -0,0 +1,376 @@
//! SSE endpoint for watching a specific sub-agent's stream output.
//!
//! `GET /api/ai/subagent/{conversation_id}/{children_id}/stream`
//!
//! Prefers Redis PubSub for low-latency live delivery and falls back to NATS
//! if Redis is unavailable. The subject/channel is
//! `chat.subagent.chunk.{conversation_id}.{children_id}`.
use actix_web::{HttpResponse, Result, web};
use futures::StreamExt;
use models::ai::ai_subagent_session;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder};
use service::AppService;
use uuid::Uuid;
use crate::error::ApiError;
async fn find_subagent_session(
service: &AppService,
conversation_id: Uuid,
children_id: &str,
) -> Option<ai_subagent_session::Model> {
ai_subagent_session::Entity::find()
.filter(ai_subagent_session::Column::ConversationId.eq(conversation_id))
.filter(ai_subagent_session::Column::ChildrenId.eq(children_id))
.order_by_desc(ai_subagent_session::Column::CreatedAt)
.one(service.db.reader())
.await
.ok()
.flatten()
}
fn terminal_event_for_status(status: &str) -> &'static str {
match status {
"stopped" | "cancelled" => "stopped",
"error" => "error",
_ => "done",
}
}
async fn send_session_snapshot(
tx: &tokio::sync::mpsc::Sender<String>,
session: &ai_subagent_session::Model,
include_output: bool,
) -> bool {
if include_output && !session.output.is_empty() {
let payload = serde_json::json!({
"event": "token",
"data": {
"content": session.output,
"children_id": session.children_id,
},
});
if tx.send(format!("data: {}\n\n", payload)).await.is_err() {
return false;
}
}
let event = terminal_event_for_status(&session.status);
let payload = serde_json::json!({
"event": event,
"data": {
"content": "",
"children_id": session.children_id,
"error": session.error_message,
},
});
tx.send(format!("data: {}\n\n", payload)).await.is_ok()
}
/// Create an SSE stream for watching a sub-agent's stream output from Redis PubSub.
pub fn create_subagent_sse_stream_redis(
service: AppService,
conversation_id: Uuid,
children_id: String,
) -> std::pin::Pin<
Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>,
> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(200);
tokio::spawn(async move {
if let Some(session) = find_subagent_session(&service, conversation_id, &children_id).await
{
let _ = send_session_snapshot(&tx, &session, true).await;
return;
}
let redis_url = match service.config.redis_url() {
Ok(url) => url,
Err(e) => {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
))
.await;
return;
}
};
let client = match redis::Client::open(redis_url) {
Ok(client) => client,
Err(e) => {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
))
.await;
return;
}
};
let pubsub = match client.get_async_pubsub().await {
Ok(pubsub) => pubsub,
Err(e) => {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
))
.await;
return;
}
};
let (mut sink, mut stream) = pubsub.split();
let channel = format!("chat.subagent.chunk.{}.{}", conversation_id, children_id);
if let Err(e) = sink.subscribe(channel.as_str()).await {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
))
.await;
return;
}
let _ = tx.send(":ok\n\n".to_string()).await;
let mut session_poll = tokio::time::interval(std::time::Duration::from_millis(500));
let stream_started = tokio::time::Instant::now();
let mut sent_content = false;
loop {
tokio::select! {
Some(msg) = stream.next() => {
let payload: String = match msg.get_payload() {
Ok(v) => v,
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":{}}}\n\n", serde_json::to_string(&e.to_string()).unwrap_or_default())).await;
break;
}
};
let event_type = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&payload) {
parsed
.get("chunk_type")
.and_then(|v| v.as_str())
.unwrap_or("chunk")
.to_string()
} else {
"chunk".to_string()
};
if matches!(event_type.as_str(), "token" | "thinking") {
sent_content = true;
}
let sse = format!("data: {{\"event\":\"{}\",\"data\":{}}}\n\n", event_type, payload);
if tx.send(sse).await.is_err() {
break;
}
if matches!(event_type.as_str(), "done" | "stopped" | "error") {
break;
}
}
_ = session_poll.tick() => {
if let Some(session) = find_subagent_session(&service, conversation_id, &children_id).await {
let _ = send_session_snapshot(&tx, &session, !sent_content).await;
break;
}
if stream_started.elapsed() > std::time::Duration::from_secs(90) {
let payload = serde_json::json!({
"event": "error",
"data": {
"content": "sub-agent stream timed out waiting for terminal state",
"children_id": children_id,
},
});
let _ = tx.send(format!("data: {}\n\n", payload)).await;
break;
}
}
}
}
});
Box::pin(
tokio_stream::wrappers::ReceiverStream::new(rx).map(|s| Ok(actix_web::web::Bytes::from(s))),
)
}
/// Create an SSE stream for watching a sub-agent's stream output via NATS fallback.
pub fn create_subagent_sse_stream_nats(
service: AppService,
conversation_id: Uuid,
children_id: String,
) -> std::pin::Pin<
Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>,
> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(200);
tokio::spawn(async move {
let nats = match &service.queue_producer.nats {
Some(n) => n.clone(),
None => {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string("NATS not available").unwrap_or_default()
))
.await;
return;
}
};
let subject = format!("chat.subagent.chunk.{}.{}", conversation_id, children_id);
let mut sub = match nats.subscribe(&subject).await {
Ok(s) => s,
Err(e) => {
let _ = tx
.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
))
.await;
return;
}
};
let _ = tx.send(":ok\n\n".to_string()).await;
loop {
match sub.next().await {
Some(msg) => {
let payload = String::from_utf8_lossy(&msg.payload);
let event_type =
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&payload) {
parsed
.get("chunk_type")
.and_then(|v| v.as_str())
.unwrap_or("chunk")
.to_string()
} else {
"chunk".to_string()
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event_type, payload
);
if tx.send(sse).await.is_err() {
break;
}
}
None => break,
}
}
});
Box::pin(
tokio_stream::wrappers::ReceiverStream::new(rx).map(|s| Ok(actix_web::web::Bytes::from(s))),
)
}
/// SSE endpoint for watching a sub-agent's stream output.
///
/// `GET /api/ai/subagent/{conversation_id}/{children_id}/stream`
#[utoipa::path(
get,
path = "/api/ai/subagent/{conversation_id}/{children_id}/stream",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("children_id" = String, Path, description = "Sub-agent children ID"),
),
responses(
(status = 200, description = "SSE stream of sub-agent events"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn subagent_stream_watch(
service: web::Data<AppService>,
session: session::Session,
path: web::Path<(Uuid, String)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
let (conversation_id, children_id) = path.into_inner();
// Verify access to the conversation
let _conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let redis_stream = create_subagent_sse_stream_redis(
service.get_ref().clone(),
conversation_id,
children_id.clone(),
);
let response = HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(redis_stream);
Ok(response.into())
}
/// NATS fallback retained for deployments where Redis PubSub is unavailable.
pub async fn subagent_stream_watch_nats(
service: web::Data<AppService>,
session: session::Session,
path: web::Path<(Uuid, String)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
let (conversation_id, children_id) = path.into_inner();
let _conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let response = HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(create_subagent_sse_stream_nats(
service.get_ref().clone(),
conversation_id,
children_id,
));
Ok(response.into())
}
#[utoipa::path(
post,
path = "/api/ai/subagent/{conversation_id}/{children_id}/stop",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("children_id" = String, Path, description = "Sub-agent children ID"),
),
responses(
(status = 200, description = "Sub-agent stop requested"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn subagent_stop(
service: web::Data<AppService>,
session: session::Session,
path: web::Path<(Uuid, String)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
let (conversation_id, children_id) = path.into_inner();
let _conv = service
.find_conversation_full_access(conversation_id, user_id)
.await?;
service
.cache
.set_sub_agent_cancelled(conversation_id, &children_id)
.await;
Ok(crate::api_success())
}

View File

@ -322,6 +322,9 @@ use utoipa::OpenApi;
crate::project::audit::project_log_audit,
crate::project::activity::project_activities,
crate::project::activity::project_log_activity,
crate::project::message_favorite::project_message_favorites,
crate::project::message_favorite::project_message_favorite_add,
crate::project::message_favorite::project_message_favorite_remove,
crate::project::stats::project_stats,
crate::project::billing::project_billing,
crate::project::billing::project_billing_history,
@ -564,6 +567,9 @@ use utoipa::OpenApi;
service::project::activity::ActivityLogResponse,
service::project::activity::ActivityLogParams,
service::project::activity::ActivityLogListResponse,
service::project::message_favorite::ProjectMessageFavoriteQuery,
service::project::message_favorite::ProjectMessageFavoriteItem,
service::project::message_favorite::ProjectMessageFavoriteResponse,
service::project::stats::ProjectStatsResponse,
service::project::stats::ActivityBreakdownItem,
service::project::stats::ProjectStatsActivityItem,

View File

@ -0,0 +1,89 @@
use crate::{ApiResponse, error::ApiError};
use actix_web::{HttpResponse, Result, web};
use service::AppService;
use service::project::message_favorite::{
ProjectMessageFavoriteQuery, ProjectMessageFavoriteResponse,
};
use session::Session;
use uuid::Uuid;
#[utoipa::path(
get,
path = "/api/projects/{project_name}/message-favorites",
params(
("project_name" = String, Path),
ProjectMessageFavoriteQuery,
),
responses(
(status = 200, description = "List current user's project message favorites", body = ApiResponse<ProjectMessageFavoriteResponse>),
(status = 401, description = "Unauthorized"),
(status = 403, description = "Forbidden"),
(status = 404, description = "Not found"),
),
tag = "Project"
)]
pub async fn project_message_favorites(
service: web::Data<AppService>,
session: Session,
path: web::Path<String>,
query: web::Query<ProjectMessageFavoriteQuery>,
) -> Result<HttpResponse, ApiError> {
let resp = service
.project_message_favorites(&session, path.into_inner(), query.into_inner())
.await?;
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/projects/{project_name}/messages/{message_id}/favorite",
params(
("project_name" = String, Path),
("message_id" = Uuid, Path),
),
responses(
(status = 200, description = "Favorite a project message", body = ApiResponse<service::project::message_favorite::ProjectMessageFavoriteItem>),
(status = 401, description = "Unauthorized"),
(status = 403, description = "Forbidden"),
(status = 404, description = "Not found"),
),
tag = "Project"
)]
pub async fn project_message_favorite_add(
service: web::Data<AppService>,
session: Session,
path: web::Path<(String, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let (project_name, message_id) = path.into_inner();
let resp = service
.project_message_favorite_add(&session, project_name, message_id)
.await?;
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
delete,
path = "/api/projects/{project_name}/messages/{message_id}/favorite",
params(
("project_name" = String, Path),
("message_id" = Uuid, Path),
),
responses(
(status = 200, description = "Remove a project message favorite"),
(status = 401, description = "Unauthorized"),
(status = 403, description = "Forbidden"),
(status = 404, description = "Not found"),
),
tag = "Project"
)]
pub async fn project_message_favorite_remove(
service: web::Data<AppService>,
session: Session,
path: web::Path<(String, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let (project_name, message_id) = path.into_inner();
service
.project_message_favorite_remove(&session, project_name, message_id)
.await?;
Ok(crate::api_success())
}

View File

@ -12,6 +12,7 @@ pub mod join_settings;
pub mod labels;
pub mod like;
pub mod members;
pub mod message_favorite;
pub mod repo;
pub mod settings;
pub mod stats;
@ -188,6 +189,18 @@ pub fn init_project_routes(cfg: &mut web::ServiceConfig) {
"/{project_name}/activities",
web::post().to(activity::project_log_activity),
)
.route(
"/{project_name}/message-favorites",
web::get().to(message_favorite::project_message_favorites),
)
.route(
"/{project_name}/messages/{message_id}/favorite",
web::post().to(message_favorite::project_message_favorite_add),
)
.route(
"/{project_name}/messages/{message_id}/favorite",
web::delete().to(message_favorite::project_message_favorite_remove),
)
.route(
"/{project_name}/avatar",
web::post().to(avatar::upload_project_avatar),

View File

@ -2,6 +2,7 @@ use crate::error::RoomError;
use crate::service::RoomService;
use crate::ws_context::WsUserContext;
use models::rooms::{room, room_message, room_user_state};
use redis::AsyncCommands;
use sea_orm::*;
use uuid::Uuid;
@ -151,6 +152,48 @@ impl RoomService {
}
pub(crate) async fn invalidate_room_list_cache(&self, project_id: Uuid) {
tracing::debug!(project_id = %project_id, "room_list cache: relying on TTL expiry");
self.invalidate_room_list_cache_for_prefix(&format!("room:list:{}:", project_id))
.await;
}
pub(crate) async fn invalidate_room_list_cache_for_user(
&self,
project_id: Uuid,
user_id: Uuid,
) {
self.invalidate_room_list_cache_for_prefix(&format!("room:list:{}:{}:", project_id, user_id))
.await;
}
async fn invalidate_room_list_cache_for_prefix(&self, prefix: &str) {
let pattern = format!("{}*", prefix);
if let Ok(mut conn) = self.cache.conn().await {
let mut cursor: u64 = 0;
loop {
match redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async::<(u64, Vec<String>)>(&mut conn)
.await
{
Ok((next_cursor, keys)) => {
for key in &keys {
let _: () = conn.del(key).await.unwrap_or(());
}
if next_cursor == 0 {
break;
}
cursor = next_cursor;
}
Err(e) => {
tracing::debug!(pattern = %pattern, error = ?e, "room_list cache scan failed");
break;
}
}
}
}
}
}

View File

@ -118,6 +118,8 @@ impl RoomService {
let updated = active.update(&self.db).await?;
let room = self.find_room_or_404(room_id).await?;
self.invalidate_room_list_cache_for_user(room.project, ctx.user_id)
.await;
self.publish_room_event(
room.project,
RoomEventType::ReadReceipt,

View File

@ -11,6 +11,7 @@ use uuid::Uuid;
use super::ai_common::create_and_publish_ai_message;
use crate::connection::RoomConnectionManager;
use agent::chat::{AiRequest, ChatService};
use agent::tool::registry::ToolRegistry;
pub async fn process_message_ai_nonstreaming(
chat_service: Arc<ChatService>,
@ -23,13 +24,14 @@ pub async fn process_message_ai_nonstreaming(
cache: AppCache,
queue: MessageProducer,
room_manager: Arc<RoomConnectionManager>,
room_tools: ToolRegistry,
) {
let chat_service = chat_service.clone();
tokio::spawn(async move {
let _lock_guard = lock_guard;
let model_display_name = request.model.name.clone();
match chat_service.process(request).await {
match chat_service.process_room(request, room_tools).await {
Ok(result) => {
if let Err(e) = create_and_publish_ai_message(
&db,

View File

@ -11,8 +11,8 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use crate::connection::RoomConnectionManager;
use crate::error::RoomError;
use crate::service::ai_react_nonstreaming;
use crate::service::ai_react_streaming;
use crate::service::ai_streaming;
use crate::service::ai_nonstreaming;
use crate::service::history;
use crate::service::patterns::{mention_bracket_re, mention_tag_re};
use agent::chat::{AiRequest, ChatService};
@ -252,6 +252,8 @@ impl RoomAiService {
think: ai_config.think,
tools: Some(chat_service.tools()),
max_tool_depth: 1000,
execution_profile: None,
room_preamble: None,
};
let (optimized_history, cutoff_seq) = chat_service
@ -263,15 +265,14 @@ impl RoomAiService {
});
request.history_cutoff_seq = cutoff_seq;
// Build room preamble: room identity, sender info, permissions, optimized history
let room_preamble = build_room_preamble(
request.room_preamble = Some(build_room_preamble(
&room,
&project,
&model,
&sender,
&sender_role,
&optimized_history,
);
));
// Pre-flight balance check: verify the project + user can afford at least a minimal AI call
let balance_ok =
@ -328,9 +329,9 @@ impl RoomAiService {
let use_streaming = ai_config.stream;
// Dispatch to ReAct streaming or nonstreaming with room tools and preamble
// Dispatch to direct streaming or nonstreaming with room tools
if use_streaming {
ai_react_streaming::process_message_ai_react_streaming(
ai_streaming::process_message_ai_streaming(
chat_service,
request,
room_id,
@ -342,11 +343,10 @@ impl RoomAiService {
self.queue.clone(),
self.room_manager.clone(),
room_tools,
room_preamble,
)
.await;
} else {
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
ai_nonstreaming::process_message_ai_nonstreaming(
chat_service,
request,
room_id,
@ -358,7 +358,6 @@ impl RoomAiService {
self.queue.clone(),
self.room_manager.clone(),
room_tools,
room_preamble,
)
.await;
}

View File

@ -14,6 +14,7 @@ use uuid::Uuid;
use super::sequence::next_room_message_seq_internal;
use crate::connection::RoomConnectionManager;
use agent::chat::{AiChunkType, AiRequest, ChatService, normalize_thinking_content};
use agent::tool::registry::ToolRegistry;
pub async fn process_message_ai_streaming(
chat_service: Arc<ChatService>,
@ -26,6 +27,7 @@ pub async fn process_message_ai_streaming(
cache: AppCache,
queue: MessageProducer,
room_manager: Arc<RoomConnectionManager>,
room_tools: ToolRegistry,
) {
use queue::RoomMessageStreamChunkEvent;
@ -245,7 +247,7 @@ pub async fn process_message_ai_streaming(
}
});
match chat_service.process_stream(request, stream_callback).await {
match chat_service.process_room_stream(request, stream_callback, room_tools).await {
Ok(result) => {
// Store ordered chunks as JSON in thinking_content for ordered replay.
// Uses {"__chunks__": [...]} marker so legacy plain-text still works.
@ -258,6 +260,7 @@ pub async fn process_message_ai_streaming(
agent::client::StreamChunkType::Thinking => "thinking",
agent::client::StreamChunkType::Answer => "answer",
agent::client::StreamChunkType::ToolCall => "tool_call",
agent::client::StreamChunkType::ToolResult => "tool_result",
};
let content = match c.chunk_type {
agent::client::StreamChunkType::Thinking => normalize_thinking_content(&c.content),

View File

@ -2,8 +2,8 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use uuid::Uuid;
use super::RoomService;
use super::ai_react_nonstreaming;
use super::ai_react_streaming;
use super::ai_nonstreaming;
use super::ai_streaming;
use super::history;
use crate::error::RoomError;
use crate::service::{mention_bracket_re, mention_tag_re};
@ -151,6 +151,8 @@ impl RoomService {
think: ai_config.think,
tools: Some(chat_service.tools()),
max_tool_depth: 1000,
execution_profile: None,
room_preamble: None,
};
let (optimized_history, cutoff_seq) = chat_service
@ -162,15 +164,13 @@ impl RoomService {
});
request.history_cutoff_seq = cutoff_seq;
// Build room preamble: room identity, sender info, permissions, optimized history
let room_preamble =
build_room_preamble(&room, &project, &sender, &sender_role, &optimized_history);
request.room_preamble = Some(build_room_preamble(&room, &project, &sender, &sender_role, &optimized_history));
let use_streaming = ai_config.stream;
// Dispatch to ReAct streaming or nonstreaming with room tools and preamble
// Dispatch to direct streaming or nonstreaming with room tools
if use_streaming {
ai_react_streaming::process_message_ai_react_streaming(
ai_streaming::process_message_ai_streaming(
chat_service.clone(),
request,
room_id,
@ -182,11 +182,10 @@ impl RoomService {
self.queue.clone(),
self.room_manager.clone(),
room_tools,
room_preamble,
)
.await;
} else {
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
ai_nonstreaming::process_message_ai_nonstreaming(
chat_service.clone(),
request,
room_id,
@ -198,7 +197,6 @@ impl RoomService {
self.queue.clone(),
self.room_manager.clone(),
room_tools,
room_preamble,
)
.await;
}