use std::{ collections::{HashMap, HashSet}, future::Future, pin::Pin, sync::{ Arc, atomic::{AtomicU64, Ordering}, }, }; use serde::Serialize; use serde_json::{Value, json}; use tokio::sync::RwLock; use crate::{ adapter::{Adapter, BroadcastOptions, MemoryAdapter}, config::SocketIoConfig, engine_packet::SocketPayload, error::{Result, SocketIoError}, packet::{EventPayload, Packet, PacketType}, session::{PendingBinary, Session, SocketState}, socket::{AckSender, DisconnectReason, Socket}, }; pub type BoxFuture = Pin + Send>>; pub type ConnectHandler = Arc BoxFuture + Send + Sync>; pub type DisconnectHandler = Arc BoxFuture + Send + Sync>; pub type EventHandler = Arc BoxFuture + Send + Sync>; pub type Middleware = Arc< dyn Fn( Socket, Option, ) -> Pin> + Send>> + Send + Sync, >; #[derive(Clone)] pub struct SocketIo { pub inner: Arc, } #[derive(Clone)] pub struct Namespace { pub io: SocketIo, pub name: String, } pub struct SocketIoBuilder { pub config: SocketIoConfig, pub adapter: Arc, } pub struct Inner { pub config: SocketIoConfig, pub sessions: RwLock>>, pub namespaces: RwLock>>, pub adapter: Arc, pub next_ack_id: AtomicU64, } pub struct NamespaceState { pub connect_handler: RwLock>, pub disconnect_handler: RwLock>, pub event_handlers: RwLock>, pub middleware: RwLock>, } impl Default for SocketIo { fn default() -> Self { Self::new() } } impl SocketIo { pub fn builder() -> SocketIoBuilder { SocketIoBuilder { config: SocketIoConfig::default(), adapter: MemoryAdapter::new(), } } pub fn new() -> Self { Self::builder().build() } pub fn config(&self) -> &SocketIoConfig { &self.inner.config } pub async fn namespace(&self, name: impl Into) -> Namespace { let name = normalize_namespace(name.into()); self.ensure_namespace(&name).await; Namespace { io: self.clone(), name, } } pub async fn on_connect(&self, handler: F) where F: Fn(Socket) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { self.namespace("/").await.on_connect(handler).await; } pub async fn on_disconnect(&self, handler: F) where F: Fn(Socket, DisconnectReason) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { self.namespace("/").await.on_disconnect(handler).await; } pub async fn on(&self, event: impl Into, handler: F) where F: Fn(Socket, EventPayload) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { self.namespace("/").await.on(event, handler).await; } pub async fn emit(&self, event: &str, data: T) -> Result<()> { self.namespace("/").await.emit(event, data).await } pub async fn emit_to_room( &self, room: &str, event: &str, data: T, ) -> Result<()> { self.namespace("/") .await .to(room.to_owned()) .emit(event, data) .await } pub async fn session(&self, sid: &str) -> Option> { self.inner.sessions.read().await.get(sid).cloned() } pub async fn insert_session(&self, session: Arc) { self.inner .sessions .write() .await .insert(session.engine_id.clone(), session); } pub async fn remove_session( &self, session: &Arc, reason: DisconnectReason, ) { self.inner.sessions.write().await.remove(&session.engine_id); let namespaces = session .namespaces .lock() .await .keys() .cloned() .collect::>(); for namespace in namespaces { let _ = self .disconnect_socket(&namespace, session, reason.clone()) .await; } } pub async fn handle_socket_payload( &self, session: Arc, payload: SocketPayload, ) -> Result<()> { match payload { SocketPayload::Text(text) => { let packet = Packet::decode(&text)?; if packet.expected_attachments == 0 { self.handle_socket_packet(session, packet).await } else { *session.pending_binary.lock().await = Some(PendingBinary { packet }); Ok(()) } } SocketPayload::Binary(bytes) => { let mut pending = session.pending_binary.lock().await; let Some(mut pending_binary) = pending.take() else { return Err(SocketIoError::InvalidPacket( "unexpected binary attachment".to_owned(), )); }; pending_binary.packet.attachments.push(bytes); if pending_binary.packet.attachments.len() == pending_binary.packet.expected_attachments { drop(pending); self.handle_socket_packet(session, pending_binary.packet) .await } else { *pending = Some(pending_binary); Ok(()) } } } } async fn handle_socket_packet( &self, session: Arc, packet: Packet, ) -> Result<()> { match packet.packet_type { PacketType::Connect => { self.connect_namespace(session, packet).await } PacketType::Disconnect => { self.disconnect_socket( &packet.namespace, &session, DisconnectReason::Client, ) .await } PacketType::Event | PacketType::BinaryEvent => { self.dispatch_event(session, packet).await } PacketType::Ack | PacketType::BinaryAck => { self.resolve_ack(session, packet).await } PacketType::ConnectError => Ok(()), } } async fn connect_namespace( &self, session: Arc, packet: Packet, ) -> Result<()> { let namespace = normalize_namespace(packet.namespace); let state = self.namespace_state(&namespace).await?; let sid = uuid::Uuid::new_v4().to_string(); let socket = Socket { io: self.clone(), session: session.clone(), namespace: namespace.clone(), sid: sid.clone(), }; for middleware in state.middleware.read().await.iter().cloned() { if let Err(err) = middleware(socket.clone(), packet.data.clone()).await { session .enqueue_socket_packet(Packet::connect_error( &namespace, err.to_string(), )) .await; return Ok(()); } } self.inner .adapter .add_socket(&namespace, &session.engine_id) .await?; session.namespaces.lock().await.insert( namespace.clone(), SocketState { sid, rooms: HashSet::new(), auth: packet.data, }, ); session .enqueue_socket_packet(Packet::connect( &namespace, Some(json!({ "sid": socket.sid })), )) .await; if let Some(handler) = state.connect_handler.read().await.clone() { handler(socket).await; } Ok(()) } pub async fn disconnect_socket( &self, namespace: &str, session: &Arc, reason: DisconnectReason, ) -> Result<()> { let namespace = normalize_namespace(namespace.to_owned()); let removed = session.namespaces.lock().await.remove(&namespace); if let Some(socket_state) = removed { self.inner .adapter .remove_socket(&namespace, &session.engine_id) .await?; if let Ok(state) = self.namespace_state(&namespace).await && let Some(handler) = state.disconnect_handler.read().await.clone() { handler( Socket { io: self.clone(), session: session.clone(), namespace, sid: socket_state.sid, }, reason, ) .await; } } Ok(()) } async fn dispatch_event( &self, session: Arc, packet: Packet, ) -> Result<()> { let namespace = normalize_namespace(packet.namespace.clone()); let state = self.namespace_state(&namespace).await?; let socket_state = session .namespaces .lock() .await .get(&namespace) .map(|state| state.sid.clone()) .ok_or_else(|| { SocketIoError::UnknownNamespace(namespace.clone()) })?; let ack = packet .id .map(|id| AckSender::new(session.clone(), namespace.clone(), id)); let payload = packet.into_event_payload(ack)?; let handler = state .event_handlers .read() .await .get(&payload.event) .cloned(); if let Some(handler) = handler { handler( Socket { io: self.clone(), session, namespace, sid: socket_state, }, payload, ) .await; } Ok(()) } async fn resolve_ack( &self, session: Arc, packet: Packet, ) -> Result<()> { let Some(id) = packet.id else { return Ok(()); }; let args = match packet.data { Some(Value::Array(values)) => values, Some(value) => vec![value], None => Vec::new(), }; if let Some(sender) = session .ack_waiters .lock() .await .remove(&(normalize_namespace(packet.namespace), id)) { let _ = sender.send(args); } Ok(()) } pub async fn join( &self, namespace: &str, engine_id: &str, room: String, ) -> Result<()> { let namespace = normalize_namespace(namespace.to_owned()); let session = self .session(engine_id) .await .ok_or(SocketIoError::UnknownSession)?; if let Some(state) = session.namespaces.lock().await.get_mut(&namespace) { state.rooms.insert(room.clone()); } self.inner .adapter .add_to_room(&namespace, engine_id, &room) .await } pub async fn leave( &self, namespace: &str, engine_id: &str, room: &str, ) -> Result<()> { let namespace = normalize_namespace(namespace.to_owned()); let session = self .session(engine_id) .await .ok_or(SocketIoError::UnknownSession)?; if let Some(state) = session.namespaces.lock().await.get_mut(&namespace) { state.rooms.remove(room); } self.inner .adapter .remove_from_room(&namespace, engine_id, room) .await } pub async fn emit_to_sid( &self, namespace: &str, engine_id: &str, event: &str, data: T, ) -> Result<()> { let args = value_to_args(serde_json::to_value(data)?); self.emit_packet_to_sid( engine_id, Packet::event(namespace, event, args), ) .await } pub async fn emit_binary_to_sid( &self, namespace: &str, engine_id: &str, event: &str, args: Vec, binary: Vec>, ) -> Result<()> { self.emit_packet_to_sid( engine_id, Packet::event(namespace, event, args).with_binary(binary), ) .await } pub async fn emit_to_sid_with_ack( &self, namespace: &str, engine_id: &str, event: &str, data: T, ) -> Result> { let session = self .session(engine_id) .await .ok_or(SocketIoError::UnknownSession)?; let id = self.inner.next_ack_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = tokio::sync::oneshot::channel(); session .ack_waiters .lock() .await .insert((normalize_namespace(namespace.to_owned()), id), tx); let mut packet = Packet::event( namespace, event, value_to_args(serde_json::to_value(data)?), ); packet.id = Some(id); session.enqueue_socket_packet(packet).await; match tokio::time::timeout(self.inner.config.ack_timeout, rx).await { Ok(Ok(values)) => Ok(values), _ => { session .ack_waiters .lock() .await .remove(&(normalize_namespace(namespace.to_owned()), id)); Err(SocketIoError::AckTimeout) } } } pub async fn broadcast_with_opts( &self, mut opts: BroadcastOptions, event: &str, data: T, ) -> Result<()> { opts.namespace = normalize_namespace(opts.namespace); let packet = Packet::event( &opts.namespace, event, value_to_args(serde_json::to_value(data)?), ); self.broadcast_packet(opts, packet).await } async fn broadcast_packet( &self, opts: BroadcastOptions, packet: Packet, ) -> Result<()> { let sockets = self.inner.adapter.sockets(&opts.namespace, &opts).await?; let mut failures = Vec::new(); for engine_id in sockets { if let Err(err) = self.emit_packet_to_sid(&engine_id, packet.clone()).await { failures.push(format!("{engine_id}: {err}")); } } if let Err(err) = self.inner.adapter.publish(&packet, &opts).await { failures.push(format!("adapter publish: {err}")); } if failures.is_empty() { Ok(()) } else { Err(SocketIoError::Adapter(format!( "broadcast partially failed: {}", failures.join("; ") ))) } } pub async fn deliver_remote_packet( &self, opts: BroadcastOptions, packet: Packet, ) -> Result<()> { let sockets = self.inner.adapter.sockets(&opts.namespace, &opts).await?; let mut failures = Vec::new(); for engine_id in sockets { if let Err(err) = self.emit_packet_to_sid(&engine_id, packet.clone()).await { failures.push(format!("{engine_id}: {err}")); } } if failures.is_empty() { Ok(()) } else { Err(SocketIoError::Adapter(format!( "remote broadcast partially failed: {}", failures.join("; ") ))) } } async fn emit_packet_to_sid( &self, engine_id: &str, packet: Packet, ) -> Result<()> { let session = self .session(engine_id) .await .ok_or(SocketIoError::UnknownSession)?; session.enqueue_socket_packet(packet).await; Ok(()) } pub async fn ensure_namespace( &self, namespace: &str, ) -> Arc { let mut namespaces = self.inner.namespaces.write().await; namespaces .entry(namespace.to_owned()) .or_insert_with(|| Arc::new(NamespaceState::default())) .clone() } async fn namespace_state( &self, namespace: &str, ) -> Result> { self.inner .namespaces .read() .await .get(namespace) .cloned() .ok_or_else(|| { SocketIoError::UnknownNamespace(namespace.to_owned()) }) } } fn value_to_args(value: Value) -> Vec { match value { Value::Array(values) => values, value => vec![value], } } fn normalize_namespace(namespace: String) -> String { if namespace.is_empty() || namespace == "/" { "/".to_owned() } else if namespace.starts_with('/') { namespace } else { format!("/{namespace}") } }