gitdataai/apps/app/src/main.rs

351 lines
12 KiB
Rust

use actix_cors::Cors;
use actix_web::cookie::time::Duration;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse};
use actix_web::{App, HttpResponse, HttpServer, cookie::Key, web};
use api::{robots, sidemap};
use clap::Parser;
use db::cache::AppCache;
use db::database::AppDatabase;
use futures::future::LocalBoxFuture;
use observability::{
HttpMetrics, HttpSnapshotGuard, MetricsMiddleware, TracingSpanMiddleware,
init_tracing_subscriber, install_recorder, prometheus_handler, push::MetricsPusher,
spawn_http_metrics_poller,
};
use sea_orm::ConnectionTrait;
use service::AppService;
use session::SessionMiddleware;
use session::config::{PersistentSession, SessionLifecycle, TtlExtensionPolicy};
use session::storage::RedisClusterSessionStore;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
mod args;
use args::ServerArgs;
use config::AppConfig;
use migrate::{Migrator, MigratorTrait};
#[derive(Clone)]
pub struct AppState {
pub db: AppDatabase,
pub cache: AppCache,
}
/// Custom middleware that logs requests except for noisy paths.
struct RequestLogger;
impl<S, B> actix_web::dev::Transform<S, ServiceRequest> for RequestLogger
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Transform = RequestLoggerService<S>;
type InitError = ();
type Future = futures::future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
futures::future::ok(RequestLoggerService {
service,
_marker: std::marker::PhantomData,
})
}
}
struct RequestLoggerService<S> {
service: S,
_marker: std::marker::PhantomData<fn(ServiceRequest)>,
}
impl<S, B> actix_web::dev::Service<ServiceRequest> for RequestLoggerService<S>
where
S: actix_web::dev::Service<
ServiceRequest,
Response = ServiceResponse<B>,
Error = actix_web::Error,
>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let path = req.path().to_string();
let method = req.method().to_string();
let should_log = !(path == "/health"
|| path == "/metrics"
|| path.starts_with("/ws")
|| path.starts_with("/assets"));
let start = Instant::now();
let fut = self.service.call(req);
Box::pin(async move {
let res = fut.await?;
if should_log {
tracing::info!(
target: "http_request",
method = %method,
path = %path,
status = res.status().as_u16(),
elapsed = ?start.elapsed(),
"{} {} {} {:?}",
method,
path,
res.status().as_u16(),
start.elapsed()
);
}
Ok(res)
})
}
}
fn build_session_key(cfg: &AppConfig) -> anyhow::Result<Key> {
if let Some(secret) = cfg.env.get("APP_SESSION_SECRET") {
if secret.len() < 32 {
tracing::warn!(
secret_len = secret.len(),
"APP_SESSION_SECRET is too short (<32 bytes), using generated key instead"
);
return Ok(Key::generate());
}
use hkdf::Hkdf;
use sha2::Sha256;
// HKDF-SHA256: standard key derivation with info string for domain separation
let hk = Hkdf::<Sha256>::new(Some(b"session-cookie-key"), secret.as_bytes());
let mut okm = [0u8; 64];
hk.expand(b"actix-session-signing-key", &mut okm)
.map_err(|e| anyhow::anyhow!("HKDF expand failed: {}", e))?;
return Ok(Key::from(&okm));
}
tracing::warn!(
"APP_SESSION_SECRET not set, using generated key (sessions invalidated on restart)"
);
Ok(Key::generate())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cfg = AppConfig::load();
let log_level = cfg.log_level().unwrap_or_else(|_| "info".to_string());
let otel_enabled = cfg.otel_enabled().unwrap_or(false);
init_tracing_subscriber(&log_level, false);
tracing::info!(
app_name = %cfg.app_name().unwrap_or_default(),
app_version = %cfg.app_version().unwrap_or_default(),
"Starting application"
);
let db = AppDatabase::init(&cfg).await?;
tracing::info!("Database connected");
let redis_urls = cfg.redis_urls()?;
let store: RedisClusterSessionStore = RedisClusterSessionStore::new(redis_urls).await?;
tracing::info!("Redis connected");
let cache = AppCache::init(&cfg).await?;
tracing::info!("Cache initialized");
run_migrations(&db).await?;
let session_key = build_session_key(&cfg)?;
let args = ServerArgs::parse();
let service = AppService::new(cfg.clone()).await?;
tracing::info!("AppService initialized");
let _model_sync_handle = service.clone().start_sync_task();
// TODO: workspace module not yet wired — billing alert task pending
// let _billing_alert_handle = service.clone().start_billing_alert_task();
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
let worker_service = service.clone();
let worker_handle =
tokio::spawn(async move { worker_service.start_room_workers(shutdown_rx).await });
let _otel_guard = if otel_enabled {
let endpoint = cfg
.otel_endpoint()
.unwrap_or_else(|_| "http://localhost:4317".to_string());
let service_name = cfg
.otel_service_name()
.unwrap_or_else(|_| "app".to_string());
let service_version = cfg
.otel_service_version()
.unwrap_or_else(|_| "0.1.0".to_string());
tracing::info!(endpoint = %endpoint, service = %service_name, "OTLP tracing enabled");
let guard =
observability::init_otlp(&endpoint, &service_name, &service_version, &log_level)
.map_err(|e| anyhow::anyhow!("OTLP init failed: {}", e))?;
guard
} else {
None
};
let prometheus_handle = install_recorder();
let prometheus_handle_data = web::Data::new(prometheus_handle.clone());
let http_metrics = std::sync::Arc::new(HttpMetrics::new());
let http_snapshot: HttpSnapshotGuard = std::sync::Arc::new(std::sync::RwLock::new(
observability::HttpMetricsSnapshot::default(),
));
let http_snapshot_for_poller = http_snapshot.clone();
spawn_http_metrics_poller(
http_metrics.clone(),
http_snapshot_for_poller,
std::time::Duration::from_secs(15),
);
let http_snapshot_data = web::Data::new(http_snapshot);
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "app");
pusher.spawn(
http_metrics.clone(),
Arc::new(prometheus_handle.clone()),
std::time::Duration::from_secs(15),
);
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let bind_addr = args.bind.unwrap_or_else(|| "127.0.0.1:8080".to_string());
tracing::info!(bind_addr = %bind_addr, "Listening");
let http_metrics_server = http_metrics.clone();
let cors_origins: Vec<String> = cfg
.env
.get("CORS_ORIGINS")
.map(|s| {
s.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
})
.unwrap_or_else(|| vec!["http://localhost:5173".to_string()]);
let cookie_secure = cfg
.env
.get("APP_COOKIE_SECURE")
.map(|s| s != "false")
.unwrap_or(true);
tracing::info!(cookie_secure = cookie_secure, "Cookie secure mode");
HttpServer::new(move || {
let mut cors = Cors::default();
for origin in &cors_origins {
cors = cors.allowed_origin(origin);
}
let cors = cors
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
.allowed_headers([
"Content-Type",
"Authorization",
"X-Requested-With",
"Accept",
"Origin",
])
.supports_credentials()
.max_age(3600);
let security_headers = actix_web::middleware::DefaultHeaders::new()
.add(("X-Content-Type-Options", "nosniff"))
.add(("X-Frame-Options", "DENY"))
.add(("Referrer-Policy", "strict-origin-when-cross-origin"));
let session_mw = SessionMiddleware::builder(store.clone(), session_key.clone())
.cookie_name("id".to_string())
.cookie_path("/".to_string())
.cookie_secure(cookie_secure)
.cookie_http_only(true)
.session_lifecycle(SessionLifecycle::PersistentSession(
PersistentSession::default()
.session_ttl(Duration::days(30))
.session_ttl_extension_policy(TtlExtensionPolicy::OnEveryRequest),
))
.build();
let metrics_mw = MetricsMiddleware::new(http_metrics_server.clone());
App::new()
.wrap(cors)
.wrap(security_headers)
.wrap(session_mw)
.wrap(RequestLogger)
.wrap(metrics_mw)
.wrap(TracingSpanMiddleware::new())
.app_data(web::Data::new(AppState {
db: db.clone(),
cache: cache.clone(),
}))
.app_data(web::Data::new(service.clone()))
.app_data(web::Data::new(cfg.clone()))
.app_data(web::Data::new(db.clone()))
.app_data(web::Data::new(cache.clone()))
.app_data(http_snapshot_data.clone())
.app_data(prometheus_handle_data.clone())
.route("/robots.txt", web::get().to(robots::robots))
.route("/sitemap.xml", web::get().to(sidemap::sitemap))
.service(
web::scope("/sidemap")
.route("", web::get().to(sidemap::sitemap))
.route("/static", web::get().to(sidemap::sitemap_static))
.route("/users", web::get().to(sidemap::sitemap_users))
.route("/projects", web::get().to(sidemap::sitemap_projects))
.route("/repos", web::get().to(sidemap::sitemap_repos)),
)
.route("/health", web::get().to(health_check))
.route("/metrics", web::get().to(prometheus_handler))
.configure(api::route::init_routes)
})
.bind(&bind_addr)?
.run()
.await?;
tracing::info!("Server stopped, shutting down room workers");
let _ = shutdown_tx.send(());
let _ = worker_handle.await;
tracing::info!("Room workers stopped");
Ok(())
}
async fn run_migrations(db: &AppDatabase) -> anyhow::Result<()> {
tracing::info!("Running database migrations...");
Migrator::up(db.writer(), None)
.await
.map_err(|e| anyhow::anyhow!("Migration failed: {:?}", e))?;
tracing::info!("Migrations completed");
Ok(())
}
async fn health_check(state: web::Data<AppState>) -> HttpResponse {
let db_ok = db_ping(&state.db).await;
let cache_ok = cache_ping(&state.cache).await;
let healthy = db_ok && cache_ok;
if healthy {
HttpResponse::Ok().json(serde_json::json!({
"status": "ok",
"db": "ok",
"cache": "ok",
}))
} else {
HttpResponse::ServiceUnavailable().json(serde_json::json!({
"status": "unhealthy",
"db": if db_ok { "ok" } else { "error" },
"cache": if cache_ok { "ok" } else { "error" },
}))
}
}
async fn db_ping(db: &AppDatabase) -> bool {
let writer_ok = db.writer().execute_unprepared("SELECT 1").await.is_ok();
let reader_ok = db.reader().execute_unprepared("SELECT 1").await.is_ok();
writer_ok && reader_ok
}
async fn cache_ping(cache: &AppCache) -> bool {
cache.conn().await.is_ok()
}