233 lines
6.2 KiB
Rust
233 lines
6.2 KiB
Rust
use std::{str::FromStr, time::Duration};
|
|
|
|
use config::AppConfig;
|
|
use sqlx::{
|
|
AssertSqlSafe, ConnectOptions, FromRow, PgPool,
|
|
postgres::{
|
|
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
|
|
},
|
|
};
|
|
|
|
use crate::{
|
|
route::{SqlRoute, route_sql},
|
|
transaction::AppTransaction,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppDatabase {
|
|
db_write: PgPool,
|
|
db_read: Option<PgPool>,
|
|
}
|
|
|
|
impl AppDatabase {
|
|
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
|
let db_url = cfg.database_url()?;
|
|
let max_connections = cfg.database_max_connections()?;
|
|
let min_connections = cfg.database_min_connections()?;
|
|
let idle_timeout = cfg.database_idle_timeout()?;
|
|
let max_lifetime = cfg.database_max_lifetime()?;
|
|
let connection_timeout = cfg.database_connection_timeout()?;
|
|
let schema_search_path = cfg.database_schema_search_path()?;
|
|
let read_replica = cfg.database_read_replicas()?;
|
|
|
|
let write_options = build_pg_options(&db_url, &schema_search_path)?;
|
|
|
|
let db_write = build_pool(
|
|
write_options,
|
|
max_connections,
|
|
min_connections,
|
|
idle_timeout,
|
|
max_lifetime,
|
|
connection_timeout,
|
|
)
|
|
.await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
|
|
.execute(&db_write)
|
|
.await?;
|
|
|
|
let db_read = if let Some(replica_url) = read_replica {
|
|
let read_options =
|
|
build_pg_options(&replica_url, &schema_search_path)?;
|
|
|
|
let pool = build_pool(
|
|
read_options,
|
|
max_connections,
|
|
min_connections,
|
|
idle_timeout,
|
|
max_lifetime,
|
|
connection_timeout,
|
|
)
|
|
.await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
Some(pool)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(Self { db_write, db_read })
|
|
}
|
|
|
|
pub fn writer(&self) -> &PgPool {
|
|
&self.db_write
|
|
}
|
|
|
|
pub fn reader(&self) -> &PgPool {
|
|
self.db_read.as_ref().unwrap_or(&self.db_write)
|
|
}
|
|
|
|
pub fn route_pool(&self, sql: &str) -> &PgPool {
|
|
match route_sql(sql) {
|
|
SqlRoute::Write => self.writer(),
|
|
SqlRoute::Read => self.reader(),
|
|
}
|
|
}
|
|
|
|
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
|
|
let txn = self.db_write.begin().await?;
|
|
Ok(AppTransaction { inner: txn })
|
|
}
|
|
|
|
pub async fn begin_read_only(
|
|
&self,
|
|
) -> Result<AppTransaction<'_>, sqlx::Error> {
|
|
let mut txn = self.reader().begin().await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned()))
|
|
.execute(&mut *txn)
|
|
.await?;
|
|
|
|
Ok(AppTransaction { inner: txn })
|
|
}
|
|
|
|
pub async fn execute(
|
|
&self,
|
|
sql: &str,
|
|
) -> Result<PgQueryResult, sqlx::Error> {
|
|
self.execute_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
pub async fn execute_with_args(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<PgQueryResult, sqlx::Error> {
|
|
let pool = self.route_pool(sql);
|
|
|
|
sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
|
.execute(pool)
|
|
.await
|
|
}
|
|
|
|
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_one_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
pub async fn fetch_one_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<T, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_one(pool)
|
|
.await
|
|
}
|
|
|
|
pub async fn fetch_optional<T>(
|
|
&self,
|
|
sql: &str,
|
|
) -> Result<Option<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_optional_with_args(sql, PgArguments::default())
|
|
.await
|
|
}
|
|
|
|
pub async fn fetch_optional_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<Option<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_optional(pool)
|
|
.await
|
|
}
|
|
|
|
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_all_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
pub async fn fetch_all_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<Vec<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_all(pool)
|
|
.await
|
|
}
|
|
}
|
|
|
|
fn build_pg_options(
|
|
db_url: &str,
|
|
schema_search_path: &str,
|
|
) -> anyhow::Result<PgConnectOptions> {
|
|
let options = PgConnectOptions::from_str(db_url)?
|
|
.options([("search_path", schema_search_path)])
|
|
.disable_statement_logging();
|
|
|
|
Ok(options)
|
|
}
|
|
|
|
async fn build_pool(
|
|
options: PgConnectOptions,
|
|
max_connections: u32,
|
|
min_connections: u32,
|
|
idle_timeout_secs: u64,
|
|
max_lifetime_secs: u64,
|
|
connection_timeout_secs: u64,
|
|
) -> Result<PgPool, sqlx::Error> {
|
|
let mut pool_options = PgPoolOptions::new()
|
|
.max_connections(max_connections)
|
|
.min_connections(min_connections)
|
|
.acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1)));
|
|
|
|
if idle_timeout_secs > 0 {
|
|
pool_options =
|
|
pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs));
|
|
}
|
|
|
|
if max_lifetime_secs > 0 {
|
|
pool_options =
|
|
pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs));
|
|
}
|
|
|
|
pool_options.connect_with(options).await
|
|
}
|