gitdataai/lib/socketio/adapter.rs
2026-05-30 01:38:40 +08:00

202 lines
5.3 KiB
Rust

use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::{error::Result, packet::Packet};
#[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
pub struct BroadcastOptions {
pub namespace: String,
pub rooms: HashSet<String>,
pub except: HashSet<String>,
pub skip_sid: Option<String>,
}
#[async_trait]
pub trait Adapter: Send + Sync {
async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()>;
async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()>;
async fn add_to_room(
&self,
namespace: &str,
sid: &str,
room: &str,
) -> Result<()>;
async fn remove_from_room(
&self,
namespace: &str,
sid: &str,
room: &str,
) -> Result<()>;
async fn sockets(
&self,
namespace: &str,
opts: &BroadcastOptions,
) -> Result<Vec<String>>;
async fn publish(
&self,
packet: &Packet,
opts: &BroadcastOptions,
) -> Result<()>;
}
#[derive(Default)]
pub struct MemoryAdapter {
state: RwLock<HashMap<String, NamespaceRooms>>,
}
#[derive(Default)]
struct NamespaceRooms {
sockets: HashSet<String>,
rooms: HashMap<String, HashSet<String>>,
}
impl MemoryAdapter {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
}
#[async_trait]
impl Adapter for MemoryAdapter {
async fn add_socket(&self, namespace: &str, sid: &str) -> Result<()> {
let mut state = self.state.write().await;
state
.entry(namespace.to_owned())
.or_default()
.sockets
.insert(sid.to_owned());
Ok(())
}
async fn remove_socket(&self, namespace: &str, sid: &str) -> Result<()> {
let mut state = self.state.write().await;
if let Some(ns) = state.get_mut(namespace) {
ns.sockets.remove(sid);
ns.rooms.retain(|_, sockets| {
sockets.remove(sid);
!sockets.is_empty()
});
}
Ok(())
}
async fn add_to_room(
&self,
namespace: &str,
sid: &str,
room: &str,
) -> Result<()> {
let mut state = self.state.write().await;
let ns = state.entry(namespace.to_owned()).or_default();
ns.sockets.insert(sid.to_owned());
ns.rooms
.entry(room.to_owned())
.or_default()
.insert(sid.to_owned());
Ok(())
}
async fn remove_from_room(
&self,
namespace: &str,
sid: &str,
room: &str,
) -> Result<()> {
let mut state = self.state.write().await;
if let Some(ns) = state.get_mut(namespace)
&& let Some(sockets) = ns.rooms.get_mut(room)
{
sockets.remove(sid);
}
Ok(())
}
async fn sockets(
&self,
namespace: &str,
opts: &BroadcastOptions,
) -> Result<Vec<String>> {
let state = self.state.read().await;
let Some(ns) = state.get(namespace) else {
return Ok(Vec::new());
};
let mut selected: HashSet<String> = if opts.rooms.is_empty() {
ns.sockets.iter().cloned().collect()
} else {
opts.rooms
.iter()
.filter_map(|room| ns.rooms.get(room))
.flat_map(|sockets| sockets.iter().cloned())
.collect()
};
for room in &opts.except {
if let Some(excluded) = ns.rooms.get(room) {
for sid in excluded {
selected.remove(sid);
}
}
}
if let Some(skip_sid) = &opts.skip_sid {
selected.remove(skip_sid);
}
Ok(selected.into_iter().collect())
}
async fn publish(
&self,
_packet: &Packet,
_opts: &BroadcastOptions,
) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn memory_adapter_filters_rooms_and_except() {
let adapter = MemoryAdapter::new();
adapter.add_socket("/", "s1").await.unwrap();
adapter.add_socket("/", "s2").await.unwrap();
adapter.add_socket("/", "s3").await.unwrap();
adapter.add_to_room("/", "s1", "room-a").await.unwrap();
adapter.add_to_room("/", "s2", "room-a").await.unwrap();
adapter.add_to_room("/", "s2", "muted").await.unwrap();
let opts = BroadcastOptions {
namespace: "/".to_owned(),
rooms: HashSet::from(["room-a".to_owned()]),
except: HashSet::from(["muted".to_owned()]),
skip_sid: None,
};
assert_eq!(adapter.sockets("/", &opts).await.unwrap(), vec!["s1"]);
}
#[tokio::test]
async fn memory_adapter_can_skip_sender() {
let adapter = MemoryAdapter::new();
adapter.add_socket("/", "s1").await.unwrap();
adapter.add_socket("/", "s2").await.unwrap();
let opts = BroadcastOptions {
namespace: "/".to_owned(),
skip_sid: Some("s1".to_owned()),
..BroadcastOptions::default()
};
let sockets = adapter.sockets("/", &opts).await.unwrap();
assert_eq!(sockets, vec!["s2"]);
}
}