diff --git a/cargo-sqlx/src/database_migrator.rs b/cargo-sqlx/src/database_migrator.rs index 4b5d7b46..9c74e15f 100644 --- a/cargo-sqlx/src/database_migrator.rs +++ b/cargo-sqlx/src/database_migrator.rs @@ -1,17 +1,32 @@ +use anyhow::Result; use async_trait::async_trait; -use anyhow::{Result}; + +#[async_trait] +pub trait MigTrans { + async fn commit(self: Box) -> Result<()>; + async fn rollback(self: Box) -> Result<()>; + async fn check_if_applied(&mut self, migration: &str) -> Result; + async fn execute_migration(&mut self, migration_sql: &str) -> Result<()>; + async fn save_applied_migration(&mut self, migration_name: &str) -> Result<()>; +} #[async_trait] pub trait DatabaseMigrator { + // Misc info fn database_type(&self) -> String; - fn get_database_name(&self) -> Result; + // Features fn can_migrate_database(&self) -> bool; fn can_create_database(&self) -> bool; fn can_drop_database(&self) -> bool; + // Database creation async fn check_if_database_exists(&self, db_name: &str) -> Result; async fn create_database(&self, db_name: &str) -> Result<()>; async fn drop_database(&self, db_name: &str) -> Result<()>; + + // Migration + async fn create_migration_table(&self) -> Result<()>; + async fn begin_migration(&self) -> Result>; } diff --git a/cargo-sqlx/src/main.rs b/cargo-sqlx/src/main.rs index 7d1e9e5a..038d1280 100644 --- a/cargo-sqlx/src/main.rs +++ b/cargo-sqlx/src/main.rs @@ -6,23 +6,17 @@ use url::Url; use dotenv::dotenv; -use sqlx::postgres::PgRow; -use sqlx::Executor; -use sqlx::PgConnection; -use sqlx::PgPool; -use sqlx::Row; - use structopt::StructOpt; use anyhow::{anyhow, Context, Result}; mod database_migrator; mod postgres; -mod sqlite; +// mod sqlite; use database_migrator::DatabaseMigrator; use postgres::Postgres; -use sqlite::Sqlite; +// use sqlite::Sqlite; const MIGRATION_FOLDER: &'static str = "migrations"; @@ -69,13 +63,14 @@ async fn main() -> Result<()> { // This code is taken from: https://github.com/launchbadge/sqlx/blob/master/sqlx-macros/src/lib.rs#L63 match db_url.scheme() { #[cfg(feature = "sqlite")] - "sqlite" => run_command(&Sqlite { db_url: &db_url_raw }).await?, + // "sqlite" => run_command(&Sqlite { db_url: &db_url_raw }).await?, + "sqlite" => return Err(anyhow!("error")), #[cfg(not(feature = "sqlite"))] "sqlite" => return Err(anyhow!("Not implemented. DATABASE_URL {} has the scheme of a SQLite database but the `sqlite` feature of sqlx was not enabled", db_url)), #[cfg(feature = "postgres")] - "postgresql" | "postgres" => run_command(&Postgres { db_url: &db_url_raw }).await?, + "postgresql" | "postgres" => run_command(&Postgres::new(db_url_raw)).await?, #[cfg(not(feature = "postgres"))] "postgresql" | "postgres" => Err(anyhow!("DATABASE_URL {} has the scheme of a Postgres database but the `postgres` feature of sqlx was not enabled", db_url)), @@ -101,7 +96,7 @@ async fn run_command(db_creator: &dyn DatabaseMigrator) -> Result<()> { match opt { Opt::Migrate(command) => match command { MigrationCommand::Add { name } => add_migration_file(&name)?, - MigrationCommand::Run => run_migrations().await?, + MigrationCommand::Run => run_migrations(db_creator).await?, }, Opt::Database(command) => match command { DatabaseCommand::Create => run_create_database(db_creator).await?, @@ -221,80 +216,37 @@ fn load_migrations() -> Result> { Ok(migrations) } -async fn run_migrations() -> Result<()> { - dotenv().ok(); - let db_url = env::var("DATABASE_URL").context("Failed to find 'DATABASE_URL'")?; +async fn run_migrations(db_creator: &dyn DatabaseMigrator) -> Result<()> { + if !db_creator.can_migrate_database() { + return Err(anyhow!( + "Database migrations not implemented for {}", + db_creator.database_type() + )); + } - // if !db_creator.can_create_database() { - // return Err(anyhow!( - // "Database drop is not implemented for {}", - // db_creator.database_type() - // )); - // } - - let mut pool = PgPool::new(&db_url) - .await - .context("Failed to connect to pool")?; - - create_migration_table(&mut pool).await?; + db_creator.create_migration_table().await?; let migrations = load_migrations()?; for mig in migrations.iter() { - let mut tx = pool.begin().await?; + let mut tx = db_creator.begin_migration().await?; - if check_if_applied(&mut tx, &mig.name).await? { + if tx.check_if_applied(&mig.name).await? { println!("Already applied migration: '{}'", mig.name); continue; } println!("Applying migration: '{}'", mig.name); - tx.execute(&*mig.sql) + tx.execute_migration(&mig.sql) .await .with_context(|| format!("Failed to run migration {:?}", &mig.name))?; - save_applied_migration(&mut tx, &mig.name).await?; + tx.save_applied_migration(&mig.name) + .await + .context("Failed to insert migration")?; tx.commit().await.context("Failed")?; } Ok(()) } - -async fn create_migration_table(mut pool: &PgPool) -> Result<()> { - pool.execute( - r#" -CREATE TABLE IF NOT EXISTS __migrations ( - migration VARCHAR (255) PRIMARY KEY, - created TIMESTAMP NOT NULL DEFAULT current_timestamp -); - "#, - ) - .await - .context("Failed to create migration table")?; - - Ok(()) -} - -async fn check_if_applied(connection: &mut PgConnection, migration: &str) -> Result { - let result = sqlx::query( - "select exists(select migration from __migrations where migration = $1) as exists", - ) - .bind(migration.to_string()) - .try_map(|row: PgRow| row.try_get("exists")) - .fetch_one(connection) - .await - .context("Failed to check migration table")?; - - Ok(result) -} - -async fn save_applied_migration(pool: &mut PgConnection, migration: &str) -> Result<()> { - sqlx::query("insert into __migrations (migration) values ($1)") - .bind(migration.to_string()) - .execute(pool) - .await - .context("Failed to insert migration")?; - - Ok(()) -} diff --git a/cargo-sqlx/src/postgres.rs b/cargo-sqlx/src/postgres.rs index 4739bab6..fd5bcc93 100644 --- a/cargo-sqlx/src/postgres.rs +++ b/cargo-sqlx/src/postgres.rs @@ -1,15 +1,26 @@ +use sqlx::pool::PoolConnection; use sqlx::postgres::PgRow; use sqlx::Connect; use sqlx::PgConnection; +use sqlx::PgPool; +use sqlx::Executor; use sqlx::Row; -use async_trait::async_trait; use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; -use crate::database_migrator::DatabaseMigrator; +use crate::database_migrator::{DatabaseMigrator, MigTrans}; -pub struct Postgres<'a> { - pub db_url: &'a str, +pub struct Postgres { + pub db_url: String, +} + +impl Postgres { + pub fn new(db_url: String) -> Self { + Postgres { + db_url: db_url.clone(), + } + } } struct DbUrl<'a> { @@ -30,9 +41,8 @@ fn get_base_url<'a>(db_url: &'a str) -> Result { Ok(DbUrl { base_url, db_name }) } - #[async_trait] -impl DatabaseMigrator for Postgres<'_> { +impl DatabaseMigrator for Postgres { fn database_type(&self) -> String { "Postgres".to_string() } @@ -50,12 +60,12 @@ impl DatabaseMigrator for Postgres<'_> { } fn get_database_name(&self) -> Result { - let db_url = get_base_url(self.db_url)?; + let db_url = get_base_url(&self.db_url)?; Ok(db_url.db_name.to_string()) } async fn check_if_database_exists(&self, db_name: &str) -> Result { - let db_url = get_base_url(self.db_url)?; + let db_url = get_base_url(&self.db_url)?; let base_url = db_url.base_url; @@ -73,7 +83,7 @@ impl DatabaseMigrator for Postgres<'_> { } async fn create_database(&self, db_name: &str) -> Result<()> { - let db_url = get_base_url(self.db_url)?; + let db_url = get_base_url(&self.db_url)?; let base_url = db_url.base_url; @@ -88,7 +98,7 @@ impl DatabaseMigrator for Postgres<'_> { } async fn drop_database(&self, db_name: &str) -> Result<()> { - let db_url = get_base_url(self.db_url)?; + let db_url = get_base_url(&self.db_url)?; let base_url = db_url.base_url; @@ -101,4 +111,76 @@ impl DatabaseMigrator for Postgres<'_> { Ok(()) } + + async fn create_migration_table(&self) -> Result<()> { + let mut conn = PgConnection::connect(&self.db_url).await?; + + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS __migrations ( + migration VARCHAR (255) PRIMARY KEY, + created TIMESTAMP NOT NULL DEFAULT current_timestamp + ); + "#, + ) + .execute(&mut conn) + .await + .context("Failed to create migration table")?; + + Ok(()) + } + + async fn begin_migration(&self) -> Result> { + let pool = PgPool::new(&self.db_url) + .await + .context("Failed to connect to pool")?; + + let tx = pool.begin().await?; + + Ok(Box::new(MigTransaction { transaction: tx })) + } +} + +pub struct MigTransaction { + pub transaction: sqlx::Transaction>, +} + +#[async_trait] +impl MigTrans for MigTransaction { + async fn commit(self: Box) -> Result<()> { + self.transaction.commit().await?; + Ok(()) + } + + async fn rollback(self: Box) -> Result<()> { + self.transaction.rollback().await?; + Ok(()) + } + + async fn check_if_applied(&mut self, migration_name: &str) -> Result { + let result = sqlx::query( + "select exists(select migration from __migrations where migration = $1) as exists", + ) + .bind(migration_name.to_string()) + .try_map(|row: PgRow| row.try_get("exists")) + .fetch_one(&mut self.transaction) + .await + .context("Failed to check migration table")?; + + Ok(result) + } + + async fn execute_migration(&mut self, migration_sql: &str) -> Result<()> { + self.transaction.execute(migration_sql).await?; + Ok(()) + } + + async fn save_applied_migration(&mut self, migration_name: &str) -> Result<()> { + sqlx::query("insert into __migrations (migration) values ($1)") + .bind(migration_name.to_string()) + .execute(&mut self.transaction) + .await + .context("Failed to insert migration")?; + Ok(()) + } } diff --git a/cargo-sqlx/src/sqlite.rs b/cargo-sqlx/src/sqlite.rs index 5c2b7366..aa54de20 100644 --- a/cargo-sqlx/src/sqlite.rs +++ b/cargo-sqlx/src/sqlite.rs @@ -31,74 +31,78 @@ fn get_base_url<'a>(db_url: &'a str) -> Result { } -#[async_trait] -impl DatabaseMigrator for Sqlite<'_> { - fn database_type(&self) -> String { - "Sqlite".to_string() - } +// #[async_trait] +// impl DatabaseMigrator for Sqlite<'_> { +// fn database_type(&self) -> String { +// "Sqlite".to_string() +// } - fn can_migrate_database(&self) -> bool { - false - } +// fn can_migrate_database(&self) -> bool { +// false +// } - fn can_create_database(&self) -> bool { - false - } +// fn can_create_database(&self) -> bool { +// false +// } - fn can_drop_database(&self) -> bool { - false - } +// fn can_drop_database(&self) -> bool { +// false +// } - fn get_database_name(&self) -> Result { - let db_url = get_base_url(self.db_url)?; - Ok(db_url.db_name.to_string()) - } +// fn get_database_name(&self) -> Result { +// let db_url = get_base_url(self.db_url)?; +// Ok(db_url.db_name.to_string()) +// } - async fn check_if_database_exists(&self, db_name: &str) -> Result { - let db_url = get_base_url(self.db_url)?; +// async fn check_if_database_exists(&self, db_name: &str) -> Result { +// let db_url = get_base_url(self.db_url)?; - let base_url = db_url.base_url; +// let base_url = db_url.base_url; - let mut conn = PgConnection::connect(base_url).await?; +// let mut conn = PgConnection::connect(base_url).await?; - let result: bool = - sqlx::query("select exists(SELECT 1 from pg_database WHERE datname = $1) as exists") - .bind(db_name) - .try_map(|row: PgRow| row.try_get("exists")) - .fetch_one(&mut conn) - .await - .context("Failed to check if database exists")?; +// let result: bool = +// sqlx::query("select exists(SELECT 1 from pg_database WHERE datname = $1) as exists") +// .bind(db_name) +// .try_map(|row: PgRow| row.try_get("exists")) +// .fetch_one(&mut conn) +// .await +// .context("Failed to check if database exists")?; - Ok(result) - } +// Ok(result) +// } - async fn create_database(&self, db_name: &str) -> Result<()> { - let db_url = get_base_url(self.db_url)?; +// async fn create_database(&self, db_name: &str) -> Result<()> { +// let db_url = get_base_url(self.db_url)?; - let base_url = db_url.base_url; +// let base_url = db_url.base_url; - let mut conn = PgConnection::connect(base_url).await?; +// let mut conn = PgConnection::connect(base_url).await?; - sqlx::query(&format!("CREATE DATABASE {}", db_name)) - .execute(&mut conn) - .await - .with_context(|| format!("Failed to create database: {}", db_name))?; +// sqlx::query(&format!("CREATE DATABASE {}", db_name)) +// .execute(&mut conn) +// .await +// .with_context(|| format!("Failed to create database: {}", db_name))?; - Ok(()) - } +// Ok(()) +// } - async fn drop_database(&self, db_name: &str) -> Result<()> { - let db_url = get_base_url(self.db_url)?; +// async fn drop_database(&self, db_name: &str) -> Result<()> { +// let db_url = get_base_url(self.db_url)?; - let base_url = db_url.base_url; +// let base_url = db_url.base_url; - let mut conn = PgConnection::connect(base_url).await?; +// let mut conn = PgConnection::connect(base_url).await?; - sqlx::query(&format!("DROP DATABASE {}", db_name)) - .execute(&mut conn) - .await - .with_context(|| format!("Failed to create database: {}", db_name))?; +// sqlx::query(&format!("DROP DATABASE {}", db_name)) +// .execute(&mut conn) +// .await +// .with_context(|| format!("Failed to create database: {}", db_name))?; - Ok(()) - } -} +// Ok(()) +// } + +// async fn create_migration_table(&self) -> Result<()> { +// Err(anyhow!("Not supported")) +// } +// }