use std::{str::FromStr, time::Duration}; use config::AppConfig; use sqlx::{ AssertSqlSafe, ConnectOptions, FromRow, PgPool, postgres::{ PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, }, }; use track::{CounterVec, HistogramVec}; use crate::{ route::{SqlRoute, route_sql}, transaction::AppTransaction, }; #[derive(Clone)] pub struct AppDatabase { db_write: PgPool, db_read: Option, metrics: Option, } impl AppDatabase { #[tracing::instrument(skip(cfg))] pub async fn init(cfg: &AppConfig) -> anyhow::Result { let db_url = cfg.database_url()?; let max_connections = cfg.database_max_connections()?; let min_connections = cfg.database_min_connections()?; let idle_timeout = cfg.database_idle_timeout()?; let max_lifetime = cfg.database_max_lifetime()?; let connection_timeout = cfg.database_connection_timeout()?; let schema_search_path = cfg.database_schema_search_path()?; let read_replica = cfg.database_read_replicas()?; let write_options = build_pg_options(&db_url, &schema_search_path)?; let db_write = build_pool( write_options, max_connections, min_connections, idle_timeout, max_lifetime, connection_timeout, ) .await?; sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) .execute(&db_write) .await?; let db_read = if let Some(replica_url) = read_replica { let read_options = build_pg_options(&replica_url, &schema_search_path)?; let pool = build_pool( read_options, max_connections, min_connections, idle_timeout, max_lifetime, connection_timeout, ) .await?; sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) .execute(&pool) .await?; Some(pool) } else { None }; Ok(Self { db_write, db_read, metrics: None, }) } pub fn set_metrics(&mut self, registry: track::MetricsRegistry) { self.metrics = Some(registry); } pub fn writer(&self) -> &PgPool { &self.db_write } pub fn reader(&self) -> &PgPool { self.db_read.as_ref().unwrap_or(&self.db_write) } pub fn route_pool(&self, sql: &str) -> &PgPool { match route_sql(sql) { SqlRoute::Write => self.writer(), SqlRoute::Read => self.reader(), } } #[tracing::instrument(skip(self), fields(sql.route = "write"))] pub async fn begin(&self) -> Result, sqlx::Error> { let txn = self.db_write.begin().await?; tracing::debug!("db transaction started"); Ok(AppTransaction { inner: txn }) } #[tracing::instrument(skip(self), fields(sql.route = "read"))] pub async fn begin_read_only( &self, ) -> Result, sqlx::Error> { let mut txn = self.reader().begin().await?; sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned())) .execute(&mut *txn) .await?; tracing::debug!("db read-only transaction started"); Ok(AppTransaction { inner: txn }) } #[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))] pub async fn execute( &self, sql: &str, ) -> Result { self.execute_with_args(sql, PgArguments::default()).await } #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))] pub async fn execute_with_args( &self, sql: &str, args: PgArguments, ) -> Result { let pool = self.route_pool(sql); let start = std::time::Instant::now(); let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) .execute(pool) .await; let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") { "insert" } else if sql.trim_start().to_uppercase().starts_with("UPDATE") { "update" } else if sql.trim_start().to_uppercase().starts_with("DELETE") { "delete" } else { "execute" }; self.record_query( kind, self.route_label(sql), start.elapsed(), result.is_ok(), ); result } #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))] pub async fn fetch_one(&self, sql: &str) -> Result where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_one_with_args(sql, PgArguments::default()).await } #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))] pub async fn fetch_one_with_args( &self, sql: &str, args: PgArguments, ) -> Result where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); let start = std::time::Instant::now(); let result = sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_one(pool) .await; self.record_query( "select", self.route_label(sql), start.elapsed(), result.is_ok(), ); result } #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))] pub async fn fetch_optional( &self, sql: &str, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_optional_with_args(sql, PgArguments::default()) .await } #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))] pub async fn fetch_optional_with_args( &self, sql: &str, args: PgArguments, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); let start = std::time::Instant::now(); let result = sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_optional(pool) .await; self.record_query( "select", self.route_label(sql), start.elapsed(), result.is_ok(), ); result } #[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))] pub async fn fetch_all(&self, sql: &str) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_all_with_args(sql, PgArguments::default()).await } #[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))] pub async fn fetch_all_with_args( &self, sql: &str, args: PgArguments, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); let start = std::time::Instant::now(); let result = sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_all(pool) .await; self.record_query( "select", self.route_label(sql), start.elapsed(), result.is_ok(), ); result } fn route_label(&self, sql: &str) -> &str { match route_sql(sql) { SqlRoute::Write => "write", SqlRoute::Read => "read", } } fn record_query( &self, kind: &str, route: &str, duration: Duration, success: bool, ) { if let Some(reg) = &self.metrics { let status = if success { "success" } else { "error" }; db_queries_vec(reg) .with_label_values(&[kind, route, status]) .inc(); db_query_duration_vec(reg) .with_label_values(&[kind, route]) .observe(duration.as_secs_f64()); } } } fn build_pg_options( db_url: &str, schema_search_path: &str, ) -> anyhow::Result { let options = PgConnectOptions::from_str(db_url)? .options([("search_path", schema_search_path)]) .disable_statement_logging(); Ok(options) } async fn build_pool( options: PgConnectOptions, max_connections: u32, min_connections: u32, idle_timeout_secs: u64, max_lifetime_secs: u64, connection_timeout_secs: u64, ) -> Result { let mut pool_options = PgPoolOptions::new() .max_connections(max_connections) .min_connections(min_connections) .acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1))); if idle_timeout_secs > 0 { pool_options = pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs)); } if max_lifetime_secs > 0 { pool_options = pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs)); } pool_options.connect_with(options).await } fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec { registry .register_counter_vec( "db_queries_total", "Total database queries", &["kind", "route", "status"], ) .expect("failed to register db_queries_total") } fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec { registry .register_histogram_vec( "db_query_duration_seconds", "DB query duration in seconds", &["kind", "route"], vec![ 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, ], ) .expect("failed to register db_query_duration_seconds") }