202 lines
5.3 KiB
Rust
202 lines
5.3 KiB
Rust
use std::{
|
|
collections::{HashMap, HashSet},
|
|
sync::Arc,
|
|
};
|
|
|
|
use async_trait::async_trait;
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::{error::Result, packet::Packet};
|
|
|
|
#[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
|
|
pub struct BroadcastOptions {
|
|
pub namespace: String,
|
|
pub rooms: HashSet<String>,
|
|
pub except: HashSet<String>,
|
|
pub skip_sid: Option<String>,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait Adapter: Send + Sync {
|
|
async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()>;
|
|
async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()>;
|
|
async fn add_to_room(
|
|
&self,
|
|
namespace: &str,
|
|
sid: &str,
|
|
room: &str,
|
|
) -> Result<()>;
|
|
async fn remove_from_room(
|
|
&self,
|
|
namespace: &str,
|
|
sid: &str,
|
|
room: &str,
|
|
) -> Result<()>;
|
|
async fn sockets(
|
|
&self,
|
|
namespace: &str,
|
|
opts: &BroadcastOptions,
|
|
) -> Result<Vec<String>>;
|
|
async fn publish(
|
|
&self,
|
|
packet: &Packet,
|
|
opts: &BroadcastOptions,
|
|
) -> Result<()>;
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub struct MemoryAdapter {
|
|
state: RwLock<HashMap<String, NamespaceRooms>>,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct NamespaceRooms {
|
|
sockets: HashSet<String>,
|
|
rooms: HashMap<String, HashSet<String>>,
|
|
}
|
|
|
|
impl MemoryAdapter {
|
|
pub fn new() -> Arc<Self> {
|
|
Arc::new(Self::default())
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Adapter for MemoryAdapter {
|
|
async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> {
|
|
let mut state = self.state.write().await;
|
|
state
|
|
.entry(namespace.to_owned())
|
|
.or_default()
|
|
.sockets
|
|
.insert(sid.to_owned());
|
|
Ok(())
|
|
}
|
|
|
|
async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> {
|
|
let mut state = self.state.write().await;
|
|
if let Some(ns) = state.get_mut(namespace) {
|
|
ns.sockets.remove(sid);
|
|
ns.rooms.retain(|_, sockets| {
|
|
sockets.remove(sid);
|
|
!sockets.is_empty()
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn add_to_room(
|
|
&self,
|
|
namespace: &str,
|
|
sid: &str,
|
|
room: &str,
|
|
) -> Result<()> {
|
|
let mut state = self.state.write().await;
|
|
let ns = state.entry(namespace.to_owned()).or_default();
|
|
ns.sockets.insert(sid.to_owned());
|
|
ns.rooms
|
|
.entry(room.to_owned())
|
|
.or_default()
|
|
.insert(sid.to_owned());
|
|
Ok(())
|
|
}
|
|
|
|
async fn remove_from_room(
|
|
&self,
|
|
namespace: &str,
|
|
sid: &str,
|
|
room: &str,
|
|
) -> Result<()> {
|
|
let mut state = self.state.write().await;
|
|
if let Some(ns) = state.get_mut(namespace)
|
|
&& let Some(sockets) = ns.rooms.get_mut(room)
|
|
{
|
|
sockets.remove(sid);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn sockets(
|
|
&self,
|
|
namespace: &str,
|
|
opts: &BroadcastOptions,
|
|
) -> Result<Vec<String>> {
|
|
let state = self.state.read().await;
|
|
let Some(ns) = state.get(namespace) else {
|
|
return Ok(Vec::new());
|
|
};
|
|
|
|
let mut selected: HashSet<String> = if opts.rooms.is_empty() {
|
|
ns.sockets.iter().cloned().collect()
|
|
} else {
|
|
opts.rooms
|
|
.iter()
|
|
.filter_map(|room| ns.rooms.get(room))
|
|
.flat_map(|sockets| sockets.iter().cloned())
|
|
.collect()
|
|
};
|
|
|
|
for room in &opts.except {
|
|
if let Some(excluded) = ns.rooms.get(room) {
|
|
for sid in excluded {
|
|
selected.remove(sid);
|
|
}
|
|
}
|
|
}
|
|
if let Some(skip_sid) = &opts.skip_sid {
|
|
selected.remove(skip_sid);
|
|
}
|
|
|
|
Ok(selected.into_iter().collect())
|
|
}
|
|
|
|
async fn publish(
|
|
&self,
|
|
_packet: &Packet,
|
|
_opts: &BroadcastOptions,
|
|
) -> Result<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn memory_adapter_filters_rooms_and_except() {
|
|
let adapter = MemoryAdapter::new();
|
|
adapter.add_socket("/", "s1").await.unwrap();
|
|
adapter.add_socket("/", "s2").await.unwrap();
|
|
adapter.add_socket("/", "s3").await.unwrap();
|
|
adapter.add_to_room("/", "s1", "room-a").await.unwrap();
|
|
adapter.add_to_room("/", "s2", "room-a").await.unwrap();
|
|
adapter.add_to_room("/", "s2", "muted").await.unwrap();
|
|
|
|
let opts = BroadcastOptions {
|
|
namespace: "/".to_owned(),
|
|
rooms: HashSet::from(["room-a".to_owned()]),
|
|
except: HashSet::from(["muted".to_owned()]),
|
|
skip_sid: None,
|
|
};
|
|
|
|
assert_eq!(adapter.sockets("/", &opts).await.unwrap(), vec!["s1"]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn memory_adapter_can_skip_sender() {
|
|
let adapter = MemoryAdapter::new();
|
|
adapter.add_socket("/", "s1").await.unwrap();
|
|
adapter.add_socket("/", "s2").await.unwrap();
|
|
|
|
let opts = BroadcastOptions {
|
|
namespace: "/".to_owned(),
|
|
skip_sid: Some("s1".to_owned()),
|
|
..BroadcastOptions::default()
|
|
};
|
|
let sockets = adapter.sockets("/", &opts).await.unwrap();
|
|
|
|
assert_eq!(sockets, vec!["s2"]);
|
|
}
|
|
}
|