268 lines
7.4 KiB
Rust
268 lines
7.4 KiB
Rust
use std::collections::{BTreeMap, HashMap, VecDeque};
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use anyhow::{Context, Result, bail};
|
|
use sqlx::PgPool;
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
struct Migration {
|
|
domain: String,
|
|
table: String,
|
|
version: u32,
|
|
direction: MigrationDir,
|
|
path: PathBuf,
|
|
depends_on: Vec<String>,
|
|
}
|
|
|
|
impl Ord for Migration {
|
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
|
(&self.domain, &self.table, self.version, &self.direction).cmp(&(
|
|
&other.domain,
|
|
&other.table,
|
|
other.version,
|
|
&other.direction,
|
|
))
|
|
}
|
|
}
|
|
|
|
impl PartialOrd for Migration {
|
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
Some(self.cmp(other))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
|
enum MigrationDir {
|
|
Up,
|
|
Down,
|
|
}
|
|
|
|
pub async fn run_up(pool: &PgPool) -> Result<()> {
|
|
let sql_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("sql");
|
|
ensure_migrations_table(pool).await?;
|
|
let all = discover_migrations(&sql_root)?;
|
|
let applied = applied_set(pool).await?;
|
|
|
|
let mut up_migrations: Vec<_> = all
|
|
.into_iter()
|
|
.filter(|m| m.direction == MigrationDir::Up)
|
|
.filter(|m| {
|
|
!applied.contains_key(&(
|
|
m.domain.clone(),
|
|
m.table.clone(),
|
|
m.version,
|
|
))
|
|
})
|
|
.collect();
|
|
|
|
if up_migrations.is_empty() {
|
|
tracing::info!("All migrations are already applied.");
|
|
return Ok(());
|
|
}
|
|
|
|
topo_sort(&mut up_migrations)?;
|
|
|
|
for m in &up_migrations {
|
|
let sql = std::fs::read_to_string(&m.path)
|
|
.context(format!("Failed to read {:?}", m.path))?;
|
|
let checksum = compute_checksum(&sql);
|
|
|
|
tracing::info!(domain = %m.domain, table = %m.table, version = m.version, "applying migration");
|
|
exec_sql(pool, &sql).await?;
|
|
record_migration(pool, m, &checksum).await?;
|
|
}
|
|
|
|
tracing::info!("Applied {} migration(s).", up_migrations.len());
|
|
Ok(())
|
|
}
|
|
|
|
fn discover_migrations(sql_root: &Path) -> Result<Vec<Migration>> {
|
|
let mut migrations = Vec::new();
|
|
|
|
if !sql_root.exists() {
|
|
bail!("SQL directory not found: {}", sql_root.display());
|
|
}
|
|
|
|
for dir_entry in std::fs::read_dir(sql_root)? {
|
|
let dir = dir_entry?;
|
|
if !dir.file_type()?.is_dir() {
|
|
continue;
|
|
}
|
|
let domain = dir.file_name().to_string_lossy().to_string();
|
|
|
|
for file_entry in std::fs::read_dir(dir.path())? {
|
|
let file = file_entry?;
|
|
let path = file.path();
|
|
if path.extension().and_then(|e| e.to_str()) != Some("sql") {
|
|
continue;
|
|
}
|
|
|
|
let stem = path
|
|
.file_stem()
|
|
.and_then(|s| s.to_str())
|
|
.context("Invalid filename")?;
|
|
|
|
let (table, direction, version) = parse_migration_stem(stem)?;
|
|
|
|
let content = std::fs::read_to_string(&path)
|
|
.context(format!("Failed to read {path:?}"))?;
|
|
let depends_on = parse_depends_on(&content);
|
|
|
|
migrations.push(Migration {
|
|
domain: domain.clone(),
|
|
table,
|
|
version,
|
|
direction,
|
|
path,
|
|
depends_on,
|
|
});
|
|
}
|
|
}
|
|
|
|
migrations.sort();
|
|
Ok(migrations)
|
|
}
|
|
|
|
fn parse_migration_stem(stem: &str) -> Result<(String, MigrationDir, u32)> {
|
|
if let Some(pos) = stem.rfind("_up_") {
|
|
let table = stem[..pos].to_string();
|
|
let version = stem[pos + 4..]
|
|
.parse::<u32>()
|
|
.context("Invalid version number")?;
|
|
Ok((table, MigrationDir::Up, version))
|
|
} else if let Some(pos) = stem.rfind("_down_") {
|
|
let table = stem[..pos].to_string();
|
|
let version = stem[pos + 6..]
|
|
.parse::<u32>()
|
|
.context("Invalid version number")?;
|
|
Ok((table, MigrationDir::Down, version))
|
|
} else {
|
|
bail!("Migration filename must contain _up_ or _down_: {stem}");
|
|
}
|
|
}
|
|
|
|
async fn ensure_migrations_table(pool: &PgPool) -> Result<()> {
|
|
sqlx::query(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS _sql_migrations (
|
|
domain TEXT NOT NULL,
|
|
table_name TEXT NOT NULL,
|
|
version INTEGER NOT NULL,
|
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
checksum TEXT NOT NULL DEFAULT '',
|
|
PRIMARY KEY (domain, table_name, version)
|
|
)
|
|
"#,
|
|
)
|
|
.execute(pool)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn applied_set(
|
|
pool: &PgPool,
|
|
) -> Result<BTreeMap<(String, String, u32), String>> {
|
|
let rows: Vec<(String, String, i32, String)> = sqlx::query_as(
|
|
"SELECT domain, table_name, version, checksum FROM _sql_migrations ORDER BY domain, table_name, version",
|
|
)
|
|
.fetch_all(pool)
|
|
.await?;
|
|
|
|
Ok(rows
|
|
.into_iter()
|
|
.map(|(d, t, v, c)| ((d, t, v as u32), c))
|
|
.collect())
|
|
}
|
|
|
|
async fn record_migration(
|
|
pool: &PgPool,
|
|
m: &Migration,
|
|
checksum: &str,
|
|
) -> Result<()> {
|
|
sqlx::query(
|
|
"INSERT INTO _sql_migrations (domain, table_name, version, checksum) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING",
|
|
)
|
|
.bind(&m.domain)
|
|
.bind(&m.table)
|
|
.bind(m.version as i32)
|
|
.bind(checksum)
|
|
.execute(pool)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn compute_checksum(content: &str) -> String {
|
|
use std::collections::hash_map::DefaultHasher;
|
|
use std::hash::{Hash, Hasher};
|
|
let mut hasher = DefaultHasher::new();
|
|
content.hash(&mut hasher);
|
|
format!("{:x}", hasher.finish())
|
|
}
|
|
|
|
fn parse_depends_on(content: &str) -> Vec<String> {
|
|
content
|
|
.lines()
|
|
.filter_map(|line| {
|
|
let line = line.trim();
|
|
line.strip_prefix("-- depends_on:").map(|deps| {
|
|
deps.split(',')
|
|
.map(|d| d.trim().to_string())
|
|
.filter(|d| !d.is_empty())
|
|
.collect::<Vec<_>>()
|
|
})
|
|
})
|
|
.flatten()
|
|
.collect()
|
|
}
|
|
|
|
fn topo_sort(migrations: &mut [Migration]) -> Result<()> {
|
|
let table_to_idx: HashMap<String, usize> = migrations
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, m)| (m.table.clone(), i))
|
|
.collect();
|
|
|
|
let n = migrations.len();
|
|
let mut in_degree = vec![0u32; n];
|
|
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
|
|
|
|
for (i, m) in migrations.iter().enumerate() {
|
|
for dep in &m.depends_on {
|
|
if let Some(&j) = table_to_idx.get(dep) {
|
|
adj[j].push(i);
|
|
in_degree[i] += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut queue: VecDeque<usize> =
|
|
(0..n).filter(|&i| in_degree[i] == 0).collect();
|
|
let mut order = Vec::with_capacity(n);
|
|
while let Some(i) = queue.pop_front() {
|
|
order.push(i);
|
|
for &next in &adj[i] {
|
|
in_degree[next] -= 1;
|
|
if in_degree[next] == 0 {
|
|
queue.push_back(next);
|
|
}
|
|
}
|
|
}
|
|
|
|
if order.len() != n {
|
|
bail!("Circular dependency detected among migrations");
|
|
}
|
|
|
|
let original: Vec<Migration> = migrations.iter().cloned().collect();
|
|
for (slot, &idx) in order.iter().enumerate() {
|
|
migrations[slot] = original[idx].clone();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn exec_sql(pool: &PgPool, sql: &str) -> Result<()> {
|
|
let s: &'static str = Box::leak(sql.to_owned().into_boxed_str());
|
|
sqlx::raw_sql(s).execute(pool).await?;
|
|
Ok(())
|
|
}
|