gitdataai/lib/socketio/actix.rs

401 lines
13 KiB
Rust

use std::sync::Arc;
use std::sync::atomic::Ordering;
use actix_web::{
Error, HttpRequest, HttpResponse,
error::{ErrorBadRequest, ErrorInternalServerError, ErrorNotFound},
http::header,
web::{self, Bytes, Data, Payload, ServiceConfig},
};
use futures_util::StreamExt;
use serde_json::json;
use session::SessionExt;
use crate::{
engine_packet::{
EnginePacket, SocketPayload, decode_engine_payload,
decode_engine_text_packet, encode_engine_packet, encode_engine_payload,
},
error::SocketIoError,
server::SocketIo,
session::{Session, Transport},
socket::DisconnectReason,
};
struct ActiveGuard(Arc<std::sync::atomic::AtomicBool>);
impl Drop for ActiveGuard {
fn drop(&mut self) {
self.0.store(false, Ordering::Release);
}
}
#[derive(Debug, serde::Deserialize)]
struct EngineQuery {
#[serde(rename = "EIO")]
eio: Option<String>,
transport: Option<String>,
sid: Option<String>,
}
pub fn configure(cfg: &mut ServiceConfig, io: SocketIo) {
let path = io.config().path.clone();
configure_at(cfg, path, io);
}
pub fn configure_at(
cfg: &mut ServiceConfig,
path: impl Into<String>,
io: SocketIo,
) {
cfg.app_data(Data::new(io)).service(
web::resource(path.into())
.route(web::get().to(engine_get))
.route(web::post().to(engine_post)),
);
}
async fn engine_get(
io: Data<SocketIo>,
req: HttpRequest,
stream: Payload,
query: web::Query<EngineQuery>,
) -> Result<HttpResponse, Error> {
validate_eio(&query)?;
match query.transport.as_deref() {
Some("polling") if query.sid.is_none() => polling_open(io, &req).await,
Some("polling") => polling_get(io, query.sid.as_deref()).await,
Some("websocket") => {
websocket_open(io, req, stream, query.sid.clone()).await
}
_ => Err(ErrorBadRequest("unsupported transport")),
}
}
async fn engine_post(
io: Data<SocketIo>,
query: web::Query<EngineQuery>,
body: Bytes,
) -> Result<HttpResponse, Error> {
validate_eio(&query)?;
if query.transport.as_deref() != Some("polling") {
return Err(ErrorBadRequest("unsupported transport"));
}
if body.len() > io.config().max_payload {
return Err(ErrorBadRequest("payload too large"));
}
let sid = query
.sid
.as_deref()
.ok_or_else(|| ErrorBadRequest("missing sid"))?;
validate_sid(sid)?;
let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?;
if session
.post_active
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
// Another POST is in progress — return error without destroying session
return Err(ErrorBadRequest("concurrent polling request"));
}
let _guard = ActiveGuard(session.post_active.clone());
handle_polling_body(&io, session, &body).await?;
Ok(HttpResponse::Ok()
.insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8"))
.body("ok"))
}
async fn handle_polling_body(
io: &SocketIo,
session: Arc<Session>,
body: &Bytes,
) -> Result<(), Error> {
touch_session(&session).await;
let payload = std::str::from_utf8(body).map_err(ErrorBadRequest)?;
let packets = decode_engine_payload(payload).map_err(map_socket_error)?;
for packet in packets {
handle_engine_packet(io, session.clone(), packet).await?;
}
Ok(())
}
async fn polling_open(
io: Data<SocketIo>,
req: &HttpRequest,
) -> Result<HttpResponse, Error> {
let session = Session::new(req.get_session().user());
let sid = session.engine_id.clone();
io.insert_session(session).await;
let open = json!({
"sid": sid,
"upgrades": ["websocket"],
"pingInterval": io.config().ping_interval.as_millis(),
"pingTimeout": io.config().ping_timeout.as_millis(),
"maxPayload": io.config().max_payload
});
Ok(text_response(encode_engine_packet(
&EnginePacket::Open(open),
true,
)))
}
async fn polling_get(
io: Data<SocketIo>,
sid: Option<&str>,
) -> Result<HttpResponse, Error> {
let sid = sid.ok_or_else(|| ErrorBadRequest("missing sid"))?;
validate_sid(sid)?;
let session = io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?;
if session
.get_active
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
// Another GET is in progress — return error without destroying session
return Err(ErrorBadRequest("concurrent polling request"));
}
let _guard = ActiveGuard(session.get_active.clone());
loop {
let packets = session.drain().await;
if !packets.is_empty() {
return Ok(text_response(encode_engine_payload(&packets, true)));
}
tokio::select! {
() = session.notify.notified() => {
}
() = tokio::time::sleep(io.config().ping_interval) => {
return Ok(text_response(encode_engine_payload(
&[EnginePacket::Ping(None)], true,
)));
}
}
}
}
async fn websocket_open(
io: Data<SocketIo>,
req: HttpRequest,
stream: Payload,
sid: Option<String>,
) -> Result<HttpResponse, Error> {
let direct_websocket = sid.is_none();
let session = match sid {
Some(ref sid) => {
validate_sid(sid)?;
io.session(sid).await.ok_or_else(|| ErrorNotFound("sid"))?
}
None => {
let session = Session::new(req.get_session().user());
io.insert_session(session.clone()).await;
session
}
};
let (response, mut ws_session, messages) = actix_ws::handle(&req, stream)?;
let io = io.get_ref().clone();
actix_web::rt::spawn(async move {
if direct_websocket {
*session.transport.lock().await = Transport::WebSocket;
let open = json!({
"sid": session.engine_id,
"upgrades": [],
"pingInterval": io.config().ping_interval.as_millis(),
"pingTimeout": io.config().ping_timeout.as_millis(),
"maxPayload": io.config().max_payload
});
if ws_session
.text(encode_engine_packet(&EnginePacket::Open(open), false))
.await
.is_err()
{
io.remove_session(&session, DisconnectReason::TransportClosed)
.await;
return;
}
}
websocket_loop(io, session, ws_session, messages, direct_websocket)
.await;
});
Ok(response)
}
async fn websocket_loop(
io: SocketIo,
session: Arc<Session>,
mut ws_session: actix_ws::Session,
mut messages: actix_ws::MessageStream,
mut upgraded: bool,
) {
let mut heartbeat = tokio::time::interval_at(
tokio::time::Instant::now() + io.config().ping_interval,
io.config().ping_interval,
);
loop {
tokio::select! {
message = messages.next() => {
match message {
Some(Ok(actix_ws::Message::Text(text))) => {
touch_session(&session).await;
match decode_engine_text_packet(text.as_ref()) {
Ok(EnginePacket::Ping(Some(value))) if value == "probe" && !upgraded => {
if ws_session.text("3probe").await.is_err() {
break;
}
}
Ok(EnginePacket::Upgrade) if !upgraded => {
*session.transport.lock().await = Transport::WebSocket;
upgraded = true;
}
Ok(packet) => {
if handle_engine_packet(&io, session.clone(), packet).await.is_err() {
break;
}
}
Err(_) => break,
}
}
Some(Ok(actix_ws::Message::Binary(bytes))) => {
touch_session(&session).await;
if handle_engine_packet(
&io,
session.clone(),
EnginePacket::Message(SocketPayload::Binary(bytes.to_vec())),
)
.await
.is_err()
{
break;
}
}
Some(Ok(actix_ws::Message::Ping(bytes))) => {
touch_session(&session).await;
if ws_session.pong(&bytes).await.is_err() {
break;
}
}
Some(Ok(actix_ws::Message::Pong(_))) => {
touch_session(&session).await;
}
Some(Ok(actix_ws::Message::Close(reason))) => {
let _ = ws_session.close(reason).await;
break;
}
Some(Ok(actix_ws::Message::Continuation(_))) => {}
Some(Ok(actix_ws::Message::Nop)) => {}
Some(Err(_)) | None => break,
}
}
() = session.notify.notified() => {
for packet in session.drain().await {
if send_ws_packet(&mut ws_session, packet).await.is_err() {
io.remove_session(&session, DisconnectReason::TransportClosed).await;
return;
}
}
}
_ = heartbeat.tick() => {
if session.last_pong.lock().await.elapsed()
> io.config().ping_interval + io.config().ping_timeout
{
io.remove_session(&session, DisconnectReason::PingTimeout).await;
return;
}
if ws_session.text("2").await.is_err() {
io.remove_session(&session, DisconnectReason::TransportClosed).await;
return;
}
}
}
}
io.remove_session(&session, DisconnectReason::TransportClosed)
.await;
}
async fn touch_session(session: &Session) {
*session.last_pong.lock().await = std::time::Instant::now();
}
async fn send_ws_packet(
ws_session: &mut actix_ws::Session,
packet: EnginePacket,
) -> std::result::Result<(), actix_ws::Closed> {
match packet {
EnginePacket::Message(SocketPayload::Binary(bytes)) => {
ws_session.binary(bytes).await
}
packet => ws_session.text(encode_engine_packet(&packet, false)).await,
}
}
async fn handle_engine_packet(
io: &SocketIo,
session: Arc<Session>,
packet: EnginePacket,
) -> Result<(), Error> {
match packet {
EnginePacket::Ping(data) => {
session.enqueue(EnginePacket::Pong(data)).await;
}
EnginePacket::Pong(_) => {
*session.last_pong.lock().await = std::time::Instant::now();
}
EnginePacket::Message(payload) => {
io.handle_socket_payload(session, payload)
.await
.map_err(map_socket_error)?;
}
EnginePacket::Close => {
io.remove_session(&session, DisconnectReason::Client).await;
}
EnginePacket::Open(_) => {
tracing::warn!("client sent unexpected Open packet");
}
EnginePacket::Upgrade | EnginePacket::Noop => {}
}
Ok(())
}
fn validate_eio(query: &EngineQuery) -> Result<(), Error> {
match query.eio.as_deref() {
Some("4") => Ok(()),
_ => Err(ErrorBadRequest("unsupported EIO version")),
}
}
fn validate_sid(sid: &str) -> Result<(), Error> {
if sid.is_empty()
|| sid.len() > 128
|| !sid.bytes().all(|byte| byte.is_ascii_graphic())
{
return Err(ErrorBadRequest("invalid sid"));
}
Ok(())
}
fn text_response(body: String) -> HttpResponse {
HttpResponse::Ok()
.insert_header((header::CONTENT_TYPE, "text/plain; charset=UTF-8"))
.body(body)
}
fn map_socket_error(err: SocketIoError) -> Error {
match err {
SocketIoError::UnknownSession | SocketIoError::UnknownNamespace(_) => {
ErrorNotFound(err)
}
SocketIoError::InvalidPacket(_) => ErrorBadRequest(err),
_ => ErrorInternalServerError(err),
}
}