gitdataai/lib/channel/http/ws.rs

160 lines
5.1 KiB
Rust

use socketio::{EventPayload, Socket};
use uuid::Uuid;
use crate::{ChannelBus, ChannelError};
use super::handler::WsHandler;
use super::out_event::{WsError, WsOutEvent};
use super::types::WsInMessage;
const CHANNEL_EVENT: &str = "channel.message";
pub async fn register_message_handler(
bus: &ChannelBus,
) -> crate::ChannelResult<()> {
let namespace = bus.inner.io.namespace(&bus.inner.config.namespace).await;
let bus_clone = bus.clone();
namespace
.on(CHANNEL_EVENT, move |socket, data: EventPayload| {
let bus = bus_clone.clone();
async move {
handle_inbound(&bus, &socket, data).await;
}
})
.await;
Ok(())
}
async fn handle_inbound(bus: &ChannelBus, socket: &Socket, data: EventPayload) {
let user_id = match socket.session_user() {
Some(id) => id,
None => {
tracing::warn!("channel message from unauthenticated socket");
send_error(socket, ChannelError::Unauthorized.to_ws_error()).await;
return;
}
};
let payload = match data.args.first() {
Some(v) => v,
None => {
tracing::warn!("channel message with empty args");
return;
}
};
let parsed = payload;
let text = serde_json::to_string(payload).unwrap_or_default();
if parsed.get("type").and_then(|t| t.as_str()) == Some("ping") {
let pong = WsOutEvent::Pong {
protocol_version: super::types::WS_PROTOCOL_VERSION,
};
send_event(socket, &pong).await;
return;
}
if !check_rate_limit(bus, user_id).await {
send_error(socket, ChannelError::RateLimitExceeded.to_ws_error()).await;
return;
}
if text.len() > super::handler::MAX_TEXT_LEN {
send_error(
socket,
WsError {
code: 422,
error: "message_too_long".to_string(),
message: "message exceeds maximum length".to_string(),
},
)
.await;
return;
}
let request_id: Option<Uuid> = parsed
.get("_request_id")
.and_then(|r| serde_json::from_value(r.clone()).ok());
match serde_json::from_value::<WsInMessage>(payload.clone()) {
Ok(in_msg) => match WsHandler::handle(bus, user_id, in_msg).await {
Ok(Some(event)) => {
let rid = request_id.unwrap_or(Uuid::nil());
let resp = WsOutEvent::Response {
request_id: rid,
data: serde_json::to_value(&event).unwrap_or_default(),
};
send_event(socket, &resp).await;
}
Ok(None) => {
let rid = request_id.unwrap_or(Uuid::nil());
let ack = WsOutEvent::Response {
request_id: rid,
data: serde_json::json!({"ok": true}),
};
send_event(socket, &ack).await;
}
Err(e) => {
tracing::warn!(user_id = %user_id, error = %e, "WS message processing failed");
let rid = request_id.unwrap_or(Uuid::nil());
let err_resp = WsOutEvent::Response {
request_id: rid,
data: serde_json::to_value(&e.to_ws_error())
.unwrap_or_default(),
};
send_event(socket, &err_resp).await;
}
},
Err(e) => {
tracing::warn!(error = %e, "WS transport parse error");
if let Some(rid) = request_id {
let err_resp = WsOutEvent::Response {
request_id: rid,
data: serde_json::to_value(&WsError {
code: 400,
error: "parse_error".to_string(),
message: e.to_string(),
})
.unwrap_or_default(),
};
send_event(socket, &err_resp).await;
} else {
send_error(
socket,
WsError {
code: 400,
error: "parse_error".to_string(),
message: e.to_string(),
},
)
.await;
}
}
}
}
async fn check_rate_limit(bus: &ChannelBus, user_id: Uuid) -> bool {
bus.inner
.rate_limiter
.check_rate_limit(user_id, "ws_message")
.await
.unwrap_or(true)
}
async fn send_event(socket: &Socket, event: &WsOutEvent) {
// Serialize to Value (not String) to avoid double JSON encoding
let value = serde_json::to_value(event).unwrap_or_default();
if let Err(e) = socket.emit(CHANNEL_EVENT, value).await {
tracing::warn!(error = %e, "WS response send failed");
}
}
async fn send_error(socket: &Socket, error: WsError) {
let value = serde_json::json!({
"type": "error",
"code": error.code,
"error": error.error,
"message": error.message,
});
if let Err(e) = socket.emit(CHANNEL_EVENT, value).await {
tracing::warn!(error = %e, "WS error send failed");
}
}