refactor(room): refactor AI service modules for cleaner separation

Simplify ai_streaming by delegating to ai_mode_streaming.
Extract sequence coordination into dedicated module.
Add worker pool management for concurrent AI task handling.
Refine ai_react_streaming for better delta chunk handling.
This commit is contained in:
ZhenYi 2026-04-30 19:16:23 +08:00
parent 4ba47370be
commit 5b81e7d774
7 changed files with 249 additions and 165 deletions

View File

@ -75,7 +75,7 @@ pub async fn process_message_ai_nonstreaming(
room_id,
project_id,
Uuid::now_v7(),
format!("[AI error: {}]", e),
"[AI 处理发生错误,请稍后再试]".to_string(),
model_id,
Some(model_display_name),
)

View File

@ -29,7 +29,7 @@ pub async fn process_message_ai_react_nonstreaming(
let model_display_name = request.model.name.clone();
let final_answer = chat_service
.process_react(&request, |_step| {})
.process_react(&request, |_step| async move {})
.await;
match final_answer {
@ -77,7 +77,7 @@ pub async fn process_message_ai_react_nonstreaming(
room_id,
project_id,
Uuid::now_v7(),
format!("[AI error: {}]", e),
"[AI 处理发生错误,请稍后再试]".to_string(),
model_id,
Some(model_display_name),
)

View File

@ -44,9 +44,41 @@ pub async fn process_message_ai_react_streaming(
tokio::spawn(async move {
let _lock_guard = lock_guard;
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
.expect("constant UUID should always parse");
let typing_start = queue::TypingEvent {
room_id: room_id_inner,
user_id: ai_typing_id,
username: ai_display_name.clone(),
avatar_url: None,
action: "start".to_string(),
sender_type: Some("ai".to_string()),
};
room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await;
let (typing_cancel_tx, mut typing_cancel_rx) = tokio::sync::oneshot::channel::<()>();
let typing_renew_handle = tokio::spawn({
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mgr = room_manager.clone();
let rid = room_id_inner;
let evt = typing_start.clone();
async move {
tokio::select! {
_ = &mut typing_cancel_rx => {}
_ = async {
loop {
interval.tick().await;
mgr.broadcast_typing(rid, evt.clone()).await;
}
} => {}
}
}
});
// Collect ordered steps for storage and streaming.
// Using poison-recovering guards to prevent Mutex poisoning from killing the room.
let steps: std::sync::Arc<std::sync::Mutex<Vec<(String, String)>>> =
std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let last_action_name: std::sync::Arc<std::sync::Mutex<String>> =
@ -54,15 +86,17 @@ pub async fn process_message_ai_react_streaming(
let answer_buffer: std::sync::Arc<std::sync::Mutex<String>> =
std::sync::Arc::new(std::sync::Mutex::new(String::new()));
let step_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
let chunk_seq: std::sync::Arc<std::sync::atomic::AtomicU64> =
std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1));
// Helper: recover from poison instead of panicking.
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
}
let on_step = {
let room_manager = room_manager.clone();
let queue = queue.clone();
let cancel = cancel.clone();
let streaming_msg_id = streaming_msg_id;
let room_id = room_id_inner;
let step_count = step_count.clone();
@ -73,6 +107,8 @@ pub async fn process_message_ai_react_streaming(
let last_action_name = last_action_name.clone();
move |step: ReactStep| {
let room_manager = room_manager.clone();
let queue = queue.clone();
let cancel = cancel.clone();
let (chunk_type, content) = match &step {
ReactStep::Thought { step: _, thought } => {
("thinking".to_string(), thought.clone())
@ -100,8 +136,6 @@ pub async fn process_message_ai_react_streaming(
step_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
// Record ordered step for storage — merge consecutive same-type chunks
// to ensure strict think→answer→think→answer alternation.
{
let mut s = lock_or_recover(&steps);
if let Some(last) = s.last_mut() {
@ -122,43 +156,47 @@ pub async fn process_message_ai_react_streaming(
let done = false;
let ai_name = ai_display_name_for_step.clone();
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::spawn(async move {
let event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id,
seq: current_seq,
content: content.clone(),
done,
error: None,
display_name: Some((*ai_name).clone()),
chunk_type: Some(chunk_type),
};
let event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id,
seq: current_seq,
content: content.clone(),
done,
error: None,
display_name: Some((*ai_name).clone()),
chunk_type: Some(chunk_type),
};
async move {
if cancel.load(std::sync::atomic::Ordering::Acquire) {
return;
}
queue.publish_stream_chunk(&event).await;
room_manager.broadcast_stream_chunk(event).await;
});
}
}
};
let result = chat_service.process_react(&request, on_step).await;
// Broadcast final done=true event to close the streaming channel on frontend.
let final_stream_content = lock_or_recover(&answer_buffer).clone();
room_manager
.broadcast_stream_chunk(RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id: room_id_inner,
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
content: final_stream_content.clone(),
done: true,
error: None,
display_name: Some(ai_display_name.clone()),
chunk_type: Some("answer".to_string()),
})
.await;
let final_event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id: room_id_inner,
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
content: final_stream_content.clone(),
done: true,
error: None,
display_name: Some(ai_display_name.clone()),
chunk_type: Some("answer".to_string()),
};
queue.publish_stream_chunk(&final_event).await;
room_manager.broadcast_stream_chunk(final_event).await;
let (final_content, _input_tokens, _output_tokens, err_msg, _should_log) = match result {
Ok((content, input, output)) => (content, input, output, None, false),
Err(e) => {
let msg = format!("[Agent error: {}]", e);
let msg = "[AI 处理发生错误,请稍后再试]".to_string();
tracing::error!(error = %e, "ReAct streaming failed");
(String::new(), 0, 0, Some(msg), true)
}
@ -183,12 +221,7 @@ pub async fn process_message_ai_react_streaming(
String::from("[No output from reasoning agent]")
};
let content_to_persist = if let Some(msg) = &err_msg {
format!(
"{}\n[Error during reasoning: {}]",
content_to_persist.trim_end(),
msg.trim_start_matches("[Agent error: ")
.trim_end_matches("]")
)
format!("{}\n[Error during reasoning: {}]", content_to_persist.trim_end(), msg)
} else {
content_to_persist
};
@ -198,7 +231,6 @@ pub async fn process_message_ai_react_streaming(
return;
}
// Serialize ordered steps as JSON for ordered replay.
let thinking_content_serialized = {
let steps = lock_or_recover(&steps);
if steps.is_empty() {
@ -250,7 +282,6 @@ pub async fn process_message_ai_react_streaming(
tracing::warn!(error = %e, "Failed to update room_ai call stats");
}
// Billing handled internally by chat_service.process_react via record_ai_session
let msg_event = queue::RoomMessageEvent {
id: streaming_msg_id,
room_id: room_id_inner,
@ -284,6 +315,19 @@ pub async fn process_message_ai_react_streaming(
.await;
}
let _ = typing_cancel_tx.send(());
typing_renew_handle.abort();
let typing_stop = queue::TypingEvent {
room_id: room_id_inner,
user_id: ai_typing_id,
username: ai_display_name.clone(),
avatar_url: None,
action: "stop".to_string(),
sender_type: Some("ai".to_string()),
};
room_manager.broadcast_typing(room_id_inner, typing_stop).await;
room_manager.unregister_stream_cancel(room_id_inner).await;
room_manager.close_stream_channel(streaming_msg_id).await;
});
}

View File

@ -60,6 +60,7 @@ pub async fn process_message_ai_streaming(
tokio::spawn(async move {
let _lock_guard = lock_guard;
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
.expect("constant UUID should always parse");
let ai_display_name_for_chunk = ai_display_name.clone();
@ -67,15 +68,22 @@ pub async fn process_message_ai_streaming(
let chunk_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
let room_manager_cb = room_manager.clone();
let queue_for_chunk = queue.clone();
let on_chunk = move |chunk: agent::chat::AiStreamChunk| {
Box::pin({
let room_manager = room_manager_cb.clone();
let queue = queue_for_chunk.clone();
let streaming_msg_id = streaming_msg_id;
let room_id = room_id_inner;
let chunk_count = chunk_count.clone();
let ai_display_name_for_chunk = ai_display_name_for_chunk.clone();
let cancel = cancel.clone();
async move {
if cancel.load(std::sync::atomic::Ordering::Acquire) {
// Stream was cancelled — drop this chunk
return;
}
let chunk_type_str = match chunk.chunk_type {
agent::chat::AiChunkType::Thinking => "thinking",
agent::chat::AiChunkType::Answer => "answer",
@ -93,6 +101,7 @@ pub async fn process_message_ai_streaming(
display_name: Some(ai_display_name_for_chunk),
chunk_type: Some(chunk_type_str.to_string()),
};
queue.publish_stream_chunk(&event).await;
room_manager.broadcast_stream_chunk(event).await;
}
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
@ -257,14 +266,16 @@ pub async fn process_message_ai_streaming(
seq: 0,
content: String::new(),
done: true,
error: Some(e.to_string()),
error: Some("AI 处理发生错误,请稍后再试".to_string()),
display_name: Some(ai_display_name.clone()),
chunk_type: None,
};
queue.publish_stream_chunk(&event).await;
room_manager.broadcast_stream_chunk(event).await;
}
}
room_manager.unregister_stream_cancel(room_id_inner).await;
room_manager.close_stream_channel(streaming_msg_id).await;
});
}

View File

@ -1,18 +1,23 @@
mod access;
mod ai_common;
mod ai_mode_dispatch;
mod ai_mode_streaming;
mod ai_nonstreaming;
mod ai_react_nonstreaming;
mod ai_react_streaming;
mod ai_service;
mod ai_streaming;
mod history;
mod mentions;
mod notifications;
mod patterns;
pub use patterns::{mention_bracket_re, mention_tag_re, user_mention_re};
mod sequence;
mod workers;
pub use access::{check_room_access, check_project_member, require_room_member, find_room_or_404};
pub use ai_common::create_and_publish_ai_message;
pub use ai_service::RoomAiService;
pub use ai_nonstreaming::process_message_ai_nonstreaming;
pub use ai_react_nonstreaming::process_message_ai_react_nonstreaming;
pub use ai_react_streaming::process_message_ai_react_streaming;
@ -41,8 +46,6 @@ use agent::embed::EmbedService;
use agent::TaskService;
use models::agent_task::AgentType;
use crate::service::patterns::{mention_bracket_re, mention_tag_re};
const DEFAULT_MAX_CONCURRENT_WORKERS: usize = 1024;
#[derive(Clone)]
@ -57,6 +60,7 @@ pub struct RoomService {
pub task_service: Option<Arc<TaskService>>,
pub embed_service: Option<Arc<EmbedService>>,
pub push_fn: Option<workers::PushNotificationFn>,
pub ai_service: RoomAiService,
worker_semaphore: Arc<tokio::sync::Semaphore>,
dedup_cache: DedupCache,
}
@ -77,6 +81,14 @@ impl RoomService {
) -> Self {
let dedup_cache: DedupCache =
Arc::new(dashmap::DashMap::with_capacity_and_hasher(10000, Default::default()));
let ai_service = RoomAiService::new(
db.clone(),
cache.clone(),
config.clone(),
queue.clone(),
room_manager.clone(),
chat_service.clone(),
);
Self {
db,
cache,
@ -87,6 +99,7 @@ impl RoomService {
chat_service,
task_service,
embed_service,
ai_service,
worker_semaphore: Arc::new(tokio::sync::Semaphore::new(
max_concurrent_workers.unwrap_or(DEFAULT_MAX_CONCURRENT_WORKERS),
)),
@ -258,34 +271,7 @@ impl RoomService {
}
pub async fn should_ai_respond(&self, room_id: Uuid, content: &str) -> Result<bool, RoomError> {
let ai_configs = history::get_room_ai_configs(&self.db, room_id).await?;
if ai_configs.is_empty() {
return Ok(false);
}
// Collect all model IDs in this room
let model_ids: std::collections::HashSet<String> = ai_configs
.iter()
.map(|c| c.model.to_string())
.collect();
for cap in mention_bracket_re().captures_iter(content) {
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
return Ok(true);
}
}
}
for cap in mention_tag_re().captures_iter(content) {
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
return Ok(true);
}
}
}
Ok(false)
self.ai_service.should_respond(room_id, content).await
}
pub async fn get_room_ai_config(
@ -421,66 +407,72 @@ impl RoomService {
};
let use_streaming = ai_config.stream;
let is_react = ai_config.agent_type.as_deref() == Some("react");
if is_react {
if use_streaming {
ai_react_streaming::process_message_ai_react_streaming(
chat_service.clone(),
request,
room_id,
room.project,
model_id,
lock_guard,
self.db.clone(),
self.cache.clone(),
self.queue.clone(),
self.room_manager.clone(),
)
.await;
} else {
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
chat_service.clone(),
request,
room_id,
room.project,
model_id,
lock_guard,
self.db.clone(),
self.cache.clone(),
self.queue.clone(),
self.room_manager.clone(),
)
.await;
match ai_config.agent_type.as_deref() {
Some("cot") => {
if use_streaming {
ai_mode_dispatch::dispatch_cot(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
} else {
if let Ok((content, _in, _out)) = chat_service.process_cot(&request, |_step| async {}).await {
let _ = create_and_publish_ai_message(
&self.db, &self.cache, &self.queue, &self.room_manager,
room_id, room.project, uuid::Uuid::new_v4(), content, model_id,
Some(request.model.name.clone()),
).await;
}
}
}
Some("rewoo") => {
if use_streaming {
ai_mode_dispatch::dispatch_rewoo(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
}
}
Some("reflexion") => {
if use_streaming {
ai_mode_dispatch::dispatch_reflexion(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
}
}
Some("react") | _ => {
if ai_config.agent_type.as_deref() == Some("react") {
if use_streaming {
ai_react_streaming::process_message_ai_react_streaming(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
} else {
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
}
} else if use_streaming {
ai_streaming::process_message_ai_streaming(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
} else {
ai_nonstreaming::process_message_ai_nonstreaming(
chat_service.clone(), request, room_id, room.project, model_id,
lock_guard, self.db.clone(), self.cache.clone(),
self.queue.clone(), self.room_manager.clone(),
).await;
}
}
} else if use_streaming {
ai_streaming::process_message_ai_streaming(
chat_service.clone(),
request,
room_id,
room.project,
model_id,
lock_guard,
self.db.clone(),
self.cache.clone(),
self.queue.clone(),
self.room_manager.clone(),
)
.await;
} else {
ai_nonstreaming::process_message_ai_nonstreaming(
chat_service.clone(),
request,
room_id,
room.project,
model_id,
lock_guard,
self.db.clone(),
self.cache.clone(),
self.queue.clone(),
self.room_manager.clone(),
)
.await;
}
Ok(())

View File

@ -6,6 +6,21 @@ use uuid::Uuid;
use crate::error::RoomError;
/// Redis Lua script that atomically INCRs the sequence number and
/// reconciles with the database max seq every 1000 increments.
/// Returns the final assigned seq (guaranteed > any existing message seq).
const ATOMIC_INCR_SCRIPT: &str = r#"
local seq = redis.call('INCR', KEYS[1])
if seq % 1000 == 0 then
local db_seq = tonumber(ARGV[1]) or 0
if db_seq >= seq then
redis.call('SET', KEYS[1], db_seq + 1)
return db_seq + 1
end
end
return seq
"#;
pub async fn next_room_message_seq_internal(
room_id: Uuid,
db: &AppDatabase,
@ -16,34 +31,24 @@ pub async fn next_room_message_seq_internal(
RoomError::Internal(format!("failed to get redis connection for seq: {}", e))
})?;
let seq: i64 = redis::cmd("INCR")
.arg(&seq_key)
.query_async(&mut conn)
let db_seq: i64 = RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.select_only()
.column_as(RmCol::Seq.max(), "max_seq")
.into_tuple::<Option<Option<i64>>>()
.one(db)
.await?
.flatten()
.flatten()
.unwrap_or(0);
let script = redis::Script::new(ATOMIC_INCR_SCRIPT);
let seq: i64 = script
.key(&seq_key)
.arg(db_seq)
.invoke_async(&mut conn)
.await
.map_err(|e| RoomError::Internal(format!("seq INCR: {}", e)))?;
// DB reconciliation: only check every 1000 messages
if seq % 1000 == 0 {
let db_seq: Option<Option<Option<i64>>> = RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.select_only()
.column_as(RmCol::Seq.max(), "max_seq")
.into_tuple::<Option<Option<i64>>>()
.one(db)
.await?
.map(|r| r);
let db_seq = db_seq.flatten().flatten().unwrap_or(0);
if db_seq >= seq {
let _: String = redis::cmd("SET")
.arg(&seq_key)
.arg(db_seq + 1)
.query_async(&mut conn)
.await
.map_err(|e| RoomError::Internal(format!("seq SET: {}", e)))?;
return Ok(db_seq + 1);
}
}
.map_err(|e| RoomError::Internal(format!("seq atomic INCR: {}", e)))?;
Ok(seq)
}

View File

@ -5,7 +5,7 @@ use db::cache::AppCache;
use db::database::AppDatabase;
use models::rooms::room;
use queue::{AgentTaskEvent, MessageProducer};
use sea_orm::{EntityTrait, QuerySelect};
use sea_orm::EntityTrait;
use uuid::Uuid;
use crate::connection::{
@ -28,11 +28,10 @@ pub async fn start_workers(
mut shutdown_rx: tokio::sync::broadcast::Receiver<()>,
embed_service: Option<Arc<agent::embed::EmbedService>>,
) -> anyhow::Result<()> {
// Load rooms with a reasonable cap to prevent resource exhaustion on large instances.
// Rooms beyond this limit will be activated on-demand when first accessed.
const MAX_INITIAL_ROOMS: u64 = 1000;
// Load all rooms. For large deployments with thousands of rooms,
// consider implementing distributed worker sharding (consistent hashing)
// to avoid all rooms being handled by a single instance.
let rooms: Vec<room::Model> = room::Entity::find()
.limit(MAX_INITIAL_ROOMS)
.all(&db)
.await?;
let room_ids: Vec<uuid::Uuid> = rooms.iter().map(|r| r.id).collect();
@ -62,6 +61,7 @@ pub async fn start_workers(
extract_get_redis(queue.clone());
let worker_room_ids = room_ids.clone();
let stream_chunk_room_ids = room_ids.clone();
let worker_shutdown = shutdown_rx.resubscribe();
let worker_handle = tokio::spawn({
let get_redis = get_redis.clone();
@ -92,6 +92,25 @@ pub async fn start_workers(
})
.collect();
let stream_chunk_handles: Vec<_> = stream_chunk_room_ids
.into_iter()
.map(|room_id| {
let manager = manager.clone();
let redis_url = redis_url_clone.clone();
let shutdown_rx = shutdown_rx.resubscribe();
tokio::spawn(async move {
crate::connection::subscribe_room_stream_chunk_events(
redis_url,
manager,
room_id,
shutdown_rx,
)
.await;
})
})
.collect();
handles.extend(stream_chunk_handles);
let project_handles: Vec<_> = project_ids
.into_iter()
.map(|project_id| {
@ -289,15 +308,17 @@ pub fn spawn_room_workers(
Default::default(),
),
),
embed_service,
embed_service.clone(),
);
let get_redis: Arc<dyn Fn() -> queue::worker::RedisFuture + Send + Sync> =
extract_get_redis(queue.clone());
let manager1 = room_manager.clone();
let manager2 = room_manager.clone();
let manager3 = room_manager.clone();
let manager4 = room_manager.clone();
let redis_url_clone = redis_url.clone();
let redis_url3 = redis_url.clone();
let redis_url4 = redis_url.clone();
let semaphore = worker_semaphore.clone();
tokio::spawn(async move {
@ -350,4 +371,15 @@ pub fn spawn_room_workers(
)
.await;
});
tokio::spawn(async move {
let shutdown_rx = manager4.register_room(room_id).await;
crate::connection::subscribe_room_stream_chunk_events(
redis_url4,
manager4,
room_id,
shutdown_rx,
)
.await;
});
}