393 lines
12 KiB
Rust
393 lines
12 KiB
Rust
use std::sync::Arc;
|
|
use std::sync::atomic::Ordering;
|
|
|
|
use actix_web::{
|
|
Error, HttpRequest, HttpResponse,
|
|
error::{ErrorBadRequest, ErrorInternalServerError, ErrorNotFound},
|
|
http::header,
|
|
web::{self, Bytes, Data, Payload, ServiceConfig},
|
|
};
|
|
use futures_util::StreamExt;
|
|
use serde_json::json;
|
|
use session::SessionExt;
|
|
|
|
use crate::{
|
|
engine_packet::{
|
|
EnginePacket, SocketPayload, decode_engine_payload,
|
|
decode_engine_text_packet, encode_engine_packet, encode_engine_payload,
|
|
},
|
|
error::SocketIoError,
|
|
server::SocketIo,
|
|
session::{Session, Transport},
|
|
socket::DisconnectReason,
|
|
};
|
|
struct ActiveGuard(Arc<std::sync::atomic::AtomicBool>);
|
|
|
|
impl Drop for ActiveGuard {
|
|
fn drop(&mut self) {
|
|
self.0.store(false, Ordering::Release);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
struct EngineQuery {
|
|
#[serde(rename = "EIO")]
|
|
eio: Option<String>,
|
|
transport: Option<String>,
|
|
sid: Option<String>,
|
|
}
|
|
|
|
pub fn configure(cfg: &mut ServiceConfig, io: SocketIo) {
|
|
let path = io.config().path.clone();
|
|
configure_at(cfg, path, io);
|
|
}
|
|
|
|
pub fn configure_at(
|
|
cfg: &mut ServiceConfig,
|
|
path: impl Into<String>,
|
|
io: SocketIo,
|
|
) {
|
|
cfg.app_data(Data::new(io)).service(
|
|
web::resource(path.into())
|
|
.route(web::get().to(engine_get))
|
|
.route(web::post().to(engine_post)),
|
|
);
|
|
}
|
|
|
|
async fn engine_get(
|
|
io: Data<SocketIo>,
|
|
req: HttpRequest,
|
|
stream: Payload,
|
|
query: web::Query<EngineQuery>,
|
|
) -> Result<HttpResponse, Error> {
|
|
validate_eio(&query)?;
|
|
|
|
match query.transport.as_deref() {
|
|
Some("polling") if query.sid.is_none() => polling_open(io, &req).await,
|
|
Some("polling") => polling_get(io, query.sid.as_deref()).await,
|
|
Some("websocket") => {
|
|
websocket_open(io, req, stream, query.sid.clone()).await
|
|
}
|
|
_ => Err(ErrorBadRequest("unsupported transport")),
|
|
}
|
|
}
|
|
|
|
async fn engine_post(
|
|
io: Data<SocketIo>,
|
|
query: web::Query<EngineQuery>,
|
|
body: Bytes,
|
|
) -> Result<HttpResponse, Error> {
|
|
validate_eio(&query)?;
|
|
if query.transport.as_deref() != Some("polling") {
|
|
return Err(ErrorBadRequest("unsupported transport"));
|
|
}
|
|
if body.len() > io.config().max_payload {
|
|
return Err(ErrorBadRequest("payload too large"));
|
|
}
|
|
|
|
let sid = query
|
|
.sid
|
|
.as_deref()
|
|
.ok_or_else(|| ErrorBadRequest("missing sid"))?;
|
|
validate_sid(sid)?;
|
|
let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?;
|
|
if session.post_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() {
|
|
// Another POST is in progress — return error without destroying session
|
|
return Err(ErrorBadRequest("concurrent polling request"));
|
|
}
|
|
|
|
let _guard = ActiveGuard(session.post_active.clone());
|
|
handle_polling_body(&io, session, &body).await?;
|
|
|
|
Ok(HttpResponse::Ok()
|
|
.insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8"))
|
|
.body("ok"))
|
|
}
|
|
|
|
async fn handle_polling_body(
|
|
io: &SocketIo,
|
|
session: Arc<Session>,
|
|
body: &Bytes,
|
|
) -> Result<(), Error> {
|
|
touch_session(&session).await;
|
|
let payload = std::str::from_utf8(body).map_err(ErrorBadRequest)?;
|
|
let packets = decode_engine_payload(payload).map_err(map_socket_error)?;
|
|
for packet in packets {
|
|
handle_engine_packet(io, session.clone(), packet).await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn polling_open(
|
|
io: Data<SocketIo>,
|
|
req: &HttpRequest,
|
|
) -> Result<HttpResponse, Error> {
|
|
let session = Session::new(req.get_session().user());
|
|
let sid = session.engine_id.clone();
|
|
io.insert_session(session).await;
|
|
|
|
let open = json!({
|
|
"sid": sid,
|
|
"upgrades": ["websocket"],
|
|
"pingInterval": io.config().ping_interval.as_millis(),
|
|
"pingTimeout": io.config().ping_timeout.as_millis(),
|
|
"maxPayload": io.config().max_payload
|
|
});
|
|
|
|
Ok(text_response(encode_engine_packet(
|
|
&EnginePacket::Open(open),
|
|
true,
|
|
)))
|
|
}
|
|
|
|
async fn polling_get(
|
|
io: Data<SocketIo>,
|
|
sid: Option<&str>,
|
|
) -> Result<HttpResponse, Error> {
|
|
let sid = sid.ok_or_else(|| ErrorBadRequest("missing sid"))?;
|
|
validate_sid(sid)?;
|
|
let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?;
|
|
if session.get_active.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed).is_err() {
|
|
// Another GET is in progress — return error without destroying session
|
|
return Err(ErrorBadRequest("concurrent polling request"));
|
|
}
|
|
|
|
let _guard = ActiveGuard(session.get_active.clone());
|
|
loop {
|
|
let packets = session.drain().await;
|
|
if !packets.is_empty() {
|
|
return Ok(text_response(encode_engine_payload(&packets, true)));
|
|
}
|
|
tokio::select! {
|
|
() = session.notify.notified() => {
|
|
}
|
|
() = tokio::time::sleep(io.config().ping_interval) => {
|
|
return Ok(text_response(encode_engine_payload(
|
|
&[EnginePacket::Ping(None)], true,
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn websocket_open(
|
|
io: Data<SocketIo>,
|
|
req: HttpRequest,
|
|
stream: Payload,
|
|
sid: Option<String>,
|
|
) -> Result<HttpResponse, Error> {
|
|
let direct_websocket = sid.is_none();
|
|
let session = match sid {
|
|
Some(ref sid) => {
|
|
validate_sid(sid)?;
|
|
io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?
|
|
}
|
|
None => {
|
|
let session = Session::new(req.get_session().user());
|
|
io.insert_session(session.clone()).await;
|
|
session
|
|
}
|
|
};
|
|
|
|
let (response, mut ws_session, messages) = actix_ws::handle(&req, stream)?;
|
|
let io = io.get_ref().clone();
|
|
|
|
actix_web::rt::spawn(async move {
|
|
if direct_websocket {
|
|
*session.transport.lock().await = Transport::WebSocket;
|
|
let open = json!({
|
|
"sid": session.engine_id,
|
|
"upgrades": [],
|
|
"pingInterval": io.config().ping_interval.as_millis(),
|
|
"pingTimeout": io.config().ping_timeout.as_millis(),
|
|
"maxPayload": io.config().max_payload
|
|
});
|
|
if ws_session
|
|
.text(encode_engine_packet(&EnginePacket::Open(open), false))
|
|
.await
|
|
.is_err()
|
|
{
|
|
io.remove_session(&session, DisconnectReason::TransportClosed)
|
|
.await;
|
|
return;
|
|
}
|
|
}
|
|
|
|
websocket_loop(io, session, ws_session, messages, direct_websocket)
|
|
.await;
|
|
});
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
async fn websocket_loop(
|
|
io: SocketIo,
|
|
session: Arc<Session>,
|
|
mut ws_session: actix_ws::Session,
|
|
mut messages: actix_ws::MessageStream,
|
|
mut upgraded: bool,
|
|
) {
|
|
let mut heartbeat = tokio::time::interval_at(
|
|
tokio::time::Instant::now() + io.config().ping_interval,
|
|
io.config().ping_interval,
|
|
);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
message = messages.next() => {
|
|
match message {
|
|
Some(Ok(actix_ws::Message::Text(text))) => {
|
|
touch_session(&session).await;
|
|
match decode_engine_text_packet(text.as_ref()) {
|
|
Ok(EnginePacket::Ping(Some(value))) if value == "probe" && !upgraded => {
|
|
if ws_session.text("3probe").await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Ok(EnginePacket::Upgrade) if !upgraded => {
|
|
*session.transport.lock().await = Transport::WebSocket;
|
|
upgraded = true;
|
|
}
|
|
Ok(packet) => {
|
|
if handle_engine_packet(&io, session.clone(), packet).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
Some(Ok(actix_ws::Message::Binary(bytes))) => {
|
|
touch_session(&session).await;
|
|
if handle_engine_packet(
|
|
&io,
|
|
session.clone(),
|
|
EnginePacket::Message(SocketPayload::Binary(bytes.to_vec())),
|
|
)
|
|
.await
|
|
.is_err()
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
Some(Ok(actix_ws::Message::Ping(bytes))) => {
|
|
touch_session(&session).await;
|
|
if ws_session.pong(&bytes).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Some(Ok(actix_ws::Message::Pong(_))) => {
|
|
touch_session(&session).await;
|
|
}
|
|
Some(Ok(actix_ws::Message::Close(reason))) => {
|
|
let _ = ws_session.close(reason).await;
|
|
break;
|
|
}
|
|
Some(Ok(actix_ws::Message::Continuation(_))) => {}
|
|
Some(Ok(actix_ws::Message::Nop)) => {}
|
|
Some(Err(_)) | None => break,
|
|
}
|
|
}
|
|
() = session.notify.notified() => {
|
|
for packet in session.drain().await {
|
|
if send_ws_packet(&mut ws_session, packet).await.is_err() {
|
|
io.remove_session(&session, DisconnectReason::TransportClosed).await;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
_ = heartbeat.tick() => {
|
|
if session.last_pong.lock().await.elapsed()
|
|
> io.config().ping_interval + io.config().ping_timeout
|
|
{
|
|
io.remove_session(&session, DisconnectReason::PingTimeout).await;
|
|
return;
|
|
}
|
|
if ws_session.text("2").await.is_err() {
|
|
io.remove_session(&session, DisconnectReason::TransportClosed).await;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
io.remove_session(&session, DisconnectReason::TransportClosed)
|
|
.await;
|
|
}
|
|
|
|
async fn touch_session(session: &Session) {
|
|
*session.last_pong.lock().await = std::time::Instant::now();
|
|
}
|
|
|
|
async fn send_ws_packet(
|
|
ws_session: &mut actix_ws::Session,
|
|
packet: EnginePacket,
|
|
) -> std::result::Result<(), actix_ws::Closed> {
|
|
match packet {
|
|
EnginePacket::Message(SocketPayload::Binary(bytes)) => {
|
|
ws_session.binary(bytes).await
|
|
}
|
|
packet => ws_session.text(encode_engine_packet(&packet, false)).await,
|
|
}
|
|
}
|
|
|
|
async fn handle_engine_packet(
|
|
io: &SocketIo,
|
|
session: Arc<Session>,
|
|
packet: EnginePacket,
|
|
) -> Result<(), Error> {
|
|
match packet {
|
|
EnginePacket::Ping(data) => {
|
|
session.enqueue(EnginePacket::Pong(data)).await;
|
|
}
|
|
EnginePacket::Pong(_) => {
|
|
*session.last_pong.lock().await = std::time::Instant::now();
|
|
}
|
|
EnginePacket::Message(payload) => {
|
|
io.handle_socket_payload(session, payload)
|
|
.await
|
|
.map_err(map_socket_error)?;
|
|
}
|
|
EnginePacket::Close => {
|
|
io.remove_session(&session, DisconnectReason::Client).await;
|
|
}
|
|
EnginePacket::Open(_) => {
|
|
tracing::warn!("client sent unexpected Open packet");
|
|
}
|
|
EnginePacket::Upgrade | EnginePacket::Noop => {}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_eio(query: &EngineQuery) -> Result<(), Error> {
|
|
match query.eio.as_deref() {
|
|
Some("4") => Ok(()),
|
|
_ => Err(ErrorBadRequest("unsupported EIO version")),
|
|
}
|
|
}
|
|
|
|
fn validate_sid(sid: &str) -> Result<(), Error> {
|
|
if sid.is_empty()
|
|
|| sid.len() > 128
|
|
|| !sid.bytes().all(|byte| byte.is_ascii_graphic())
|
|
{
|
|
return Err(ErrorBadRequest("invalid sid"));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn text_response(body: String) -> HttpResponse {
|
|
HttpResponse::Ok()
|
|
.insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8"))
|
|
.body(body)
|
|
}
|
|
|
|
fn map_socket_error(err: SocketIoError) -> Error {
|
|
match err {
|
|
SocketIoError::UnknownSession | SocketIoError::UnknownNamespace(_) => {
|
|
ErrorNotFound(err)
|
|
}
|
|
SocketIoError::InvalidPacket(_) => ErrorBadRequest(err),
|
|
_ => ErrorInternalServerError(err),
|
|
}
|
|
}
|