use socketio::{EventPayload, Socket}; use uuid::Uuid; use crate::{ChannelBus, ChannelError, ChannelResult}; use super::handler::WsHandler; use super::out_event::{WsError, WsOutEvent}; use super::types::WsInMessage; const CHANNEL_EVENT: &str = "channel.message"; pub async fn register_message_handler( bus: &ChannelBus, ) -> crate::ChannelResult<()> { let namespace = bus.inner.io.namespace(&bus.inner.config.namespace).await; let bus_clone = bus.clone(); namespace .on(CHANNEL_EVENT, move |socket, data: EventPayload| { let bus = bus_clone.clone(); async move { handle_inbound(&bus, &socket, data).await; } }) .await; Ok(()) } async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) { let user_id = match socket.session_user() { Some(id) => id, None => { tracing::warn!("channel message from unauthenticated socket"); send_error(socket, ChannelError::Unauthorized.to_ws_error()).await; return; } }; let payload = match data.args.first() { Some(v) => v, None => { tracing::warn!("channel message with empty args"); return; } }; let parsed = payload; let text = serde_json::to_string(payload).unwrap_or_default(); if parsed .get("type") .and_then(|t| t.as_str()) == Some("ping") { let pong = WsOutEvent::Pong { protocol_version: super::types::WS_PROTOCOL_VERSION, }; send_event(socket, &pong).await.ok(); return; } if !check_rate_limit(bus, user_id).await { send_error(socket, ChannelError::RateLimitExceeded.to_ws_error()).await; return; } if text.len() > super::handler::MAX_TEXT_LEN { send_error( socket, WsError { code: 422, error: "message_too_long".to_string(), message: "message exceeds maximum length".to_string(), }, ) .await; return; } let request_id: Option = parsed .get("_request_id") .and_then(|r| serde_json::from_value(r.clone()).ok()); match serde_json::from_value::(payload.clone()) { Ok(in_msg) => match WsHandler::handle(bus, user_id, in_msg).await { Ok(Some(event)) => { let rid = request_id.unwrap_or(Uuid::nil()); let resp = WsOutEvent::Response { request_id: rid, data: serde_json::to_value(&event).unwrap_or_default(), }; send_event(socket, &resp).await.ok(); } Ok(None) => { let rid = request_id.unwrap_or(Uuid::nil()); let ack = WsOutEvent::Response { request_id: rid, data: serde_json::json!({"ok": true}), }; send_event(socket, &ack).await.ok(); } Err(e) => { tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed"); let rid = request_id.unwrap_or(Uuid::nil()); let err_resp = WsOutEvent::Response { request_id: rid, data: serde_json::to_value(&e.to_ws_error()) .unwrap_or_default(), }; send_event(socket, &err_resp).await.ok(); } }, Err(e) => { tracing::warn!(error = %e, "WS transport parse error"); send_error( socket, WsError { code: 400, error: "parse_error".to_string(), message: e.to_string(), }, ) .await; } } } async fn check_rate_limit(bus: &ChannelBus, user_id: Uuid) -> bool { bus.inner .rate_limiter .check_rate_limit(user_id, "ws_message") .await .unwrap_or(true) } async fn send_event(socket: &Socket, event: &WsOutEvent) -> ChannelResult<()> { let json = serde_json::to_string(event)?; socket .emit(CHANNEL_EVENT, &json) .await .map_err(|e| { tracing::warn!(error = %e, "WS send failed"); ChannelError::SocketIo(e) }) } async fn send_error(socket: &Socket, error: WsError) { let json = serde_json::json!({ "type": "error", "code": error.code, "error": error.error, "message": error.message, }); if let Err(e) = socket.emit(CHANNEL_EVENT, json.to_string()).await { tracing::warn!(error = %e, "WS error send failed"); } }