use std::{ collections::{HashMap, HashSet}, sync::Arc, }; use async_trait::async_trait; use tokio::sync::RwLock; use crate::{error::Result, packet::Packet}; #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)] pub struct BroadcastOptions { pub namespace: String, pub rooms: HashSet, pub except: HashSet, pub skip_sid: Option, } #[async_trait] pub trait Adapter: Send + Sync { async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()>; async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()>; async fn add_to_room( &self, namespace: &str, sid: &str, room: &str, ) -> Result<()>; async fn remove_from_room( &self, namespace: &str, sid: &str, room: &str, ) -> Result<()>; async fn sockets( &self, namespace: &str, opts: &BroadcastOptions, ) -> Result>; async fn publish( &self, packet: &Packet, opts: &BroadcastOptions, ) -> Result<()>; } #[derive(Default)] pub struct MemoryAdapter { state: RwLock>, } #[derive(Default)] struct NamespaceRooms { sockets: HashSet, rooms: HashMap>, } impl MemoryAdapter { pub fn new() -> Arc { Arc::new(Self::default()) } } #[async_trait] impl Adapter for MemoryAdapter { async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> { let mut state = self.state.write().await; state .entry(namespace.to_owned()) .or_default() .sockets .insert(sid.to_owned()); Ok(()) } async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> { let mut state = self.state.write().await; if let Some(ns) = state.get_mut(namespace) { ns.sockets.remove(sid); ns.rooms.retain(|_, sockets| { sockets.remove(sid); !sockets.is_empty() }); } Ok(()) } async fn add_to_room( &self, namespace: &str, sid: &str, room: &str, ) -> Result<()> { let mut state = self.state.write().await; let ns = state.entry(namespace.to_owned()).or_default(); ns.sockets.insert(sid.to_owned()); ns.rooms .entry(room.to_owned()) .or_default() .insert(sid.to_owned()); Ok(()) } async fn remove_from_room( &self, namespace: &str, sid: &str, room: &str, ) -> Result<()> { let mut state = self.state.write().await; if let Some(ns) = state.get_mut(namespace) && let Some(sockets) = ns.rooms.get_mut(room) { sockets.remove(sid); } Ok(()) } async fn sockets( &self, namespace: &str, opts: &BroadcastOptions, ) -> Result> { let state = self.state.read().await; let Some(ns) = state.get(namespace) else { return Ok(Vec::new()); }; let mut selected: HashSet = if opts.rooms.is_empty() { ns.sockets.iter().cloned().collect() } else { opts.rooms .iter() .filter_map(|room| ns.rooms.get(room)) .flat_map(|sockets| sockets.iter().cloned()) .collect() }; for room in &opts.except { if let Some(excluded) = ns.rooms.get(room) { for sid in excluded { selected.remove(sid); } } } if let Some(skip_sid) = &opts.skip_sid { selected.remove(skip_sid); } Ok(selected.into_iter().collect()) } async fn publish( &self, _packet: &Packet, _opts: &BroadcastOptions, ) -> Result<()> { Ok(()) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn memory_adapter_filters_rooms_and_except() { let adapter = MemoryAdapter::new(); adapter.add_socket("/", "s1").await.unwrap(); adapter.add_socket("/", "s2").await.unwrap(); adapter.add_socket("/", "s3").await.unwrap(); adapter.add_to_room("/", "s1", "room-a").await.unwrap(); adapter.add_to_room("/", "s2", "room-a").await.unwrap(); adapter.add_to_room("/", "s2", "muted").await.unwrap(); let opts = BroadcastOptions { namespace: "/".to_owned(), rooms: HashSet::from(["room-a".to_owned()]), except: HashSet::from(["muted".to_owned()]), skip_sid: None, }; assert_eq!(adapter.sockets("/", &opts).await.unwrap(), vec!["s1"]); } #[tokio::test] async fn memory_adapter_can_skip_sender() { let adapter = MemoryAdapter::new(); adapter.add_socket("/", "s1").await.unwrap(); adapter.add_socket("/", "s2").await.unwrap(); let opts = BroadcastOptions { namespace: "/".to_owned(), skip_sid: Some("s1".to_owned()), ..BroadcastOptions::default() }; let sockets = adapter.sockets("/", &opts).await.unwrap(); assert_eq!(sockets, vec!["s2"]); } }