99 lines
3.7 KiB
Rust
99 lines
3.7 KiB
Rust
use actix_web::web::Bytes;
|
|
use actix_web::{HttpRequest, HttpResponse, web};
|
|
use tokio_stream::StreamExt;
|
|
use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
|
|
use uuid::Uuid;
|
|
|
|
use queue::RoomMessageStreamChunkEvent;
|
|
use service::AppService;
|
|
|
|
/// SSE endpoint: GET /ws/ai-stream/{room_id}/{message_id}
|
|
pub async fn ws_ai_stream(
|
|
service: web::Data<AppService>,
|
|
req: HttpRequest,
|
|
path: web::Path<(Uuid, Uuid)>,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
let (room_id, message_id) = path.into_inner();
|
|
|
|
// Prefer Authorization header over query parameter
|
|
let user_id = if let Some(auth_header) = req.headers().get("Authorization") {
|
|
if let Ok(auth_str) = auth_header.to_str() {
|
|
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
|
match service.ws_token.validate_token(token).await {
|
|
Ok(uid) => uid,
|
|
Err(_) => return Err(actix_web::error::ErrorUnauthorized("invalid token")),
|
|
}
|
|
} else {
|
|
return Err(actix_web::error::ErrorUnauthorized("invalid auth header"));
|
|
}
|
|
} else {
|
|
return Err(actix_web::error::ErrorUnauthorized("invalid auth header"));
|
|
}
|
|
} else if let Some(token) = req.uri().query().and_then(|q| {
|
|
q.split('&')
|
|
.find(|p| p.starts_with("token="))
|
|
.and_then(|p| p.split('=').nth(1))
|
|
}) {
|
|
match service.ws_token.validate_token(token).await {
|
|
Ok(uid) => uid,
|
|
Err(_) => return Err(actix_web::error::ErrorUnauthorized("invalid token")),
|
|
}
|
|
} else {
|
|
return Err(actix_web::error::ErrorUnauthorized("no auth provided"));
|
|
};
|
|
|
|
if let Err(e) = service.room.check_room_access(room_id, user_id).await {
|
|
tracing::warn!(user_id = %user_id, room_id = %room_id, error = ?e, "AI SSE: access denied");
|
|
return Err(actix_web::error::ErrorForbidden("access denied").into());
|
|
}
|
|
|
|
tracing::info!(user_id = %user_id, room_id = %room_id, message_id = %message_id, "AI SSE stream opened");
|
|
|
|
let manager = service.room.room_manager.clone();
|
|
|
|
let stream_rx = manager.subscribe_stream(message_id).await;
|
|
let stream_rx = match stream_rx {
|
|
Some(rx) => rx,
|
|
None => return Err(actix_web::error::ErrorNotFound("stream not found").into()),
|
|
};
|
|
|
|
let sse_stream = BroadcastStream::new(stream_rx).map(move |result| match result {
|
|
Ok(chunk) => {
|
|
let data = format_sse_chunk(&chunk);
|
|
if chunk.done {
|
|
Ok::<_, std::io::Error>(Bytes::from(format!("{}event: done\ndata: \n\n", data)))
|
|
} else {
|
|
Ok::<_, std::io::Error>(Bytes::from(data))
|
|
}
|
|
}
|
|
Err(BroadcastStreamRecvError::Lagged(_)) => {
|
|
tracing::warn!(message_id = %message_id, "SSE subscriber lagged");
|
|
Err(std::io::Error::new(
|
|
std::io::ErrorKind::TimedOut,
|
|
"stream lagged",
|
|
))
|
|
}
|
|
});
|
|
|
|
Ok(HttpResponse::Ok()
|
|
.content_type("text/event-stream")
|
|
.append_header(("Cache-Control", "no-cache"))
|
|
.append_header(("Connection", "keep-alive"))
|
|
.append_header(("X-Accel-Buffering", "no"))
|
|
.streaming(sse_stream))
|
|
}
|
|
|
|
fn format_sse_chunk(chunk: &RoomMessageStreamChunkEvent) -> String {
|
|
let json = serde_json::json!({
|
|
"message_id": chunk.message_id,
|
|
"room_id": chunk.room_id,
|
|
"seq": chunk.seq,
|
|
"content": chunk.content,
|
|
"done": chunk.done,
|
|
"error": chunk.error,
|
|
"display_name": chunk.display_name,
|
|
"chunk_type": chunk.chunk_type,
|
|
});
|
|
format!("event: chunk\ndata: {}\n\n", json)
|
|
}
|