diff --git a/deploy/configmap.yaml b/deploy/configmap.yaml index af9427b..6981c72 100644 --- a/deploy/configmap.yaml +++ b/deploy/configmap.yaml @@ -11,6 +11,7 @@ data: # App Info APP_NAME: "gitdata" APP_VERSION: "0.1.0" + APP_DOMAIN_URL: "https://gitdata.ai" APP_STATIC_DOMAIN: "https://static.gitdata.ai" APP_MEDIA_DOMAIN: "https://static.gitdata.ai" APP_GIT_HTTP_DOMAIN: "https://git.gitdata.ai" diff --git a/libs/api/room/ws.rs b/libs/api/room/ws.rs index 8d79481..ab4366b 100644 --- a/libs/api/room/ws.rs +++ b/libs/api/room/ws.rs @@ -231,20 +231,57 @@ pub struct WsOutEvent { pub(crate) fn validate_origin(req: &HttpRequest) -> bool { static ALLOWED_ORIGINS: LazyLock> = LazyLock::new(|| { + // Build default origins from localhost + APP_DOMAIN_URL + let domain = std::env::var("APP_DOMAIN_URL") + .unwrap_or_else(|_| "http://127.0.0.1".to_string()); + + // Normalize: strip trailing slash, derive https/wss variants + let domain = domain.trim_end_matches('/'); + let https_domain = domain.replace("http://", "https://").replace("ws://", "wss://"); + let ws_domain = domain.replace("https://", "ws://").replace("http://", "ws://"); + + let mut defaults = vec![ + "http://localhost".to_string(), + "https://localhost".to_string(), + "http://127.0.0.1".to_string(), + "https://127.0.0.1".to_string(), + "ws://localhost".to_string(), + "wss://localhost".to_string(), + "ws://127.0.0.1".to_string(), + "wss://127.0.0.1".to_string(), + ]; + + // Always include APP_DOMAIN_URL and APP_STATIC_DOMAIN origins + let mut add_origin = |origin: &str| { + let origin = origin.trim_end_matches('/'); + let https_v = origin.replace("http://", "https://").replace("ws://", "wss://"); + let ws_v = origin.replace("https://", "ws://").replace("http://", "ws://"); + for v in [origin, &https_v, &ws_v] { + if !defaults.contains(&v.to_string()) && v != domain { + defaults.push(v.to_string()); + } + } + }; + if let Ok(static_domain) = std::env::var("APP_STATIC_DOMAIN") { + add_origin(&static_domain); + } + if !defaults.contains(&domain.to_string()) { + defaults.push(domain.to_string()); + } + if !defaults.contains(&https_domain) && https_domain != domain { + defaults.push(https_domain.clone()); + } + if !defaults.contains(&ws_domain) && ws_domain != domain && ws_domain != https_domain { + defaults.push(ws_domain); + } + std::env::var("WS_ALLOWED_ORIGINS") - .map(|v| v.split(',').map(|s| s.trim().to_string()).collect()) - .unwrap_or_else(|_| { - vec![ - "http://localhost".to_string(), - "https://localhost".to_string(), - "http://127.0.0.1".to_string(), - "https://127.0.0.1".to_string(), - "ws://localhost".to_string(), - "wss://localhost".to_string(), - "ws://127.0.0.1".to_string(), - "wss://127.0.0.1".to_string(), - ] + .map(|v| { + let mut origins = defaults.clone(); + origins.extend(v.split(',').map(|s| s.trim().to_string())); + origins }) + .unwrap_or_else(|_| defaults) }); let Some(origin) = req.headers().get("origin") else {