195 lines
5.7 KiB
Rust
195 lines
5.7 KiB
Rust
use config::AppConfig;
|
|
use rand::random_range;
|
|
use sea_orm::prelude::async_trait::async_trait;
|
|
use sea_orm::{
|
|
ConnectionTrait, Database, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
|
|
ExecResult, QueryResult, Statement, TransactionTrait,
|
|
};
|
|
use std::time::Duration;
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppDatabase {
|
|
db_write: DatabaseConnection,
|
|
db_read: Vec<DatabaseConnection>,
|
|
}
|
|
|
|
impl AppDatabase {
|
|
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
|
|
let db_url = cfg.database_url()?;
|
|
let max_connections = cfg.database_max_connections()?;
|
|
let min_connections = cfg.database_min_connections()?;
|
|
let idle_timeout = cfg.database_idle_timeout()?;
|
|
let max_lifetime = cfg.database_max_lifetime()?;
|
|
let connection_timeout = cfg.database_connection_timeout()?;
|
|
let schema_search_path = cfg.database_schema_search_path()?;
|
|
let read_replicas = cfg.database_read_replicas()?;
|
|
|
|
let conn_cfg = sea_orm::ConnectOptions::new(db_url)
|
|
.max_connections(max_connections)
|
|
.min_connections(min_connections)
|
|
.idle_timeout(Duration::from_secs(idle_timeout))
|
|
.max_lifetime(Duration::from_secs(max_lifetime))
|
|
.connect_timeout(Duration::from_secs(connection_timeout))
|
|
.set_schema_search_path(schema_search_path)
|
|
.sqlx_logging(false)
|
|
.to_owned();
|
|
|
|
let db_write = Database::connect(conn_cfg).await?;
|
|
|
|
let mut db_read = vec![];
|
|
for replica in read_replicas {
|
|
let conn_cfg = sea_orm::ConnectOptions::new(replica.clone())
|
|
.max_connections(max_connections)
|
|
.min_connections(min_connections)
|
|
.idle_timeout(Duration::from_secs(idle_timeout))
|
|
.max_lifetime(Duration::from_secs(max_lifetime))
|
|
.connect_timeout(Duration::from_secs(connection_timeout))
|
|
.to_owned();
|
|
|
|
let conn = Database::connect(conn_cfg).await?;
|
|
db_read.push(conn);
|
|
}
|
|
|
|
Ok(Self { db_write, db_read })
|
|
}
|
|
|
|
pub fn writer(&self) -> &DatabaseConnection {
|
|
&self.db_write
|
|
}
|
|
|
|
pub fn reader(&self) -> &DatabaseConnection {
|
|
if self.db_read.is_empty() {
|
|
return &self.db_write;
|
|
}
|
|
|
|
&self.db_read[random_range(0..self.db_read.len())]
|
|
}
|
|
|
|
pub async fn begin(&self) -> Result<AppTransaction, DbErr> {
|
|
let txn = self.db_write.begin().await?;
|
|
Ok(AppTransaction { inner: txn })
|
|
}
|
|
}
|
|
|
|
pub struct AppTransaction {
|
|
inner: DatabaseTransaction,
|
|
}
|
|
|
|
impl AppTransaction {
|
|
pub async fn commit(self) -> Result<(), DbErr> {
|
|
self.inner.commit().await
|
|
}
|
|
pub async fn rollback(self) -> Result<(), DbErr> {
|
|
self.inner.rollback().await
|
|
}
|
|
}
|
|
#[async_trait]
|
|
impl ConnectionTrait for AppTransaction {
|
|
fn get_database_backend(&self) -> DbBackend {
|
|
self.inner.get_database_backend()
|
|
}
|
|
|
|
async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
|
|
self.inner.execute_raw(stmt).await
|
|
}
|
|
|
|
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
|
|
self.inner.execute_unprepared(sql).await
|
|
}
|
|
|
|
async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
|
|
self.inner.query_one_raw(stmt).await
|
|
}
|
|
|
|
async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
|
|
self.inner.query_all_raw(stmt).await
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ConnectionTrait for AppDatabase {
|
|
fn get_database_backend(&self) -> DbBackend {
|
|
self.db_write.get_database_backend()
|
|
}
|
|
|
|
async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
|
|
if is_force_write(&stmt.sql) {
|
|
return self.db_write.execute_raw(stmt).await;
|
|
}
|
|
|
|
if is_read_query(&stmt.sql) {
|
|
return self.reader().execute_raw(stmt).await;
|
|
}
|
|
|
|
self.db_write.execute_raw(stmt).await
|
|
}
|
|
|
|
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
|
|
if is_read_query(sql) {
|
|
self.reader().execute_unprepared(sql).await
|
|
} else {
|
|
self.db_write.execute_unprepared(sql).await
|
|
}
|
|
}
|
|
|
|
async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
|
|
if is_force_write(&stmt.sql) {
|
|
return self.db_write.query_one_raw(stmt).await;
|
|
}
|
|
|
|
if is_read_query(&stmt.sql) {
|
|
return self.reader().query_one_raw(stmt).await;
|
|
}
|
|
|
|
self.db_write.query_one_raw(stmt).await
|
|
}
|
|
|
|
async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
|
|
if is_force_write(&stmt.sql) {
|
|
return self.db_write.query_all_raw(stmt).await;
|
|
}
|
|
|
|
if is_read_query(&stmt.sql) {
|
|
return self.reader().query_all_raw(stmt).await;
|
|
}
|
|
|
|
self.db_write.query_all_raw(stmt).await
|
|
}
|
|
}
|
|
|
|
fn is_force_write(sql: &str) -> bool {
|
|
sql.contains("/*+ write */")
|
|
}
|
|
|
|
fn is_force_read(sql: &str) -> bool {
|
|
sql.contains("/*+ read */")
|
|
}
|
|
|
|
fn is_read_query(sql: &str) -> bool {
|
|
if is_force_write(sql) {
|
|
return false;
|
|
}
|
|
if is_force_read(sql) {
|
|
return true;
|
|
}
|
|
let sql = strip_comments(sql).to_lowercase();
|
|
if sql.contains("for update") || sql.contains("for share") {
|
|
return false;
|
|
}
|
|
|
|
match sql.split_whitespace().next() {
|
|
Some("select") | Some("show") | Some("desc") | Some("describe") | Some("explain") => true,
|
|
_ => false,
|
|
}
|
|
}
|
|
|
|
fn strip_comments(sql: &str) -> String {
|
|
sql.lines()
|
|
.filter(|l| {
|
|
let l = l.trim_start();
|
|
!l.starts_with("--") && !l.starts_with("/*")
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join(" ")
|
|
}
|