refactor: update infrastructure libs (config, db, cache, queue, storage, migrate)
This commit is contained in:
parent
e44f3d13c4
commit
734e1c4cc8
319
lib/cache/app.rs
vendored
Normal file
319
lib/cache/app.rs
vendored
Normal file
@ -0,0 +1,319 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use track::CounterVec;
|
||||
|
||||
use crate::{
|
||||
cluster::{ClusterCache, ClusterCacheConfig},
|
||||
error::{CacheError, CacheResult},
|
||||
local::{LocalCacheConfig, MokaCache},
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppCacheConfig {
|
||||
pub local: LocalCacheConfig,
|
||||
pub cluster: Option<ClusterCacheConfig>,
|
||||
pub default_ttl: Option<Duration>,
|
||||
pub cluster_write_through: bool,
|
||||
}
|
||||
|
||||
impl Default for AppCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
local: LocalCacheConfig::default(),
|
||||
cluster: None,
|
||||
default_ttl: Some(Duration::from_secs(300)),
|
||||
cluster_write_through: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&config::AppConfig> for AppCacheConfig {
|
||||
type Error = CacheError;
|
||||
|
||||
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
|
||||
let local = LocalCacheConfig {
|
||||
max_capacity: config
|
||||
.cache_local_max_capacity()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
time_to_live: config
|
||||
.cache_local_ttl()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
time_to_idle: config
|
||||
.cache_local_tti()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
};
|
||||
|
||||
let cluster = if config
|
||||
.cache_cluster_enabled()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?
|
||||
{
|
||||
Some(ClusterCacheConfig {
|
||||
urls: config
|
||||
.redis_urls()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
key_prefix: config.cache_cluster_key_prefix(),
|
||||
command_timeout: config
|
||||
.cache_cluster_command_timeout()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
local,
|
||||
cluster,
|
||||
default_ttl: config
|
||||
.cache_default_ttl()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
cluster_write_through: config
|
||||
.cache_cluster_write_through()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// AppCache
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppCache {
|
||||
pub local: MokaCache,
|
||||
pub cluster: Option<ClusterCache>,
|
||||
default_ttl: Option<Duration>,
|
||||
cluster_write_through: bool,
|
||||
metrics: Option<track::MetricsRegistry>,
|
||||
}
|
||||
|
||||
impl AppCache {
|
||||
#[tracing::instrument(skip(config))]
|
||||
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
|
||||
let local = MokaCache::with_config(config.local);
|
||||
let cluster = match config.cluster {
|
||||
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
|
||||
Ok(cluster) => cluster,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to connect to cache cluster");
|
||||
return Err(e);
|
||||
}
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
|
||||
tracing::info!(has_cluster = cluster.is_some(), "cache initialized");
|
||||
Ok(Self {
|
||||
local,
|
||||
cluster,
|
||||
default_ttl: config.default_ttl,
|
||||
cluster_write_through: config.cluster_write_through,
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn local_only(local: MokaCache) -> Self {
|
||||
Self {
|
||||
local,
|
||||
cluster: None,
|
||||
default_ttl: None,
|
||||
cluster_write_through: false,
|
||||
metrics: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach a metrics registry for recording cache counters.
|
||||
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||
self.metrics = Some(registry);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(cache.key = %key))]
|
||||
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
|
||||
where
|
||||
T: serde::Serialize + serde::de::DeserializeOwned,
|
||||
{
|
||||
if let Some(value) = self.local.get(key).await? {
|
||||
tracing::debug!("cache hit (local)");
|
||||
self.record_hit("local");
|
||||
return Ok(Some(value));
|
||||
}
|
||||
|
||||
let Some(cluster) = &self.cluster else {
|
||||
tracing::debug!("cache miss");
|
||||
self.record_miss();
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let value = cluster.get::<T>(key).await?;
|
||||
if let Some(value) = &value {
|
||||
self.local.set(key, value).await?;
|
||||
tracing::debug!("cache hit (cluster)");
|
||||
self.record_hit("cluster");
|
||||
} else {
|
||||
tracing::debug!("cache miss");
|
||||
self.record_miss();
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, value), fields(cache.key = %key))]
|
||||
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
|
||||
where
|
||||
T: serde::Serialize + ?Sized,
|
||||
{
|
||||
self.local.set(key, value).await?;
|
||||
if self.cluster_write_through
|
||||
&& let Some(cluster) = &self.cluster
|
||||
{
|
||||
cluster.set(key, value, self.default_ttl).await?;
|
||||
}
|
||||
self.record_set();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_with_ttl<T>(
|
||||
&self,
|
||||
key: &str,
|
||||
value: &T,
|
||||
ttl: std::time::Duration,
|
||||
) -> CacheResult<()>
|
||||
where
|
||||
T: serde::Serialize + ?Sized,
|
||||
{
|
||||
self.local.set(key, value).await?;
|
||||
if self.cluster_write_through
|
||||
&& let Some(cluster) = &self.cluster
|
||||
{
|
||||
cluster.set(key, value, Some(ttl)).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(cache.key = %key))]
|
||||
pub async fn remove(&self, key: &str) -> CacheResult<()> {
|
||||
self.local.remove(key).await;
|
||||
if let Some(cluster) = &self.cluster {
|
||||
cluster.remove(key).await?;
|
||||
}
|
||||
self.record_remove();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn record_hit(&self, tier: &str) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
cache_hits_vec(reg).with_label_values(&[tier]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
fn record_miss(&self) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
cache_misses_vec(reg).with_label_values(&[]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
fn record_set(&self) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
cache_sets_vec(reg).with_label_values(&[]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
fn record_remove(&self) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
cache_removes_vec(reg).with_label_values(&[]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
|
||||
let pattern = pattern.to_string();
|
||||
let local_pattern = pattern.clone();
|
||||
self.local.invalidate_entries_if(move |key| {
|
||||
simple_glob_match(&local_pattern, key)
|
||||
});
|
||||
|
||||
let mut removed = 0u64;
|
||||
if let Some(cluster) = &self.cluster {
|
||||
removed = cluster.delete_pattern(&pattern).await?;
|
||||
}
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
pub async fn ping_cluster(&self) -> CacheResult<()> {
|
||||
if let Some(cluster) = &self.cluster {
|
||||
cluster.ping().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
|
||||
self.cluster.as_ref().map(|c| c.conn())
|
||||
}
|
||||
}
|
||||
|
||||
fn cache_hits_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec("cache_hits_total", "Total cache hits", &["tier"])
|
||||
.expect("failed to register cache_hits_total")
|
||||
}
|
||||
|
||||
fn cache_misses_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec("cache_misses_total", "Total cache misses", &[])
|
||||
.expect("failed to register cache_misses_total")
|
||||
}
|
||||
|
||||
fn cache_sets_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"cache_sets_total",
|
||||
"Total cache set operations",
|
||||
&[],
|
||||
)
|
||||
.expect("failed to register cache_sets_total")
|
||||
}
|
||||
|
||||
fn cache_removes_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"cache_removes_total",
|
||||
"Total cache remove operations",
|
||||
&[],
|
||||
)
|
||||
.expect("failed to register cache_removes_total")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helpers
|
||||
// ============================================================================
|
||||
|
||||
fn simple_glob_match(pattern: &str, key: &str) -> bool {
|
||||
let p = pattern.as_bytes();
|
||||
let k = key.as_bytes();
|
||||
let (mut pi, mut ki) = (0usize, 0usize);
|
||||
|
||||
let mut backtrack_p: Option<usize> = None;
|
||||
let mut backtrack_k: usize = 0;
|
||||
|
||||
loop {
|
||||
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
|
||||
pi += 1;
|
||||
ki += 1;
|
||||
} else if pi < p.len() && p[pi] == b'*' {
|
||||
backtrack_p = Some(pi);
|
||||
backtrack_k = ki;
|
||||
pi += 1;
|
||||
} else if let Some(saved_pi) = backtrack_p {
|
||||
backtrack_k += 1;
|
||||
ki = backtrack_k;
|
||||
pi = saved_pi + 1;
|
||||
} else {
|
||||
return pi == p.len() && ki == k.len();
|
||||
}
|
||||
|
||||
if pi == p.len() && ki == k.len() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
220
lib/cache/lib.rs
vendored
220
lib/cache/lib.rs
vendored
@ -1,227 +1,11 @@
|
||||
pub mod app;
|
||||
pub mod cluster;
|
||||
pub mod error;
|
||||
pub mod local;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
pub use crate::{
|
||||
app::{AppCache, AppCacheConfig},
|
||||
cluster::{ClusterCache, ClusterCacheConfig},
|
||||
error::{CacheError, CacheResult},
|
||||
local::{LocalCacheConfig, MokaCache},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppCacheConfig {
|
||||
pub local: LocalCacheConfig,
|
||||
pub cluster: Option<ClusterCacheConfig>,
|
||||
pub default_ttl: Option<Duration>,
|
||||
pub cluster_write_through: bool,
|
||||
}
|
||||
|
||||
impl Default for AppCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
local: LocalCacheConfig::default(),
|
||||
cluster: None,
|
||||
default_ttl: Some(Duration::from_secs(300)),
|
||||
cluster_write_through: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppCache {
|
||||
pub local: MokaCache,
|
||||
pub cluster: Option<ClusterCache>,
|
||||
default_ttl: Option<Duration>,
|
||||
cluster_write_through: bool,
|
||||
}
|
||||
|
||||
impl AppCache {
|
||||
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
|
||||
let local = MokaCache::with_config(config.local);
|
||||
let cluster = match config.cluster {
|
||||
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
|
||||
Ok(cluster) => cluster,
|
||||
Err(e) => {
|
||||
println!("cache:init:error with: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
local,
|
||||
cluster,
|
||||
default_ttl: config.default_ttl,
|
||||
cluster_write_through: config.cluster_write_through,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn local_only(local: MokaCache) -> Self {
|
||||
Self {
|
||||
local,
|
||||
cluster: None,
|
||||
default_ttl: None,
|
||||
cluster_write_through: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
|
||||
where
|
||||
T: serde::Serialize + serde::de::DeserializeOwned,
|
||||
{
|
||||
if let Some(value) = self.local.get(key).await? {
|
||||
return Ok(Some(value));
|
||||
}
|
||||
|
||||
let Some(cluster) = &self.cluster else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let value = cluster.get::<T>(key).await?;
|
||||
if let Some(value) = &value {
|
||||
self.local.set(key, value).await?;
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
|
||||
where
|
||||
T: serde::Serialize + ?Sized,
|
||||
{
|
||||
self.local.set(key, value).await?;
|
||||
if self.cluster_write_through
|
||||
&& let Some(cluster) = &self.cluster
|
||||
{
|
||||
cluster.set(key, value, self.default_ttl).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_with_ttl<T>(
|
||||
&self,
|
||||
key: &str,
|
||||
value: &T,
|
||||
ttl: std::time::Duration,
|
||||
) -> CacheResult<()>
|
||||
where
|
||||
T: serde::Serialize + ?Sized,
|
||||
{
|
||||
self.local.set(key, value).await?;
|
||||
if self.cluster_write_through
|
||||
&& let Some(cluster) = &self.cluster
|
||||
{
|
||||
cluster.set(key, value, Some(ttl)).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove(&self, key: &str) -> CacheResult<()> {
|
||||
self.local.remove(key).await;
|
||||
if let Some(cluster) = &self.cluster {
|
||||
cluster.remove(key).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
|
||||
let pattern = pattern.to_string();
|
||||
let local_pattern = pattern.clone();
|
||||
self.local.invalidate_entries_if(move |key| {
|
||||
simple_glob_match(&local_pattern, key)
|
||||
});
|
||||
|
||||
let mut removed = 0u64;
|
||||
if let Some(cluster) = &self.cluster {
|
||||
removed = cluster.delete_pattern(&pattern).await?;
|
||||
}
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
pub async fn ping_cluster(&self) -> CacheResult<()> {
|
||||
if let Some(cluster) = &self.cluster {
|
||||
cluster.ping().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
|
||||
self.cluster.as_ref().map(|c| c.conn())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&config::AppConfig> for AppCacheConfig {
|
||||
type Error = CacheError;
|
||||
|
||||
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
|
||||
let local = LocalCacheConfig {
|
||||
max_capacity: config
|
||||
.cache_local_max_capacity()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
time_to_live: config
|
||||
.cache_local_ttl()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
time_to_idle: config
|
||||
.cache_local_tti()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
};
|
||||
|
||||
let cluster = if config
|
||||
.cache_cluster_enabled()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?
|
||||
{
|
||||
Some(ClusterCacheConfig {
|
||||
urls: config
|
||||
.redis_urls()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
key_prefix: config.cache_cluster_key_prefix(),
|
||||
command_timeout: config
|
||||
.cache_cluster_command_timeout()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
local,
|
||||
cluster,
|
||||
default_ttl: config
|
||||
.cache_default_ttl()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
cluster_write_through: config
|
||||
.cache_cluster_write_through()
|
||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||
})
|
||||
}
|
||||
}
|
||||
fn simple_glob_match(pattern: &str, key: &str) -> bool {
|
||||
let p = pattern.as_bytes();
|
||||
let k = key.as_bytes();
|
||||
let (mut pi, mut ki) = (0usize, 0usize);
|
||||
|
||||
let mut backtrack_p: Option<usize> = None;
|
||||
let mut backtrack_k: usize = 0;
|
||||
|
||||
loop {
|
||||
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
|
||||
pi += 1;
|
||||
ki += 1;
|
||||
} else if pi < p.len() && p[pi] == b'*' {
|
||||
backtrack_p = Some(pi);
|
||||
backtrack_k = ki;
|
||||
pi += 1;
|
||||
} else if let Some(saved_pi) = backtrack_p {
|
||||
backtrack_k += 1;
|
||||
ki = backtrack_k;
|
||||
pi = saved_pi + 1;
|
||||
} else {
|
||||
return pi == p.len() && ki == k.len();
|
||||
}
|
||||
|
||||
if pi == p.len() && ki == k.len() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,6 +28,13 @@ impl AppConfig {
|
||||
Ok(8080)
|
||||
}
|
||||
|
||||
pub fn email_health_port(&self) -> u16 {
|
||||
self.env
|
||||
.get("APP_EMAIL_HEALTH_PORT")
|
||||
.and_then(|port| port.parse::<u16>().ok())
|
||||
.unwrap_or(8083)
|
||||
}
|
||||
|
||||
pub fn session_secret(&self) -> anyhow::Result<String> {
|
||||
if let Some(secret) = self.env.get("APP_SESSION_SECRET") {
|
||||
return Ok(secret.to_string());
|
||||
|
||||
39
lib/config/app_config.rs
Normal file
39
lib/config/app_config.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use std::{collections::HashMap, sync::OnceLock};
|
||||
|
||||
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppConfig {
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
|
||||
|
||||
pub fn load() -> AppConfig {
|
||||
let mut env = HashMap::new();
|
||||
for env_file in AppConfig::ENV_FILES {
|
||||
if let Err(e) = dotenvy::from_path(env_file) {
|
||||
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
|
||||
}
|
||||
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
|
||||
for line in env_file_content.lines() {
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
env.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
env = env.into_iter().chain(std::env::vars()).collect();
|
||||
let this = AppConfig { env };
|
||||
if let Some(config) = GLOBAL_CONFIG.get() {
|
||||
config.clone()
|
||||
} else {
|
||||
let _ = GLOBAL_CONFIG.set(this);
|
||||
GLOBAL_CONFIG
|
||||
.get()
|
||||
.expect("global config should be set after load")
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,44 +1,6 @@
|
||||
use std::{collections::HashMap, sync::OnceLock};
|
||||
|
||||
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppConfig {
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
|
||||
pub fn load() -> AppConfig {
|
||||
let mut env = HashMap::new();
|
||||
for env_file in AppConfig::ENV_FILES {
|
||||
if let Err(e) = dotenvy::from_path(env_file) {
|
||||
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
|
||||
}
|
||||
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
|
||||
for line in env_file_content.lines() {
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
env.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
env = env.into_iter().chain(std::env::vars()).collect();
|
||||
let this = AppConfig { env };
|
||||
if GLOBAL_CONFIG.get().is_some() {
|
||||
GLOBAL_CONFIG.get().unwrap().clone()
|
||||
} else {
|
||||
let _ = GLOBAL_CONFIG.set(this);
|
||||
GLOBAL_CONFIG
|
||||
.get()
|
||||
.expect("global config should be set after load")
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod ai;
|
||||
pub mod app;
|
||||
pub mod app_config;
|
||||
pub mod auth;
|
||||
pub mod avatar;
|
||||
pub mod cache;
|
||||
@ -57,3 +19,5 @@ pub mod redis;
|
||||
pub mod smtp;
|
||||
pub mod ssh;
|
||||
pub mod storage;
|
||||
|
||||
pub use app_config::{AppConfig, GLOBAL_CONFIG};
|
||||
|
||||
@ -61,7 +61,7 @@ impl AppConfig {
|
||||
if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") {
|
||||
return Ok(endpoint.to_string());
|
||||
}
|
||||
Ok("http://localhost:5080/api/default/v1/traces".to_string())
|
||||
Ok("http://localhost:4318".to_string())
|
||||
}
|
||||
|
||||
pub fn otel_service_name(&self) -> anyhow::Result<String> {
|
||||
|
||||
@ -7,6 +7,7 @@ use sqlx::{
|
||||
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
|
||||
},
|
||||
};
|
||||
use track::{CounterVec, HistogramVec};
|
||||
|
||||
use crate::{
|
||||
route::{SqlRoute, route_sql},
|
||||
@ -17,9 +18,11 @@ use crate::{
|
||||
pub struct AppDatabase {
|
||||
db_write: PgPool,
|
||||
db_read: Option<PgPool>,
|
||||
metrics: Option<track::MetricsRegistry>,
|
||||
}
|
||||
|
||||
impl AppDatabase {
|
||||
#[tracing::instrument(skip(cfg))]
|
||||
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
||||
let db_url = cfg.database_url()?;
|
||||
let max_connections = cfg.database_max_connections()?;
|
||||
@ -69,7 +72,15 @@ impl AppDatabase {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self { db_write, db_read })
|
||||
Ok(Self {
|
||||
db_write,
|
||||
db_read,
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||
self.metrics = Some(registry);
|
||||
}
|
||||
|
||||
pub fn writer(&self) -> &PgPool {
|
||||
@ -87,11 +98,14 @@ impl AppDatabase {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(sql.route = "write"))]
|
||||
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
|
||||
let txn = self.db_write.begin().await?;
|
||||
tracing::debug!("db transaction started");
|
||||
Ok(AppTransaction { inner: txn })
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(sql.route = "read"))]
|
||||
pub async fn begin_read_only(
|
||||
&self,
|
||||
) -> Result<AppTransaction<'_>, sqlx::Error> {
|
||||
@ -101,9 +115,11 @@ impl AppDatabase {
|
||||
.execute(&mut *txn)
|
||||
.await?;
|
||||
|
||||
tracing::debug!("db read-only transaction started");
|
||||
Ok(AppTransaction { inner: txn })
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
sql: &str,
|
||||
@ -111,18 +127,38 @@ impl AppDatabase {
|
||||
self.execute_with_args(sql, PgArguments::default()).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))]
|
||||
pub async fn execute_with_args(
|
||||
&self,
|
||||
sql: &str,
|
||||
args: PgArguments,
|
||||
) -> Result<PgQueryResult, sqlx::Error> {
|
||||
let pool = self.route_pool(sql);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
||||
let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
||||
.execute(pool)
|
||||
.await
|
||||
.await;
|
||||
|
||||
let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") {
|
||||
"insert"
|
||||
} else if sql.trim_start().to_uppercase().starts_with("UPDATE") {
|
||||
"update"
|
||||
} else if sql.trim_start().to_uppercase().starts_with("DELETE") {
|
||||
"delete"
|
||||
} else {
|
||||
"execute"
|
||||
};
|
||||
self.record_query(
|
||||
kind,
|
||||
self.route_label(sql),
|
||||
start.elapsed(),
|
||||
result.is_ok(),
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))]
|
||||
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
|
||||
where
|
||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||
@ -130,6 +166,7 @@ impl AppDatabase {
|
||||
self.fetch_one_with_args(sql, PgArguments::default()).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))]
|
||||
pub async fn fetch_one_with_args<T>(
|
||||
&self,
|
||||
sql: &str,
|
||||
@ -139,12 +176,23 @@ impl AppDatabase {
|
||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||
{
|
||||
let pool = self.route_pool(sql);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
let result =
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
self.record_query(
|
||||
"select",
|
||||
self.route_label(sql),
|
||||
start.elapsed(),
|
||||
result.is_ok(),
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))]
|
||||
pub async fn fetch_optional<T>(
|
||||
&self,
|
||||
sql: &str,
|
||||
@ -156,6 +204,7 @@ impl AppDatabase {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))]
|
||||
pub async fn fetch_optional_with_args<T>(
|
||||
&self,
|
||||
sql: &str,
|
||||
@ -165,12 +214,23 @@ impl AppDatabase {
|
||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||
{
|
||||
let pool = self.route_pool(sql);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
let result =
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
self.record_query(
|
||||
"select",
|
||||
self.route_label(sql),
|
||||
start.elapsed(),
|
||||
result.is_ok(),
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))]
|
||||
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
|
||||
where
|
||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||
@ -178,6 +238,7 @@ impl AppDatabase {
|
||||
self.fetch_all_with_args(sql, PgArguments::default()).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))]
|
||||
pub async fn fetch_all_with_args<T>(
|
||||
&self,
|
||||
sql: &str,
|
||||
@ -187,10 +248,45 @@ impl AppDatabase {
|
||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||
{
|
||||
let pool = self.route_pool(sql);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
let result =
|
||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
self.record_query(
|
||||
"select",
|
||||
self.route_label(sql),
|
||||
start.elapsed(),
|
||||
result.is_ok(),
|
||||
);
|
||||
result
|
||||
}
|
||||
|
||||
fn route_label(&self, sql: &str) -> &str {
|
||||
match route_sql(sql) {
|
||||
SqlRoute::Write => "write",
|
||||
SqlRoute::Read => "read",
|
||||
}
|
||||
}
|
||||
|
||||
fn record_query(
|
||||
&self,
|
||||
kind: &str,
|
||||
route: &str,
|
||||
duration: Duration,
|
||||
success: bool,
|
||||
) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
let status = if success { "success" } else { "error" };
|
||||
db_queries_vec(reg)
|
||||
.with_label_values(&[kind, route, status])
|
||||
.inc();
|
||||
db_query_duration_vec(reg)
|
||||
.with_label_values(&[kind, route])
|
||||
.observe(duration.as_secs_f64());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,3 +326,26 @@ async fn build_pool(
|
||||
|
||||
pool_options.connect_with(options).await
|
||||
}
|
||||
|
||||
fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"db_queries_total",
|
||||
"Total database queries",
|
||||
&["kind", "route", "status"],
|
||||
)
|
||||
.expect("failed to register db_queries_total")
|
||||
}
|
||||
|
||||
fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec {
|
||||
registry
|
||||
.register_histogram_vec(
|
||||
"db_query_duration_seconds",
|
||||
"DB query duration in seconds",
|
||||
&["kind", "route"],
|
||||
vec![
|
||||
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
|
||||
],
|
||||
)
|
||||
.expect("failed to register db_query_duration_seconds")
|
||||
}
|
||||
|
||||
267
lib/migrate/lib.rs
Normal file
267
lib/migrate/lib.rs
Normal file
@ -0,0 +1,267 @@
|
||||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct Migration {
|
||||
domain: String,
|
||||
table: String,
|
||||
version: u32,
|
||||
direction: MigrationDir,
|
||||
path: PathBuf,
|
||||
depends_on: Vec<String>,
|
||||
}
|
||||
|
||||
impl Ord for Migration {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
|
||||
&other.domain,
|
||||
&other.table,
|
||||
other.version,
|
||||
&other.direction,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Migration {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum MigrationDir {
|
||||
Up,
|
||||
Down,
|
||||
}
|
||||
|
||||
pub async fn run_up(pool: &PgPool) -> Result<()> {
|
||||
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
|
||||
ensure_migrations_table(pool).await?;
|
||||
let all = discover_migrations(&sql_root)?;
|
||||
let applied = applied_set(pool).await?;
|
||||
|
||||
let mut up_migrations: Vec<_> = all
|
||||
.into_iter()
|
||||
.filter(|m| m.direction == MigrationDir::Up)
|
||||
.filter(|m| {
|
||||
!applied.contains_key(&(
|
||||
m.domain.clone(),
|
||||
m.table.clone(),
|
||||
m.version,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if up_migrations.is_empty() {
|
||||
tracing::info!("All migrations are already applied.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
topo_sort(&mut up_migrations)?;
|
||||
|
||||
for m in &up_migrations {
|
||||
let sql = std::fs::read_to_string(&m.path)
|
||||
.context(format!("Failed to read {:?}", m.path))?;
|
||||
let checksum = compute_checksum(&sql);
|
||||
|
||||
tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration");
|
||||
exec_sql(pool, &sql).await?;
|
||||
record_migration(pool, m, &checksum).await?;
|
||||
}
|
||||
|
||||
tracing::info!("Applied {} migration(s).", up_migrations.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
|
||||
let mut migrations = Vec::new();
|
||||
|
||||
if !sql_root.exists() {
|
||||
bail!("SQL directory not found: {}", sql_root.display());
|
||||
}
|
||||
|
||||
for dir_entry in std::fs::read_dir(sql_root)? {
|
||||
let dir = dir_entry?;
|
||||
if !dir.file_type()?.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let domain = dir.file_name().to_string_lossy().to_string();
|
||||
|
||||
for file_entry in std::fs::read_dir(dir.path())? {
|
||||
let file = file_entry?;
|
||||
let path = file.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.context("Invalid filename")?;
|
||||
|
||||
let (table, direction, version) = parse_migration_stem(stem)?;
|
||||
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.context(format!("Failed to read {path:?}"))?;
|
||||
let depends_on = parse_depends_on(&content);
|
||||
|
||||
migrations.push(Migration {
|
||||
domain: domain.clone(),
|
||||
table,
|
||||
version,
|
||||
direction,
|
||||
path,
|
||||
depends_on,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
migrations.sort();
|
||||
Ok(migrations)
|
||||
}
|
||||
|
||||
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
|
||||
if let Some(pos) = stem.rfind("_up_") {
|
||||
let table = stem[..pos].to_string();
|
||||
let version = stem[pos + 4..]
|
||||
.parse::<u32>()
|
||||
.context("Invalid version number")?;
|
||||
Ok((table, MigrationDir::Up, version))
|
||||
} else if let Some(pos) = stem.rfind("_down_") {
|
||||
let table = stem[..pos].to_string();
|
||||
let version = stem[pos + 6..]
|
||||
.parse::<u32>()
|
||||
.context("Invalid version number")?;
|
||||
Ok((table, MigrationDir::Down, version))
|
||||
} else {
|
||||
bail!("Migration filename must contain _up_ or _down_: {stem}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_migrations_table(pool: &PgPool) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS _sql_migrations (
|
||||
domain TEXT NOT NULL,
|
||||
table_name TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
checksum TEXT NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (domain, table_name, version)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn applied_set(
|
||||
pool: &PgPool,
|
||||
) -> Result<BTreeMap<(String, String, u32), String>> {
|
||||
let rows: Vec<(String, String, i32, String)> = sqlx::query_as(
|
||||
"SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|(d, t, v, c)| ((d, t, v as u32), c))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn record_migration(
|
||||
pool: &PgPool,
|
||||
m: &Migration,
|
||||
checksum: &str,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
"INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING",
|
||||
)
|
||||
.bind(&m.domain)
|
||||
.bind(&m.table)
|
||||
.bind(m.version as i32)
|
||||
.bind(checksum)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_checksum(content: &str) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
content.hash(&mut hasher);
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
fn parse_depends_on(content: &str) -> Vec<String> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
line.strip_prefix("-- depends_on:").map(|deps| {
|
||||
deps.split(',')
|
||||
.map(|d| d.trim().to_string())
|
||||
.filter(|d| !d.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
.flatten()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
|
||||
let table_to_idx: HashMap<String, usize> = migrations
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, m)| (m.table.clone(), i))
|
||||
.collect();
|
||||
|
||||
let n = migrations.len();
|
||||
let mut in_degree = vec![0u32; n];
|
||||
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
||||
|
||||
for (i, m) in migrations.iter().enumerate() {
|
||||
for dep in &m.depends_on {
|
||||
if let Some(&j) = table_to_idx.get(dep) {
|
||||
adj[j].push(i);
|
||||
in_degree[i] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut queue: VecDeque<usize> =
|
||||
(0..n).filter(|&i| in_degree[i] == 0).collect();
|
||||
let mut order = Vec::with_capacity(n);
|
||||
while let Some(i) = queue.pop_front() {
|
||||
order.push(i);
|
||||
for &next in &adj[i] {
|
||||
in_degree[next] -= 1;
|
||||
if in_degree[next] == 0 {
|
||||
queue.push_back(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if order.len() != n {
|
||||
bail!("Circular dependency detected among migrations");
|
||||
}
|
||||
|
||||
let original: Vec<Migration> = migrations.iter().cloned().collect();
|
||||
for (slot, &idx) in order.iter().enumerate() {
|
||||
migrations[slot] = original[idx].clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> {
|
||||
let s: &'static str = Box::leak(sql.to_owned().into_boxed_str());
|
||||
sqlx::raw_sql(s).execute(pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
2
lib/migrate/sql/room/room_attachment_down_02.sql
Normal file
2
lib/migrate/sql/room/room_attachment_down_02.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE room_attachment ALTER COLUMN message SET NOT NULL;
|
||||
ALTER TABLE room_attachment ALTER COLUMN seq SET NOT NULL;
|
||||
2
lib/migrate/sql/room/room_attachment_up_02.sql
Normal file
2
lib/migrate/sql/room/room_attachment_up_02.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE room_attachment ALTER COLUMN message DROP NOT NULL;
|
||||
ALTER TABLE room_attachment ALTER COLUMN seq DROP NOT NULL;
|
||||
1
lib/migrate/sql/room/room_mention_down_02.sql
Normal file
1
lib/migrate/sql/room/room_mention_down_02.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE room_mention ALTER COLUMN target_id TYPE UUID USING target_id::UUID;
|
||||
1
lib/migrate/sql/room/room_mention_up_02.sql
Normal file
1
lib/migrate/sql/room/room_mention_up_02.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE room_mention ALTER COLUMN target_id TYPE TEXT;
|
||||
@ -1,59 +1,12 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::{Parser, Subcommand};
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "migrate", about = "Database migration tool")]
|
||||
struct Cli {
|
||||
#[arg(short, long)]
|
||||
#[arg(short, long, env = "DATABASE_URL")]
|
||||
database_url: String,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
Up,
|
||||
Down,
|
||||
Fresh,
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct Migration {
|
||||
domain: String,
|
||||
table: String,
|
||||
version: u32,
|
||||
direction: MigrationDir,
|
||||
path: PathBuf,
|
||||
depends_on: Vec<String>,
|
||||
}
|
||||
|
||||
impl Ord for Migration {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
|
||||
&other.domain,
|
||||
&other.table,
|
||||
other.version,
|
||||
&other.direction,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Migration {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum MigrationDir {
|
||||
Up,
|
||||
Down,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -64,358 +17,10 @@ async fn main() -> Result<()> {
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
let database_url = std::env::var("DATABASE_URL")
|
||||
.context("DATABASE_URL must be set or provided via --database-url")?;
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.connect(&database_url)
|
||||
.await
|
||||
.context("Failed to connect to database")?;
|
||||
|
||||
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
|
||||
|
||||
match cli.command {
|
||||
Command::Up => run_up(&pool, &sql_root).await,
|
||||
Command::Down => run_down(&pool, &sql_root).await,
|
||||
Command::Fresh => run_fresh(&pool, &sql_root).await,
|
||||
Command::List => run_list(&pool, &sql_root).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
|
||||
let mut migrations = Vec::new();
|
||||
|
||||
if !sql_root.exists() {
|
||||
bail!("SQL directory not found: {}", sql_root.display());
|
||||
}
|
||||
|
||||
for dir_entry in std::fs::read_dir(sql_root)? {
|
||||
let dir = dir_entry?;
|
||||
if !dir.file_type()?.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let domain = dir.file_name().to_string_lossy().to_string();
|
||||
|
||||
for file_entry in std::fs::read_dir(dir.path())? {
|
||||
let file = file_entry?;
|
||||
let path = file.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.context("Invalid filename")?;
|
||||
|
||||
let (table, direction, version) = parse_migration_stem(stem)?;
|
||||
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.context(format!("Failed to read {path:?}"))?;
|
||||
let depends_on = parse_depends_on(&content);
|
||||
|
||||
migrations.push(Migration {
|
||||
domain: domain.clone(),
|
||||
table,
|
||||
version,
|
||||
direction,
|
||||
path,
|
||||
depends_on,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
migrations.sort();
|
||||
Ok(migrations)
|
||||
}
|
||||
|
||||
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
|
||||
if let Some(pos) = stem.rfind("_up_") {
|
||||
let table = stem[..pos].to_string();
|
||||
let ver_str = &stem[pos + 4..];
|
||||
let version =
|
||||
ver_str.parse::<u32>().context("Invalid version number")?;
|
||||
Ok((table, MigrationDir::Up, version))
|
||||
} else if let Some(pos) = stem.rfind("_down_") {
|
||||
let table = stem[..pos].to_string();
|
||||
let ver_str = &stem[pos + 6..];
|
||||
let version =
|
||||
ver_str.parse::<u32>().context("Invalid version number")?;
|
||||
Ok((table, MigrationDir::Down, version))
|
||||
} else {
|
||||
bail!("Migration filename must contain _up_ or _down_: {stem}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS _sql_migrations (
|
||||
domain TEXT NOT NULL,
|
||||
table_name TEXT NOT NULL,
|
||||
version INTEGER NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
checksum TEXT NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (domain, table_name, version)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn applied_set(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Result<BTreeMap<(String, String, u32), String>> {
|
||||
let rows: Vec<(String, String, i32, String)> =
|
||||
sqlx::query_as("SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version")
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|(d, t, v, c)| ((d, t, v as u32), c))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn record_migration(
|
||||
pool: &sqlx::PgPool,
|
||||
m: &Migration,
|
||||
checksum: &str,
|
||||
) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO _sql_migrations (domain, table_name, version, checksum)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(&m.domain)
|
||||
.bind(&m.table)
|
||||
.bind(m.version as i32)
|
||||
.bind(checksum)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_migration(pool: &sqlx::PgPool, m: &Migration) -> Result<()> {
|
||||
sqlx::query(
|
||||
"DELETE FROM _sql_migrations WHERE domain = $1 AND table_name = $2 AND version = $3",
|
||||
)
|
||||
.bind(&m.domain)
|
||||
.bind(&m.table)
|
||||
.bind(m.version as i32)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_checksum(content: &str) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
content.hash(&mut hasher);
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
fn parse_depends_on(content: &str) -> Vec<String> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
line.strip_prefix("-- depends_on:").map(|deps| {
|
||||
deps.split(',')
|
||||
.map(|d| d.trim().to_string())
|
||||
.filter(|d| !d.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
.flatten()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
|
||||
let table_to_idx: HashMap<String, usize> = migrations
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, m)| (m.table.clone(), i))
|
||||
.collect();
|
||||
|
||||
let n = migrations.len();
|
||||
let mut in_degree = vec![0u32; n];
|
||||
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
||||
|
||||
for (i, m) in migrations.iter().enumerate() {
|
||||
for dep in &m.depends_on {
|
||||
if let Some(&j) = table_to_idx.get(dep) {
|
||||
adj[j].push(i);
|
||||
in_degree[i] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut queue: VecDeque<usize> =
|
||||
(0..n).filter(|&i| in_degree[i] == 0).collect();
|
||||
|
||||
let mut order = Vec::with_capacity(n);
|
||||
while let Some(i) = queue.pop_front() {
|
||||
order.push(i);
|
||||
for &next in &adj[i] {
|
||||
in_degree[next] -= 1;
|
||||
if in_degree[next] == 0 {
|
||||
queue.push_back(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if order.len() != n {
|
||||
bail!("Circular dependency detected among migrations");
|
||||
}
|
||||
|
||||
let original: Vec<Migration> = migrations.iter().cloned().collect();
|
||||
for (slot, &idx) in order.iter().enumerate() {
|
||||
migrations[slot] = original[idx].clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
fn into_static(s: String) -> &'static str {
|
||||
Box::leak(s.into_boxed_str())
|
||||
}
|
||||
|
||||
async fn exec_sql(pool: &sqlx::PgPool, sql: &str) -> Result<()> {
|
||||
sqlx::raw_sql(into_static(sql.to_owned()))
|
||||
.execute(pool)
|
||||
.connect(&cli.database_url)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_up(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
||||
ensure_migrations_table(pool).await?;
|
||||
let all = discover_migrations(sql_root)?;
|
||||
let applied = applied_set(pool).await?;
|
||||
|
||||
let mut up_migrations: Vec<_> = all
|
||||
.into_iter()
|
||||
.filter(|m| m.direction == MigrationDir::Up)
|
||||
.filter(|m| {
|
||||
!applied.contains_key(&(
|
||||
m.domain.clone(),
|
||||
m.table.clone(),
|
||||
m.version,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if up_migrations.is_empty() {
|
||||
info!("All migrations are already applied.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
topo_sort(&mut up_migrations)?;
|
||||
|
||||
for m in &up_migrations {
|
||||
let sql = std::fs::read_to_string(&m.path)
|
||||
.context(format!("Failed to read {:?}", m.path))?;
|
||||
let checksum = compute_checksum(&sql);
|
||||
|
||||
info!("Applying {}/{}/v{}", m.domain, m.table, m.version);
|
||||
exec_sql(pool, &sql).await?;
|
||||
record_migration(pool, m, &checksum).await?;
|
||||
}
|
||||
|
||||
info!("Applied {} migration(s).", up_migrations.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_down(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
||||
ensure_migrations_table(pool).await?;
|
||||
let all = discover_migrations(sql_root)?;
|
||||
let applied = applied_set(pool).await?;
|
||||
|
||||
let mut down_targets: Vec<_> = all
|
||||
.into_iter()
|
||||
.filter(|m| m.direction == MigrationDir::Down)
|
||||
.filter(|m| {
|
||||
applied.contains_key(&(
|
||||
m.domain.clone(),
|
||||
m.table.clone(),
|
||||
m.version,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
down_targets.sort();
|
||||
|
||||
if down_targets.is_empty() {
|
||||
info!("No migrations to roll back.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let m = &down_targets[down_targets.len() - 1];
|
||||
let sql = std::fs::read_to_string(&m.path)?;
|
||||
|
||||
info!("Rolling back {}/{}/v{}", m.domain, m.table, m.version);
|
||||
exec_sql(pool, &sql).await?;
|
||||
delete_migration(pool, m).await?;
|
||||
|
||||
info!("Rolled back 1 migration.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_fresh(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
||||
info!("Dropping all tables and re-applying migrations...");
|
||||
|
||||
exec_sql(pool, "DROP TABLE IF EXISTS _sql_migrations CASCADE").await?;
|
||||
|
||||
let all = discover_migrations(sql_root)?;
|
||||
let down_migrations: Vec<_> = all
|
||||
.into_iter()
|
||||
.filter(|m| m.direction == MigrationDir::Down)
|
||||
.collect();
|
||||
|
||||
let mut drops: Vec<_> = down_migrations.iter().collect();
|
||||
drops.sort();
|
||||
drops.reverse();
|
||||
|
||||
for m in &drops {
|
||||
let sql = std::fs::read_to_string(&m.path)?;
|
||||
let _ = exec_sql(pool, &sql).await;
|
||||
}
|
||||
|
||||
run_up(pool, sql_root).await
|
||||
}
|
||||
|
||||
async fn run_list(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
||||
ensure_migrations_table(pool).await?;
|
||||
let all = discover_migrations(sql_root)?;
|
||||
let applied = applied_set(pool).await?;
|
||||
|
||||
let up_migrations: Vec<_> = all
|
||||
.into_iter()
|
||||
.filter(|m| m.direction == MigrationDir::Up)
|
||||
.collect();
|
||||
|
||||
println!(
|
||||
"{:<20} {:<30} {:>8} {}",
|
||||
"Domain", "Table", "Version", "Status"
|
||||
);
|
||||
println!("{:-<20} {:-<30} {:-<8} {:-<10}", "", "", "", "");
|
||||
|
||||
for m in &up_migrations {
|
||||
let key = (m.domain.clone(), m.table.clone(), m.version);
|
||||
let status = if applied.contains_key(&key) {
|
||||
"Applied"
|
||||
} else {
|
||||
"Pending"
|
||||
};
|
||||
println!(
|
||||
"{:<20} {:<30} {:>8} {}",
|
||||
m.domain, m.table, m.version, status
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
migrate::run_up(&pool).await
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ use async_nats::{HeaderMap, jetstream};
|
||||
use config::AppConfig;
|
||||
use futures_util::StreamExt;
|
||||
use tracing::{error, info, warn};
|
||||
use track::CounterVec;
|
||||
|
||||
use crate::{
|
||||
handler::{AckAction, MessageHandler},
|
||||
@ -16,6 +17,7 @@ pub struct NatsConsumer {
|
||||
max_deliver: i64,
|
||||
retry_delay_secs: u64,
|
||||
durable_name: String,
|
||||
metrics: Option<track::MetricsRegistry>,
|
||||
}
|
||||
|
||||
impl NatsConsumer {
|
||||
@ -33,9 +35,15 @@ impl NatsConsumer {
|
||||
max_deliver: config.nats_max_deliver(),
|
||||
retry_delay_secs: config.nats_retry_delay_secs(),
|
||||
durable_name: durable_name(group_id),
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||
self.producer.set_metrics(registry.clone());
|
||||
self.metrics = Some(registry);
|
||||
}
|
||||
|
||||
pub async fn start_consuming<H>(
|
||||
&self,
|
||||
topics: &[&str],
|
||||
@ -61,9 +69,10 @@ impl NatsConsumer {
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("NATS consumer started subscribing to: {:?}", topics_owned);
|
||||
info!(topics = ?topics_owned, durable = %self.durable_name, "NATS consumer started");
|
||||
|
||||
let producer = self.producer.clone();
|
||||
let metrics = self.metrics.clone();
|
||||
let max_deliver = self.max_deliver;
|
||||
let retry_delay_secs = self.retry_delay_secs;
|
||||
let handler = Arc::new(handler);
|
||||
@ -73,10 +82,7 @@ impl NatsConsumer {
|
||||
let mut messages = match messages {
|
||||
Ok(messages) => messages,
|
||||
Err(error) => {
|
||||
error!(
|
||||
"NATS error while opening consumer stream: {:?}",
|
||||
error
|
||||
);
|
||||
error!(error = %error, "NATS error while opening consumer stream");
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -86,6 +92,7 @@ impl NatsConsumer {
|
||||
Ok(message) => {
|
||||
handle_message(
|
||||
&producer,
|
||||
metrics.as_ref(),
|
||||
max_deliver,
|
||||
retry_delay_secs,
|
||||
handler.as_ref(),
|
||||
@ -94,7 +101,7 @@ impl NatsConsumer {
|
||||
.await;
|
||||
}
|
||||
Err(error) => {
|
||||
error!("NATS error while consuming: {:?}", error);
|
||||
error!(error = %error, "NATS error while consuming");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -106,6 +113,7 @@ impl NatsConsumer {
|
||||
|
||||
async fn handle_message<H>(
|
||||
producer: &NatsProducer,
|
||||
metrics: Option<&track::MetricsRegistry>,
|
||||
max_deliver: i64,
|
||||
retry_delay_secs: u64,
|
||||
handler: &H,
|
||||
@ -116,12 +124,17 @@ async fn handle_message<H>(
|
||||
let subject = message.subject.to_string();
|
||||
let payload = message.payload.clone();
|
||||
let delivered = message.info().map(|info| info.delivered).unwrap_or(1);
|
||||
record_queue_message(metrics, &subject, "received");
|
||||
|
||||
match handler.handle(&subject, &payload).await {
|
||||
AckAction::Ack => ack_message(&message, &subject, "message").await,
|
||||
AckAction::Ack => {
|
||||
ack_message(metrics, &message, &subject, "message").await
|
||||
}
|
||||
AckAction::Nack => {
|
||||
record_queue_message(metrics, &subject, "nack");
|
||||
if let Err(error) = handle_nack(
|
||||
producer,
|
||||
metrics,
|
||||
&message,
|
||||
&subject,
|
||||
&payload,
|
||||
@ -131,9 +144,13 @@ async fn handle_message<H>(
|
||||
)
|
||||
.await
|
||||
{
|
||||
record_queue_message(metrics, &subject, "error");
|
||||
error!(
|
||||
"Failed to route NACKed message from subject {}: {:?}",
|
||||
subject, error
|
||||
subject = %subject,
|
||||
delivered,
|
||||
max_deliver,
|
||||
error = %error,
|
||||
"failed to route NACKed message"
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -142,6 +159,7 @@ async fn handle_message<H>(
|
||||
|
||||
async fn handle_nack(
|
||||
producer: &NatsProducer,
|
||||
metrics: Option<&track::MetricsRegistry>,
|
||||
message: &jetstream::Message,
|
||||
subject: &str,
|
||||
payload: &[u8],
|
||||
@ -151,8 +169,11 @@ async fn handle_nack(
|
||||
) -> anyhow::Result<()> {
|
||||
if delivered < max_deliver {
|
||||
warn!(
|
||||
"Message in subject {} failed (NACK). Retrying delivery {}/{} in {} seconds",
|
||||
subject, delivered, max_deliver, retry_delay_secs
|
||||
subject,
|
||||
delivered,
|
||||
max_deliver,
|
||||
retry_delay_secs,
|
||||
"message NACKed, scheduling retry"
|
||||
);
|
||||
message
|
||||
.ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs(
|
||||
@ -162,13 +183,17 @@ async fn handle_nack(
|
||||
.map_err(|error| {
|
||||
anyhow::anyhow!("failed to nack message: {error}")
|
||||
})?;
|
||||
record_queue_message(metrics, subject, "retry");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let dlq_subject = format!("{subject}.dlq");
|
||||
error!(
|
||||
"Message in subject {} exceeded max deliver attempts ({}). Routing to DLQ: {}",
|
||||
subject, max_deliver, dlq_subject
|
||||
subject,
|
||||
dlq_subject = %dlq_subject,
|
||||
delivered,
|
||||
max_deliver,
|
||||
"message exceeded max deliver attempts, routing to DLQ"
|
||||
);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@ -185,22 +210,69 @@ async fn handle_nack(
|
||||
message.ack().await.map_err(|error| {
|
||||
anyhow::anyhow!("failed to ack DLQ message: {error}")
|
||||
})?;
|
||||
record_queue_message(metrics, subject, "dlq");
|
||||
record_queue_dlq(metrics, subject);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ack_message(
|
||||
metrics: Option<&track::MetricsRegistry>,
|
||||
message: &jetstream::Message,
|
||||
subject: &str,
|
||||
description: &str,
|
||||
) {
|
||||
if let Err(error) = message.ack().await {
|
||||
error!(
|
||||
"Failed to ack {} in subject {}: {:?}",
|
||||
description, subject, error
|
||||
);
|
||||
match message.ack().await {
|
||||
Ok(()) => record_queue_message(metrics, subject, "ack"),
|
||||
Err(error) => {
|
||||
record_queue_message(metrics, subject, "ack_error");
|
||||
error!(
|
||||
subject,
|
||||
description,
|
||||
error = %error,
|
||||
"failed to ack message"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_queue_message(
|
||||
metrics: Option<&track::MetricsRegistry>,
|
||||
topic: &str,
|
||||
status: &str,
|
||||
) {
|
||||
if let Some(metrics) = metrics {
|
||||
queue_messages_vec(metrics)
|
||||
.with_label_values(&[topic, status])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
|
||||
fn record_queue_dlq(metrics: Option<&track::MetricsRegistry>, topic: &str) {
|
||||
if let Some(metrics) = metrics {
|
||||
queue_dlq_vec(metrics).with_label_values(&[topic]).inc();
|
||||
}
|
||||
}
|
||||
|
||||
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"queue_messages_total",
|
||||
"Total queue messages",
|
||||
&["topic", "status"],
|
||||
)
|
||||
.expect("failed to register queue_messages_total")
|
||||
}
|
||||
|
||||
fn queue_dlq_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"queue_dlq_total",
|
||||
"Total messages routed to DLQ",
|
||||
&["topic"],
|
||||
)
|
||||
.expect("failed to register queue_dlq_total")
|
||||
}
|
||||
|
||||
fn durable_name(name: &str) -> String {
|
||||
name.replace('.', "-")
|
||||
}
|
||||
|
||||
@ -3,10 +3,12 @@ use std::time::Duration;
|
||||
use async_nats::{HeaderMap, jetstream};
|
||||
use config::AppConfig;
|
||||
use serde::Serialize;
|
||||
use track::CounterVec;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NatsProducer {
|
||||
jetstream: jetstream::Context,
|
||||
metrics: Option<track::MetricsRegistry>,
|
||||
}
|
||||
|
||||
impl NatsProducer {
|
||||
@ -14,7 +16,14 @@ impl NatsProducer {
|
||||
let jetstream = connect_jetstream(config).await?;
|
||||
ensure_stream(config, &jetstream).await?;
|
||||
|
||||
Ok(Self { jetstream })
|
||||
Ok(Self {
|
||||
jetstream,
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||
self.metrics = Some(registry);
|
||||
}
|
||||
|
||||
pub async fn send<T>(
|
||||
@ -44,19 +53,37 @@ impl NatsProducer {
|
||||
}
|
||||
|
||||
let subject = subject.to_string();
|
||||
let publish = if headers.is_empty() {
|
||||
self.jetstream
|
||||
.publish(subject.clone(), payload.to_vec().into())
|
||||
.await?
|
||||
} else {
|
||||
self.jetstream
|
||||
.publish_with_headers(subject, headers, payload.to_vec().into())
|
||||
.await?
|
||||
};
|
||||
let publish_result: anyhow::Result<()> = async {
|
||||
let publish = if headers.is_empty() {
|
||||
self.jetstream
|
||||
.publish(subject.clone(), payload.to_vec().into())
|
||||
.await?
|
||||
} else {
|
||||
self.jetstream
|
||||
.publish_with_headers(
|
||||
subject.clone(),
|
||||
headers,
|
||||
payload.to_vec().into(),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), publish).await??;
|
||||
tokio::time::timeout(Duration::from_secs(5), publish).await??;
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
self.record_published(&subject, publish_result.is_ok());
|
||||
publish_result
|
||||
}
|
||||
|
||||
fn record_published(&self, topic: &str, success: bool) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
let status = if success { "published" } else { "error" };
|
||||
queue_messages_vec(reg)
|
||||
.with_label_values(&[topic, status])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,3 +115,13 @@ pub async fn ensure_stream(
|
||||
})
|
||||
.await?)
|
||||
}
|
||||
|
||||
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"queue_messages_total",
|
||||
"Total queue messages",
|
||||
&["topic", "status"],
|
||||
)
|
||||
.expect("failed to register queue_messages_total")
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ use aws_sdk_s3::primitives::ByteStreamError;
|
||||
pub use error::{StorageError, StorageResult};
|
||||
pub use local::{LocalStorage, LocalStorageConfig};
|
||||
pub use s3::{S3Storage, S3StorageConfig};
|
||||
use track::CounterVec;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AppStorageConfig {
|
||||
@ -18,11 +19,60 @@ pub enum AppStorageConfig {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum AppStorage {
|
||||
pub struct AppStorage {
|
||||
inner: StorageBackend,
|
||||
metrics: Option<track::MetricsRegistry>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum StorageBackend {
|
||||
Local(LocalStorage),
|
||||
S3(S3Storage),
|
||||
}
|
||||
|
||||
impl AppStorage {
|
||||
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||
self.metrics = Some(registry);
|
||||
}
|
||||
|
||||
fn backend_name(&self) -> &str {
|
||||
match &self.inner {
|
||||
StorageBackend::Local(_) => "local",
|
||||
StorageBackend::S3(_) => "s3",
|
||||
}
|
||||
}
|
||||
|
||||
fn record_upload(&self, bytes: usize) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
storage_ops_vec(reg)
|
||||
.with_label_values(&["upload", self.backend_name()])
|
||||
.inc();
|
||||
storage_bytes_vec(reg)
|
||||
.with_label_values(&["upload"])
|
||||
.inc_by(bytes as f64);
|
||||
}
|
||||
}
|
||||
|
||||
fn record_download(&self, bytes: usize) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
storage_ops_vec(reg)
|
||||
.with_label_values(&["download", self.backend_name()])
|
||||
.inc();
|
||||
storage_bytes_vec(reg)
|
||||
.with_label_values(&["download"])
|
||||
.inc_by(bytes as f64);
|
||||
}
|
||||
}
|
||||
|
||||
fn record_delete(&self) {
|
||||
if let Some(reg) = &self.metrics {
|
||||
storage_ops_vec(reg)
|
||||
.with_label_values(&["delete", self.backend_name()])
|
||||
.inc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct PutObjectOptions {
|
||||
pub content_type: Option<String>,
|
||||
@ -87,76 +137,109 @@ pub trait ObjectStorage: Send + Sync {
|
||||
}
|
||||
|
||||
impl AppStorage {
|
||||
#[tracing::instrument(skip(config))]
|
||||
pub async fn init(config: AppStorageConfig) -> StorageResult<Self> {
|
||||
match config {
|
||||
let inner = match config {
|
||||
AppStorageConfig::Local(config) => {
|
||||
Ok(Self::Local(LocalStorage::connect(config).await?))
|
||||
tracing::info!("initializing local storage");
|
||||
StorageBackend::Local(LocalStorage::connect(config).await?)
|
||||
}
|
||||
AppStorageConfig::S3(config) => {
|
||||
Ok(Self::S3(S3Storage::connect(config).await?))
|
||||
tracing::info!(bucket = %config.bucket, region = %config.region, "initializing S3 storage");
|
||||
StorageBackend::S3(S3Storage::connect(config).await?)
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
inner,
|
||||
metrics: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ObjectStorage for AppStorage {
|
||||
#[tracing::instrument(skip(self, body), fields(storage.key = %key))]
|
||||
async fn put_stream(
|
||||
&self,
|
||||
key: &str,
|
||||
body: ByteStream,
|
||||
options: PutObjectOptions,
|
||||
) -> StorageResult<StoredObject> {
|
||||
match self {
|
||||
Self::Local(storage) => {
|
||||
let result = match &self.inner {
|
||||
StorageBackend::Local(storage) => {
|
||||
storage.put_stream(key, body, options).await
|
||||
}
|
||||
Self::S3(storage) => storage.put_stream(key, body, options).await,
|
||||
StorageBackend::S3(storage) => {
|
||||
storage.put_stream(key, body, options).await
|
||||
}
|
||||
};
|
||||
if result.is_ok() {
|
||||
self.record_upload(0);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, bytes), fields(storage.key = %key, storage.size = bytes.len()))]
|
||||
async fn put_bytes(
|
||||
&self,
|
||||
key: &str,
|
||||
bytes: Vec<u8>,
|
||||
options: PutObjectOptions,
|
||||
) -> StorageResult<StoredObject> {
|
||||
match self {
|
||||
Self::Local(storage) => {
|
||||
let size = bytes.len();
|
||||
let result = match &self.inner {
|
||||
StorageBackend::Local(storage) => {
|
||||
storage.put_bytes(key, bytes, options).await
|
||||
}
|
||||
Self::S3(storage) => storage.put_bytes(key, bytes, options).await,
|
||||
StorageBackend::S3(storage) => {
|
||||
storage.put_bytes(key, bytes, options).await
|
||||
}
|
||||
};
|
||||
if result.is_ok() {
|
||||
self.record_upload(size);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||
async fn get_stream(
|
||||
&self,
|
||||
key: &str,
|
||||
) -> StorageResult<StorageObjectStream> {
|
||||
match self {
|
||||
Self::Local(storage) => storage.get_stream(key).await,
|
||||
Self::S3(storage) => storage.get_stream(key).await,
|
||||
match &self.inner {
|
||||
StorageBackend::Local(storage) => storage.get_stream(key).await,
|
||||
StorageBackend::S3(storage) => storage.get_stream(key).await,
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||
async fn get_bytes(&self, key: &str) -> StorageResult<StorageObject> {
|
||||
match self {
|
||||
Self::Local(storage) => storage.get_bytes(key).await,
|
||||
Self::S3(storage) => storage.get_bytes(key).await,
|
||||
let result = match &self.inner {
|
||||
StorageBackend::Local(storage) => storage.get_bytes(key).await,
|
||||
StorageBackend::S3(storage) => storage.get_bytes(key).await,
|
||||
};
|
||||
if let Ok(obj) = &result {
|
||||
self.record_download(obj.bytes.len());
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||
async fn delete(&self, key: &str) -> StorageResult<()> {
|
||||
match self {
|
||||
Self::Local(storage) => storage.delete(key).await,
|
||||
Self::S3(storage) => storage.delete(key).await,
|
||||
let result = match &self.inner {
|
||||
StorageBackend::Local(storage) => storage.delete(key).await,
|
||||
StorageBackend::S3(storage) => storage.delete(key).await,
|
||||
};
|
||||
if result.is_ok() {
|
||||
self.record_delete();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn public_url(&self, key: &str) -> StorageResult<Option<String>> {
|
||||
match self {
|
||||
Self::Local(storage) => storage.public_url(key),
|
||||
Self::S3(storage) => storage.public_url(key),
|
||||
match &self.inner {
|
||||
StorageBackend::Local(storage) => storage.public_url(key),
|
||||
StorageBackend::S3(storage) => storage.public_url(key),
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,17 +248,37 @@ impl ObjectStorage for AppStorage {
|
||||
key: &str,
|
||||
expires_in: Duration,
|
||||
) -> StorageResult<String> {
|
||||
match self {
|
||||
Self::Local(storage) => {
|
||||
match &self.inner {
|
||||
StorageBackend::Local(storage) => {
|
||||
storage.presigned_get_url(key, expires_in).await
|
||||
}
|
||||
Self::S3(storage) => {
|
||||
StorageBackend::S3(storage) => {
|
||||
storage.presigned_get_url(key, expires_in).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn storage_ops_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"storage_operations_total",
|
||||
"Total storage operations",
|
||||
&["operation", "backend"],
|
||||
)
|
||||
.expect("failed to register storage_operations_total")
|
||||
}
|
||||
|
||||
fn storage_bytes_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||
registry
|
||||
.register_counter_vec(
|
||||
"storage_bytes_total",
|
||||
"Total bytes transferred",
|
||||
&["operation"],
|
||||
)
|
||||
.expect("failed to register storage_bytes_total")
|
||||
}
|
||||
|
||||
pub async fn collect_byte_stream(
|
||||
body: ByteStream,
|
||||
) -> Result<Vec<u8>, ByteStreamError> {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user