gitdataai/libs/room/src/service/ai_mode_streaming.rs
ZhenYi fdca1fbf86
Some checks are pending
CI / Rust Lint & Check (push) Waiting to run
CI / Rust Tests (push) Waiting to run
CI / Frontend Lint & Type Check (push) Waiting to run
CI / Frontend Build (push) Blocked by required conditions
feat(ai): add comprehensive AI streaming and non-streaming processing services
2026-05-01 00:54:24 +08:00

325 lines
12 KiB
Rust

use std::pin::Pin;
use std::sync::Arc;
use chrono::Utc;
use db::cache::AppCache;
use db::database::AppDatabase;
use models::rooms::room_ai;
use queue::{MessageProducer, ProjectRoomEvent, RoomMessageEnvelope};
use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, ExprTrait, QueryFilter};
use uuid::Uuid;
use super::sequence::next_room_message_seq_internal;
use crate::connection::RoomConnectionManager;
use agent::chat::{AiRequest, ChatService};
pub type RunModeFn = Box<
dyn FnOnce(
Arc<ChatService>,
AiRequest,
Arc<dyn Fn(String, String, bool) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>,
) -> Pin<Box<dyn std::future::Future<Output = Result<(String, i64, i64), agent::AgentError>> + Send>>
+ Send,
>;
pub async fn run_mode_streaming(
chat_service: Arc<ChatService>,
request: AiRequest,
room_id: Uuid,
project_id: Uuid,
model_id: Uuid,
lock_guard: crate::room_ai_queue::RoomAiLockGuard,
db: AppDatabase,
cache: AppCache,
queue: MessageProducer,
room_manager: Arc<RoomConnectionManager>,
mode_name_str: &str,
run: RunModeFn,
) {
let mode_name = mode_name_str.to_string();
use queue::RoomMessageStreamChunkEvent;
let streaming_msg_id = Uuid::now_v7();
let seq = match next_room_message_seq_internal(room_id, &db, &cache).await {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "Failed to get seq for {} streaming", mode_name);
return;
}
};
let _ = room_manager
.register_stream_channel(streaming_msg_id)
.await;
let initial_event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id,
seq: 0,
content: String::new(),
done: false,
error: None,
display_name: Some(request.model.name.clone()),
chunk_type: Some("thinking".to_string()),
};
room_manager.broadcast_stream_chunk(initial_event).await;
let room_id_inner = room_id;
let project_id_inner = project_id;
let now = Utc::now();
let sender_type = "ai".to_string();
let ai_display_name = request.model.name.clone();
tokio::spawn(async move {
let _lock_guard = lock_guard;
let cancel = room_manager.register_stream_cancel(room_id_inner).await;
let ai_typing_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001")
.expect("constant UUID should always parse");
let typing_start = queue::TypingEvent {
room_id: room_id_inner,
user_id: ai_typing_id,
username: ai_display_name.clone(),
avatar_url: None,
action: "start".to_string(),
sender_type: Some("ai".to_string()),
};
room_manager.broadcast_typing(room_id_inner, typing_start.clone()).await;
let (typing_cancel_tx, typing_cancel_rx) = tokio::sync::oneshot::channel::<()>();
let typing_renew_handle = tokio::spawn({
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mgr = room_manager.clone();
let rid = room_id_inner;
let evt = typing_start.clone();
async move {
tokio::select! {
_ = typing_cancel_rx => {}
_ = async {
loop {
interval.tick().await;
mgr.broadcast_typing(rid, evt.clone()).await;
}
} => {}
}
}
});
let chunk_seq: Arc<std::sync::atomic::AtomicU64> =
Arc::new(std::sync::atomic::AtomicU64::new(1));
let all_chunks: Arc<std::sync::Mutex<Vec<(String, String)>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let answer_buffer: Arc<std::sync::Mutex<String>> =
Arc::new(std::sync::Mutex::new(String::new()));
fn lock_or_recover<T>(mutex: &std::sync::Mutex<T>) -> std::sync::MutexGuard<'_, T> {
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
}
let on_chunk = {
let room_manager = room_manager.clone();
let queue = queue.clone();
let cancel = cancel.clone();
let streaming_msg_id = streaming_msg_id;
let room_id = room_id_inner;
let chunk_seq = chunk_seq.clone();
let ai_display_name = ai_display_name.clone();
let all_chunks = all_chunks.clone();
let answer_buffer = answer_buffer.clone();
Arc::new(move |chunk_type: String, content: String, is_answer: bool| {
let room_manager = room_manager.clone();
let queue = queue.clone();
let cancel = cancel.clone();
let chunk_seq = chunk_seq.clone();
let ai_display_name = ai_display_name.clone();
let all_chunks = all_chunks.clone();
let answer_buffer = answer_buffer.clone();
{
let mut chunks = lock_or_recover(&all_chunks);
chunks.push((chunk_type.clone(), content.clone()));
}
if is_answer {
let mut ab = lock_or_recover(&answer_buffer);
ab.push_str(&content);
}
let current_seq = chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id,
seq: current_seq,
content: content.clone(),
done: false,
error: None,
display_name: Some(ai_display_name),
chunk_type: Some(chunk_type),
};
Box::pin(async move {
if cancel.load(std::sync::atomic::Ordering::Acquire) {
return;
}
queue.publish_stream_chunk(&event).await;
room_manager.broadcast_stream_chunk(event).await;
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
})
};
let result = run(chat_service, request, on_chunk).await;
let final_stream_content = lock_or_recover(&answer_buffer).clone();
let final_event = RoomMessageStreamChunkEvent {
message_id: streaming_msg_id,
room_id: room_id_inner,
seq: chunk_seq.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
content: final_stream_content.clone(),
done: true,
error: None,
display_name: Some(ai_display_name.clone()),
chunk_type: Some("answer".to_string()),
};
queue.publish_stream_chunk(&final_event).await;
room_manager.broadcast_stream_chunk(final_event).await;
let (final_content, err_msg) = match result {
Ok((content, _input, _output)) => (content, None),
Err(e) => {
let msg = format!("AI 处理失败: {}", e);
tracing::error!(error = ?e, "{} streaming failed", mode_name);
(String::new(), Some(msg))
}
};
let all_chunks_data = lock_or_recover(&all_chunks).clone();
let reasoning_chain: String = all_chunks_data
.iter()
.filter(|(t, _)| t != "answer")
.map(|(_, c)| c.clone())
.collect::<Vec<_>>()
.join("\n");
let content_to_persist = if !final_content.is_empty() {
final_content.clone()
} else if !reasoning_chain.trim().is_empty() {
format!(
"[Agent ran through reasoning steps but did not produce a final answer.]\n{}",
reasoning_chain.trim_end()
)
} else {
String::from("[No output from reasoning agent]")
};
let content_to_persist = if let Some(msg) = &err_msg {
format!("{}\n[Error: {}]", content_to_persist.trim_end(), msg)
} else {
content_to_persist
};
let persist_content = content_to_persist.trim().to_string();
if persist_content.is_empty() {
let _ = typing_cancel_tx.send(());
typing_renew_handle.abort();
room_manager.unregister_stream_cancel(room_id_inner).await;
room_manager.close_stream_channel(streaming_msg_id).await;
return;
}
let thinking_content_serialized = {
let chunks = lock_or_recover(&all_chunks);
if chunks.is_empty() {
None
} else {
let chunks_json = serde_json::json!({
"__chunks__": chunks.iter().map(|(t, c)| serde_json::json!({
"type": t,
"content": c,
})).collect::<Vec<_>>(),
});
Some(chunks_json.to_string())
}
};
let thinking_content_for_event = thinking_content_serialized.clone();
let envelope = RoomMessageEnvelope {
id: streaming_msg_id,
dedup_key: Some(format!("{}:{}", room_id_inner, streaming_msg_id)),
room_id: room_id_inner,
sender_type: sender_type.clone(),
sender_id: None,
model_id: Some(model_id),
thread_id: None,
content: persist_content.clone(),
content_type: "text".to_string(),
thinking_content: thinking_content_serialized,
send_at: now,
seq,
in_reply_to: None,
display_name: Some(ai_display_name.clone()),
};
if let Err(e) = queue.publish(room_id_inner, envelope).await {
tracing::error!(error = %e, "Failed to publish {} streaming message", mode_name);
} else {
let now = Utc::now();
if let Err(e) = room_ai::Entity::update_many()
.col_expr(room_ai::Column::CallCount, Expr::col(room_ai::Column::CallCount).add(1))
.col_expr(room_ai::Column::LastCallAt, Expr::value(Some(now)))
.filter(room_ai::Column::Room.eq(room_id_inner))
.filter(room_ai::Column::Model.eq(model_id))
.exec(&db)
.await
{
tracing::warn!(error = %e, "Failed to update room_ai call stats");
}
let msg_event = queue::RoomMessageEvent {
id: streaming_msg_id,
room_id: room_id_inner,
sender_type: sender_type.clone(),
sender_id: None,
thread_id: None,
content: persist_content,
content_type: "text".to_string(),
thinking_content: thinking_content_for_event,
send_at: now,
seq,
display_name: Some(ai_display_name.clone()),
in_reply_to: None,
reactions: None,
message_id: None,
};
room_manager.broadcast(room_id_inner, msg_event).await;
room_manager.metrics.messages_sent.increment(1);
let event = ProjectRoomEvent {
event_type: crate::RoomEventType::NewMessage.as_str().into(),
project_id: project_id_inner,
room_id: Some(room_id_inner),
category_id: None,
message_id: Some(streaming_msg_id),
seq: Some(seq),
timestamp: now,
};
queue.publish_project_room_event(project_id_inner, event).await;
}
let _ = typing_cancel_tx.send(());
typing_renew_handle.abort();
let typing_stop = queue::TypingEvent {
room_id: room_id_inner,
user_id: ai_typing_id,
username: ai_display_name.clone(),
avatar_url: None,
action: "stop".to_string(),
sender_type: Some("ai".to_string()),
};
room_manager.broadcast_typing(room_id_inner, typing_stop).await;
room_manager.unregister_stream_cancel(room_id_inner).await;
room_manager.close_stream_channel(streaming_msg_id).await;
});
}