refactor: update infrastructure libs (config, db, cache, queue, storage, migrate)
This commit is contained in:
parent
e44f3d13c4
commit
734e1c4cc8
319
lib/cache/app.rs
vendored
Normal file
319
lib/cache/app.rs
vendored
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use track::CounterVec;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
cluster::{ClusterCache, ClusterCacheConfig},
|
||||||
|
error::{CacheError, CacheResult},
|
||||||
|
local::{LocalCacheConfig, MokaCache},
|
||||||
|
};
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Configuration
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AppCacheConfig {
|
||||||
|
pub local: LocalCacheConfig,
|
||||||
|
pub cluster: Option<ClusterCacheConfig>,
|
||||||
|
pub default_ttl: Option<Duration>,
|
||||||
|
pub cluster_write_through: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppCacheConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
local: LocalCacheConfig::default(),
|
||||||
|
cluster: None,
|
||||||
|
default_ttl: Some(Duration::from_secs(300)),
|
||||||
|
cluster_write_through: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&config::AppConfig> for AppCacheConfig {
|
||||||
|
type Error = CacheError;
|
||||||
|
|
||||||
|
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
|
||||||
|
let local = LocalCacheConfig {
|
||||||
|
max_capacity: config
|
||||||
|
.cache_local_max_capacity()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
time_to_live: config
|
||||||
|
.cache_local_ttl()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
time_to_idle: config
|
||||||
|
.cache_local_tti()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cluster = if config
|
||||||
|
.cache_cluster_enabled()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?
|
||||||
|
{
|
||||||
|
Some(ClusterCacheConfig {
|
||||||
|
urls: config
|
||||||
|
.redis_urls()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
key_prefix: config.cache_cluster_key_prefix(),
|
||||||
|
command_timeout: config
|
||||||
|
.cache_cluster_command_timeout()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
local,
|
||||||
|
cluster,
|
||||||
|
default_ttl: config
|
||||||
|
.cache_default_ttl()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
cluster_write_through: config
|
||||||
|
.cache_cluster_write_through()
|
||||||
|
.map_err(|error| CacheError::Config(error.to_string()))?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// AppCache
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppCache {
|
||||||
|
pub local: MokaCache,
|
||||||
|
pub cluster: Option<ClusterCache>,
|
||||||
|
default_ttl: Option<Duration>,
|
||||||
|
cluster_write_through: bool,
|
||||||
|
metrics: Option<track::MetricsRegistry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppCache {
|
||||||
|
#[tracing::instrument(skip(config))]
|
||||||
|
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
|
||||||
|
let local = MokaCache::with_config(config.local);
|
||||||
|
let cluster = match config.cluster {
|
||||||
|
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
|
||||||
|
Ok(cluster) => cluster,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "failed to connect to cache cluster");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(has_cluster = cluster.is_some(), "cache initialized");
|
||||||
|
Ok(Self {
|
||||||
|
local,
|
||||||
|
cluster,
|
||||||
|
default_ttl: config.default_ttl,
|
||||||
|
cluster_write_through: config.cluster_write_through,
|
||||||
|
metrics: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_only(local: MokaCache) -> Self {
|
||||||
|
Self {
|
||||||
|
local,
|
||||||
|
cluster: None,
|
||||||
|
default_ttl: None,
|
||||||
|
cluster_write_through: false,
|
||||||
|
metrics: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attach a metrics registry for recording cache counters.
|
||||||
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||||
|
self.metrics = Some(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(cache.key = %key))]
|
||||||
|
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
|
||||||
|
where
|
||||||
|
T: serde::Serialize + serde::de::DeserializeOwned,
|
||||||
|
{
|
||||||
|
if let Some(value) = self.local.get(key).await? {
|
||||||
|
tracing::debug!("cache hit (local)");
|
||||||
|
self.record_hit("local");
|
||||||
|
return Ok(Some(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(cluster) = &self.cluster else {
|
||||||
|
tracing::debug!("cache miss");
|
||||||
|
self.record_miss();
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = cluster.get::<T>(key).await?;
|
||||||
|
if let Some(value) = &value {
|
||||||
|
self.local.set(key, value).await?;
|
||||||
|
tracing::debug!("cache hit (cluster)");
|
||||||
|
self.record_hit("cluster");
|
||||||
|
} else {
|
||||||
|
tracing::debug!("cache miss");
|
||||||
|
self.record_miss();
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, value), fields(cache.key = %key))]
|
||||||
|
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
|
||||||
|
where
|
||||||
|
T: serde::Serialize + ?Sized,
|
||||||
|
{
|
||||||
|
self.local.set(key, value).await?;
|
||||||
|
if self.cluster_write_through
|
||||||
|
&& let Some(cluster) = &self.cluster
|
||||||
|
{
|
||||||
|
cluster.set(key, value, self.default_ttl).await?;
|
||||||
|
}
|
||||||
|
self.record_set();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_with_ttl<T>(
|
||||||
|
&self,
|
||||||
|
key: &str,
|
||||||
|
value: &T,
|
||||||
|
ttl: std::time::Duration,
|
||||||
|
) -> CacheResult<()>
|
||||||
|
where
|
||||||
|
T: serde::Serialize + ?Sized,
|
||||||
|
{
|
||||||
|
self.local.set(key, value).await?;
|
||||||
|
if self.cluster_write_through
|
||||||
|
&& let Some(cluster) = &self.cluster
|
||||||
|
{
|
||||||
|
cluster.set(key, value, Some(ttl)).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(cache.key = %key))]
|
||||||
|
pub async fn remove(&self, key: &str) -> CacheResult<()> {
|
||||||
|
self.local.remove(key).await;
|
||||||
|
if let Some(cluster) = &self.cluster {
|
||||||
|
cluster.remove(key).await?;
|
||||||
|
}
|
||||||
|
self.record_remove();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_hit(&self, tier: &str) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
cache_hits_vec(reg).with_label_values(&[tier]).inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_miss(&self) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
cache_misses_vec(reg).with_label_values(&[]).inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_set(&self) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
cache_sets_vec(reg).with_label_values(&[]).inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_remove(&self) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
cache_removes_vec(reg).with_label_values(&[]).inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
|
||||||
|
let pattern = pattern.to_string();
|
||||||
|
let local_pattern = pattern.clone();
|
||||||
|
self.local.invalidate_entries_if(move |key| {
|
||||||
|
simple_glob_match(&local_pattern, key)
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut removed = 0u64;
|
||||||
|
if let Some(cluster) = &self.cluster {
|
||||||
|
removed = cluster.delete_pattern(&pattern).await?;
|
||||||
|
}
|
||||||
|
Ok(removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ping_cluster(&self) -> CacheResult<()> {
|
||||||
|
if let Some(cluster) = &self.cluster {
|
||||||
|
cluster.ping().await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
|
||||||
|
self.cluster.as_ref().map(|c| c.conn())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cache_hits_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec("cache_hits_total", "Total cache hits", &["tier"])
|
||||||
|
.expect("failed to register cache_hits_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cache_misses_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec("cache_misses_total", "Total cache misses", &[])
|
||||||
|
.expect("failed to register cache_misses_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cache_sets_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"cache_sets_total",
|
||||||
|
"Total cache set operations",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.expect("failed to register cache_sets_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cache_removes_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"cache_removes_total",
|
||||||
|
"Total cache remove operations",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.expect("failed to register cache_removes_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Helpers
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
fn simple_glob_match(pattern: &str, key: &str) -> bool {
|
||||||
|
let p = pattern.as_bytes();
|
||||||
|
let k = key.as_bytes();
|
||||||
|
let (mut pi, mut ki) = (0usize, 0usize);
|
||||||
|
|
||||||
|
let mut backtrack_p: Option<usize> = None;
|
||||||
|
let mut backtrack_k: usize = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
|
||||||
|
pi += 1;
|
||||||
|
ki += 1;
|
||||||
|
} else if pi < p.len() && p[pi] == b'*' {
|
||||||
|
backtrack_p = Some(pi);
|
||||||
|
backtrack_k = ki;
|
||||||
|
pi += 1;
|
||||||
|
} else if let Some(saved_pi) = backtrack_p {
|
||||||
|
backtrack_k += 1;
|
||||||
|
ki = backtrack_k;
|
||||||
|
pi = saved_pi + 1;
|
||||||
|
} else {
|
||||||
|
return pi == p.len() && ki == k.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
if pi == p.len() && ki == k.len() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
220
lib/cache/lib.rs
vendored
220
lib/cache/lib.rs
vendored
@ -1,227 +1,11 @@
|
|||||||
|
pub mod app;
|
||||||
pub mod cluster;
|
pub mod cluster;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod local;
|
pub mod local;
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
pub use crate::{
|
pub use crate::{
|
||||||
|
app::{AppCache, AppCacheConfig},
|
||||||
cluster::{ClusterCache, ClusterCacheConfig},
|
cluster::{ClusterCache, ClusterCacheConfig},
|
||||||
error::{CacheError, CacheResult},
|
error::{CacheError, CacheResult},
|
||||||
local::{LocalCacheConfig, MokaCache},
|
local::{LocalCacheConfig, MokaCache},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct AppCacheConfig {
|
|
||||||
pub local: LocalCacheConfig,
|
|
||||||
pub cluster: Option<ClusterCacheConfig>,
|
|
||||||
pub default_ttl: Option<Duration>,
|
|
||||||
pub cluster_write_through: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for AppCacheConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
local: LocalCacheConfig::default(),
|
|
||||||
cluster: None,
|
|
||||||
default_ttl: Some(Duration::from_secs(300)),
|
|
||||||
cluster_write_through: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct AppCache {
|
|
||||||
pub local: MokaCache,
|
|
||||||
pub cluster: Option<ClusterCache>,
|
|
||||||
default_ttl: Option<Duration>,
|
|
||||||
cluster_write_through: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AppCache {
|
|
||||||
pub async fn init(config: AppCacheConfig) -> CacheResult<Self> {
|
|
||||||
let local = MokaCache::with_config(config.local);
|
|
||||||
let cluster = match config.cluster {
|
|
||||||
Some(cluster) => Some(match ClusterCache::connect(cluster).await {
|
|
||||||
Ok(cluster) => cluster,
|
|
||||||
Err(e) => {
|
|
||||||
println!("cache:init:error with: {}", e);
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
local,
|
|
||||||
cluster,
|
|
||||||
default_ttl: config.default_ttl,
|
|
||||||
cluster_write_through: config.cluster_write_through,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn local_only(local: MokaCache) -> Self {
|
|
||||||
Self {
|
|
||||||
local,
|
|
||||||
cluster: None,
|
|
||||||
default_ttl: None,
|
|
||||||
cluster_write_through: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get<T>(&self, key: &str) -> CacheResult<Option<T>>
|
|
||||||
where
|
|
||||||
T: serde::Serialize + serde::de::DeserializeOwned,
|
|
||||||
{
|
|
||||||
if let Some(value) = self.local.get(key).await? {
|
|
||||||
return Ok(Some(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(cluster) = &self.cluster else {
|
|
||||||
return Ok(None);
|
|
||||||
};
|
|
||||||
|
|
||||||
let value = cluster.get::<T>(key).await?;
|
|
||||||
if let Some(value) = &value {
|
|
||||||
self.local.set(key, value).await?;
|
|
||||||
}
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set<T>(&self, key: &str, value: &T) -> CacheResult<()>
|
|
||||||
where
|
|
||||||
T: serde::Serialize + ?Sized,
|
|
||||||
{
|
|
||||||
self.local.set(key, value).await?;
|
|
||||||
if self.cluster_write_through
|
|
||||||
&& let Some(cluster) = &self.cluster
|
|
||||||
{
|
|
||||||
cluster.set(key, value, self.default_ttl).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_with_ttl<T>(
|
|
||||||
&self,
|
|
||||||
key: &str,
|
|
||||||
value: &T,
|
|
||||||
ttl: std::time::Duration,
|
|
||||||
) -> CacheResult<()>
|
|
||||||
where
|
|
||||||
T: serde::Serialize + ?Sized,
|
|
||||||
{
|
|
||||||
self.local.set(key, value).await?;
|
|
||||||
if self.cluster_write_through
|
|
||||||
&& let Some(cluster) = &self.cluster
|
|
||||||
{
|
|
||||||
cluster.set(key, value, Some(ttl)).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn remove(&self, key: &str) -> CacheResult<()> {
|
|
||||||
self.local.remove(key).await;
|
|
||||||
if let Some(cluster) = &self.cluster {
|
|
||||||
cluster.remove(key).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
pub async fn delete_pattern(&self, pattern: &str) -> CacheResult<u64> {
|
|
||||||
let pattern = pattern.to_string();
|
|
||||||
let local_pattern = pattern.clone();
|
|
||||||
self.local.invalidate_entries_if(move |key| {
|
|
||||||
simple_glob_match(&local_pattern, key)
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut removed = 0u64;
|
|
||||||
if let Some(cluster) = &self.cluster {
|
|
||||||
removed = cluster.delete_pattern(&pattern).await?;
|
|
||||||
}
|
|
||||||
Ok(removed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ping_cluster(&self) -> CacheResult<()> {
|
|
||||||
if let Some(cluster) = &self.cluster {
|
|
||||||
cluster.ping().await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conn(&self) -> Option<redis::cluster_async::ClusterConnection> {
|
|
||||||
self.cluster.as_ref().map(|c| c.conn())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&config::AppConfig> for AppCacheConfig {
|
|
||||||
type Error = CacheError;
|
|
||||||
|
|
||||||
fn try_from(config: &config::AppConfig) -> Result<Self, Self::Error> {
|
|
||||||
let local = LocalCacheConfig {
|
|
||||||
max_capacity: config
|
|
||||||
.cache_local_max_capacity()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
time_to_live: config
|
|
||||||
.cache_local_ttl()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
time_to_idle: config
|
|
||||||
.cache_local_tti()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
let cluster = if config
|
|
||||||
.cache_cluster_enabled()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?
|
|
||||||
{
|
|
||||||
Some(ClusterCacheConfig {
|
|
||||||
urls: config
|
|
||||||
.redis_urls()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
key_prefix: config.cache_cluster_key_prefix(),
|
|
||||||
command_timeout: config
|
|
||||||
.cache_cluster_command_timeout()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
local,
|
|
||||||
cluster,
|
|
||||||
default_ttl: config
|
|
||||||
.cache_default_ttl()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
cluster_write_through: config
|
|
||||||
.cache_cluster_write_through()
|
|
||||||
.map_err(|error| CacheError::Config(error.to_string()))?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn simple_glob_match(pattern: &str, key: &str) -> bool {
|
|
||||||
let p = pattern.as_bytes();
|
|
||||||
let k = key.as_bytes();
|
|
||||||
let (mut pi, mut ki) = (0usize, 0usize);
|
|
||||||
|
|
||||||
let mut backtrack_p: Option<usize> = None;
|
|
||||||
let mut backtrack_k: usize = 0;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
if pi < p.len() && ki < k.len() && (p[pi] == b'?' || p[pi] == k[ki]) {
|
|
||||||
pi += 1;
|
|
||||||
ki += 1;
|
|
||||||
} else if pi < p.len() && p[pi] == b'*' {
|
|
||||||
backtrack_p = Some(pi);
|
|
||||||
backtrack_k = ki;
|
|
||||||
pi += 1;
|
|
||||||
} else if let Some(saved_pi) = backtrack_p {
|
|
||||||
backtrack_k += 1;
|
|
||||||
ki = backtrack_k;
|
|
||||||
pi = saved_pi + 1;
|
|
||||||
} else {
|
|
||||||
return pi == p.len() && ki == k.len();
|
|
||||||
}
|
|
||||||
|
|
||||||
if pi == p.len() && ki == k.len() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -28,6 +28,13 @@ impl AppConfig {
|
|||||||
Ok(8080)
|
Ok(8080)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn email_health_port(&self) -> u16 {
|
||||||
|
self.env
|
||||||
|
.get("APP_EMAIL_HEALTH_PORT")
|
||||||
|
.and_then(|port| port.parse::<u16>().ok())
|
||||||
|
.unwrap_or(8083)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn session_secret(&self) -> anyhow::Result<String> {
|
pub fn session_secret(&self) -> anyhow::Result<String> {
|
||||||
if let Some(secret) = self.env.get("APP_SESSION_SECRET") {
|
if let Some(secret) = self.env.get("APP_SESSION_SECRET") {
|
||||||
return Ok(secret.to_string());
|
return Ok(secret.to_string());
|
||||||
|
|||||||
39
lib/config/app_config.rs
Normal file
39
lib/config/app_config.rs
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
use std::{collections::HashMap, sync::OnceLock};
|
||||||
|
|
||||||
|
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AppConfig {
|
||||||
|
pub env: HashMap<String, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppConfig {
|
||||||
|
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
|
||||||
|
|
||||||
|
pub fn load() -> AppConfig {
|
||||||
|
let mut env = HashMap::new();
|
||||||
|
for env_file in AppConfig::ENV_FILES {
|
||||||
|
if let Err(e) = dotenvy::from_path(env_file) {
|
||||||
|
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
|
||||||
|
}
|
||||||
|
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
|
||||||
|
for line in env_file_content.lines() {
|
||||||
|
if let Some((key, value)) = line.split_once('=') {
|
||||||
|
env.insert(key.to_string(), value.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
env = env.into_iter().chain(std::env::vars()).collect();
|
||||||
|
let this = AppConfig { env };
|
||||||
|
if let Some(config) = GLOBAL_CONFIG.get() {
|
||||||
|
config.clone()
|
||||||
|
} else {
|
||||||
|
let _ = GLOBAL_CONFIG.set(this);
|
||||||
|
GLOBAL_CONFIG
|
||||||
|
.get()
|
||||||
|
.expect("global config should be set after load")
|
||||||
|
.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,44 +1,6 @@
|
|||||||
use std::{collections::HashMap, sync::OnceLock};
|
|
||||||
|
|
||||||
pub static GLOBAL_CONFIG: OnceLock<AppConfig> = OnceLock::new();
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct AppConfig {
|
|
||||||
pub env: HashMap<String, String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AppConfig {
|
|
||||||
const ENV_FILES: &'static [&'static str] = &[".env", ".env.local"];
|
|
||||||
pub fn load() -> AppConfig {
|
|
||||||
let mut env = HashMap::new();
|
|
||||||
for env_file in AppConfig::ENV_FILES {
|
|
||||||
if let Err(e) = dotenvy::from_path(env_file) {
|
|
||||||
tracing::debug!(file = %env_file, error = %e, "dotenv load skipped");
|
|
||||||
}
|
|
||||||
if let Ok(env_file_content) = std::fs::read_to_string(env_file) {
|
|
||||||
for line in env_file_content.lines() {
|
|
||||||
if let Some((key, value)) = line.split_once('=') {
|
|
||||||
env.insert(key.to_string(), value.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
env = env.into_iter().chain(std::env::vars()).collect();
|
|
||||||
let this = AppConfig { env };
|
|
||||||
if GLOBAL_CONFIG.get().is_some() {
|
|
||||||
GLOBAL_CONFIG.get().unwrap().clone()
|
|
||||||
} else {
|
|
||||||
let _ = GLOBAL_CONFIG.set(this);
|
|
||||||
GLOBAL_CONFIG
|
|
||||||
.get()
|
|
||||||
.expect("global config should be set after load")
|
|
||||||
.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod ai;
|
pub mod ai;
|
||||||
pub mod app;
|
pub mod app;
|
||||||
|
pub mod app_config;
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod avatar;
|
pub mod avatar;
|
||||||
pub mod cache;
|
pub mod cache;
|
||||||
@ -57,3 +19,5 @@ pub mod redis;
|
|||||||
pub mod smtp;
|
pub mod smtp;
|
||||||
pub mod ssh;
|
pub mod ssh;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
|
|
||||||
|
pub use app_config::{AppConfig, GLOBAL_CONFIG};
|
||||||
|
|||||||
@ -61,7 +61,7 @@ impl AppConfig {
|
|||||||
if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") {
|
if let Some(endpoint) = self.env.get("APP_OTEL_ENDPOINT") {
|
||||||
return Ok(endpoint.to_string());
|
return Ok(endpoint.to_string());
|
||||||
}
|
}
|
||||||
Ok("http://localhost:5080/api/default/v1/traces".to_string())
|
Ok("http://localhost:4318".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn otel_service_name(&self) -> anyhow::Result<String> {
|
pub fn otel_service_name(&self) -> anyhow::Result<String> {
|
||||||
|
|||||||
@ -7,6 +7,7 @@ use sqlx::{
|
|||||||
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
|
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
use track::{CounterVec, HistogramVec};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
route::{SqlRoute, route_sql},
|
route::{SqlRoute, route_sql},
|
||||||
@ -17,9 +18,11 @@ use crate::{
|
|||||||
pub struct AppDatabase {
|
pub struct AppDatabase {
|
||||||
db_write: PgPool,
|
db_write: PgPool,
|
||||||
db_read: Option<PgPool>,
|
db_read: Option<PgPool>,
|
||||||
|
metrics: Option<track::MetricsRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppDatabase {
|
impl AppDatabase {
|
||||||
|
#[tracing::instrument(skip(cfg))]
|
||||||
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
||||||
let db_url = cfg.database_url()?;
|
let db_url = cfg.database_url()?;
|
||||||
let max_connections = cfg.database_max_connections()?;
|
let max_connections = cfg.database_max_connections()?;
|
||||||
@ -69,7 +72,15 @@ impl AppDatabase {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self { db_write, db_read })
|
Ok(Self {
|
||||||
|
db_write,
|
||||||
|
db_read,
|
||||||
|
metrics: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||||
|
self.metrics = Some(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writer(&self) -> &PgPool {
|
pub fn writer(&self) -> &PgPool {
|
||||||
@ -87,11 +98,14 @@ impl AppDatabase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(sql.route = "write"))]
|
||||||
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
|
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
|
||||||
let txn = self.db_write.begin().await?;
|
let txn = self.db_write.begin().await?;
|
||||||
|
tracing::debug!("db transaction started");
|
||||||
Ok(AppTransaction { inner: txn })
|
Ok(AppTransaction { inner: txn })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(sql.route = "read"))]
|
||||||
pub async fn begin_read_only(
|
pub async fn begin_read_only(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<AppTransaction<'_>, sqlx::Error> {
|
) -> Result<AppTransaction<'_>, sqlx::Error> {
|
||||||
@ -101,9 +115,11 @@ impl AppDatabase {
|
|||||||
.execute(&mut *txn)
|
.execute(&mut *txn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
tracing::debug!("db read-only transaction started");
|
||||||
Ok(AppTransaction { inner: txn })
|
Ok(AppTransaction { inner: txn })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
@ -111,18 +127,38 @@ impl AppDatabase {
|
|||||||
self.execute_with_args(sql, PgArguments::default()).await
|
self.execute_with_args(sql, PgArguments::default()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))]
|
||||||
pub async fn execute_with_args(
|
pub async fn execute_with_args(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
args: PgArguments,
|
args: PgArguments,
|
||||||
) -> Result<PgQueryResult, sqlx::Error> {
|
) -> Result<PgQueryResult, sqlx::Error> {
|
||||||
let pool = self.route_pool(sql);
|
let pool = self.route_pool(sql);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await
|
.await;
|
||||||
|
|
||||||
|
let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") {
|
||||||
|
"insert"
|
||||||
|
} else if sql.trim_start().to_uppercase().starts_with("UPDATE") {
|
||||||
|
"update"
|
||||||
|
} else if sql.trim_start().to_uppercase().starts_with("DELETE") {
|
||||||
|
"delete"
|
||||||
|
} else {
|
||||||
|
"execute"
|
||||||
|
};
|
||||||
|
self.record_query(
|
||||||
|
kind,
|
||||||
|
self.route_label(sql),
|
||||||
|
start.elapsed(),
|
||||||
|
result.is_ok(),
|
||||||
|
);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))]
|
||||||
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
|
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
|
||||||
where
|
where
|
||||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||||
@ -130,6 +166,7 @@ impl AppDatabase {
|
|||||||
self.fetch_one_with_args(sql, PgArguments::default()).await
|
self.fetch_one_with_args(sql, PgArguments::default()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))]
|
||||||
pub async fn fetch_one_with_args<T>(
|
pub async fn fetch_one_with_args<T>(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
@ -139,12 +176,23 @@ impl AppDatabase {
|
|||||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||||
{
|
{
|
||||||
let pool = self.route_pool(sql);
|
let pool = self.route_pool(sql);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
let result =
|
||||||
.fetch_one(pool)
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||||
.await
|
.fetch_one(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.record_query(
|
||||||
|
"select",
|
||||||
|
self.route_label(sql),
|
||||||
|
start.elapsed(),
|
||||||
|
result.is_ok(),
|
||||||
|
);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))]
|
||||||
pub async fn fetch_optional<T>(
|
pub async fn fetch_optional<T>(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
@ -156,6 +204,7 @@ impl AppDatabase {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))]
|
||||||
pub async fn fetch_optional_with_args<T>(
|
pub async fn fetch_optional_with_args<T>(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
@ -165,12 +214,23 @@ impl AppDatabase {
|
|||||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||||
{
|
{
|
||||||
let pool = self.route_pool(sql);
|
let pool = self.route_pool(sql);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
let result =
|
||||||
.fetch_optional(pool)
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||||
.await
|
.fetch_optional(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.record_query(
|
||||||
|
"select",
|
||||||
|
self.route_label(sql),
|
||||||
|
start.elapsed(),
|
||||||
|
result.is_ok(),
|
||||||
|
);
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))]
|
||||||
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
|
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
|
||||||
where
|
where
|
||||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||||
@ -178,6 +238,7 @@ impl AppDatabase {
|
|||||||
self.fetch_all_with_args(sql, PgArguments::default()).await
|
self.fetch_all_with_args(sql, PgArguments::default()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))]
|
||||||
pub async fn fetch_all_with_args<T>(
|
pub async fn fetch_all_with_args<T>(
|
||||||
&self,
|
&self,
|
||||||
sql: &str,
|
sql: &str,
|
||||||
@ -187,10 +248,45 @@ impl AppDatabase {
|
|||||||
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
||||||
{
|
{
|
||||||
let pool = self.route_pool(sql);
|
let pool = self.route_pool(sql);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
let result =
|
||||||
.fetch_all(pool)
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
||||||
.await
|
.fetch_all(pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.record_query(
|
||||||
|
"select",
|
||||||
|
self.route_label(sql),
|
||||||
|
start.elapsed(),
|
||||||
|
result.is_ok(),
|
||||||
|
);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn route_label(&self, sql: &str) -> &str {
|
||||||
|
match route_sql(sql) {
|
||||||
|
SqlRoute::Write => "write",
|
||||||
|
SqlRoute::Read => "read",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_query(
|
||||||
|
&self,
|
||||||
|
kind: &str,
|
||||||
|
route: &str,
|
||||||
|
duration: Duration,
|
||||||
|
success: bool,
|
||||||
|
) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
let status = if success { "success" } else { "error" };
|
||||||
|
db_queries_vec(reg)
|
||||||
|
.with_label_values(&[kind, route, status])
|
||||||
|
.inc();
|
||||||
|
db_query_duration_vec(reg)
|
||||||
|
.with_label_values(&[kind, route])
|
||||||
|
.observe(duration.as_secs_f64());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,3 +326,26 @@ async fn build_pool(
|
|||||||
|
|
||||||
pool_options.connect_with(options).await
|
pool_options.connect_with(options).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"db_queries_total",
|
||||||
|
"Total database queries",
|
||||||
|
&["kind", "route", "status"],
|
||||||
|
)
|
||||||
|
.expect("failed to register db_queries_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec {
|
||||||
|
registry
|
||||||
|
.register_histogram_vec(
|
||||||
|
"db_query_duration_seconds",
|
||||||
|
"DB query duration in seconds",
|
||||||
|
&["kind", "route"],
|
||||||
|
vec![
|
||||||
|
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.expect("failed to register db_query_duration_seconds")
|
||||||
|
}
|
||||||
|
|||||||
267
lib/migrate/lib.rs
Normal file
267
lib/migrate/lib.rs
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
use std::collections::{BTreeMap, HashMap, VecDeque};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
struct Migration {
|
||||||
|
domain: String,
|
||||||
|
table: String,
|
||||||
|
version: u32,
|
||||||
|
direction: MigrationDir,
|
||||||
|
path: PathBuf,
|
||||||
|
depends_on: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for Migration {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
|
||||||
|
&other.domain,
|
||||||
|
&other.table,
|
||||||
|
other.version,
|
||||||
|
&other.direction,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for Migration {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
|
enum MigrationDir {
|
||||||
|
Up,
|
||||||
|
Down,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_up(pool: &PgPool) -> Result<()> {
|
||||||
|
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
|
||||||
|
ensure_migrations_table(pool).await?;
|
||||||
|
let all = discover_migrations(&sql_root)?;
|
||||||
|
let applied = applied_set(pool).await?;
|
||||||
|
|
||||||
|
let mut up_migrations: Vec<_> = all
|
||||||
|
.into_iter()
|
||||||
|
.filter(|m| m.direction == MigrationDir::Up)
|
||||||
|
.filter(|m| {
|
||||||
|
!applied.contains_key(&(
|
||||||
|
m.domain.clone(),
|
||||||
|
m.table.clone(),
|
||||||
|
m.version,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if up_migrations.is_empty() {
|
||||||
|
tracing::info!("All migrations are already applied.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
topo_sort(&mut up_migrations)?;
|
||||||
|
|
||||||
|
for m in &up_migrations {
|
||||||
|
let sql = std::fs::read_to_string(&m.path)
|
||||||
|
.context(format!("Failed to read {:?}", m.path))?;
|
||||||
|
let checksum = compute_checksum(&sql);
|
||||||
|
|
||||||
|
tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration");
|
||||||
|
exec_sql(pool, &sql).await?;
|
||||||
|
record_migration(pool, m, &checksum).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Applied {} migration(s).", up_migrations.len());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
|
||||||
|
let mut migrations = Vec::new();
|
||||||
|
|
||||||
|
if !sql_root.exists() {
|
||||||
|
bail!("SQL directory not found: {}", sql_root.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
for dir_entry in std::fs::read_dir(sql_root)? {
|
||||||
|
let dir = dir_entry?;
|
||||||
|
if !dir.file_type()?.is_dir() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let domain = dir.file_name().to_string_lossy().to_string();
|
||||||
|
|
||||||
|
for file_entry in std::fs::read_dir(dir.path())? {
|
||||||
|
let file = file_entry?;
|
||||||
|
let path = file.path();
|
||||||
|
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stem = path
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.context("Invalid filename")?;
|
||||||
|
|
||||||
|
let (table, direction, version) = parse_migration_stem(stem)?;
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(&path)
|
||||||
|
.context(format!("Failed to read {path:?}"))?;
|
||||||
|
let depends_on = parse_depends_on(&content);
|
||||||
|
|
||||||
|
migrations.push(Migration {
|
||||||
|
domain: domain.clone(),
|
||||||
|
table,
|
||||||
|
version,
|
||||||
|
direction,
|
||||||
|
path,
|
||||||
|
depends_on,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
migrations.sort();
|
||||||
|
Ok(migrations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
|
||||||
|
if let Some(pos) = stem.rfind("_up_") {
|
||||||
|
let table = stem[..pos].to_string();
|
||||||
|
let version = stem[pos + 4..]
|
||||||
|
.parse::<u32>()
|
||||||
|
.context("Invalid version number")?;
|
||||||
|
Ok((table, MigrationDir::Up, version))
|
||||||
|
} else if let Some(pos) = stem.rfind("_down_") {
|
||||||
|
let table = stem[..pos].to_string();
|
||||||
|
let version = stem[pos + 6..]
|
||||||
|
.parse::<u32>()
|
||||||
|
.context("Invalid version number")?;
|
||||||
|
Ok((table, MigrationDir::Down, version))
|
||||||
|
} else {
|
||||||
|
bail!("Migration filename must contain _up_ or _down_: {stem}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_migrations_table(pool: &PgPool) -> Result<()> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS _sql_migrations (
|
||||||
|
domain TEXT NOT NULL,
|
||||||
|
table_name TEXT NOT NULL,
|
||||||
|
version INTEGER NOT NULL,
|
||||||
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
checksum TEXT NOT NULL DEFAULT '',
|
||||||
|
PRIMARY KEY (domain, table_name, version)
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn applied_set(
|
||||||
|
pool: &PgPool,
|
||||||
|
) -> Result<BTreeMap<(String, String, u32), String>> {
|
||||||
|
let rows: Vec<(String, String, i32, String)> = sqlx::query_as(
|
||||||
|
"SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version",
|
||||||
|
)
|
||||||
|
.fetch_all(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|(d, t, v, c)| ((d, t, v as u32), c))
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn record_migration(
|
||||||
|
pool: &PgPool,
|
||||||
|
m: &Migration,
|
||||||
|
checksum: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING",
|
||||||
|
)
|
||||||
|
.bind(&m.domain)
|
||||||
|
.bind(&m.table)
|
||||||
|
.bind(m.version as i32)
|
||||||
|
.bind(checksum)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_checksum(content: &str) -> String {
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
content.hash(&mut hasher);
|
||||||
|
format!("{:x}", hasher.finish())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_depends_on(content: &str) -> Vec<String> {
|
||||||
|
content
|
||||||
|
.lines()
|
||||||
|
.filter_map(|line| {
|
||||||
|
let line = line.trim();
|
||||||
|
line.strip_prefix("-- depends_on:").map(|deps| {
|
||||||
|
deps.split(',')
|
||||||
|
.map(|d| d.trim().to_string())
|
||||||
|
.filter(|d| !d.is_empty())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
|
||||||
|
let table_to_idx: HashMap<String, usize> = migrations
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, m)| (m.table.clone(), i))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let n = migrations.len();
|
||||||
|
let mut in_degree = vec![0u32; n];
|
||||||
|
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
||||||
|
|
||||||
|
for (i, m) in migrations.iter().enumerate() {
|
||||||
|
for dep in &m.depends_on {
|
||||||
|
if let Some(&j) = table_to_idx.get(dep) {
|
||||||
|
adj[j].push(i);
|
||||||
|
in_degree[i] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut queue: VecDeque<usize> =
|
||||||
|
(0..n).filter(|&i| in_degree[i] == 0).collect();
|
||||||
|
let mut order = Vec::with_capacity(n);
|
||||||
|
while let Some(i) = queue.pop_front() {
|
||||||
|
order.push(i);
|
||||||
|
for &next in &adj[i] {
|
||||||
|
in_degree[next] -= 1;
|
||||||
|
if in_degree[next] == 0 {
|
||||||
|
queue.push_back(next);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if order.len() != n {
|
||||||
|
bail!("Circular dependency detected among migrations");
|
||||||
|
}
|
||||||
|
|
||||||
|
let original: Vec<Migration> = migrations.iter().cloned().collect();
|
||||||
|
for (slot, &idx) in order.iter().enumerate() {
|
||||||
|
migrations[slot] = original[idx].clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> {
|
||||||
|
let s: &'static str = Box::leak(sql.to_owned().into_boxed_str());
|
||||||
|
sqlx::raw_sql(s).execute(pool).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
2
lib/migrate/sql/room/room_attachment_down_02.sql
Normal file
2
lib/migrate/sql/room/room_attachment_down_02.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE room_attachment ALTER COLUMN message SET NOT NULL;
|
||||||
|
ALTER TABLE room_attachment ALTER COLUMN seq SET NOT NULL;
|
||||||
2
lib/migrate/sql/room/room_attachment_up_02.sql
Normal file
2
lib/migrate/sql/room/room_attachment_up_02.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
ALTER TABLE room_attachment ALTER COLUMN message DROP NOT NULL;
|
||||||
|
ALTER TABLE room_attachment ALTER COLUMN seq DROP NOT NULL;
|
||||||
1
lib/migrate/sql/room/room_mention_down_02.sql
Normal file
1
lib/migrate/sql/room/room_mention_down_02.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE room_mention ALTER COLUMN target_id TYPE UUID USING target_id::UUID;
|
||||||
1
lib/migrate/sql/room/room_mention_up_02.sql
Normal file
1
lib/migrate/sql/room/room_mention_up_02.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE room_mention ALTER COLUMN target_id TYPE TEXT;
|
||||||
@ -1,59 +1,12 @@
|
|||||||
use anyhow::{Context, Result, bail};
|
use anyhow::Result;
|
||||||
use clap::{Parser, Subcommand};
|
use clap::Parser;
|
||||||
use sqlx::postgres::PgPoolOptions;
|
use sqlx::postgres::PgPoolOptions;
|
||||||
use std::collections::{BTreeMap, HashMap, VecDeque};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "migrate", about = "Database migration tool")]
|
#[command(name = "migrate", about = "Database migration tool")]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[arg(short, long)]
|
#[arg(short, long, env = "DATABASE_URL")]
|
||||||
database_url: String,
|
database_url: String,
|
||||||
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Command,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
|
||||||
enum Command {
|
|
||||||
Up,
|
|
||||||
Down,
|
|
||||||
Fresh,
|
|
||||||
List,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
struct Migration {
|
|
||||||
domain: String,
|
|
||||||
table: String,
|
|
||||||
version: u32,
|
|
||||||
direction: MigrationDir,
|
|
||||||
path: PathBuf,
|
|
||||||
depends_on: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Ord for Migration {
|
|
||||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
|
||||||
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
|
|
||||||
&other.domain,
|
|
||||||
&other.table,
|
|
||||||
other.version,
|
|
||||||
&other.direction,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialOrd for Migration {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
||||||
Some(self.cmp(other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
|
||||||
enum MigrationDir {
|
|
||||||
Up,
|
|
||||||
Down,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -64,358 +17,10 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
let database_url = std::env::var("DATABASE_URL")
|
|
||||||
.context("DATABASE_URL must be set or provided via --database-url")?;
|
|
||||||
|
|
||||||
let pool = PgPoolOptions::new()
|
let pool = PgPoolOptions::new()
|
||||||
.max_connections(1)
|
.max_connections(1)
|
||||||
.connect(&database_url)
|
.connect(&cli.database_url)
|
||||||
.await
|
|
||||||
.context("Failed to connect to database")?;
|
|
||||||
|
|
||||||
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
|
|
||||||
|
|
||||||
match cli.command {
|
|
||||||
Command::Up => run_up(&pool, &sql_root).await,
|
|
||||||
Command::Down => run_down(&pool, &sql_root).await,
|
|
||||||
Command::Fresh => run_fresh(&pool, &sql_root).await,
|
|
||||||
Command::List => run_list(&pool, &sql_root).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
|
|
||||||
let mut migrations = Vec::new();
|
|
||||||
|
|
||||||
if !sql_root.exists() {
|
|
||||||
bail!("SQL directory not found: {}", sql_root.display());
|
|
||||||
}
|
|
||||||
|
|
||||||
for dir_entry in std::fs::read_dir(sql_root)? {
|
|
||||||
let dir = dir_entry?;
|
|
||||||
if !dir.file_type()?.is_dir() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let domain = dir.file_name().to_string_lossy().to_string();
|
|
||||||
|
|
||||||
for file_entry in std::fs::read_dir(dir.path())? {
|
|
||||||
let file = file_entry?;
|
|
||||||
let path = file.path();
|
|
||||||
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let stem = path
|
|
||||||
.file_stem()
|
|
||||||
.and_then(|s| s.to_str())
|
|
||||||
.context("Invalid filename")?;
|
|
||||||
|
|
||||||
let (table, direction, version) = parse_migration_stem(stem)?;
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(&path)
|
|
||||||
.context(format!("Failed to read {path:?}"))?;
|
|
||||||
let depends_on = parse_depends_on(&content);
|
|
||||||
|
|
||||||
migrations.push(Migration {
|
|
||||||
domain: domain.clone(),
|
|
||||||
table,
|
|
||||||
version,
|
|
||||||
direction,
|
|
||||||
path,
|
|
||||||
depends_on,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
migrations.sort();
|
|
||||||
Ok(migrations)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
|
|
||||||
if let Some(pos) = stem.rfind("_up_") {
|
|
||||||
let table = stem[..pos].to_string();
|
|
||||||
let ver_str = &stem[pos + 4..];
|
|
||||||
let version =
|
|
||||||
ver_str.parse::<u32>().context("Invalid version number")?;
|
|
||||||
Ok((table, MigrationDir::Up, version))
|
|
||||||
} else if let Some(pos) = stem.rfind("_down_") {
|
|
||||||
let table = stem[..pos].to_string();
|
|
||||||
let ver_str = &stem[pos + 6..];
|
|
||||||
let version =
|
|
||||||
ver_str.parse::<u32>().context("Invalid version number")?;
|
|
||||||
Ok((table, MigrationDir::Down, version))
|
|
||||||
} else {
|
|
||||||
bail!("Migration filename must contain _up_ or _down_: {stem}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn ensure_migrations_table(pool: &sqlx::PgPool) -> Result<()> {
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS _sql_migrations (
|
|
||||||
domain TEXT NOT NULL,
|
|
||||||
table_name TEXT NOT NULL,
|
|
||||||
version INTEGER NOT NULL,
|
|
||||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
||||||
checksum TEXT NOT NULL DEFAULT '',
|
|
||||||
PRIMARY KEY (domain, table_name, version)
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn applied_set(
|
|
||||||
pool: &sqlx::PgPool,
|
|
||||||
) -> Result<BTreeMap<(String, String, u32), String>> {
|
|
||||||
let rows: Vec<(String, String, i32, String)> =
|
|
||||||
sqlx::query_as("SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version")
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|(d, t, v, c)| ((d, t, v as u32), c))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn record_migration(
|
|
||||||
pool: &sqlx::PgPool,
|
|
||||||
m: &Migration,
|
|
||||||
checksum: &str,
|
|
||||||
) -> Result<()> {
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO _sql_migrations (domain, table_name, version, checksum)
|
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
ON CONFLICT DO NOTHING
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&m.domain)
|
|
||||||
.bind(&m.table)
|
|
||||||
.bind(m.version as i32)
|
|
||||||
.bind(checksum)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn delete_migration(pool: &sqlx::PgPool, m: &Migration) -> Result<()> {
|
|
||||||
sqlx::query(
|
|
||||||
"DELETE FROM _sql_migrations WHERE domain = $1 AND table_name = $2 AND version = $3",
|
|
||||||
)
|
|
||||||
.bind(&m.domain)
|
|
||||||
.bind(&m.table)
|
|
||||||
.bind(m.version as i32)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compute_checksum(content: &str) -> String {
|
|
||||||
use std::collections::hash_map::DefaultHasher;
|
|
||||||
use std::hash::{Hash, Hasher};
|
|
||||||
let mut hasher = DefaultHasher::new();
|
|
||||||
content.hash(&mut hasher);
|
|
||||||
format!("{:x}", hasher.finish())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_depends_on(content: &str) -> Vec<String> {
|
|
||||||
content
|
|
||||||
.lines()
|
|
||||||
.filter_map(|line| {
|
|
||||||
let line = line.trim();
|
|
||||||
line.strip_prefix("-- depends_on:").map(|deps| {
|
|
||||||
deps.split(',')
|
|
||||||
.map(|d| d.trim().to_string())
|
|
||||||
.filter(|d| !d.is_empty())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.flatten()
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
|
|
||||||
let table_to_idx: HashMap<String, usize> = migrations
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, m)| (m.table.clone(), i))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let n = migrations.len();
|
|
||||||
let mut in_degree = vec![0u32; n];
|
|
||||||
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
|
||||||
|
|
||||||
for (i, m) in migrations.iter().enumerate() {
|
|
||||||
for dep in &m.depends_on {
|
|
||||||
if let Some(&j) = table_to_idx.get(dep) {
|
|
||||||
adj[j].push(i);
|
|
||||||
in_degree[i] += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut queue: VecDeque<usize> =
|
|
||||||
(0..n).filter(|&i| in_degree[i] == 0).collect();
|
|
||||||
|
|
||||||
let mut order = Vec::with_capacity(n);
|
|
||||||
while let Some(i) = queue.pop_front() {
|
|
||||||
order.push(i);
|
|
||||||
for &next in &adj[i] {
|
|
||||||
in_degree[next] -= 1;
|
|
||||||
if in_degree[next] == 0 {
|
|
||||||
queue.push_back(next);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if order.len() != n {
|
|
||||||
bail!("Circular dependency detected among migrations");
|
|
||||||
}
|
|
||||||
|
|
||||||
let original: Vec<Migration> = migrations.iter().cloned().collect();
|
|
||||||
for (slot, &idx) in order.iter().enumerate() {
|
|
||||||
migrations[slot] = original[idx].clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
fn into_static(s: String) -> &'static str {
|
|
||||||
Box::leak(s.into_boxed_str())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn exec_sql(pool: &sqlx::PgPool, sql: &str) -> Result<()> {
|
|
||||||
sqlx::raw_sql(into_static(sql.to_owned()))
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
|
||||||
}
|
migrate::run_up(&pool).await
|
||||||
|
|
||||||
async fn run_up(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
|
||||||
ensure_migrations_table(pool).await?;
|
|
||||||
let all = discover_migrations(sql_root)?;
|
|
||||||
let applied = applied_set(pool).await?;
|
|
||||||
|
|
||||||
let mut up_migrations: Vec<_> = all
|
|
||||||
.into_iter()
|
|
||||||
.filter(|m| m.direction == MigrationDir::Up)
|
|
||||||
.filter(|m| {
|
|
||||||
!applied.contains_key(&(
|
|
||||||
m.domain.clone(),
|
|
||||||
m.table.clone(),
|
|
||||||
m.version,
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if up_migrations.is_empty() {
|
|
||||||
info!("All migrations are already applied.");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
topo_sort(&mut up_migrations)?;
|
|
||||||
|
|
||||||
for m in &up_migrations {
|
|
||||||
let sql = std::fs::read_to_string(&m.path)
|
|
||||||
.context(format!("Failed to read {:?}", m.path))?;
|
|
||||||
let checksum = compute_checksum(&sql);
|
|
||||||
|
|
||||||
info!("Applying {}/{}/v{}", m.domain, m.table, m.version);
|
|
||||||
exec_sql(pool, &sql).await?;
|
|
||||||
record_migration(pool, m, &checksum).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("Applied {} migration(s).", up_migrations.len());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_down(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
|
||||||
ensure_migrations_table(pool).await?;
|
|
||||||
let all = discover_migrations(sql_root)?;
|
|
||||||
let applied = applied_set(pool).await?;
|
|
||||||
|
|
||||||
let mut down_targets: Vec<_> = all
|
|
||||||
.into_iter()
|
|
||||||
.filter(|m| m.direction == MigrationDir::Down)
|
|
||||||
.filter(|m| {
|
|
||||||
applied.contains_key(&(
|
|
||||||
m.domain.clone(),
|
|
||||||
m.table.clone(),
|
|
||||||
m.version,
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
down_targets.sort();
|
|
||||||
|
|
||||||
if down_targets.is_empty() {
|
|
||||||
info!("No migrations to roll back.");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let m = &down_targets[down_targets.len() - 1];
|
|
||||||
let sql = std::fs::read_to_string(&m.path)?;
|
|
||||||
|
|
||||||
info!("Rolling back {}/{}/v{}", m.domain, m.table, m.version);
|
|
||||||
exec_sql(pool, &sql).await?;
|
|
||||||
delete_migration(pool, m).await?;
|
|
||||||
|
|
||||||
info!("Rolled back 1 migration.");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_fresh(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
|
||||||
info!("Dropping all tables and re-applying migrations...");
|
|
||||||
|
|
||||||
exec_sql(pool, "DROP TABLE IF EXISTS _sql_migrations CASCADE").await?;
|
|
||||||
|
|
||||||
let all = discover_migrations(sql_root)?;
|
|
||||||
let down_migrations: Vec<_> = all
|
|
||||||
.into_iter()
|
|
||||||
.filter(|m| m.direction == MigrationDir::Down)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut drops: Vec<_> = down_migrations.iter().collect();
|
|
||||||
drops.sort();
|
|
||||||
drops.reverse();
|
|
||||||
|
|
||||||
for m in &drops {
|
|
||||||
let sql = std::fs::read_to_string(&m.path)?;
|
|
||||||
let _ = exec_sql(pool, &sql).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
run_up(pool, sql_root).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_list(pool: &sqlx::PgPool, sql_root: &Path) -> Result<()> {
|
|
||||||
ensure_migrations_table(pool).await?;
|
|
||||||
let all = discover_migrations(sql_root)?;
|
|
||||||
let applied = applied_set(pool).await?;
|
|
||||||
|
|
||||||
let up_migrations: Vec<_> = all
|
|
||||||
.into_iter()
|
|
||||||
.filter(|m| m.direction == MigrationDir::Up)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"{:<20} {:<30} {:>8} {}",
|
|
||||||
"Domain", "Table", "Version", "Status"
|
|
||||||
);
|
|
||||||
println!("{:-<20} {:-<30} {:-<8} {:-<10}", "", "", "", "");
|
|
||||||
|
|
||||||
for m in &up_migrations {
|
|
||||||
let key = (m.domain.clone(), m.table.clone(), m.version);
|
|
||||||
let status = if applied.contains_key(&key) {
|
|
||||||
"Applied"
|
|
||||||
} else {
|
|
||||||
"Pending"
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"{:<20} {:<30} {:>8} {}",
|
|
||||||
m.domain, m.table, m.version, status
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use async_nats::{HeaderMap, jetstream};
|
|||||||
use config::AppConfig;
|
use config::AppConfig;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
use track::CounterVec;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
handler::{AckAction, MessageHandler},
|
handler::{AckAction, MessageHandler},
|
||||||
@ -16,6 +17,7 @@ pub struct NatsConsumer {
|
|||||||
max_deliver: i64,
|
max_deliver: i64,
|
||||||
retry_delay_secs: u64,
|
retry_delay_secs: u64,
|
||||||
durable_name: String,
|
durable_name: String,
|
||||||
|
metrics: Option<track::MetricsRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NatsConsumer {
|
impl NatsConsumer {
|
||||||
@ -33,9 +35,15 @@ impl NatsConsumer {
|
|||||||
max_deliver: config.nats_max_deliver(),
|
max_deliver: config.nats_max_deliver(),
|
||||||
retry_delay_secs: config.nats_retry_delay_secs(),
|
retry_delay_secs: config.nats_retry_delay_secs(),
|
||||||
durable_name: durable_name(group_id),
|
durable_name: durable_name(group_id),
|
||||||
|
metrics: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||||
|
self.producer.set_metrics(registry.clone());
|
||||||
|
self.metrics = Some(registry);
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn start_consuming<H>(
|
pub async fn start_consuming<H>(
|
||||||
&self,
|
&self,
|
||||||
topics: &[&str],
|
topics: &[&str],
|
||||||
@ -61,9 +69,10 @@ impl NatsConsumer {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
info!("NATS consumer started subscribing to: {:?}", topics_owned);
|
info!(topics = ?topics_owned, durable = %self.durable_name, "NATS consumer started");
|
||||||
|
|
||||||
let producer = self.producer.clone();
|
let producer = self.producer.clone();
|
||||||
|
let metrics = self.metrics.clone();
|
||||||
let max_deliver = self.max_deliver;
|
let max_deliver = self.max_deliver;
|
||||||
let retry_delay_secs = self.retry_delay_secs;
|
let retry_delay_secs = self.retry_delay_secs;
|
||||||
let handler = Arc::new(handler);
|
let handler = Arc::new(handler);
|
||||||
@ -73,10 +82,7 @@ impl NatsConsumer {
|
|||||||
let mut messages = match messages {
|
let mut messages = match messages {
|
||||||
Ok(messages) => messages,
|
Ok(messages) => messages,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(
|
error!(error = %error, "NATS error while opening consumer stream");
|
||||||
"NATS error while opening consumer stream: {:?}",
|
|
||||||
error
|
|
||||||
);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -86,6 +92,7 @@ impl NatsConsumer {
|
|||||||
Ok(message) => {
|
Ok(message) => {
|
||||||
handle_message(
|
handle_message(
|
||||||
&producer,
|
&producer,
|
||||||
|
metrics.as_ref(),
|
||||||
max_deliver,
|
max_deliver,
|
||||||
retry_delay_secs,
|
retry_delay_secs,
|
||||||
handler.as_ref(),
|
handler.as_ref(),
|
||||||
@ -94,7 +101,7 @@ impl NatsConsumer {
|
|||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!("NATS error while consuming: {:?}", error);
|
error!(error = %error, "NATS error while consuming");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +113,7 @@ impl NatsConsumer {
|
|||||||
|
|
||||||
async fn handle_message<H>(
|
async fn handle_message<H>(
|
||||||
producer: &NatsProducer,
|
producer: &NatsProducer,
|
||||||
|
metrics: Option<&track::MetricsRegistry>,
|
||||||
max_deliver: i64,
|
max_deliver: i64,
|
||||||
retry_delay_secs: u64,
|
retry_delay_secs: u64,
|
||||||
handler: &H,
|
handler: &H,
|
||||||
@ -116,12 +124,17 @@ async fn handle_message<H>(
|
|||||||
let subject = message.subject.to_string();
|
let subject = message.subject.to_string();
|
||||||
let payload = message.payload.clone();
|
let payload = message.payload.clone();
|
||||||
let delivered = message.info().map(|info| info.delivered).unwrap_or(1);
|
let delivered = message.info().map(|info| info.delivered).unwrap_or(1);
|
||||||
|
record_queue_message(metrics, &subject, "received");
|
||||||
|
|
||||||
match handler.handle(&subject, &payload).await {
|
match handler.handle(&subject, &payload).await {
|
||||||
AckAction::Ack => ack_message(&message, &subject, "message").await,
|
AckAction::Ack => {
|
||||||
|
ack_message(metrics, &message, &subject, "message").await
|
||||||
|
}
|
||||||
AckAction::Nack => {
|
AckAction::Nack => {
|
||||||
|
record_queue_message(metrics, &subject, "nack");
|
||||||
if let Err(error) = handle_nack(
|
if let Err(error) = handle_nack(
|
||||||
producer,
|
producer,
|
||||||
|
metrics,
|
||||||
&message,
|
&message,
|
||||||
&subject,
|
&subject,
|
||||||
&payload,
|
&payload,
|
||||||
@ -131,9 +144,13 @@ async fn handle_message<H>(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
record_queue_message(metrics, &subject, "error");
|
||||||
error!(
|
error!(
|
||||||
"Failed to route NACKed message from subject {}: {:?}",
|
subject = %subject,
|
||||||
subject, error
|
delivered,
|
||||||
|
max_deliver,
|
||||||
|
error = %error,
|
||||||
|
"failed to route NACKed message"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,6 +159,7 @@ async fn handle_message<H>(
|
|||||||
|
|
||||||
async fn handle_nack(
|
async fn handle_nack(
|
||||||
producer: &NatsProducer,
|
producer: &NatsProducer,
|
||||||
|
metrics: Option<&track::MetricsRegistry>,
|
||||||
message: &jetstream::Message,
|
message: &jetstream::Message,
|
||||||
subject: &str,
|
subject: &str,
|
||||||
payload: &[u8],
|
payload: &[u8],
|
||||||
@ -151,8 +169,11 @@ async fn handle_nack(
|
|||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
if delivered < max_deliver {
|
if delivered < max_deliver {
|
||||||
warn!(
|
warn!(
|
||||||
"Message in subject {} failed (NACK). Retrying delivery {}/{} in {} seconds",
|
subject,
|
||||||
subject, delivered, max_deliver, retry_delay_secs
|
delivered,
|
||||||
|
max_deliver,
|
||||||
|
retry_delay_secs,
|
||||||
|
"message NACKed, scheduling retry"
|
||||||
);
|
);
|
||||||
message
|
message
|
||||||
.ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs(
|
.ack_with(jetstream::AckKind::Nak(Some(Duration::from_secs(
|
||||||
@ -162,13 +183,17 @@ async fn handle_nack(
|
|||||||
.map_err(|error| {
|
.map_err(|error| {
|
||||||
anyhow::anyhow!("failed to nack message: {error}")
|
anyhow::anyhow!("failed to nack message: {error}")
|
||||||
})?;
|
})?;
|
||||||
|
record_queue_message(metrics, subject, "retry");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let dlq_subject = format!("{subject}.dlq");
|
let dlq_subject = format!("{subject}.dlq");
|
||||||
error!(
|
error!(
|
||||||
"Message in subject {} exceeded max deliver attempts ({}). Routing to DLQ: {}",
|
subject,
|
||||||
subject, max_deliver, dlq_subject
|
dlq_subject = %dlq_subject,
|
||||||
|
delivered,
|
||||||
|
max_deliver,
|
||||||
|
"message exceeded max deliver attempts, routing to DLQ"
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
@ -185,22 +210,69 @@ async fn handle_nack(
|
|||||||
message.ack().await.map_err(|error| {
|
message.ack().await.map_err(|error| {
|
||||||
anyhow::anyhow!("failed to ack DLQ message: {error}")
|
anyhow::anyhow!("failed to ack DLQ message: {error}")
|
||||||
})?;
|
})?;
|
||||||
|
record_queue_message(metrics, subject, "dlq");
|
||||||
|
record_queue_dlq(metrics, subject);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ack_message(
|
async fn ack_message(
|
||||||
|
metrics: Option<&track::MetricsRegistry>,
|
||||||
message: &jetstream::Message,
|
message: &jetstream::Message,
|
||||||
subject: &str,
|
subject: &str,
|
||||||
description: &str,
|
description: &str,
|
||||||
) {
|
) {
|
||||||
if let Err(error) = message.ack().await {
|
match message.ack().await {
|
||||||
error!(
|
Ok(()) => record_queue_message(metrics, subject, "ack"),
|
||||||
"Failed to ack {} in subject {}: {:?}",
|
Err(error) => {
|
||||||
description, subject, error
|
record_queue_message(metrics, subject, "ack_error");
|
||||||
);
|
error!(
|
||||||
|
subject,
|
||||||
|
description,
|
||||||
|
error = %error,
|
||||||
|
"failed to ack message"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn record_queue_message(
|
||||||
|
metrics: Option<&track::MetricsRegistry>,
|
||||||
|
topic: &str,
|
||||||
|
status: &str,
|
||||||
|
) {
|
||||||
|
if let Some(metrics) = metrics {
|
||||||
|
queue_messages_vec(metrics)
|
||||||
|
.with_label_values(&[topic, status])
|
||||||
|
.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_queue_dlq(metrics: Option<&track::MetricsRegistry>, topic: &str) {
|
||||||
|
if let Some(metrics) = metrics {
|
||||||
|
queue_dlq_vec(metrics).with_label_values(&[topic]).inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"queue_messages_total",
|
||||||
|
"Total queue messages",
|
||||||
|
&["topic", "status"],
|
||||||
|
)
|
||||||
|
.expect("failed to register queue_messages_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn queue_dlq_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"queue_dlq_total",
|
||||||
|
"Total messages routed to DLQ",
|
||||||
|
&["topic"],
|
||||||
|
)
|
||||||
|
.expect("failed to register queue_dlq_total")
|
||||||
|
}
|
||||||
|
|
||||||
fn durable_name(name: &str) -> String {
|
fn durable_name(name: &str) -> String {
|
||||||
name.replace('.', "-")
|
name.replace('.', "-")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,10 +3,12 @@ use std::time::Duration;
|
|||||||
use async_nats::{HeaderMap, jetstream};
|
use async_nats::{HeaderMap, jetstream};
|
||||||
use config::AppConfig;
|
use config::AppConfig;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use track::CounterVec;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct NatsProducer {
|
pub struct NatsProducer {
|
||||||
jetstream: jetstream::Context,
|
jetstream: jetstream::Context,
|
||||||
|
metrics: Option<track::MetricsRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NatsProducer {
|
impl NatsProducer {
|
||||||
@ -14,7 +16,14 @@ impl NatsProducer {
|
|||||||
let jetstream = connect_jetstream(config).await?;
|
let jetstream = connect_jetstream(config).await?;
|
||||||
ensure_stream(config, &jetstream).await?;
|
ensure_stream(config, &jetstream).await?;
|
||||||
|
|
||||||
Ok(Self { jetstream })
|
Ok(Self {
|
||||||
|
jetstream,
|
||||||
|
metrics: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||||
|
self.metrics = Some(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send<T>(
|
pub async fn send<T>(
|
||||||
@ -44,19 +53,37 @@ impl NatsProducer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let subject = subject.to_string();
|
let subject = subject.to_string();
|
||||||
let publish = if headers.is_empty() {
|
let publish_result: anyhow::Result<()> = async {
|
||||||
self.jetstream
|
let publish = if headers.is_empty() {
|
||||||
.publish(subject.clone(), payload.to_vec().into())
|
self.jetstream
|
||||||
.await?
|
.publish(subject.clone(), payload.to_vec().into())
|
||||||
} else {
|
.await?
|
||||||
self.jetstream
|
} else {
|
||||||
.publish_with_headers(subject, headers, payload.to_vec().into())
|
self.jetstream
|
||||||
.await?
|
.publish_with_headers(
|
||||||
};
|
subject.clone(),
|
||||||
|
headers,
|
||||||
|
payload.to_vec().into(),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
tokio::time::timeout(Duration::from_secs(5), publish).await??;
|
tokio::time::timeout(Duration::from_secs(5), publish).await??;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
Ok(())
|
self.record_published(&subject, publish_result.is_ok());
|
||||||
|
publish_result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_published(&self, topic: &str, success: bool) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
let status = if success { "published" } else { "error" };
|
||||||
|
queue_messages_vec(reg)
|
||||||
|
.with_label_values(&[topic, status])
|
||||||
|
.inc();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,3 +115,13 @@ pub async fn ensure_stream(
|
|||||||
})
|
})
|
||||||
.await?)
|
.await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn queue_messages_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"queue_messages_total",
|
||||||
|
"Total queue messages",
|
||||||
|
&["topic", "status"],
|
||||||
|
)
|
||||||
|
.expect("failed to register queue_messages_total")
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ use aws_sdk_s3::primitives::ByteStreamError;
|
|||||||
pub use error::{StorageError, StorageResult};
|
pub use error::{StorageError, StorageResult};
|
||||||
pub use local::{LocalStorage, LocalStorageConfig};
|
pub use local::{LocalStorage, LocalStorageConfig};
|
||||||
pub use s3::{S3Storage, S3StorageConfig};
|
pub use s3::{S3Storage, S3StorageConfig};
|
||||||
|
use track::CounterVec;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum AppStorageConfig {
|
pub enum AppStorageConfig {
|
||||||
@ -18,11 +19,60 @@ pub enum AppStorageConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum AppStorage {
|
pub struct AppStorage {
|
||||||
|
inner: StorageBackend,
|
||||||
|
metrics: Option<track::MetricsRegistry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum StorageBackend {
|
||||||
Local(LocalStorage),
|
Local(LocalStorage),
|
||||||
S3(S3Storage),
|
S3(S3Storage),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AppStorage {
|
||||||
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
||||||
|
self.metrics = Some(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backend_name(&self) -> &str {
|
||||||
|
match &self.inner {
|
||||||
|
StorageBackend::Local(_) => "local",
|
||||||
|
StorageBackend::S3(_) => "s3",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_upload(&self, bytes: usize) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
storage_ops_vec(reg)
|
||||||
|
.with_label_values(&["upload", self.backend_name()])
|
||||||
|
.inc();
|
||||||
|
storage_bytes_vec(reg)
|
||||||
|
.with_label_values(&["upload"])
|
||||||
|
.inc_by(bytes as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_download(&self, bytes: usize) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
storage_ops_vec(reg)
|
||||||
|
.with_label_values(&["download", self.backend_name()])
|
||||||
|
.inc();
|
||||||
|
storage_bytes_vec(reg)
|
||||||
|
.with_label_values(&["download"])
|
||||||
|
.inc_by(bytes as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_delete(&self) {
|
||||||
|
if let Some(reg) = &self.metrics {
|
||||||
|
storage_ops_vec(reg)
|
||||||
|
.with_label_values(&["delete", self.backend_name()])
|
||||||
|
.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Clone, Debug, Default)]
|
||||||
pub struct PutObjectOptions {
|
pub struct PutObjectOptions {
|
||||||
pub content_type: Option<String>,
|
pub content_type: Option<String>,
|
||||||
@ -87,76 +137,109 @@ pub trait ObjectStorage: Send + Sync {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AppStorage {
|
impl AppStorage {
|
||||||
|
#[tracing::instrument(skip(config))]
|
||||||
pub async fn init(config: AppStorageConfig) -> StorageResult<Self> {
|
pub async fn init(config: AppStorageConfig) -> StorageResult<Self> {
|
||||||
match config {
|
let inner = match config {
|
||||||
AppStorageConfig::Local(config) => {
|
AppStorageConfig::Local(config) => {
|
||||||
Ok(Self::Local(LocalStorage::connect(config).await?))
|
tracing::info!("initializing local storage");
|
||||||
|
StorageBackend::Local(LocalStorage::connect(config).await?)
|
||||||
}
|
}
|
||||||
AppStorageConfig::S3(config) => {
|
AppStorageConfig::S3(config) => {
|
||||||
Ok(Self::S3(S3Storage::connect(config).await?))
|
tracing::info!(bucket = %config.bucket, region = %config.region, "initializing S3 storage");
|
||||||
|
StorageBackend::S3(S3Storage::connect(config).await?)
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
Ok(Self {
|
||||||
|
inner,
|
||||||
|
metrics: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ObjectStorage for AppStorage {
|
impl ObjectStorage for AppStorage {
|
||||||
|
#[tracing::instrument(skip(self, body), fields(storage.key = %key))]
|
||||||
async fn put_stream(
|
async fn put_stream(
|
||||||
&self,
|
&self,
|
||||||
key: &str,
|
key: &str,
|
||||||
body: ByteStream,
|
body: ByteStream,
|
||||||
options: PutObjectOptions,
|
options: PutObjectOptions,
|
||||||
) -> StorageResult<StoredObject> {
|
) -> StorageResult<StoredObject> {
|
||||||
match self {
|
let result = match &self.inner {
|
||||||
Self::Local(storage) => {
|
StorageBackend::Local(storage) => {
|
||||||
storage.put_stream(key, body, options).await
|
storage.put_stream(key, body, options).await
|
||||||
}
|
}
|
||||||
Self::S3(storage) => storage.put_stream(key, body, options).await,
|
StorageBackend::S3(storage) => {
|
||||||
|
storage.put_stream(key, body, options).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
self.record_upload(0);
|
||||||
}
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self, bytes), fields(storage.key = %key, storage.size = bytes.len()))]
|
||||||
async fn put_bytes(
|
async fn put_bytes(
|
||||||
&self,
|
&self,
|
||||||
key: &str,
|
key: &str,
|
||||||
bytes: Vec<u8>,
|
bytes: Vec<u8>,
|
||||||
options: PutObjectOptions,
|
options: PutObjectOptions,
|
||||||
) -> StorageResult<StoredObject> {
|
) -> StorageResult<StoredObject> {
|
||||||
match self {
|
let size = bytes.len();
|
||||||
Self::Local(storage) => {
|
let result = match &self.inner {
|
||||||
|
StorageBackend::Local(storage) => {
|
||||||
storage.put_bytes(key, bytes, options).await
|
storage.put_bytes(key, bytes, options).await
|
||||||
}
|
}
|
||||||
Self::S3(storage) => storage.put_bytes(key, bytes, options).await,
|
StorageBackend::S3(storage) => {
|
||||||
|
storage.put_bytes(key, bytes, options).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
self.record_upload(size);
|
||||||
}
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||||
async fn get_stream(
|
async fn get_stream(
|
||||||
&self,
|
&self,
|
||||||
key: &str,
|
key: &str,
|
||||||
) -> StorageResult<StorageObjectStream> {
|
) -> StorageResult<StorageObjectStream> {
|
||||||
match self {
|
match &self.inner {
|
||||||
Self::Local(storage) => storage.get_stream(key).await,
|
StorageBackend::Local(storage) => storage.get_stream(key).await,
|
||||||
Self::S3(storage) => storage.get_stream(key).await,
|
StorageBackend::S3(storage) => storage.get_stream(key).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||||
async fn get_bytes(&self, key: &str) -> StorageResult<StorageObject> {
|
async fn get_bytes(&self, key: &str) -> StorageResult<StorageObject> {
|
||||||
match self {
|
let result = match &self.inner {
|
||||||
Self::Local(storage) => storage.get_bytes(key).await,
|
StorageBackend::Local(storage) => storage.get_bytes(key).await,
|
||||||
Self::S3(storage) => storage.get_bytes(key).await,
|
StorageBackend::S3(storage) => storage.get_bytes(key).await,
|
||||||
|
};
|
||||||
|
if let Ok(obj) = &result {
|
||||||
|
self.record_download(obj.bytes.len());
|
||||||
}
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self), fields(storage.key = %key))]
|
||||||
async fn delete(&self, key: &str) -> StorageResult<()> {
|
async fn delete(&self, key: &str) -> StorageResult<()> {
|
||||||
match self {
|
let result = match &self.inner {
|
||||||
Self::Local(storage) => storage.delete(key).await,
|
StorageBackend::Local(storage) => storage.delete(key).await,
|
||||||
Self::S3(storage) => storage.delete(key).await,
|
StorageBackend::S3(storage) => storage.delete(key).await,
|
||||||
|
};
|
||||||
|
if result.is_ok() {
|
||||||
|
self.record_delete();
|
||||||
}
|
}
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn public_url(&self, key: &str) -> StorageResult<Option<String>> {
|
fn public_url(&self, key: &str) -> StorageResult<Option<String>> {
|
||||||
match self {
|
match &self.inner {
|
||||||
Self::Local(storage) => storage.public_url(key),
|
StorageBackend::Local(storage) => storage.public_url(key),
|
||||||
Self::S3(storage) => storage.public_url(key),
|
StorageBackend::S3(storage) => storage.public_url(key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,17 +248,37 @@ impl ObjectStorage for AppStorage {
|
|||||||
key: &str,
|
key: &str,
|
||||||
expires_in: Duration,
|
expires_in: Duration,
|
||||||
) -> StorageResult<String> {
|
) -> StorageResult<String> {
|
||||||
match self {
|
match &self.inner {
|
||||||
Self::Local(storage) => {
|
StorageBackend::Local(storage) => {
|
||||||
storage.presigned_get_url(key, expires_in).await
|
storage.presigned_get_url(key, expires_in).await
|
||||||
}
|
}
|
||||||
Self::S3(storage) => {
|
StorageBackend::S3(storage) => {
|
||||||
storage.presigned_get_url(key, expires_in).await
|
storage.presigned_get_url(key, expires_in).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn storage_ops_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"storage_operations_total",
|
||||||
|
"Total storage operations",
|
||||||
|
&["operation", "backend"],
|
||||||
|
)
|
||||||
|
.expect("failed to register storage_operations_total")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_bytes_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
||||||
|
registry
|
||||||
|
.register_counter_vec(
|
||||||
|
"storage_bytes_total",
|
||||||
|
"Total bytes transferred",
|
||||||
|
&["operation"],
|
||||||
|
)
|
||||||
|
.expect("failed to register storage_bytes_total")
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn collect_byte_stream(
|
pub async fn collect_byte_stream(
|
||||||
body: ByteStream,
|
body: ByteStream,
|
||||||
) -> Result<Vec<u8>, ByteStreamError> {
|
) -> Result<Vec<u8>, ByteStreamError> {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user