160 lines
5.1 KiB
Rust
160 lines
5.1 KiB
Rust
use socketio::{EventPayload, Socket};
|
|
use uuid::Uuid;
|
|
|
|
use crate::{ChannelBus, ChannelError};
|
|
|
|
use super::handler::WsHandler;
|
|
use super::out_event::{WsError, WsOutEvent};
|
|
use super::types::WsInMessage;
|
|
|
|
const CHANNEL_EVENT: &str = "channel.message";
|
|
|
|
pub async fn register_message_handler(
|
|
bus: &ChannelBus,
|
|
) -> crate::ChannelResult<()> {
|
|
let namespace = bus.inner.io.namespace(&bus.inner.config.namespace).await;
|
|
|
|
let bus_clone = bus.clone();
|
|
namespace
|
|
.on(CHANNEL_EVENT, move |socket, data: EventPayload| {
|
|
let bus = bus_clone.clone();
|
|
async move {
|
|
handle_inbound(&bus, &socket, data).await;
|
|
}
|
|
})
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) {
|
|
let user_id = match socket.session_user() {
|
|
Some(id) => id,
|
|
None => {
|
|
tracing::warn!("channel message from unauthenticated socket");
|
|
send_error(socket, ChannelError::Unauthorized.to_ws_error()).await;
|
|
return;
|
|
}
|
|
};
|
|
let payload = match data.args.first() {
|
|
Some(v) => v,
|
|
None => {
|
|
tracing::warn!("channel message with empty args");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let parsed = payload;
|
|
|
|
let text = serde_json::to_string(payload).unwrap_or_default();
|
|
if parsed.get("type").and_then(|t| t.as_str()) == Some("ping") {
|
|
let pong = WsOutEvent::Pong {
|
|
protocol_version: super::types::WS_PROTOCOL_VERSION,
|
|
};
|
|
send_event(socket, &pong).await;
|
|
return;
|
|
}
|
|
if !check_rate_limit(bus, user_id).await {
|
|
send_error(socket, ChannelError::RateLimitExceeded.to_ws_error()).await;
|
|
return;
|
|
}
|
|
if text.len() > super::handler::MAX_TEXT_LEN {
|
|
send_error(
|
|
socket,
|
|
WsError {
|
|
code: 422,
|
|
error: "message_too_long".to_string(),
|
|
message: "message exceeds maximum length".to_string(),
|
|
},
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
let request_id: Option<Uuid> = parsed
|
|
.get("_request_id")
|
|
.and_then(|r| serde_json::from_value(r.clone()).ok());
|
|
match serde_json::from_value::<WsInMessage>(payload.clone()) {
|
|
Ok(in_msg) => match WsHandler::handle(bus, user_id, in_msg).await {
|
|
Ok(Some(event)) => {
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let resp = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::to_value(&event).unwrap_or_default(),
|
|
};
|
|
send_event(socket, &resp).await;
|
|
}
|
|
Ok(None) => {
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let ack = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::json!({"ok": true}),
|
|
};
|
|
send_event(socket, &ack).await;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed");
|
|
let rid = request_id.unwrap_or(Uuid::nil());
|
|
let err_resp = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::to_value(&e.to_ws_error())
|
|
.unwrap_or_default(),
|
|
};
|
|
send_event(socket, &err_resp).await;
|
|
}
|
|
},
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "WS transport parse error");
|
|
if let Some(rid) = request_id {
|
|
let err_resp = WsOutEvent::Response {
|
|
request_id: rid,
|
|
data: serde_json::to_value(&WsError {
|
|
code: 400,
|
|
error: "parse_error".to_string(),
|
|
message: e.to_string(),
|
|
})
|
|
.unwrap_or_default(),
|
|
};
|
|
send_event(socket, &err_resp).await;
|
|
} else {
|
|
send_error(
|
|
socket,
|
|
WsError {
|
|
code: 400,
|
|
error: "parse_error".to_string(),
|
|
message: e.to_string(),
|
|
},
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn check_rate_limit(bus: &ChannelBus, user_id: Uuid) -> bool {
|
|
bus.inner
|
|
.rate_limiter
|
|
.check_rate_limit(user_id, "ws_message")
|
|
.await
|
|
.unwrap_or(true)
|
|
}
|
|
|
|
async fn send_event(socket: &Socket, event: &WsOutEvent) {
|
|
// Serialize to Value (not String) to avoid double JSON encoding
|
|
let value = serde_json::to_value(event).unwrap_or_default();
|
|
if let Err(e) = socket.emit(CHANNEL_EVENT, value).await {
|
|
tracing::warn!(error = %e, "WS response send failed");
|
|
}
|
|
}
|
|
|
|
async fn send_error(socket: &Socket, error: WsError) {
|
|
let value = serde_json::json!({
|
|
"type": "error",
|
|
"code": error.code,
|
|
"error": error.error,
|
|
"message": error.message,
|
|
});
|
|
if let Err(e) = socket.emit(CHANNEL_EVENT, value).await {
|
|
tracing::warn!(error = %e, "WS error send failed");
|
|
}
|
|
}
|