626 lines
18 KiB
Rust
626 lines
18 KiB
Rust
use std::{
|
|
collections::{HashMap, HashSet},
|
|
future::Future,
|
|
pin::Pin,
|
|
sync::{
|
|
Arc,
|
|
atomic::{AtomicU64, Ordering},
|
|
},
|
|
};
|
|
|
|
use serde::Serialize;
|
|
use serde_json::{Value, json};
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::{
|
|
adapter::{Adapter, BroadcastOptions, MemoryAdapter},
|
|
config::SocketIoConfig,
|
|
engine_packet::SocketPayload,
|
|
error::{Result, SocketIoError},
|
|
packet::{EventPayload, Packet, PacketType},
|
|
session::{PendingBinary, Session, SocketState},
|
|
socket::{AckSender, DisconnectReason, Socket},
|
|
};
|
|
|
|
pub type BoxFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
|
|
pub type ConnectHandler = Arc<dyn Fn(Socket) -> BoxFuture + Send + Sync>;
|
|
pub type DisconnectHandler =
|
|
Arc<dyn Fn(Socket, DisconnectReason) -> BoxFuture + Send + Sync>;
|
|
pub type EventHandler =
|
|
Arc<dyn Fn(Socket, EventPayload) -> BoxFuture + Send + Sync>;
|
|
pub type Middleware = Arc<
|
|
dyn Fn(
|
|
Socket,
|
|
Option<Value>,
|
|
) -> Pin<Box<dyn Future<Output = Result<()>> + Send>>
|
|
+ Send
|
|
+ Sync,
|
|
>;
|
|
|
|
#[derive(Clone)]
|
|
pub struct SocketIo {
|
|
pub inner: Arc<Inner>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Namespace {
|
|
pub io: SocketIo,
|
|
pub name: String,
|
|
}
|
|
|
|
pub struct SocketIoBuilder {
|
|
pub config: SocketIoConfig,
|
|
pub adapter: Arc<dyn Adapter>,
|
|
}
|
|
|
|
pub struct Inner {
|
|
pub config: SocketIoConfig,
|
|
pub sessions: RwLock<HashMap<String, Arc<Session>>>,
|
|
pub namespaces: RwLock<HashMap<String, Arc<NamespaceState>>>,
|
|
pub adapter: Arc<dyn Adapter>,
|
|
pub next_ack_id: AtomicU64,
|
|
}
|
|
|
|
pub struct NamespaceState {
|
|
pub connect_handler: RwLock<Option<ConnectHandler>>,
|
|
pub disconnect_handler: RwLock<Option<DisconnectHandler>>,
|
|
pub event_handlers: RwLock<HashMap<String, EventHandler>>,
|
|
pub middleware: RwLock<Vec<Middleware>>,
|
|
}
|
|
|
|
impl Default for SocketIo {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl SocketIo {
|
|
pub fn builder() -> SocketIoBuilder {
|
|
SocketIoBuilder {
|
|
config: SocketIoConfig::default(),
|
|
adapter: MemoryAdapter::new(),
|
|
}
|
|
}
|
|
|
|
pub fn new() -> Self {
|
|
Self::builder().build()
|
|
}
|
|
|
|
pub fn config(&self) -> &SocketIoConfig {
|
|
&self.inner.config
|
|
}
|
|
|
|
pub async fn namespace(&self, name: impl Into<String>) -> Namespace {
|
|
let name = normalize_namespace(name.into());
|
|
self.ensure_namespace(&name).await;
|
|
Namespace {
|
|
io: self.clone(),
|
|
name,
|
|
}
|
|
}
|
|
|
|
pub async fn on_connect<F, Fut>(&self, handler: F)
|
|
where
|
|
F: Fn(Socket) -> Fut + Send + Sync + 'static,
|
|
Fut: Future<Output = ()> + Send + 'static,
|
|
{
|
|
self.namespace("/").await.on_connect(handler).await;
|
|
}
|
|
|
|
pub async fn on_disconnect<F, Fut>(&self, handler: F)
|
|
where
|
|
F: Fn(Socket, DisconnectReason) -> Fut + Send + Sync + 'static,
|
|
Fut: Future<Output = ()> + Send + 'static,
|
|
{
|
|
self.namespace("/").await.on_disconnect(handler).await;
|
|
}
|
|
|
|
pub async fn on<F, Fut>(&self, event: impl Into<String>, handler: F)
|
|
where
|
|
F: Fn(Socket, EventPayload) -> Fut + Send + Sync + 'static,
|
|
Fut: Future<Output = ()> + Send + 'static,
|
|
{
|
|
self.namespace("/").await.on(event, handler).await;
|
|
}
|
|
|
|
pub async fn emit<T: Serialize>(&self, event: &str, data: T) -> Result<()> {
|
|
self.namespace("/").await.emit(event, data).await
|
|
}
|
|
|
|
pub async fn emit_to_room<T: Serialize>(
|
|
&self,
|
|
room: &str,
|
|
event: &str,
|
|
data: T,
|
|
) -> Result<()> {
|
|
self.namespace("/")
|
|
.await
|
|
.to(room.to_owned())
|
|
.emit(event, data)
|
|
.await
|
|
}
|
|
|
|
pub async fn session(&self, sid: &str) -> Option<Arc<Session>> {
|
|
self.inner.sessions.read().await.get(sid).cloned()
|
|
}
|
|
|
|
pub async fn insert_session(&self, session: Arc<Session>) {
|
|
self.inner
|
|
.sessions
|
|
.write()
|
|
.await
|
|
.insert(session.engine_id.clone(), session);
|
|
}
|
|
|
|
pub async fn remove_session(
|
|
&self,
|
|
session: &Arc<Session>,
|
|
reason: DisconnectReason,
|
|
) {
|
|
self.inner.sessions.write().await.remove(&session.engine_id);
|
|
let namespaces = session
|
|
.namespaces
|
|
.lock()
|
|
.await
|
|
.keys()
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
for namespace in namespaces {
|
|
let _ = self
|
|
.disconnect_socket(&namespace, session, reason.clone())
|
|
.await;
|
|
}
|
|
}
|
|
|
|
pub async fn handle_socket_payload(
|
|
&self,
|
|
session: Arc<Session>,
|
|
payload: SocketPayload,
|
|
) -> Result<()> {
|
|
match payload {
|
|
SocketPayload::Text(text) => {
|
|
let packet = Packet::decode(&text)?;
|
|
if packet.expected_attachments == 0 {
|
|
self.handle_socket_packet(session, packet).await
|
|
} else {
|
|
*session.pending_binary.lock().await =
|
|
Some(PendingBinary { packet });
|
|
Ok(())
|
|
}
|
|
}
|
|
SocketPayload::Binary(bytes) => {
|
|
let mut pending = session.pending_binary.lock().await;
|
|
let Some(mut pending_binary) = pending.take() else {
|
|
return Err(SocketIoError::InvalidPacket(
|
|
"unexpected binary attachment".to_owned(),
|
|
));
|
|
};
|
|
pending_binary.packet.attachments.push(bytes);
|
|
if pending_binary.packet.attachments.len()
|
|
== pending_binary.packet.expected_attachments
|
|
{
|
|
drop(pending);
|
|
self.handle_socket_packet(session, pending_binary.packet)
|
|
.await
|
|
} else {
|
|
*pending = Some(pending_binary);
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_socket_packet(
|
|
&self,
|
|
session: Arc<Session>,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
match packet.packet_type {
|
|
PacketType::Connect => {
|
|
self.connect_namespace(session, packet).await
|
|
}
|
|
PacketType::Disconnect => {
|
|
self.disconnect_socket(
|
|
&packet.namespace,
|
|
&session,
|
|
DisconnectReason::Client,
|
|
)
|
|
.await
|
|
}
|
|
PacketType::Event | PacketType::BinaryEvent => {
|
|
self.dispatch_event(session, packet).await
|
|
}
|
|
PacketType::Ack | PacketType::BinaryAck => {
|
|
self.resolve_ack(session, packet).await
|
|
}
|
|
PacketType::ConnectError => Ok(()),
|
|
}
|
|
}
|
|
|
|
async fn connect_namespace(
|
|
&self,
|
|
session: Arc<Session>,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let namespace = normalize_namespace(packet.namespace);
|
|
let state = self.namespace_state(&namespace).await?;
|
|
let sid = uuid::Uuid::new_v4().to_string();
|
|
let socket = Socket {
|
|
io: self.clone(),
|
|
session: session.clone(),
|
|
namespace: namespace.clone(),
|
|
sid: sid.clone(),
|
|
};
|
|
|
|
for middleware in state.middleware.read().await.iter().cloned() {
|
|
if let Err(err) =
|
|
middleware(socket.clone(), packet.data.clone()).await
|
|
{
|
|
session
|
|
.enqueue_socket_packet(Packet::connect_error(
|
|
&namespace,
|
|
err.to_string(),
|
|
))
|
|
.await;
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
self.inner
|
|
.adapter
|
|
.add_socket(&namespace, &session.engine_id)
|
|
.await?;
|
|
session.namespaces.lock().await.insert(
|
|
namespace.clone(),
|
|
SocketState {
|
|
sid,
|
|
rooms: HashSet::new(),
|
|
auth: packet.data,
|
|
},
|
|
);
|
|
session
|
|
.enqueue_socket_packet(Packet::connect(
|
|
&namespace,
|
|
Some(json!({ "sid": socket.sid })),
|
|
))
|
|
.await;
|
|
|
|
if let Some(handler) = state.connect_handler.read().await.clone() {
|
|
handler(socket).await;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn disconnect_socket(
|
|
&self,
|
|
namespace: &str,
|
|
session: &Arc<Session>,
|
|
reason: DisconnectReason,
|
|
) -> Result<()> {
|
|
let namespace = normalize_namespace(namespace.to_owned());
|
|
let removed = session.namespaces.lock().await.remove(&namespace);
|
|
if let Some(socket_state) = removed {
|
|
self.inner
|
|
.adapter
|
|
.remove_socket(&namespace, &session.engine_id)
|
|
.await?;
|
|
if let Ok(state) = self.namespace_state(&namespace).await
|
|
&& let Some(handler) =
|
|
state.disconnect_handler.read().await.clone()
|
|
{
|
|
handler(
|
|
Socket {
|
|
io: self.clone(),
|
|
session: session.clone(),
|
|
namespace,
|
|
sid: socket_state.sid,
|
|
},
|
|
reason,
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn dispatch_event(
|
|
&self,
|
|
session: Arc<Session>,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let namespace = normalize_namespace(packet.namespace.clone());
|
|
let state = self.namespace_state(&namespace).await?;
|
|
let socket_state = session
|
|
.namespaces
|
|
.lock()
|
|
.await
|
|
.get(&namespace)
|
|
.map(|state| state.sid.clone())
|
|
.ok_or_else(|| {
|
|
SocketIoError::UnknownNamespace(namespace.clone())
|
|
})?;
|
|
let ack = packet
|
|
.id
|
|
.map(|id| AckSender::new(session.clone(), namespace.clone(), id));
|
|
let payload = packet.into_event_payload(ack)?;
|
|
let handler = state
|
|
.event_handlers
|
|
.read()
|
|
.await
|
|
.get(&payload.event)
|
|
.cloned();
|
|
if let Some(handler) = handler {
|
|
handler(
|
|
Socket {
|
|
io: self.clone(),
|
|
session,
|
|
namespace,
|
|
sid: socket_state,
|
|
},
|
|
payload,
|
|
)
|
|
.await;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn resolve_ack(
|
|
&self,
|
|
session: Arc<Session>,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let Some(id) = packet.id else {
|
|
return Ok(());
|
|
};
|
|
let args = match packet.data {
|
|
Some(Value::Array(values)) => values,
|
|
Some(value) => vec![value],
|
|
None => Vec::new(),
|
|
};
|
|
if let Some(sender) = session
|
|
.ack_waiters
|
|
.lock()
|
|
.await
|
|
.remove(&(normalize_namespace(packet.namespace), id))
|
|
{
|
|
let _ = sender.send(args);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn join(
|
|
&self,
|
|
namespace: &str,
|
|
engine_id: &str,
|
|
room: String,
|
|
) -> Result<()> {
|
|
let namespace = normalize_namespace(namespace.to_owned());
|
|
let session = self
|
|
.session(engine_id)
|
|
.await
|
|
.ok_or(SocketIoError::UnknownSession)?;
|
|
if let Some(state) = session.namespaces.lock().await.get_mut(&namespace)
|
|
{
|
|
state.rooms.insert(room.clone());
|
|
}
|
|
self.inner
|
|
.adapter
|
|
.add_to_room(&namespace, engine_id, &room)
|
|
.await
|
|
}
|
|
|
|
pub async fn leave(
|
|
&self,
|
|
namespace: &str,
|
|
engine_id: &str,
|
|
room: &str,
|
|
) -> Result<()> {
|
|
let namespace = normalize_namespace(namespace.to_owned());
|
|
let session = self
|
|
.session(engine_id)
|
|
.await
|
|
.ok_or(SocketIoError::UnknownSession)?;
|
|
if let Some(state) = session.namespaces.lock().await.get_mut(&namespace)
|
|
{
|
|
state.rooms.remove(room);
|
|
}
|
|
self.inner
|
|
.adapter
|
|
.remove_from_room(&namespace, engine_id, room)
|
|
.await
|
|
}
|
|
|
|
pub async fn emit_to_sid<T: Serialize>(
|
|
&self,
|
|
namespace: &str,
|
|
engine_id: &str,
|
|
event: &str,
|
|
data: T,
|
|
) -> Result<()> {
|
|
let args = value_to_args(serde_json::to_value(data)?);
|
|
self.emit_packet_to_sid(
|
|
engine_id,
|
|
Packet::event(namespace, event, args),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn emit_binary_to_sid(
|
|
&self,
|
|
namespace: &str,
|
|
engine_id: &str,
|
|
event: &str,
|
|
args: Vec<Value>,
|
|
binary: Vec<Vec<u8>>,
|
|
) -> Result<()> {
|
|
self.emit_packet_to_sid(
|
|
engine_id,
|
|
Packet::event(namespace, event, args).with_binary(binary),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn emit_to_sid_with_ack<T: Serialize>(
|
|
&self,
|
|
namespace: &str,
|
|
engine_id: &str,
|
|
event: &str,
|
|
data: T,
|
|
) -> Result<Vec<Value>> {
|
|
let session = self
|
|
.session(engine_id)
|
|
.await
|
|
.ok_or(SocketIoError::UnknownSession)?;
|
|
let id = self.inner.next_ack_id.fetch_add(1, Ordering::Relaxed);
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
session
|
|
.ack_waiters
|
|
.lock()
|
|
.await
|
|
.insert((normalize_namespace(namespace.to_owned()), id), tx);
|
|
let mut packet = Packet::event(
|
|
namespace,
|
|
event,
|
|
value_to_args(serde_json::to_value(data)?),
|
|
);
|
|
packet.id = Some(id);
|
|
session.enqueue_socket_packet(packet).await;
|
|
|
|
match tokio::time::timeout(self.inner.config.ack_timeout, rx).await {
|
|
Ok(Ok(values)) => Ok(values),
|
|
_ => {
|
|
session
|
|
.ack_waiters
|
|
.lock()
|
|
.await
|
|
.remove(&(normalize_namespace(namespace.to_owned()), id));
|
|
Err(SocketIoError::AckTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn broadcast_with_opts<T: Serialize>(
|
|
&self,
|
|
mut opts: BroadcastOptions,
|
|
event: &str,
|
|
data: T,
|
|
) -> Result<()> {
|
|
opts.namespace = normalize_namespace(opts.namespace);
|
|
let packet = Packet::event(
|
|
&opts.namespace,
|
|
event,
|
|
value_to_args(serde_json::to_value(data)?),
|
|
);
|
|
self.broadcast_packet(opts, packet).await
|
|
}
|
|
|
|
async fn broadcast_packet(
|
|
&self,
|
|
opts: BroadcastOptions,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let sockets =
|
|
self.inner.adapter.sockets(&opts.namespace, &opts).await?;
|
|
let mut failures = Vec::new();
|
|
for engine_id in sockets {
|
|
if let Err(err) =
|
|
self.emit_packet_to_sid(&engine_id, packet.clone()).await
|
|
{
|
|
failures.push(format!("{engine_id}: {err}"));
|
|
}
|
|
}
|
|
if let Err(err) = self.inner.adapter.publish(&packet, &opts).await {
|
|
failures.push(format!("adapter publish: {err}"));
|
|
}
|
|
if failures.is_empty() {
|
|
Ok(())
|
|
} else {
|
|
Err(SocketIoError::Adapter(format!(
|
|
"broadcast partially failed: {}",
|
|
failures.join("; ")
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub async fn deliver_remote_packet(
|
|
&self,
|
|
opts: BroadcastOptions,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let sockets =
|
|
self.inner.adapter.sockets(&opts.namespace, &opts).await?;
|
|
let mut failures = Vec::new();
|
|
for engine_id in sockets {
|
|
if let Err(err) =
|
|
self.emit_packet_to_sid(&engine_id, packet.clone()).await
|
|
{
|
|
failures.push(format!("{engine_id}: {err}"));
|
|
}
|
|
}
|
|
if failures.is_empty() {
|
|
Ok(())
|
|
} else {
|
|
Err(SocketIoError::Adapter(format!(
|
|
"remote broadcast partially failed: {}",
|
|
failures.join("; ")
|
|
)))
|
|
}
|
|
}
|
|
|
|
async fn emit_packet_to_sid(
|
|
&self,
|
|
engine_id: &str,
|
|
packet: Packet,
|
|
) -> Result<()> {
|
|
let session = self
|
|
.session(engine_id)
|
|
.await
|
|
.ok_or(SocketIoError::UnknownSession)?;
|
|
session.enqueue_socket_packet(packet).await;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn ensure_namespace(
|
|
&self,
|
|
namespace: &str,
|
|
) -> Arc<NamespaceState> {
|
|
let mut namespaces = self.inner.namespaces.write().await;
|
|
namespaces
|
|
.entry(namespace.to_owned())
|
|
.or_insert_with(|| Arc::new(NamespaceState::default()))
|
|
.clone()
|
|
}
|
|
|
|
async fn namespace_state(
|
|
&self,
|
|
namespace: &str,
|
|
) -> Result<Arc<NamespaceState>> {
|
|
self.inner
|
|
.namespaces
|
|
.read()
|
|
.await
|
|
.get(namespace)
|
|
.cloned()
|
|
.ok_or_else(|| {
|
|
SocketIoError::UnknownNamespace(namespace.to_owned())
|
|
})
|
|
}
|
|
}
|
|
|
|
fn value_to_args(value: Value) -> Vec<Value> {
|
|
match value {
|
|
Value::Array(values) => values,
|
|
value => vec![value],
|
|
}
|
|
}
|
|
|
|
fn normalize_namespace(namespace: String) -> String {
|
|
if namespace.is_empty() || namespace == "/" {
|
|
"/".to_owned()
|
|
} else if namespace.starts_with('/') {
|
|
namespace
|
|
} else {
|
|
format!("/{namespace}")
|
|
}
|
|
}
|