use models::projects::project_skill; use models::rooms::room_ai; use sea_orm::{EntityTrait, ColumnTrait, QueryFilter}; use std::pin::Pin; use std::sync::Arc; use uuid::Uuid; use super::service::StreamResult; use super::{AiChunkType, AiRequest, AiStreamChunk, StreamCallback}; use crate::client::AiClientConfig; use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream}; use crate::error::Result; use crate::perception::{SkillEntry, ToolCallEvent}; use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolExecutor}; use super::message_builder::MessageBuilder; use super::session_recording::record_ai_session; type SharedCallback = Arc Pin + Send>> + Send + Sync>; pub async fn execute_process_stream( request: AiRequest, on_chunk: StreamCallback, message_builder: &MessageBuilder, tool_registry: &Option, ai_base_url: Option, ai_api_key: Option, ) -> Result { let on_chunk: SharedCallback = Arc::from(on_chunk); let tools: Vec = request.tools.clone().unwrap_or_default(); let tools_enabled = !tools.is_empty(); let max_tool_depth = request.max_tool_depth; let mut messages = message_builder.build_messages(&request).await?; let room_ai_config = room_ai::Entity::find() .filter(room_ai::Column::Room.eq(request.room.id)) .filter(room_ai::Column::Model.eq(request.model.id)) .one(&request.db).await?; let model_name = request.model.name.clone(); let temperature = room_ai_config.as_ref().and_then(|r| r.temperature.map(|v| v as f32)).unwrap_or(request.temperature as f32); let max_tokens = room_ai_config.as_ref().and_then(|r| r.max_tokens.map(|v| v as u32)).unwrap_or(request.max_tokens as u32); let mut tool_depth = 0; let mut total_input_tokens = 0i64; let mut total_output_tokens = 0i64; let session_id = Uuid::now_v7(); let session_start = std::time::Instant::now(); let version_id = room_ai_config.as_ref().and_then(|r| r.version); let config = AiClientConfig::new(ai_api_key.unwrap_or_default()) .with_base_url(ai_base_url.unwrap_or_else(|| "https://api.openai.com".into())); let mut full_content = String::new(); let mut all_chunks: Vec = Vec::new(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); loop { let on_chunk_cb = on_chunk.clone(); let on_chunk_cb2 = on_chunk.clone(); let tx_arc = Arc::new(tx.clone()); let tx_arc2 = tx_arc.clone(); let response = call_stream( &messages, &model_name, &config, temperature, max_tokens, if tools_enabled { Some(&tools) } else { None }, None, Arc::new(move |delta| { let content = delta.to_string(); let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer }); fut }), Arc::new(move |delta| { let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking }); fut }), Arc::new(move |tc: &StreamedToolCall| { let tx = tx_arc2.clone(); let tc_owned = tc.clone(); Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin + Send>> }), ).await?; total_input_tokens += response.input_tokens; total_output_tokens += response.output_tokens; all_chunks.extend(response.chunks.clone()); let has_tool_calls = tools_enabled && !response.tool_calls.is_empty(); if !has_tool_calls { return handle_final_answer(response, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await; } full_content.push_str(&response.content); let tool_calls: Vec = response.tool_calls.iter().map(|tc| ToolCall { id: tc.id.clone(), type_: "function".into(), function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() }, }).collect(); messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone()))); drain_tool_call_notifications(&mut rx, &on_chunk, &mut all_chunks).await; let calls: Vec = response.tool_calls.iter().map(|tc| AgentToolCall { id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(), }).collect(); let tool_messages = execute_streaming_tools( &request, &calls, session_id, &on_chunk, &mut all_chunks, tool_registry, message_builder, ).await; messages.extend(tool_messages); inject_passive_skills_stream(&request, message_builder, &response.tool_calls, &mut messages).await; tool_depth += 1; if tool_depth >= max_tool_depth { let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth); on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text }); record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await; return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks }); } } } async fn drain_tool_call_notifications( rx: &mut tokio::sync::mpsc::UnboundedReceiver, on_chunk: &SharedCallback, all_chunks: &mut Vec, ) { loop { match rx.try_recv() { Ok(tc) => { let args_display = if tc.arguments.len() > 100 { let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100); format!("{}...", &tc.arguments[..end]) } else { tc.arguments.clone() }; let tool_display = format!("🔧 {}({})", tc.name, args_display); on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display }); } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break, Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break, } } } async fn execute_streaming_tools( request: &AiRequest, calls: &[AgentToolCall], session_id: Uuid, on_chunk: &SharedCallback, all_chunks: &mut Vec, tool_registry: &Option, message_builder: &MessageBuilder, ) -> Vec { let mut tool_messages = Vec::new(); let mut ctx = ToolContext::new(request.db.clone(), request.cache.clone(), request.config.clone(), request.room.id, Some(request.sender.uid)).with_project(request.project.id); if let Some(es) = &message_builder.embed_service { ctx = ctx.with_embed_service(es.clone()); } if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); } let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id); let mut join_set = tokio::task::JoinSet::new(); for call in calls { let call_clone = call.clone(); let mut ctx_clone = ctx.clone(); let sender_uid = request.sender.uid; let recorder_clone = recorder.clone(); join_set.spawn(async move { let start = std::time::Instant::now(); let executor = ToolExecutor::new(); let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await; (call_clone, res, start.elapsed(), sender_uid, recorder_clone) }); } let heartbeat_dur = std::time::Duration::from_secs(10); while !join_set.is_empty() { tokio::select! { Some(res) = join_set.join_next() => { if let Ok((call, results, elapsed, sender_uid, recorder)) = res { match results { Ok(results) => { for result in &results { let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() }; let preview = if text.len() > 300 { let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300); format!("{}...", &text[..end]) } else { text.clone() }; tracing::debug!("tool_result: {} — {}", call.name, preview); let is_error = matches!(result.result, crate::tool::ToolResult::Error(_)); let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None }; recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: sender_uid, arguments: call.arguments_json().unwrap_or_default(), status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed.as_millis() as i64), error_message: error_msg, error_stack: None, retry_count: 0 }); } let success_display = format!("✅ {}", call.name); on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display }); let msgs = ToolExecutor::to_tool_messages(&results); tool_messages.extend(msgs); } Err(e) => { recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: sender_uid, arguments: call.arguments_json().unwrap_or_default(), status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed.as_millis() as i64), error_message: Some(e.to_string()), error_stack: None, retry_count: 0 }); let err_text = format!("[Tool call failed: {}]", e); tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed"); let err_display = format!("❌ {} (failed)", call.name); on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display }); tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text)); } } } }, _ = tokio::time::sleep(heartbeat_dur) => { on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await; } } } tool_messages } async fn handle_final_answer( response: crate::client::StreamResponse, all_chunks: Vec, request: &AiRequest, session_id: Uuid, version_id: Option, total_input_tokens: i64, total_output_tokens: i64, session_start: std::time::Instant, ) -> Result { let full_content = response.content.clone(); // Don't push full content as a chunk — incremental deltas in // response.chunks (already accumulated above) sum to the same text. // merge_consecutive_blocks would concatenate delta_sum + full = // 2× full, causing duplicate content in DB persistence. record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await; Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks }) } async fn inject_passive_skills_stream( request: &AiRequest, message_builder: &MessageBuilder, tool_calls: &[StreamedToolCall], messages: &mut Vec, ) { if let Ok(skills) = project_skill::Entity::find() .filter(project_skill::Column::ProjectUuid.eq(request.project.id)) .filter(project_skill::Column::Enabled.eq(true)).all(&request.db).await { let skill_entries: Vec = skills.into_iter().map(|s| SkillEntry { slug: s.slug, name: s.name, description: s.description, content: s.content }).collect(); let tool_events: Vec = tool_calls.iter().map(|tc| ToolCallEvent { tool_name: tc.name.clone(), arguments: tc.arguments.clone() }).collect(); for event in &tool_events { if let Some(ctx) = message_builder.perception_service.passive.detect(event, &skill_entries) { messages.push(ctx.to_system_message()); } } } }