gitdataai/lib/socketio/server.rs

626 lines
18 KiB
Rust

use std::{
collections::{HashMap, HashSet},
future::Future,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use serde::Serialize;
use serde_json::{Value, json};
use tokio::sync::RwLock;
use crate::{
adapter::{Adapter, BroadcastOptions, MemoryAdapter},
config::SocketIoConfig,
engine_packet::SocketPayload,
error::{Result, SocketIoError},
packet::{EventPayload, Packet, PacketType},
session::{PendingBinary, Session, SocketState},
socket::{AckSender, DisconnectReason, Socket},
};
pub type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
pub type ConnectHandler = Arc<dyn Fn(Socket) -> BoxFuture + Send + Sync>;
pub type DisconnectHandler =
Arc<dyn Fn(Socket, DisconnectReason) -> BoxFuture + Send + Sync>;
pub type EventHandler =
Arc<dyn Fn(Socket, EventPayload) -> BoxFuture + Send + Sync>;
pub type Middleware = Arc<
dyn Fn(
Socket,
Option<Value>,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct SocketIo {
pub inner: Arc<Inner>,
}
#[derive(Clone)]
pub struct Namespace {
pub io: SocketIo,
pub name: String,
}
pub struct SocketIoBuilder {
pub config: SocketIoConfig,
pub adapter: Arc<dyn Adapter>,
}
pub struct Inner {
pub config: SocketIoConfig,
pub sessions: RwLock<HashMap<String, Arc<Session>>>,
pub namespaces: RwLock<HashMap<String, Arc<NamespaceState>>>,
pub adapter: Arc<dyn Adapter>,
pub next_ack_id: AtomicU64,
}
pub struct NamespaceState {
pub connect_handler: RwLock<Option<ConnectHandler>>,
pub disconnect_handler: RwLock<Option<DisconnectHandler>>,
pub event_handlers: RwLock<HashMap<String, EventHandler>>,
pub middleware: RwLock<Vec<Middleware>>,
}
impl Default for SocketIo {
fn default() -> Self {
Self::new()
}
}
impl SocketIo {
pub fn builder() -> SocketIoBuilder {
SocketIoBuilder {
config: SocketIoConfig::default(),
adapter: MemoryAdapter::new(),
}
}
pub fn new() -> Self {
Self::builder().build()
}
pub fn config(&self) -> &SocketIoConfig {
&self.inner.config
}
pub async fn namespace(&self, name: impl Into<String>) -> Namespace {
let name = normalize_namespace(name.into());
self.ensure_namespace(&name).await;
Namespace {
io: self.clone(),
name,
}
}
pub async fn on_connect<F, Fut>(&self, handler: F)
where
F: Fn(Socket) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.namespace("/").await.on_connect(handler).await;
}
pub async fn on_disconnect<F, Fut>(&self, handler: F)
where
F: Fn(Socket, DisconnectReason) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.namespace("/").await.on_disconnect(handler).await;
}
pub async fn on<F, Fut>(&self, event: impl Into<String>, handler: F)
where
F: Fn(Socket, EventPayload) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.namespace("/").await.on(event, handler).await;
}
pub async fn emit<T: Serialize>(&self, event: &str, data: T) -> Result<()> {
self.namespace("/").await.emit(event, data).await
}
pub async fn emit_to_room<T: Serialize>(
&self,
room: &str,
event: &str,
data: T,
) -> Result<()> {
self.namespace("/")
.await
.to(room.to_owned())
.emit(event, data)
.await
}
pub async fn session(&self, sid: &str) -> Option<Arc<Session>> {
self.inner.sessions.read().await.get(sid).cloned()
}
pub async fn insert_session(&self, session: Arc<Session>) {
self.inner
.sessions
.write()
.await
.insert(session.engine_id.clone(), session);
}
pub async fn remove_session(
&self,
session: &Arc<Session>,
reason: DisconnectReason,
) {
self.inner.sessions.write().await.remove(&session.engine_id);
let namespaces = session
.namespaces
.lock()
.await
.keys()
.cloned()
.collect::<Vec<_>>();
for namespace in namespaces {
let _ = self
.disconnect_socket(&namespace, session, reason.clone())
.await;
}
}
pub async fn handle_socket_payload(
&self,
session: Arc<Session>,
payload: SocketPayload,
) -> Result<()> {
match payload {
SocketPayload::Text(text) => {
let packet = Packet::decode(&text)?;
if packet.expected_attachments == 0 {
self.handle_socket_packet(session, packet).await
} else {
*session.pending_binary.lock().await =
Some(PendingBinary { packet });
Ok(())
}
}
SocketPayload::Binary(bytes) => {
let mut pending = session.pending_binary.lock().await;
let Some(mut pending_binary) = pending.take() else {
return Err(SocketIoError::InvalidPacket(
"unexpected binary attachment".to_owned(),
));
};
pending_binary.packet.attachments.push(bytes);
if pending_binary.packet.attachments.len()
== pending_binary.packet.expected_attachments
{
drop(pending);
self.handle_socket_packet(session, pending_binary.packet)
.await
} else {
*pending = Some(pending_binary);
Ok(())
}
}
}
}
async fn handle_socket_packet(
&self,
session: Arc<Session>,
packet: Packet,
) -> Result<()> {
match packet.packet_type {
PacketType::Connect => {
self.connect_namespace(session, packet).await
}
PacketType::Disconnect => {
self.disconnect_socket(
&packet.namespace,
&session,
DisconnectReason::Client,
)
.await
}
PacketType::Event | PacketType::BinaryEvent => {
self.dispatch_event(session, packet).await
}
PacketType::Ack | PacketType::BinaryAck => {
self.resolve_ack(session, packet).await
}
PacketType::ConnectError => Ok(()),
}
}
async fn connect_namespace(
&self,
session: Arc<Session>,
packet: Packet,
) -> Result<()> {
let namespace = normalize_namespace(packet.namespace);
let state = self.namespace_state(&namespace).await?;
let sid = uuid::Uuid::new_v4().to_string();
let socket = Socket {
io: self.clone(),
session: session.clone(),
namespace: namespace.clone(),
sid: sid.clone(),
};
for middleware in state.middleware.read().await.iter().cloned() {
if let Err(err) =
middleware(socket.clone(), packet.data.clone()).await
{
session
.enqueue_socket_packet(Packet::connect_error(
&namespace,
err.to_string(),
))
.await;
return Ok(());
}
}
self.inner
.adapter
.add_socket(&namespace, &session.engine_id)
.await?;
session.namespaces.lock().await.insert(
namespace.clone(),
SocketState {
sid,
rooms: HashSet::new(),
auth: packet.data,
},
);
session
.enqueue_socket_packet(Packet::connect(
&namespace,
Some(json!({ "sid": socket.sid })),
))
.await;
if let Some(handler) = state.connect_handler.read().await.clone() {
handler(socket).await;
}
Ok(())
}
pub async fn disconnect_socket(
&self,
namespace: &str,
session: &Arc<Session>,
reason: DisconnectReason,
) -> Result<()> {
let namespace = normalize_namespace(namespace.to_owned());
let removed = session.namespaces.lock().await.remove(&namespace);
if let Some(socket_state) = removed {
self.inner
.adapter
.remove_socket(&namespace, &session.engine_id)
.await?;
if let Ok(state) = self.namespace_state(&namespace).await
&& let Some(handler) =
state.disconnect_handler.read().await.clone()
{
handler(
Socket {
io: self.clone(),
session: session.clone(),
namespace,
sid: socket_state.sid,
},
reason,
)
.await;
}
}
Ok(())
}
async fn dispatch_event(
&self,
session: Arc<Session>,
packet: Packet,
) -> Result<()> {
let namespace = normalize_namespace(packet.namespace.clone());
let state = self.namespace_state(&namespace).await?;
let socket_state = session
.namespaces
.lock()
.await
.get(&namespace)
.map(|state| state.sid.clone())
.ok_or_else(|| {
SocketIoError::UnknownNamespace(namespace.clone())
})?;
let ack = packet
.id
.map(|id| AckSender::new(session.clone(), namespace.clone(), id));
let payload = packet.into_event_payload(ack)?;
let handler = state
.event_handlers
.read()
.await
.get(&payload.event)
.cloned();
if let Some(handler) = handler {
handler(
Socket {
io: self.clone(),
session,
namespace,
sid: socket_state,
},
payload,
)
.await;
}
Ok(())
}
async fn resolve_ack(
&self,
session: Arc<Session>,
packet: Packet,
) -> Result<()> {
let Some(id) = packet.id else {
return Ok(());
};
let args = match packet.data {
Some(Value::Array(values)) => values,
Some(value) => vec![value],
None => Vec::new(),
};
if let Some(sender) = session
.ack_waiters
.lock()
.await
.remove(&(normalize_namespace(packet.namespace), id))
{
let _ = sender.send(args);
}
Ok(())
}
pub async fn join(
&self,
namespace: &str,
engine_id: &str,
room: String,
) -> Result<()> {
let namespace = normalize_namespace(namespace.to_owned());
let session = self
.session(engine_id)
.await
.ok_or(SocketIoError::UnknownSession)?;
if let Some(state) = session.namespaces.lock().await.get_mut(&namespace)
{
state.rooms.insert(room.clone());
}
self.inner
.adapter
.add_to_room(&namespace, engine_id, &room)
.await
}
pub async fn leave(
&self,
namespace: &str,
engine_id: &str,
room: &str,
) -> Result<()> {
let namespace = normalize_namespace(namespace.to_owned());
let session = self
.session(engine_id)
.await
.ok_or(SocketIoError::UnknownSession)?;
if let Some(state) = session.namespaces.lock().await.get_mut(&namespace)
{
state.rooms.remove(room);
}
self.inner
.adapter
.remove_from_room(&namespace, engine_id, room)
.await
}
pub async fn emit_to_sid<T: Serialize>(
&self,
namespace: &str,
engine_id: &str,
event: &str,
data: T,
) -> Result<()> {
let args = value_to_args(serde_json::to_value(data)?);
self.emit_packet_to_sid(
engine_id,
Packet::event(namespace, event, args),
)
.await
}
pub async fn emit_binary_to_sid(
&self,
namespace: &str,
engine_id: &str,
event: &str,
args: Vec<Value>,
binary: Vec<Vec<u8>>,
) -> Result<()> {
self.emit_packet_to_sid(
engine_id,
Packet::event(namespace, event, args).with_binary(binary),
)
.await
}
pub async fn emit_to_sid_with_ack<T: Serialize>(
&self,
namespace: &str,
engine_id: &str,
event: &str,
data: T,
) -> Result<Vec<Value>> {
let session = self
.session(engine_id)
.await
.ok_or(SocketIoError::UnknownSession)?;
let id = self.inner.next_ack_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = tokio::sync::oneshot::channel();
session
.ack_waiters
.lock()
.await
.insert((normalize_namespace(namespace.to_owned()), id), tx);
let mut packet = Packet::event(
namespace,
event,
value_to_args(serde_json::to_value(data)?),
);
packet.id = Some(id);
session.enqueue_socket_packet(packet).await;
match tokio::time::timeout(self.inner.config.ack_timeout, rx).await {
Ok(Ok(values)) => Ok(values),
_ => {
session
.ack_waiters
.lock()
.await
.remove(&(normalize_namespace(namespace.to_owned()), id));
Err(SocketIoError::AckTimeout)
}
}
}
pub async fn broadcast_with_opts<T: Serialize>(
&self,
mut opts: BroadcastOptions,
event: &str,
data: T,
) -> Result<()> {
opts.namespace = normalize_namespace(opts.namespace);
let packet = Packet::event(
&opts.namespace,
event,
value_to_args(serde_json::to_value(data)?),
);
self.broadcast_packet(opts, packet).await
}
async fn broadcast_packet(
&self,
opts: BroadcastOptions,
packet: Packet,
) -> Result<()> {
let sockets =
self.inner.adapter.sockets(&opts.namespace, &opts).await?;
let mut failures = Vec::new();
for engine_id in sockets {
if let Err(err) =
self.emit_packet_to_sid(&engine_id, packet.clone()).await
{
failures.push(format!("{engine_id}: {err}"));
}
}
if let Err(err) = self.inner.adapter.publish(&packet, &opts).await {
failures.push(format!("adapter publish: {err}"));
}
if failures.is_empty() {
Ok(())
} else {
Err(SocketIoError::Adapter(format!(
"broadcast partially failed: {}",
failures.join("; ")
)))
}
}
pub async fn deliver_remote_packet(
&self,
opts: BroadcastOptions,
packet: Packet,
) -> Result<()> {
let sockets =
self.inner.adapter.sockets(&opts.namespace, &opts).await?;
let mut failures = Vec::new();
for engine_id in sockets {
if let Err(err) =
self.emit_packet_to_sid(&engine_id, packet.clone()).await
{
failures.push(format!("{engine_id}: {err}"));
}
}
if failures.is_empty() {
Ok(())
} else {
Err(SocketIoError::Adapter(format!(
"remote broadcast partially failed: {}",
failures.join("; ")
)))
}
}
async fn emit_packet_to_sid(
&self,
engine_id: &str,
packet: Packet,
) -> Result<()> {
let session = self
.session(engine_id)
.await
.ok_or(SocketIoError::UnknownSession)?;
session.enqueue_socket_packet(packet).await;
Ok(())
}
pub async fn ensure_namespace(
&self,
namespace: &str,
) -> Arc<NamespaceState> {
let mut namespaces = self.inner.namespaces.write().await;
namespaces
.entry(namespace.to_owned())
.or_insert_with(|| Arc::new(NamespaceState::default()))
.clone()
}
async fn namespace_state(
&self,
namespace: &str,
) -> Result<Arc<NamespaceState>> {
self.inner
.namespaces
.read()
.await
.get(namespace)
.cloned()
.ok_or_else(|| {
SocketIoError::UnknownNamespace(namespace.to_owned())
})
}
}
fn value_to_args(value: Value) -> Vec<Value> {
match value {
Value::Array(values) => values,
value => vec![value],
}
}
fn normalize_namespace(namespace: String) -> String {
if namespace.is_empty() || namespace == "/" {
"/".to_owned()
} else if namespace.starts_with('/') {
namespace
} else {
format!("/{namespace}")
}
}