352 lines
10 KiB
Rust
352 lines
10 KiB
Rust
use std::{str::FromStr, time::Duration};
|
|
|
|
use config::AppConfig;
|
|
use sqlx::{
|
|
AssertSqlSafe, ConnectOptions, FromRow, PgPool,
|
|
postgres::{
|
|
PgArguments, PgConnectOptions, PgPoolOptions, PgQueryResult, PgRow,
|
|
},
|
|
};
|
|
use track::{CounterVec, HistogramVec};
|
|
|
|
use crate::{
|
|
route::{SqlRoute, route_sql},
|
|
transaction::AppTransaction,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppDatabase {
|
|
db_write: PgPool,
|
|
db_read: Option<PgPool>,
|
|
metrics: Option<track::MetricsRegistry>,
|
|
}
|
|
|
|
impl AppDatabase {
|
|
#[tracing::instrument(skip(cfg))]
|
|
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
|
let db_url = cfg.database_url()?;
|
|
let max_connections = cfg.database_max_connections()?;
|
|
let min_connections = cfg.database_min_connections()?;
|
|
let idle_timeout = cfg.database_idle_timeout()?;
|
|
let max_lifetime = cfg.database_max_lifetime()?;
|
|
let connection_timeout = cfg.database_connection_timeout()?;
|
|
let schema_search_path = cfg.database_schema_search_path()?;
|
|
let read_replica = cfg.database_read_replicas()?;
|
|
|
|
let write_options = build_pg_options(&db_url, &schema_search_path)?;
|
|
|
|
let db_write = build_pool(
|
|
write_options,
|
|
max_connections,
|
|
min_connections,
|
|
idle_timeout,
|
|
max_lifetime,
|
|
connection_timeout,
|
|
)
|
|
.await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
|
|
.execute(&db_write)
|
|
.await?;
|
|
|
|
let db_read = if let Some(replica_url) = read_replica {
|
|
let read_options =
|
|
build_pg_options(&replica_url, &schema_search_path)?;
|
|
|
|
let pool = build_pool(
|
|
read_options,
|
|
max_connections,
|
|
min_connections,
|
|
idle_timeout,
|
|
max_lifetime,
|
|
connection_timeout,
|
|
)
|
|
.await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SELECT 1".to_owned()))
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
Some(pool)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(Self {
|
|
db_write,
|
|
db_read,
|
|
metrics: None,
|
|
})
|
|
}
|
|
|
|
pub fn set_metrics(&mut self, registry: track::MetricsRegistry) {
|
|
self.metrics = Some(registry);
|
|
}
|
|
|
|
pub fn writer(&self) -> &PgPool {
|
|
&self.db_write
|
|
}
|
|
|
|
pub fn reader(&self) -> &PgPool {
|
|
self.db_read.as_ref().unwrap_or(&self.db_write)
|
|
}
|
|
|
|
pub fn route_pool(&self, sql: &str) -> &PgPool {
|
|
match route_sql(sql) {
|
|
SqlRoute::Write => self.writer(),
|
|
SqlRoute::Read => self.reader(),
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip(self), fields(sql.route = "write"))]
|
|
pub async fn begin(&self) -> Result<AppTransaction<'_>, sqlx::Error> {
|
|
let txn = self.db_write.begin().await?;
|
|
tracing::debug!("db transaction started");
|
|
Ok(AppTransaction { inner: txn })
|
|
}
|
|
|
|
#[tracing::instrument(skip(self), fields(sql.route = "read"))]
|
|
pub async fn begin_read_only(
|
|
&self,
|
|
) -> Result<AppTransaction<'_>, sqlx::Error> {
|
|
let mut txn = self.reader().begin().await?;
|
|
|
|
sqlx::query(AssertSqlSafe("SET TRANSACTION READ ONLY".to_owned()))
|
|
.execute(&mut *txn)
|
|
.await?;
|
|
|
|
tracing::debug!("db read-only transaction started");
|
|
Ok(AppTransaction { inner: txn })
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "execute"))]
|
|
pub async fn execute(
|
|
&self,
|
|
sql: &str,
|
|
) -> Result<PgQueryResult, sqlx::Error> {
|
|
self.execute_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "execute"))]
|
|
pub async fn execute_with_args(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<PgQueryResult, sqlx::Error> {
|
|
let pool = self.route_pool(sql);
|
|
let start = std::time::Instant::now();
|
|
|
|
let result = sqlx::query_with(AssertSqlSafe(sql.to_owned()), args)
|
|
.execute(pool)
|
|
.await;
|
|
|
|
let kind = if sql.trim_start().to_uppercase().starts_with("INSERT") {
|
|
"insert"
|
|
} else if sql.trim_start().to_uppercase().starts_with("UPDATE") {
|
|
"update"
|
|
} else if sql.trim_start().to_uppercase().starts_with("DELETE") {
|
|
"delete"
|
|
} else {
|
|
"execute"
|
|
};
|
|
self.record_query(
|
|
kind,
|
|
self.route_label(sql),
|
|
start.elapsed(),
|
|
result.is_ok(),
|
|
);
|
|
result
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_one"))]
|
|
pub async fn fetch_one<T>(&self, sql: &str) -> Result<T, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_one_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_one"))]
|
|
pub async fn fetch_one_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<T, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
let start = std::time::Instant::now();
|
|
|
|
let result =
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_one(pool)
|
|
.await;
|
|
|
|
self.record_query(
|
|
"select",
|
|
self.route_label(sql),
|
|
start.elapsed(),
|
|
result.is_ok(),
|
|
);
|
|
result
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_optional"))]
|
|
pub async fn fetch_optional<T>(
|
|
&self,
|
|
sql: &str,
|
|
) -> Result<Option<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_optional_with_args(sql, PgArguments::default())
|
|
.await
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_optional"))]
|
|
pub async fn fetch_optional_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<Option<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
let start = std::time::Instant::now();
|
|
|
|
let result =
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_optional(pool)
|
|
.await;
|
|
|
|
self.record_query(
|
|
"select",
|
|
self.route_label(sql),
|
|
start.elapsed(),
|
|
result.is_ok(),
|
|
);
|
|
result
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql), fields(sql.kind = "fetch_all"))]
|
|
pub async fn fetch_all<T>(&self, sql: &str) -> Result<Vec<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
self.fetch_all_with_args(sql, PgArguments::default()).await
|
|
}
|
|
|
|
#[tracing::instrument(skip(self, sql, args), fields(sql.kind = "fetch_all"))]
|
|
pub async fn fetch_all_with_args<T>(
|
|
&self,
|
|
sql: &str,
|
|
args: PgArguments,
|
|
) -> Result<Vec<T>, sqlx::Error>
|
|
where
|
|
for<'r> T: FromRow<'r, PgRow> + Send + Unpin,
|
|
{
|
|
let pool = self.route_pool(sql);
|
|
let start = std::time::Instant::now();
|
|
|
|
let result =
|
|
sqlx::query_as_with::<_, T, _>(AssertSqlSafe(sql.to_owned()), args)
|
|
.fetch_all(pool)
|
|
.await;
|
|
|
|
self.record_query(
|
|
"select",
|
|
self.route_label(sql),
|
|
start.elapsed(),
|
|
result.is_ok(),
|
|
);
|
|
result
|
|
}
|
|
|
|
fn route_label(&self, sql: &str) -> &str {
|
|
match route_sql(sql) {
|
|
SqlRoute::Write => "write",
|
|
SqlRoute::Read => "read",
|
|
}
|
|
}
|
|
|
|
fn record_query(
|
|
&self,
|
|
kind: &str,
|
|
route: &str,
|
|
duration: Duration,
|
|
success: bool,
|
|
) {
|
|
if let Some(reg) = &self.metrics {
|
|
let status = if success { "success" } else { "error" };
|
|
db_queries_vec(reg)
|
|
.with_label_values(&[kind, route, status])
|
|
.inc();
|
|
db_query_duration_vec(reg)
|
|
.with_label_values(&[kind, route])
|
|
.observe(duration.as_secs_f64());
|
|
}
|
|
}
|
|
}
|
|
|
|
fn build_pg_options(
|
|
db_url: &str,
|
|
schema_search_path: &str,
|
|
) -> anyhow::Result<PgConnectOptions> {
|
|
let options = PgConnectOptions::from_str(db_url)?
|
|
.options([("search_path", schema_search_path)])
|
|
.disable_statement_logging();
|
|
|
|
Ok(options)
|
|
}
|
|
|
|
async fn build_pool(
|
|
options: PgConnectOptions,
|
|
max_connections: u32,
|
|
min_connections: u32,
|
|
idle_timeout_secs: u64,
|
|
max_lifetime_secs: u64,
|
|
connection_timeout_secs: u64,
|
|
) -> Result<PgPool, sqlx::Error> {
|
|
let mut pool_options = PgPoolOptions::new()
|
|
.max_connections(max_connections)
|
|
.min_connections(min_connections)
|
|
.acquire_timeout(Duration::from_secs(connection_timeout_secs.max(1)));
|
|
|
|
if idle_timeout_secs > 0 {
|
|
pool_options =
|
|
pool_options.idle_timeout(Duration::from_secs(idle_timeout_secs));
|
|
}
|
|
|
|
if max_lifetime_secs > 0 {
|
|
pool_options =
|
|
pool_options.max_lifetime(Duration::from_secs(max_lifetime_secs));
|
|
}
|
|
|
|
pool_options.connect_with(options).await
|
|
}
|
|
|
|
fn db_queries_vec(registry: &track::MetricsRegistry) -> CounterVec {
|
|
registry
|
|
.register_counter_vec(
|
|
"db_queries_total",
|
|
"Total database queries",
|
|
&["kind", "route", "status"],
|
|
)
|
|
.expect("failed to register db_queries_total")
|
|
}
|
|
|
|
fn db_query_duration_vec(registry: &track::MetricsRegistry) -> HistogramVec {
|
|
registry
|
|
.register_histogram_vec(
|
|
"db_query_duration_seconds",
|
|
"DB query duration in seconds",
|
|
&["kind", "route"],
|
|
vec![
|
|
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
|
|
],
|
|
)
|
|
.expect("failed to register db_query_duration_seconds")
|
|
}
|