gitdataai/libs/db/database.rs
2026-04-14 19:02:01 +08:00

195 lines
5.7 KiB
Rust

use config::AppConfig;
use rand::random_range;
use sea_orm::prelude::async_trait::async_trait;
use sea_orm::{
ConnectionTrait, Database, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
ExecResult, QueryResult, Statement, TransactionTrait,
};
use std::time::Duration;
#[derive(Clone)]
pub struct AppDatabase {
db_write: DatabaseConnection,
db_read: Vec<DatabaseConnection>,
}
impl AppDatabase {
pub async fn init(cfg: &AppConfig) -> anyhow::Result<Self> {
let db_url = cfg.database_url()?;
let max_connections = cfg.database_max_connections()?;
let min_connections = cfg.database_min_connections()?;
let idle_timeout = cfg.database_idle_timeout()?;
let max_lifetime = cfg.database_max_lifetime()?;
let connection_timeout = cfg.database_connection_timeout()?;
let schema_search_path = cfg.database_schema_search_path()?;
let read_replicas = cfg.database_read_replicas()?;
let conn_cfg = sea_orm::ConnectOptions::new(db_url)
.max_connections(max_connections)
.min_connections(min_connections)
.idle_timeout(Duration::from_secs(idle_timeout))
.max_lifetime(Duration::from_secs(max_lifetime))
.connect_timeout(Duration::from_secs(connection_timeout))
.set_schema_search_path(schema_search_path)
.sqlx_logging(false)
.to_owned();
let db_write = Database::connect(conn_cfg).await?;
let mut db_read = vec![];
for replica in read_replicas {
let conn_cfg = sea_orm::ConnectOptions::new(replica.clone())
.max_connections(max_connections)
.min_connections(min_connections)
.idle_timeout(Duration::from_secs(idle_timeout))
.max_lifetime(Duration::from_secs(max_lifetime))
.connect_timeout(Duration::from_secs(connection_timeout))
.to_owned();
let conn = Database::connect(conn_cfg).await?;
db_read.push(conn);
}
Ok(Self { db_write, db_read })
}
pub fn writer(&self) -> &DatabaseConnection {
&self.db_write
}
pub fn reader(&self) -> &DatabaseConnection {
if self.db_read.is_empty() {
return &self.db_write;
}
&self.db_read[random_range(0..self.db_read.len())]
}
pub async fn begin(&self) -> Result<AppTransaction, DbErr> {
let txn = self.db_write.begin().await?;
Ok(AppTransaction { inner: txn })
}
}
pub struct AppTransaction {
inner: DatabaseTransaction,
}
impl AppTransaction {
pub async fn commit(self) -> Result<(), DbErr> {
self.inner.commit().await
}
pub async fn rollback(self) -> Result<(), DbErr> {
self.inner.rollback().await
}
}
#[async_trait]
impl ConnectionTrait for AppTransaction {
fn get_database_backend(&self) -> DbBackend {
self.inner.get_database_backend()
}
async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
self.inner.execute_raw(stmt).await
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
self.inner.execute_unprepared(sql).await
}
async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
self.inner.query_one_raw(stmt).await
}
async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.inner.query_all_raw(stmt).await
}
}
#[async_trait]
impl ConnectionTrait for AppDatabase {
fn get_database_backend(&self) -> DbBackend {
self.db_write.get_database_backend()
}
async fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
if is_force_write(&stmt.sql) {
return self.db_write.execute_raw(stmt).await;
}
if is_read_query(&stmt.sql) {
return self.reader().execute_raw(stmt).await;
}
self.db_write.execute_raw(stmt).await
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
if is_read_query(sql) {
self.reader().execute_unprepared(sql).await
} else {
self.db_write.execute_unprepared(sql).await
}
}
async fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
if is_force_write(&stmt.sql) {
return self.db_write.query_one_raw(stmt).await;
}
if is_read_query(&stmt.sql) {
return self.reader().query_one_raw(stmt).await;
}
self.db_write.query_one_raw(stmt).await
}
async fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
if is_force_write(&stmt.sql) {
return self.db_write.query_all_raw(stmt).await;
}
if is_read_query(&stmt.sql) {
return self.reader().query_all_raw(stmt).await;
}
self.db_write.query_all_raw(stmt).await
}
}
fn is_force_write(sql: &str) -> bool {
sql.contains("/*+ write */")
}
fn is_force_read(sql: &str) -> bool {
sql.contains("/*+ read */")
}
fn is_read_query(sql: &str) -> bool {
if is_force_write(sql) {
return false;
}
if is_force_read(sql) {
return true;
}
let sql = strip_comments(sql).to_lowercase();
if sql.contains("for update") || sql.contains("for share") {
return false;
}
match sql.split_whitespace().next() {
Some("select") | Some("show") | Some("desc") | Some("describe") | Some("explain") => true,
_ => false,
}
}
fn strip_comments(sql: &str) -> String {
sql.lines()
.filter(|l| {
let l = l.trim_start();
!l.starts_with("--") && !l.starts_with("/*")
})
.collect::<Vec<_>>()
.join(" ")
}