use std::collections::{BTreeMap, HashMap, VecDeque}; use std::path::{Path, PathBuf}; use anyhow::{Context, Result, bail}; use sqlx::PgPool; #[derive(Debug, Clone, PartialEq, Eq)] struct Migration { domain: String, table: String, version: u32, direction: MigrationDir, path: PathBuf, depends_on: Vec, } impl Ord for Migration { fn cmp(&self, other: &Self) -> std::cmp::Ordering { (&self.domain, &self.table, self.version, &self.direction).cmp(&( &other.domain, &other.table, other.version, &other.direction, )) } } impl PartialOrd for Migration { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] enum MigrationDir { Up, Down, } pub async fn run_up(pool: &PgPool) -> Result<()> { let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql"); ensure_migrations_table(pool).await?; let all = discover_migrations(&sql_root)?; let applied = applied_set(pool).await?; let mut up_migrations: Vec<_> = all .into_iter() .filter(|m| m.direction == MigrationDir::Up) .filter(|m| { !applied.contains_key(&( m.domain.clone(), m.table.clone(), m.version, )) }) .collect(); if up_migrations.is_empty() { tracing::info!("All migrations are already applied."); return Ok(()); } topo_sort(&mut up_migrations)?; for m in &up_migrations { let sql = std::fs::read_to_string(&m.path) .context(format!("Failed to read {:?}", m.path))?; let checksum = compute_checksum(&sql); tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration"); exec_sql(pool, &sql).await?; record_migration(pool, m, &checksum).await?; } tracing::info!("Applied {} migration(s).", up_migrations.len()); Ok(()) } fn discover_migrations(sql_root: &Path) -> Result> { let mut migrations = Vec::new(); if !sql_root.exists() { bail!("SQL directory not found: {}", sql_root.display()); } for dir_entry in std::fs::read_dir(sql_root)? { let dir = dir_entry?; if !dir.file_type()?.is_dir() { continue; } let domain = dir.file_name().to_string_lossy().to_string(); for file_entry in std::fs::read_dir(dir.path())? { let file = file_entry?; let path = file.path(); if path.extension().and_then(|e| e.to_str()) != Some("sql") { continue; } let stem = path .file_stem() .and_then(|s| s.to_str()) .context("Invalid filename")?; let (table, direction, version) = parse_migration_stem(stem)?; let content = std::fs::read_to_string(&path) .context(format!("Failed to read {path:?}"))?; let depends_on = parse_depends_on(&content); migrations.push(Migration { domain: domain.clone(), table, version, direction, path, depends_on, }); } } migrations.sort(); Ok(migrations) } fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> { if let Some(pos) = stem.rfind("_up_") { let table = stem[..pos].to_string(); let version = stem[pos + 4..] .parse::() .context("Invalid version number")?; Ok((table, MigrationDir::Up, version)) } else if let Some(pos) = stem.rfind("_down_") { let table = stem[..pos].to_string(); let version = stem[pos + 6..] .parse::() .context("Invalid version number")?; Ok((table, MigrationDir::Down, version)) } else { bail!("Migration filename must contain _up_ or _down_: {stem}"); } } async fn ensure_migrations_table(pool: &PgPool) -> Result<()> { sqlx::query( r#" CREATE TABLE IF NOT EXISTS _sql_migrations ( domain TEXT NOT NULL, table_name TEXT NOT NULL, version INTEGER NOT NULL, applied_at TIMESTAMPTZ NOT NULL DEFAULT now(), checksum TEXT NOT NULL DEFAULT '', PRIMARY KEY (domain, table_name, version) ) "#, ) .execute(pool) .await?; Ok(()) } async fn applied_set( pool: &PgPool, ) -> Result> { let rows: Vec<(String, String, i32, String)> = sqlx::query_as( "SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version", ) .fetch_all(pool) .await?; Ok(rows .into_iter() .map(|(d, t, v, c)| ((d, t, v as u32), c)) .collect()) } async fn record_migration( pool: &PgPool, m: &Migration, checksum: &str, ) -> Result<()> { sqlx::query( "INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", ) .bind(&m.domain) .bind(&m.table) .bind(m.version as i32) .bind(checksum) .execute(pool) .await?; Ok(()) } fn compute_checksum(content: &str) -> String { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); content.hash(&mut hasher); format!("{:x}", hasher.finish()) } fn parse_depends_on(content: &str) -> Vec { content .lines() .filter_map(|line| { let line = line.trim(); line.strip_prefix("-- depends_on:").map(|deps| { deps.split(',') .map(|d| d.trim().to_string()) .filter(|d| !d.is_empty()) .collect::>() }) }) .flatten() .collect() } fn topo_sort(migrations: &mut [Migration]) -> Result<()> { let table_to_idx: HashMap = migrations .iter() .enumerate() .map(|(i, m)| (m.table.clone(), i)) .collect(); let n = migrations.len(); let mut in_degree = vec![0u32; n]; let mut adj: Vec> = vec![Vec::new(); n]; for (i, m) in migrations.iter().enumerate() { for dep in &m.depends_on { if let Some(&j) = table_to_idx.get(dep) { adj[j].push(i); in_degree[i] += 1; } } } let mut queue: VecDeque = (0..n).filter(|&i| in_degree[i] == 0).collect(); let mut order = Vec::with_capacity(n); while let Some(i) = queue.pop_front() { order.push(i); for &next in &adj[i] { in_degree[next] -= 1; if in_degree[next] == 0 { queue.push_back(next); } } } if order.len() != n { bail!("Circular dependency detected among migrations"); } let original: Vec = migrations.iter().cloned().collect(); for (slot, &idx) in order.iter().enumerate() { migrations[slot] = original[idx].clone(); } Ok(()) } async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> { let s: &'static str = Box::leak(sql.to_owned().into_boxed_str()); sqlx::raw_sql(s).execute(pool).await?; Ok(()) }