refactor: update infrastructure libs (config, db, cache, queue, storage, migrate)

This commit is contained in:
zhenyi 2026-06-01 22:04:25 +08:00
parent e44f3d13c4
commit 734e1c4cc8
16 changed files with 1049 additions and 727 deletions

319
lib/cache/app.rs vendored Normal file
View File

@ -0,0 +1,319 @@
use std::time::Duration;
use track::CounterVec;
use crate::{
cluster::{ClusterCache, ClusterCacheConfig},
error::{CacheError, CacheResult},
local::{LocalCacheConfig, MokaCache},
};
// ============================================================================
// Configuration
// ============================================================================
#[derive(Clone, Debug)]
pub struct AppCacheConfig {
pub local: LocalCacheConfig,
pub cluster: Option<ClusterCacheConfig>,
pub default_ttl: Option<Duration>,
pub cluster_write_through: bool,
}
impl Default for AppCacheConfig {
fn default() -> Self {
Self {
local: LocalCacheConfig::default(),
cluster: None,
default_ttl: Some(Duration::from_secs(300)),
cluster_write_through: true,
}
}
}
impl TryFrom<&config::AppConfig> for AppCacheConfig {
type Error = CacheError;
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
let local = LocalCacheConfig {
max_capacity: config
.cache_local_max_capacity()
.map_err(|error| CacheError::Config(error.to_string()))?,
time_to_live: config
.cache_local_ttl()
.map_err(|error| CacheError::Config(error.to_string()))?,
time_to_idle: config
.cache_local_tti()
.map_err(|error| CacheError::Config(error.to_string()))?,
};
let cluster = if config
.cache_cluster_enabled()
.map_err(|error| CacheError::Config(error.to_string()))?
{
Some(ClusterCacheConfig {
urls: config
.redis_urls()
.map_err(|error| CacheError::Config(error.to_string()))?,
key_prefix: config.cache_cluster_key_prefix(),
command_timeout: config
.cache_cluster_command_timeout()
.map_err(|error| CacheError::Config(error.to_string()))?,
})
} else {
None
};
Ok(Self {
local,
cluster,
default_ttl: config
.cache_default_ttl()
.map_err(|error| CacheError::Config(error.to_string()))?,
cluster_write_through: config
.cache_cluster_write_through()
.map_err(|error| CacheError::Config(error.to_string()))?,
})
}
}
// ============================================================================
// AppCache
// ============================================================================
#[derive(Clone)]
pub struct AppCache {
pub local: MokaCache,
pub cluster: Option<ClusterCache>,
default_ttl: Option<Duration>,
cluster_write_through: bool,
metrics: Option<track::MetricsRegistry>,
}
impl AppCache {
#[tracing::instrument(skip(config))]
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
let local = MokaCache::with_config(config.local);
let cluster = match config.cluster {
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
Ok(cluster) => cluster,
Err(e) => {
tracing::error!(error = %e, "failed to connect to cache cluster");
return Err(e);
}
}),
None => None,
};
tracing::info!(has_cluster = cluster.is_some(), "cache initialized");
Ok(Self {
local,
cluster,
default_ttl: config.default_ttl,
cluster_write_through: config.cluster_write_through,
metrics: None,
})
}
pub fn local_only(local: MokaCache) -> Self {
Self {
local,
cluster: None,
default_ttl: None,
cluster_write_through: false,
metrics: None,
}
}
/// Attach a metrics registry for recording cache counters.
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
self.metrics = Some(registry);
}
#[tracing::instrument(skip(self), fields(cache.key = %key))]
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
if let Some(value) = self.local.get(key).await? {
tracing::debug!("cache hit (local)");
self.record_hit("local");
return Ok(Some(value));
}
let Some(cluster) = &self.cluster else {
tracing::debug!("cache miss");
self.record_miss();
return Ok(None);
};
let value = cluster.get::<T>(key).await?;
if let Some(value) = &value {
self.local.set(key, value).await?;
tracing::debug!("cache hit (cluster)");
self.record_hit("cluster");
} else {
tracing::debug!("cache miss");
self.record_miss();
}
Ok(value)
}
#[tracing::instrument(skip(self, value), fields(cache.key = %key))]
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
where
T: serde::Serialize + ?Sized,
{
self.local.set(key, value).await?;
if self.cluster_write_through
&& let Some(cluster) = &self.cluster
{
cluster.set(key, value, self.default_ttl).await?;
}
self.record_set();
Ok(())
}
pub async fn set_with_ttl<T>(
&self,
key: &str,
value: &T,
ttl: std::time::Duration,
) -> CacheResult<()>
where
T: serde::Serialize + ?Sized,
{
self.local.set(key, value).await?;
if self.cluster_write_through
&& let Some(cluster) = &self.cluster
{
cluster.set(key, value, Some(ttl)).await?;
}
Ok(())
}
#[tracing::instrument(skip(self), fields(cache.key = %key))]
pub async fn remove(&self, key: &str) -> CacheResult<()> {
self.local.remove(key).await;
if let Some(cluster) = &self.cluster {
cluster.remove(key).await?;
}
self.record_remove();
Ok(())
}
fn record_hit(&self, tier: &str) {
if let Some(reg) = &self.metrics {
cache_hits_vec(reg).with_label_values(&[tier]).inc();
}
}
fn record_miss(&self) {
if let Some(reg) = &self.metrics {
cache_misses_vec(reg).with_label_values(&[]).inc();
}
}
fn record_set(&self) {
if let Some(reg) = &self.metrics {
cache_sets_vec(reg).with_label_values(&[]).inc();
}
}
fn record_remove(&self) {
if let Some(reg) = &self.metrics {
cache_removes_vec(reg).with_label_values(&[]).inc();
}
}
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
let pattern = pattern.to_string();
let local_pattern = pattern.clone();
self.local.invalidate_entries_if(move |key| {
simple_glob_match(&local_pattern, key)
});
let mut removed = 0u64;
if let Some(cluster) = &self.cluster {
removed = cluster.delete_pattern(&pattern).await?;
}
Ok(removed)
}
pub async fn ping_cluster(&self) -> CacheResult<()> {
if let Some(cluster) = &self.cluster {
cluster.ping().await?;
}
Ok(())
}
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
self.cluster.as_ref().map(|c| c.conn())
}
}
fn cache_hits_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec("cache_hits_total", "Total cache hits", &["tier"])
.expect("failed to register cache_hits_total")
}
fn cache_misses_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec("cache_misses_total", "Total cache misses", &[])
.expect("failed to register cache_misses_total")
}
fn cache_sets_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"cache_sets_total",
"Total cache set operations",
&[],
)
.expect("failed to register cache_sets_total")
}
fn cache_removes_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"cache_removes_total",
"Total cache remove operations",
&[],
)
.expect("failed to register cache_removes_total")
}
// ============================================================================
// Helpers
// ============================================================================
fn simple_glob_match(pattern: &str, key: &str) -> bool {
let p = pattern.as_bytes();
let k = key.as_bytes();
let (mut pi, mut ki) = (0usize, 0usize);
let mut backtrack_p: Option<usize> = None;
let mut backtrack_k: usize = 0;
loop {
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
pi += 1;
ki += 1;
} else if pi < p.len() && p[pi] == b'*' {
backtrack_p = Some(pi);
backtrack_k = ki;
pi += 1;
} else if let Some(saved_pi) = backtrack_p {
backtrack_k += 1;
ki = backtrack_k;
pi = saved_pi + 1;
} else {
return pi == p.len() && ki == k.len();
}
if pi == p.len() && ki == k.len() {
return true;
}
}
}

220
lib/cache/lib.rs vendored
View File

@ -1,227 +1,11 @@
pub mod app;
pub mod cluster;
pub mod error;
pub mod local;
use std::time::Duration;
pub use crate::{
app::{AppCache, AppCacheConfig},
cluster::{ClusterCache, ClusterCacheConfig},
error::{CacheError, CacheResult},
local::{LocalCacheConfig, MokaCache},
};
#[derive(Clone, Debug)]
pub struct AppCacheConfig {
pub local: LocalCacheConfig,
pub cluster: Option<ClusterCacheConfig>,
pub default_ttl: Option<Duration>,
pub cluster_write_through: bool,
}
impl Default for AppCacheConfig {
fn default() -> Self {
Self {
local: LocalCacheConfig::default(),
cluster: None,
default_ttl: Some(Duration::from_secs(300)),
cluster_write_through: true,
}
}
}
#[derive(Clone)]
pub struct AppCache {
pub local: MokaCache,
pub cluster: Option<ClusterCache>,
default_ttl: Option<Duration>,
cluster_write_through: bool,
}
impl AppCache {
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
let local = MokaCache::with_config(config.local);
let cluster = match config.cluster {
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
Ok(cluster) => cluster,
Err(e) => {
println!("cache:init:error with: {}", e);
return Err(e);
}
}),
None => None,
};
Ok(Self {
local,
cluster,
default_ttl: config.default_ttl,
cluster_write_through: config.cluster_write_through,
})
}
pub fn local_only(local: MokaCache) -> Self {
Self {
local,
cluster: None,
default_ttl: None,
cluster_write_through: false,
}
}
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
if let Some(value) = self.local.get(key).await? {
return Ok(Some(value));
}
let Some(cluster) = &self.cluster else {
return Ok(None);
};
let value = cluster.get::<T>(key).await?;
if let Some(value) = &value {
self.local.set(key, value).await?;
}
Ok(value)
}
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
where
T: serde::Serialize + ?Sized,
{
self.local.set(key, value).await?;
if self.cluster_write_through
&& let Some(cluster) = &self.cluster
{
cluster.set(key, value, self.default_ttl).await?;
}
Ok(())
}
pub async fn set_with_ttl<T>(
&self,
key: &str,
value: &T,
ttl: std::time::Duration,
) -> CacheResult<()>
where
T: serde::Serialize + ?Sized,
{
self.local.set(key, value).await?;
if self.cluster_write_through
&& let Some(cluster) = &self.cluster
{
cluster.set(key, value, Some(ttl)).await?;
}
Ok(())
}
pub async fn remove(&self, key: &str) -> CacheResult<()> {
self.local.remove(key).await;
if let Some(cluster) = &self.cluster {
cluster.remove(key).await?;
}
Ok(())
}
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
let pattern = pattern.to_string();
let local_pattern = pattern.clone();
self.local.invalidate_entries_if(move |key| {
simple_glob_match(&local_pattern, key)
});
let mut removed = 0u64;
if let Some(cluster) = &self.cluster {
removed = cluster.delete_pattern(&pattern).await?;
}
Ok(removed)
}
pub async fn ping_cluster(&self) -> CacheResult<()> {
if let Some(cluster) = &self.cluster {
cluster.ping().await?;
}
Ok(())
}
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
self.cluster.as_ref().map(|c| c.conn())
}
}
impl TryFrom<&config::AppConfig> for AppCacheConfig {
type Error = CacheError;
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
let local = LocalCacheConfig {
max_capacity: config
.cache_local_max_capacity()
.map_err(|error| CacheError::Config(error.to_string()))?,
time_to_live: config
.cache_local_ttl()
.map_err(|error| CacheError::Config(error.to_string()))?,
time_to_idle: config
.cache_local_tti()
.map_err(|error| CacheError::Config(error.to_string()))?,
};
let cluster = if config
.cache_cluster_enabled()
.map_err(|error| CacheError::Config(error.to_string()))?
{
Some(ClusterCacheConfig {
urls: config
.redis_urls()
.map_err(|error| CacheError::Config(error.to_string()))?,
key_prefix: config.cache_cluster_key_prefix(),
command_timeout: config
.cache_cluster_command_timeout()
.map_err(|error| CacheError::Config(error.to_string()))?,
})
} else {
None
};
Ok(Self {
local,
cluster,
default_ttl: config
.cache_default_ttl()
.map_err(|error| CacheError::Config(error.to_string()))?,
cluster_write_through: config
.cache_cluster_write_through()
.map_err(|error| CacheError::Config(error.to_string()))?,
})
}
}
fn simple_glob_match(pattern: &str, key: &str) -> bool {
let p = pattern.as_bytes();
let k = key.as_bytes();
let (mut pi, mut ki) = (0usize, 0usize);
let mut backtrack_p: Option<usize> = None;
let mut backtrack_k: usize = 0;
loop {
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
pi += 1;
ki += 1;
} else if pi < p.len() && p[pi] == b'*' {
backtrack_p = Some(pi);
backtrack_k = ki;
pi += 1;
} else if let Some(saved_pi) = backtrack_p {
backtrack_k += 1;
ki = backtrack_k;
pi = saved_pi + 1;
} else {
return pi == p.len() && ki == k.len();
}
if pi == p.len() && ki == k.len() {
return true;
}
}
}

View File

@ -28,6 +28,13 @@ impl AppConfig {
Ok(8080)
}
pub fn email_health_port(&self) -> u16 {
self.env
.get("APP_EMAIL_HEALTH_PORT")
.and_then(|port| port.parse::<u16>().ok())
.unwrap_or(8083)
}
pub fn session_secret(&self) -> anyhow::Result<String> {
if let Some(secret) = self.env.get("APP_SESSION_SECRET") {
return Ok(secret.to_string());

39
lib/config/app_config.rs Normal file
View File

@ -0,0 +1,39 @@
use std::{collections::HashMap, sync::OnceLock};
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
#[derive(Clone, Debug)]
pub struct AppConfig {
pub env: HashMap<String, String>,
}
impl AppConfig {
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
pub fn load() -> AppConfig {
let mut env = HashMap::new();
for env_file in AppConfig::ENV_FILES {
if let Err(e) = dotenvy::from_path(env_file) {
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
}
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
for line in env_file_content.lines() {
if let Some((key, value)) = line.split_once('=') {
env.insert(key.to_string(), value.to_string());
}
}
}
}
env = env.into_iter().chain(std::env::vars()).collect();
let this = AppConfig { env };
if let Some(config) = GLOBAL_CONFIG.get() {
config.clone()
} else {
let _ = GLOBAL_CONFIG.set(this);
GLOBAL_CONFIG
.get()
.expect("global config should be set after load")
.clone()
}
}
}

View File

@ -1,44 +1,6 @@
use std::{collections::HashMap, sync::OnceLock};
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
#[derive(Clone, Debug)]
pub struct AppConfig {
pub env: HashMap<String, String>,
}
impl AppConfig {
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
pub fn load() -> AppConfig {
let mut env = HashMap::new();
for env_file in AppConfig::ENV_FILES {
if let Err(e) = dotenvy::from_path(env_file) {
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
}
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
for line in env_file_content.lines() {
if let Some((key, value)) = line.split_once('=') {
env.insert(key.to_string(), value.to_string());
}
}
}
}
env = env.into_iter().chain(std::env::vars()).collect();
let this = AppConfig { env };
if GLOBAL_CONFIG.get().is_some() {
GLOBAL_CONFIG.get().unwrap().clone()
} else {
let _ = GLOBAL_CONFIG.set(this);
GLOBAL_CONFIG
.get()
.expect("global config should be set after load")
.clone()
}
}
}
pub mod ai;
pub mod app;
pub mod app_config;
pub mod auth;
pub mod avatar;
pub mod cache;
@ -57,3 +19,5 @@ pub mod redis;
pub mod smtp;
pub mod ssh;
pub mod storage;
pub use app_config::{AppConfig, GLOBAL_CONFIG};

View File

@ -61,7 +61,7 @@ impl AppConfig {
if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") {
return Ok(endpoint.to_string());
}
Ok("http://localhost:5080/api/default/v1/traces".to_string())
Ok("http://localhost:4318".to_string())
}
pub fn otel_service_name(&self) -> anyhow::Result<String> {

View File

@ -7,6 +7,7 @@ use sqlx::{
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
},
};
use track::{CounterVec, HistogramVec};
use crate::{
route::{SqlRoute, route_sql},
@ -17,9 +18,11 @@ use crate::{
pub struct AppDatabase {
db_write: PgPool,
db_read: Option<PgPool>,
metrics: Option<track::MetricsRegistry>,
}
impl AppDatabase {
#[tracing::instrument(skip(cfg))]
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
let db_url = cfg.database_url()?;
let max_connections = cfg.database_max_connections()?;
@ -69,7 +72,15 @@ impl AppDatabase {
None
};
Ok(Self { db_write, db_read })
Ok(Self {
db_write,
db_read,
metrics: None,
})
}
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
self.metrics = Some(registry);
}
pub fn writer(&self) -> &PgPool {
@ -87,11 +98,14 @@ impl AppDatabase {
}
}
#[tracing::instrument(skip(self), fields(sql.route = "write"))]
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
let txn = self.db_write.begin().await?;
tracing::debug!("db transaction started");
Ok(AppTransaction { inner: txn })
}
#[tracing::instrument(skip(self), fields(sql.route = "read"))]
pub async fn begin_read_only(
&self,
) -> Result<AppTransaction<'_>, sqlx::Error> {
@ -101,9 +115,11 @@ impl AppDatabase {
.execute(&mut *txn)
.await?;
tracing::debug!("db read-only transaction started");
Ok(AppTransaction { inner: txn })
}
#[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))]
pub async fn execute(
&self,
sql: &str,
@ -111,18 +127,38 @@ impl AppDatabase {
self.execute_with_args(sql, PgArguments::default()).await
}
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))]
pub async fn execute_with_args(
&self,
sql: &str,
args: PgArguments,
) -> Result<PgQueryResult, sqlx::Error> {
let pool = self.route_pool(sql);
let start = std::time::Instant::now();
sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
.execute(pool)
.await
.await;
let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") {
"insert"
} else if sql.trim_start().to_uppercase().starts_with("UPDATE") {
"update"
} else if sql.trim_start().to_uppercase().starts_with("DELETE") {
"delete"
} else {
"execute"
};
self.record_query(
kind,
self.route_label(sql),
start.elapsed(),
result.is_ok(),
);
result
}
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))]
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
@ -130,6 +166,7 @@ impl AppDatabase {
self.fetch_one_with_args(sql, PgArguments::default()).await
}
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))]
pub async fn fetch_one_with_args<T>(
&self,
sql: &str,
@ -139,12 +176,23 @@ impl AppDatabase {
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
let start = std::time::Instant::now();
let result =
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_one(pool)
.await
.await;
self.record_query(
"select",
self.route_label(sql),
start.elapsed(),
result.is_ok(),
);
result
}
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))]
pub async fn fetch_optional<T>(
&self,
sql: &str,
@ -156,6 +204,7 @@ impl AppDatabase {
.await
}
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))]
pub async fn fetch_optional_with_args<T>(
&self,
sql: &str,
@ -165,12 +214,23 @@ impl AppDatabase {
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
let start = std::time::Instant::now();
let result =
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_optional(pool)
.await
.await;
self.record_query(
"select",
self.route_label(sql),
start.elapsed(),
result.is_ok(),
);
result
}
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))]
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
@ -178,6 +238,7 @@ impl AppDatabase {
self.fetch_all_with_args(sql, PgArguments::default()).await
}
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))]
pub async fn fetch_all_with_args<T>(
&self,
sql: &str,
@ -187,10 +248,45 @@ impl AppDatabase {
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
let start = std::time::Instant::now();
let result =
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_all(pool)
.await
.await;
self.record_query(
"select",
self.route_label(sql),
start.elapsed(),
result.is_ok(),
);
result
}
fn route_label(&self, sql: &str) -> &str {
match route_sql(sql) {
SqlRoute::Write => "write",
SqlRoute::Read => "read",
}
}
fn record_query(
&self,
kind: &str,
route: &str,
duration: Duration,
success: bool,
) {
if let Some(reg) = &self.metrics {
let status = if success { "success" } else { "error" };
db_queries_vec(reg)
.with_label_values(&[kind, route, status])
.inc();
db_query_duration_vec(reg)
.with_label_values(&[kind, route])
.observe(duration.as_secs_f64());
}
}
}
@ -230,3 +326,26 @@ async fn build_pool(
pool_options.connect_with(options).await
}
fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"db_queries_total",
"Total database queries",
&["kind", "route", "status"],
)
.expect("failed to register db_queries_total")
}
fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec {
registry
.register_histogram_vec(
"db_query_duration_seconds",
"DB query duration in seconds",
&["kind", "route"],
vec![
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
],
)
.expect("failed to register db_query_duration_seconds")
}

267
lib/migrate/lib.rs Normal file
View File

@ -0,0 +1,267 @@
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, bail};
use sqlx::PgPool;
#[derive(Debug, Clone, PartialEq, Eq)]
struct Migration {
domain: String,
table: String,
version: u32,
direction: MigrationDir,
path: PathBuf,
depends_on: Vec<String>,
}
impl Ord for Migration {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
&other.domain,
&other.table,
other.version,
&other.direction,
))
}
}
impl PartialOrd for Migration {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum MigrationDir {
Up,
Down,
}
pub async fn run_up(pool: &PgPool) -> Result<()> {
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
ensure_migrations_table(pool).await?;
let all = discover_migrations(&sql_root)?;
let applied = applied_set(pool).await?;
let mut up_migrations: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Up)
.filter(|m| {
!applied.contains_key(&(
m.domain.clone(),
m.table.clone(),
m.version,
))
})
.collect();
if up_migrations.is_empty() {
tracing::info!("All migrations are already applied.");
return Ok(());
}
topo_sort(&mut up_migrations)?;
for m in &up_migrations {
let sql = std::fs::read_to_string(&m.path)
.context(format!("Failed to read {:?}", m.path))?;
let checksum = compute_checksum(&sql);
tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration");
exec_sql(pool, &sql).await?;
record_migration(pool, m, &checksum).await?;
}
tracing::info!("Applied {} migration(s).", up_migrations.len());
Ok(())
}
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
let mut migrations = Vec::new();
if !sql_root.exists() {
bail!("SQL directory not found: {}", sql_root.display());
}
for dir_entry in std::fs::read_dir(sql_root)? {
let dir = dir_entry?;
if !dir.file_type()?.is_dir() {
continue;
}
let domain = dir.file_name().to_string_lossy().to_string();
for file_entry in std::fs::read_dir(dir.path())? {
let file = file_entry?;
let path = file.path();
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
continue;
}
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.context("Invalid filename")?;
let (table, direction, version) = parse_migration_stem(stem)?;
let content = std::fs::read_to_string(&path)
.context(format!("Failed to read {path:?}"))?;
let depends_on = parse_depends_on(&content);
migrations.push(Migration {
domain: domain.clone(),
table,
version,
direction,
path,
depends_on,
});
}
}
migrations.sort();
Ok(migrations)
}
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
if let Some(pos) = stem.rfind("_up_") {
let table = stem[..pos].to_string();
let version = stem[pos + 4..]
.parse::<u32>()
.context("Invalid version number")?;
Ok((table, MigrationDir::Up, version))
} else if let Some(pos) = stem.rfind("_down_") {
let table = stem[..pos].to_string();
let version = stem[pos + 6..]
.parse::<u32>()
.context("Invalid version number")?;
Ok((table, MigrationDir::Down, version))
} else {
bail!("Migration filename must contain _up_ or _down_: {stem}");
}
}
async fn ensure_migrations_table(pool: &PgPool) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS _sql_migrations (
domain TEXT NOT NULL,
table_name TEXT NOT NULL,
version INTEGER NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
checksum TEXT NOT NULL DEFAULT '',
PRIMARY KEY (domain, table_name, version)
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
async fn applied_set(
pool: &PgPool,
) -> Result<BTreeMap<(String, String, u32), String>> {
let rows: Vec<(String, String, i32, String)> = sqlx::query_as(
"SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version",
)
.fetch_all(pool)
.await?;
Ok(rows
.into_iter()
.map(|(d, t, v, c)| ((d, t, v as u32), c))
.collect())
}
async fn record_migration(
pool: &PgPool,
m: &Migration,
checksum: &str,
) -> Result<()> {
sqlx::query(
"INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING",
)
.bind(&m.domain)
.bind(&m.table)
.bind(m.version as i32)
.bind(checksum)
.execute(pool)
.await?;
Ok(())
}
fn compute_checksum(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn parse_depends_on(content: &str) -> Vec<String> {
content
.lines()
.filter_map(|line| {
let line = line.trim();
line.strip_prefix("-- depends_on:").map(|deps| {
deps.split(',')
.map(|d| d.trim().to_string())
.filter(|d| !d.is_empty())
.collect::<Vec<_>>()
})
})
.flatten()
.collect()
}
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
let table_to_idx: HashMap<String, usize> = migrations
.iter()
.enumerate()
.map(|(i, m)| (m.table.clone(), i))
.collect();
let n = migrations.len();
let mut in_degree = vec![0u32; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, m) in migrations.iter().enumerate() {
for dep in &m.depends_on {
if let Some(&j) = table_to_idx.get(dep) {
adj[j].push(i);
in_degree[i] += 1;
}
}
}
let mut queue: VecDeque<usize> =
(0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(i) = queue.pop_front() {
order.push(i);
for &next in &adj[i] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push_back(next);
}
}
}
if order.len() != n {
bail!("Circular dependency detected among migrations");
}
let original: Vec<Migration> = migrations.iter().cloned().collect();
for (slot, &idx) in order.iter().enumerate() {
migrations[slot] = original[idx].clone();
}
Ok(())
}
async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> {
let s: &'static str = Box::leak(sql.to_owned().into_boxed_str());
sqlx::raw_sql(s).execute(pool).await?;
Ok(())
}

View File

@ -0,0 +1,2 @@
ALTER TABLE room_attachment ALTER COLUMN message SET NOT NULL;
ALTER TABLE room_attachment ALTER COLUMN seq SET NOT NULL;

View File

@ -0,0 +1,2 @@
ALTER TABLE room_attachment ALTER COLUMN message DROP NOT NULL;
ALTER TABLE room_attachment ALTER COLUMN seq DROP NOT NULL;

View File

@ -0,0 +1 @@
ALTER TABLE room_mention ALTER COLUMN target_id TYPE UUID USING target_id::UUID;

View File

@ -0,0 +1 @@
ALTER TABLE room_mention ALTER COLUMN target_id TYPE TEXT;

View File

@ -1,59 +1,12 @@
use anyhow::{Context, Result, bail};
use clap::{Parser, Subcommand};
use anyhow::Result;
use clap::Parser;
use sqlx::postgres::PgPoolOptions;
use std::collections::{BTreeMap, HashMap, VecDeque};
use std::path::{Path, PathBuf};
use tracing::info;
#[derive(Parser)]
#[command(name = "migrate", about = "Database migration tool")]
struct Cli {
#[arg(short, long)]
#[arg(short, long, env = "DATABASE_URL")]
database_url: String,
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
Up,
Down,
Fresh,
List,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Migration {
domain: String,
table: String,
version: u32,
direction: MigrationDir,
path: PathBuf,
depends_on: Vec<String>,
}
impl Ord for Migration {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
&other.domain,
&other.table,
other.version,
&other.direction,
))
}
}
impl PartialOrd for Migration {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum MigrationDir {
Up,
Down,
}
#[tokio::main]
@ -64,358 +17,10 @@ async fn main() -> Result<()> {
let cli = Cli::parse();
let database_url = std::env::var("DATABASE_URL")
.context("DATABASE_URL must be set or provided via --database-url")?;
let pool = PgPoolOptions::new()
.max_connections(1)
.connect(&database_url)
.await
.context("Failed to connect to database")?;
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
match cli.command {
Command::Up => run_up(&pool, &sql_root).await,
Command::Down => run_down(&pool, &sql_root).await,
Command::Fresh => run_fresh(&pool, &sql_root).await,
Command::List => run_list(&pool, &sql_root).await,
}
}
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
let mut migrations = Vec::new();
if !sql_root.exists() {
bail!("SQL directory not found: {}", sql_root.display());
}
for dir_entry in std::fs::read_dir(sql_root)? {
let dir = dir_entry?;
if !dir.file_type()?.is_dir() {
continue;
}
let domain = dir.file_name().to_string_lossy().to_string();
for file_entry in std::fs::read_dir(dir.path())? {
let file = file_entry?;
let path = file.path();
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
continue;
}
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.context("Invalid filename")?;
let (table, direction, version) = parse_migration_stem(stem)?;
let content = std::fs::read_to_string(&path)
.context(format!("Failed to read {path:?}"))?;
let depends_on = parse_depends_on(&content);
migrations.push(Migration {
domain: domain.clone(),
table,
version,
direction,
path,
depends_on,
});
}
}
migrations.sort();
Ok(migrations)
}
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
if let Some(pos) = stem.rfind("_up_") {
let table = stem[..pos].to_string();
let ver_str = &stem[pos + 4..];
let version =
ver_str.parse::<u32>().context("Invalid version number")?;
Ok((table, MigrationDir::Up, version))
} else if let Some(pos) = stem.rfind("_down_") {
let table = stem[..pos].to_string();
let ver_str = &stem[pos + 6..];
let version =
ver_str.parse::<u32>().context("Invalid version number")?;
Ok((table, MigrationDir::Down, version))
} else {
bail!("Migration filename must contain _up_ or _down_: {stem}");
}
}
async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS _sql_migrations (
domain TEXT NOT NULL,
table_name TEXT NOT NULL,
version INTEGER NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
checksum TEXT NOT NULL DEFAULT '',
PRIMARY KEY (domain, table_name, version)
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
async fn applied_set(
pool: &sqlx::PgPool,
) -> Result<BTreeMap<(String, String, u32), String>> {
let rows: Vec<(String, String, i32, String)> =
sqlx::query_as("SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version")
.fetch_all(pool)
.connect(&cli.database_url)
.await?;
Ok(rows
.into_iter()
.map(|(d, t, v, c)| ((d, t, v as u32), c))
.collect())
}
async fn record_migration(
pool: &sqlx::PgPool,
m: &Migration,
checksum: &str,
) -> Result<()> {
sqlx::query(
r#"
INSERT INTO _sql_migrations (domain, table_name, version, checksum)
VALUES ($1, $2, $3, $4)
ON CONFLICT DO NOTHING
"#,
)
.bind(&m.domain)
.bind(&m.table)
.bind(m.version as i32)
.bind(checksum)
.execute(pool)
.await?;
Ok(())
}
async fn delete_migration(pool: &sqlx::PgPool, m: &Migration) -> Result<()> {
sqlx::query(
"DELETE FROM _sql_migrations WHERE domain = $1 AND table_name = $2 AND version = $3",
)
.bind(&m.domain)
.bind(&m.table)
.bind(m.version as i32)
.execute(pool)
.await?;
Ok(())
}
fn compute_checksum(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn parse_depends_on(content: &str) -> Vec<String> {
content
.lines()
.filter_map(|line| {
let line = line.trim();
line.strip_prefix("-- depends_on:").map(|deps| {
deps.split(',')
.map(|d| d.trim().to_string())
.filter(|d| !d.is_empty())
.collect::<Vec<_>>()
})
})
.flatten()
.collect()
}
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
let table_to_idx: HashMap<String, usize> = migrations
.iter()
.enumerate()
.map(|(i, m)| (m.table.clone(), i))
.collect();
let n = migrations.len();
let mut in_degree = vec![0u32; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, m) in migrations.iter().enumerate() {
for dep in &m.depends_on {
if let Some(&j) = table_to_idx.get(dep) {
adj[j].push(i);
in_degree[i] += 1;
}
}
}
let mut queue: VecDeque<usize> =
(0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(i) = queue.pop_front() {
order.push(i);
for &next in &adj[i] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push_back(next);
}
}
}
if order.len() != n {
bail!("Circular dependency detected among migrations");
}
let original: Vec<Migration> = migrations.iter().cloned().collect();
for (slot, &idx) in order.iter().enumerate() {
migrations[slot] = original[idx].clone();
}
Ok(())
}
fn into_static(s: String) -> &'static str {
Box::leak(s.into_boxed_str())
}
async fn exec_sql(pool: &sqlx::PgPool, sql: &str) -> Result<()> {
sqlx::raw_sql(into_static(sql.to_owned()))
.execute(pool)
.await?;
Ok(())
}
async fn run_up(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
ensure_migrations_table(pool).await?;
let all = discover_migrations(sql_root)?;
let applied = applied_set(pool).await?;
let mut up_migrations: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Up)
.filter(|m| {
!applied.contains_key(&(
m.domain.clone(),
m.table.clone(),
m.version,
))
})
.collect();
if up_migrations.is_empty() {
info!("All migrations are already applied.");
return Ok(());
}
topo_sort(&mut up_migrations)?;
for m in &up_migrations {
let sql = std::fs::read_to_string(&m.path)
.context(format!("Failed to read {:?}", m.path))?;
let checksum = compute_checksum(&sql);
info!("Applying {}/{}/v{}", m.domain, m.table, m.version);
exec_sql(pool, &sql).await?;
record_migration(pool, m, &checksum).await?;
}
info!("Applied {} migration(s).", up_migrations.len());
Ok(())
}
async fn run_down(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
ensure_migrations_table(pool).await?;
let all = discover_migrations(sql_root)?;
let applied = applied_set(pool).await?;
let mut down_targets: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Down)
.filter(|m| {
applied.contains_key(&(
m.domain.clone(),
m.table.clone(),
m.version,
))
})
.collect();
down_targets.sort();
if down_targets.is_empty() {
info!("No migrations to roll back.");
return Ok(());
}
let m = &down_targets[down_targets.len() - 1];
let sql = std::fs::read_to_string(&m.path)?;
info!("Rolling back {}/{}/v{}", m.domain, m.table, m.version);
exec_sql(pool, &sql).await?;
delete_migration(pool, m).await?;
info!("Rolled back 1 migration.");
Ok(())
}
async fn run_fresh(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
info!("Dropping all tables and re-applying migrations...");
exec_sql(pool, "DROP TABLE IF EXISTS _sql_migrations CASCADE").await?;
let all = discover_migrations(sql_root)?;
let down_migrations: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Down)
.collect();
let mut drops: Vec<_> = down_migrations.iter().collect();
drops.sort();
drops.reverse();
for m in &drops {
let sql = std::fs::read_to_string(&m.path)?;
let _ = exec_sql(pool, &sql).await;
}
run_up(pool, sql_root).await
}
async fn run_list(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
ensure_migrations_table(pool).await?;
let all = discover_migrations(sql_root)?;
let applied = applied_set(pool).await?;
let up_migrations: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Up)
.collect();
println!(
"{:<20} {:<30} {:>8} {}",
"Domain", "Table", "Version", "Status"
);
println!("{:-<20} {:-<30} {:-<8} {:-<10}", "", "", "", "");
for m in &up_migrations {
let key = (m.domain.clone(), m.table.clone(), m.version);
let status = if applied.contains_key(&key) {
"Applied"
} else {
"Pending"
};
println!(
"{:<20} {:<30} {:>8} {}",
m.domain, m.table, m.version, status
);
}
Ok(())
migrate::run_up(&pool).await
}

View File

@ -4,6 +4,7 @@ use async_nats::{HeaderMap, jetstream};
use config::AppConfig;
use futures_util::StreamExt;
use tracing::{error, info, warn};
use track::CounterVec;
use crate::{
handler::{AckAction, MessageHandler},
@ -16,6 +17,7 @@ pub struct NatsConsumer {
max_deliver: i64,
retry_delay_secs: u64,
durable_name: String,
metrics: Option<track::MetricsRegistry>,
}
impl NatsConsumer {
@ -33,9 +35,15 @@ impl NatsConsumer {
max_deliver: config.nats_max_deliver(),
retry_delay_secs: config.nats_retry_delay_secs(),
durable_name: durable_name(group_id),
metrics: None,
})
}
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
self.producer.set_metrics(registry.clone());
self.metrics = Some(registry);
}
pub async fn start_consuming<H>(
&self,
topics: &[&str],
@ -61,9 +69,10 @@ impl NatsConsumer {
)
.await?;
info!("NATS consumer started subscribing to: {:?}", topics_owned);
info!(topics = ?topics_owned, durable = %self.durable_name, "NATS consumer started");
let producer = self.producer.clone();
let metrics = self.metrics.clone();
let max_deliver = self.max_deliver;
let retry_delay_secs = self.retry_delay_secs;
let handler = Arc::new(handler);
@ -73,10 +82,7 @@ impl NatsConsumer {
let mut messages = match messages {
Ok(messages) => messages,
Err(error) => {
error!(
"NATS error while opening consumer stream: {:?}",
error
);
error!(error = %error, "NATS error while opening consumer stream");
return;
}
};
@ -86,6 +92,7 @@ impl NatsConsumer {
Ok(message) => {
handle_message(
&producer,
metrics.as_ref(),
max_deliver,
retry_delay_secs,
handler.as_ref(),
@ -94,7 +101,7 @@ impl NatsConsumer {
.await;
}
Err(error) => {
error!("NATS error while consuming: {:?}", error);
error!(error = %error, "NATS error while consuming");
}
}
}
@ -106,6 +113,7 @@ impl NatsConsumer {
async fn handle_message<H>(
producer: &NatsProducer,
metrics: Option<&track::MetricsRegistry>,
max_deliver: i64,
retry_delay_secs: u64,
handler: &H,
@ -116,12 +124,17 @@ async fn handle_message<H>(
let subject = message.subject.to_string();
let payload = message.payload.clone();
let delivered = message.info().map(|info| info.delivered).unwrap_or(1);
record_queue_message(metrics, &subject, "received");
match handler.handle(&subject, &payload).await {
AckAction::Ack => ack_message(&message, &subject, "message").await,
AckAction::Ack => {
ack_message(metrics, &message, &subject, "message").await
}
AckAction::Nack => {
record_queue_message(metrics, &subject, "nack");
if let Err(error) = handle_nack(
producer,
metrics,
&message,
&subject,
&payload,
@ -131,9 +144,13 @@ async fn handle_message<H>(
)
.await
{
record_queue_message(metrics, &subject, "error");
error!(
"Failed to route NACKed message from subject {}: {:?}",
subject, error
subject = %subject,
delivered,
max_deliver,
error = %error,
"failed to route NACKed message"
);
}
}
@ -142,6 +159,7 @@ async fn handle_message<H>(
async fn handle_nack(
producer: &NatsProducer,
metrics: Option<&track::MetricsRegistry>,
message: &jetstream::Message,
subject: &str,
payload: &[u8],
@ -151,8 +169,11 @@ async fn handle_nack(
) -> anyhow::Result<()> {
if delivered < max_deliver {
warn!(
"Message in subject {} failed (NACK). Retrying delivery {}/{} in {} seconds",
subject, delivered, max_deliver, retry_delay_secs
subject,
delivered,
max_deliver,
retry_delay_secs,
"message NACKed, scheduling retry"
);
message
.ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs(
@ -162,13 +183,17 @@ async fn handle_nack(
.map_err(|error| {
anyhow::anyhow!("failed to nack message: {error}")
})?;
record_queue_message(metrics, subject, "retry");
return Ok(());
}
let dlq_subject = format!("{subject}.dlq");
error!(
"Message in subject {} exceeded max deliver attempts ({}). Routing to DLQ: {}",
subject, max_deliver, dlq_subject
subject,
dlq_subject = %dlq_subject,
delivered,
max_deliver,
"message exceeded max deliver attempts, routing to DLQ"
);
let mut headers = HeaderMap::new();
@ -185,21 +210,68 @@ async fn handle_nack(
message.ack().await.map_err(|error| {
anyhow::anyhow!("failed to ack DLQ message: {error}")
})?;
record_queue_message(metrics, subject, "dlq");
record_queue_dlq(metrics, subject);
Ok(())
}
async fn ack_message(
metrics: Option<&track::MetricsRegistry>,
message: &jetstream::Message,
subject: &str,
description: &str,
) {
if let Err(error) = message.ack().await {
match message.ack().await {
Ok(()) => record_queue_message(metrics, subject, "ack"),
Err(error) => {
record_queue_message(metrics, subject, "ack_error");
error!(
"Failed to ack {} in subject {}: {:?}",
description, subject, error
subject,
description,
error = %error,
"failed to ack message"
);
}
}
}
fn record_queue_message(
metrics: Option<&track::MetricsRegistry>,
topic: &str,
status: &str,
) {
if let Some(metrics) = metrics {
queue_messages_vec(metrics)
.with_label_values(&[topic, status])
.inc();
}
}
fn record_queue_dlq(metrics: Option<&track::MetricsRegistry>, topic: &str) {
if let Some(metrics) = metrics {
queue_dlq_vec(metrics).with_label_values(&[topic]).inc();
}
}
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"queue_messages_total",
"Total queue messages",
&["topic", "status"],
)
.expect("failed to register queue_messages_total")
}
fn queue_dlq_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"queue_dlq_total",
"Total messages routed to DLQ",
&["topic"],
)
.expect("failed to register queue_dlq_total")
}
fn durable_name(name: &str) -> String {
name.replace('.', "-")

View File

@ -3,10 +3,12 @@ use std::time::Duration;
use async_nats::{HeaderMap, jetstream};
use config::AppConfig;
use serde::Serialize;
use track::CounterVec;
#[derive(Clone)]
pub struct NatsProducer {
jetstream: jetstream::Context,
metrics: Option<track::MetricsRegistry>,
}
impl NatsProducer {
@ -14,7 +16,14 @@ impl NatsProducer {
let jetstream = connect_jetstream(config).await?;
ensure_stream(config, &jetstream).await?;
Ok(Self { jetstream })
Ok(Self {
jetstream,
metrics: None,
})
}
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
self.metrics = Some(registry);
}
pub async fn send<T>(
@ -44,20 +53,38 @@ impl NatsProducer {
}
let subject = subject.to_string();
let publish_result: anyhow::Result<()> = async {
let publish = if headers.is_empty() {
self.jetstream
.publish(subject.clone(), payload.to_vec().into())
.await?
} else {
self.jetstream
.publish_with_headers(subject, headers, payload.to_vec().into())
.publish_with_headers(
subject.clone(),
headers,
payload.to_vec().into(),
)
.await?
};
tokio::time::timeout(Duration::from_secs(5), publish).await??;
Ok(())
}
.await;
self.record_published(&subject, publish_result.is_ok());
publish_result
}
fn record_published(&self, topic: &str, success: bool) {
if let Some(reg) = &self.metrics {
let status = if success { "published" } else { "error" };
queue_messages_vec(reg)
.with_label_values(&[topic, status])
.inc();
}
}
}
pub async fn connect_jetstream(
@ -88,3 +115,13 @@ pub async fn ensure_stream(
})
.await?)
}
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"queue_messages_total",
"Total queue messages",
&["topic", "status"],
)
.expect("failed to register queue_messages_total")
}

View File

@ -10,6 +10,7 @@ use aws_sdk_s3::primitives::ByteStreamError;
pub use error::{StorageError, StorageResult};
pub use local::{LocalStorage, LocalStorageConfig};
pub use s3::{S3Storage, S3StorageConfig};
use track::CounterVec;
#[derive(Clone, Debug)]
pub enum AppStorageConfig {
@ -18,11 +19,60 @@ pub enum AppStorageConfig {
}
#[derive(Clone)]
pub enum AppStorage {
pub struct AppStorage {
inner: StorageBackend,
metrics: Option<track::MetricsRegistry>,
}
#[derive(Clone)]
enum StorageBackend {
Local(LocalStorage),
S3(S3Storage),
}
impl AppStorage {
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
self.metrics = Some(registry);
}
fn backend_name(&self) -> &str {
match &self.inner {
StorageBackend::Local(_) => "local",
StorageBackend::S3(_) => "s3",
}
}
fn record_upload(&self, bytes: usize) {
if let Some(reg) = &self.metrics {
storage_ops_vec(reg)
.with_label_values(&["upload", self.backend_name()])
.inc();
storage_bytes_vec(reg)
.with_label_values(&["upload"])
.inc_by(bytes as f64);
}
}
fn record_download(&self, bytes: usize) {
if let Some(reg) = &self.metrics {
storage_ops_vec(reg)
.with_label_values(&["download", self.backend_name()])
.inc();
storage_bytes_vec(reg)
.with_label_values(&["download"])
.inc_by(bytes as f64);
}
}
fn record_delete(&self) {
if let Some(reg) = &self.metrics {
storage_ops_vec(reg)
.with_label_values(&["delete", self.backend_name()])
.inc();
}
}
}
#[derive(Clone, Debug, Default)]
pub struct PutObjectOptions {
pub content_type: Option<String>,
@ -87,76 +137,109 @@ pub trait ObjectStorage: Send + Sync {
}
impl AppStorage {
#[tracing::instrument(skip(config))]
pub async fn init(config: AppStorageConfig) -> StorageResult<Self> {
match config {
let inner = match config {
AppStorageConfig::Local(config) => {
Ok(Self::Local(LocalStorage::connect(config).await?))
tracing::info!("initializing local storage");
StorageBackend::Local(LocalStorage::connect(config).await?)
}
AppStorageConfig::S3(config) => {
Ok(Self::S3(S3Storage::connect(config).await?))
}
tracing::info!(bucket = %config.bucket, region = %config.region, "initializing S3 storage");
StorageBackend::S3(S3Storage::connect(config).await?)
}
};
Ok(Self {
inner,
metrics: None,
})
}
}
#[async_trait]
impl ObjectStorage for AppStorage {
#[tracing::instrument(skip(self, body), fields(storage.key = %key))]
async fn put_stream(
&self,
key: &str,
body: ByteStream,
options: PutObjectOptions,
) -> StorageResult<StoredObject> {
match self {
Self::Local(storage) => {
let result = match &self.inner {
StorageBackend::Local(storage) => {
storage.put_stream(key, body, options).await
}
Self::S3(storage) => storage.put_stream(key, body, options).await,
StorageBackend::S3(storage) => {
storage.put_stream(key, body, options).await
}
};
if result.is_ok() {
self.record_upload(0);
}
result
}
#[tracing::instrument(skip(self, bytes), fields(storage.key = %key, storage.size = bytes.len()))]
async fn put_bytes(
&self,
key: &str,
bytes: Vec<u8>,
options: PutObjectOptions,
) -> StorageResult<StoredObject> {
match self {
Self::Local(storage) => {
let size = bytes.len();
let result = match &self.inner {
StorageBackend::Local(storage) => {
storage.put_bytes(key, bytes, options).await
}
Self::S3(storage) => storage.put_bytes(key, bytes, options).await,
StorageBackend::S3(storage) => {
storage.put_bytes(key, bytes, options).await
}
};
if result.is_ok() {
self.record_upload(size);
}
result
}
#[tracing::instrument(skip(self), fields(storage.key = %key))]
async fn get_stream(
&self,
key: &str,
) -> StorageResult<StorageObjectStream> {
match self {
Self::Local(storage) => storage.get_stream(key).await,
Self::S3(storage) => storage.get_stream(key).await,
match &self.inner {
StorageBackend::Local(storage) => storage.get_stream(key).await,
StorageBackend::S3(storage) => storage.get_stream(key).await,
}
}
#[tracing::instrument(skip(self), fields(storage.key = %key))]
async fn get_bytes(&self, key: &str) -> StorageResult<StorageObject> {
match self {
Self::Local(storage) => storage.get_bytes(key).await,
Self::S3(storage) => storage.get_bytes(key).await,
let result = match &self.inner {
StorageBackend::Local(storage) => storage.get_bytes(key).await,
StorageBackend::S3(storage) => storage.get_bytes(key).await,
};
if let Ok(obj) = &result {
self.record_download(obj.bytes.len());
}
result
}
#[tracing::instrument(skip(self), fields(storage.key = %key))]
async fn delete(&self, key: &str) -> StorageResult<()> {
match self {
Self::Local(storage) => storage.delete(key).await,
Self::S3(storage) => storage.delete(key).await,
let result = match &self.inner {
StorageBackend::Local(storage) => storage.delete(key).await,
StorageBackend::S3(storage) => storage.delete(key).await,
};
if result.is_ok() {
self.record_delete();
}
result
}
fn public_url(&self, key: &str) -> StorageResult<Option<String>> {
match self {
Self::Local(storage) => storage.public_url(key),
Self::S3(storage) => storage.public_url(key),
match &self.inner {
StorageBackend::Local(storage) => storage.public_url(key),
StorageBackend::S3(storage) => storage.public_url(key),
}
}
@ -165,17 +248,37 @@ impl ObjectStorage for AppStorage {
key: &str,
expires_in: Duration,
) -> StorageResult<String> {
match self {
Self::Local(storage) => {
match &self.inner {
StorageBackend::Local(storage) => {
storage.presigned_get_url(key, expires_in).await
}
Self::S3(storage) => {
StorageBackend::S3(storage) => {
storage.presigned_get_url(key, expires_in).await
}
}
}
}
fn storage_ops_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"storage_operations_total",
"Total storage operations",
&["operation", "backend"],
)
.expect("failed to register storage_operations_total")
}
fn storage_bytes_vec(registry: &track::MetricsRegistry) -> CounterVec {
registry
.register_counter_vec(
"storage_bytes_total",
"Total bytes transferred",
&["operation"],
)
.expect("failed to register storage_bytes_total")
}
pub async fn collect_byte_stream(
body: ByteStream,
) -> Result<Vec<u8>, ByteStreamError> {