use std::pin::Pin; use std::sync::Arc; use uuid::Uuid; use crate::client::AiClientConfig; use crate::client::types::{ChatRequestMessage, ToolCall}; use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream}; use crate::embed::EmbedService; use crate::error::Result; use crate::tool::registry::ToolRegistry; use crate::tool::{ ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam, }; use sea_orm::{ActiveModelTrait, EntityTrait, Set}; use super::service::StreamResult; use super::{AiChunkType, AiStreamChunk, StreamCallback}; // Keyword-extraction-based title generator: reads conversation messages, extracts // meaningful words, and updates the conversation record with a short title. async fn generate_title_for_conversation( ctx: &ToolContext, conversation_id: Uuid, ) -> Result { use models::ai::{AiMessage, ai_conversation, ai_message}; use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect}; let db_reader = ctx.db().reader(); let db_writer = ctx.db().writer(); let conv = ai_conversation::Entity::find_by_id(conversation_id) .one(db_reader) .await .map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e), })? .ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?; let recent_messages = AiMessage::find() .filter(ai_message::Column::ConversationId.eq(conversation_id)) .filter(ai_message::Column::Role.eq("user")) .order_by_desc(ai_message::Column::CreatedAt) .limit(3) .all(db_reader) .await .map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e), })?; if recent_messages.is_empty() { return Err(crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: "No user messages found".into(), }); } let content = recent_messages .first() .and_then(|m| m.content.as_array()) .and_then(|arr| arr.first()) .and_then(|v| v.get("content")) .and_then(|c| c.as_str()) .unwrap_or(""); let words: Vec<&str> = content .split_whitespace() .filter(|w| w.len() > 2 && !is_stop_word(w)) .take(5) .collect(); let title = if words.is_empty() { "New Chat".to_string() } else { words.join(" ") }; let mut active: ai_conversation::ActiveModel = conv.into(); active.title = Set(Some(title.clone())); active.updated_at = Set(chrono::Utc::now()); active .update(db_writer) .await .map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("failed to update title: {}", e), })?; Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title })) } fn is_stop_word(w: &str) -> bool { matches!( w.to_lowercase().as_str(), "the" | "this" | "that" | "what" | "which" | "when" | "where" | "why" | "how" | "can" | "could" | "would" | "should" | "please" | "help" | "thanks" | "thank" | "you" | "your" | "have" | "has" | "had" | "with" | "for" | "from" | "into" | "about" | "also" | "just" | "now" | "very" | "really" ) } type SharedCallback = Arc< dyn Fn(AiStreamChunk) -> Pin + Send>> + Send + Sync, >; /// Simplified ReAct execution for Chat API. /// /// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data), /// this function takes messages and tools directly. It does NOT record AI sessions to /// the `ai_session` table — the caller is responsible for persisting results. pub async fn execute_chat_stream( messages: Vec, tools: Vec, model_name: &str, config: &AiClientConfig, temperature: f32, max_tokens: u32, max_tool_depth: usize, tool_registry: Option<&ToolRegistry>, db: db::database::AppDatabase, cache: db::cache::AppCache, app_config: config::AppConfig, project_id: Uuid, sender_uid: Uuid, embed_service: Option, on_chunk: StreamCallback, conversation_id: Option, ) -> Result { let on_chunk: SharedCallback = Arc::from(on_chunk); let tools_enabled = !tools.is_empty(); let mut messages = messages; let mut tool_depth = 0; let mut total_input_tokens = 0i64; let mut total_output_tokens = 0i64; let mut full_content = String::new(); let mut all_chunks: Vec = Vec::new(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); // Conditionally inject chat_generate_title tool if conversation has no title let (tools, _tools_injected) = if let Some(conv_id) = conversation_id { if let Some(registry) = tool_registry { let db_reader = db.reader(); let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id) .one(db_reader) .await .map(|c| c.map(|m| m.title.is_some()).unwrap_or(false)) .unwrap_or(false); if !has_title { let mut reg = registry.clone(); reg.register( ToolDefinition::new("chat_generate_title") .description( "Generate a concise title (5 words or fewer) for the current conversation \ based on its message history, and save it to the conversation record. \ Call this tool at the start of a new conversation if it has no title.", ) .parameters(crate::tool::ToolSchema { schema_type: "object".into(), properties: Some({ let mut p = std::collections::HashMap::new(); p.insert("conversation_id".into(), ToolParam { name: "conversation_id".into(), param_type: "string".into(), description: Some("The UUID of the conversation (required).".into()), required: true, properties: None, items: None, }); p }), required: Some(vec!["conversation_id".into()]), }), ToolHandler::new(|ctx, args| { let conv_id = args.get("conversation_id") .and_then(|v| v.as_str()) .and_then(|s| Uuid::parse_str(s).ok()); Box::pin(async move { match conv_id { Some(id) => generate_title_for_conversation(&ctx, id).await .map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())), None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())), } }) }), ); // Prepend system message instructing the model to generate title first messages.insert(0, ChatRequestMessage::system( "IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \ with the conversation_id immediately before answering any user question. \ The title must be 5 words or fewer and should summarize the user's intent.".to_string(), )); (reg.to_openai_tools(), true) } else { (tools.clone(), false) } } else { (tools.clone(), false) } } else { (tools.clone(), false) }; loop { let on_chunk_cb = on_chunk.clone(); let on_chunk_cb2 = on_chunk.clone(); let tx_arc = Arc::new(tx.clone()); let tx_arc2 = tx_arc.clone(); let response = call_stream( &messages, model_name, config, temperature, max_tokens, if tools_enabled { Some(&tools) } else { None }, None, Arc::new(move |delta| { let content = delta.to_string(); let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer, metadata: None, }); fut }), Arc::new(move |delta| { let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking, metadata: None, }); fut }), Arc::new(move |tc: &StreamedToolCall| { let tx = tx_arc2.clone(); let tc_owned = tc.clone(); Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin + Send>> }), ) .await?; total_input_tokens += response.input_tokens; total_output_tokens += response.output_tokens; all_chunks.extend(response.chunks.clone()); let has_tool_calls = tools_enabled && !response.tool_calls.is_empty(); if !has_tool_calls { let final_content = response.content.clone(); // Don't push full content as a chunk — incremental deltas in // response.chunks (already added above) sum to the same text. // merge_consecutive_blocks would concatenate delta_sum + full = // 2× full, causing duplicate content in DB persistence. return Ok(StreamResult { content: final_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks, }); } full_content.push_str(&response.content); let tool_calls: Vec = response .tool_calls .iter() .map(|tc| ToolCall { id: tc.id.clone(), type_: "function".into(), function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone(), }, }) .collect(); messages.push(ChatRequestMessage::assistant( Some(response.content.clone()), Some(tool_calls.clone()), )); // Drain tool call notifications loop { match rx.try_recv() { Ok(tc) => { let args_display = if tc.arguments.len() > 100 { let end = tc .arguments .char_indices() .map(|(i, _)| i) .take_while(|&i| i <= 100) .last() .unwrap_or(100); format!("{}...", &tc.arguments[..end]) } else { tc.arguments.clone() }; let tool_display = format!("🔧 {}({})", tc.name, args_display); on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None, }) .await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display, }); } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break, Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break, } } let calls: Vec = response .tool_calls .iter() .map(|tc| AgentToolCall { id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(), }) .collect(); let tool_messages = execute_tools( &calls, &db, &cache, &app_config, project_id, sender_uid, tool_registry, embed_service.as_ref(), &on_chunk, &mut all_chunks, ) .await; messages.extend(tool_messages); tool_depth += 1; if tool_depth >= max_tool_depth { let max_depth_text = format!( "[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth ); on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer, metadata: None, }) .await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text, }); return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks, }); } } } async fn execute_tools( calls: &[AgentToolCall], db: &db::database::AppDatabase, cache: &db::cache::AppCache, app_config: &config::AppConfig, project_id: Uuid, sender_uid: Uuid, tool_registry: Option<&ToolRegistry>, embed_service: Option<&EmbedService>, on_chunk: &SharedCallback, all_chunks: &mut Vec, ) -> Vec { let mut tool_messages = Vec::new(); let mut ctx = ToolContext::new( db.clone(), cache.clone(), app_config.clone(), Uuid::nil(), Some(sender_uid), ) .with_project(project_id); if let Some(es) = embed_service { ctx = ctx.with_embed_service(es.clone()); } if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); } let mut join_set = tokio::task::JoinSet::new(); for call in calls { let call_clone = call.clone(); let mut ctx_clone = ctx.clone(); join_set.spawn(async move { let executor = ToolExecutor::new(); let res = executor .execute_batch(vec![call_clone.clone()], &mut ctx_clone) .await; (call_clone, res) }); } let heartbeat_dur = std::time::Duration::from_secs(10); while !join_set.is_empty() { tokio::select! { Some(res) = join_set.join_next() => { if let Ok((call, results)) = res { match results { Ok(results) => { for result in &results { let preview = match &result.result { crate::tool::ToolResult::Ok(v) => { let t = v.to_string(); if t.len() > 300 { let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300); format!("{}...", &t[..end]) } else { t.clone() } } crate::tool::ToolResult::Error(msg) => msg.clone(), }; tracing::debug!("tool_result: {} — {}", call.name, preview); } let success_display = format!("✅ {}", call.name); on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display }); let msgs = ToolExecutor::to_tool_messages(&results); tool_messages.extend(msgs); } Err(e) => { tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed"); let err_text = format!("[Tool call failed: {}]", e); let err_display = format!("❌ {} (failed)", call.name); on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await; all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display }); tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text)); } } } }, _ = tokio::time::sleep(heartbeat_dur) => { on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await; } } } tool_messages }