feat(room): add mode-dispatched AI service orchestration
Add RoomAiService as the central dispatcher that selects execution path based on mode (react/chat/cot/reflexion/rewoo) and streams vs nonstreaming preference. Replace monolithic ai_streaming with mode-aware dispatch and dedicated streaming implementation.
This commit is contained in:
parent
27b9d3e4bd
commit
4ba47370be
217
libs/agent/chat/state.rs
Normal file
217
libs/agent/chat/state.rs
Normal file
@ -0,0 +1,217 @@
|
||||
//! Agent state machine — tracks lifecycle of a single AI agent invocation.
|
||||
//!
|
||||
//! States: Idle → Thinking → ToolCall → Thinking → ... → Answering | Error
|
||||
//! The Thinking ↔ ToolCall cycle repeats until max tool depth or final answer.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Current phase of an agent's execution lifecycle.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum AgentState {
|
||||
/// Agent is idle, waiting for input
|
||||
Idle,
|
||||
/// Agent is reasoning/thinking (may produce thinking chunks)
|
||||
Thinking {
|
||||
started_at: DateTime<Utc>,
|
||||
tool_depth: u32,
|
||||
},
|
||||
/// Agent is executing a tool call
|
||||
ToolCall {
|
||||
tool_name: String,
|
||||
started_at: DateTime<Utc>,
|
||||
},
|
||||
/// Agent is returning the final answer
|
||||
Answering {
|
||||
/// Accumulated answer content so far
|
||||
content_chars: u64,
|
||||
started_at: DateTime<Utc>,
|
||||
},
|
||||
/// Agent encountered a non-recoverable error
|
||||
Error {
|
||||
message: String,
|
||||
tool_depth: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl AgentState {
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(self, AgentState::Answering { .. } | AgentState::Error { .. })
|
||||
}
|
||||
|
||||
pub fn is_idle(&self) -> bool {
|
||||
matches!(self, AgentState::Idle)
|
||||
}
|
||||
|
||||
pub fn current_phase(&self) -> &'static str {
|
||||
match self {
|
||||
AgentState::Idle => "idle",
|
||||
AgentState::Thinking { .. } => "thinking",
|
||||
AgentState::ToolCall { .. } => "tool_call",
|
||||
AgentState::Answering { .. } => "answering",
|
||||
AgentState::Error { .. } => "error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// State machine for agent lifecycle transitions.
|
||||
pub struct AgentRuntime {
|
||||
state: AgentState,
|
||||
max_tool_depth: u32,
|
||||
current_depth: u32,
|
||||
}
|
||||
|
||||
impl AgentRuntime {
|
||||
pub fn new(max_tool_depth: u32) -> Self {
|
||||
Self {
|
||||
state: AgentState::Idle,
|
||||
max_tool_depth,
|
||||
current_depth: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> &AgentState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Transition from Idle → Thinking
|
||||
pub fn start_thinking(&mut self) {
|
||||
debug_assert!(self.state.is_idle(), "must be Idle to start thinking");
|
||||
self.current_depth = 0;
|
||||
self.state = AgentState::Thinking {
|
||||
started_at: Utc::now(),
|
||||
tool_depth: 0,
|
||||
};
|
||||
}
|
||||
|
||||
/// Transition from Thinking → ToolCall (increments tool depth)
|
||||
pub fn start_tool_call(&mut self, tool_name: String) -> Result<(), &'static str> {
|
||||
if !matches!(self.state, AgentState::Thinking { .. }) {
|
||||
return Err("must be Thinking to start tool call");
|
||||
}
|
||||
if self.current_depth >= self.max_tool_depth {
|
||||
return Err("max tool depth reached");
|
||||
}
|
||||
self.state = AgentState::ToolCall {
|
||||
tool_name,
|
||||
started_at: Utc::now(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transition from ToolCall → Thinking (back after tool result)
|
||||
pub fn complete_tool_call(&mut self) -> Result<(), &'static str> {
|
||||
if !matches!(self.state, AgentState::ToolCall { .. }) {
|
||||
return Err("must be ToolCall to complete");
|
||||
}
|
||||
self.current_depth += 1;
|
||||
self.state = AgentState::Thinking {
|
||||
started_at: Utc::now(),
|
||||
tool_depth: self.current_depth,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transition to Answering (terminal)
|
||||
pub fn start_answer(&mut self) {
|
||||
self.state = AgentState::Answering {
|
||||
content_chars: 0,
|
||||
started_at: Utc::now(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn append_answer(&mut self, content: &str) {
|
||||
if let AgentState::Answering { content_chars, .. } = &mut self.state {
|
||||
*content_chars += content.len() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition to Error (terminal)
|
||||
pub fn fail(&mut self, message: String) {
|
||||
self.state = AgentState::Error {
|
||||
message,
|
||||
tool_depth: self.current_depth,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_starts_idle() {
|
||||
let rt = AgentRuntime::new(10);
|
||||
assert!(rt.state().is_idle());
|
||||
assert_eq!(rt.state().current_phase(), "idle");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idle_to_thinking() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
assert_eq!(rt.state().current_phase(), "thinking");
|
||||
assert!(!rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_to_tool_call_and_back() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_tool_call("search".into()).unwrap();
|
||||
assert_eq!(rt.state().current_phase(), "tool_call");
|
||||
rt.complete_tool_call().unwrap();
|
||||
assert_eq!(rt.state().current_phase(), "thinking");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_to_answer() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_answer();
|
||||
assert_eq!(rt.state().current_phase(), "answering");
|
||||
assert!(rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_answer_tracks_chars() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.start_answer();
|
||||
rt.append_answer("hello");
|
||||
if let AgentState::Answering { content_chars, .. } = rt.state() {
|
||||
assert_eq!(*content_chars, 5);
|
||||
} else {
|
||||
panic!("expected Answering state");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_is_terminal() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
rt.start_thinking();
|
||||
rt.fail("something broke".into());
|
||||
assert_eq!(rt.state().current_phase(), "error");
|
||||
assert!(rt.state().is_terminal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transition_from_wrong_state() {
|
||||
let mut rt = AgentRuntime::new(10);
|
||||
// Can't start tool call from Idle
|
||||
assert!(rt.start_tool_call("tool".into()).is_err());
|
||||
// Can't complete tool call from Idle
|
||||
assert!(rt.complete_tool_call().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_depth_rejected() {
|
||||
let mut rt = AgentRuntime::new(2);
|
||||
rt.start_thinking();
|
||||
rt.start_tool_call("tool1".into()).unwrap();
|
||||
rt.complete_tool_call().unwrap();
|
||||
rt.start_tool_call("tool2".into()).unwrap();
|
||||
rt.complete_tool_call().unwrap();
|
||||
assert!(rt.start_tool_call("tool3".into()).is_err());
|
||||
}
|
||||
}
|
||||
183
libs/room/src/service/ai_mode_dispatch.rs
Normal file
183
libs/room/src/service/ai_mode_dispatch.rs
Normal file
@ -0,0 +1,183 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use queue::MessageProducer;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ai_mode_streaming::run_mode_streaming;
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
// ── CoT ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_cot(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service,
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"cot",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_cot(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::cot::CotStep::Thought(t) => {
|
||||
("thinking".to_string(), t, false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Action { name, args } => {
|
||||
("tool_call".to_string(),
|
||||
serde_json::json!({"name": name, "arguments": args}).to_string(),
|
||||
false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Observation(o) => {
|
||||
("tool_result".to_string(), o, false)
|
||||
}
|
||||
agent::modes::cot::CotStep::Answer(a) => {
|
||||
("answer".to_string(), a, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
})
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// ── ReWOO ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_rewoo(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service.clone(),
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"rewoo",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_rewoo(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::rewoo::ReWooStep::Plan { raw, .. } => {
|
||||
("tool_call".to_string(), raw, false)
|
||||
}
|
||||
agent::modes::rewoo::ReWooStep::Execution { tool_name, result } => {
|
||||
("tool_result".to_string(), format!("{}: {}", tool_name, result), false)
|
||||
}
|
||||
agent::modes::rewoo::ReWooStep::Synthesis(s) => {
|
||||
("answer".to_string(), s, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
})
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// ── Reflexion ────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn dispatch_reflexion(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
) {
|
||||
let req = request;
|
||||
run_mode_streaming(
|
||||
chat_service.clone(),
|
||||
req,
|
||||
room_id,
|
||||
project_id,
|
||||
model_id,
|
||||
lock_guard,
|
||||
db,
|
||||
cache,
|
||||
queue,
|
||||
room_manager,
|
||||
"reflexion",
|
||||
Box::new(|chat_svc, ai_req, on_chunk| {
|
||||
Box::pin(async move {
|
||||
let result = chat_svc
|
||||
.process_reflexion(&ai_req, |step| {
|
||||
let on_chunk = on_chunk.clone();
|
||||
async move {
|
||||
let (ct, content, is_final) = match step {
|
||||
agent::modes::reflexion::ReflexionStep::Generate(s) => {
|
||||
("thinking".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Critique(s) => {
|
||||
("tool_call".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Revise(s) => {
|
||||
("thinking".to_string(), s, false)
|
||||
}
|
||||
agent::modes::reflexion::ReflexionStep::Final(s) => {
|
||||
("answer".to_string(), s, true)
|
||||
}
|
||||
};
|
||||
on_chunk(ct, content, is_final).await;
|
||||
}
|
||||
}, 3)
|
||||
.await;
|
||||
result
|
||||
})
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
324
libs/room/src/service/ai_mode_streaming.rs
Normal file
324
libs/room/src/service/ai_mode_streaming.rs
Normal file
@ -0,0 +1,324 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::rooms::room_ai;
|
||||
use queue::{MessageProducer, ProjectRoomEvent, RoomMessageEnvelope};
|
||||
use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, ExprTrait, QueryFilter};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::sequence::next_room_message_seq_internal;
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
pub type RunModeFn = Box<
|
||||
dyn FnOnce(
|
||||
Arc<ChatService>,
|
||||
AiRequest,
|
||||
Arc<dyn Fn(String, String, bool) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = Result<(String, i64, i64), agent::AgentError>> + Send>>
|
||||
+ Send,
|
||||
>;
|
||||
|
||||
pub async fn run_mode_streaming(
|
||||
chat_service: Arc<ChatService>,
|
||||
request: AiRequest,
|
||||
room_id: Uuid,
|
||||
project_id: Uuid,
|
||||
model_id: Uuid,
|
||||
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
mode_name_str: &str,
|
||||
run: RunModeFn,
|
||||
) {
|
||||
let mode_name = mode_name_str.to_string();
|
||||
use queue::RoomMessageStreamChunkEvent;
|
||||
|
||||
let streaming_msg_id = Uuid::now_v7();
|
||||
let seq = match next_room_message_seq_internal(room_id, &db, &cache).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to get seq for {} streaming", mode_name);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let _ = room_manager
|
||||
.register_stream_channel(streaming_msg_id)
|
||||
.await;
|
||||
|
||||
let initial_event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: 0,
|
||||
content: String::new(),
|
||||
done: false,
|
||||
error: None,
|
||||
display_name: Some(request.model.name.clone()),
|
||||
chunk_type: Some("thinking".to_string()),
|
||||
};
|
||||
room_manager.broadcast_stream_chunk(initial_event).await;
|
||||
|
||||
let room_id_inner = room_id;
|
||||
let project_id_inner = project_id;
|
||||
let now = Utc::now();
|
||||
let sender_type = "ai".to_string();
|
||||
let ai_display_name = request.model.name.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _lock_guard = lock_guard;
|
||||
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
|
||||
|
||||
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
|
||||
.expect("constant UUID should always parse");
|
||||
|
||||
let typing_start = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "start".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await;
|
||||
|
||||
let (typing_cancel_tx, typing_cancel_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let typing_renew_handle = tokio::spawn({
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
let mgr = room_manager.clone();
|
||||
let rid = room_id_inner;
|
||||
let evt = typing_start.clone();
|
||||
async move {
|
||||
tokio::select! {
|
||||
_ = typing_cancel_rx => {}
|
||||
_ = async {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
mgr.broadcast_typing(rid, evt.clone()).await;
|
||||
}
|
||||
} => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let chunk_seq: Arc<std::sync::atomic::AtomicU64> =
|
||||
Arc::new(std::sync::atomic::AtomicU64::new(1));
|
||||
let all_chunks: Arc<std::sync::Mutex<Vec<(String, String)>>> =
|
||||
Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let answer_buffer: Arc<std::sync::Mutex<String>> =
|
||||
Arc::new(std::sync::Mutex::new(String::new()));
|
||||
|
||||
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
|
||||
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
let on_chunk = {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let streaming_msg_id = streaming_msg_id;
|
||||
let room_id = room_id_inner;
|
||||
let chunk_seq = chunk_seq.clone();
|
||||
let ai_display_name = ai_display_name.clone();
|
||||
let all_chunks = all_chunks.clone();
|
||||
let answer_buffer = answer_buffer.clone();
|
||||
|
||||
Arc::new(move |chunk_type: String, content: String, is_answer: bool| {
|
||||
let room_manager = room_manager.clone();
|
||||
let queue = queue.clone();
|
||||
let cancel = cancel.clone();
|
||||
let chunk_seq = chunk_seq.clone();
|
||||
let ai_display_name = ai_display_name.clone();
|
||||
let all_chunks = all_chunks.clone();
|
||||
let answer_buffer = answer_buffer.clone();
|
||||
|
||||
{
|
||||
let mut chunks = lock_or_recover(&all_chunks);
|
||||
chunks.push((chunk_type.clone(), content.clone()));
|
||||
}
|
||||
if is_answer {
|
||||
let mut ab = lock_or_recover(&answer_buffer);
|
||||
ab.push_str(&content);
|
||||
}
|
||||
|
||||
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
let event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id,
|
||||
seq: current_seq,
|
||||
content: content.clone(),
|
||||
done: false,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name),
|
||||
chunk_type: Some(chunk_type),
|
||||
};
|
||||
|
||||
Box::pin(async move {
|
||||
if cancel.load(std::sync::atomic::Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
queue.publish_stream_chunk(&event).await;
|
||||
room_manager.broadcast_stream_chunk(event).await;
|
||||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||||
})
|
||||
};
|
||||
|
||||
let result = run(chat_service, request, on_chunk).await;
|
||||
|
||||
let final_stream_content = lock_or_recover(&answer_buffer).clone();
|
||||
let final_event = RoomMessageStreamChunkEvent {
|
||||
message_id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
||||
content: final_stream_content.clone(),
|
||||
done: true,
|
||||
error: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
chunk_type: Some("answer".to_string()),
|
||||
};
|
||||
queue.publish_stream_chunk(&final_event).await;
|
||||
room_manager.broadcast_stream_chunk(final_event).await;
|
||||
|
||||
let (final_content, err_msg) = match result {
|
||||
Ok((content, _input, _output)) => (content, None),
|
||||
Err(e) => {
|
||||
let msg = "[AI 处理发生错误,请稍后再试]".to_string();
|
||||
tracing::error!(error = %e, "{} streaming failed", mode_name);
|
||||
(String::new(), Some(msg))
|
||||
}
|
||||
};
|
||||
|
||||
let all_chunks_data = lock_or_recover(&all_chunks).clone();
|
||||
let reasoning_chain: String = all_chunks_data
|
||||
.iter()
|
||||
.filter(|(t, _)| t != "answer")
|
||||
.map(|(_, c)| c.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let content_to_persist = if !final_content.is_empty() {
|
||||
final_content.clone()
|
||||
} else if !reasoning_chain.trim().is_empty() {
|
||||
format!(
|
||||
"[Agent ran through reasoning steps but did not produce a final answer.]\n{}",
|
||||
reasoning_chain.trim_end()
|
||||
)
|
||||
} else {
|
||||
String::from("[No output from reasoning agent]")
|
||||
};
|
||||
let content_to_persist = if let Some(msg) = &err_msg {
|
||||
format!("{}\n[Error: {}]", content_to_persist.trim_end(), msg)
|
||||
} else {
|
||||
content_to_persist
|
||||
};
|
||||
|
||||
let persist_content = content_to_persist.trim().to_string();
|
||||
if persist_content.is_empty() {
|
||||
let _ = typing_cancel_tx.send(());
|
||||
typing_renew_handle.abort();
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let thinking_content_serialized = {
|
||||
let chunks = lock_or_recover(&all_chunks);
|
||||
if chunks.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let chunks_json = serde_json::json!({
|
||||
"__chunks__": chunks.iter().map(|(t, c)| serde_json::json!({
|
||||
"type": t,
|
||||
"content": c,
|
||||
})).collect::<Vec<_>>(),
|
||||
});
|
||||
Some(chunks_json.to_string())
|
||||
}
|
||||
};
|
||||
let thinking_content_for_event = thinking_content_serialized.clone();
|
||||
|
||||
let envelope = RoomMessageEnvelope {
|
||||
id: streaming_msg_id,
|
||||
dedup_key: Some(format!("{}:{}", room_id_inner, streaming_msg_id)),
|
||||
room_id: room_id_inner,
|
||||
sender_type: sender_type.clone(),
|
||||
sender_id: None,
|
||||
model_id: Some(model_id),
|
||||
thread_id: None,
|
||||
content: persist_content.clone(),
|
||||
content_type: "text".to_string(),
|
||||
thinking_content: thinking_content_serialized,
|
||||
send_at: now,
|
||||
seq,
|
||||
in_reply_to: None,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
};
|
||||
|
||||
if let Err(e) = queue.publish(room_id_inner, envelope).await {
|
||||
tracing::error!(error = %e, "Failed to publish {} streaming message", mode_name);
|
||||
} else {
|
||||
let now = Utc::now();
|
||||
if let Err(e) = room_ai::Entity::update_many()
|
||||
.col_expr(room_ai::Column::CallCount, Expr::col(room_ai::Column::CallCount).add(1))
|
||||
.col_expr(room_ai::Column::LastCallAt, Expr::value(Some(now)))
|
||||
.filter(room_ai::Column::Room.eq(room_id_inner))
|
||||
.filter(room_ai::Column::Model.eq(model_id))
|
||||
.exec(&db)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "Failed to update room_ai call stats");
|
||||
}
|
||||
|
||||
let msg_event = queue::RoomMessageEvent {
|
||||
id: streaming_msg_id,
|
||||
room_id: room_id_inner,
|
||||
sender_type: sender_type.clone(),
|
||||
sender_id: None,
|
||||
thread_id: None,
|
||||
content: persist_content,
|
||||
content_type: "text".to_string(),
|
||||
thinking_content: thinking_content_for_event,
|
||||
send_at: now,
|
||||
seq,
|
||||
display_name: Some(ai_display_name.clone()),
|
||||
in_reply_to: None,
|
||||
reactions: None,
|
||||
message_id: None,
|
||||
};
|
||||
room_manager.broadcast(room_id_inner, msg_event).await;
|
||||
room_manager.metrics.messages_sent.increment(1);
|
||||
|
||||
let event = ProjectRoomEvent {
|
||||
event_type: crate::RoomEventType::NewMessage.as_str().into(),
|
||||
project_id: project_id_inner,
|
||||
room_id: Some(room_id_inner),
|
||||
category_id: None,
|
||||
message_id: Some(streaming_msg_id),
|
||||
seq: Some(seq),
|
||||
timestamp: now,
|
||||
};
|
||||
queue.publish_project_room_event(project_id_inner, event).await;
|
||||
}
|
||||
|
||||
let _ = typing_cancel_tx.send(());
|
||||
typing_renew_handle.abort();
|
||||
let typing_stop = queue::TypingEvent {
|
||||
room_id: room_id_inner,
|
||||
user_id: ai_typing_id,
|
||||
username: ai_display_name.clone(),
|
||||
avatar_url: None,
|
||||
action: "stop".to_string(),
|
||||
sender_type: Some("ai".to_string()),
|
||||
};
|
||||
room_manager.broadcast_typing(room_id_inner, typing_stop).await;
|
||||
|
||||
room_manager.unregister_stream_cancel(room_id_inner).await;
|
||||
room_manager.close_stream_channel(streaming_msg_id).await;
|
||||
});
|
||||
}
|
||||
292
libs/room/src/service/ai_service.rs
Normal file
292
libs/room/src/service/ai_service.rs
Normal file
@ -0,0 +1,292 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use db::cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
use models::rooms::room_ai;
|
||||
use queue::MessageProducer;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
use crate::connection::RoomConnectionManager;
|
||||
use crate::error::RoomError;
|
||||
use crate::service::ai_common::create_and_publish_ai_message;
|
||||
use crate::service::ai_mode_dispatch;
|
||||
use crate::service::ai_nonstreaming;
|
||||
use crate::service::ai_react_nonstreaming;
|
||||
use crate::service::ai_react_streaming;
|
||||
use crate::service::ai_streaming;
|
||||
use crate::service::history;
|
||||
use crate::service::patterns::{mention_bracket_re, mention_tag_re};
|
||||
use agent::chat::{AiRequest, ChatService};
|
||||
|
||||
/// Service responsible for AI message generation orchestration.
|
||||
/// Decides which execution path to use (streaming/nonstreaming, ReAct/chat)
|
||||
/// and dispatches accordingly.
|
||||
#[derive(Clone)]
|
||||
pub struct RoomAiService {
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
config: config::AppConfig,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
chat_service: Option<Arc<ChatService>>,
|
||||
}
|
||||
|
||||
impl RoomAiService {
|
||||
pub fn new(
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
config: config::AppConfig,
|
||||
queue: MessageProducer,
|
||||
room_manager: Arc<RoomConnectionManager>,
|
||||
chat_service: Option<Arc<ChatService>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
cache,
|
||||
config,
|
||||
queue,
|
||||
room_manager,
|
||||
chat_service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the message content mentions an AI model configured in this room.
|
||||
pub async fn should_respond(&self, room_id: Uuid, content: &str) -> Result<bool, RoomError> {
|
||||
let ai_configs = history::get_room_ai_configs(&self.db, room_id).await?;
|
||||
if ai_configs.is_empty() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let model_ids: std::collections::HashSet<String> = ai_configs
|
||||
.iter()
|
||||
.map(|c| c.model.to_string())
|
||||
.collect();
|
||||
|
||||
for cap in mention_bracket_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for cap in mention_tag_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" && model_ids.contains(id_m.as_str().trim()) {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Extract the mentioned AI model ID from message content.
|
||||
fn extract_mentioned_model_id(content: &str) -> Option<Uuid> {
|
||||
for cap in mention_bracket_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" {
|
||||
if let Ok(uuid) = Uuid::parse_str(id_m.as_str().trim()) {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for cap in mention_tag_re().captures_iter(content) {
|
||||
if let (Some(type_m), Some(id_m)) = (cap.get(1), cap.get(2)) {
|
||||
if type_m.as_str() == "ai" {
|
||||
if let Ok(uuid) = Uuid::parse_str(id_m.as_str().trim()) {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Main entry point: process an AI message request.
|
||||
/// Handles model lookup, history building, locking, and dispatch.
|
||||
pub async fn process(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
sender_id: Uuid,
|
||||
content: &str,
|
||||
) -> Result<(), RoomError> {
|
||||
let chat_service = match &self.chat_service {
|
||||
Some(cs) => cs.clone(),
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let model_id = match Self::extract_mentioned_model_id(content) {
|
||||
Some(id) => id,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let ai_config = match room_ai::Entity::find()
|
||||
.filter(room_ai::Column::Room.eq(room_id))
|
||||
.filter(room_ai::Column::Model.eq(model_id))
|
||||
.one(&self.db)
|
||||
.await?
|
||||
{
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
// Idempotency check: skip if this content already triggered AI within 60s
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
content.hash(&mut hasher);
|
||||
let idemp_key = format!("ai:idempot:{}:{}", room_id, hasher.finish());
|
||||
{
|
||||
let mut conn = self.cache.conn().await.map_err(|e| {
|
||||
RoomError::Internal(format!("cache conn: {}", e))
|
||||
})?;
|
||||
let exists = redis::cmd("SET")
|
||||
.arg(&idemp_key)
|
||||
.arg("1")
|
||||
.arg("NX")
|
||||
.arg("EX")
|
||||
.arg(60u64)
|
||||
.query_async::<Option<String>>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| RoomError::Internal(format!("idemp SET: {}", e)))?
|
||||
.is_some();
|
||||
if !exists {
|
||||
tracing::debug!(room_id = %room_id, "AI idempotency hit, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let lock_guard =
|
||||
match crate::room_ai_queue::acquire_room_ai_lock(&self.cache, room_id).await? {
|
||||
Some(g) => g,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let room = crate::service::find_room_or_404(&self.db, room_id).await?;
|
||||
|
||||
let project = models::projects::project::Entity::find_by_id(room.project)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("Project not found".to_string()))?;
|
||||
|
||||
let model = models::agents::model::Entity::find_by_id(model_id)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("AI model not found".to_string()))?;
|
||||
|
||||
let sender = models::users::User::find_by_id(sender_id)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| RoomError::NotFound("Sender not found".to_string()))?;
|
||||
|
||||
let history_messages = history::get_room_history(&self.db, room_id, 50).await?;
|
||||
|
||||
let user_ids: Vec<Uuid> = history_messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.chain(std::iter::once(sender_id))
|
||||
.collect();
|
||||
|
||||
let user_names = history::get_user_names(&self.db, &user_ids).await;
|
||||
|
||||
let mentions =
|
||||
history::extract_mention_context(&self.db, room.project, content).await;
|
||||
|
||||
let request = AiRequest {
|
||||
db: self.db.clone(),
|
||||
cache: self.cache.clone(),
|
||||
config: self.config.clone(),
|
||||
model,
|
||||
project: project.clone(),
|
||||
sender,
|
||||
room: room.clone(),
|
||||
input: content.to_string(),
|
||||
mention: mentions,
|
||||
history: history_messages,
|
||||
user_names,
|
||||
temperature: ai_config.temperature.unwrap_or(0.7),
|
||||
max_tokens: ai_config.max_tokens.unwrap_or(4096) as i32,
|
||||
top_p: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
think: ai_config.think,
|
||||
tools: Some(chat_service.tools()),
|
||||
max_tool_depth: 1000,
|
||||
};
|
||||
|
||||
let use_streaming = ai_config.stream;
|
||||
|
||||
match ai_config.agent_type.as_deref() {
|
||||
Some("cot") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_cot(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
if let Ok((content, _in, _out)) = chat_service.process_cot(&request, |_step| async {}).await {
|
||||
let _ = create_and_publish_ai_message(
|
||||
&self.db, &self.cache, &self.queue, &self.room_manager,
|
||||
room_id, room.project, uuid::Uuid::new_v4(), content, model_id,
|
||||
Some(request.model.name.clone()),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("rewoo") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_rewoo(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("reflexion") => {
|
||||
if use_streaming {
|
||||
ai_mode_dispatch::dispatch_reflexion(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Some("react") | _ => {
|
||||
if ai_config.agent_type.as_deref() == Some("react") {
|
||||
if use_streaming {
|
||||
ai_react_streaming::process_message_ai_react_streaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_react_nonstreaming::process_message_ai_react_nonstreaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
} else if use_streaming {
|
||||
ai_streaming::process_message_ai_streaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
} else {
|
||||
ai_nonstreaming::process_message_ai_nonstreaming(
|
||||
chat_service, request, room_id, room.project, model_id,
|
||||
lock_guard, self.db.clone(), self.cache.clone(),
|
||||
self.queue.clone(), self.room_manager.clone(),
|
||||
).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user