diff --git a/libs/room/src/service/ai_nonstreaming.rs b/libs/room/src/service/ai_nonstreaming.rs index e7b59eb..e8a12b3 100644 --- a/libs/room/src/service/ai_nonstreaming.rs +++ b/libs/room/src/service/ai_nonstreaming.rs @@ -75,7 +75,7 @@ pub async fn process_message_ai_nonstreaming( room_id, project_id, Uuid::now_v7(), - format!("[AI error: {}]", e), + "[AI 处理发生错误,请稍后再试]".to_string(), model_id, Some(model_display_name), ) diff --git a/libs/room/src/service/ai_react_nonstreaming.rs b/libs/room/src/service/ai_react_nonstreaming.rs index 53d5d4b..d558d47 100644 --- a/libs/room/src/service/ai_react_nonstreaming.rs +++ b/libs/room/src/service/ai_react_nonstreaming.rs @@ -29,7 +29,7 @@ pub async fn process_message_ai_react_nonstreaming( let model_display_name = request.model.name.clone(); let final_answer = chat_service - .process_react(&request, |_step| {}) + .process_react(&request, |_step| async move {}) .await; match final_answer { @@ -77,7 +77,7 @@ pub async fn process_message_ai_react_nonstreaming( room_id, project_id, Uuid::now_v7(), - format!("[AI error: {}]", e), + "[AI 处理发生错误,请稍后再试]".to_string(), model_id, Some(model_display_name), ) diff --git a/libs/room/src/service/ai_react_streaming.rs b/libs/room/src/service/ai_react_streaming.rs index 535bacf..fcb5d15 100644 --- a/libs/room/src/service/ai_react_streaming.rs +++ b/libs/room/src/service/ai_react_streaming.rs @@ -44,9 +44,41 @@ pub async fn process_message_ai_react_streaming( tokio::spawn(async move { let _lock_guard = lock_guard; + let cancel = room_manager.register_stream_cancel(room_id_inner).await; + + let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001") + .expect("constant UUID should always parse"); + + let typing_start = queue::TypingEvent { + room_id: room_id_inner, + user_id: ai_typing_id, + username: ai_display_name.clone(), + avatar_url: None, + action: "start".to_string(), + sender_type: Some("ai".to_string()), + }; + room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await; + + let (typing_cancel_tx, mut typing_cancel_rx) = tokio::sync::oneshot::channel::<()>(); + let typing_renew_handle = tokio::spawn({ + let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + let mgr = room_manager.clone(); + let rid = room_id_inner; + let evt = typing_start.clone(); + async move { + tokio::select! { + _ = &mut typing_cancel_rx => {} + _ = async { + loop { + interval.tick().await; + mgr.broadcast_typing(rid, evt.clone()).await; + } + } => {} + } + } + }); - // Collect ordered steps for storage and streaming. - // Using poison-recovering guards to prevent Mutex poisoning from killing the room. let steps: std::sync::Arc>> = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); let last_action_name: std::sync::Arc> = @@ -54,15 +86,17 @@ pub async fn process_message_ai_react_streaming( let answer_buffer: std::sync::Arc> = std::sync::Arc::new(std::sync::Mutex::new(String::new())); let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); - let chunk_seq: std::sync::Arc = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); + let chunk_seq: std::sync::Arc = + std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); - // Helper: recover from poison instead of panicking. fn lock_or_recover(mutex: &std::sync::Mutex) -> std::sync::MutexGuard<'_, T> { mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) } let on_step = { let room_manager = room_manager.clone(); + let queue = queue.clone(); + let cancel = cancel.clone(); let streaming_msg_id = streaming_msg_id; let room_id = room_id_inner; let step_count = step_count.clone(); @@ -73,6 +107,8 @@ pub async fn process_message_ai_react_streaming( let last_action_name = last_action_name.clone(); move |step: ReactStep| { let room_manager = room_manager.clone(); + let queue = queue.clone(); + let cancel = cancel.clone(); let (chunk_type, content) = match &step { ReactStep::Thought { step: _, thought } => { ("thinking".to_string(), thought.clone()) @@ -100,8 +136,6 @@ pub async fn process_message_ai_react_streaming( step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); } - // Record ordered step for storage — merge consecutive same-type chunks - // to ensure strict think→answer→think→answer alternation. { let mut s = lock_or_recover(&steps); if let Some(last) = s.last_mut() { @@ -122,43 +156,47 @@ pub async fn process_message_ai_react_streaming( let done = false; let ai_name = ai_display_name_for_step.clone(); let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - tokio::spawn(async move { - let event = RoomMessageStreamChunkEvent { - message_id: streaming_msg_id, - room_id, - seq: current_seq, - content: content.clone(), - done, - error: None, - display_name: Some((*ai_name).clone()), - chunk_type: Some(chunk_type), - }; + let event = RoomMessageStreamChunkEvent { + message_id: streaming_msg_id, + room_id, + seq: current_seq, + content: content.clone(), + done, + error: None, + display_name: Some((*ai_name).clone()), + chunk_type: Some(chunk_type), + }; + + async move { + if cancel.load(std::sync::atomic::Ordering::Acquire) { + return; + } + queue.publish_stream_chunk(&event).await; room_manager.broadcast_stream_chunk(event).await; - }); + } } }; let result = chat_service.process_react(&request, on_step).await; - // Broadcast final done=true event to close the streaming channel on frontend. let final_stream_content = lock_or_recover(&answer_buffer).clone(); - room_manager - .broadcast_stream_chunk(RoomMessageStreamChunkEvent { - message_id: streaming_msg_id, - room_id: room_id_inner, - seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed), - content: final_stream_content.clone(), - done: true, - error: None, - display_name: Some(ai_display_name.clone()), - chunk_type: Some("answer".to_string()), - }) - .await; + let final_event = RoomMessageStreamChunkEvent { + message_id: streaming_msg_id, + room_id: room_id_inner, + seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed), + content: final_stream_content.clone(), + done: true, + error: None, + display_name: Some(ai_display_name.clone()), + chunk_type: Some("answer".to_string()), + }; + queue.publish_stream_chunk(&final_event).await; + room_manager.broadcast_stream_chunk(final_event).await; let (final_content, _input_tokens, _output_tokens, err_msg, _should_log) = match result { Ok((content, input, output)) => (content, input, output, None, false), Err(e) => { - let msg = format!("[Agent error: {}]", e); + let msg = "[AI 处理发生错误,请稍后再试]".to_string(); tracing::error!(error = %e, "ReAct streaming failed"); (String::new(), 0, 0, Some(msg), true) } @@ -183,12 +221,7 @@ pub async fn process_message_ai_react_streaming( String::from("[No output from reasoning agent]") }; let content_to_persist = if let Some(msg) = &err_msg { - format!( - "{}\n[Error during reasoning: {}]", - content_to_persist.trim_end(), - msg.trim_start_matches("[Agent error: ") - .trim_end_matches("]") - ) + format!("{}\n[Error during reasoning: {}]", content_to_persist.trim_end(), msg) } else { content_to_persist }; @@ -198,7 +231,6 @@ pub async fn process_message_ai_react_streaming( return; } - // Serialize ordered steps as JSON for ordered replay. let thinking_content_serialized = { let steps = lock_or_recover(&steps); if steps.is_empty() { @@ -250,7 +282,6 @@ pub async fn process_message_ai_react_streaming( tracing::warn!(error = %e, "Failed to update room_ai call stats"); } - // Billing handled internally by chat_service.process_react via record_ai_session let msg_event = queue::RoomMessageEvent { id: streaming_msg_id, room_id: room_id_inner, @@ -284,6 +315,19 @@ pub async fn process_message_ai_react_streaming( .await; } + let _ = typing_cancel_tx.send(()); + typing_renew_handle.abort(); + let typing_stop = queue::TypingEvent { + room_id: room_id_inner, + user_id: ai_typing_id, + username: ai_display_name.clone(), + avatar_url: None, + action: "stop".to_string(), + sender_type: Some("ai".to_string()), + }; + room_manager.broadcast_typing(room_id_inner, typing_stop).await; + + room_manager.unregister_stream_cancel(room_id_inner).await; room_manager.close_stream_channel(streaming_msg_id).await; }); } diff --git a/libs/room/src/service/ai_streaming.rs b/libs/room/src/service/ai_streaming.rs index 74bfdb2..822b00f 100644 --- a/libs/room/src/service/ai_streaming.rs +++ b/libs/room/src/service/ai_streaming.rs @@ -60,6 +60,7 @@ pub async fn process_message_ai_streaming( tokio::spawn(async move { let _lock_guard = lock_guard; + let cancel = room_manager.register_stream_cancel(room_id_inner).await; let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001") .expect("constant UUID should always parse"); let ai_display_name_for_chunk = ai_display_name.clone(); @@ -67,15 +68,22 @@ pub async fn process_message_ai_streaming( let chunk_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)); let room_manager_cb = room_manager.clone(); + let queue_for_chunk = queue.clone(); let on_chunk = move |chunk: agent::chat::AiStreamChunk| { Box::pin({ let room_manager = room_manager_cb.clone(); + let queue = queue_for_chunk.clone(); let streaming_msg_id = streaming_msg_id; let room_id = room_id_inner; let chunk_count = chunk_count.clone(); let ai_display_name_for_chunk = ai_display_name_for_chunk.clone(); + let cancel = cancel.clone(); async move { + if cancel.load(std::sync::atomic::Ordering::Acquire) { + // Stream was cancelled — drop this chunk + return; + } let chunk_type_str = match chunk.chunk_type { agent::chat::AiChunkType::Thinking => "thinking", agent::chat::AiChunkType::Answer => "answer", @@ -93,6 +101,7 @@ pub async fn process_message_ai_streaming( display_name: Some(ai_display_name_for_chunk), chunk_type: Some(chunk_type_str.to_string()), }; + queue.publish_stream_chunk(&event).await; room_manager.broadcast_stream_chunk(event).await; } }) as Pin + Send>> @@ -257,14 +266,16 @@ pub async fn process_message_ai_streaming( seq: 0, content: String::new(), done: true, - error: Some(e.to_string()), + error: Some("AI 处理发生错误,请稍后再试".to_string()), display_name: Some(ai_display_name.clone()), chunk_type: None, }; + queue.publish_stream_chunk(&event).await; room_manager.broadcast_stream_chunk(event).await; } } + room_manager.unregister_stream_cancel(room_id_inner).await; room_manager.close_stream_channel(streaming_msg_id).await; }); } diff --git a/libs/room/src/service/mod.rs b/libs/room/src/service/mod.rs index e632362..0b0c198 100644 --- a/libs/room/src/service/mod.rs +++ b/libs/room/src/service/mod.rs @@ -1,18 +1,23 @@ mod access; mod ai_common; +mod ai_mode_dispatch; +mod ai_mode_streaming; mod ai_nonstreaming; mod ai_react_nonstreaming; mod ai_react_streaming; +mod ai_service; mod ai_streaming; mod history; mod mentions; mod notifications; mod patterns; +pub use patterns::{mention_bracket_re, mention_tag_re, user_mention_re}; mod sequence; mod workers; pub use access::{check_room_access, check_project_member, require_room_member, find_room_or_404}; pub use ai_common::create_and_publish_ai_message; +pub use ai_service::RoomAiService; pub use ai_nonstreaming::process_message_ai_nonstreaming; pub use ai_react_nonstreaming::process_message_ai_react_nonstreaming; pub use ai_react_streaming::process_message_ai_react_streaming; @@ -41,8 +46,6 @@ use agent::embed::EmbedService; use agent::TaskService; use models::agent_task::AgentType; -use crate::service::patterns::{mention_bracket_re, mention_tag_re}; - const DEFAULT_MAX_CONCURRENT_WORKERS: usize = 1024; #[derive(Clone)] @@ -57,6 +60,7 @@ pub struct RoomService { pub task_service: Option>, pub embed_service: Option>, pub push_fn: Option, + pub ai_service: RoomAiService, worker_semaphore: Arc, dedup_cache: DedupCache, } @@ -77,6 +81,14 @@ impl RoomService { ) -> Self { let dedup_cache: DedupCache = Arc::new(dashmap::DashMap::with_capacity_and_hasher(10000, Default::default())); + let ai_service = RoomAiService::new( + db.clone(), + cache.clone(), + config.clone(), + queue.clone(), + room_manager.clone(), + chat_service.clone(), + ); Self { db, cache, @@ -87,6 +99,7 @@ impl RoomService { chat_service, task_service, embed_service, + ai_service, worker_semaphore: Arc::new(tokio::sync::Semaphore::new( max_concurrent_workers.unwrap_or(DEFAULT_MAX_CONCURRENT_WORKERS), )), @@ -258,34 +271,7 @@ impl RoomService { } pub async fn should_ai_respond(&self, room_id: Uuid, content: &str) -> Result { - let ai_configs = history::get_room_ai_configs(&self.db, room_id).await?; - if ai_configs.is_empty() { - return Ok(false); - } - - // Collect all model IDs in this room - let model_ids: std::collections::HashSet = ai_configs - .iter() - .map(|c| c.model.to_string()) - .collect(); - - for cap in mention_bracket_re().captures_iter(content) { - if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) { - if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) { - return Ok(true); - } - } - } - - for cap in mention_tag_re().captures_iter(content) { - if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) { - if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) { - return Ok(true); - } - } - } - - Ok(false) + self.ai_service.should_respond(room_id, content).await } pub async fn get_room_ai_config( @@ -421,66 +407,72 @@ impl RoomService { }; let use_streaming = ai_config.stream; - let is_react = ai_config.agent_type.as_deref() == Some("react"); - if is_react { - if use_streaming { - ai_react_streaming::process_message_ai_react_streaming( - chat_service.clone(), - request, - room_id, - room.project, - model_id, - lock_guard, - self.db.clone(), - self.cache.clone(), - self.queue.clone(), - self.room_manager.clone(), - ) - .await; - } else { - ai_react_nonstreaming::process_message_ai_react_nonstreaming( - chat_service.clone(), - request, - room_id, - room.project, - model_id, - lock_guard, - self.db.clone(), - self.cache.clone(), - self.queue.clone(), - self.room_manager.clone(), - ) - .await; + match ai_config.agent_type.as_deref() { + Some("cot") => { + if use_streaming { + ai_mode_dispatch::dispatch_cot( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } else { + if let Ok((content, _in, _out)) = chat_service.process_cot(&request, |_step| async {}).await { + let _ = create_and_publish_ai_message( + &self.db, &self.cache, &self.queue, &self.room_manager, + room_id, room.project, uuid::Uuid::new_v4(), content, model_id, + Some(request.model.name.clone()), + ).await; + } + } + } + Some("rewoo") => { + if use_streaming { + ai_mode_dispatch::dispatch_rewoo( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } + } + Some("reflexion") => { + if use_streaming { + ai_mode_dispatch::dispatch_reflexion( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } + } + Some("react") | _ => { + if ai_config.agent_type.as_deref() == Some("react") { + if use_streaming { + ai_react_streaming::process_message_ai_react_streaming( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } else { + ai_react_nonstreaming::process_message_ai_react_nonstreaming( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } + } else if use_streaming { + ai_streaming::process_message_ai_streaming( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } else { + ai_nonstreaming::process_message_ai_nonstreaming( + chat_service.clone(), request, room_id, room.project, model_id, + lock_guard, self.db.clone(), self.cache.clone(), + self.queue.clone(), self.room_manager.clone(), + ).await; + } } - } else if use_streaming { - ai_streaming::process_message_ai_streaming( - chat_service.clone(), - request, - room_id, - room.project, - model_id, - lock_guard, - self.db.clone(), - self.cache.clone(), - self.queue.clone(), - self.room_manager.clone(), - ) - .await; - } else { - ai_nonstreaming::process_message_ai_nonstreaming( - chat_service.clone(), - request, - room_id, - room.project, - model_id, - lock_guard, - self.db.clone(), - self.cache.clone(), - self.queue.clone(), - self.room_manager.clone(), - ) - .await; } Ok(()) diff --git a/libs/room/src/service/sequence.rs b/libs/room/src/service/sequence.rs index 965978e..e00eed2 100644 --- a/libs/room/src/service/sequence.rs +++ b/libs/room/src/service/sequence.rs @@ -6,6 +6,21 @@ use uuid::Uuid; use crate::error::RoomError; +/// Redis Lua script that atomically INCRs the sequence number and +/// reconciles with the database max seq every 1000 increments. +/// Returns the final assigned seq (guaranteed > any existing message seq). +const ATOMIC_INCR_SCRIPT: &str = r#" +local seq = redis.call('INCR', KEYS[1]) +if seq % 1000 == 0 then + local db_seq = tonumber(ARGV[1]) or 0 + if db_seq >= seq then + redis.call('SET', KEYS[1], db_seq + 1) + return db_seq + 1 + end +end +return seq +"#; + pub async fn next_room_message_seq_internal( room_id: Uuid, db: &AppDatabase, @@ -16,34 +31,24 @@ pub async fn next_room_message_seq_internal( RoomError::Internal(format!("failed to get redis connection for seq: {}", e)) })?; - let seq: i64 = redis::cmd("INCR") - .arg(&seq_key) - .query_async(&mut conn) + let db_seq: i64 = RoomMessage::find() + .filter(RmCol::Room.eq(room_id)) + .select_only() + .column_as(RmCol::Seq.max(), "max_seq") + .into_tuple::>>() + .one(db) + .await? + .flatten() + .flatten() + .unwrap_or(0); + + let script = redis::Script::new(ATOMIC_INCR_SCRIPT); + let seq: i64 = script + .key(&seq_key) + .arg(db_seq) + .invoke_async(&mut conn) .await - .map_err(|e| RoomError::Internal(format!("seq INCR: {}", e)))?; - - // DB reconciliation: only check every 1000 messages - if seq % 1000 == 0 { - let db_seq: Option>> = RoomMessage::find() - .filter(RmCol::Room.eq(room_id)) - .select_only() - .column_as(RmCol::Seq.max(), "max_seq") - .into_tuple::>>() - .one(db) - .await? - .map(|r| r); - let db_seq = db_seq.flatten().flatten().unwrap_or(0); - - if db_seq >= seq { - let _: String = redis::cmd("SET") - .arg(&seq_key) - .arg(db_seq + 1) - .query_async(&mut conn) - .await - .map_err(|e| RoomError::Internal(format!("seq SET: {}", e)))?; - return Ok(db_seq + 1); - } - } + .map_err(|e| RoomError::Internal(format!("seq atomic INCR: {}", e)))?; Ok(seq) } diff --git a/libs/room/src/service/workers.rs b/libs/room/src/service/workers.rs index 44e4602..61a19e4 100644 --- a/libs/room/src/service/workers.rs +++ b/libs/room/src/service/workers.rs @@ -5,7 +5,7 @@ use db::cache::AppCache; use db::database::AppDatabase; use models::rooms::room; use queue::{AgentTaskEvent, MessageProducer}; -use sea_orm::{EntityTrait, QuerySelect}; +use sea_orm::EntityTrait; use uuid::Uuid; use crate::connection::{ @@ -28,11 +28,10 @@ pub async fn start_workers( mut shutdown_rx: tokio::sync::broadcast::Receiver<()>, embed_service: Option>, ) -> anyhow::Result<()> { - // Load rooms with a reasonable cap to prevent resource exhaustion on large instances. - // Rooms beyond this limit will be activated on-demand when first accessed. - const MAX_INITIAL_ROOMS: u64 = 1000; + // Load all rooms. For large deployments with thousands of rooms, + // consider implementing distributed worker sharding (consistent hashing) + // to avoid all rooms being handled by a single instance. let rooms: Vec = room::Entity::find() - .limit(MAX_INITIAL_ROOMS) .all(&db) .await?; let room_ids: Vec = rooms.iter().map(|r| r.id).collect(); @@ -62,6 +61,7 @@ pub async fn start_workers( extract_get_redis(queue.clone()); let worker_room_ids = room_ids.clone(); + let stream_chunk_room_ids = room_ids.clone(); let worker_shutdown = shutdown_rx.resubscribe(); let worker_handle = tokio::spawn({ let get_redis = get_redis.clone(); @@ -92,6 +92,25 @@ pub async fn start_workers( }) .collect(); + let stream_chunk_handles: Vec<_> = stream_chunk_room_ids + .into_iter() + .map(|room_id| { + let manager = manager.clone(); + let redis_url = redis_url_clone.clone(); + let shutdown_rx = shutdown_rx.resubscribe(); + tokio::spawn(async move { + crate::connection::subscribe_room_stream_chunk_events( + redis_url, + manager, + room_id, + shutdown_rx, + ) + .await; + }) + }) + .collect(); + handles.extend(stream_chunk_handles); + let project_handles: Vec<_> = project_ids .into_iter() .map(|project_id| { @@ -289,15 +308,17 @@ pub fn spawn_room_workers( Default::default(), ), ), - embed_service, + embed_service.clone(), ); let get_redis: Arc queue::worker::RedisFuture + Send + Sync> = extract_get_redis(queue.clone()); let manager1 = room_manager.clone(); let manager2 = room_manager.clone(); let manager3 = room_manager.clone(); + let manager4 = room_manager.clone(); let redis_url_clone = redis_url.clone(); let redis_url3 = redis_url.clone(); + let redis_url4 = redis_url.clone(); let semaphore = worker_semaphore.clone(); tokio::spawn(async move { @@ -350,4 +371,15 @@ pub fn spawn_room_workers( ) .await; }); + + tokio::spawn(async move { + let shutdown_rx = manager4.register_room(room_id).await; + crate::connection::subscribe_room_stream_chunk_events( + redis_url4, + manager4, + room_id, + shutdown_rx, + ) + .await; + }); }