fix(ws): allow APP_DOMAIN_URL and APP_STATIC_DOMAIN origins
validate_origin() only allowed localhost origins by default, causing production WebSocket connections to be rejected. Now it reads APP_DOMAIN_URL and APP_STATIC_DOMAIN from env and automatically adds their http/https/ws/wss variants to the allowed origins list. Also add APP_DOMAIN_URL to the production configmap.
This commit is contained in:
parent
89deebced6
commit
431f40063f
@ -11,6 +11,7 @@ data:
|
|||||||
# App Info
|
# App Info
|
||||||
APP_NAME: "gitdata"
|
APP_NAME: "gitdata"
|
||||||
APP_VERSION: "0.1.0"
|
APP_VERSION: "0.1.0"
|
||||||
|
APP_DOMAIN_URL: "https://gitdata.ai"
|
||||||
APP_STATIC_DOMAIN: "https://static.gitdata.ai"
|
APP_STATIC_DOMAIN: "https://static.gitdata.ai"
|
||||||
APP_MEDIA_DOMAIN: "https://static.gitdata.ai"
|
APP_MEDIA_DOMAIN: "https://static.gitdata.ai"
|
||||||
APP_GIT_HTTP_DOMAIN: "https://git.gitdata.ai"
|
APP_GIT_HTTP_DOMAIN: "https://git.gitdata.ai"
|
||||||
|
|||||||
@ -231,20 +231,57 @@ pub struct WsOutEvent {
|
|||||||
|
|
||||||
pub(crate) fn validate_origin(req: &HttpRequest) -> bool {
|
pub(crate) fn validate_origin(req: &HttpRequest) -> bool {
|
||||||
static ALLOWED_ORIGINS: LazyLock<Vec<String>> = LazyLock::new(|| {
|
static ALLOWED_ORIGINS: LazyLock<Vec<String>> = LazyLock::new(|| {
|
||||||
|
// Build default origins from localhost + APP_DOMAIN_URL
|
||||||
|
let domain = std::env::var("APP_DOMAIN_URL")
|
||||||
|
.unwrap_or_else(|_| "http://127.0.0.1".to_string());
|
||||||
|
|
||||||
|
// Normalize: strip trailing slash, derive https/wss variants
|
||||||
|
let domain = domain.trim_end_matches('/');
|
||||||
|
let https_domain = domain.replace("http://", "https://").replace("ws://", "wss://");
|
||||||
|
let ws_domain = domain.replace("https://", "ws://").replace("http://", "ws://");
|
||||||
|
|
||||||
|
let mut defaults = vec![
|
||||||
|
"http://localhost".to_string(),
|
||||||
|
"https://localhost".to_string(),
|
||||||
|
"http://127.0.0.1".to_string(),
|
||||||
|
"https://127.0.0.1".to_string(),
|
||||||
|
"ws://localhost".to_string(),
|
||||||
|
"wss://localhost".to_string(),
|
||||||
|
"ws://127.0.0.1".to_string(),
|
||||||
|
"wss://127.0.0.1".to_string(),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Always include APP_DOMAIN_URL and APP_STATIC_DOMAIN origins
|
||||||
|
let mut add_origin = |origin: &str| {
|
||||||
|
let origin = origin.trim_end_matches('/');
|
||||||
|
let https_v = origin.replace("http://", "https://").replace("ws://", "wss://");
|
||||||
|
let ws_v = origin.replace("https://", "ws://").replace("http://", "ws://");
|
||||||
|
for v in [origin, &https_v, &ws_v] {
|
||||||
|
if !defaults.contains(&v.to_string()) && v != domain {
|
||||||
|
defaults.push(v.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Ok(static_domain) = std::env::var("APP_STATIC_DOMAIN") {
|
||||||
|
add_origin(&static_domain);
|
||||||
|
}
|
||||||
|
if !defaults.contains(&domain.to_string()) {
|
||||||
|
defaults.push(domain.to_string());
|
||||||
|
}
|
||||||
|
if !defaults.contains(&https_domain) && https_domain != domain {
|
||||||
|
defaults.push(https_domain.clone());
|
||||||
|
}
|
||||||
|
if !defaults.contains(&ws_domain) && ws_domain != domain && ws_domain != https_domain {
|
||||||
|
defaults.push(ws_domain);
|
||||||
|
}
|
||||||
|
|
||||||
std::env::var("WS_ALLOWED_ORIGINS")
|
std::env::var("WS_ALLOWED_ORIGINS")
|
||||||
.map(|v| v.split(',').map(|s| s.trim().to_string()).collect())
|
.map(|v| {
|
||||||
.unwrap_or_else(|_| {
|
let mut origins = defaults.clone();
|
||||||
vec![
|
origins.extend(v.split(',').map(|s| s.trim().to_string()));
|
||||||
"http://localhost".to_string(),
|
origins
|
||||||
"https://localhost".to_string(),
|
|
||||||
"http://127.0.0.1".to_string(),
|
|
||||||
"https://127.0.0.1".to_string(),
|
|
||||||
"ws://localhost".to_string(),
|
|
||||||
"wss://localhost".to_string(),
|
|
||||||
"ws://127.0.0.1".to_string(),
|
|
||||||
"wss://127.0.0.1".to_string(),
|
|
||||||
]
|
|
||||||
})
|
})
|
||||||
|
.unwrap_or_else(|_| defaults)
|
||||||
});
|
});
|
||||||
|
|
||||||
let Some(origin) = req.headers().get("origin") else {
|
let Some(origin) = req.headers().get("origin") else {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user