use actix_web::{web, HttpRequest, HttpResponse}; use actix_web::web::Bytes; use tokio_stream::StreamExt; use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError}; use uuid::Uuid; use service::AppService; use queue::RoomMessageStreamChunkEvent; /// SSE endpoint: GET /ws/ai-stream/{room_id}/{message_id} pub async fn ws_ai_stream( service: web::Data, req: HttpRequest, path: web::Path<(Uuid, Uuid)>, ) -> Result { let (room_id, message_id) = path.into_inner(); // Prefer Authorization header over query parameter let user_id = if let Some(auth_header) = req.headers().get("Authorization") { if let Ok(auth_str) = auth_header.to_str() { if let Some(token) = auth_str.strip_prefix("Bearer ") { match service.ws_token.validate_token(token).await { Ok(uid) => uid, Err(_) => return Err(actix_web::error::ErrorUnauthorized("invalid token")), } } else { return Err(actix_web::error::ErrorUnauthorized("invalid auth header")); } } else { return Err(actix_web::error::ErrorUnauthorized("invalid auth header")); } } else if let Some(token) = req.uri().query().and_then(|q| { q.split('&').find(|p| p.starts_with("token=")).and_then(|p| p.split('=').nth(1)) }) { match service.ws_token.validate_token(token).await { Ok(uid) => uid, Err(_) => return Err(actix_web::error::ErrorUnauthorized("invalid token")), } } else { return Err(actix_web::error::ErrorUnauthorized("no auth provided")); }; if let Err(e) = service.room.check_room_access(room_id, user_id).await { tracing::warn!(user_id = %user_id, room_id = %room_id, error = ?e, "AI SSE: access denied"); return Err(actix_web::error::ErrorForbidden("access denied").into()); } tracing::info!(user_id = %user_id, room_id = %room_id, message_id = %message_id, "AI SSE stream opened"); let manager = service.room.room_manager.clone(); let stream_rx = manager.subscribe_stream(message_id).await; let stream_rx = match stream_rx { Some(rx) => rx, None => return Err(actix_web::error::ErrorNotFound("stream not found").into()), }; let sse_stream = BroadcastStream::new(stream_rx) .map(move |result| match result { Ok(chunk) => { let data = format_sse_chunk(&chunk); if chunk.done { Ok::<_, std::io::Error>(Bytes::from(format!("{}event: done\ndata: \n\n", data))) } else { Ok::<_, std::io::Error>(Bytes::from(data)) } } Err(BroadcastStreamRecvError::Lagged(_)) => { tracing::warn!(message_id = %message_id, "SSE subscriber lagged"); Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "stream lagged")) } }); Ok(HttpResponse::Ok() .content_type("text/event-stream") .append_header(("Cache-Control", "no-cache")) .append_header(("Connection", "keep-alive")) .append_header(("X-Accel-Buffering", "no")) .streaming(sse_stream)) } fn format_sse_chunk(chunk: &RoomMessageStreamChunkEvent) -> String { let json = serde_json::json!({ "message_id": chunk.message_id, "room_id": chunk.room_id, "seq": chunk.seq, "content": chunk.content, "done": chunk.done, "error": chunk.error, "display_name": chunk.display_name, "chunk_type": chunk.chunk_type, }); format!("event: chunk\ndata: {}\n\n", json) }