From 8950332ca325945611aafb62393cd4bcc64d04a8 Mon Sep 17 00:00:00 2001 From: Evgeny Nosov Date: Tue, 26 Jan 2021 13:17:58 +0300 Subject: [PATCH] Rewrite migration algorithm in order to apply also unapplied migrations --- sqlx-cli/src/migrate.rs | 137 ++++++++++++++++++----------- sqlx-core/src/any/migrate.rs | 50 +++++------ sqlx-core/src/migrate/migrate.rs | 18 ++-- sqlx-core/src/migrate/migration.rs | 6 ++ sqlx-core/src/migrate/migrator.rs | 44 +++++++-- sqlx-core/src/migrate/mod.rs | 2 +- sqlx-core/src/mysql/migrate.rs | 56 ++++++------ sqlx-core/src/postgres/migrate.rs | 56 ++++++------ sqlx-core/src/sqlite/migrate.rs | 57 ++++++------ 9 files changed, 243 insertions(+), 183 deletions(-) diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index f73aa1c35..d75746ddc 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -1,8 +1,9 @@ use anyhow::{bail, Context}; use chrono::Utc; use console::style; -use sqlx::migrate::{Migrate, MigrateError, MigrationType, Migrator}; +use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator}; use sqlx::{AnyConnection, Connection}; +use std::collections::{HashMap, HashSet}; use std::fs::{self, File}; use std::io::Write; use std::path::Path; @@ -82,13 +83,18 @@ pub async fn info(migration_source: &str, uri: &str) -> anyhow::Result<()> { conn.ensure_migrations_table().await?; - let (version, _) = conn.version().await?.unwrap_or((0, false)); + let applied_migrations: HashMap<_, _> = conn + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); for migration in migrator.iter() { println!( "{}/{} {}", style(migration.version).cyan(), - if version >= migration.version { + if applied_migrations.contains_key(&migration.version) { style("installed").green() } else { style("pending").yellow() @@ -100,41 +106,69 @@ pub async fn info(migration_source: &str, uri: &str) -> anyhow::Result<()> { Ok(()) } +fn validate_applied_migrations( + applied_migrations: &[AppliedMigration], + migrator: &Migrator, +) -> Result<(), MigrateError> { + let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); + + for applied_migration in applied_migrations { + if !migrations.contains(&applied_migration.version) { + return Err(MigrateError::VersionMissing(applied_migration.version)); + } + } + + Ok(()) +} + pub async fn run(migration_source: &str, uri: &str, dry_run: bool) -> anyhow::Result<()> { let migrator = Migrator::new(Path::new(migration_source)).await?; let mut conn = AnyConnection::connect(uri).await?; conn.ensure_migrations_table().await?; - let (version, dirty) = conn.version().await?.unwrap_or((0, false)); - - if dirty { + let version = conn.dirty_version().await?; + if let Some(version) = version { bail!(MigrateError::Dirty(version)); } + let applied_migrations = conn.list_applied_migrations().await?; + validate_applied_migrations(&applied_migrations, &migrator)?; + + let applied_migrations: HashMap<_, _> = applied_migrations + .into_iter() + .map(|m| (m.version, m)) + .collect(); + for migration in migrator.iter() { if migration.migration_type.is_down_migration() { // Skipping down migrations continue; } - if migration.version > version { - let elapsed = if dry_run { - Duration::new(0, 0) - } else { - conn.apply(migration).await? - }; - let text = if dry_run { "Can apply" } else { "Applied" }; - println!( - "{} {}/{} {} {}", - text, - style(migration.version).cyan(), - style(migration.migration_type.label()).green(), - migration.description, - style(format!("({:?})", elapsed)).dim() - ); - } else { - conn.validate(migration).await?; + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum { + bail!(MigrateError::VersionMismatch(migration.version)); + } + } + None => { + let elapsed = if dry_run { + Duration::new(0, 0) + } else { + conn.apply(migration).await? + }; + let text = if dry_run { "Can apply" } else { "Applied" }; + + println!( + "{} {}/{} {} {}", + text, + style(migration.version).cyan(), + style(migration.migration_type.label()).green(), + migration.description, + style(format!("({:?})", elapsed)).dim() + ); + } } } @@ -147,12 +181,19 @@ pub async fn revert(migration_source: &str, uri: &str, dry_run: bool) -> anyhow: conn.ensure_migrations_table().await?; - let (version, dirty) = conn.version().await?.unwrap_or((0, false)); - - if dirty { + let version = conn.dirty_version().await?; + if let Some(version) = version { bail!(MigrateError::Dirty(version)); } + let applied_migrations = conn.list_applied_migrations().await?; + validate_applied_migrations(&applied_migrations, &migrator)?; + + let applied_migrations: HashMap<_, _> = applied_migrations + .into_iter() + .map(|m| (m.version, m)) + .collect(); + let mut is_applied = false; for migration in migrator.iter().rev() { if !migration.migration_type.is_down_migration() { @@ -160,30 +201,28 @@ pub async fn revert(migration_source: &str, uri: &str, dry_run: bool) -> anyhow: // This will skip any simple or up migration file continue; } - if migration.version > version { - // Skipping unapplied migrations - continue; + + if applied_migrations.contains_key(&migration.version) { + let elapsed = if dry_run { + Duration::new(0, 0) + } else { + conn.revert(migration).await? + }; + let text = if dry_run { "Can apply" } else { "Applied" }; + + println!( + "{} {}/{} {} {}", + text, + style(migration.version).cyan(), + style(migration.migration_type.label()).green(), + migration.description, + style(format!("({:?})", elapsed)).dim() + ); + + is_applied = true; + // Only a single migration will be reverted at a time, so we break + break; } - - let elapsed = if dry_run { - Duration::new(0, 0) - } else { - conn.revert(migration).await? - }; - let text = if dry_run { "Can apply" } else { "Applied" }; - - println!( - "{} {}/{} {} {}", - text, - style(migration.version).cyan(), - style(migration.migration_type.label()).green(), - migration.description, - style(format!("({:?})", elapsed)).dim() - ); - - is_applied = true; - // Only a single migration will be reverted at a time, so we break - break; } if !is_applied { println!("No migrations available to revert"); diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index 8af3e6f0c..764991b5f 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -2,7 +2,7 @@ use crate::any::connection::AnyConnectionKind; use crate::any::kind::AnyKind; use crate::any::{Any, AnyConnection}; use crate::error::Error; -use crate::migrate::{Migrate, MigrateDatabase, MigrateError, Migration}; +use crate::migrate::{AppliedMigration, Migrate, MigrateDatabase, MigrateError, Migration}; use futures_core::future::BoxFuture; use std::str::FromStr; use std::time::Duration; @@ -80,16 +80,34 @@ impl Migrate for AnyConnection { } } - fn version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { match &mut self.0 { #[cfg(feature = "postgres")] - AnyConnectionKind::Postgres(conn) => conn.version(), + AnyConnectionKind::Postgres(conn) => conn.dirty_version(), #[cfg(feature = "sqlite")] - AnyConnectionKind::Sqlite(conn) => conn.version(), + AnyConnectionKind::Sqlite(conn) => conn.dirty_version(), #[cfg(feature = "mysql")] - AnyConnectionKind::MySql(conn) => conn.version(), + AnyConnectionKind::MySql(conn) => conn.dirty_version(), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(_conn) => unimplemented!(), + } + } + + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>> { + match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.list_applied_migrations(), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.list_applied_migrations(), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.list_applied_migrations(), #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), @@ -128,28 +146,6 @@ impl Migrate for AnyConnection { } } - fn validate<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result<(), MigrateError>> { - match &mut self.0 { - #[cfg(feature = "postgres")] - AnyConnectionKind::Postgres(conn) => conn.validate(migration), - - #[cfg(feature = "sqlite")] - AnyConnectionKind::Sqlite(conn) => conn.validate(migration), - - #[cfg(feature = "mysql")] - AnyConnectionKind::MySql(conn) => conn.validate(migration), - - #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(_conn) => { - let _ = migration; - unimplemented!() - } - } - } - fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, diff --git a/sqlx-core/src/migrate/migrate.rs b/sqlx-core/src/migrate/migrate.rs index 31690d929..a3958e157 100644 --- a/sqlx-core/src/migrate/migrate.rs +++ b/sqlx-core/src/migrate/migrate.rs @@ -1,5 +1,5 @@ use crate::error::Error; -use crate::migrate::{MigrateError, Migration}; +use crate::migrate::{AppliedMigration, MigrateError, Migration}; use futures_core::future::BoxFuture; use std::time::Duration; @@ -23,9 +23,14 @@ pub trait Migrate { // will create or migrate it if needed fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>>; - // Return the current version and if the database is "dirty". + // Return the version on which the database is dirty or None otherwise. // "dirty" means there is a partially applied migration that failed. - fn version(&mut self) -> BoxFuture<'_, Result, MigrateError>>; + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>>; + + // Return the ordered list of applied migrations + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>>; // Should acquire a database lock so that only one migration process // can run at a time. [`Migrate`] will call this function before applying @@ -36,13 +41,6 @@ pub trait Migrate { // migrations have been run. fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>>; - // validate the migration - // checks that it does exist on the database and that the checksum matches - fn validate<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result<(), MigrateError>>; - // run SQL from migration in a DDL transaction // insert new row to [_migrations] table on completion (success or failure) // returns the time taking to run the migration SQL diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index ed362da24..69290c089 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -31,3 +31,9 @@ impl Migration { } } } + +#[derive(Debug, Clone)] +pub struct AppliedMigration { + pub version: i64, + pub checksum: Cow<'static, [u8]>, +} diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 9f23fe679..dc77fe7c8 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -1,6 +1,7 @@ use crate::acquire::Acquire; -use crate::migrate::{Migrate, MigrateError, Migration, MigrationSource}; +use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource}; use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; use std::ops::Deref; use std::slice; @@ -9,6 +10,21 @@ pub struct Migrator { pub migrations: Cow<'static, [Migration]>, } +fn validate_applied_migrations( + applied_migrations: &[AppliedMigration], + migrator: &Migrator, +) -> Result<(), MigrateError> { + let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect(); + + for applied_migration in applied_migrations { + if !migrations.contains(&applied_migration.version) { + return Err(MigrateError::VersionMissing(applied_migration.version)); + } + } + + Ok(()) +} + impl Migrator { /// Creates a new instance with the given source. /// @@ -73,17 +89,29 @@ impl Migrator { // eventually this will likely migrate previous versions of the table conn.ensure_migrations_table().await?; - let (version, dirty) = conn.version().await?.unwrap_or((0, false)); - - if dirty { + let version = conn.dirty_version().await?; + if let Some(version) = version { return Err(MigrateError::Dirty(version)); } + let applied_migrations = conn.list_applied_migrations().await?; + validate_applied_migrations(&applied_migrations, self)?; + + let applied_migrations: HashMap<_, _> = applied_migrations + .into_iter() + .map(|m| (m.version, m)) + .collect(); + for migration in self.iter() { - if migration.version > version { - conn.apply(migration).await?; - } else { - conn.validate(migration).await?; + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum { + return Err(MigrateError::VersionMismatch(migration.version)); + } + } + None => { + conn.apply(migration).await?; + } } } diff --git a/sqlx-core/src/migrate/mod.rs b/sqlx-core/src/migrate/mod.rs index 6d72e8cdc..a1095fb8b 100644 --- a/sqlx-core/src/migrate/mod.rs +++ b/sqlx-core/src/migrate/mod.rs @@ -8,7 +8,7 @@ mod source; pub use error::MigrateError; pub use migrate::{Migrate, MigrateDatabase}; -pub use migration::Migration; +pub use migration::{AppliedMigration, Migration}; pub use migration_type::MigrationType; pub use migrator::Migrator; pub use source::MigrationSource; diff --git a/sqlx-core/src/mysql/migrate.rs b/sqlx-core/src/mysql/migrate.rs index 0aedac28d..66894b9c9 100644 --- a/sqlx-core/src/mysql/migrate.rs +++ b/sqlx-core/src/mysql/migrate.rs @@ -2,7 +2,7 @@ use crate::connection::ConnectOptions; use crate::error::Error; use crate::executor::Executor; use crate::migrate::MigrateError; -use crate::migrate::Migration; +use crate::migrate::{AppliedMigration, Migration}; use crate::migrate::{Migrate, MigrateDatabase}; use crate::mysql::{MySql, MySqlConnectOptions, MySqlConnection}; use crate::query::query; @@ -97,16 +97,38 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row = query_as( - "SELECT version, NOT success FROM _sqlx_migrations ORDER BY version DESC LIMIT 1", + let row: Option<(i64,)> = query_as( + "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", ) .fetch_optional(self) .await?; - Ok(row) + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>> { + Box::pin(async move { + // language=SQL + let rows: Vec<(i64, Vec)> = + query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) }) } @@ -146,30 +168,6 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn validate<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result<(), MigrateError>> { - Box::pin(async move { - // language=SQL - let checksum: Option> = - query_scalar("SELECT checksum FROM _sqlx_migrations WHERE version = ?") - .bind(migration.version) - .fetch_optional(self) - .await?; - - if let Some(checksum) = checksum { - return if checksum == &*migration.checksum { - Ok(()) - } else { - Err(MigrateError::VersionMismatch(migration.version)) - }; - } else { - Err(MigrateError::VersionMissing(migration.version)) - } - }) - } - fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, diff --git a/sqlx-core/src/postgres/migrate.rs b/sqlx-core/src/postgres/migrate.rs index d9adcea16..b24966d2c 100644 --- a/sqlx-core/src/postgres/migrate.rs +++ b/sqlx-core/src/postgres/migrate.rs @@ -2,7 +2,7 @@ use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::migrate::MigrateError; -use crate::migrate::Migration; +use crate::migrate::{AppliedMigration, Migration}; use crate::migrate::{Migrate, MigrateDatabase}; use crate::postgres::{PgConnectOptions, PgConnection, Postgres}; use crate::query::query; @@ -107,16 +107,38 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row = query_as( - "SELECT version, NOT success FROM _sqlx_migrations ORDER BY version DESC LIMIT 1", + let row: Option<(i64,)> = query_as( + "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", ) .fetch_optional(self) .await?; - Ok(row) + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>> { + Box::pin(async move { + // language=SQL + let rows: Vec<(i64, Vec)> = + query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) }) } @@ -156,30 +178,6 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn validate<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result<(), MigrateError>> { - Box::pin(async move { - // language=SQL - let checksum: Option> = - query_scalar("SELECT checksum FROM _sqlx_migrations WHERE version = $1") - .bind(migration.version) - .fetch_optional(self) - .await?; - - if let Some(checksum) = checksum { - return if checksum == &*migration.checksum { - Ok(()) - } else { - Err(MigrateError::VersionMismatch(migration.version)) - }; - } else { - Err(MigrateError::VersionMissing(migration.version)) - } - }) - } - fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, diff --git a/sqlx-core/src/sqlite/migrate.rs b/sqlx-core/src/sqlite/migrate.rs index 3eae07cec..315c93454 100644 --- a/sqlx-core/src/sqlite/migrate.rs +++ b/sqlx-core/src/sqlite/migrate.rs @@ -2,11 +2,10 @@ use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::migrate::MigrateError; -use crate::migrate::Migration; +use crate::migrate::{AppliedMigration, Migration}; use crate::migrate::{Migrate, MigrateDatabase}; use crate::query::query; use crate::query_as::query_as; -use crate::query_scalar::query_scalar; use crate::sqlite::{Sqlite, SqliteConnectOptions, SqliteConnection}; use futures_core::future::BoxFuture; use sqlx_rt::fs; @@ -74,16 +73,38 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( }) } - fn version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { + fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let row = query_as( - "SELECT version, NOT success FROM _sqlx_migrations ORDER BY version DESC LIMIT 1", + let row: Option<(i64,)> = query_as( + "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", ) .fetch_optional(self) .await?; - Ok(row) + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations( + &mut self, + ) -> BoxFuture<'_, Result, MigrateError>> { + Box::pin(async move { + // language=SQLite + let rows: Vec<(i64, Vec)> = + query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) }) } @@ -95,30 +116,6 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( Box::pin(async move { Ok(()) }) } - fn validate<'e: 'm, 'm>( - &'e mut self, - migration: &'m Migration, - ) -> BoxFuture<'m, Result<(), MigrateError>> { - Box::pin(async move { - // language=SQL - let checksum: Option> = - query_scalar("SELECT checksum FROM _sqlx_migrations WHERE version = ?1") - .bind(migration.version) - .fetch_optional(self) - .await?; - - if let Some(checksum) = checksum { - if checksum == &*migration.checksum { - Ok(()) - } else { - Err(MigrateError::VersionMismatch(migration.version)) - } - } else { - Err(MigrateError::VersionMissing(migration.version)) - } - }) - } - fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration,