use std::{str::FromStr, time::Duration}; use config::AppConfig; use sqlx::{ AssertSqlSafe, ConnectOptions, FromRow, PgPool, postgres::{ PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow, }, }; use crate::{ route::{SqlRoute, route_sql}, transaction::AppTransaction, }; #[derive(Clone)] pub struct AppDatabase { db_write: PgPool, db_read: Option, } impl AppDatabase { pub async fn init(cfg: &AppConfig) -> anyhow::Result { let db_url = cfg.database_url()?; let max_connections = cfg.database_max_connections()?; let min_connections = cfg.database_min_connections()?; let idle_timeout = cfg.database_idle_timeout()?; let max_lifetime = cfg.database_max_lifetime()?; let connection_timeout = cfg.database_connection_timeout()?; let schema_search_path = cfg.database_schema_search_path()?; let read_replica = cfg.database_read_replicas()?; let write_options = build_pg_options(&db_url, &schema_search_path)?; let db_write = build_pool( write_options, max_connections, min_connections, idle_timeout, max_lifetime, connection_timeout, ) .await?; sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) .execute(&db_write) .await?; let db_read = if let Some(replica_url) = read_replica { let read_options = build_pg_options(&replica_url, &schema_search_path)?; let pool = build_pool( read_options, max_connections, min_connections, idle_timeout, max_lifetime, connection_timeout, ) .await?; sqlx::query(AssertSqlSafe("SELECT 1".to_owned())) .execute(&pool) .await?; Some(pool) } else { None }; Ok(Self { db_write, db_read }) } pub fn writer(&self) -> &PgPool { &self.db_write } pub fn reader(&self) -> &PgPool { self.db_read.as_ref().unwrap_or(&self.db_write) } pub fn route_pool(&self, sql: &str) -> &PgPool { match route_sql(sql) { SqlRoute::Write => self.writer(), SqlRoute::Read => self.reader(), } } pub async fn begin(&self) -> Result, sqlx::Error> { let txn = self.db_write.begin().await?; Ok(AppTransaction { inner: txn }) } pub async fn begin_read_only( &self, ) -> Result, sqlx::Error> { let mut txn = self.reader().begin().await?; sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned())) .execute(&mut *txn) .await?; Ok(AppTransaction { inner: txn }) } pub async fn execute( &self, sql: &str, ) -> Result { self.execute_with_args(sql, PgArguments::default()).await } pub async fn execute_with_args( &self, sql: &str, args: PgArguments, ) -> Result { let pool = self.route_pool(sql); sqlx::query_with(AssertSqlSafe(sql.to_owned()), args) .execute(pool) .await } pub async fn fetch_one(&self, sql: &str) -> Result where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_one_with_args(sql, PgArguments::default()).await } pub async fn fetch_one_with_args( &self, sql: &str, args: PgArguments, ) -> Result where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_one(pool) .await } pub async fn fetch_optional( &self, sql: &str, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_optional_with_args(sql, PgArguments::default()) .await } pub async fn fetch_optional_with_args( &self, sql: &str, args: PgArguments, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_optional(pool) .await } pub async fn fetch_all(&self, sql: &str) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { self.fetch_all_with_args(sql, PgArguments::default()).await } pub async fn fetch_all_with_args( &self, sql: &str, args: PgArguments, ) -> Result, sqlx::Error> where for<'r> T: FromRow<'r, PgRow> + Send + Unpin, { let pool = self.route_pool(sql); sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args) .fetch_all(pool) .await } } fn build_pg_options( db_url: &str, schema_search_path: &str, ) -> anyhow::Result { let options = PgConnectOptions::from_str(db_url)? .options([("search_path", schema_search_path)]) .disable_statement_logging(); Ok(options) } async fn build_pool( options: PgConnectOptions, max_connections: u32, min_connections: u32, idle_timeout_secs: u64, max_lifetime_secs: u64, connection_timeout_secs: u64, ) -> Result { let mut pool_options = PgPoolOptions::new() .max_connections(max_connections) .min_connections(min_connections) .acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1))); if idle_timeout_secs > 0 { pool_options = pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs)); } if max_lifetime_secs > 0 { pool_options = pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs)); } pool_options.connect_with(options).await }