gitdataai/lib/migrate/lib.rs

268 lines
7.4 KiB
Rust

use std::collections::{BTreeMap, HashMap, VecDeque};
use std::path::{Path, PathBuf};
use anyhow::{Context, Result, bail};
use sqlx::PgPool;
#[derive(Debug, Clone, PartialEq, Eq)]
struct Migration {
domain: String,
table: String,
version: u32,
direction: MigrationDir,
path: PathBuf,
depends_on: Vec<String>,
}
impl Ord for Migration {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
&other.domain,
&other.table,
other.version,
&other.direction,
))
}
}
impl PartialOrd for Migration {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum MigrationDir {
Up,
Down,
}
pub async fn run_up(pool: &PgPool) -> Result<()> {
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
ensure_migrations_table(pool).await?;
let all = discover_migrations(&sql_root)?;
let applied = applied_set(pool).await?;
let mut up_migrations: Vec<_> = all
.into_iter()
.filter(|m| m.direction == MigrationDir::Up)
.filter(|m| {
!applied.contains_key(&(
m.domain.clone(),
m.table.clone(),
m.version,
))
})
.collect();
if up_migrations.is_empty() {
tracing::info!("All migrations are already applied.");
return Ok(());
}
topo_sort(&mut up_migrations)?;
for m in &up_migrations {
let sql = std::fs::read_to_string(&m.path)
.context(format!("Failed to read {:?}", m.path))?;
let checksum = compute_checksum(&sql);
tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration");
exec_sql(pool, &sql).await?;
record_migration(pool, m, &checksum).await?;
}
tracing::info!("Applied {} migration(s).", up_migrations.len());
Ok(())
}
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
let mut migrations = Vec::new();
if !sql_root.exists() {
bail!("SQL directory not found: {}", sql_root.display());
}
for dir_entry in std::fs::read_dir(sql_root)? {
let dir = dir_entry?;
if !dir.file_type()?.is_dir() {
continue;
}
let domain = dir.file_name().to_string_lossy().to_string();
for file_entry in std::fs::read_dir(dir.path())? {
let file = file_entry?;
let path = file.path();
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
continue;
}
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.context("Invalid filename")?;
let (table, direction, version) = parse_migration_stem(stem)?;
let content = std::fs::read_to_string(&path)
.context(format!("Failed to read {path:?}"))?;
let depends_on = parse_depends_on(&content);
migrations.push(Migration {
domain: domain.clone(),
table,
version,
direction,
path,
depends_on,
});
}
}
migrations.sort();
Ok(migrations)
}
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
if let Some(pos) = stem.rfind("_up_") {
let table = stem[..pos].to_string();
let version = stem[pos + 4..]
.parse::<u32>()
.context("Invalid version number")?;
Ok((table, MigrationDir::Up, version))
} else if let Some(pos) = stem.rfind("_down_") {
let table = stem[..pos].to_string();
let version = stem[pos + 6..]
.parse::<u32>()
.context("Invalid version number")?;
Ok((table, MigrationDir::Down, version))
} else {
bail!("Migration filename must contain _up_ or _down_: {stem}");
}
}
async fn ensure_migrations_table(pool: &PgPool) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS _sql_migrations (
domain TEXT NOT NULL,
table_name TEXT NOT NULL,
version INTEGER NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
checksum TEXT NOT NULL DEFAULT '',
PRIMARY KEY (domain, table_name, version)
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
async fn applied_set(
pool: &PgPool,
) -> Result<BTreeMap<(String, String, u32), String>> {
let rows: Vec<(String, String, i32, String)> = sqlx::query_as(
"SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version",
)
.fetch_all(pool)
.await?;
Ok(rows
.into_iter()
.map(|(d, t, v, c)| ((d, t, v as u32), c))
.collect())
}
async fn record_migration(
pool: &PgPool,
m: &Migration,
checksum: &str,
) -> Result<()> {
sqlx::query(
"INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING",
)
.bind(&m.domain)
.bind(&m.table)
.bind(m.version as i32)
.bind(checksum)
.execute(pool)
.await?;
Ok(())
}
fn compute_checksum(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn parse_depends_on(content: &str) -> Vec<String> {
content
.lines()
.filter_map(|line| {
let line = line.trim();
line.strip_prefix("-- depends_on:").map(|deps| {
deps.split(',')
.map(|d| d.trim().to_string())
.filter(|d| !d.is_empty())
.collect::<Vec<_>>()
})
})
.flatten()
.collect()
}
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
let table_to_idx: HashMap<String, usize> = migrations
.iter()
.enumerate()
.map(|(i, m)| (m.table.clone(), i))
.collect();
let n = migrations.len();
let mut in_degree = vec![0u32; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, m) in migrations.iter().enumerate() {
for dep in &m.depends_on {
if let Some(&j) = table_to_idx.get(dep) {
adj[j].push(i);
in_degree[i] += 1;
}
}
}
let mut queue: VecDeque<usize> =
(0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(i) = queue.pop_front() {
order.push(i);
for &next in &adj[i] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push_back(next);
}
}
}
if order.len() != n {
bail!("Circular dependency detected among migrations");
}
let original: Vec<Migration> = migrations.iter().cloned().collect();
for (slot, &idx) in order.iter().enumerate() {
migrations[slot] = original[idx].clone();
}
Ok(())
}
async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> {
let s: &'static str = Box::leak(sql.to_owned().into_boxed_str());
sqlx::raw_sql(s).execute(pool).await?;
Ok(())
}