From 734e1c4cc83fe74de310b16961befcc2006ab11f Mon Sep 17 00:00:00 2001 From: zhenyi <434836402@qq.com> Date: Mon, 1 Jun 2026 22:04:25 +0800 Subject: [PATCH] refactor: update infrastructure libs (config, db, cache, queue, storage, migrate) --- lib/cache/app.rs | 319 ++++++++++++++ lib/cache/lib.rs | 220 +--------- lib/config/app.rs | 7 + lib/config/app_config.rs | 39 ++ lib/config/lib.rs | 42 +- lib/config/logs.rs | 2 +- lib/db/database.rs | 143 +++++- lib/migrate/lib.rs | 267 ++++++++++++ .../sql/room/room_attachment_down_02.sql | 2 + .../sql/room/room_attachment_up_02.sql | 2 + lib/migrate/sql/room/room_mention_down_02.sql | 1 + lib/migrate/sql/room/room_mention_up_02.sql | 1 + lib/migrate/src/main.rs | 407 +----------------- lib/queue/consumer.rs | 108 ++++- lib/queue/producer.rs | 61 ++- lib/storage/lib.rs | 155 +++++-- 16 files changed, 1049 insertions(+), 727 deletions(-) create mode 100644 lib/cache/app.rs create mode 100644 lib/config/app_config.rs create mode 100644 lib/migrate/lib.rs create mode 100644 lib/migrate/sql/room/room_attachment_down_02.sql create mode 100644 lib/migrate/sql/room/room_attachment_up_02.sql create mode 100644 lib/migrate/sql/room/room_mention_down_02.sql create mode 100644 lib/migrate/sql/room/room_mention_up_02.sql diff --git a/lib/cache/app.rs b/lib/cache/app.rs new file mode 100644 index 0000000..253b2ab --- /dev/null +++ b/lib/cache/app.rs @@ -0,0 +1,319 @@ +use std::time::Duration; + +use track::CounterVec; + +use crate::{ + cluster::{ClusterCache, ClusterCacheConfig}, + error::{CacheError, CacheResult}, + local::{LocalCacheConfig, MokaCache}, +}; + +// ============================================================================ +// Configuration +// ============================================================================ + +#[derive(Clone, Debug)] +pub struct AppCacheConfig { + pub local: LocalCacheConfig, + pub cluster: Option, + pub default_ttl: Option, + pub cluster_write_through: bool, +} + +impl Default for AppCacheConfig { + fn default() -> Self { + Self { + local: LocalCacheConfig::default(), + cluster: None, + default_ttl: Some(Duration::from_secs(300)), + cluster_write_through: true, + } + } +} + +impl TryFrom<&config::AppConfig> for AppCacheConfig { + type Error = CacheError; + + fn try_from(config: &config::AppConfig) -> Result { + let local = LocalCacheConfig { + max_capacity: config + .cache_local_max_capacity() + .map_err(|error| CacheError::Config(error.to_string()))?, + time_to_live: config + .cache_local_ttl() + .map_err(|error| CacheError::Config(error.to_string()))?, + time_to_idle: config + .cache_local_tti() + .map_err(|error| CacheError::Config(error.to_string()))?, + }; + + let cluster = if config + .cache_cluster_enabled() + .map_err(|error| CacheError::Config(error.to_string()))? + { + Some(ClusterCacheConfig { + urls: config + .redis_urls() + .map_err(|error| CacheError::Config(error.to_string()))?, + key_prefix: config.cache_cluster_key_prefix(), + command_timeout: config + .cache_cluster_command_timeout() + .map_err(|error| CacheError::Config(error.to_string()))?, + }) + } else { + None + }; + + Ok(Self { + local, + cluster, + default_ttl: config + .cache_default_ttl() + .map_err(|error| CacheError::Config(error.to_string()))?, + cluster_write_through: config + .cache_cluster_write_through() + .map_err(|error| CacheError::Config(error.to_string()))?, + }) + } +} + +// ============================================================================ +// AppCache +// ============================================================================ + +#[derive(Clone)] +pub struct AppCache { + pub local: MokaCache, + pub cluster: Option, + default_ttl: Option, + cluster_write_through: bool, + metrics: Option, +} + +impl AppCache { + #[tracing::instrument(skip(config))] + pub async fn init(config: AppCacheConfig) -> CacheResult { + let local = MokaCache::with_config(config.local); + let cluster = match config.cluster { + Some(cluster) => Some(match ClusterCache::connect(cluster).await { + Ok(cluster) => cluster, + Err(e) => { + tracing::error!(error = %e, "failed to connect to cache cluster"); + return Err(e); + } + }), + None => None, + }; + + tracing::info!(has_cluster = cluster.is_some(), "cache initialized"); + Ok(Self { + local, + cluster, + default_ttl: config.default_ttl, + cluster_write_through: config.cluster_write_through, + metrics: None, + }) + } + + pub fn local_only(local: MokaCache) -> Self { + Self { + local, + cluster: None, + default_ttl: None, + cluster_write_through: false, + metrics: None, + } + } + + /// Attach a metrics registry for recording cache counters. + pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { + self.metrics = Some(registry); + } + + #[tracing::instrument(skip(self), fields(cache.key = %key))] + pub async fn get(&self, key: &str) -> CacheResult> + where + T: serde::Serialize + serde::de::DeserializeOwned, + { + if let Some(value) = self.local.get(key).await? { + tracing::debug!("cache hit (local)"); + self.record_hit("local"); + return Ok(Some(value)); + } + + let Some(cluster) = &self.cluster else { + tracing::debug!("cache miss"); + self.record_miss(); + return Ok(None); + }; + + let value = cluster.get::(key).await?; + if let Some(value) = &value { + self.local.set(key, value).await?; + tracing::debug!("cache hit (cluster)"); + self.record_hit("cluster"); + } else { + tracing::debug!("cache miss"); + self.record_miss(); + } + Ok(value) + } + + #[tracing::instrument(skip(self, value), fields(cache.key = %key))] + pub async fn set(&self, key: &str, value: &T) -> CacheResult<()> + where + T: serde::Serialize + ?Sized, + { + self.local.set(key, value).await?; + if self.cluster_write_through + && let Some(cluster) = &self.cluster + { + cluster.set(key, value, self.default_ttl).await?; + } + self.record_set(); + Ok(()) + } + + pub async fn set_with_ttl( + &self, + key: &str, + value: &T, + ttl: std::time::Duration, + ) -> CacheResult<()> + where + T: serde::Serialize + ?Sized, + { + self.local.set(key, value).await?; + if self.cluster_write_through + && let Some(cluster) = &self.cluster + { + cluster.set(key, value, Some(ttl)).await?; + } + Ok(()) + } + + #[tracing::instrument(skip(self), fields(cache.key = %key))] + pub async fn remove(&self, key: &str) -> CacheResult<()> { + self.local.remove(key).await; + if let Some(cluster) = &self.cluster { + cluster.remove(key).await?; + } + self.record_remove(); + Ok(()) + } + + fn record_hit(&self, tier: &str) { + if let Some(reg) = &self.metrics { + cache_hits_vec(reg).with_label_values(&[tier]).inc(); + } + } + + fn record_miss(&self) { + if let Some(reg) = &self.metrics { + cache_misses_vec(reg).with_label_values(&[]).inc(); + } + } + + fn record_set(&self) { + if let Some(reg) = &self.metrics { + cache_sets_vec(reg).with_label_values(&[]).inc(); + } + } + + fn record_remove(&self) { + if let Some(reg) = &self.metrics { + cache_removes_vec(reg).with_label_values(&[]).inc(); + } + } + + pub async fn delete_pattern(&self, pattern: &str) -> CacheResult { + let pattern = pattern.to_string(); + let local_pattern = pattern.clone(); + self.local.invalidate_entries_if(move |key| { + simple_glob_match(&local_pattern, key) + }); + + let mut removed = 0u64; + if let Some(cluster) = &self.cluster { + removed = cluster.delete_pattern(&pattern).await?; + } + Ok(removed) + } + + pub async fn ping_cluster(&self) -> CacheResult<()> { + if let Some(cluster) = &self.cluster { + cluster.ping().await?; + } + Ok(()) + } + + pub fn conn(&self) -> Option { + self.cluster.as_ref().map(|c| c.conn()) + } +} + +fn cache_hits_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec("cache_hits_total", "Total cache hits", &["tier"]) + .expect("failed to register cache_hits_total") +} + +fn cache_misses_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec("cache_misses_total", "Total cache misses", &[]) + .expect("failed to register cache_misses_total") +} + +fn cache_sets_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "cache_sets_total", + "Total cache set operations", + &[], + ) + .expect("failed to register cache_sets_total") +} + +fn cache_removes_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "cache_removes_total", + "Total cache remove operations", + &[], + ) + .expect("failed to register cache_removes_total") +} + +// ============================================================================ +// Helpers +// ============================================================================ + +fn simple_glob_match(pattern: &str, key: &str) -> bool { + let p = pattern.as_bytes(); + let k = key.as_bytes(); + let (mut pi, mut ki) = (0usize, 0usize); + + let mut backtrack_p: Option = None; + let mut backtrack_k: usize = 0; + + loop { + if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) { + pi += 1; + ki += 1; + } else if pi < p.len() && p[pi] == b'*' { + backtrack_p = Some(pi); + backtrack_k = ki; + pi += 1; + } else if let Some(saved_pi) = backtrack_p { + backtrack_k += 1; + ki = backtrack_k; + pi = saved_pi + 1; + } else { + return pi == p.len() && ki == k.len(); + } + + if pi == p.len() && ki == k.len() { + return true; + } + } +} diff --git a/lib/cache/lib.rs b/lib/cache/lib.rs index 6d2cc38..8800318 100644 --- a/lib/cache/lib.rs +++ b/lib/cache/lib.rs @@ -1,227 +1,11 @@ +pub mod app; pub mod cluster; pub mod error; pub mod local; -use std::time::Duration; - pub use crate::{ + app::{AppCache, AppCacheConfig}, cluster::{ClusterCache, ClusterCacheConfig}, error::{CacheError, CacheResult}, local::{LocalCacheConfig, MokaCache}, }; - -#[derive(Clone, Debug)] -pub struct AppCacheConfig { - pub local: LocalCacheConfig, - pub cluster: Option, - pub default_ttl: Option, - pub cluster_write_through: bool, -} - -impl Default for AppCacheConfig { - fn default() -> Self { - Self { - local: LocalCacheConfig::default(), - cluster: None, - default_ttl: Some(Duration::from_secs(300)), - cluster_write_through: true, - } - } -} - -#[derive(Clone)] -pub struct AppCache { - pub local: MokaCache, - pub cluster: Option, - default_ttl: Option, - cluster_write_through: bool, -} - -impl AppCache { - pub async fn init(config: AppCacheConfig) -> CacheResult { - let local = MokaCache::with_config(config.local); - let cluster = match config.cluster { - Some(cluster) => Some(match ClusterCache::connect(cluster).await { - Ok(cluster) => cluster, - Err(e) => { - println!("cache:init:error with: {}", e); - return Err(e); - } - }), - None => None, - }; - - Ok(Self { - local, - cluster, - default_ttl: config.default_ttl, - cluster_write_through: config.cluster_write_through, - }) - } - - pub fn local_only(local: MokaCache) -> Self { - Self { - local, - cluster: None, - default_ttl: None, - cluster_write_through: false, - } - } - - pub async fn get(&self, key: &str) -> CacheResult> - where - T: serde::Serialize + serde::de::DeserializeOwned, - { - if let Some(value) = self.local.get(key).await? { - return Ok(Some(value)); - } - - let Some(cluster) = &self.cluster else { - return Ok(None); - }; - - let value = cluster.get::(key).await?; - if let Some(value) = &value { - self.local.set(key, value).await?; - } - Ok(value) - } - - pub async fn set(&self, key: &str, value: &T) -> CacheResult<()> - where - T: serde::Serialize + ?Sized, - { - self.local.set(key, value).await?; - if self.cluster_write_through - && let Some(cluster) = &self.cluster - { - cluster.set(key, value, self.default_ttl).await?; - } - Ok(()) - } - - pub async fn set_with_ttl( - &self, - key: &str, - value: &T, - ttl: std::time::Duration, - ) -> CacheResult<()> - where - T: serde::Serialize + ?Sized, - { - self.local.set(key, value).await?; - if self.cluster_write_through - && let Some(cluster) = &self.cluster - { - cluster.set(key, value, Some(ttl)).await?; - } - Ok(()) - } - - pub async fn remove(&self, key: &str) -> CacheResult<()> { - self.local.remove(key).await; - if let Some(cluster) = &self.cluster { - cluster.remove(key).await?; - } - Ok(()) - } - pub async fn delete_pattern(&self, pattern: &str) -> CacheResult { - let pattern = pattern.to_string(); - let local_pattern = pattern.clone(); - self.local.invalidate_entries_if(move |key| { - simple_glob_match(&local_pattern, key) - }); - - let mut removed = 0u64; - if let Some(cluster) = &self.cluster { - removed = cluster.delete_pattern(&pattern).await?; - } - Ok(removed) - } - - pub async fn ping_cluster(&self) -> CacheResult<()> { - if let Some(cluster) = &self.cluster { - cluster.ping().await?; - } - Ok(()) - } - - pub fn conn(&self) -> Option { - self.cluster.as_ref().map(|c| c.conn()) - } -} - -impl TryFrom<&config::AppConfig> for AppCacheConfig { - type Error = CacheError; - - fn try_from(config: &config::AppConfig) -> Result { - let local = LocalCacheConfig { - max_capacity: config - .cache_local_max_capacity() - .map_err(|error| CacheError::Config(error.to_string()))?, - time_to_live: config - .cache_local_ttl() - .map_err(|error| CacheError::Config(error.to_string()))?, - time_to_idle: config - .cache_local_tti() - .map_err(|error| CacheError::Config(error.to_string()))?, - }; - - let cluster = if config - .cache_cluster_enabled() - .map_err(|error| CacheError::Config(error.to_string()))? - { - Some(ClusterCacheConfig { - urls: config - .redis_urls() - .map_err(|error| CacheError::Config(error.to_string()))?, - key_prefix: config.cache_cluster_key_prefix(), - command_timeout: config - .cache_cluster_command_timeout() - .map_err(|error| CacheError::Config(error.to_string()))?, - }) - } else { - None - }; - - Ok(Self { - local, - cluster, - default_ttl: config - .cache_default_ttl() - .map_err(|error| CacheError::Config(error.to_string()))?, - cluster_write_through: config - .cache_cluster_write_through() - .map_err(|error| CacheError::Config(error.to_string()))?, - }) - } -} -fn simple_glob_match(pattern: &str, key: &str) -> bool { - let p = pattern.as_bytes(); - let k = key.as_bytes(); - let (mut pi, mut ki) = (0usize, 0usize); - - let mut backtrack_p: Option = None; - let mut backtrack_k: usize = 0; - - loop { - if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) { - pi += 1; - ki += 1; - } else if pi < p.len() && p[pi] == b'*' { - backtrack_p = Some(pi); - backtrack_k = ki; - pi += 1; - } else if let Some(saved_pi) = backtrack_p { - backtrack_k += 1; - ki = backtrack_k; - pi = saved_pi + 1; - } else { - return pi == p.len() && ki == k.len(); - } - - if pi == p.len() && ki == k.len() { - return true; - } - } -} diff --git a/lib/config/app.rs b/lib/config/app.rs index 0335865..6d9777f 100644 --- a/lib/config/app.rs +++ b/lib/config/app.rs @@ -28,6 +28,13 @@ impl AppConfig { Ok(8080) } + pub fn email_health_port(&self) -> u16 { + self.env + .get("APP_EMAIL_HEALTH_PORT") + .and_then(|port| port.parse::().ok()) + .unwrap_or(8083) + } + pub fn session_secret(&self) -> anyhow::Result { if let Some(secret) = self.env.get("APP_SESSION_SECRET") { return Ok(secret.to_string()); diff --git a/lib/config/app_config.rs b/lib/config/app_config.rs new file mode 100644 index 0000000..f3a1405 --- /dev/null +++ b/lib/config/app_config.rs @@ -0,0 +1,39 @@ +use std::{collections::HashMap, sync::OnceLock}; + +pub static GLOBAL_CONFIG: OnceLock = OnceLock::new(); + +#[derive(Clone, Debug)] +pub struct AppConfig { + pub env: HashMap, +} + +impl AppConfig { + const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"]; + + pub fn load() -> AppConfig { + let mut env = HashMap::new(); + for env_file in AppConfig::ENV_FILES { + if let Err(e) = dotenvy::from_path(env_file) { + tracing::debug!(file = %env_file, error = %e, "dotenv load skipped"); + } + if let Ok(env_file_content) = std::fs::read_to_string(env_file) { + for line in env_file_content.lines() { + if let Some((key, value)) = line.split_once('=') { + env.insert(key.to_string(), value.to_string()); + } + } + } + } + env = env.into_iter().chain(std::env::vars()).collect(); + let this = AppConfig { env }; + if let Some(config) = GLOBAL_CONFIG.get() { + config.clone() + } else { + let _ = GLOBAL_CONFIG.set(this); + GLOBAL_CONFIG + .get() + .expect("global config should be set after load") + .clone() + } + } +} diff --git a/lib/config/lib.rs b/lib/config/lib.rs index 5426616..0f030ad 100644 --- a/lib/config/lib.rs +++ b/lib/config/lib.rs @@ -1,44 +1,6 @@ -use std::{collections::HashMap, sync::OnceLock}; - -pub static GLOBAL_CONFIG: OnceLock = OnceLock::new(); - -#[derive(Clone, Debug)] -pub struct AppConfig { - pub env: HashMap, -} - -impl AppConfig { - const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"]; - pub fn load() -> AppConfig { - let mut env = HashMap::new(); - for env_file in AppConfig::ENV_FILES { - if let Err(e) = dotenvy::from_path(env_file) { - tracing::debug!(file = %env_file, error = %e, "dotenv load skipped"); - } - if let Ok(env_file_content) = std::fs::read_to_string(env_file) { - for line in env_file_content.lines() { - if let Some((key, value)) = line.split_once('=') { - env.insert(key.to_string(), value.to_string()); - } - } - } - } - env = env.into_iter().chain(std::env::vars()).collect(); - let this = AppConfig { env }; - if GLOBAL_CONFIG.get().is_some() { - GLOBAL_CONFIG.get().unwrap().clone() - } else { - let _ = GLOBAL_CONFIG.set(this); - GLOBAL_CONFIG - .get() - .expect("global config should be set after load") - .clone() - } - } -} - pub mod ai; pub mod app; +pub mod app_config; pub mod auth; pub mod avatar; pub mod cache; @@ -57,3 +19,5 @@ pub mod redis; pub mod smtp; pub mod ssh; pub mod storage; + +pub use app_config::{AppConfig, GLOBAL_CONFIG}; diff --git a/lib/config/logs.rs b/lib/config/logs.rs index b3e5b98..9746fd6 100644 --- a/lib/config/logs.rs +++ b/lib/config/logs.rs @@ -61,7 +61,7 @@ impl AppConfig { if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") { return Ok(endpoint.to_string()); } - Ok("http://localhost:5080/api/default/v1/traces".to_string()) + Ok("http://localhost:4318".to_string()) } pub fn otel_service_name(&self) -> anyhow::Result { diff --git a/lib/db/database.rs b/lib/db/database.rs index 7c25bd7..ce10a05 100644 --- a/lib/db/database.rs +++ b/lib/db/database.rs @@ -7,6 +7,7 @@ use sqlx::{ PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, }, }; +use track::{CounterVec, HistogramVec}; use crate::{ route::{SqlRoute, route_sql}, @@ -17,9 +18,11 @@ use crate::{ pub struct AppDatabase { db_write: PgPool, db_read: Option, + metrics: Option, } impl AppDatabase { + #[tracing::instrument(skip(cfg))] pub async fn init(cfg: &AppConfig) -> anyhow::Result { let db_url = cfg.database_url()?; let max_connections = cfg.database_max_connections()?; @@ -69,7 +72,15 @@ impl AppDatabase { None }; - Ok(Self { db_write, db_read }) + Ok(Self { + db_write, + db_read, + metrics: None, + }) + } + + pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { + self.metrics = Some(registry); } pub fn writer(&self) -> &PgPool { @@ -87,11 +98,14 @@ impl AppDatabase { } } + #[tracing::instrument(skip(self), fields(sql.route = "write"))] pub async fn begin(&self) -> Result, sqlx::Error> { let txn = self.db_write.begin().await?; + tracing::debug!("db transaction started"); Ok(AppTransaction { inner: txn }) } + #[tracing::instrument(skip(self), fields(sql.route = "read"))] pub async fn begin_read_only( &self, ) -> Result, sqlx::Error> { @@ -101,9 +115,11 @@ impl AppDatabase { .execute(&mut *txn) .await?; + tracing::debug!("db read-only transaction started"); Ok(AppTransaction { inner: txn }) } + #[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))] pub async fn execute( &self, sql: &str, @@ -111,18 +127,38 @@ impl AppDatabase { self.execute_with_args(sql, PgArguments::default()).await } + #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))] pub async fn execute_with_args( &self, sql: &str, args: PgArguments, ) -> Result { let pool = self.route_pool(sql); + let start = std::time::Instant::now(); - sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) + let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) .execute(pool) - .await + .await; + + let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") { + "insert" + } else if sql.trim_start().to_uppercase().starts_with("UPDATE") { + "update" + } else if sql.trim_start().to_uppercase().starts_with("DELETE") { + "delete" + } else { + "execute" + }; + self.record_query( + kind, + self.route_label(sql), + start.elapsed(), + result.is_ok(), + ); + result } + #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))] pub async fn fetch_one(&self, sql: &str) -> Result where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, @@ -130,6 +166,7 @@ impl AppDatabase { self.fetch_one_with_args(sql, PgArguments::default()).await } + #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))] pub async fn fetch_one_with_args( &self, sql: &str, @@ -139,12 +176,23 @@ impl AppDatabase { for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); + let start = std::time::Instant::now(); - sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) - .fetch_one(pool) - .await + let result = + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_one(pool) + .await; + + self.record_query( + "select", + self.route_label(sql), + start.elapsed(), + result.is_ok(), + ); + result } + #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))] pub async fn fetch_optional( &self, sql: &str, @@ -156,6 +204,7 @@ impl AppDatabase { .await } + #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))] pub async fn fetch_optional_with_args( &self, sql: &str, @@ -165,12 +214,23 @@ impl AppDatabase { for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); + let start = std::time::Instant::now(); - sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) - .fetch_optional(pool) - .await + let result = + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_optional(pool) + .await; + + self.record_query( + "select", + self.route_label(sql), + start.elapsed(), + result.is_ok(), + ); + result } + #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))] pub async fn fetch_all(&self, sql: &str) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, @@ -178,6 +238,7 @@ impl AppDatabase { self.fetch_all_with_args(sql, PgArguments::default()).await } + #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))] pub async fn fetch_all_with_args( &self, sql: &str, @@ -187,10 +248,45 @@ impl AppDatabase { for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); + let start = std::time::Instant::now(); - sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) - .fetch_all(pool) - .await + let result = + sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) + .fetch_all(pool) + .await; + + self.record_query( + "select", + self.route_label(sql), + start.elapsed(), + result.is_ok(), + ); + result + } + + fn route_label(&self, sql: &str) -> &str { + match route_sql(sql) { + SqlRoute::Write => "write", + SqlRoute::Read => "read", + } + } + + fn record_query( + &self, + kind: &str, + route: &str, + duration: Duration, + success: bool, + ) { + if let Some(reg) = &self.metrics { + let status = if success { "success" } else { "error" }; + db_queries_vec(reg) + .with_label_values(&[kind, route, status]) + .inc(); + db_query_duration_vec(reg) + .with_label_values(&[kind, route]) + .observe(duration.as_secs_f64()); + } } } @@ -230,3 +326,26 @@ async fn build_pool( pool_options.connect_with(options).await } + +fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "db_queries_total", + "Total database queries", + &["kind", "route", "status"], + ) + .expect("failed to register db_queries_total") +} + +fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec { + registry + .register_histogram_vec( + "db_query_duration_seconds", + "DB query duration in seconds", + &["kind", "route"], + vec![ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, + ], + ) + .expect("failed to register db_query_duration_seconds") +} diff --git a/lib/migrate/lib.rs b/lib/migrate/lib.rs new file mode 100644 index 0000000..0ff3cb4 --- /dev/null +++ b/lib/migrate/lib.rs @@ -0,0 +1,267 @@ +use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result, bail}; +use sqlx::PgPool; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct Migration { + domain: String, + table: String, + version: u32, + direction: MigrationDir, + path: PathBuf, + depends_on: Vec, +} + +impl Ord for Migration { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (&self.domain, &self.table, self.version, &self.direction).cmp(&( + &other.domain, + &other.table, + other.version, + &other.direction, + )) + } +} + +impl PartialOrd for Migration { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum MigrationDir { + Up, + Down, +} + +pub async fn run_up(pool: &PgPool) -> Result<()> { + let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql"); + ensure_migrations_table(pool).await?; + let all = discover_migrations(&sql_root)?; + let applied = applied_set(pool).await?; + + let mut up_migrations: Vec<_> = all + .into_iter() + .filter(|m| m.direction == MigrationDir::Up) + .filter(|m| { + !applied.contains_key(&( + m.domain.clone(), + m.table.clone(), + m.version, + )) + }) + .collect(); + + if up_migrations.is_empty() { + tracing::info!("All migrations are already applied."); + return Ok(()); + } + + topo_sort(&mut up_migrations)?; + + for m in &up_migrations { + let sql = std::fs::read_to_string(&m.path) + .context(format!("Failed to read {:?}", m.path))?; + let checksum = compute_checksum(&sql); + + tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration"); + exec_sql(pool, &sql).await?; + record_migration(pool, m, &checksum).await?; + } + + tracing::info!("Applied {} migration(s).", up_migrations.len()); + Ok(()) +} + +fn discover_migrations(sql_root: &Path) -> Result> { + let mut migrations = Vec::new(); + + if !sql_root.exists() { + bail!("SQL directory not found: {}", sql_root.display()); + } + + for dir_entry in std::fs::read_dir(sql_root)? { + let dir = dir_entry?; + if !dir.file_type()?.is_dir() { + continue; + } + let domain = dir.file_name().to_string_lossy().to_string(); + + for file_entry in std::fs::read_dir(dir.path())? { + let file = file_entry?; + let path = file.path(); + if path.extension().and_then(|e| e.to_str()) != Some("sql") { + continue; + } + + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .context("Invalid filename")?; + + let (table, direction, version) = parse_migration_stem(stem)?; + + let content = std::fs::read_to_string(&path) + .context(format!("Failed to read {path:?}"))?; + let depends_on = parse_depends_on(&content); + + migrations.push(Migration { + domain: domain.clone(), + table, + version, + direction, + path, + depends_on, + }); + } + } + + migrations.sort(); + Ok(migrations) +} + +fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> { + if let Some(pos) = stem.rfind("_up_") { + let table = stem[..pos].to_string(); + let version = stem[pos + 4..] + .parse::() + .context("Invalid version number")?; + Ok((table, MigrationDir::Up, version)) + } else if let Some(pos) = stem.rfind("_down_") { + let table = stem[..pos].to_string(); + let version = stem[pos + 6..] + .parse::() + .context("Invalid version number")?; + Ok((table, MigrationDir::Down, version)) + } else { + bail!("Migration filename must contain _up_ or _down_: {stem}"); + } +} + +async fn ensure_migrations_table(pool: &PgPool) -> Result<()> { + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS _sql_migrations ( + domain TEXT NOT NULL, + table_name TEXT NOT NULL, + version INTEGER NOT NULL, + applied_at TIMESTAMPTZ NOT NULL DEFAULT now(), + checksum TEXT NOT NULL DEFAULT '', + PRIMARY KEY (domain, table_name, version) + ) + "#, + ) + .execute(pool) + .await?; + Ok(()) +} + +async fn applied_set( + pool: &PgPool, +) -> Result> { + let rows: Vec<(String, String, i32, String)> = sqlx::query_as( + "SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version", + ) + .fetch_all(pool) + .await?; + + Ok(rows + .into_iter() + .map(|(d, t, v, c)| ((d, t, v as u32), c)) + .collect()) +} + +async fn record_migration( + pool: &PgPool, + m: &Migration, + checksum: &str, +) -> Result<()> { + sqlx::query( + "INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", + ) + .bind(&m.domain) + .bind(&m.table) + .bind(m.version as i32) + .bind(checksum) + .execute(pool) + .await?; + Ok(()) +} + +fn compute_checksum(content: &str) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + content.hash(&mut hasher); + format!("{:x}", hasher.finish()) +} + +fn parse_depends_on(content: &str) -> Vec { + content + .lines() + .filter_map(|line| { + let line = line.trim(); + line.strip_prefix("-- depends_on:").map(|deps| { + deps.split(',') + .map(|d| d.trim().to_string()) + .filter(|d| !d.is_empty()) + .collect::>() + }) + }) + .flatten() + .collect() +} + +fn topo_sort(migrations: &mut [Migration]) -> Result<()> { + let table_to_idx: HashMap = migrations + .iter() + .enumerate() + .map(|(i, m)| (m.table.clone(), i)) + .collect(); + + let n = migrations.len(); + let mut in_degree = vec![0u32; n]; + let mut adj: Vec> = vec![Vec::new(); n]; + + for (i, m) in migrations.iter().enumerate() { + for dep in &m.depends_on { + if let Some(&j) = table_to_idx.get(dep) { + adj[j].push(i); + in_degree[i] += 1; + } + } + } + + let mut queue: VecDeque = + (0..n).filter(|&i| in_degree[i] == 0).collect(); + let mut order = Vec::with_capacity(n); + while let Some(i) = queue.pop_front() { + order.push(i); + for &next in &adj[i] { + in_degree[next] -= 1; + if in_degree[next] == 0 { + queue.push_back(next); + } + } + } + + if order.len() != n { + bail!("Circular dependency detected among migrations"); + } + + let original: Vec = migrations.iter().cloned().collect(); + for (slot, &idx) in order.iter().enumerate() { + migrations[slot] = original[idx].clone(); + } + + Ok(()) +} + +async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> { + let s: &'static str = Box::leak(sql.to_owned().into_boxed_str()); + sqlx::raw_sql(s).execute(pool).await?; + Ok(()) +} diff --git a/lib/migrate/sql/room/room_attachment_down_02.sql b/lib/migrate/sql/room/room_attachment_down_02.sql new file mode 100644 index 0000000..1714a9b --- /dev/null +++ b/lib/migrate/sql/room/room_attachment_down_02.sql @@ -0,0 +1,2 @@ +ALTER TABLE room_attachment ALTER COLUMN message SET NOT NULL; +ALTER TABLE room_attachment ALTER COLUMN seq SET NOT NULL; diff --git a/lib/migrate/sql/room/room_attachment_up_02.sql b/lib/migrate/sql/room/room_attachment_up_02.sql new file mode 100644 index 0000000..1bf49f5 --- /dev/null +++ b/lib/migrate/sql/room/room_attachment_up_02.sql @@ -0,0 +1,2 @@ +ALTER TABLE room_attachment ALTER COLUMN message DROP NOT NULL; +ALTER TABLE room_attachment ALTER COLUMN seq DROP NOT NULL; diff --git a/lib/migrate/sql/room/room_mention_down_02.sql b/lib/migrate/sql/room/room_mention_down_02.sql new file mode 100644 index 0000000..8c126f9 --- /dev/null +++ b/lib/migrate/sql/room/room_mention_down_02.sql @@ -0,0 +1 @@ +ALTER TABLE room_mention ALTER COLUMN target_id TYPE UUID USING target_id::UUID; diff --git a/lib/migrate/sql/room/room_mention_up_02.sql b/lib/migrate/sql/room/room_mention_up_02.sql new file mode 100644 index 0000000..017758e --- /dev/null +++ b/lib/migrate/sql/room/room_mention_up_02.sql @@ -0,0 +1 @@ +ALTER TABLE room_mention ALTER COLUMN target_id TYPE TEXT; diff --git a/lib/migrate/src/main.rs b/lib/migrate/src/main.rs index 82dca78..e6f3a4a 100644 --- a/lib/migrate/src/main.rs +++ b/lib/migrate/src/main.rs @@ -1,59 +1,12 @@ -use anyhow::{Context, Result, bail}; -use clap::{Parser, Subcommand}; +use anyhow::Result; +use clap::Parser; use sqlx::postgres::PgPoolOptions; -use std::collections::{BTreeMap, HashMap, VecDeque}; -use std::path::{Path, PathBuf}; -use tracing::info; #[derive(Parser)] #[command(name = "migrate", about = "Database migration tool")] struct Cli { - #[arg(short, long)] + #[arg(short, long, env = "DATABASE_URL")] database_url: String, - - #[command(subcommand)] - command: Command, -} - -#[derive(Subcommand)] -enum Command { - Up, - Down, - Fresh, - List, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct Migration { - domain: String, - table: String, - version: u32, - direction: MigrationDir, - path: PathBuf, - depends_on: Vec, -} - -impl Ord for Migration { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - (&self.domain, &self.table, self.version, &self.direction).cmp(&( - &other.domain, - &other.table, - other.version, - &other.direction, - )) - } -} - -impl PartialOrd for Migration { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -enum MigrationDir { - Up, - Down, } #[tokio::main] @@ -64,358 +17,10 @@ async fn main() -> Result<()> { let cli = Cli::parse(); - let database_url = std::env::var("DATABASE_URL") - .context("DATABASE_URL must be set or provided via --database-url")?; - let pool = PgPoolOptions::new() .max_connections(1) - .connect(&database_url) - .await - .context("Failed to connect to database")?; - - let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql"); - - match cli.command { - Command::Up => run_up(&pool, &sql_root).await, - Command::Down => run_down(&pool, &sql_root).await, - Command::Fresh => run_fresh(&pool, &sql_root).await, - Command::List => run_list(&pool, &sql_root).await, - } -} - -fn discover_migrations(sql_root: &Path) -> Result> { - let mut migrations = Vec::new(); - - if !sql_root.exists() { - bail!("SQL directory not found: {}", sql_root.display()); - } - - for dir_entry in std::fs::read_dir(sql_root)? { - let dir = dir_entry?; - if !dir.file_type()?.is_dir() { - continue; - } - let domain = dir.file_name().to_string_lossy().to_string(); - - for file_entry in std::fs::read_dir(dir.path())? { - let file = file_entry?; - let path = file.path(); - if path.extension().and_then(|e| e.to_str()) != Some("sql") { - continue; - } - - let stem = path - .file_stem() - .and_then(|s| s.to_str()) - .context("Invalid filename")?; - - let (table, direction, version) = parse_migration_stem(stem)?; - - let content = std::fs::read_to_string(&path) - .context(format!("Failed to read {path:?}"))?; - let depends_on = parse_depends_on(&content); - - migrations.push(Migration { - domain: domain.clone(), - table, - version, - direction, - path, - depends_on, - }); - } - } - - migrations.sort(); - Ok(migrations) -} - -fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> { - if let Some(pos) = stem.rfind("_up_") { - let table = stem[..pos].to_string(); - let ver_str = &stem[pos + 4..]; - let version = - ver_str.parse::().context("Invalid version number")?; - Ok((table, MigrationDir::Up, version)) - } else if let Some(pos) = stem.rfind("_down_") { - let table = stem[..pos].to_string(); - let ver_str = &stem[pos + 6..]; - let version = - ver_str.parse::().context("Invalid version number")?; - Ok((table, MigrationDir::Down, version)) - } else { - bail!("Migration filename must contain _up_ or _down_: {stem}"); - } -} - -async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<()> { - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS _sql_migrations ( - domain TEXT NOT NULL, - table_name TEXT NOT NULL, - version INTEGER NOT NULL, - applied_at TIMESTAMPTZ NOT NULL DEFAULT now(), - checksum TEXT NOT NULL DEFAULT '', - PRIMARY KEY (domain, table_name, version) - ) - "#, - ) - .execute(pool) - .await?; - Ok(()) -} - -async fn applied_set( - pool: &sqlx::PgPool, -) -> Result> { - let rows: Vec<(String, String, i32, String)> = - sqlx::query_as("SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version") - .fetch_all(pool) - .await?; - - Ok(rows - .into_iter() - .map(|(d, t, v, c)| ((d, t, v as u32), c)) - .collect()) -} - -async fn record_migration( - pool: &sqlx::PgPool, - m: &Migration, - checksum: &str, -) -> Result<()> { - sqlx::query( - r#" - INSERT INTO _sql_migrations (domain, table_name, version, checksum) - VALUES ($1, $2, $3, $4) - ON CONFLICT DO NOTHING - "#, - ) - .bind(&m.domain) - .bind(&m.table) - .bind(m.version as i32) - .bind(checksum) - .execute(pool) - .await?; - Ok(()) -} - -async fn delete_migration(pool: &sqlx::PgPool, m: &Migration) -> Result<()> { - sqlx::query( - "DELETE FROM _sql_migrations WHERE domain = $1 AND table_name = $2 AND version = $3", - ) - .bind(&m.domain) - .bind(&m.table) - .bind(m.version as i32) - .execute(pool) - .await?; - Ok(()) -} - -fn compute_checksum(content: &str) -> String { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - content.hash(&mut hasher); - format!("{:x}", hasher.finish()) -} - -fn parse_depends_on(content: &str) -> Vec { - content - .lines() - .filter_map(|line| { - let line = line.trim(); - line.strip_prefix("-- depends_on:").map(|deps| { - deps.split(',') - .map(|d| d.trim().to_string()) - .filter(|d| !d.is_empty()) - .collect::>() - }) - }) - .flatten() - .collect() -} - -fn topo_sort(migrations: &mut [Migration]) -> Result<()> { - let table_to_idx: HashMap = migrations - .iter() - .enumerate() - .map(|(i, m)| (m.table.clone(), i)) - .collect(); - - let n = migrations.len(); - let mut in_degree = vec![0u32; n]; - let mut adj: Vec> = vec![Vec::new(); n]; - - for (i, m) in migrations.iter().enumerate() { - for dep in &m.depends_on { - if let Some(&j) = table_to_idx.get(dep) { - adj[j].push(i); - in_degree[i] += 1; - } - } - } - - let mut queue: VecDeque = - (0..n).filter(|&i| in_degree[i] == 0).collect(); - - let mut order = Vec::with_capacity(n); - while let Some(i) = queue.pop_front() { - order.push(i); - for &next in &adj[i] { - in_degree[next] -= 1; - if in_degree[next] == 0 { - queue.push_back(next); - } - } - } - - if order.len() != n { - bail!("Circular dependency detected among migrations"); - } - - let original: Vec = migrations.iter().cloned().collect(); - for (slot, &idx) in order.iter().enumerate() { - migrations[slot] = original[idx].clone(); - } - - Ok(()) -} -fn into_static(s: String) -> &'static str { - Box::leak(s.into_boxed_str()) -} - -async fn exec_sql(pool: &sqlx::PgPool, sql: &str) -> Result<()> { - sqlx::raw_sql(into_static(sql.to_owned())) - .execute(pool) + .connect(&cli.database_url) .await?; - Ok(()) -} - -async fn run_up(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { - ensure_migrations_table(pool).await?; - let all = discover_migrations(sql_root)?; - let applied = applied_set(pool).await?; - - let mut up_migrations: Vec<_> = all - .into_iter() - .filter(|m| m.direction == MigrationDir::Up) - .filter(|m| { - !applied.contains_key(&( - m.domain.clone(), - m.table.clone(), - m.version, - )) - }) - .collect(); - - if up_migrations.is_empty() { - info!("All migrations are already applied."); - return Ok(()); - } - - topo_sort(&mut up_migrations)?; - - for m in &up_migrations { - let sql = std::fs::read_to_string(&m.path) - .context(format!("Failed to read {:?}", m.path))?; - let checksum = compute_checksum(&sql); - - info!("Applying {}/{}/v{}", m.domain, m.table, m.version); - exec_sql(pool, &sql).await?; - record_migration(pool, m, &checksum).await?; - } - - info!("Applied {} migration(s).", up_migrations.len()); - Ok(()) -} - -async fn run_down(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { - ensure_migrations_table(pool).await?; - let all = discover_migrations(sql_root)?; - let applied = applied_set(pool).await?; - - let mut down_targets: Vec<_> = all - .into_iter() - .filter(|m| m.direction == MigrationDir::Down) - .filter(|m| { - applied.contains_key(&( - m.domain.clone(), - m.table.clone(), - m.version, - )) - }) - .collect(); - down_targets.sort(); - - if down_targets.is_empty() { - info!("No migrations to roll back."); - return Ok(()); - } - - let m = &down_targets[down_targets.len() - 1]; - let sql = std::fs::read_to_string(&m.path)?; - - info!("Rolling back {}/{}/v{}", m.domain, m.table, m.version); - exec_sql(pool, &sql).await?; - delete_migration(pool, m).await?; - - info!("Rolled back 1 migration."); - Ok(()) -} - -async fn run_fresh(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { - info!("Dropping all tables and re-applying migrations..."); - - exec_sql(pool, "DROP TABLE IF EXISTS _sql_migrations CASCADE").await?; - - let all = discover_migrations(sql_root)?; - let down_migrations: Vec<_> = all - .into_iter() - .filter(|m| m.direction == MigrationDir::Down) - .collect(); - - let mut drops: Vec<_> = down_migrations.iter().collect(); - drops.sort(); - drops.reverse(); - - for m in &drops { - let sql = std::fs::read_to_string(&m.path)?; - let _ = exec_sql(pool, &sql).await; - } - - run_up(pool, sql_root).await -} - -async fn run_list(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> { - ensure_migrations_table(pool).await?; - let all = discover_migrations(sql_root)?; - let applied = applied_set(pool).await?; - - let up_migrations: Vec<_> = all - .into_iter() - .filter(|m| m.direction == MigrationDir::Up) - .collect(); - - println!( - "{:<20} {:<30} {:>8} {}", - "Domain", "Table", "Version", "Status" - ); - println!("{:-<20} {:-<30} {:-<8} {:-<10}", "", "", "", ""); - - for m in &up_migrations { - let key = (m.domain.clone(), m.table.clone(), m.version); - let status = if applied.contains_key(&key) { - "Applied" - } else { - "Pending" - }; - println!( - "{:<20} {:<30} {:>8} {}", - m.domain, m.table, m.version, status - ); - } - - Ok(()) + + migrate::run_up(&pool).await } diff --git a/lib/queue/consumer.rs b/lib/queue/consumer.rs index 77fa31b..d2f7892 100644 --- a/lib/queue/consumer.rs +++ b/lib/queue/consumer.rs @@ -4,6 +4,7 @@ use async_nats::{HeaderMap, jetstream}; use config::AppConfig; use futures_util::StreamExt; use tracing::{error, info, warn}; +use track::CounterVec; use crate::{ handler::{AckAction, MessageHandler}, @@ -16,6 +17,7 @@ pub struct NatsConsumer { max_deliver: i64, retry_delay_secs: u64, durable_name: String, + metrics: Option, } impl NatsConsumer { @@ -33,9 +35,15 @@ impl NatsConsumer { max_deliver: config.nats_max_deliver(), retry_delay_secs: config.nats_retry_delay_secs(), durable_name: durable_name(group_id), + metrics: None, }) } + pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { + self.producer.set_metrics(registry.clone()); + self.metrics = Some(registry); + } + pub async fn start_consuming( &self, topics: &[&str], @@ -61,9 +69,10 @@ impl NatsConsumer { ) .await?; - info!("NATS consumer started subscribing to: {:?}", topics_owned); + info!(topics = ?topics_owned, durable = %self.durable_name, "NATS consumer started"); let producer = self.producer.clone(); + let metrics = self.metrics.clone(); let max_deliver = self.max_deliver; let retry_delay_secs = self.retry_delay_secs; let handler = Arc::new(handler); @@ -73,10 +82,7 @@ impl NatsConsumer { let mut messages = match messages { Ok(messages) => messages, Err(error) => { - error!( - "NATS error while opening consumer stream: {:?}", - error - ); + error!(error = %error, "NATS error while opening consumer stream"); return; } }; @@ -86,6 +92,7 @@ impl NatsConsumer { Ok(message) => { handle_message( &producer, + metrics.as_ref(), max_deliver, retry_delay_secs, handler.as_ref(), @@ -94,7 +101,7 @@ impl NatsConsumer { .await; } Err(error) => { - error!("NATS error while consuming: {:?}", error); + error!(error = %error, "NATS error while consuming"); } } } @@ -106,6 +113,7 @@ impl NatsConsumer { async fn handle_message( producer: &NatsProducer, + metrics: Option<&track::MetricsRegistry>, max_deliver: i64, retry_delay_secs: u64, handler: &H, @@ -116,12 +124,17 @@ async fn handle_message( let subject = message.subject.to_string(); let payload = message.payload.clone(); let delivered = message.info().map(|info| info.delivered).unwrap_or(1); + record_queue_message(metrics, &subject, "received"); match handler.handle(&subject, &payload).await { - AckAction::Ack => ack_message(&message, &subject, "message").await, + AckAction::Ack => { + ack_message(metrics, &message, &subject, "message").await + } AckAction::Nack => { + record_queue_message(metrics, &subject, "nack"); if let Err(error) = handle_nack( producer, + metrics, &message, &subject, &payload, @@ -131,9 +144,13 @@ async fn handle_message( ) .await { + record_queue_message(metrics, &subject, "error"); error!( - "Failed to route NACKed message from subject {}: {:?}", - subject, error + subject = %subject, + delivered, + max_deliver, + error = %error, + "failed to route NACKed message" ); } } @@ -142,6 +159,7 @@ async fn handle_message( async fn handle_nack( producer: &NatsProducer, + metrics: Option<&track::MetricsRegistry>, message: &jetstream::Message, subject: &str, payload: &[u8], @@ -151,8 +169,11 @@ async fn handle_nack( ) -> anyhow::Result<()> { if delivered < max_deliver { warn!( - "Message in subject {} failed (NACK). Retrying delivery {}/{} in {} seconds", - subject, delivered, max_deliver, retry_delay_secs + subject, + delivered, + max_deliver, + retry_delay_secs, + "message NACKed, scheduling retry" ); message .ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs( @@ -162,13 +183,17 @@ async fn handle_nack( .map_err(|error| { anyhow::anyhow!("failed to nack message: {error}") })?; + record_queue_message(metrics, subject, "retry"); return Ok(()); } let dlq_subject = format!("{subject}.dlq"); error!( - "Message in subject {} exceeded max deliver attempts ({}). Routing to DLQ: {}", - subject, max_deliver, dlq_subject + subject, + dlq_subject = %dlq_subject, + delivered, + max_deliver, + "message exceeded max deliver attempts, routing to DLQ" ); let mut headers = HeaderMap::new(); @@ -185,22 +210,69 @@ async fn handle_nack( message.ack().await.map_err(|error| { anyhow::anyhow!("failed to ack DLQ message: {error}") })?; + record_queue_message(metrics, subject, "dlq"); + record_queue_dlq(metrics, subject); Ok(()) } async fn ack_message( + metrics: Option<&track::MetricsRegistry>, message: &jetstream::Message, subject: &str, description: &str, ) { - if let Err(error) = message.ack().await { - error!( - "Failed to ack {} in subject {}: {:?}", - description, subject, error - ); + match message.ack().await { + Ok(()) => record_queue_message(metrics, subject, "ack"), + Err(error) => { + record_queue_message(metrics, subject, "ack_error"); + error!( + subject, + description, + error = %error, + "failed to ack message" + ); + } } } +fn record_queue_message( + metrics: Option<&track::MetricsRegistry>, + topic: &str, + status: &str, +) { + if let Some(metrics) = metrics { + queue_messages_vec(metrics) + .with_label_values(&[topic, status]) + .inc(); + } +} + +fn record_queue_dlq(metrics: Option<&track::MetricsRegistry>, topic: &str) { + if let Some(metrics) = metrics { + queue_dlq_vec(metrics).with_label_values(&[topic]).inc(); + } +} + +fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "queue_messages_total", + "Total queue messages", + &["topic", "status"], + ) + .expect("failed to register queue_messages_total") +} + +fn queue_dlq_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "queue_dlq_total", + "Total messages routed to DLQ", + &["topic"], + ) + .expect("failed to register queue_dlq_total") +} + fn durable_name(name: &str) -> String { name.replace('.', "-") } diff --git a/lib/queue/producer.rs b/lib/queue/producer.rs index 1d65d27..e67fd29 100644 --- a/lib/queue/producer.rs +++ b/lib/queue/producer.rs @@ -3,10 +3,12 @@ use std::time::Duration; use async_nats::{HeaderMap, jetstream}; use config::AppConfig; use serde::Serialize; +use track::CounterVec; #[derive(Clone)] pub struct NatsProducer { jetstream: jetstream::Context, + metrics: Option, } impl NatsProducer { @@ -14,7 +16,14 @@ impl NatsProducer { let jetstream = connect_jetstream(config).await?; ensure_stream(config, &jetstream).await?; - Ok(Self { jetstream }) + Ok(Self { + jetstream, + metrics: None, + }) + } + + pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { + self.metrics = Some(registry); } pub async fn send( @@ -44,19 +53,37 @@ impl NatsProducer { } let subject = subject.to_string(); - let publish = if headers.is_empty() { - self.jetstream - .publish(subject.clone(), payload.to_vec().into()) - .await? - } else { - self.jetstream - .publish_with_headers(subject, headers, payload.to_vec().into()) - .await? - }; + let publish_result: anyhow::Result<()> = async { + let publish = if headers.is_empty() { + self.jetstream + .publish(subject.clone(), payload.to_vec().into()) + .await? + } else { + self.jetstream + .publish_with_headers( + subject.clone(), + headers, + payload.to_vec().into(), + ) + .await? + }; - tokio::time::timeout(Duration::from_secs(5), publish).await??; + tokio::time::timeout(Duration::from_secs(5), publish).await??; + Ok(()) + } + .await; - Ok(()) + self.record_published(&subject, publish_result.is_ok()); + publish_result + } + + fn record_published(&self, topic: &str, success: bool) { + if let Some(reg) = &self.metrics { + let status = if success { "published" } else { "error" }; + queue_messages_vec(reg) + .with_label_values(&[topic, status]) + .inc(); + } } } @@ -88,3 +115,13 @@ pub async fn ensure_stream( }) .await?) } + +fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "queue_messages_total", + "Total queue messages", + &["topic", "status"], + ) + .expect("failed to register queue_messages_total") +} diff --git a/lib/storage/lib.rs b/lib/storage/lib.rs index 7dd630f..6e30294 100644 --- a/lib/storage/lib.rs +++ b/lib/storage/lib.rs @@ -10,6 +10,7 @@ use aws_sdk_s3::primitives::ByteStreamError; pub use error::{StorageError, StorageResult}; pub use local::{LocalStorage, LocalStorageConfig}; pub use s3::{S3Storage, S3StorageConfig}; +use track::CounterVec; #[derive(Clone, Debug)] pub enum AppStorageConfig { @@ -18,11 +19,60 @@ pub enum AppStorageConfig { } #[derive(Clone)] -pub enum AppStorage { +pub struct AppStorage { + inner: StorageBackend, + metrics: Option, +} + +#[derive(Clone)] +enum StorageBackend { Local(LocalStorage), S3(S3Storage), } +impl AppStorage { + pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { + self.metrics = Some(registry); + } + + fn backend_name(&self) -> &str { + match &self.inner { + StorageBackend::Local(_) => "local", + StorageBackend::S3(_) => "s3", + } + } + + fn record_upload(&self, bytes: usize) { + if let Some(reg) = &self.metrics { + storage_ops_vec(reg) + .with_label_values(&["upload", self.backend_name()]) + .inc(); + storage_bytes_vec(reg) + .with_label_values(&["upload"]) + .inc_by(bytes as f64); + } + } + + fn record_download(&self, bytes: usize) { + if let Some(reg) = &self.metrics { + storage_ops_vec(reg) + .with_label_values(&["download", self.backend_name()]) + .inc(); + storage_bytes_vec(reg) + .with_label_values(&["download"]) + .inc_by(bytes as f64); + } + } + + fn record_delete(&self) { + if let Some(reg) = &self.metrics { + storage_ops_vec(reg) + .with_label_values(&["delete", self.backend_name()]) + .inc(); + } + } +} + #[derive(Clone, Debug, Default)] pub struct PutObjectOptions { pub content_type: Option, @@ -87,76 +137,109 @@ pub trait ObjectStorage: Send + Sync { } impl AppStorage { + #[tracing::instrument(skip(config))] pub async fn init(config: AppStorageConfig) -> StorageResult { - match config { + let inner = match config { AppStorageConfig::Local(config) => { - Ok(Self::Local(LocalStorage::connect(config).await?)) + tracing::info!("initializing local storage"); + StorageBackend::Local(LocalStorage::connect(config).await?) } AppStorageConfig::S3(config) => { - Ok(Self::S3(S3Storage::connect(config).await?)) + tracing::info!(bucket = %config.bucket, region = %config.region, "initializing S3 storage"); + StorageBackend::S3(S3Storage::connect(config).await?) } - } + }; + Ok(Self { + inner, + metrics: None, + }) } } #[async_trait] impl ObjectStorage for AppStorage { + #[tracing::instrument(skip(self, body), fields(storage.key = %key))] async fn put_stream( &self, key: &str, body: ByteStream, options: PutObjectOptions, ) -> StorageResult { - match self { - Self::Local(storage) => { + let result = match &self.inner { + StorageBackend::Local(storage) => { storage.put_stream(key, body, options).await } - Self::S3(storage) => storage.put_stream(key, body, options).await, + StorageBackend::S3(storage) => { + storage.put_stream(key, body, options).await + } + }; + if result.is_ok() { + self.record_upload(0); } + result } + #[tracing::instrument(skip(self, bytes), fields(storage.key = %key, storage.size = bytes.len()))] async fn put_bytes( &self, key: &str, bytes: Vec, options: PutObjectOptions, ) -> StorageResult { - match self { - Self::Local(storage) => { + let size = bytes.len(); + let result = match &self.inner { + StorageBackend::Local(storage) => { storage.put_bytes(key, bytes, options).await } - Self::S3(storage) => storage.put_bytes(key, bytes, options).await, + StorageBackend::S3(storage) => { + storage.put_bytes(key, bytes, options).await + } + }; + if result.is_ok() { + self.record_upload(size); } + result } + #[tracing::instrument(skip(self), fields(storage.key = %key))] async fn get_stream( &self, key: &str, ) -> StorageResult { - match self { - Self::Local(storage) => storage.get_stream(key).await, - Self::S3(storage) => storage.get_stream(key).await, + match &self.inner { + StorageBackend::Local(storage) => storage.get_stream(key).await, + StorageBackend::S3(storage) => storage.get_stream(key).await, } } + #[tracing::instrument(skip(self), fields(storage.key = %key))] async fn get_bytes(&self, key: &str) -> StorageResult { - match self { - Self::Local(storage) => storage.get_bytes(key).await, - Self::S3(storage) => storage.get_bytes(key).await, + let result = match &self.inner { + StorageBackend::Local(storage) => storage.get_bytes(key).await, + StorageBackend::S3(storage) => storage.get_bytes(key).await, + }; + if let Ok(obj) = &result { + self.record_download(obj.bytes.len()); } + result } + #[tracing::instrument(skip(self), fields(storage.key = %key))] async fn delete(&self, key: &str) -> StorageResult<()> { - match self { - Self::Local(storage) => storage.delete(key).await, - Self::S3(storage) => storage.delete(key).await, + let result = match &self.inner { + StorageBackend::Local(storage) => storage.delete(key).await, + StorageBackend::S3(storage) => storage.delete(key).await, + }; + if result.is_ok() { + self.record_delete(); } + result } fn public_url(&self, key: &str) -> StorageResult> { - match self { - Self::Local(storage) => storage.public_url(key), - Self::S3(storage) => storage.public_url(key), + match &self.inner { + StorageBackend::Local(storage) => storage.public_url(key), + StorageBackend::S3(storage) => storage.public_url(key), } } @@ -165,17 +248,37 @@ impl ObjectStorage for AppStorage { key: &str, expires_in: Duration, ) -> StorageResult { - match self { - Self::Local(storage) => { + match &self.inner { + StorageBackend::Local(storage) => { storage.presigned_get_url(key, expires_in).await } - Self::S3(storage) => { + StorageBackend::S3(storage) => { storage.presigned_get_url(key, expires_in).await } } } } +fn storage_ops_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "storage_operations_total", + "Total storage operations", + &["operation", "backend"], + ) + .expect("failed to register storage_operations_total") +} + +fn storage_bytes_vec(registry: &track::MetricsRegistry) -> CounterVec { + registry + .register_counter_vec( + "storage_bytes_total", + "Total bytes transferred", + &["operation"], + ) + .expect("failed to register storage_bytes_total") +} + pub async fn collect_byte_stream( body: ByteStream, ) -> Result, ByteStreamError> {