use std::sync::Arc; use std::sync::atomic::Ordering; use actix_web::{ Error, HttpRequest, HttpResponse, error::{ErrorBadRequest, ErrorInternalServerError, ErrorNotFound}, http::header, web::{self, Bytes, Data, Payload, ServiceConfig}, }; use futures_util::StreamExt; use serde_json::json; use session::SessionExt; use crate::{ engine_packet::{ EnginePacket, SocketPayload, decode_engine_payload, decode_engine_text_packet, encode_engine_packet, encode_engine_payload, }, error::SocketIoError, server::SocketIo, session::{Session, Transport}, socket::DisconnectReason, }; struct ActiveGuard(Arc); impl Drop for ActiveGuard { fn drop(&mut self) { self.0.store(false, Ordering::Release); } } #[derive(Debug, serde::Deserialize)] struct EngineQuery { #[serde(rename = "EIO")] eio: Option, transport: Option, sid: Option, } pub fn configure(cfg: &mut ServiceConfig, io: SocketIo) { let path = io.config().path.clone(); configure_at(cfg, path, io); } pub fn configure_at( cfg: &mut ServiceConfig, path: impl Into, io: SocketIo, ) { cfg.app_data(Data::new(io)).service( web::resource(path.into()) .route(web::get().to(engine_get)) .route(web::post().to(engine_post)), ); } async fn engine_get( io: Data, req: HttpRequest, stream: Payload, query: web::Query, ) -> Result { validate_eio(&query)?; match query.transport.as_deref() { Some("polling") if query.sid.is_none() => polling_open(io, &req).await, Some("polling") => polling_get(io, query.sid.as_deref()).await, Some("websocket") => { websocket_open(io, req, stream, query.sid.clone()).await } _ => Err(ErrorBadRequest("unsupported transport")), } } async fn engine_post( io: Data, query: web::Query, body: Bytes, ) -> Result { validate_eio(&query)?; if query.transport.as_deref() != Some("polling") { return Err(ErrorBadRequest("unsupported transport")); } if body.len() > io.config().max_payload { return Err(ErrorBadRequest("payload too large")); } let sid = query .sid .as_deref() .ok_or_else(|| ErrorBadRequest("missing sid"))?; validate_sid(sid)?; let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; if session.post_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { io.remove_session(&session, DisconnectReason::TransportClosed) .await; return Err(ErrorBadRequest("overlapping post request")); } let _guard = ActiveGuard(session.post_active.clone()); handle_polling_body(&io, session, &body).await?; Ok(HttpResponse::Ok() .insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8")) .body("ok")) } async fn handle_polling_body( io: &SocketIo, session: Arc, body: &Bytes, ) -> Result<(), Error> { touch_session(&session).await; let payload = std::str::from_utf8(body).map_err(ErrorBadRequest)?; let packets = decode_engine_payload(payload).map_err(map_socket_error)?; for packet in packets { handle_engine_packet(io, session.clone(), packet).await?; } Ok(()) } async fn polling_open( io: Data, req: &HttpRequest, ) -> Result { let session = Session::new(req.get_session().user()); let sid = session.engine_id.clone(); io.insert_session(session).await; let open = json!({ "sid": sid, "upgrades": ["websocket"], "pingInterval": io.config().ping_interval.as_millis(), "pingTimeout": io.config().ping_timeout.as_millis(), "maxPayload": io.config().max_payload }); Ok(text_response(encode_engine_packet( &EnginePacket::Open(open), true, ))) } async fn polling_get( io: Data, sid: Option<&str>, ) -> Result { let sid = sid.ok_or_else(|| ErrorBadRequest("missing sid"))?; validate_sid(sid)?; let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?; if session.get_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() { io.remove_session(&session, DisconnectReason::TransportClosed) .await; return Err(ErrorBadRequest("overlapping get request")); } let _guard = ActiveGuard(session.get_active.clone()); loop { let packets = session.drain().await; if !packets.is_empty() { return Ok(text_response(encode_engine_payload(&packets, true))); } tokio::select! { () = session.notify.notified() => { } () = tokio::time::sleep(io.config().ping_interval) => { return Ok(text_response(encode_engine_payload( &[EnginePacket::Ping(None)], true, ))); } } } } async fn websocket_open( io: Data, req: HttpRequest, stream: Payload, sid: Option, ) -> Result { let direct_websocket = sid.is_none(); let session = match sid { Some(ref sid) => { validate_sid(sid)?; io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))? } None => { let session = Session::new(req.get_session().user()); io.insert_session(session.clone()).await; session } }; let (response, mut ws_session, messages) = actix_ws::handle(&req, stream)?; let io = io.get_ref().clone(); actix_web::rt::spawn(async move { if direct_websocket { *session.transport.lock().await = Transport::WebSocket; let open = json!({ "sid": session.engine_id, "upgrades": [], "pingInterval": io.config().ping_interval.as_millis(), "pingTimeout": io.config().ping_timeout.as_millis(), "maxPayload": io.config().max_payload }); if ws_session .text(encode_engine_packet(&EnginePacket::Open(open), false)) .await .is_err() { io.remove_session(&session, DisconnectReason::TransportClosed) .await; return; } } websocket_loop(io, session, ws_session, messages, direct_websocket) .await; }); Ok(response) } async fn websocket_loop( io: SocketIo, session: Arc, mut ws_session: actix_ws::Session, mut messages: actix_ws::MessageStream, mut upgraded: bool, ) { let mut heartbeat = tokio::time::interval_at( tokio::time::Instant::now() + io.config().ping_interval, io.config().ping_interval, ); loop { tokio::select! { message = messages.next() => { match message { Some(Ok(actix_ws::Message::Text(text))) => { touch_session(&session).await; match decode_engine_text_packet(text.as_ref()) { Ok(EnginePacket::Ping(Some(value))) if value == "probe" && !upgraded => { if ws_session.text("3probe").await.is_err() { break; } } Ok(EnginePacket::Upgrade) if !upgraded => { *session.transport.lock().await = Transport::WebSocket; upgraded = true; } Ok(packet) => { if handle_engine_packet(&io, session.clone(), packet).await.is_err() { break; } } Err(_) => break, } } Some(Ok(actix_ws::Message::Binary(bytes))) => { touch_session(&session).await; if handle_engine_packet( &io, session.clone(), EnginePacket::Message(SocketPayload::Binary(bytes.to_vec())), ) .await .is_err() { break; } } Some(Ok(actix_ws::Message::Ping(bytes))) => { touch_session(&session).await; if ws_session.pong(&bytes).await.is_err() { break; } } Some(Ok(actix_ws::Message::Pong(_))) => { touch_session(&session).await; } Some(Ok(actix_ws::Message::Close(reason))) => { let _ = ws_session.close(reason).await; break; } Some(Ok(actix_ws::Message::Continuation(_))) => {} Some(Ok(actix_ws::Message::Nop)) => {} Some(Err(_)) | None => break, } } () = session.notify.notified() => { for packet in session.drain().await { if send_ws_packet(&mut ws_session, packet).await.is_err() { io.remove_session(&session, DisconnectReason::TransportClosed).await; return; } } } _ = heartbeat.tick() => { if session.last_pong.lock().await.elapsed() > io.config().ping_interval + io.config().ping_timeout { io.remove_session(&session, DisconnectReason::PingTimeout).await; return; } if ws_session.text("2").await.is_err() { io.remove_session(&session, DisconnectReason::TransportClosed).await; return; } } } } io.remove_session(&session, DisconnectReason::TransportClosed) .await; } async fn touch_session(session: &Session) { *session.last_pong.lock().await = std::time::Instant::now(); } async fn send_ws_packet( ws_session: &mut actix_ws::Session, packet: EnginePacket, ) -> std::result::Result<(), actix_ws::Closed> { match packet { EnginePacket::Message(SocketPayload::Binary(bytes)) => { ws_session.binary(bytes).await } packet => ws_session.text(encode_engine_packet(&packet, false)).await, } } async fn handle_engine_packet( io: &SocketIo, session: Arc, packet: EnginePacket, ) -> Result<(), Error> { match packet { EnginePacket::Ping(data) => { session.enqueue(EnginePacket::Pong(data)).await; } EnginePacket::Pong(_) => { *session.last_pong.lock().await = std::time::Instant::now(); } EnginePacket::Message(payload) => { io.handle_socket_payload(session, payload) .await .map_err(map_socket_error)?; } EnginePacket::Close => { io.remove_session(&session, DisconnectReason::Client).await; } EnginePacket::Open(_) => { tracing::warn!("client sent unexpected Open packet"); } EnginePacket::Upgrade | EnginePacket::Noop => {} } Ok(()) } fn validate_eio(query: &EngineQuery) -> Result<(), Error> { match query.eio.as_deref() { Some("4") => Ok(()), _ => Err(ErrorBadRequest("unsupported EIO version")), } } fn validate_sid(sid: &str) -> Result<(), Error> { if sid.is_empty() || sid.len() > 128 || !sid.bytes().all(|byte| byte.is_ascii_graphic()) { return Err(ErrorBadRequest("invalid sid")); } Ok(()) } fn text_response(body: String) -> HttpResponse { HttpResponse::Ok() .insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8")) .body(body) } fn map_socket_error(err: SocketIoError) -> Error { match err { SocketIoError::UnknownSession | SocketIoError::UnknownNamespace(_) => { ErrorNotFound(err) } SocketIoError::InvalidPacket(_) => ErrorBadRequest(err), _ => ErrorInternalServerError(err), } }