483 lines
19 KiB
Rust
483 lines
19 KiB
Rust
use models::projects::project_skill;
|
||
use models::rooms::room_ai;
|
||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||
use std::pin::Pin;
|
||
use std::sync::Arc;
|
||
use uuid::Uuid;
|
||
|
||
use super::message_builder::MessageBuilder;
|
||
use super::service::StreamResult;
|
||
use super::session_recording::record_ai_session;
|
||
use super::{AiChunkType, AiRequest, AiStreamChunk, StreamCallback};
|
||
use crate::client::AiClientConfig;
|
||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
|
||
use crate::error::Result;
|
||
use crate::perception::{SkillEntry, ToolCallEvent};
|
||
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor};
|
||
|
||
type SharedCallback = Arc<
|
||
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
|
||
>;
|
||
|
||
pub async fn execute_process_stream(
|
||
request: AiRequest,
|
||
on_chunk: StreamCallback,
|
||
message_builder: &MessageBuilder,
|
||
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
|
||
ai_base_url: Option<String>,
|
||
ai_api_key: Option<String>,
|
||
) -> Result<StreamResult> {
|
||
let on_chunk: SharedCallback = Arc::from(on_chunk);
|
||
let tools: Vec<serde_json::Value> = request.tools.clone().unwrap_or_default();
|
||
let tools_enabled = !tools.is_empty();
|
||
let max_tool_depth = request.max_tool_depth;
|
||
|
||
let mut messages = message_builder.build_messages(&request).await?;
|
||
|
||
let room_ai_config = room_ai::Entity::find()
|
||
.filter(room_ai::Column::Room.eq(request.room.id))
|
||
.filter(room_ai::Column::Model.eq(request.model.id))
|
||
.one(&request.db)
|
||
.await?;
|
||
|
||
let model_name = request.model.name.clone();
|
||
let temperature = room_ai_config
|
||
.as_ref()
|
||
.and_then(|r| r.temperature.map(|v| v as f32))
|
||
.unwrap_or(request.temperature as f32);
|
||
let max_tokens = room_ai_config
|
||
.as_ref()
|
||
.and_then(|r| r.max_tokens.map(|v| v as u32))
|
||
.unwrap_or(request.max_tokens as u32);
|
||
let mut tool_depth = 0;
|
||
let mut total_input_tokens = 0i64;
|
||
let mut total_output_tokens = 0i64;
|
||
let session_id = Uuid::now_v7();
|
||
let session_start = std::time::Instant::now();
|
||
let version_id = room_ai_config.as_ref().and_then(|r| r.version);
|
||
|
||
let config = AiClientConfig::new(ai_api_key.unwrap_or_default())
|
||
.with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into()));
|
||
|
||
let mut full_content = String::new();
|
||
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
|
||
|
||
loop {
|
||
let on_chunk_cb = on_chunk.clone();
|
||
let on_chunk_cb2 = on_chunk.clone();
|
||
let tx_arc = Arc::new(tx.clone());
|
||
let tx_arc2 = tx_arc.clone();
|
||
let response = call_stream(
|
||
&messages,
|
||
&model_name,
|
||
&config,
|
||
temperature,
|
||
max_tokens,
|
||
if tools_enabled { Some(&tools) } else { None },
|
||
None,
|
||
Arc::new(move |delta| {
|
||
let content = delta.to_string();
|
||
let fut = on_chunk_cb(AiStreamChunk {
|
||
content,
|
||
done: false,
|
||
chunk_type: AiChunkType::Answer,
|
||
metadata: None,
|
||
});
|
||
fut
|
||
}),
|
||
Arc::new(move |delta| {
|
||
let fut = on_chunk_cb2(AiStreamChunk {
|
||
content: delta.to_string(),
|
||
done: false,
|
||
chunk_type: AiChunkType::Thinking,
|
||
metadata: None,
|
||
});
|
||
fut
|
||
}),
|
||
Arc::new(move |tc: &StreamedToolCall| {
|
||
let tx = tx_arc2.clone();
|
||
let tc_owned = tc.clone();
|
||
Box::pin(async move {
|
||
let _ = tx.send(tc_owned);
|
||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||
}),
|
||
)
|
||
.await?;
|
||
|
||
total_input_tokens += response.input_tokens;
|
||
total_output_tokens += response.output_tokens;
|
||
all_chunks.extend(response.chunks.clone());
|
||
|
||
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
||
if !has_tool_calls {
|
||
return handle_final_answer(
|
||
response,
|
||
all_chunks,
|
||
&request,
|
||
session_id,
|
||
version_id,
|
||
total_input_tokens,
|
||
total_output_tokens,
|
||
session_start,
|
||
)
|
||
.await;
|
||
}
|
||
|
||
full_content.push_str(&response.content);
|
||
|
||
let tool_calls: Vec<ToolCall> = response
|
||
.tool_calls
|
||
.iter()
|
||
.map(|tc| ToolCall {
|
||
id: tc.id.clone(),
|
||
type_: "function".into(),
|
||
function: crate::client::types::ToolCallFunction {
|
||
name: tc.name.clone(),
|
||
arguments: tc.arguments.clone(),
|
||
},
|
||
})
|
||
.collect();
|
||
|
||
messages.push(ChatRequestMessage::assistant(
|
||
Some(response.content.clone()),
|
||
Some(tool_calls.clone()),
|
||
));
|
||
|
||
drain_tool_call_notifications(&mut rx, &on_chunk, &mut all_chunks).await;
|
||
|
||
let calls: Vec<AgentToolCall> = response
|
||
.tool_calls
|
||
.iter()
|
||
.map(|tc| AgentToolCall {
|
||
id: tc.id.clone(),
|
||
name: tc.name.clone(),
|
||
arguments: tc.arguments.clone(),
|
||
})
|
||
.collect();
|
||
|
||
let tool_messages = execute_streaming_tools(
|
||
&request,
|
||
&calls,
|
||
session_id,
|
||
&on_chunk,
|
||
&mut all_chunks,
|
||
tool_registry,
|
||
message_builder,
|
||
)
|
||
.await;
|
||
|
||
messages.extend(tool_messages);
|
||
inject_passive_skills_stream(
|
||
&request,
|
||
message_builder,
|
||
&response.tool_calls,
|
||
&mut messages,
|
||
)
|
||
.await;
|
||
|
||
tool_depth += 1;
|
||
if tool_depth >= max_tool_depth {
|
||
let max_depth_text = format!(
|
||
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
||
max_tool_depth
|
||
);
|
||
on_chunk(AiStreamChunk {
|
||
content: max_depth_text.clone(),
|
||
done: true,
|
||
chunk_type: AiChunkType::Answer,
|
||
metadata: None,
|
||
})
|
||
.await;
|
||
all_chunks.push(StreamChunk {
|
||
chunk_type: StreamChunkType::Answer,
|
||
content: max_depth_text,
|
||
});
|
||
record_ai_session(
|
||
&request.cache,
|
||
&request.db,
|
||
request.project.id,
|
||
request.sender.uid,
|
||
session_id,
|
||
request.room.id,
|
||
request.model.id,
|
||
version_id.unwrap_or_default(),
|
||
total_input_tokens,
|
||
total_output_tokens,
|
||
session_start.elapsed().as_millis() as i64,
|
||
)
|
||
.await;
|
||
return Ok(StreamResult {
|
||
content: full_content,
|
||
reasoning_content: String::new(),
|
||
input_tokens: 0,
|
||
output_tokens: 0,
|
||
chunks: all_chunks,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn drain_tool_call_notifications(
|
||
rx: &mut tokio::sync::mpsc::UnboundedReceiver<StreamedToolCall>,
|
||
on_chunk: &SharedCallback,
|
||
all_chunks: &mut Vec<StreamChunk>,
|
||
) {
|
||
loop {
|
||
match rx.try_recv() {
|
||
Ok(tc) => {
|
||
let args_display = if tc.arguments.len() > 100 {
|
||
let end = tc
|
||
.arguments
|
||
.char_indices()
|
||
.map(|(i, _)| i)
|
||
.take_while(|&i| i <= 100)
|
||
.last()
|
||
.unwrap_or(100);
|
||
format!("{}...", &tc.arguments[..end])
|
||
} else {
|
||
tc.arguments.clone()
|
||
};
|
||
let tool_display = format!("🔧 {}({})", tc.name, args_display);
|
||
// Parse arguments JSON for structured metadata
|
||
let args_json =
|
||
serde_json::from_str(&tc.arguments).unwrap_or(serde_json::json!({}));
|
||
let metadata = serde_json::json!({
|
||
"tool": tc.name,
|
||
"args": args_json,
|
||
"display": tool_display.clone(),
|
||
});
|
||
on_chunk(AiStreamChunk {
|
||
content: tool_display.clone(),
|
||
done: false,
|
||
chunk_type: AiChunkType::ToolCall,
|
||
metadata: Some(metadata),
|
||
})
|
||
.await;
|
||
all_chunks.push(StreamChunk {
|
||
chunk_type: StreamChunkType::ToolCall,
|
||
content: tool_display,
|
||
});
|
||
}
|
||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn execute_streaming_tools(
|
||
request: &AiRequest,
|
||
calls: &[AgentToolCall],
|
||
session_id: Uuid,
|
||
on_chunk: &SharedCallback,
|
||
all_chunks: &mut Vec<StreamChunk>,
|
||
tool_registry: &Option<crate::tool::registry::ToolRegistry>,
|
||
message_builder: &MessageBuilder,
|
||
) -> Vec<ChatRequestMessage> {
|
||
let mut tool_messages = Vec::new();
|
||
let mut ctx = ToolContext::new(
|
||
request.db.clone(),
|
||
request.cache.clone(),
|
||
request.config.clone(),
|
||
request.room.id,
|
||
Some(request.sender.uid),
|
||
)
|
||
.with_project(request.project.id);
|
||
if let Some(es) = &message_builder.embed_service {
|
||
ctx = ctx.with_embed_service(es.clone());
|
||
}
|
||
if let Some(registry) = tool_registry {
|
||
ctx.registry_mut().merge(registry.clone());
|
||
}
|
||
|
||
let recorder =
|
||
crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
|
||
let mut join_set = tokio::task::JoinSet::new();
|
||
|
||
for call in calls {
|
||
let call_clone = call.clone();
|
||
let mut ctx_clone = ctx.clone();
|
||
let sender_uid = request.sender.uid;
|
||
let recorder_clone = recorder.clone();
|
||
|
||
join_set.spawn(async move {
|
||
let start = std::time::Instant::now();
|
||
let executor = ToolExecutor::new();
|
||
let res = executor
|
||
.execute_batch(vec![call_clone.clone()], &mut ctx_clone)
|
||
.await;
|
||
(call_clone, res, start.elapsed(), sender_uid, recorder_clone)
|
||
});
|
||
}
|
||
|
||
let heartbeat_dur = std::time::Duration::from_secs(10);
|
||
while !join_set.is_empty() {
|
||
tokio::select! {
|
||
Some(res) = join_set.join_next() => {
|
||
if let Ok((call, results, elapsed, sender_uid, recorder)) = res {
|
||
match results {
|
||
Ok(results) => {
|
||
for result in &results {
|
||
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
|
||
let preview = if text.len() > 300 {
|
||
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
|
||
format!("{}...", &text[..end])
|
||
} else { text.clone() };
|
||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||
|
||
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
|
||
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
|
||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||
tool_call_id: call.id.clone(),
|
||
session_id: recorder.session_id(),
|
||
tool_name: call.name.clone(),
|
||
caller: sender_uid,
|
||
arguments: call.arguments_json().unwrap_or_default(),
|
||
status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success },
|
||
execution_time_ms: Some(elapsed.as_millis() as i64),
|
||
error_message: error_msg,
|
||
error_stack: None,
|
||
retry_count: 0
|
||
});
|
||
}
|
||
let success_display = format!("✅ {}", call.name);
|
||
let result_preview: Vec<String> = results.iter().map(|r| {
|
||
match &r.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() }
|
||
}).collect();
|
||
let metadata = serde_json::json!({
|
||
"tool": call.name,
|
||
"status": "ok",
|
||
"result": result_preview.join("\n").chars().take(500).collect::<String>(),
|
||
"display": success_display.clone(),
|
||
});
|
||
on_chunk(AiStreamChunk {
|
||
content: success_display.clone(),
|
||
done: false,
|
||
chunk_type: AiChunkType::ToolResult,
|
||
metadata: Some(metadata),
|
||
}).await;
|
||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
|
||
let msgs = ToolExecutor::to_tool_messages(&results);
|
||
tool_messages.extend(msgs);
|
||
}
|
||
Err(e) => {
|
||
recorder.record(crate::tool::recorder::ToolCallRecord {
|
||
tool_call_id: call.id.clone(),
|
||
session_id: recorder.session_id(),
|
||
tool_name: call.name.clone(),
|
||
caller: sender_uid,
|
||
arguments: call.arguments_json().unwrap_or_default(),
|
||
status: models::ai::ToolCallStatus::Failed,
|
||
execution_time_ms: Some(elapsed.as_millis() as i64),
|
||
error_message: Some(e.to_string()),
|
||
error_stack: None,
|
||
retry_count: 0
|
||
});
|
||
let err_text = format!("[Tool call failed: {}]", e);
|
||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||
let err_display = format!("❌ {} (failed)", call.name);
|
||
let metadata = serde_json::json!({
|
||
"tool": call.name,
|
||
"status": "error",
|
||
"result": e.to_string(),
|
||
"display": err_display.clone(),
|
||
});
|
||
on_chunk(AiStreamChunk {
|
||
content: err_display.clone(),
|
||
done: false,
|
||
chunk_type: AiChunkType::ToolResult,
|
||
metadata: Some(metadata),
|
||
}).await;
|
||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
|
||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||
}
|
||
}
|
||
}
|
||
},
|
||
_ = tokio::time::sleep(heartbeat_dur) => {
|
||
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await;
|
||
}
|
||
}
|
||
}
|
||
tool_messages
|
||
}
|
||
|
||
async fn handle_final_answer(
|
||
response: crate::client::StreamResponse,
|
||
all_chunks: Vec<StreamChunk>,
|
||
request: &AiRequest,
|
||
session_id: Uuid,
|
||
version_id: Option<Uuid>,
|
||
total_input_tokens: i64,
|
||
total_output_tokens: i64,
|
||
session_start: std::time::Instant,
|
||
) -> Result<StreamResult> {
|
||
let full_content = response.content.clone();
|
||
// Don't push full content as a chunk — incremental deltas in
|
||
// response.chunks (already accumulated above) sum to the same text.
|
||
// merge_consecutive_blocks would concatenate delta_sum + full =
|
||
// 2× full, causing duplicate content in DB persistence.
|
||
record_ai_session(
|
||
&request.cache,
|
||
&request.db,
|
||
request.project.id,
|
||
request.sender.uid,
|
||
session_id,
|
||
request.room.id,
|
||
request.model.id,
|
||
version_id.unwrap_or_default(),
|
||
total_input_tokens,
|
||
total_output_tokens,
|
||
session_start.elapsed().as_millis() as i64,
|
||
)
|
||
.await;
|
||
Ok(StreamResult {
|
||
content: full_content,
|
||
reasoning_content: response.reasoning_content,
|
||
input_tokens: total_input_tokens,
|
||
output_tokens: total_output_tokens,
|
||
chunks: all_chunks,
|
||
})
|
||
}
|
||
|
||
async fn inject_passive_skills_stream(
|
||
request: &AiRequest,
|
||
message_builder: &MessageBuilder,
|
||
tool_calls: &[StreamedToolCall],
|
||
messages: &mut Vec<ChatRequestMessage>,
|
||
) {
|
||
if let Ok(skills) = project_skill::Entity::find()
|
||
.filter(project_skill::Column::ProjectUuid.eq(request.project.id))
|
||
.filter(project_skill::Column::Enabled.eq(true))
|
||
.all(&request.db)
|
||
.await
|
||
{
|
||
let skill_entries: Vec<SkillEntry> = skills
|
||
.into_iter()
|
||
.map(|s| SkillEntry {
|
||
slug: s.slug,
|
||
name: s.name,
|
||
description: s.description,
|
||
content: s.content,
|
||
})
|
||
.collect();
|
||
let tool_events: Vec<ToolCallEvent> = tool_calls
|
||
.iter()
|
||
.map(|tc| ToolCallEvent {
|
||
tool_name: tc.name.clone(),
|
||
arguments: tc.arguments.clone(),
|
||
})
|
||
.collect();
|
||
for event in &tool_events {
|
||
if let Some(ctx) = message_builder
|
||
.perception_service
|
||
.passive
|
||
.detect(event, &skill_entries)
|
||
{
|
||
messages.push(ctx.to_system_message());
|
||
}
|
||
}
|
||
}
|
||
}
|