use config::AppConfig; use rand::random_range; use sea_orm::prelude::async_trait::async_trait; use sea_orm::{ ConnectionTrait, Database, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionTrait, }; use std::time::Duration; #[derive(Clone)] pub struct AppDatabase { db_write: DatabaseConnection, db_read: Vec, } impl AppDatabase { pub async fn init(cfg: &AppConfig) -> anyhow::Result { let db_url = cfg.database_url()?; let max_connections = cfg.database_max_connections()?; let min_connections = cfg.database_min_connections()?; let idle_timeout = cfg.database_idle_timeout()?; let max_lifetime = cfg.database_max_lifetime()?; let connection_timeout = cfg.database_connection_timeout()?; let schema_search_path = cfg.database_schema_search_path()?; let read_replicas = cfg.database_read_replicas()?; let conn_cfg = sea_orm::ConnectOptions::new(db_url) .max_connections(max_connections) .min_connections(min_connections) .idle_timeout(Duration::from_secs(idle_timeout)) .max_lifetime(Duration::from_secs(max_lifetime)) .connect_timeout(Duration::from_secs(connection_timeout)) .set_schema_search_path(schema_search_path) .sqlx_logging(false) .to_owned(); let db_write = Database::connect(conn_cfg).await?; let mut db_read = vec![]; for replica in read_replicas { let conn_cfg = sea_orm::ConnectOptions::new(replica.clone()) .max_connections(max_connections) .min_connections(min_connections) .idle_timeout(Duration::from_secs(idle_timeout)) .max_lifetime(Duration::from_secs(max_lifetime)) .connect_timeout(Duration::from_secs(connection_timeout)) .to_owned(); let conn = Database::connect(conn_cfg).await?; db_read.push(conn); } Ok(Self { db_write, db_read }) } pub fn writer(&self) -> &DatabaseConnection { &self.db_write } pub fn reader(&self) -> &DatabaseConnection { if self.db_read.is_empty() { return &self.db_write; } &self.db_read[random_range(0..self.db_read.len())] } pub async fn begin(&self) -> Result { let txn = self.db_write.begin().await?; Ok(AppTransaction { inner: txn }) } } pub struct AppTransaction { inner: DatabaseTransaction, } impl AppTransaction { pub async fn commit(self) -> Result<(), DbErr> { self.inner.commit().await } pub async fn rollback(self) -> Result<(), DbErr> { self.inner.rollback().await } } #[async_trait] impl ConnectionTrait for AppTransaction { fn get_database_backend(&self) -> DbBackend { self.inner.get_database_backend() } async fn execute_raw(&self, stmt: Statement) -> Result { self.inner.execute_raw(stmt).await } async fn execute_unprepared(&self, sql: &str) -> Result { self.inner.execute_unprepared(sql).await } async fn query_one_raw(&self, stmt: Statement) -> Result, DbErr> { self.inner.query_one_raw(stmt).await } async fn query_all_raw(&self, stmt: Statement) -> Result, DbErr> { self.inner.query_all_raw(stmt).await } } #[async_trait] impl ConnectionTrait for AppDatabase { fn get_database_backend(&self) -> DbBackend { self.db_write.get_database_backend() } async fn execute_raw(&self, stmt: Statement) -> Result { if is_force_write(&stmt.sql) { return self.db_write.execute_raw(stmt).await; } if is_read_query(&stmt.sql) { return self.reader().execute_raw(stmt).await; } self.db_write.execute_raw(stmt).await } async fn execute_unprepared(&self, sql: &str) -> Result { if is_read_query(sql) { self.reader().execute_unprepared(sql).await } else { self.db_write.execute_unprepared(sql).await } } async fn query_one_raw(&self, stmt: Statement) -> Result, DbErr> { if is_force_write(&stmt.sql) { return self.db_write.query_one_raw(stmt).await; } if is_read_query(&stmt.sql) { return self.reader().query_one_raw(stmt).await; } self.db_write.query_one_raw(stmt).await } async fn query_all_raw(&self, stmt: Statement) -> Result, DbErr> { if is_force_write(&stmt.sql) { return self.db_write.query_all_raw(stmt).await; } if is_read_query(&stmt.sql) { return self.reader().query_all_raw(stmt).await; } self.db_write.query_all_raw(stmt).await } } fn is_force_write(sql: &str) -> bool { sql.contains("/*+ write */") } fn is_force_read(sql: &str) -> bool { sql.contains("/*+ read */") } fn is_read_query(sql: &str) -> bool { if is_force_write(sql) { return false; } if is_force_read(sql) { return true; } let sql = strip_comments(sql).to_lowercase(); if sql.contains("for update") || sql.contains("for share") { return false; } match sql.split_whitespace().next() { Some("select") | Some("show") | Some("desc") | Some("describe") | Some("explain") => true, _ => false, } } fn strip_comments(sql: &str) -> String { sql.lines() .filter(|l| { let l = l.trim_start(); !l.starts_with("--") && !l.starts_with("/*") }) .collect::>() .join(" ") }