gitdataai/lib/db/database.rs
2026-05-30 01:38:40 +08:00

233 lines
6.2 KiB
Rust

use std::{str::FromStr, time::Duration};
use config::AppConfig;
use sqlx::{
AssertSqlSafe, ConnectOptions, FromRow, PgPool,
postgres::{
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
},
};
use crate::{
route::{SqlRoute, route_sql},
transaction::AppTransaction,
};
#[derive(Clone)]
pub struct AppDatabase {
db_write: PgPool,
db_read: Option<PgPool>,
}
impl AppDatabase {
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
let db_url = cfg.database_url()?;
let max_connections = cfg.database_max_connections()?;
let min_connections = cfg.database_min_connections()?;
let idle_timeout = cfg.database_idle_timeout()?;
let max_lifetime = cfg.database_max_lifetime()?;
let connection_timeout = cfg.database_connection_timeout()?;
let schema_search_path = cfg.database_schema_search_path()?;
let read_replica = cfg.database_read_replicas()?;
let write_options = build_pg_options(&db_url, &schema_search_path)?;
let db_write = build_pool(
write_options,
max_connections,
min_connections,
idle_timeout,
max_lifetime,
connection_timeout,
)
.await?;
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
.execute(&db_write)
.await?;
let db_read = if let Some(replica_url) = read_replica {
let read_options =
build_pg_options(&replica_url, &schema_search_path)?;
let pool = build_pool(
read_options,
max_connections,
min_connections,
idle_timeout,
max_lifetime,
connection_timeout,
)
.await?;
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
.execute(&pool)
.await?;
Some(pool)
} else {
None
};
Ok(Self { db_write, db_read })
}
pub fn writer(&self) -> &PgPool {
&self.db_write
}
pub fn reader(&self) -> &PgPool {
self.db_read.as_ref().unwrap_or(&self.db_write)
}
pub fn route_pool(&self, sql: &str) -> &PgPool {
match route_sql(sql) {
SqlRoute::Write => self.writer(),
SqlRoute::Read => self.reader(),
}
}
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
let txn = self.db_write.begin().await?;
Ok(AppTransaction { inner: txn })
}
pub async fn begin_read_only(
&self,
) -> Result<AppTransaction<'_>, sqlx::Error> {
let mut txn = self.reader().begin().await?;
sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned()))
.execute(&mut *txn)
.await?;
Ok(AppTransaction { inner: txn })
}
pub async fn execute(
&self,
sql: &str,
) -> Result<PgQueryResult, sqlx::Error> {
self.execute_with_args(sql, PgArguments::default()).await
}
pub async fn execute_with_args(
&self,
sql: &str,
args: PgArguments,
) -> Result<PgQueryResult, sqlx::Error> {
let pool = self.route_pool(sql);
sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
.execute(pool)
.await
}
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
self.fetch_one_with_args(sql, PgArguments::default()).await
}
pub async fn fetch_one_with_args<T>(
&self,
sql: &str,
args: PgArguments,
) -> Result<T, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_one(pool)
.await
}
pub async fn fetch_optional<T>(
&self,
sql: &str,
) -> Result<Option<T>, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
self.fetch_optional_with_args(sql, PgArguments::default())
.await
}
pub async fn fetch_optional_with_args<T>(
&self,
sql: &str,
args: PgArguments,
) -> Result<Option<T>, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_optional(pool)
.await
}
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
self.fetch_all_with_args(sql, PgArguments::default()).await
}
pub async fn fetch_all_with_args<T>(
&self,
sql: &str,
args: PgArguments,
) -> Result<Vec<T>, sqlx::Error>
where
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
{
let pool = self.route_pool(sql);
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
.fetch_all(pool)
.await
}
}
fn build_pg_options(
db_url: &str,
schema_search_path: &str,
) -> anyhow::Result<PgConnectOptions> {
let options = PgConnectOptions::from_str(db_url)?
.options([("search_path", schema_search_path)])
.disable_statement_logging();
Ok(options)
}
async fn build_pool(
options: PgConnectOptions,
max_connections: u32,
min_connections: u32,
idle_timeout_secs: u64,
max_lifetime_secs: u64,
connection_timeout_secs: u64,
) -> Result<PgPool, sqlx::Error> {
let mut pool_options = PgPoolOptions::new()
.max_connections(max_connections)
.min_connections(min_connections)
.acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1)));
if idle_timeout_secs > 0 {
pool_options =
pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs));
}
if max_lifetime_secs > 0 {
pool_options =
pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs));
}
pool_options.connect_with(options).await
}