feat: create sqlx.toml format (#3383)

* feat: create `sqlx.toml` format

* feat: add support for ignored_chars config to sqlx_core::migrate

* chore: test ignored_chars with `U+FEFF` (ZWNBSP/BOM)

https://en.wikipedia.org/wiki/Byte_order_mark

* refactor: make `Config` always compiled

simplifies usage while still making parsing optional for less generated code

* refactor: add origin information to `Column`

* feat(macros): implement `type_override` and `column_override` from `sqlx.toml`

* refactor(sqlx.toml): make all keys kebab-case, create `macros.preferred-crates`

* feat: make macros aware of `macros.preferred-crates`

* feat: make `sqlx-cli` aware of `database-url-var`

* feat: teach macros about `migrate.table-name`, `migrations-dir`

* feat: teach macros about `migrate.ignored-chars`

* chore: delete unused source file `sqlx-cli/src/migration.rs`

* feat: teach `sqlx-cli` about `migrate.defaults`

* feat: teach `sqlx-cli` about `migrate.migrations-dir`

* feat: teach `sqlx-cli` about `migrate.table-name`

* feat: introduce `migrate.create-schemas`

* WIP feat: create multi-tenant database example

* fix(postgres): don't fetch `ColumnOrigin` for transparently-prepared statements

* feat: progress on axum-multi-tenant example

* feat(config): better errors for mislabeled fields

* WIP feat: filling out axum-multi-tenant example

* feat: multi-tenant example

No longer Axum-based because filling out the request routes would have distracted from the purpose of the example.

* chore(ci): test multi-tenant example

* fixup after merge

* fix(ci): enable `sqlx-toml` in CLI build for examples

* fix: CI, README for `multi-tenant`

* fix: clippy warnings

* fix: multi-tenant README

* fix: sequential versioning inference for migrations

* fix: migration versioning with explicit overrides

* fix: only warn on ambiguous crates if the invocation relies on it

* fix: remove unused imports

* fix: doctest

* fix: `sqlx mig add` behavior and tests

* fix: restore original type-checking order

* fix: deprecation warning in `tests/postgres/macros.rs`

* feat: create postgres/multi-database example

* fix: examples/postgres/multi-database

* fix: cargo fmt

* chore: add tests for config `migrate.defaults`

* fix: sqlx-cli/tests/add.rs

* feat(cli): add `--config` override to all relevant commands

* chore: run `sqlx mig add` test with `RUST_BACKTRACE=1`

* fix: properly canonicalize config path for `sqlx mig add` test

* fix: get `sqlx mig add` test passing

* fix(cli): test `migrate.ignored-chars`, fix bugs

* feat: create `macros.preferred-crates` example

* fix(examples): use workspace `sqlx`

* fix: examples

* fix(sqlite): unexpected feature flags in `type_checking.rs`

* fix: run `cargo fmt`

* fix: more example fixes

* fix(ci): preferred-crates setup

* fix(examples): enable default-features for workspace `sqlx`

* fix(examples): issues in `preferred-crates`

* chore: adjust error message for missing param type in `query!()`

* doc: mention new `sqlx.toml` configuration

* chore: add `CHANGELOG` entry

Normally I generate these when cutting the release, but I wanted to take time to editorialize this one.

* doc: fix new example titles

* refactor: make `sqlx-toml` feature non-default, improve errors

* refactor: eliminate panics in `Config` read path

* chore: remove unused `axum` dependency from new examples

* fix(config): restore fallback to default config for macros

* chore(config): remove use of `once_cell` (to match `main`)
This commit is contained in:
Austin Bonander 2025-06-30 16:34:46 -07:00 committed by GitHub
parent 764ae2f702
commit 25cbeedab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
127 changed files with 6443 additions and 1138 deletions

View File

@ -27,7 +27,7 @@ jobs:
--bin sqlx
--release
--no-default-features
--features mysql,postgres,sqlite
--features mysql,postgres,sqlite,sqlx-toml
- uses: actions/upload-artifact@v4
with:
@ -175,6 +175,49 @@ jobs:
DATABASE_URL: postgres://postgres:password@localhost:5432/mockable-todos
run: cargo run -p sqlx-example-postgres-mockable-todos
- name: Multi-Database (Setup)
working-directory: examples/postgres/multi-database
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database
ACCOUNTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-accounts
PAYMENTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-payments
run: |
(cd accounts && sqlx db setup)
(cd payments && sqlx db setup)
sqlx db setup
- name: Multi-Database (Run)
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database
ACCOUNTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-accounts
PAYMENTS_DATABASE_URL: postgres://postgres:password@localhost:5432/multi-database-payments
run: cargo run -p sqlx-example-postgres-multi-database
- name: Multi-Tenant (Setup)
working-directory: examples/postgres/multi-tenant
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant
run: |
(cd accounts && sqlx db setup)
(cd payments && sqlx migrate run)
sqlx migrate run
- name: Multi-Tenant (Run)
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/multi-tenant
run: cargo run -p sqlx-example-postgres-multi-tenant
- name: Preferred-Crates (Setup)
working-directory: examples/postgres/preferred-crates
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/preferred-crates
run: sqlx db setup
- name: Multi-Tenant (Run)
env:
DATABASE_URL: postgres://postgres:password@localhost:5432/preferred-crates
run: cargo run -p sqlx-example-postgres-preferred-crates
- name: TODOs (Setup)
working-directory: examples/postgres/todos
env:

View File

@ -13,16 +13,42 @@ This section will be replaced in subsequent alpha releases. See the Git history
### Breaking
* [[#3821]] Groundwork for 0.9.0-alpha.1
* Increased MSRV to 1.86 and set rust-version [@abonander]
* [[#3821]]: Groundwork for 0.9.0-alpha.1 [[@abonander]]
* Increased MSRV to 1.86 and set rust-version
* Deleted deprecated combination runtime+TLS features (e.g. `runtime-tokio-native-tls`)
* Deleted re-export of unstable `TransactionManager` trait in `sqlx`.
* Not technically a breaking change because it's `#[doc(hidden)]`,
but [it _will_ break SeaORM][seaorm-2600] if not proactively fixed.
* [[#3383]]: feat: create `sqlx.toml` format [[@abonander]]
* SQLx and `sqlx-cli` now support per-crate configuration files (`sqlx.toml`)
* New functionality includes, but is not limited to:
* Rename `DATABASE_URL` for a crate (for multi-database workspaces)
* Set global type overrides for the macros (supporting custom types)
* Rename or relocate the `_sqlx_migrations` table (for multiple crates using the same database)
* Set characters to ignore when hashing migrations (e.g. ignore whitespace)
* More to be implemented in future releases.
* Enable feature `sqlx-toml` to use.
* `sqlx-cli` has it enabled by default, but `sqlx` does **not**.
* Default features of library crates can be hard to completely turn off because of [feature unification],
so it's better to keep the default feature set as limited as possible.
[This is something we learned the hard way.][preferred-crates]
* Guide: see `sqlx::_config` module in documentation.
* Reference: [[Link](sqlx-core/src/config/reference.toml)]
* Examples (written for Postgres but can be adapted to other databases; PRs welcome!):
* Multiple databases using `DATABASE_URL` renaming and global type overrides: [[Link](examples/postgres/multi-database)]
* Multi-tenant database using `_sqlx_migrations` renaming and multiple schemas: [[Link](examples/postgres/multi-tenant)]
* Force use of `chrono` when `time` is enabled (e.g. when using `tower-sessions-sqlx-store`): [[Link][preferred-crates]]
* Forcing `bigdecimal` when `rust_decimal` is enabled is also shown, but problems with `chrono`/`time` are more common.
* **Breaking changes**:
* Significant changes to the `Migrate` trait
* `sqlx::migrate::resolve_blocking()` is now `#[doc(hidden)]` and thus SemVer-exempt.
[seaorm-2600]: https://github.com/SeaQL/sea-orm/issues/2600
[feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification
[preferred-crates]: examples/postgres/preferred-crates
[#3821]: https://github.com/launchbadge/sqlx/pull/3830
[#3821]: https://github.com/launchbadge/sqlx/pull/3821
[#3383]: https://github.com/launchbadge/sqlx/pull/3383
## 0.8.6 - 2025-05-19

315
Cargo.lock generated
View File

@ -4,18 +4,18 @@ version = 4
[[package]]
name = "addr2line"
version = "0.24.2"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
dependencies = [
"gimli",
]
[[package]]
name = "adler2"
version = "2.0.0"
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
@ -127,7 +127,19 @@ checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73"
dependencies = [
"base64ct",
"blake2",
"password-hash",
"password-hash 0.4.2",
]
[[package]]
name = "argon2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash 0.5.0",
]
[[package]]
@ -438,17 +450,17 @@ dependencies = [
[[package]]
name = "backtrace"
version = "0.3.74"
version = "0.3.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a"
checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
"windows-targets 0.52.6",
]
[[package]]
@ -742,8 +754,10 @@ checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets 0.52.6",
]
@ -844,6 +858,33 @@ dependencies = [
"cc",
]
[[package]]
name = "color-eyre"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5"
dependencies = [
"backtrace",
"color-spantrace",
"eyre",
"indenter",
"once_cell",
"owo-colors",
"tracing-error",
]
[[package]]
name = "color-spantrace"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2"
dependencies = [
"once_cell",
"owo-colors",
"tracing-core",
"tracing-error",
]
[[package]]
name = "colorchoice"
version = "1.0.3"
@ -1276,6 +1317,16 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eyre"
version = "0.6.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec"
dependencies = [
"indenter",
"once_cell",
]
[[package]]
name = "fastrand"
version = "1.9.0"
@ -1526,9 +1577,9 @@ dependencies = [
[[package]]
name = "gimli"
version = "0.31.1"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
@ -1897,6 +1948,12 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed"
[[package]]
name = "indenter"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
[[package]]
name = "indexmap"
version = "1.9.3"
@ -2198,11 +2255,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.2"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
dependencies = [
"adler2",
"adler",
]
[[package]]
@ -2301,6 +2358,16 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
@ -2366,9 +2433,9 @@ dependencies = [
[[package]]
name = "object"
version = "0.36.7"
version = "0.32.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441"
dependencies = [
"memchr",
]
@ -2439,6 +2506,18 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "owo-colors"
version = "3.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f"
[[package]]
name = "parking"
version = "2.2.1"
@ -2479,6 +2558,17 @@ dependencies = [
"subtle",
]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.15"
@ -3165,18 +3255,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.217"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.217"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
dependencies = [
"proc-macro2",
"quote",
@ -3275,6 +3365,15 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shell-words"
version = "1.1.0"
@ -3487,6 +3586,7 @@ dependencies = [
"time",
"tokio",
"tokio-stream",
"toml",
"tracing",
"url",
"uuid",
@ -3508,7 +3608,7 @@ name = "sqlx-example-postgres-axum-social"
version = "0.1.0"
dependencies = [
"anyhow",
"argon2",
"argon2 0.4.1",
"axum",
"dotenvy",
"rand",
@ -3582,6 +3682,123 @@ dependencies = [
"tokio",
]
[[package]]
name = "sqlx-example-postgres-multi-database"
version = "0.9.0-alpha.1"
dependencies = [
"color-eyre",
"dotenvy",
"rand",
"rust_decimal",
"sqlx",
"sqlx-example-postgres-multi-database-accounts",
"sqlx-example-postgres-multi-database-payments",
"tokio",
"tracing-subscriber",
]
[[package]]
name = "sqlx-example-postgres-multi-database-accounts"
version = "0.1.0"
dependencies = [
"argon2 0.5.3",
"password-hash 0.5.0",
"rand",
"serde",
"sqlx",
"thiserror 1.0.69",
"time",
"tokio",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-multi-database-payments"
version = "0.1.0"
dependencies = [
"rust_decimal",
"sqlx",
"sqlx-example-postgres-multi-database-accounts",
"time",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-multi-tenant"
version = "0.9.0-alpha.1"
dependencies = [
"color-eyre",
"dotenvy",
"rand",
"rust_decimal",
"sqlx",
"sqlx-example-postgres-multi-tenant-accounts",
"sqlx-example-postgres-multi-tenant-payments",
"tokio",
"tracing-subscriber",
]
[[package]]
name = "sqlx-example-postgres-multi-tenant-accounts"
version = "0.1.0"
dependencies = [
"argon2 0.5.3",
"password-hash 0.5.0",
"rand",
"serde",
"sqlx",
"thiserror 1.0.69",
"time",
"tokio",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-multi-tenant-payments"
version = "0.1.0"
dependencies = [
"rust_decimal",
"sqlx",
"sqlx-example-postgres-multi-tenant-accounts",
"time",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-preferred-crates"
version = "0.9.0-alpha.1"
dependencies = [
"anyhow",
"chrono",
"dotenvy",
"serde",
"sqlx",
"sqlx-example-postgres-preferred-crates-uses-rust-decimal",
"sqlx-example-postgres-preferred-crates-uses-time",
"tokio",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-preferred-crates-uses-rust-decimal"
version = "0.9.0-alpha.1"
dependencies = [
"chrono",
"rust_decimal",
"sqlx",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-preferred-crates-uses-time"
version = "0.9.0-alpha.1"
dependencies = [
"serde",
"sqlx",
"time",
"uuid",
]
[[package]]
name = "sqlx-example-postgres-todos"
version = "0.1.0"
@ -4047,6 +4264,16 @@ dependencies = [
"syn 2.0.96",
]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "time"
version = "0.3.37"
@ -4264,6 +4491,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-error"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db"
dependencies = [
"tracing",
"tracing-subscriber",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
@ -4392,9 +4655,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.11.1"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4"
checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587"
dependencies = [
"serde",
]
@ -4441,6 +4704,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "value-bag"
version = "1.10.0"

View File

@ -16,8 +16,11 @@ members = [
"examples/postgres/files",
"examples/postgres/json",
"examples/postgres/listen",
"examples/postgres/todos",
"examples/postgres/mockable-todos",
"examples/postgres/multi-database",
"examples/postgres/multi-tenant",
"examples/postgres/preferred-crates",
"examples/postgres/todos",
"examples/postgres/transaction",
"examples/sqlite/todos",
]
@ -51,7 +54,7 @@ repository.workspace = true
rust-version.workspace = true
[package.metadata.docs.rs]
features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"]
features = ["all-databases", "_unstable-all-types", "_unstable-doc", "sqlite-preupdate-hook"]
rustdoc-args = ["--cfg", "docsrs"]
[features]
@ -61,6 +64,9 @@ derive = ["sqlx-macros/derive"]
macros = ["derive", "sqlx-macros/macros"]
migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"]
# Enable parsing of `sqlx.toml` for configuring macros and migrations.
sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-macros?/sqlx-toml"]
# intended mainly for CI and docs
all-databases = ["mysql", "sqlite", "postgres", "any"]
_unstable-all-types = [
@ -76,6 +82,8 @@ _unstable-all-types = [
"bit-vec",
"bstr"
]
# Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`).
_unstable-doc = []
# Base runtime features without TLS
runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"]
@ -132,7 +140,7 @@ sqlx-postgres = { version = "=0.9.0-alpha.1", path = "sqlx-postgres" }
sqlx-sqlite = { version = "=0.9.0-alpha.1", path = "sqlx-sqlite" }
# Facade crate (for reference from sqlx-cli)
sqlx = { version = "=0.9.0-alpha.1", path = ".", default-features = false }
sqlx = { version = "=0.9.0-alpha.1", path = "." }
# Common type integrations shared by multiple driver crates.
# These are optional unless enabled in a workspace crate.

View File

@ -0,0 +1,36 @@
[package]
name = "sqlx-example-postgres-multi-database"
version.workspace = true
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
authors.workspace = true
[dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
color-eyre = "0.6.3"
dotenvy = "0.15.7"
tracing-subscriber = "0.3.19"
rust_decimal = "1.36.0"
rand = "0.8.5"
[dependencies.sqlx]
# version = "0.9.0"
workspace = true
features = ["runtime-tokio", "postgres", "migrate", "sqlx-toml"]
[dependencies.accounts]
path = "accounts"
package = "sqlx-example-postgres-multi-database-accounts"
[dependencies.payments]
path = "payments"
package = "sqlx-example-postgres-multi-database-payments"
[lints]
workspace = true

View File

@ -0,0 +1,62 @@
# Using Multiple Databases with `sqlx.toml`
This example project involves three crates, each owning a different schema in one database,
with their own set of migrations.
* The main crate, a simple binary simulating the action of a REST API.
* Owns the `public` schema (tables are referenced unqualified).
* Migrations are moved to `src/migrations` using config key `migrate.migrations-dir`
to visually separate them from the subcrate folders.
* `accounts`: a subcrate simulating a reusable account-management crate.
* Owns schema `accounts`.
* `payments`: a subcrate simulating a wrapper for a payments API.
* Owns schema `payments`.
## Note: Schema-Qualified Names
This example uses schema-qualified names everywhere for clarity.
It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema
prefixes, but this can cause some really confusing issues when names conflict.
This example will generate a `_sqlx_migrations` table in three different schemas; if `search_path` is set
to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified,
it would throw an error.
# Setup
This example requires running three different sets of migrations.
Ensure `sqlx-cli` is installed with Postgres and `sqlx.toml` support:
```
cargo install sqlx-cli --features postgres,sqlx-toml
```
Start a Postgres server (shown here using Docker, `run` command also works with `podman`):
```
docker run -d -e POSTGRES_PASSWORD=password -p 5432:5432 --name postgres postgres:latest
```
Create `.env` with the various database URLs or set them in your shell environment;
```
DATABASE_URL=postgres://postgres:password@localhost/example-multi-database
ACCOUNTS_DATABASE_URL=postgres://postgres:password@localhost/example-multi-database-accounts
PAYMENTS_DATABASE_URL=postgres://postgres:password@localhost/example-multi-database-payments
```
Run the following commands:
```
(cd accounts && sqlx db setup)
(cd payments && sqlx db setup)
sqlx db setup
```
It is an open question how to make this more convenient; `sqlx-cli` could gain a `--recursive` flag that checks
subdirectories for `sqlx.toml` files, but that would only work for crates within the same workspace. If the `accounts`
and `payments` crates were instead crates.io dependencies, we would need Cargo's help to resolve that information.
An issue has been opened for discussion: <https://github.com/launchbadge/sqlx/issues/3761>

View File

@ -0,0 +1,22 @@
[package]
name = "sqlx-example-postgres-multi-database-accounts"
version = "0.1.0"
edition = "2021"
[dependencies]
sqlx = { workspace = true, features = ["postgres", "time", "uuid", "macros", "sqlx-toml"] }
tokio = { version = "1", features = ["rt", "sync"] }
argon2 = { version = "0.5.3", features = ["password-hash"] }
password-hash = { version = "0.5", features = ["std"] }
uuid = { version = "1", features = ["serde"] }
thiserror = "1"
rand = "0.8"
time = { version = "0.3.37", features = ["serde"] }
serde = { version = "1.0.218", features = ["derive"] }
[dev-dependencies]
sqlx = { workspace = true, features = ["runtime-tokio"] }

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,10 @@
create table account
(
account_id uuid primary key default gen_random_uuid(),
email text unique not null,
password_hash text not null,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select trigger_updated_at('account');

View File

@ -0,0 +1,6 @@
create table session
(
session_token text primary key, -- random alphanumeric string
account_id uuid not null references account (account_id),
created_at timestamptz not null default now()
);

View File

@ -0,0 +1,10 @@
[common]
database-url-var = "ACCOUNTS_DATABASE_URL"
[macros.table-overrides.'account']
'account_id' = "crate::AccountId"
'password_hash' = "sqlx::types::Text<password_hash::PasswordHashString>"
[macros.table-overrides.'session']
'session_token' = "crate::SessionToken"
'account_id' = "crate::AccountId"

View File

@ -0,0 +1,293 @@
use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier};
use password_hash::PasswordHashString;
use rand::distributions::{Alphanumeric, DistString};
use sqlx::PgPool;
use std::sync::Arc;
use uuid::Uuid;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use tokio::sync::Semaphore;
#[derive(sqlx::Type, Copy, Clone, Debug, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
pub struct AccountId(pub Uuid);
#[derive(sqlx::Type, Clone, Debug, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
pub struct SessionToken(pub String);
pub struct Session {
pub account_id: AccountId,
pub session_token: SessionToken,
}
#[derive(Clone)]
pub struct AccountsManager {
/// To prevent confusion, each crate manages its own database connection pool.
pool: PgPool,
/// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing.
///
/// ### Motivation
/// Tokio blocking tasks are generally not designed for CPU-bound work.
///
/// If no threads are idle, Tokio will automatically spawn new ones to handle
/// new blocking tasks up to a very high limit--512 by default.
///
/// This is because blocking tasks are expected to spend their time *blocked*, e.g. on
/// blocking I/O, and thus not consume CPU resources or require a lot of context switching.
///
/// This strategy is not the most efficient way to use threads for CPU-bound work, which
/// should schedule work to a fixed number of threads to minimize context switching
/// and memory usage (each new thread needs significant space allocated for its stack).
///
/// We can work around this by using a purpose-designed thread-pool, like Rayon,
/// but we still have the problem that those APIs usually are not designed to support `async`,
/// so we end up needing blocking tasks anyway, or implementing our own work queue using
/// channels. Rayon also does not shut down idle worker threads.
///
/// `block_in_place` is not a silver bullet, either, as it simply uses `spawn_blocking`
/// internally to take over from the current thread while it is executing blocking work.
/// This also prevents futures from being polled concurrently in the current task.
///
/// We can lower the limit for blocking threads when creating the runtime, but this risks
/// starving other blocking tasks that are being created by the application or the Tokio
/// runtime itself
/// (which are used for `tokio::fs`, stdio, resolving of hostnames by `ToSocketAddrs`, etc.).
///
/// Instead, we can just use a Semaphore to limit how many blocking tasks are spawned at once,
/// emulating the behavior of a thread pool like Rayon without needing any additional crates.
hashing_semaphore: Arc<Semaphore>,
}
#[derive(Debug, thiserror::Error)]
pub enum CreateAccountError {
#[error("error creating account: email in-use")]
EmailInUse,
#[error("error creating account")]
General(
#[source]
#[from]
GeneralError,
),
}
#[derive(Debug, thiserror::Error)]
pub enum CreateSessionError {
#[error("unknown email")]
UnknownEmail,
#[error("invalid password")]
InvalidPassword,
#[error("authentication error")]
General(
#[source]
#[from]
GeneralError,
),
}
#[derive(Debug, thiserror::Error)]
pub enum GeneralError {
#[error("database error")]
Sqlx(
#[source]
#[from]
sqlx::Error,
),
#[error("error hashing password")]
PasswordHash(
#[source]
#[from]
password_hash::Error,
),
#[error("task panicked")]
Task(
#[source]
#[from]
tokio::task::JoinError,
),
}
impl AccountsManager {
pub async fn setup(
opts: PgConnectOptions,
max_hashing_threads: usize,
) -> Result<Self, GeneralError> {
// This should be configurable by the caller, but for simplicity, it's not.
let pool = PgPoolOptions::new()
.max_connections(5)
.connect_with(opts)
.await?;
sqlx::migrate!()
.run(&pool)
.await
.map_err(sqlx::Error::from)?;
Ok(AccountsManager {
pool,
hashing_semaphore: Semaphore::new(max_hashing_threads).into(),
})
}
async fn hash_password(&self, password: String) -> Result<PasswordHashString, GeneralError> {
let guard = self
.hashing_semaphore
.clone()
.acquire_owned()
.await
.expect("BUG: this semaphore should not be closed");
// We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn
// excess threads.
let (_guard, res) = tokio::task::spawn_blocking(move || {
let salt = password_hash::SaltString::generate(rand::thread_rng());
(
guard,
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map(|hash| hash.serialize()),
)
})
.await?;
Ok(res?)
}
async fn verify_password(
&self,
password: String,
hash: PasswordHashString,
) -> Result<(), CreateSessionError> {
let guard = self
.hashing_semaphore
.clone()
.acquire_owned()
.await
.expect("BUG: this semaphore should not be closed");
let (_guard, res) = tokio::task::spawn_blocking(move || {
(
guard,
Argon2::default().verify_password(password.as_bytes(), &hash.password_hash()),
)
})
.await
.map_err(GeneralError::from)?;
if let Err(password_hash::Error::Password) = res {
return Err(CreateSessionError::InvalidPassword);
}
res.map_err(GeneralError::from)?;
Ok(())
}
pub async fn create(
&self,
email: &str,
password: String,
) -> Result<AccountId, CreateAccountError> {
// Hash password whether the account exists or not to make it harder
// to tell the difference in the timing.
let hash = self.hash_password(password).await?;
// Thanks to `sqlx.toml`, `account_id` maps to `AccountId`
sqlx::query_scalar!(
// language=PostgreSQL
"insert into account(email, password_hash) \
values ($1, $2) \
returning account_id",
email,
hash.as_str(),
)
.fetch_one(&self.pool)
.await
.map_err(|e| {
if e.as_database_error().and_then(|dbe| dbe.constraint())
== Some("account_account_id_key")
{
CreateAccountError::EmailInUse
} else {
GeneralError::from(e).into()
}
})
}
pub async fn create_session(
&self,
email: &str,
password: String,
) -> Result<Session, CreateSessionError> {
let mut txn = self.pool.begin().await.map_err(GeneralError::from)?;
// To save a round-trip to the database, we'll speculatively insert the session token
// at the same time as we're looking up the password hash.
//
// This does nothing until the transaction is actually committed.
let session_token = SessionToken::generate();
// Thanks to `sqlx.toml`:
// * `account_id` maps to `AccountId`
// * `password_hash` maps to `Text<PasswordHashString>`
// * `session_token` maps to `SessionToken`
let maybe_account = sqlx::query!(
// language=PostgreSQL
"with account as (
select account_id, password_hash \
from account \
where email = $1
), session as (
insert into session(session_token, account_id)
select $2, account_id
from account
)
select account.account_id, account.password_hash from account",
email,
session_token.0
)
.fetch_optional(&mut *txn)
.await
.map_err(GeneralError::from)?;
let Some(account) = maybe_account else {
// Hash the password whether the account exists or not to hide the difference in timing.
self.hash_password(password)
.await
.map_err(GeneralError::from)?;
return Err(CreateSessionError::UnknownEmail);
};
self.verify_password(password, account.password_hash.into_inner())
.await?;
txn.commit().await.map_err(GeneralError::from)?;
Ok(Session {
account_id: account.account_id,
session_token,
})
}
pub async fn auth_session(
&self,
session_token: &str,
) -> Result<Option<AccountId>, GeneralError> {
sqlx::query_scalar!(
"select account_id from session where session_token = $1",
session_token
)
.fetch_optional(&self.pool)
.await
.map_err(GeneralError::from)
}
}
impl SessionToken {
const LEN: usize = 32;
fn generate() -> Self {
SessionToken(Alphanumeric.sample_string(&mut rand::thread_rng(), Self::LEN))
}
}

View File

@ -0,0 +1,20 @@
[package]
name = "sqlx-example-postgres-multi-database-payments"
version = "0.1.0"
edition = "2021"
[dependencies]
sqlx = { workspace = true, features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml"] }
rust_decimal = "1.36.0"
time = "0.3.37"
uuid = "1.12.1"
[dependencies.accounts]
path = "../accounts"
package = "sqlx-example-postgres-multi-database-accounts"
[dev-dependencies]
sqlx = { workspace = true, features = ["runtime-tokio"] }

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,59 @@
-- `payments::PaymentStatus`
--
-- Historically at LaunchBadge we preferred not to define enums on the database side because it can be annoying
-- and error-prone to keep them in-sync with the application.
-- Instead, we let the application define the enum and just have the database store a compact representation of it.
-- This is mostly a matter of taste, however.
--
-- For the purposes of this example, we're using an in-database enum because this is a common use-case
-- for needing type overrides.
create type payment_status as enum (
'pending',
'created',
'success',
'failed'
);
create table payment
(
payment_id uuid primary key default gen_random_uuid(),
-- Since `account` is in a separate database, we can't foreign-key to it.
account_id uuid not null,
status payment_status not null,
-- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes)
--
-- This *could* be an ENUM of currency codes, but constraining this to a set of known values in the database
-- would be annoying to keep up to date as support for more currencies is added.
--
-- Consider also if support for cryptocurrencies is desired; those are not covered by ISO 4217.
--
-- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)`
-- all use the same storage format in Postgres. Any constraint against the length of this field
-- would purely be a sanity check.
currency text not null,
-- There's an endless debate about what type should be used to represent currency amounts.
--
-- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly
-- optimized for storing USD, or other currencies with a minimum fraction of 1 cent.
--
-- NEVER use `FLOAT` or `DOUBLE`. IEEE-754 rounding point has round-off and precision errors that make it wholly
-- unsuitable for representing real money amounts.
--
-- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency,
-- and so is what we've chosen here.
amount NUMERIC not null,
-- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.),
-- so imagine this is an identifier string for this payment in such a vendor's systems.
--
-- For privacy and security reasons, payment and personally-identifying information
-- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor
-- unless there is a good reason otherwise.
external_payment_id text,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select trigger_updated_at('payment');

View File

@ -0,0 +1,9 @@
[common]
database-url-var = "PAYMENTS_DATABASE_URL"
[macros.table-overrides.'payment']
'payment_id' = "crate::PaymentId"
'account_id' = "accounts::AccountId"
[macros.type-overrides]
'payment_status' = "crate::PaymentStatus"

View File

@ -0,0 +1,127 @@
use accounts::{AccountId, AccountsManager};
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use sqlx::{Acquire, PgConnection, PgPool, Postgres};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(sqlx::Type, Copy, Clone, Debug)]
#[sqlx(transparent)]
pub struct PaymentId(pub Uuid);
#[derive(sqlx::Type, Copy, Clone, Debug)]
#[sqlx(type_name = "payment_status")]
#[sqlx(rename_all = "snake_case")]
pub enum PaymentStatus {
Pending,
Created,
Success,
Failed,
}
// Users often assume that they need `#[derive(FromRow)]` to use `query_as!()`,
// then are surprised when the derive's control attributes have no effect.
// The macros currently do *not* use the `FromRow` trait at all.
// Support for `FromRow` is planned, but would require significant changes to the macros.
// See https://github.com/launchbadge/sqlx/issues/514 for details.
#[derive(Clone, Debug)]
pub struct Payment {
pub payment_id: PaymentId,
pub account_id: AccountId,
pub status: PaymentStatus,
pub currency: String,
// `rust_decimal::Decimal` has more than enough precision for any real-world amount of money.
pub amount: rust_decimal::Decimal,
pub external_payment_id: Option<String>,
pub created_at: OffsetDateTime,
pub updated_at: Option<OffsetDateTime>,
}
pub struct PaymentsManager {
pool: PgPool,
}
impl PaymentsManager {
pub async fn setup(opts: PgConnectOptions) -> sqlx::Result<Self> {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect_with(opts)
.await?;
sqlx::migrate!().run(&pool).await?;
Ok(Self { pool })
}
/// # Note
/// For simplicity, this does not ensure that `account_id` actually exists.
pub async fn create(
&self,
account_id: AccountId,
currency: &str,
amount: rust_decimal::Decimal,
) -> sqlx::Result<Payment> {
// Check-out a connection to avoid paying the overhead of acquiring one for each call.
let mut conn = self.pool.acquire().await?;
// Imagine this method does more than just create a record in the database;
// maybe it actually initiates the payment with a third-party vendor, like Stripe.
//
// We need to ensure that we can link the payment in the vendor's systems back to a record
// in ours, even if any of the following happens:
// * The application dies before storing the external payment ID in the database
// * We lose the connection to the database while trying to commit a transaction
// * The database server dies while committing the transaction
//
// Thus, we create the payment in three atomic phases:
// * We create the payment record in our system and commit it.
// * We create the payment in the vendor's system with our payment ID attached.
// * We update our payment record with the vendor's payment ID.
let payment_id = sqlx::query_scalar!(
"insert into payment(account_id, status, currency, amount) \
values ($1, $2, $3, $4) \
returning payment_id",
// The database doesn't give us enough information to correctly typecheck `AccountId` here.
// We have to insert the UUID directly.
account_id.0,
PaymentStatus::Pending,
currency,
amount,
)
.fetch_one(&mut *conn)
.await?;
// We then create the record with the payment vendor...
let external_payment_id = "foobar1234";
// Then we store the external payment ID and update the payment status.
//
// NOTE: use caution with `select *` or `returning *`;
// the order of columns gets baked into the binary, so if it changes between compile time and
// run-time, you may run into errors.
let payment = sqlx::query_as!(
Payment,
"update payment \
set status = $1, external_payment_id = $2 \
where payment_id = $3 \
returning *",
PaymentStatus::Created,
external_payment_id,
payment_id.0,
)
.fetch_one(&mut *conn)
.await?;
Ok(payment)
}
pub async fn get(&self, payment_id: PaymentId) -> sqlx::Result<Option<Payment>> {
sqlx::query_as!(
Payment,
// see note above about `select *`
"select * from payment where payment_id = $1",
payment_id.0
)
.fetch_optional(&self.pool)
.await
}
}

View File

@ -0,0 +1,3 @@
[migrate]
# Move `migrations/` to under `src/` to separate it from subcrates.
migrations-dir = "src/migrations"

View File

@ -0,0 +1,120 @@
use accounts::AccountsManager;
use color_eyre::eyre;
use color_eyre::eyre::{Context, OptionExt};
use payments::PaymentsManager;
use rand::distributions::{Alphanumeric, DistString};
use sqlx::Connection;
#[tokio::main]
async fn main() -> eyre::Result<()> {
color_eyre::install()?;
let _ = dotenvy::dotenv();
tracing_subscriber::fmt::init();
let mut conn = sqlx::PgConnection::connect(
// `env::var()` doesn't include the variable name in the error.
&dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?,
)
.await
.wrap_err("could not connect to database")?;
let accounts = AccountsManager::setup(
dotenvy::var("ACCOUNTS_DATABASE_URL")
.wrap_err("ACCOUNTS_DATABASE_URL must be set")?
.parse()
.wrap_err("error parsing ACCOUNTS_DATABASE_URL")?,
1,
)
.await
.wrap_err("error initializing AccountsManager")?;
let payments = PaymentsManager::setup(
dotenvy::var("PAYMENTS_DATABASE_URL")
.wrap_err("PAYMENTS_DATABASE_URL must be set")?
.parse()
.wrap_err("error parsing PAYMENTS_DATABASE_URL")?,
)
.await
.wrap_err("error initializing PaymentsManager")?;
// For simplicity's sake, imagine each of these might be invoked by different request routes
// in a web application.
// POST /account
let user_email = format!("user{}@example.com", rand::random::<u32>());
let user_password = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
// Requires an externally managed transaction in case any application-specific records
// should be created after the actual account record.
let mut txn = conn.begin().await?;
let account_id = accounts
// Takes ownership of the password string because it's sent to another thread for hashing.
.create(&user_email, user_password.clone())
.await
.wrap_err("error creating account")?;
txn.commit().await?;
println!(
"created account ID: {}, email: {user_email:?}, password: {user_password:?}",
account_id.0
);
// POST /session
// Log the user in.
let session = accounts
.create_session(&user_email, user_password.clone())
.await
.wrap_err("error creating session")?;
// After this, session.session_token should then be returned to the client,
// either in the response body or a `Set-Cookie` header.
println!("created session token: {}", session.session_token.0);
// POST /purchase
// The client would then pass the session token to authenticated routes.
// In this route, they're making some kind of purchase.
// First, we need to ensure the session is valid.
// `session.session_token` would be passed by the client in whatever way is appropriate.
//
// For a pure REST API, consider an `Authorization: Bearer` header instead of the request body.
// With Axum, you can create a reusable extractor that reads the header and validates the session
// by implementing `FromRequestParts`.
//
// For APIs where the browser is intended to be the primary client, using a session cookie
// may be easier for the frontend. By setting the cookie with `HttpOnly: true`,
// it's impossible for malicious Javascript on the client to access and steal the session token.
let account_id = accounts
.auth_session(&session.session_token.0)
.await
.wrap_err("error authenticating session")?
.ok_or_eyre("session does not exist")?;
let purchase_amount: rust_decimal::Decimal = "12.34".parse().unwrap();
// Then, because the user is making a purchase, we record a payment.
let payment = payments
.create(account_id, "USD", purchase_amount)
.await
.wrap_err("error creating payment")?;
println!("created payment: {payment:?}");
let purchase_id = sqlx::query_scalar!(
"insert into purchase(account_id, payment_id, amount) values ($1, $2, $3) returning purchase_id",
account_id.0,
payment.payment_id.0,
purchase_amount
)
.fetch_one(&mut conn)
.await
.wrap_err("error creating purchase")?;
println!("created purchase: {purchase_id}");
conn.close().await?;
Ok(())
}

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,11 @@
create table purchase
(
purchase_id uuid primary key default gen_random_uuid(),
account_id uuid not null,
payment_id uuid not null,
amount numeric not null,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select trigger_updated_at('purchase');

View File

@ -0,0 +1,36 @@
[package]
name = "sqlx-example-postgres-multi-tenant"
version.workspace = true
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
authors.workspace = true
[dependencies]
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
color-eyre = "0.6.3"
dotenvy = "0.15.7"
tracing-subscriber = "0.3.19"
rust_decimal = "1.36.0"
rand = "0.8.5"
[dependencies.sqlx]
# version = "0.9.0"
workspace = true
features = ["runtime-tokio", "postgres", "migrate", "sqlx-toml"]
[dependencies.accounts]
path = "accounts"
package = "sqlx-example-postgres-multi-tenant-accounts"
[dependencies.payments]
path = "payments"
package = "sqlx-example-postgres-multi-tenant-payments"
[lints]
workspace = true

View File

@ -0,0 +1,60 @@
# Multi-tenant Databases with `sqlx.toml`
This example project involves three crates, each owning a different schema in one database,
with their own set of migrations.
* The main crate, a simple binary simulating the action of a REST API.
* Owns the `public` schema (tables are referenced unqualified).
* Migrations are moved to `src/migrations` using config key `migrate.migrations-dir`
to visually separate them from the subcrate folders.
* `accounts`: a subcrate simulating a reusable account-management crate.
* Owns schema `accounts`.
* `payments`: a subcrate simulating a wrapper for a payments API.
* Owns schema `payments`.
## Note: Schema-Qualified Names
This example uses schema-qualified names everywhere for clarity.
It can be tempting to change the `search_path` of the connection (MySQL, Postgres) to eliminate the need for schema
prefixes, but this can cause some really confusing issues when names conflict.
This example will generate a `_sqlx_migrations` table in three different schemas; if `search_path` is set
to `public,accounts,payments` and the migrator for the main application attempts to reference the table unqualified,
it would throw an error.
# Setup
This example requires running three different sets of migrations.
Ensure `sqlx-cli` is installed with Postgres and `sqlx.toml` support:
```
cargo install sqlx-cli --features postgres,sqlx-toml
```
Start a Postgres server (shown here using Docker, `run` command also works with `podman`):
```
docker run -d -e POSTGRES_PASSWORD=password -p 5432:5432 --name postgres postgres:latest
```
Create `.env` with `DATABASE_URL` or set the variable in your shell environment;
```
DATABASE_URL=postgres://postgres:password@localhost/example-multi-tenant
```
Run the following commands:
```
(cd accounts && sqlx db setup)
(cd payments && sqlx migrate run)
sqlx migrate run
```
It is an open question how to make this more convenient; `sqlx-cli` could gain a `--recursive` flag that checks
subdirectories for `sqlx.toml` files, but that would only work for crates within the same workspace. If the `accounts`
and `payments` crates were instead crates.io dependencies, we would need Cargo's help to resolve that information.
An issue has been opened for discussion: <https://github.com/launchbadge/sqlx/issues/3761>

View File

@ -0,0 +1,26 @@
[package]
name = "sqlx-example-postgres-multi-tenant-accounts"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1", features = ["rt", "sync"] }
argon2 = { version = "0.5.3", features = ["password-hash"] }
password-hash = { version = "0.5", features = ["std"] }
uuid = { version = "1", features = ["serde"] }
thiserror = "1"
rand = "0.8"
time = { version = "0.3.37", features = ["serde"] }
serde = { version = "1.0.218", features = ["derive"] }
[dependencies.sqlx]
# version = "0.9.0"
workspace = true
features = ["postgres", "time", "uuid", "macros", "sqlx-toml", "migrate"]
[dev-dependencies]
sqlx = { workspace = true, features = ["runtime-tokio"] }

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select accounts.trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function accounts.set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function accounts.trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION accounts.set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,10 @@
create table accounts.account
(
account_id uuid primary key default gen_random_uuid(),
email text unique not null,
password_hash text not null,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select accounts.trigger_updated_at('accounts.account');

View File

@ -0,0 +1,6 @@
create table accounts.session
(
session_token text primary key, -- random alphanumeric string
account_id uuid not null references accounts.account (account_id),
created_at timestamptz not null default now()
);

View File

@ -0,0 +1,11 @@
[migrate]
create-schemas = ["accounts"]
table-name = "accounts._sqlx_migrations"
[macros.table-overrides.'accounts.account']
'account_id' = "crate::AccountId"
'password_hash' = "sqlx::types::Text<password_hash::PasswordHashString>"
[macros.table-overrides.'accounts.session']
'session_token' = "crate::SessionToken"
'account_id' = "crate::AccountId"

View File

@ -0,0 +1,284 @@
use argon2::{password_hash, Argon2, PasswordHasher, PasswordVerifier};
use password_hash::PasswordHashString;
use rand::distributions::{Alphanumeric, DistString};
use sqlx::{Acquire, Executor, PgTransaction, Postgres};
use std::sync::Arc;
use uuid::Uuid;
use tokio::sync::Semaphore;
#[derive(sqlx::Type, Copy, Clone, Debug, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
pub struct AccountId(pub Uuid);
#[derive(sqlx::Type, Clone, Debug, serde::Deserialize, serde::Serialize)]
#[sqlx(transparent)]
pub struct SessionToken(pub String);
pub struct Session {
pub account_id: AccountId,
pub session_token: SessionToken,
}
pub struct AccountsManager {
/// Controls how many blocking tasks are allowed to run concurrently for Argon2 hashing.
///
/// ### Motivation
/// Tokio blocking tasks are generally not designed for CPU-bound work.
///
/// If no threads are idle, Tokio will automatically spawn new ones to handle
/// new blocking tasks up to a very high limit--512 by default.
///
/// This is because blocking tasks are expected to spend their time *blocked*, e.g. on
/// blocking I/O, and thus not consume CPU resources or require a lot of context switching.
///
/// This strategy is not the most efficient way to use threads for CPU-bound work, which
/// should schedule work to a fixed number of threads to minimize context switching
/// and memory usage (each new thread needs significant space allocated for its stack).
///
/// We can work around this by using a purpose-designed thread-pool, like Rayon,
/// but we still have the problem that those APIs usually are not designed to support `async`,
/// so we end up needing blocking tasks anyway, or implementing our own work queue using
/// channels. Rayon also does not shut down idle worker threads.
///
/// `block_in_place` is not a silver bullet, either, as it simply uses `spawn_blocking`
/// internally to take over from the current thread while it is executing blocking work.
/// This also prevents futures from being polled concurrently in the current task.
///
/// We can lower the limit for blocking threads when creating the runtime, but this risks
/// starving other blocking tasks that are being created by the application or the Tokio
/// runtime itself
/// (which are used for `tokio::fs`, stdio, resolving of hostnames by `ToSocketAddrs`, etc.).
///
/// Instead, we can just use a Semaphore to limit how many blocking tasks are spawned at once,
/// emulating the behavior of a thread pool like Rayon without needing any additional crates.
hashing_semaphore: Arc<Semaphore>,
}
#[derive(Debug, thiserror::Error)]
pub enum CreateAccountError {
#[error("error creating account: email in-use")]
EmailInUse,
#[error("error creating account")]
General(
#[source]
#[from]
GeneralError,
),
}
#[derive(Debug, thiserror::Error)]
pub enum CreateSessionError {
#[error("unknown email")]
UnknownEmail,
#[error("invalid password")]
InvalidPassword,
#[error("authentication error")]
General(
#[source]
#[from]
GeneralError,
),
}
#[derive(Debug, thiserror::Error)]
pub enum GeneralError {
#[error("database error")]
Sqlx(
#[source]
#[from]
sqlx::Error,
),
#[error("error hashing password")]
PasswordHash(
#[source]
#[from]
password_hash::Error,
),
#[error("task panicked")]
Task(
#[source]
#[from]
tokio::task::JoinError,
),
}
impl AccountsManager {
pub async fn setup(
pool: impl Acquire<'_, Database = Postgres>,
max_hashing_threads: usize,
) -> Result<Self, GeneralError> {
sqlx::migrate!()
.run(pool)
.await
.map_err(sqlx::Error::from)?;
Ok(AccountsManager {
hashing_semaphore: Semaphore::new(max_hashing_threads).into(),
})
}
async fn hash_password(&self, password: String) -> Result<PasswordHashString, GeneralError> {
let guard = self
.hashing_semaphore
.clone()
.acquire_owned()
.await
.expect("BUG: this semaphore should not be closed");
// We transfer ownership to the blocking task and back to ensure Tokio doesn't spawn
// excess threads.
let (_guard, res) = tokio::task::spawn_blocking(move || {
let salt = password_hash::SaltString::generate(rand::thread_rng());
(
guard,
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map(|hash| hash.serialize()),
)
})
.await?;
Ok(res?)
}
async fn verify_password(
&self,
password: String,
hash: PasswordHashString,
) -> Result<(), CreateSessionError> {
let guard = self
.hashing_semaphore
.clone()
.acquire_owned()
.await
.expect("BUG: this semaphore should not be closed");
let (_guard, res) = tokio::task::spawn_blocking(move || {
(
guard,
Argon2::default().verify_password(password.as_bytes(), &hash.password_hash()),
)
})
.await
.map_err(GeneralError::from)?;
if let Err(password_hash::Error::Password) = res {
return Err(CreateSessionError::InvalidPassword);
}
res.map_err(GeneralError::from)?;
Ok(())
}
pub async fn create(
&self,
txn: &mut PgTransaction<'_>,
email: &str,
password: String,
) -> Result<AccountId, CreateAccountError> {
// Hash password whether the account exists or not to make it harder
// to tell the difference in the timing.
let hash = self.hash_password(password).await?;
// Thanks to `sqlx.toml`, `account_id` maps to `AccountId`
sqlx::query_scalar!(
// language=PostgreSQL
"insert into accounts.account(email, password_hash) \
values ($1, $2) \
returning account_id",
email,
hash.as_str(),
)
.fetch_one(&mut **txn)
.await
.map_err(|e| {
if e.as_database_error().and_then(|dbe| dbe.constraint())
== Some("account_account_id_key")
{
CreateAccountError::EmailInUse
} else {
GeneralError::from(e).into()
}
})
}
pub async fn create_session(
&self,
db: impl Acquire<'_, Database = Postgres>,
email: &str,
password: String,
) -> Result<Session, CreateSessionError> {
let mut txn = db.begin().await.map_err(GeneralError::from)?;
// To save a round-trip to the database, we'll speculatively insert the session token
// at the same time as we're looking up the password hash.
//
// This does nothing until the transaction is actually committed.
let session_token = SessionToken::generate();
// Thanks to `sqlx.toml`:
// * `account_id` maps to `AccountId`
// * `password_hash` maps to `Text<PasswordHashString>`
// * `session_token` maps to `SessionToken`
let maybe_account = sqlx::query!(
// language=PostgreSQL
"with account as (
select account_id, password_hash \
from accounts.account \
where email = $1
), session as (
insert into accounts.session(session_token, account_id)
select $2, account_id
from account
)
select account.account_id, account.password_hash from account",
email,
session_token.0
)
.fetch_optional(&mut *txn)
.await
.map_err(GeneralError::from)?;
let Some(account) = maybe_account else {
// Hash the password whether the account exists or not to hide the difference in timing.
self.hash_password(password)
.await
.map_err(GeneralError::from)?;
return Err(CreateSessionError::UnknownEmail);
};
self.verify_password(password, account.password_hash.into_inner())
.await?;
txn.commit().await.map_err(GeneralError::from)?;
Ok(Session {
account_id: account.account_id,
session_token,
})
}
pub async fn auth_session(
&self,
db: impl Executor<'_, Database = Postgres>,
session_token: &str,
) -> Result<Option<AccountId>, GeneralError> {
sqlx::query_scalar!(
"select account_id from accounts.session where session_token = $1",
session_token
)
.fetch_optional(db)
.await
.map_err(GeneralError::from)
}
}
impl SessionToken {
const LEN: usize = 32;
fn generate() -> Self {
SessionToken(Alphanumeric.sample_string(&mut rand::thread_rng(), Self::LEN))
}
}

View File

@ -0,0 +1,23 @@
[package]
name = "sqlx-example-postgres-multi-tenant-payments"
version = "0.1.0"
edition = "2021"
[dependencies]
rust_decimal = "1.36.0"
time = "0.3.37"
uuid = "1.12.1"
[dependencies.sqlx]
# version = "0.9.0"
workspace = true
features = ["postgres", "time", "uuid", "rust_decimal", "sqlx-toml", "migrate"]
[dependencies.accounts]
path = "../accounts"
package = "sqlx-example-postgres-multi-tenant-accounts"
[dev-dependencies]
sqlx = { workspace = true, features = ["runtime-tokio"] }

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select payments.trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function payments.set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function payments.trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION payments.set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,59 @@
-- `payments::PaymentStatus`
--
-- Historically at LaunchBadge we preferred not to define enums on the database side because it can be annoying
-- and error-prone to keep them in-sync with the application.
-- Instead, we let the application define the enum and just have the database store a compact representation of it.
-- This is mostly a matter of taste, however.
--
-- For the purposes of this example, we're using an in-database enum because this is a common use-case
-- for needing type overrides.
create type payments.payment_status as enum (
'pending',
'created',
'success',
'failed'
);
create table payments.payment
(
payment_id uuid primary key default gen_random_uuid(),
-- This cross-schema reference means migrations for the `accounts` crate should be run first.
account_id uuid not null references accounts.account (account_id),
status payments.payment_status not null,
-- ISO 4217 currency code (https://en.wikipedia.org/wiki/ISO_4217#List_of_ISO_4217_currency_codes)
--
-- This *could* be an ENUM of currency codes, but constraining this to a set of known values in the database
-- would be annoying to keep up to date as support for more currencies is added.
--
-- Consider also if support for cryptocurrencies is desired; those are not covered by ISO 4217.
--
-- Though ISO 4217 is a three-character code, `TEXT`, `VARCHAR` and `CHAR(N)`
-- all use the same storage format in Postgres. Any constraint against the length of this field
-- would purely be a sanity check.
currency text not null,
-- There's an endless debate about what type should be used to represent currency amounts.
--
-- Postgres has the `MONEY` type, but the fractional precision depends on a C locale setting and the type is mostly
-- optimized for storing USD, or other currencies with a minimum fraction of 1 cent.
--
-- NEVER use `FLOAT` or `DOUBLE`. IEEE-754 rounding point has round-off and precision errors that make it wholly
-- unsuitable for representing real money amounts.
--
-- `NUMERIC`, being an arbitrary-precision decimal format, is a safe default choice that can support any currency,
-- and so is what we've chosen here.
amount NUMERIC not null,
-- Payments almost always take place through a third-party vendor (e.g. PayPal, Stripe, etc.),
-- so imagine this is an identifier string for this payment in such a vendor's systems.
--
-- For privacy and security reasons, payment and personally-identifying information
-- (e.g. credit card numbers, bank account numbers, billing addresses) should only be stored with the vendor
-- unless there is a good reason otherwise.
external_payment_id text,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select payments.trigger_updated_at('payments.payment');

View File

@ -0,0 +1,10 @@
[migrate]
create-schemas = ["payments"]
table-name = "payments._sqlx_migrations"
[macros.table-overrides.'payments.payment']
'payment_id' = "crate::PaymentId"
'account_id' = "accounts::AccountId"
[macros.type-overrides]
'payments.payment_status' = "crate::PaymentStatus"

View File

@ -0,0 +1,110 @@
use accounts::AccountId;
use sqlx::{Acquire, PgConnection, Postgres};
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(sqlx::Type, Copy, Clone, Debug)]
#[sqlx(transparent)]
pub struct PaymentId(pub Uuid);
#[derive(sqlx::Type, Copy, Clone, Debug)]
#[sqlx(type_name = "payments.payment_status")]
#[sqlx(rename_all = "snake_case")]
pub enum PaymentStatus {
Pending,
Created,
Success,
Failed,
}
// Users often assume that they need `#[derive(FromRow)]` to use `query_as!()`,
// then are surprised when the derive's control attributes have no effect.
// The macros currently do *not* use the `FromRow` trait at all.
// Support for `FromRow` is planned, but would require significant changes to the macros.
// See https://github.com/launchbadge/sqlx/issues/514 for details.
#[derive(Clone, Debug)]
pub struct Payment {
pub payment_id: PaymentId,
pub account_id: AccountId,
pub status: PaymentStatus,
pub currency: String,
// `rust_decimal::Decimal` has more than enough precision for any real-world amount of money.
pub amount: rust_decimal::Decimal,
pub external_payment_id: Option<String>,
pub created_at: OffsetDateTime,
pub updated_at: Option<OffsetDateTime>,
}
// Accepting `impl Acquire` allows this function to be generic over `Pool`, `Connection` and `Transaction`.
pub async fn migrate(db: impl Acquire<'_, Database = Postgres>) -> sqlx::Result<()> {
sqlx::migrate!().run(db).await?;
Ok(())
}
pub async fn create(
conn: &mut PgConnection,
account_id: AccountId,
currency: &str,
amount: rust_decimal::Decimal,
) -> sqlx::Result<Payment> {
// Imagine this method does more than just create a record in the database;
// maybe it actually initiates the payment with a third-party vendor, like Stripe.
//
// We need to ensure that we can link the payment in the vendor's systems back to a record
// in ours, even if any of the following happens:
// * The application dies before storing the external payment ID in the database
// * We lose the connection to the database while trying to commit a transaction
// * The database server dies while committing the transaction
//
// Thus, we create the payment in three atomic phases:
// * We create the payment record in our system and commit it.
// * We create the payment in the vendor's system with our payment ID attached.
// * We update our payment record with the vendor's payment ID.
let payment_id = sqlx::query_scalar!(
"insert into payments.payment(account_id, status, currency, amount) \
values ($1, $2, $3, $4) \
returning payment_id",
// The database doesn't give us enough information to correctly typecheck `AccountId` here.
// We have to insert the UUID directly.
account_id.0,
PaymentStatus::Pending,
currency,
amount,
)
.fetch_one(&mut *conn)
.await?;
// We then create the record with the payment vendor...
let external_payment_id = "foobar1234";
// Then we store the external payment ID and update the payment status.
//
// NOTE: use caution with `select *` or `returning *`;
// the order of columns gets baked into the binary, so if it changes between compile time and
// run-time, you may run into errors.
let payment = sqlx::query_as!(
Payment,
"update payments.payment \
set status = $1, external_payment_id = $2 \
where payment_id = $3 \
returning *",
PaymentStatus::Created,
external_payment_id,
payment_id.0,
)
.fetch_one(&mut *conn)
.await?;
Ok(payment)
}
pub async fn get(db: &mut PgConnection, payment_id: PaymentId) -> sqlx::Result<Option<Payment>> {
sqlx::query_as!(
Payment,
// see note above about `select *`
"select * from payments.payment where payment_id = $1",
payment_id.0
)
.fetch_optional(db)
.await
}

View File

@ -0,0 +1,3 @@
[migrate]
# Move `migrations/` to under `src/` to separate it from subcrates.
migrations-dir = "src/migrations"

View File

@ -0,0 +1,108 @@
use accounts::AccountsManager;
use color_eyre::eyre;
use color_eyre::eyre::{Context, OptionExt};
use rand::distributions::{Alphanumeric, DistString};
use sqlx::Connection;
#[tokio::main]
async fn main() -> eyre::Result<()> {
color_eyre::install()?;
let _ = dotenvy::dotenv();
tracing_subscriber::fmt::init();
let mut conn = sqlx::PgConnection::connect(
// `env::var()` doesn't include the variable name in the error.
&dotenvy::var("DATABASE_URL").wrap_err("DATABASE_URL must be set")?,
)
.await
.wrap_err("could not connect to database")?;
// Runs migration for `accounts` internally.
let accounts = AccountsManager::setup(&mut conn, 1)
.await
.wrap_err("error initializing AccountsManager")?;
payments::migrate(&mut conn)
.await
.wrap_err("error running payments migrations")?;
// For simplicity's sake, imagine each of these might be invoked by different request routes
// in a web application.
// POST /account
let user_email = format!("user{}@example.com", rand::random::<u32>());
let user_password = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
// Requires an externally managed transaction in case any application-specific records
// should be created after the actual account record.
let mut txn = conn.begin().await?;
let account_id = accounts
// Takes ownership of the password string because it's sent to another thread for hashing.
.create(&mut txn, &user_email, user_password.clone())
.await
.wrap_err("error creating account")?;
txn.commit().await?;
println!(
"created account ID: {}, email: {user_email:?}, password: {user_password:?}",
account_id.0
);
// POST /session
// Log the user in.
let session = accounts
.create_session(&mut conn, &user_email, user_password.clone())
.await
.wrap_err("error creating session")?;
// After this, session.session_token should then be returned to the client,
// either in the response body or a `Set-Cookie` header.
println!("created session token: {}", session.session_token.0);
// POST /purchase
// The client would then pass the session token to authenticated routes.
// In this route, they're making some kind of purchase.
// First, we need to ensure the session is valid.
// `session.session_token` would be passed by the client in whatever way is appropriate.
//
// For a pure REST API, consider an `Authorization: Bearer` header instead of the request body.
// With Axum, you can create a reusable extractor that reads the header and validates the session
// by implementing `FromRequestParts`.
//
// For APIs where the browser is intended to be the primary client, using a session cookie
// may be easier for the frontend. By setting the cookie with `HttpOnly: true`,
// it's impossible for malicious Javascript on the client to access and steal the session token.
let account_id = accounts
.auth_session(&mut conn, &session.session_token.0)
.await
.wrap_err("error authenticating session")?
.ok_or_eyre("session does not exist")?;
let purchase_amount: rust_decimal::Decimal = "12.34".parse().unwrap();
// Then, because the user is making a purchase, we record a payment.
let payment = payments::create(&mut conn, account_id, "USD", purchase_amount)
.await
.wrap_err("error creating payment")?;
println!("created payment: {payment:?}");
let purchase_id = sqlx::query_scalar!(
"insert into purchase(account_id, payment_id, amount) values ($1, $2, $3) returning purchase_id",
account_id.0,
payment.payment_id.0,
purchase_amount
)
.fetch_one(&mut conn)
.await
.wrap_err("error creating purchase")?;
println!("created purchase: {purchase_id}");
conn.close().await?;
Ok(())
}

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,11 @@
create table purchase
(
purchase_id uuid primary key default gen_random_uuid(),
account_id uuid not null references accounts.account (account_id),
payment_id uuid not null references payments.payment (payment_id),
amount numeric not null,
created_at timestamptz not null default now(),
updated_at timestamptz
);
select trigger_updated_at('purchase');

View File

@ -0,0 +1,37 @@
[package]
name = "sqlx-example-postgres-preferred-crates"
version.workspace = true
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
authors.workspace = true
[dependencies]
dotenvy.workspace = true
anyhow = "1"
chrono = "0.4"
serde = { version = "1", features = ["derive"] }
uuid = { version = "1", features = ["serde"] }
[dependencies.tokio]
workspace = true
features = ["rt-multi-thread", "macros"]
[dependencies.sqlx]
# version = "0.9.0"
workspace = true
features = ["runtime-tokio", "postgres", "bigdecimal", "chrono", "derive", "migrate", "sqlx-toml"]
[dependencies.uses-rust-decimal]
path = "uses-rust-decimal"
package = "sqlx-example-postgres-preferred-crates-uses-rust-decimal"
[dependencies.uses-time]
path = "uses-time"
package = "sqlx-example-postgres-preferred-crates-uses-time"
[lints]
workspace = true

View File

@ -0,0 +1,55 @@
# Usage of `macros.preferred-crates` in `sqlx.toml`
## The Problem
SQLx has many optional features that enable integrations for external crates to map from/to SQL types.
In some cases, more than one optional feature applies to the same set of types:
* The `chrono` and `time` features enable mapping SQL date/time types to those in these crates.
* Similarly, `bigdecimal` and `rust_decimal` enable mapping for the SQL `NUMERIC` type.
Throughout its existence, the `query!()` family of macros has inferred which crate to use based on which optional
feature was enabled. If multiple features are enabled, one takes precedent over the other: `time` over `chrono`,
`rust_decimal` over `bigdecimal`, etc. The ordering is purely the result of historical happenstance and
does not indicate any specific preference for one crate over another. They each have their tradeoffs.
This works fine when only one crate in the dependency graph depends on SQLx, but can break down if another crate
in the dependency graph also depends on SQLx. Because of Cargo's [feature unification], any features enabled
by this other crate are also forced on for all other crates that depend on the same version of SQLx in the same project.
This is intentional design on Cargo's part; features are meant to be purely additive, so it can build each transitive
dependency just once no matter how many crates depend on it. Otherwise, this could result in combinatorial explosion.
Unfortunately for us, this means that if your project depends on SQLx and enables the `chrono` feature, but also depends
on another crate that enables the `time` feature, the `query!()` macros will end up thinking that _you_ want to use
the `time` crate, because they don't know any better.
Fixing this has historically required patching the dependency, which is annoying to maintain long-term.
[feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification
## The Solution
However, as of 0.9.0, SQLx has gained the ability to configure the macros through the use of a `sqlx.toml` file.
This includes the ability to tell the macros which crate you prefer, overriding the inference.
See the [`sqlx.toml`](./sqlx.toml) file in this directory for details.
A full reference `sqlx.toml` is also available as `sqlx-core/src/config/reference.toml`.
## This Example
This example exists both to showcase the macro configuration and also serve as a test for the functionality.
It consists of three crates:
* The root crate, which depends on SQLx and enables the `chrono` and `bigdecimal` features,
* `uses-rust-decimal`, a dependency which also depends on SQLx and enables the `rust_decimal` feature,
* and `uses-time`, a dependency which also depends on SQLx and enables the `time` feature.
* This serves as a stand-in for `tower-sessions-sqlx-store`, which is [one of the culprits for this issue](https://github.com/launchbadge/sqlx/issues/3412#issuecomment-2277377597).
Given that both dependencies enable features with higher precedence, they would historically have interfered
with the usage in the root crate. (Pretend that they're published to crates.io and cannot be easily changed.)
However, because the root crate uses a `sqlx.toml`, the macros know exactly which crates it wants to use and everyone's happy.

View File

@ -0,0 +1,9 @@
[migrate]
# Move `migrations/` to under `src/` to separate it from subcrates.
migrations-dir = "src/migrations"
[macros.preferred-crates]
# Keeps `time` from taking precedent even though it's enabled by a dependency.
date-time = "chrono"
# Same thing with `rust_decimal`
numeric = "bigdecimal"

View File

@ -0,0 +1,70 @@
use anyhow::Context;
use chrono::{DateTime, Utc};
use sqlx::{Connection, PgConnection};
use std::time::Duration;
use uuid::Uuid;
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Eq, Debug)]
struct SessionData {
user_id: Uuid,
}
#[derive(sqlx::FromRow, Debug)]
struct User {
id: Uuid,
username: String,
password_hash: String,
// Because `time` is enabled by a transitive dependency, we previously would have needed
// a type override in the query to get types from `chrono`.
created_at: DateTime<Utc>,
updated_at: Option<DateTime<Utc>>,
}
const SESSION_DURATION: Duration = Duration::from_secs(60 * 60); // 1 hour
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut conn =
PgConnection::connect(&dotenvy::var("DATABASE_URL").context("DATABASE_URL must be set")?)
.await
.context("failed to connect to DATABASE_URL")?;
sqlx::migrate!("./src/migrations").run(&mut conn).await?;
uses_rust_decimal::create_table(&mut conn).await?;
uses_time::create_table(&mut conn).await?;
let user_id = sqlx::query_scalar!(
"insert into users(username, password_hash) values($1, $2) returning id",
"user_foo",
"<pretend this is a password hash>",
)
.fetch_one(&mut conn)
.await?;
let user = sqlx::query_as!(User, "select * from users where id = $1", user_id)
.fetch_one(&mut conn)
.await?;
println!("Created user: {user:?}");
let session =
uses_time::create_session(&mut conn, SessionData { user_id }, SESSION_DURATION).await?;
let session_from_id = uses_time::get_session::<SessionData>(&mut conn, session.id)
.await?
.expect("expected session");
assert_eq!(session, session_from_id);
let purchase_id =
uses_rust_decimal::create_purchase(&mut conn, user_id, 1234u32.into(), "Rent").await?;
let purchase = uses_rust_decimal::get_purchase(&mut conn, purchase_id)
.await?
.expect("expected purchase");
println!("Created purchase: {purchase:?}");
Ok(())
}

View File

@ -0,0 +1,30 @@
-- We try to ensure every table has `created_at` and `updated_at` columns, which can help immensely with debugging
-- and auditing.
--
-- While `created_at` can just be `default now()`, setting `updated_at` on update requires a trigger which
-- is a lot of boilerplate. These two functions save us from writing that every time as instead we can just do
--
-- select trigger_updated_at('<table name>');
--
-- after a `CREATE TABLE`.
create or replace function set_updated_at()
returns trigger as
$$
begin
NEW.updated_at = now();
return NEW;
end;
$$ language plpgsql;
create or replace function trigger_updated_at(tablename regclass)
returns void as
$$
begin
execute format('CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON %s
FOR EACH ROW
WHEN (OLD is distinct from NEW)
EXECUTE FUNCTION set_updated_at();', tablename);
end;
$$ language plpgsql;

View File

@ -0,0 +1,11 @@
create table users(
id uuid primary key default gen_random_uuid(),
username text not null,
password_hash text not null,
created_at timestamptz not null default now(),
updated_at timestamptz
);
create unique index users_username_unique on users(lower(username));
select trigger_updated_at('users');

View File

@ -0,0 +1,21 @@
[package]
name = "sqlx-example-postgres-preferred-crates-uses-rust-decimal"
version.workspace = true
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
authors.workspace = true
[dependencies]
chrono = "0.4"
rust_decimal = "1"
uuid = "1"
[dependencies.sqlx]
workspace = true
features = ["runtime-tokio", "postgres", "rust_decimal", "chrono", "uuid"]
[lints]
workspace = true

View File

@ -0,0 +1,55 @@
use chrono::{DateTime, Utc};
use sqlx::PgExecutor;
#[derive(sqlx::FromRow, Debug)]
pub struct Purchase {
pub id: Uuid,
pub user_id: Uuid,
pub amount: Decimal,
pub description: String,
pub created_at: DateTime<Utc>,
}
pub use rust_decimal::Decimal;
use uuid::Uuid;
pub async fn create_table(e: impl PgExecutor<'_>) -> sqlx::Result<()> {
sqlx::raw_sql(
// language=PostgreSQL
"create table if not exists purchases( \
id uuid primary key default gen_random_uuid(), \
user_id uuid not null, \
amount numeric not null check(amount > 0), \
description text not null, \
created_at timestamptz not null default now() \
);
",
)
.execute(e)
.await?;
Ok(())
}
pub async fn create_purchase(
e: impl PgExecutor<'_>,
user_id: Uuid,
amount: Decimal,
description: &str,
) -> sqlx::Result<Uuid> {
sqlx::query_scalar(
"insert into purchases(user_id, amount, description) values ($1, $2, $3) returning id",
)
.bind(user_id)
.bind(amount)
.bind(description)
.fetch_one(e)
.await
}
pub async fn get_purchase(e: impl PgExecutor<'_>, id: Uuid) -> sqlx::Result<Option<Purchase>> {
sqlx::query_as("select * from purchases where id = $1")
.bind(id)
.fetch_optional(e)
.await
}

View File

@ -0,0 +1,21 @@
[package]
name = "sqlx-example-postgres-preferred-crates-uses-time"
version.workspace = true
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
authors.workspace = true
[dependencies]
serde = "1"
time = "0.3"
uuid = "1"
[dependencies.sqlx]
workspace = true
features = ["runtime-tokio", "postgres", "time", "json", "uuid"]
[lints]
workspace = true

View File

@ -0,0 +1,75 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use sqlx::PgExecutor;
use std::time::Duration;
use time::OffsetDateTime;
use sqlx::types::Json;
use uuid::Uuid;
#[derive(sqlx::FromRow, PartialEq, Eq, Debug)]
pub struct Session<D> {
pub id: Uuid,
#[sqlx(json)]
pub data: D,
pub created_at: OffsetDateTime,
pub expires_at: OffsetDateTime,
}
pub async fn create_table(e: impl PgExecutor<'_>) -> sqlx::Result<()> {
sqlx::raw_sql(
// language=PostgreSQL
"create table if not exists sessions( \
id uuid primary key default gen_random_uuid(), \
data jsonb not null,
created_at timestamptz not null default now(),
expires_at timestamptz not null
)",
)
.execute(e)
.await?;
Ok(())
}
pub async fn create_session<D: Serialize>(
e: impl PgExecutor<'_>,
data: D,
valid_duration: Duration,
) -> sqlx::Result<Session<D>> {
// Round down to the nearest second because
// Postgres doesn't support precision higher than 1 microsecond anyway.
let created_at = OffsetDateTime::now_utc()
.replace_nanosecond(0)
.expect("0 nanoseconds should be in range");
let expires_at = created_at + valid_duration;
let id: Uuid = sqlx::query_scalar(
"insert into sessions(data, created_at, expires_at) \
values ($1, $2, $3) \
returning id",
)
.bind(Json(&data))
.bind(created_at)
.bind(expires_at)
.fetch_one(e)
.await?;
Ok(Session {
id,
data,
created_at,
expires_at,
})
}
pub async fn get_session<D: DeserializeOwned + Send + Unpin + 'static>(
e: impl PgExecutor<'_>,
id: Uuid,
) -> sqlx::Result<Option<Session<D>>> {
sqlx::query_as("select id, data, created_at, expires_at from sessions where id = $1")
.bind(id)
.fetch_optional(e)
.await
}

View File

@ -28,11 +28,6 @@ path = "src/bin/cargo-sqlx.rs"
[dependencies]
dotenvy = "0.15.0"
tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread", "signal"] }
sqlx = { workspace = true, default-features = false, features = [
"runtime-tokio",
"migrate",
"any",
] }
futures-util = { version = "0.3.19", features = ["alloc"] }
clap = { version = "4.3.10", features = ["derive", "env", "wrap_help"] }
clap_complete = { version = "4.3.1", optional = true }
@ -48,8 +43,18 @@ filetime = "0.2"
backoff = { version = "0.4.0", features = ["futures", "tokio"] }
[dependencies.sqlx]
workspace = true
default-features = false
features = [
"runtime-tokio",
"migrate",
"any",
]
[features]
default = ["postgres", "sqlite", "mysql", "native-tls", "completions"]
default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"]
rustls = ["sqlx/tls-rustls"]
native-tls = ["sqlx/tls-native-tls"]
@ -64,6 +69,8 @@ openssl-vendored = ["openssl/vendored"]
completions = ["dep:clap_complete"]
sqlx-toml = ["sqlx/sqlx-toml"]
# Conditional compilation only
_sqlite = []

View File

@ -1,5 +1,5 @@
use crate::migrate;
use crate::opt::ConnectOpts;
use crate::opt::{ConnectOpts, MigrationSourceOpt};
use crate::{migrate, Config};
use console::{style, Term};
use dialoguer::Confirm;
use sqlx::any::Any;
@ -19,14 +19,14 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> {
std::sync::atomic::Ordering::Release,
);
Any::create_database(connect_opts.required_db_url()?).await?;
Any::create_database(connect_opts.expect_db_url()?).await?;
}
Ok(())
}
pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> anyhow::Result<()> {
if confirm && !ask_to_continue_drop(connect_opts.required_db_url()?.to_owned()).await {
if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?.to_owned()).await {
return Ok(());
}
@ -36,9 +36,9 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any
if exists {
if force {
Any::force_drop_database(connect_opts.required_db_url()?).await?;
Any::force_drop_database(connect_opts.expect_db_url()?).await?;
} else {
Any::drop_database(connect_opts.required_db_url()?).await?;
Any::drop_database(connect_opts.expect_db_url()?).await?;
}
}
@ -46,18 +46,23 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> any
}
pub async fn reset(
migration_source: &str,
config: &Config,
migration_source: &MigrationSourceOpt,
connect_opts: &ConnectOpts,
confirm: bool,
force: bool,
) -> anyhow::Result<()> {
drop(connect_opts, confirm, force).await?;
setup(migration_source, connect_opts).await
setup(config, migration_source, connect_opts).await
}
pub async fn setup(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> {
pub async fn setup(
config: &Config,
migration_source: &MigrationSourceOpt,
connect_opts: &ConnectOpts,
) -> anyhow::Result<()> {
create(connect_opts).await?;
migrate::run(migration_source, connect_opts, false, false, None).await
migrate::run(config, migration_source, connect_opts, false, false, None).await
}
async fn ask_to_continue_drop(db_url: String) -> bool {

View File

@ -2,7 +2,6 @@ use std::future::Future;
use std::io;
use std::time::Duration;
use anyhow::Result;
use futures_util::TryFutureExt;
use sqlx::{AnyConnection, Connection};
@ -22,6 +21,8 @@ mod prepare;
pub use crate::opt::Opt;
pub use sqlx::_unstable::config::{self, Config};
/// Check arguments for `--no-dotenv` _before_ Clap parsing, and apply `.env` if not set.
pub fn maybe_apply_dotenv() {
if std::env::args().any(|arg| arg == "--no-dotenv") {
@ -31,7 +32,7 @@ pub fn maybe_apply_dotenv() {
dotenvy::dotenv().ok();
}
pub async fn run(opt: Opt) -> Result<()> {
pub async fn run(opt: Opt) -> anyhow::Result<()> {
// This `select!` is here so that when the process receives a `SIGINT` (CTRL + C),
// the futures currently running on this task get dropped before the program exits.
// This is currently necessary for the consumers of the `dialoguer` crate to restore
@ -51,24 +52,24 @@ pub async fn run(opt: Opt) -> Result<()> {
}
}
async fn do_run(opt: Opt) -> Result<()> {
async fn do_run(opt: Opt) -> anyhow::Result<()> {
match opt.command {
Command::Migrate(migrate) => match migrate.command {
MigrateCommand::Add {
source,
description,
reversible,
sequential,
timestamp,
} => migrate::add(&source, &description, reversible, sequential, timestamp).await?,
MigrateCommand::Add(opts) => migrate::add(opts).await?,
MigrateCommand::Run {
source,
config,
dry_run,
ignore_missing,
connect_opts,
mut connect_opts,
target_version,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
migrate::run(
&config,
&source,
&connect_opts,
dry_run,
@ -79,12 +80,18 @@ async fn do_run(opt: Opt) -> Result<()> {
}
MigrateCommand::Revert {
source,
config,
dry_run,
ignore_missing,
connect_opts,
mut connect_opts,
target_version,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
migrate::revert(
&config,
&source,
&connect_opts,
dry_run,
@ -95,37 +102,83 @@ async fn do_run(opt: Opt) -> Result<()> {
}
MigrateCommand::Info {
source,
connect_opts,
} => migrate::info(&source, &connect_opts).await?,
MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?,
config,
mut connect_opts,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
migrate::info(&config, &source, &connect_opts).await?
}
MigrateCommand::BuildScript {
source,
config,
force,
} => {
let config = config.load_config().await?;
migrate::build_script(&config, &source, force)?
}
},
Command::Database(database) => match database.command {
DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?,
DatabaseCommand::Create {
config,
mut connect_opts,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
database::create(&connect_opts).await?
}
DatabaseCommand::Drop {
confirmation,
connect_opts,
config,
mut connect_opts,
force,
} => database::drop(&connect_opts, !confirmation.yes, force).await?,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
database::drop(&connect_opts, !confirmation.yes, force).await?
}
DatabaseCommand::Reset {
confirmation,
source,
connect_opts,
config,
mut connect_opts,
force,
} => database::reset(&source, &connect_opts, !confirmation.yes, force).await?,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
database::reset(&config, &source, &connect_opts, !confirmation.yes, force).await?
}
DatabaseCommand::Setup {
source,
connect_opts,
} => database::setup(&source, &connect_opts).await?,
config,
mut connect_opts,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
database::setup(&config, &source, &connect_opts).await?
}
},
Command::Prepare {
check,
all,
workspace,
connect_opts,
mut connect_opts,
args,
} => prepare::run(check, all, workspace, connect_opts, args).await?,
config,
} => {
let config = config.load_config().await?;
connect_opts.populate_db_url(&config)?;
prepare::run(check, all, workspace, connect_opts, args).await?
}
#[cfg(feature = "completions")]
Command::Completions { shell } => completions::run(shell),
@ -153,7 +206,7 @@ where
{
sqlx::any::install_default_drivers();
let db_url = opts.required_db_url()?;
let db_url = opts.expect_db_url()?;
backoff::future::retry(
backoff::ExponentialBackoffBuilder::new()

View File

@ -1,6 +1,6 @@
use crate::opt::ConnectOpts;
use crate::config::Config;
use crate::opt::{AddMigrationOpts, ConnectOpts, MigrationSourceOpt};
use anyhow::{bail, Context};
use chrono::Utc;
use console::style;
use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, MigrationType, Migrator};
use sqlx::Connection;
@ -11,142 +11,47 @@ use std::fs::{self, File};
use std::path::Path;
use std::time::Duration;
fn create_file(
migration_source: &str,
file_prefix: &str,
description: &str,
migration_type: MigrationType,
) -> anyhow::Result<()> {
use std::path::PathBuf;
pub async fn add(opts: AddMigrationOpts) -> anyhow::Result<()> {
let config = opts.config.load_config().await?;
let mut file_name = file_prefix.to_string();
file_name.push('_');
file_name.push_str(&description.replace(' ', "_"));
file_name.push_str(migration_type.suffix());
let source = opts.source.resolve_path(&config);
let mut path = PathBuf::new();
path.push(migration_source);
path.push(&file_name);
fs::create_dir_all(source).context("Unable to create migrations directory")?;
println!("Creating {}", style(path.display()).cyan());
let migrator = opts.source.resolve(&config).await?;
let mut file = File::create(&path).context("Failed to create migration file")?;
let version_prefix = opts.version_prefix(&config, &migrator);
std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?;
Ok(())
}
enum MigrationOrdering {
Timestamp(String),
Sequential(String),
}
impl MigrationOrdering {
fn timestamp() -> MigrationOrdering {
Self::Timestamp(Utc::now().format("%Y%m%d%H%M%S").to_string())
}
fn sequential(version: i64) -> MigrationOrdering {
Self::Sequential(format!("{version:04}"))
}
fn file_prefix(&self) -> &str {
match self {
MigrationOrdering::Timestamp(prefix) => prefix,
MigrationOrdering::Sequential(prefix) => prefix,
}
}
fn infer(sequential: bool, timestamp: bool, migrator: &Migrator) -> Self {
match (timestamp, sequential) {
(true, true) => panic!("Impossible to specify both timestamp and sequential mode"),
(true, false) => MigrationOrdering::timestamp(),
(false, true) => MigrationOrdering::sequential(
migrator
.iter()
.last()
.map_or(1, |last_migration| last_migration.version + 1),
),
(false, false) => {
// inferring the naming scheme
let migrations = migrator
.iter()
.filter(|migration| migration.migration_type.is_up_migration())
.rev()
.take(2)
.collect::<Vec<_>>();
if let [last, pre_last] = &migrations[..] {
// there are at least two migrations, compare the last twothere's only one existing migration
if last.version - pre_last.version == 1 {
// their version numbers differ by 1, infer sequential
MigrationOrdering::sequential(last.version + 1)
} else {
MigrationOrdering::timestamp()
}
} else if let [last] = &migrations[..] {
// there is only one existing migration
if last.version == 0 || last.version == 1 {
// infer sequential if the version number is 0 or 1
MigrationOrdering::sequential(last.version + 1)
} else {
MigrationOrdering::timestamp()
}
} else {
MigrationOrdering::timestamp()
}
}
}
}
}
pub async fn add(
migration_source: &str,
description: &str,
reversible: bool,
sequential: bool,
timestamp: bool,
) -> anyhow::Result<()> {
fs::create_dir_all(migration_source).context("Unable to create migrations directory")?;
let migrator = Migrator::new(Path::new(migration_source)).await?;
// Type of newly created migration will be the same as the first one
// or reversible flag if this is the first migration
let migration_type = MigrationType::infer(&migrator, reversible);
let ordering = MigrationOrdering::infer(sequential, timestamp, &migrator);
let file_prefix = ordering.file_prefix();
if migration_type.is_reversible() {
if opts.reversible(&config, &migrator) {
create_file(
migration_source,
file_prefix,
description,
source,
&version_prefix,
&opts.description,
MigrationType::ReversibleUp,
)?;
create_file(
migration_source,
file_prefix,
description,
source,
&version_prefix,
&opts.description,
MigrationType::ReversibleDown,
)?;
} else {
create_file(
migration_source,
file_prefix,
description,
source,
&version_prefix,
&opts.description,
MigrationType::Simple,
)?;
}
// if the migrations directory is empty
let has_existing_migrations = fs::read_dir(migration_source)
let has_existing_migrations = fs::read_dir(source)
.map(|mut dir| dir.next().is_some())
.unwrap_or(false);
if !has_existing_migrations {
let quoted_source = if migration_source != "migrations" {
format!("{migration_source:?}")
let quoted_source = if opts.source.source.is_some() {
format!("{source:?}")
} else {
"".to_string()
};
@ -184,6 +89,32 @@ See: https://docs.rs/sqlx/{version}/sqlx/macro.migrate.html
Ok(())
}
fn create_file(
migration_source: &str,
file_prefix: &str,
description: &str,
migration_type: MigrationType,
) -> anyhow::Result<()> {
use std::path::PathBuf;
let mut file_name = file_prefix.to_string();
file_name.push('_');
file_name.push_str(&description.replace(' ', "_"));
file_name.push_str(migration_type.suffix());
let mut path = PathBuf::new();
path.push(migration_source);
path.push(&file_name);
println!("Creating {}", style(path.display()).cyan());
let mut file = File::create(&path).context("Failed to create migration file")?;
std::io::Write::write_all(&mut file, migration_type.file_content().as_bytes())?;
Ok(())
}
fn short_checksum(checksum: &[u8]) -> String {
let mut s = String::with_capacity(checksum.len() * 2);
for b in checksum {
@ -192,14 +123,25 @@ fn short_checksum(checksum: &[u8]) -> String {
s
}
pub async fn info(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> {
let migrator = Migrator::new(Path::new(migration_source)).await?;
pub async fn info(
config: &Config,
migration_source: &MigrationSourceOpt,
connect_opts: &ConnectOpts,
) -> anyhow::Result<()> {
let migrator = migration_source.resolve(config).await?;
let mut conn = crate::connect(connect_opts).await?;
conn.ensure_migrations_table().await?;
// FIXME: we shouldn't actually be creating anything here
for schema_name in &config.migrate.create_schemas {
conn.create_schema_if_not_exists(schema_name).await?;
}
conn.ensure_migrations_table(config.migrate.table_name())
.await?;
let applied_migrations: HashMap<_, _> = conn
.list_applied_migrations()
.list_applied_migrations(config.migrate.table_name())
.await?
.into_iter()
.map(|m| (m.version, m))
@ -272,13 +214,15 @@ fn validate_applied_migrations(
}
pub async fn run(
migration_source: &str,
config: &Config,
migration_source: &MigrationSourceOpt,
connect_opts: &ConnectOpts,
dry_run: bool,
ignore_missing: bool,
target_version: Option<i64>,
) -> anyhow::Result<()> {
let migrator = Migrator::new(Path::new(migration_source)).await?;
let migrator = migration_source.resolve(config).await?;
if let Some(target_version) = target_version {
if !migrator.version_exists(target_version) {
bail!(MigrateError::VersionNotPresent(target_version));
@ -287,14 +231,21 @@ pub async fn run(
let mut conn = crate::connect(connect_opts).await?;
conn.ensure_migrations_table().await?;
for schema_name in &config.migrate.create_schemas {
conn.create_schema_if_not_exists(schema_name).await?;
}
let version = conn.dirty_version().await?;
conn.ensure_migrations_table(config.migrate.table_name())
.await?;
let version = conn.dirty_version(config.migrate.table_name()).await?;
if let Some(version) = version {
bail!(MigrateError::Dirty(version));
}
let applied_migrations = conn.list_applied_migrations().await?;
let applied_migrations = conn
.list_applied_migrations(config.migrate.table_name())
.await?;
validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?;
let latest_version = applied_migrations
@ -332,7 +283,7 @@ pub async fn run(
let elapsed = if dry_run || skip {
Duration::new(0, 0)
} else {
conn.apply(migration).await?
conn.apply(config.migrate.table_name(), migration).await?
};
let text = if skip {
"Skipped"
@ -365,13 +316,15 @@ pub async fn run(
}
pub async fn revert(
migration_source: &str,
config: &Config,
migration_source: &MigrationSourceOpt,
connect_opts: &ConnectOpts,
dry_run: bool,
ignore_missing: bool,
target_version: Option<i64>,
) -> anyhow::Result<()> {
let migrator = Migrator::new(Path::new(migration_source)).await?;
let migrator = migration_source.resolve(config).await?;
if let Some(target_version) = target_version {
if target_version != 0 && !migrator.version_exists(target_version) {
bail!(MigrateError::VersionNotPresent(target_version));
@ -380,14 +333,22 @@ pub async fn revert(
let mut conn = crate::connect(connect_opts).await?;
conn.ensure_migrations_table().await?;
// FIXME: we should not be creating anything here if it doesn't exist
for schema_name in &config.migrate.create_schemas {
conn.create_schema_if_not_exists(schema_name).await?;
}
let version = conn.dirty_version().await?;
conn.ensure_migrations_table(config.migrate.table_name())
.await?;
let version = conn.dirty_version(config.migrate.table_name()).await?;
if let Some(version) = version {
bail!(MigrateError::Dirty(version));
}
let applied_migrations = conn.list_applied_migrations().await?;
let applied_migrations = conn
.list_applied_migrations(config.migrate.table_name())
.await?;
validate_applied_migrations(&applied_migrations, &migrator, ignore_missing)?;
let latest_version = applied_migrations
@ -421,7 +382,7 @@ pub async fn revert(
let elapsed = if dry_run || skip {
Duration::new(0, 0)
} else {
conn.revert(migration).await?
conn.revert(config.migrate.table_name(), migration).await?
};
let text = if skip {
"Skipped"
@ -458,7 +419,13 @@ pub async fn revert(
Ok(())
}
pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> {
pub fn build_script(
config: &Config,
migration_source: &MigrationSourceOpt,
force: bool,
) -> anyhow::Result<()> {
let source = migration_source.resolve_path(config);
anyhow::ensure!(
Path::new("Cargo.toml").exists(),
"must be run in a Cargo project root"
@ -473,7 +440,7 @@ pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> {
r#"// generated by `sqlx migrate build-script`
fn main() {{
// trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed={migration_source}");
println!("cargo:rerun-if-changed={source}");
}}
"#,
);

View File

@ -1,187 +0,0 @@
use anyhow::{bail, Context};
use console::style;
use std::fs::{self, File};
use std::io::{Read, Write};
const MIGRATION_FOLDER: &str = "migrations";
pub struct Migration {
pub name: String,
pub sql: String,
}
pub fn add_file(name: &str) -> anyhow::Result<()> {
use chrono::prelude::*;
use std::path::PathBuf;
fs::create_dir_all(MIGRATION_FOLDER).context("Unable to create migrations directory")?;
let dt = Utc::now();
let mut file_name = dt.format("%Y-%m-%d_%H-%M-%S").to_string();
file_name.push_str("_");
file_name.push_str(name);
file_name.push_str(".sql");
let mut path = PathBuf::new();
path.push(MIGRATION_FOLDER);
path.push(&file_name);
let mut file = File::create(path).context("Failed to create file")?;
file.write_all(b"-- Add migration script here")
.context("Could not write to file")?;
println!("Created migration: '{file_name}'");
Ok(())
}
pub async fn run() -> anyhow::Result<()> {
let migrator = crate::migrator::get()?;
if !migrator.can_migrate_database() {
bail!(
"Database migrations not supported for {}",
migrator.database_type()
);
}
migrator.create_migration_table().await?;
let migrations = load_migrations()?;
for mig in migrations.iter() {
let mut tx = migrator.begin_migration().await?;
if tx.check_if_applied(&mig.name).await? {
println!("Already applied migration: '{}'", mig.name);
continue;
}
println!("Applying migration: '{}'", mig.name);
tx.execute_migration(&mig.sql)
.await
.with_context(|| format!("Failed to run migration {:?}", &mig.name))?;
tx.save_applied_migration(&mig.name)
.await
.context("Failed to insert migration")?;
tx.commit().await.context("Failed")?;
}
Ok(())
}
pub async fn list() -> anyhow::Result<()> {
let migrator = crate::migrator::get()?;
if !migrator.can_migrate_database() {
bail!(
"Database migrations not supported for {}",
migrator.database_type()
);
}
let file_migrations = load_migrations()?;
if migrator
.check_if_database_exists(&migrator.get_database_name()?)
.await?
{
let applied_migrations = migrator.get_migrations().await.unwrap_or_else(|_| {
println!("Could not retrieve data from migration table");
Vec::new()
});
let mut width = 0;
for mig in file_migrations.iter() {
width = std::cmp::max(width, mig.name.len());
}
for mig in file_migrations.iter() {
let status = if applied_migrations
.iter()
.find(|&m| mig.name == *m)
.is_some()
{
style("Applied").green()
} else {
style("Not Applied").yellow()
};
println!("{:width$}\t{}", mig.name, status, width = width);
}
let orphans = check_for_orphans(file_migrations, applied_migrations);
if let Some(orphans) = orphans {
println!("\nFound migrations applied in the database that does not have a corresponding migration file:");
for name in orphans {
println!("{:width$}\t{}", name, style("Orphan").red(), width = width);
}
}
} else {
println!("No database found, listing migrations");
for mig in file_migrations {
println!("{}", mig.name);
}
}
Ok(())
}
fn load_migrations() -> anyhow::Result<Vec<Migration>> {
let entries = fs::read_dir(&MIGRATION_FOLDER).context("Could not find 'migrations' dir")?;
let mut migrations = Vec::new();
for e in entries {
if let Ok(e) = e {
if let Ok(meta) = e.metadata() {
if !meta.is_file() {
continue;
}
if let Some(ext) = e.path().extension() {
if ext != "sql" {
println!("Wrong ext: {ext:?}");
continue;
}
} else {
continue;
}
let mut file = File::open(e.path())
.with_context(|| format!("Failed to open: '{:?}'", e.file_name()))?;
let mut contents = String::new();
file.read_to_string(&mut contents)
.with_context(|| format!("Failed to read: '{:?}'", e.file_name()))?;
migrations.push(Migration {
name: e.file_name().to_str().unwrap().to_string(),
sql: contents,
});
}
}
}
migrations.sort_by(|a, b| a.name.partial_cmp(&b.name).unwrap());
Ok(migrations)
}
fn check_for_orphans(
file_migrations: Vec<Migration>,
applied_migrations: Vec<String>,
) -> Option<Vec<String>> {
let orphans: Vec<String> = applied_migrations
.iter()
.filter(|m| !file_migrations.iter().any(|fm| fm.name == **m))
.cloned()
.collect();
if orphans.len() > 0 {
Some(orphans)
} else {
None
}
}

View File

@ -1,11 +1,17 @@
use std::ops::{Deref, Not};
use crate::config::migrate::{DefaultMigrationType, DefaultVersioning};
use crate::config::Config;
use anyhow::Context;
use chrono::Utc;
use clap::{
builder::{styling::AnsiColor, Styles},
Args, Parser,
};
#[cfg(feature = "completions")]
use clap_complete::Shell;
use sqlx::migrate::{MigrateError, Migrator, ResolveWith};
use std::env;
use std::ops::{Deref, Not};
use std::path::PathBuf;
const HELP_STYLES: Styles = Styles::styled()
.header(AnsiColor::Blue.on_default().bold())
@ -62,6 +68,9 @@ pub enum Command {
#[clap(flatten)]
connect_opts: ConnectOpts,
#[clap(flatten)]
config: ConfigOpt,
},
#[clap(alias = "mig")]
@ -85,6 +94,9 @@ pub enum DatabaseCommand {
Create {
#[clap(flatten)]
connect_opts: ConnectOpts,
#[clap(flatten)]
config: ConfigOpt,
},
/// Drops the database specified in your DATABASE_URL.
@ -92,6 +104,9 @@ pub enum DatabaseCommand {
#[clap(flatten)]
confirmation: Confirmation,
#[clap(flatten)]
config: ConfigOpt,
#[clap(flatten)]
connect_opts: ConnectOpts,
@ -106,7 +121,10 @@ pub enum DatabaseCommand {
confirmation: Confirmation,
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
#[clap(flatten)]
connect_opts: ConnectOpts,
@ -119,7 +137,10 @@ pub enum DatabaseCommand {
/// Creates the database specified in your DATABASE_URL and runs any pending migrations.
Setup {
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
#[clap(flatten)]
connect_opts: ConnectOpts,
@ -137,8 +158,55 @@ pub struct MigrateOpt {
pub enum MigrateCommand {
/// Create a new migration with the given description.
///
/// --------------------------------
///
/// Migrations may either be simple, or reversible.
///
/// Reversible migrations can be reverted with `sqlx migrate revert`, simple migrations cannot.
///
/// Reversible migrations are created as a pair of two files with the same filename but
/// extensions `.up.sql` and `.down.sql` for the up-migration and down-migration, respectively.
///
/// The up-migration should contain the commands to be used when applying the migration,
/// while the down-migration should contain the commands to reverse the changes made by the
/// up-migration.
///
/// When writing down-migrations, care should be taken to ensure that they
/// do not leave the database in an inconsistent state.
///
/// Simple migrations have just `.sql` for their extension and represent an up-migration only.
///
/// Note that reverting a migration is **destructive** and will likely result in data loss.
/// Reverting a migration will not restore any data discarded by commands in the up-migration.
///
/// It is recommended to always back up the database before running migrations.
///
/// --------------------------------
///
/// For convenience, this command attempts to detect if reversible migrations are in-use.
///
/// If the latest existing migration is reversible, the new migration will also be reversible.
///
/// Otherwise, a simple migration is created.
///
/// This behavior can be overridden by `--simple` or `--reversible`, respectively.
///
/// The default type to use can also be set in `sqlx.toml`.
///
/// --------------------------------
///
/// A version number will be automatically assigned to the migration.
///
/// Migrations are applied in ascending order by version number.
/// Version numbers do not need to be strictly consecutive.
///
/// The migration process will abort if SQLx encounters a migration with a version number
/// less than _any_ previously applied migration.
///
/// Migrations should only be created with increasing version number.
///
/// --------------------------------
///
/// For convenience, this command will attempt to detect if sequential versioning is in use,
/// and if so, continue the sequence.
///
@ -148,33 +216,20 @@ pub enum MigrateCommand {
///
/// * only one migration exists and its version number is either 0 or 1.
///
/// Otherwise timestamp versioning is assumed.
/// Otherwise, timestamp versioning (`YYYYMMDDHHMMSS`) is assumed.
///
/// This behavior can overridden by `--sequential` or `--timestamp`, respectively.
Add {
description: String,
#[clap(flatten)]
source: Source,
/// If true, creates a pair of up and down migration files with same version
/// else creates a single sql file
#[clap(short)]
reversible: bool,
/// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`.
#[clap(short, long)]
timestamp: bool,
/// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`.
#[clap(short, long, conflicts_with = "timestamp")]
sequential: bool,
},
/// This behavior can be overridden by `--timestamp` or `--sequential`, respectively.
///
/// The default versioning to use can also be set in `sqlx.toml`.
Add(AddMigrationOpts),
/// Run all pending migrations.
Run {
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
/// List all the migrations to be run without applying
#[clap(long)]
@ -195,7 +250,10 @@ pub enum MigrateCommand {
/// Revert the latest migration with a down file.
Revert {
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
/// List the migration to be reverted without applying
#[clap(long)]
@ -217,7 +275,10 @@ pub enum MigrateCommand {
/// List all available migrations.
Info {
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
#[clap(flatten)]
connect_opts: ConnectOpts,
@ -228,7 +289,10 @@ pub enum MigrateCommand {
/// Must be run in a Cargo project root.
BuildScript {
#[clap(flatten)]
source: Source,
source: MigrationSourceOpt,
#[clap(flatten)]
config: ConfigOpt,
/// Overwrite the build script if it already exists.
#[clap(long)]
@ -236,19 +300,62 @@ pub enum MigrateCommand {
},
}
/// Argument for the migration scripts source.
#[derive(Args, Debug)]
pub struct Source {
/// Path to folder containing migrations.
#[clap(long, default_value = "migrations")]
source: String,
pub struct AddMigrationOpts {
pub description: String,
#[clap(flatten)]
pub source: MigrationSourceOpt,
#[clap(flatten)]
pub config: ConfigOpt,
/// If set, create an up-migration only. Conflicts with `--reversible`.
#[clap(long, conflicts_with = "reversible")]
simple: bool,
/// If set, create a pair of up and down migration files with same version.
///
/// Conflicts with `--simple`.
#[clap(short, long, conflicts_with = "simple")]
reversible: bool,
/// If set, use timestamp versioning for the new migration. Conflicts with `--sequential`.
///
/// Timestamp format: `YYYYMMDDHHMMSS`
#[clap(short, long, conflicts_with = "sequential")]
timestamp: bool,
/// If set, use sequential versioning for the new migration. Conflicts with `--timestamp`.
#[clap(short, long, conflicts_with = "timestamp")]
sequential: bool,
}
impl Deref for Source {
type Target = String;
/// Argument for the migration scripts source.
#[derive(Args, Debug)]
pub struct MigrationSourceOpt {
/// Path to folder containing migrations.
///
/// Defaults to `migrations/` if not specified, but a different default may be set by `sqlx.toml`.
#[clap(long)]
pub source: Option<String>,
}
fn deref(&self) -> &Self::Target {
&self.source
impl MigrationSourceOpt {
pub fn resolve_path<'a>(&'a self, config: &'a Config) -> &'a str {
if let Some(source) = &self.source {
return source;
}
config.migrate.migrations_dir()
}
pub async fn resolve(&self, config: &Config) -> Result<Migrator, MigrateError> {
Migrator::new(ResolveWith(
self.resolve_path(config),
config.migrate.to_resolve_config(),
))
.await
}
}
@ -259,7 +366,7 @@ pub struct ConnectOpts {
pub no_dotenv: NoDotenvOpt,
/// Location of the DB, by default will be read from the DATABASE_URL env var or `.env` files.
#[clap(long, short = 'D', env)]
#[clap(long, short = 'D')]
pub database_url: Option<String>,
/// The maximum time, in seconds, to try connecting to the database server before
@ -290,15 +397,85 @@ pub struct NoDotenvOpt {
pub no_dotenv: bool,
}
#[derive(Args, Debug)]
pub struct ConfigOpt {
/// Override the path to the config file.
///
/// Defaults to `sqlx.toml` in the current directory, if it exists.
///
/// Configuration file loading may be bypassed with `--config=/dev/null` on Linux,
/// or `--config=NUL` on Windows.
///
/// Config file loading is enabled by the `sqlx-toml` feature.
#[clap(long)]
pub config: Option<PathBuf>,
}
impl ConnectOpts {
/// Require a database URL to be provided, otherwise
/// return an error.
pub fn required_db_url(&self) -> anyhow::Result<&str> {
self.database_url.as_deref().ok_or_else(
|| anyhow::anyhow!(
"the `--database-url` option or the `DATABASE_URL` environment variable must be provided"
)
)
pub fn expect_db_url(&self) -> anyhow::Result<&str> {
self.database_url
.as_deref()
.context("BUG: database_url not populated")
}
/// Populate `database_url` from the environment, if not set.
pub fn populate_db_url(&mut self, config: &Config) -> anyhow::Result<()> {
if self.database_url.is_some() {
return Ok(());
}
let var = config.common.database_url_var();
let context = if var != "DATABASE_URL" {
" (`common.database-url-var` in `sqlx.toml`)"
} else {
""
};
match env::var(var) {
Ok(url) => {
if !context.is_empty() {
eprintln!("Read database url from `{var}`{context}");
}
self.database_url = Some(url)
}
Err(env::VarError::NotPresent) => {
anyhow::bail!("`--database-url` or `{var}`{context} must be set")
}
Err(env::VarError::NotUnicode(_)) => {
anyhow::bail!("`{var}`{context} is not valid UTF-8");
}
}
Ok(())
}
}
impl ConfigOpt {
pub async fn load_config(&self) -> anyhow::Result<Config> {
let path = self.config.clone();
// Tokio does file I/O on a background task anyway
tokio::task::spawn_blocking(|| {
if let Some(path) = path {
let err_str = format!("error reading config from {path:?}");
Config::try_from_path(path).context(err_str)
} else {
let path = PathBuf::from("sqlx.toml");
if path.exists() {
eprintln!("Found `sqlx.toml` in current directory; reading...");
Ok(Config::try_from_path(path)?)
} else {
Ok(Config::default())
}
}
})
.await
.context("unexpected error loading config")?
}
}
@ -334,3 +511,67 @@ impl Not for IgnoreMissing {
!self.ignore_missing
}
}
impl AddMigrationOpts {
pub fn reversible(&self, config: &Config, migrator: &Migrator) -> bool {
if self.reversible {
return true;
}
if self.simple {
return false;
}
match config.migrate.defaults.migration_type {
DefaultMigrationType::Inferred => migrator
.iter()
.last()
.is_some_and(|m| m.migration_type.is_reversible()),
DefaultMigrationType::Simple => false,
DefaultMigrationType::Reversible => true,
}
}
pub fn version_prefix(&self, config: &Config, migrator: &Migrator) -> String {
let default_versioning = &config.migrate.defaults.migration_versioning;
match (self.timestamp, self.sequential, default_versioning) {
(true, false, _) | (false, false, DefaultVersioning::Timestamp) => next_timestamp(),
(false, true, _) | (false, false, DefaultVersioning::Sequential) => fmt_sequential(
migrator
.migrations
.last()
.map_or(1, |migration| migration.version + 1),
),
(false, false, DefaultVersioning::Inferred) => {
migrator
.migrations
.rchunks(2)
.next()
.and_then(|migrations| {
match migrations {
[previous, latest] => {
// If the latest two versions differ by 1, infer sequential.
(latest.version - previous.version == 1)
.then_some(latest.version + 1)
}
[latest] => {
// If only one migration exists and its version is 0 or 1, infer sequential
matches!(latest.version, 0 | 1).then_some(latest.version + 1)
}
_ => unreachable!(),
}
})
.map_or_else(next_timestamp, fmt_sequential)
}
(true, true, _) => unreachable!("BUG: Clap should have rejected this case"),
}
}
}
fn next_timestamp() -> String {
Utc::now().format("%Y%m%d%H%M%S").to_string()
}
fn fmt_sequential(version: i64) -> String {
format!("{version:04}")
}

View File

@ -1,20 +1,11 @@
use anyhow::Context;
use assert_cmd::Command;
use std::cmp::Ordering;
use std::fs::read_dir;
use std::ops::Index;
use std::path::{Path, PathBuf};
use tempfile::TempDir;
#[test]
fn add_migration_ambiguous() -> anyhow::Result<()> {
for reversible in [true, false] {
let files = AddMigrations::new()?
.run("hello world", reversible, true, true, false)?
.fs_output()?;
assert_eq!(files.0, Vec::<FileName>::new());
}
Ok(())
}
#[derive(Debug, PartialEq, Eq)]
struct FileName {
id: u64,
@ -34,11 +25,6 @@ impl PartialOrd<Self> for FileName {
impl FileName {
fn assert_is_timestamp(&self) {
//if the library is still used in 2050, this will need bumping ^^
assert!(
self.id < 20500101000000,
"{self:?} is too high for a timestamp"
);
assert!(
self.id > 20200101000000,
"{self:?} is too low for a timestamp"
@ -59,6 +45,154 @@ impl From<PathBuf> for FileName {
}
}
}
struct AddMigrationsResult(Vec<FileName>);
impl AddMigrationsResult {
fn len(&self) -> usize {
self.0.len()
}
fn assert_is_reversible(&self) {
let mut up_cnt = 0;
let mut down_cnt = 0;
for file in self.0.iter() {
if file.suffix == "down.sql" {
down_cnt += 1;
} else if file.suffix == "up.sql" {
up_cnt += 1;
} else {
panic!("unknown suffix for {file:?}");
}
assert!(file.description.starts_with("hello_world"));
}
assert_eq!(up_cnt, down_cnt);
}
fn assert_is_not_reversible(&self) {
for file in self.0.iter() {
assert_eq!(file.suffix, "sql");
assert!(file.description.starts_with("hello_world"));
}
}
}
impl Index<usize> for AddMigrationsResult {
type Output = FileName;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
struct AddMigrations {
tempdir: TempDir,
config_arg: Option<String>,
}
impl AddMigrations {
fn new() -> anyhow::Result<Self> {
anyhow::Ok(Self {
tempdir: TempDir::new()?,
config_arg: None,
})
}
fn with_config(mut self, filename: &str) -> anyhow::Result<Self> {
let path = format!("./tests/assets/{filename}");
let path = std::fs::canonicalize(&path)
.with_context(|| format!("error canonicalizing path {path:?}"))?;
let path = path
.to_str()
.with_context(|| format!("canonicalized version of path {path:?} is not UTF-8"))?;
self.config_arg = Some(format!("--config={path}"));
Ok(self)
}
fn run(
&self,
description: &str,
revesible: bool,
timestamp: bool,
sequential: bool,
expect_success: bool,
) -> anyhow::Result<&'_ Self> {
let cmd_result = Command::cargo_bin("cargo-sqlx")?
.current_dir(&self.tempdir)
.args(
[
vec!["sqlx", "migrate", "add", description],
self.config_arg.as_deref().map_or(vec![], |arg| vec![arg]),
match revesible {
true => vec!["-r"],
false => vec![],
},
match timestamp {
true => vec!["--timestamp"],
false => vec![],
},
match sequential {
true => vec!["--sequential"],
false => vec![],
},
]
.concat(),
)
.env("RUST_BACKTRACE", "1")
.assert();
if expect_success {
cmd_result.success();
} else {
cmd_result.failure();
}
anyhow::Ok(self)
}
fn fs_output(&self) -> anyhow::Result<AddMigrationsResult> {
let files = recurse_files(&self.tempdir)?;
let mut fs_paths = Vec::with_capacity(files.len());
for path in files {
let relative_path = path.strip_prefix(self.tempdir.path())?.to_path_buf();
fs_paths.push(FileName::from(relative_path));
}
Ok(AddMigrationsResult(fs_paths))
}
}
fn recurse_files(path: impl AsRef<Path>) -> anyhow::Result<Vec<PathBuf>> {
let mut buf = vec![];
let entries = read_dir(path)?;
for entry in entries {
let entry = entry?;
let meta = entry.metadata()?;
if meta.is_dir() {
let mut subdir = recurse_files(entry.path())?;
buf.append(&mut subdir);
}
if meta.is_file() {
buf.push(entry.path());
}
}
buf.sort();
Ok(buf)
}
#[test]
fn add_migration_error_ambiguous() -> anyhow::Result<()> {
for reversible in [true, false] {
let files = AddMigrations::new()?
// Passing both `--timestamp` and `--reversible` should result in an error.
.run("hello world", reversible, true, true, false)?
.fs_output()?;
// Assert that no files are created
assert_eq!(files.0, []);
}
Ok(())
}
#[test]
fn add_migration_sequential() -> anyhow::Result<()> {
{
@ -74,10 +208,12 @@ fn add_migration_sequential() -> anyhow::Result<()> {
.run("hello world1", false, false, true, true)?
.run("hello world2", true, false, true, true)?
.fs_output()?;
assert_eq!(files.len(), 2);
files.assert_is_not_reversible();
assert_eq!(files.len(), 3);
assert_eq!(files.0[0].id, 1);
assert_eq!(files.0[1].id, 2);
assert_eq!(files.0[1].suffix, "down.sql");
assert_eq!(files.0[2].id, 2);
assert_eq!(files.0[2].suffix, "up.sql");
}
Ok(())
}
@ -126,146 +262,145 @@ fn add_migration_timestamp() -> anyhow::Result<()> {
.run("hello world1", false, true, false, true)?
.run("hello world2", true, false, true, true)?
.fs_output()?;
assert_eq!(files.len(), 2);
files.assert_is_not_reversible();
assert_eq!(files.len(), 3);
files.0[0].assert_is_timestamp();
// sequential -> timestamp is one way
files.0[1].assert_is_timestamp();
files.0[2].assert_is_timestamp();
}
Ok(())
}
#[test]
fn add_migration_timestamp_reversible() -> anyhow::Result<()> {
{
let files = AddMigrations::new()?
.run("hello world", true, false, false, true)?
.fs_output()?;
assert_eq!(files.len(), 2);
files.assert_is_reversible();
files.0[0].assert_is_timestamp();
files.0[1].assert_is_timestamp();
// .up.sql and .down.sql
files[0].assert_is_timestamp();
assert_eq!(files[1].id, files[0].id);
}
{
let files = AddMigrations::new()?
.run("hello world", true, true, false, true)?
.fs_output()?;
assert_eq!(files.len(), 2);
files.assert_is_reversible();
files.0[0].assert_is_timestamp();
files.0[1].assert_is_timestamp();
// .up.sql and .down.sql
files[0].assert_is_timestamp();
assert_eq!(files[1].id, files[0].id);
}
{
let files = AddMigrations::new()?
.run("hello world1", true, true, false, true)?
.run("hello world2", true, false, true, true)?
// Reversible should be inferred, but sequential should be forced
.run("hello world2", false, false, true, true)?
.fs_output()?;
assert_eq!(files.len(), 4);
files.assert_is_reversible();
files.0[0].assert_is_timestamp();
files.0[1].assert_is_timestamp();
files.0[2].assert_is_timestamp();
files.0[3].assert_is_timestamp();
// First pair: .up.sql and .down.sql
files[0].assert_is_timestamp();
assert_eq!(files[1].id, files[0].id);
// Second pair; we set `--sequential` so this version should be one higher
assert_eq!(files[2].id, files[1].id + 1);
assert_eq!(files[3].id, files[1].id + 1);
}
Ok(())
}
struct AddMigrationsResult(Vec<FileName>);
impl AddMigrationsResult {
fn len(&self) -> usize {
self.0.len()
}
fn assert_is_reversible(&self) {
let mut up_cnt = 0;
let mut down_cnt = 0;
for file in self.0.iter() {
if file.suffix == "down.sql" {
down_cnt += 1;
} else if file.suffix == "up.sql" {
up_cnt += 1;
} else {
panic!("unknown suffix for {file:?}");
}
assert!(file.description.starts_with("hello_world"));
}
assert_eq!(up_cnt, down_cnt);
}
fn assert_is_not_reversible(&self) {
for file in self.0.iter() {
assert_eq!(file.suffix, "sql");
assert!(file.description.starts_with("hello_world"));
}
}
}
struct AddMigrations(TempDir);
#[test]
fn add_migration_config_default_type_reversible() -> anyhow::Result<()> {
let files = AddMigrations::new()?
.with_config("config_default_type_reversible.toml")?
// Type should default to reversible without any flags
.run("hello world", false, false, false, true)?
.run("hello world2", false, false, false, true)?
.run("hello world3", false, false, false, true)?
.fs_output()?;
impl AddMigrations {
fn new() -> anyhow::Result<Self> {
anyhow::Ok(Self(TempDir::new()?))
}
fn run(
self,
description: &str,
revesible: bool,
timestamp: bool,
sequential: bool,
expect_success: bool,
) -> anyhow::Result<Self> {
let cmd_result = Command::cargo_bin("cargo-sqlx")?
.current_dir(&self.0)
.args(
[
vec!["sqlx", "migrate", "add", description],
match revesible {
true => vec!["-r"],
false => vec![],
},
match timestamp {
true => vec!["--timestamp"],
false => vec![],
},
match sequential {
true => vec!["--sequential"],
false => vec![],
},
]
.concat(),
)
.assert();
if expect_success {
cmd_result.success();
} else {
cmd_result.failure();
}
anyhow::Ok(self)
}
fn fs_output(&self) -> anyhow::Result<AddMigrationsResult> {
let files = recurse_files(&self.0)?;
let mut fs_paths = Vec::with_capacity(files.len());
for path in files {
let relative_path = path.strip_prefix(self.0.path())?.to_path_buf();
fs_paths.push(FileName::from(relative_path));
}
Ok(AddMigrationsResult(fs_paths))
}
assert_eq!(files.len(), 6);
files.assert_is_reversible();
files[0].assert_is_timestamp();
assert_eq!(files[1].id, files[0].id);
files[2].assert_is_timestamp();
assert_eq!(files[3].id, files[2].id);
files[4].assert_is_timestamp();
assert_eq!(files[5].id, files[4].id);
Ok(())
}
fn recurse_files(path: impl AsRef<Path>) -> anyhow::Result<Vec<PathBuf>> {
let mut buf = vec![];
let entries = read_dir(path)?;
#[test]
fn add_migration_config_default_versioning_sequential() -> anyhow::Result<()> {
let files = AddMigrations::new()?
.with_config("config_default_versioning_sequential.toml")?
// Versioning should default to timestamp without any flags
.run("hello world", false, false, false, true)?
.run("hello world2", false, false, false, true)?
.run("hello world3", false, false, false, true)?
.fs_output()?;
for entry in entries {
let entry = entry?;
let meta = entry.metadata()?;
assert_eq!(files.len(), 3);
files.assert_is_not_reversible();
if meta.is_dir() {
let mut subdir = recurse_files(entry.path())?;
buf.append(&mut subdir);
}
assert_eq!(files[0].id, 1);
assert_eq!(files[1].id, 2);
assert_eq!(files[2].id, 3);
if meta.is_file() {
buf.push(entry.path());
}
}
buf.sort();
Ok(buf)
Ok(())
}
#[test]
fn add_migration_config_default_versioning_timestamp() -> anyhow::Result<()> {
let migrations = AddMigrations::new()?;
migrations
.run("hello world", false, false, true, true)?
// Default config should infer sequential even without passing `--sequential`
.run("hello world2", false, false, false, true)?
.run("hello world3", false, false, false, true)?;
let files = migrations.fs_output()?;
assert_eq!(files.len(), 3);
files.assert_is_not_reversible();
assert_eq!(files[0].id, 1);
assert_eq!(files[1].id, 2);
assert_eq!(files[2].id, 3);
// Now set a config that uses `default-versioning = "timestamp"`
let migrations = migrations.with_config("config_default_versioning_timestamp.toml")?;
// Now the default should be a timestamp
migrations
.run("hello world4", false, false, false, true)?
.run("hello world5", false, false, false, true)?;
let files = migrations.fs_output()?;
assert_eq!(files.len(), 5);
files.assert_is_not_reversible();
assert_eq!(files[0].id, 1);
assert_eq!(files[1].id, 2);
assert_eq!(files[2].id, 3);
files[3].assert_is_timestamp();
files[4].assert_is_timestamp();
Ok(())
}

View File

@ -0,0 +1,2 @@
[migrate.defaults]
migration-type = "reversible"

View File

@ -0,0 +1,2 @@
[migrate.defaults]
migration-versioning = "sequential"

View File

@ -0,0 +1,2 @@
[migrate.defaults]
migration-versioning = "timestamp"

View File

@ -1,25 +1,41 @@
use assert_cmd::{assert::Assert, Command};
use sqlx::_unstable::config::Config;
use sqlx::{migrate::Migrate, Connection, SqliteConnection};
use std::{
env::temp_dir,
fs::remove_file,
env, fs,
path::{Path, PathBuf},
};
pub struct TestDatabase {
file_path: PathBuf,
migrations: String,
migrations_path: PathBuf,
pub config_path: Option<PathBuf>,
}
impl TestDatabase {
pub fn new(name: &str, migrations: &str) -> Self {
let migrations_path = Path::new("tests").join(migrations);
let file_path = Path::new(&temp_dir()).join(format!("test-{}.db", name));
let ret = Self {
// Note: only set when _building_
let temp_dir = option_env!("CARGO_TARGET_TMPDIR").map_or_else(env::temp_dir, PathBuf::from);
let test_dir = temp_dir.join("migrate");
fs::create_dir_all(&test_dir)
.unwrap_or_else(|e| panic!("error creating directory: {test_dir:?}: {e}"));
let file_path = test_dir.join(format!("test-{name}.db"));
if file_path.exists() {
fs::remove_file(&file_path)
.unwrap_or_else(|e| panic!("error deleting test database {file_path:?}: {e}"));
}
let this = Self {
file_path,
migrations: String::from(migrations_path.to_str().unwrap()),
migrations_path: Path::new("tests").join(migrations),
config_path: None,
};
Command::cargo_bin("cargo-sqlx")
.unwrap()
.args([
@ -27,11 +43,15 @@ impl TestDatabase {
"database",
"create",
"--database-url",
&ret.connection_string(),
&this.connection_string(),
])
.assert()
.success();
ret
this
}
pub fn set_migrations(&mut self, migrations: &str) {
self.migrations_path = Path::new("tests").join(migrations);
}
pub fn connection_string(&self) -> String {
@ -39,55 +59,77 @@ impl TestDatabase {
}
pub fn run_migration(&self, revert: bool, version: Option<i64>, dry_run: bool) -> Assert {
let ver = match version {
Some(v) => v.to_string(),
None => String::from(""),
};
Command::cargo_bin("cargo-sqlx")
.unwrap()
.args(
[
vec![
"sqlx",
"migrate",
match revert {
true => "revert",
false => "run",
},
"--database-url",
&self.connection_string(),
"--source",
&self.migrations,
],
match version {
Some(_) => vec!["--target-version", &ver],
None => vec![],
},
match dry_run {
true => vec!["--dry-run"],
false => vec![],
},
]
.concat(),
)
.assert()
let mut command = Command::cargo_bin("sqlx").unwrap();
command
.args([
"migrate",
match revert {
true => "revert",
false => "run",
},
"--database-url",
&self.connection_string(),
"--source",
])
.arg(&self.migrations_path);
if let Some(config_path) = &self.config_path {
command.arg("--config").arg(config_path);
}
if let Some(version) = version {
command.arg("--target-version").arg(version.to_string());
}
if dry_run {
command.arg("--dry-run");
}
command.assert()
}
pub async fn applied_migrations(&self) -> Vec<i64> {
let mut conn = SqliteConnection::connect(&self.connection_string())
.await
.unwrap();
conn.list_applied_migrations()
let config = Config::default();
conn.list_applied_migrations(config.migrate.table_name())
.await
.unwrap()
.iter()
.map(|m| m.version)
.collect()
}
pub fn migrate_info(&self) -> Assert {
let mut command = Command::cargo_bin("sqlx").unwrap();
command
.args([
"migrate",
"info",
"--database-url",
&self.connection_string(),
"--source",
])
.arg(&self.migrations_path);
if let Some(config_path) = &self.config_path {
command.arg("--config").arg(config_path);
}
command.assert()
}
}
impl Drop for TestDatabase {
fn drop(&mut self) {
remove_file(&self.file_path).unwrap();
// Only remove the database if there isn't a failure.
if !std::thread::panicking() {
fs::remove_file(&self.file_path).unwrap_or_else(|e| {
panic!("error deleting test database {:?}: {e}", self.file_path)
});
}
}
}

View File

@ -0,0 +1 @@
*.sql text eol=lf

View File

@ -0,0 +1,6 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key,
username text unique not null
);

View File

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default (datetime('now'))
);
create index post_created_at on post (created_at desc);

View File

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references "user" (user_id),
content text not null,
created_at datetime default (datetime('now'))
);
create index comment_created_at on comment (created_at desc);

View File

@ -0,0 +1 @@
*.sql text eol=crlf

View File

@ -0,0 +1,6 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key,
username text unique not null
);

View File

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default (datetime('now'))
);
create index post_created_at on post (created_at desc);

View File

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references "user" (user_id),
content text not null,
created_at datetime default (datetime('now'))
);
create index comment_created_at on comment (created_at desc);

View File

@ -0,0 +1 @@
*.sql text eol=lf

View File

@ -0,0 +1,6 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key,
username text unique not null
);

View File

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default (datetime('now'))
);
create index post_created_at on post (created_at desc);

View File

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references "user" (user_id),
content text not null,
created_at datetime default (datetime('now'))
);
create index comment_created_at on comment (created_at desc);

View File

@ -0,0 +1 @@
*.sql text eol=lf

View File

@ -0,0 +1,6 @@
create table user
(
-- integer primary keys are the most efficient in SQLite
user_id integer primary key,
username text unique not null
);

View File

@ -0,0 +1,10 @@
create table post
(
post_id integer primary key,
user_id integer not null references user (user_id),
content text not null,
-- Defaults have to be wrapped in parenthesis
created_at datetime default (datetime('now'))
);
create index post_created_at on post (created_at desc);

View File

@ -0,0 +1,10 @@
create table comment
(
comment_id integer primary key,
post_id integer not null references post (post_id),
user_id integer not null references "user" (user_id),
content text not null,
created_at datetime default (datetime('now'))
);
create index comment_created_at on comment (created_at desc);

View File

@ -0,0 +1,7 @@
[migrate]
# Ignore common whitespace characters (beware syntatically significant whitespace!)
# Space, tab, CR, LF, zero-width non-breaking space (U+FEFF)
#
# U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded,
# where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark
ignored-chars = [" ", "\t", "\r", "\n", "\uFEFF"]

View File

@ -13,16 +13,13 @@ async fn run_reversible_migrations() {
];
// Without --target-version specified.k
{
let db = TestDatabase::new("migrate_run_reversible_latest", "migrations_reversible");
let db = TestDatabase::new("run_reversible_latest", "migrations_reversible");
db.run_migration(false, None, false).success();
assert_eq!(db.applied_migrations().await, all_migrations);
}
// With --target-version specified.
{
let db = TestDatabase::new(
"migrate_run_reversible_latest_explicit",
"migrations_reversible",
);
let db = TestDatabase::new("run_reversible_latest_explicit", "migrations_reversible");
// Move to latest, explicitly specified.
db.run_migration(false, Some(20230501000000), false)
@ -41,10 +38,7 @@ async fn run_reversible_migrations() {
}
// With --target-version, incrementally upgrade.
{
let db = TestDatabase::new(
"migrate_run_reversible_incremental",
"migrations_reversible",
);
let db = TestDatabase::new("run_reversible_incremental", "migrations_reversible");
// First version
db.run_migration(false, Some(20230101000000), false)
@ -92,7 +86,7 @@ async fn revert_migrations() {
// Without --target-version
{
let db = TestDatabase::new("migrate_revert_incremental", "migrations_reversible");
let db = TestDatabase::new("revert_incremental", "migrations_reversible");
db.run_migration(false, None, false).success();
// Dry-run
@ -109,7 +103,7 @@ async fn revert_migrations() {
}
// With --target-version
{
let db = TestDatabase::new("migrate_revert_incremental", "migrations_reversible");
let db = TestDatabase::new("revert_incremental", "migrations_reversible");
db.run_migration(false, None, false).success();
// Dry-run downgrade to version 3.
@ -142,6 +136,32 @@ async fn revert_migrations() {
// Downgrade to zero.
db.run_migration(true, Some(0), false).success();
assert_eq!(db.applied_migrations().await, vec![] as Vec<i64>);
assert_eq!(db.applied_migrations().await, Vec::<i64>::new());
}
}
#[tokio::test]
async fn ignored_chars() {
let mut db = TestDatabase::new("ignored-chars", "ignored-chars/LF");
db.config_path = Some("tests/ignored-chars/sqlx.toml".into());
db.run_migration(false, None, false).success();
db.set_migrations("ignored-chars/CRLF");
let expected_info = "1/installed user\n2/installed post\n3/installed comment\n";
// `ignored-chars` should produce the same migration checksum here
db.migrate_info().success().stdout(expected_info);
// Running migration should be a no-op
db.run_migration(false, None, false).success().stdout("");
db.set_migrations("ignored-chars/BOM");
db.migrate_info().success().stdout(expected_info);
db.run_migration(false, None, false).success().stdout("");
db.set_migrations("ignored-chars/oops-all-tabs");
db.migrate_info().success().stdout(expected_info);
db.run_migration(false, None, false).success().stdout("");
}

View File

@ -32,6 +32,14 @@ _tls-none = []
# support offline/decoupled building (enables serialization of `Describe`)
offline = ["serde", "either/serde"]
# Enable parsing of `sqlx.toml`.
# For simplicity, the `config` module is always enabled,
# but disabling this disables the `serde` derives and the `toml` crate,
# which is a good bit less code to compile if the feature isn't being used.
sqlx-toml = ["serde", "toml/parse"]
_unstable-doc = ["sqlx-toml"]
[dependencies]
# Runtimes
async-std = { workspace = true, optional = true }
@ -71,6 +79,7 @@ percent-encoding = "2.1.0"
regex = { version = "1.5.5", optional = true }
serde = { version = "1.0.132", features = ["derive", "rc"], optional = true }
serde_json = { version = "1.0.73", features = ["raw_value"], optional = true }
toml = { version = "0.8.16", optional = true }
sha2 = { version = "0.10.0", default-features = false, optional = true }
#sqlformat = "0.2.0"
thiserror = "2.0.0"

View File

@ -44,18 +44,44 @@ impl MigrateDatabase for Any {
}
impl Migrate for AnyConnection {
fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
Box::pin(async { self.get_migrate()?.ensure_migrations_table().await })
fn create_schema_if_not_exists<'e>(
&'e mut self,
schema_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
Box::pin(async {
self.get_migrate()?
.create_schema_if_not_exists(schema_name)
.await
})
}
fn dirty_version(&mut self) -> BoxFuture<'_, Result<Option<i64>, MigrateError>> {
Box::pin(async { self.get_migrate()?.dirty_version().await })
fn ensure_migrations_table<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>> {
Box::pin(async {
self.get_migrate()?
.ensure_migrations_table(table_name)
.await
})
}
fn list_applied_migrations(
&mut self,
) -> BoxFuture<'_, Result<Vec<AppliedMigration>, MigrateError>> {
Box::pin(async { self.get_migrate()?.list_applied_migrations().await })
fn dirty_version<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Option<i64>, MigrateError>> {
Box::pin(async { self.get_migrate()?.dirty_version(table_name).await })
}
fn list_applied_migrations<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Vec<AppliedMigration>, MigrateError>> {
Box::pin(async {
self.get_migrate()?
.list_applied_migrations(table_name)
.await
})
}
fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> {
@ -66,17 +92,19 @@ impl Migrate for AnyConnection {
Box::pin(async { self.get_migrate()?.unlock().await })
}
fn apply<'e: 'm, 'm>(
fn apply<'e>(
&'e mut self,
migration: &'m Migration,
) -> BoxFuture<'m, Result<Duration, MigrateError>> {
Box::pin(async { self.get_migrate()?.apply(migration).await })
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
Box::pin(async { self.get_migrate()?.apply(table_name, migration).await })
}
fn revert<'e: 'm, 'm>(
fn revert<'e>(
&'e mut self,
migration: &'m Migration,
) -> BoxFuture<'m, Result<Duration, MigrateError>> {
Box::pin(async { self.get_migrate()?.revert(migration).await })
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>> {
Box::pin(async { self.get_migrate()?.revert(table_name, migration).await })
}
}

View File

@ -2,6 +2,7 @@ use crate::database::Database;
use crate::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
pub trait Column: 'static + Send + Sync + Debug {
type Database: Database<Column = Self>;
@ -20,6 +21,61 @@ pub trait Column: 'static + Send + Sync + Debug {
/// Gets the type information for the column.
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
/// If this column comes from a table, return the table and original column name.
///
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
/// or else the source table could not be determined.
///
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
/// or has not overridden this method.
// This method returns an owned value instead of a reference,
// to give the implementor more flexibility.
fn origin(&self) -> ColumnOrigin {
ColumnOrigin::Unknown
}
}
/// A [`Column`] that originates from a table.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TableColumn {
/// The name of the table (optionally schema-qualified) that the column comes from.
pub table: Arc<str>,
/// The original name of the column.
pub name: Arc<str>,
}
/// The possible statuses for our knowledge of the origin of a [`Column`].
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub enum ColumnOrigin {
/// The column is known to originate from a table.
///
/// Included is the table name and original column name.
Table(TableColumn),
/// The column originates from an expression, or else its origin could not be determined.
Expression,
/// The database driver does not know the column origin at this time.
///
/// This may happen if:
/// * The connection is in the middle of executing a query,
/// and cannot query the catalog to fetch this information.
/// * The connection does not have access to the database catalog.
/// * The implementation of [`Column`] did not override [`Column::origin()`].
#[default]
Unknown,
}
impl ColumnOrigin {
/// Returns the true column origin, if known.
pub fn table_column(&self) -> Option<&TableColumn> {
if let Self::Table(table_column) = self {
Some(table_column)
} else {
None
}
}
}
/// A type that can be used to index into a [`Row`] or [`Statement`].

View File

@ -0,0 +1,49 @@
/// Configuration shared by multiple components.
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case", deny_unknown_fields)
)]
pub struct Config {
/// Override the database URL environment variable.
///
/// This is used by both the macros and `sqlx-cli`.
///
/// Case-sensitive. Defaults to `DATABASE_URL`.
///
/// Example: Multi-Database Project
/// -------
/// You can use multiple databases in the same project by breaking it up into multiple crates,
/// then using a different environment variable for each.
///
/// For example, with two crates in the workspace named `foo` and `bar`:
///
/// #### `foo/sqlx.toml`
/// ```toml
/// [common]
/// database-url-var = "FOO_DATABASE_URL"
/// ```
///
/// #### `bar/sqlx.toml`
/// ```toml
/// [common]
/// database-url-var = "BAR_DATABASE_URL"
/// ```
///
/// #### `.env`
/// ```text
/// FOO_DATABASE_URL=postgres://postgres@localhost:5432/foo
/// BAR_DATABASE_URL=postgres://postgres@localhost:5432/bar
/// ```
///
/// The query macros used in `foo` will use `FOO_DATABASE_URL`,
/// and the ones used in `bar` will use `BAR_DATABASE_URL`.
pub database_url_var: Option<String>,
}
impl Config {
pub fn database_url_var(&self) -> &str {
self.database_url_var.as_deref().unwrap_or("DATABASE_URL")
}
}

View File

@ -0,0 +1,418 @@
use std::collections::BTreeMap;
/// Configuration for the `query!()` family of macros.
///
/// See also [`common::Config`][crate::config::common::Config] for renaming `DATABASE_URL`.
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case", deny_unknown_fields)
)]
pub struct Config {
/// Specify which crates' types to use when types from multiple crates apply.
///
/// See [`PreferredCrates`] for details.
pub preferred_crates: PreferredCrates,
/// Specify global overrides for mapping SQL type names to Rust type names.
///
/// Default type mappings are defined by the database driver.
/// Refer to the `sqlx::types` module for details.
///
/// ## Note: Case-Sensitive
/// Currently, the case of the type name MUST match the name SQLx knows it by.
/// Built-in types are spelled in all-uppercase to match SQL convention.
///
/// However, user-created types in Postgres are all-lowercase unless quoted.
///
/// ## Note: Orthogonal to Nullability
/// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>`
/// or not. They only override the inner type used.
///
/// ## Note: Schema Qualification (Postgres)
/// Type names may be schema-qualified in Postgres. If so, the schema should be part
/// of the type string, e.g. `'foo.bar'` to reference type `bar` in schema `foo`.
///
/// The schema and/or type name may additionally be quoted in the string
/// for a quoted identifier (see next section).
///
/// Schema qualification should not be used for types in the search path.
///
/// ## Note: Quoted Identifiers (Postgres)
/// Type names using [quoted identifiers in Postgres] must also be specified with quotes here.
///
/// Note, however, that the TOML format parses way the outer pair of quotes,
/// so for quoted names in Postgres, double-quoting is necessary,
/// e.g. `'"Foo"'` for SQL type `"Foo"`.
///
/// To reference a schema-qualified type with a quoted name, use double-quotes after the
/// dot, e.g. `'foo."Bar"'` to reference type `"Bar"` of schema `foo`, and vice versa for
/// quoted schema names.
///
/// We recommend wrapping all type names in single quotes, as shown below,
/// to avoid confusion.
///
/// MySQL/MariaDB and SQLite do not support custom types, so quoting type names should
/// never be necessary.
///
/// [quoted identifiers in Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
// Note: we wanted to be able to handle this intelligently,
// but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761
//
// We decided to just encourage always quoting type names instead.
/// Example: Custom Wrapper Types
/// -------
/// Does SQLx not support a type that you need? Do you want additional semantics not
/// implemented on the built-in types? You can create a custom wrapper,
/// or use an external crate.
///
/// #### `sqlx.toml`
/// ```toml
/// [macros.type-overrides]
/// # Override a built-in type
/// 'UUID' = "crate::types::MyUuid"
///
/// # Support an external or custom wrapper type (e.g. from the `isn` Postgres extension)
/// # (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING)
/// 'isbn13' = "isn_rs::sqlx::ISBN13"
/// ```
///
/// Example: Custom Types in Postgres
/// -------
/// If you have a custom type in Postgres that you want to map without needing to use
/// the type override syntax in `sqlx::query!()` every time, you can specify a global
/// override here.
///
/// For example, a custom enum type `foo`:
///
/// #### Migration or Setup SQL (e.g. `migrations/0_setup.sql`)
/// ```sql
/// CREATE TYPE foo AS ENUM ('Bar', 'Baz');
/// ```
///
/// #### `src/types.rs`
/// ```rust,no_run
/// #[derive(sqlx::Type)]
/// pub enum Foo {
/// Bar,
/// Baz
/// }
/// ```
///
/// If you're not using `PascalCase` in your enum variants then you'll want to use
/// `#[sqlx(rename_all = "<strategy>")]` on your enum.
/// See [`Type`][crate::type::Type] for details.
///
/// #### `sqlx.toml`
/// ```toml
/// [macros.type-overrides]
/// # Map SQL type `foo` to `crate::types::Foo`
/// 'foo' = "crate::types::Foo"
/// ```
///
/// Example: Schema-Qualified Types
/// -------
/// (See `Note` section above for details.)
///
/// ```toml
/// [macros.type-overrides]
/// # Map SQL type `foo.foo` to `crate::types::Foo`
/// 'foo.foo' = "crate::types::Foo"
/// ```
///
/// Example: Quoted Identifiers
/// -------
/// If a type or schema uses quoted identifiers,
/// it must be wrapped in quotes _twice_ for SQLx to know the difference:
///
/// ```toml
/// [macros.type-overrides]
/// # `"Foo"` in SQLx
/// '"Foo"' = "crate::types::Foo"
/// # **NOT** `"Foo"` in SQLx (parses as just `Foo`)
/// "Foo" = "crate::types::Foo"
///
/// # Schema-qualified
/// '"foo".foo' = "crate::types::Foo"
/// 'foo."Foo"' = "crate::types::Foo"
/// '"foo"."Foo"' = "crate::types::Foo"
/// ```
///
/// (See `Note` section above for details.)
// TODO: allow specifying different types for input vs output
// e.g. to accept `&[T]` on input but output `Vec<T>`
pub type_overrides: BTreeMap<SqlType, RustType>,
/// Specify per-table and per-column overrides for mapping SQL types to Rust types.
///
/// Default type mappings are defined by the database driver.
/// Refer to the `sqlx::types` module for details.
///
/// The supported syntax is similar to [`type_overrides`][Self::type_overrides],
/// (with the same caveat for quoted names!) but column names must be qualified
/// by a separately quoted table name, which may optionally be schema-qualified.
///
/// Multiple columns for the same SQL table may be written in the same table in TOML
/// (see examples below).
///
/// ## Note: Orthogonal to Nullability
/// These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>`
/// or not. They only override the inner type used.
///
/// ## Note: Schema Qualification
/// Table names may be schema-qualified. If so, the schema should be part
/// of the table name string, e.g. `'foo.bar'` to reference table `bar` in schema `foo`.
///
/// The schema and/or type name may additionally be quoted in the string
/// for a quoted identifier (see next section).
///
/// Postgres users: schema qualification should not be used for tables in the search path.
///
/// ## Note: Quoted Identifiers
/// Schema, table, or column names using quoted identifiers ([MySQL], [Postgres], [SQLite])
/// in SQL must also be specified with quotes here.
///
/// Postgres and SQLite use double-quotes (`"Foo"`) while MySQL uses backticks (`\`Foo\`).
///
/// Note, however, that the TOML format parses way the outer pair of quotes,
/// so for quoted names in Postgres, double-quoting is necessary,
/// e.g. `'"Foo"'` for SQL name `"Foo"`.
///
/// To reference a schema-qualified table with a quoted name, use the appropriate quotation
/// characters after the dot, e.g. `'foo."Bar"'` to reference table `"Bar"` of schema `foo`,
/// and vice versa for quoted schema names.
///
/// We recommend wrapping all table and column names in single quotes, as shown below,
/// to avoid confusion.
///
/// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/identifiers.html
/// [Postgres]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
/// [SQLite]: https://sqlite.org/lang_keywords.html
// Note: we wanted to be able to handle this intelligently,
// but the `toml` crate authors weren't interested: https://github.com/toml-rs/toml/issues/761
//
// We decided to just encourage always quoting type names instead.
///
/// Example
/// -------
///
/// #### `sqlx.toml`
/// ```toml
/// [macros.table-overrides.'foo']
/// # Map column `bar` of table `foo` to Rust type `crate::types::Foo`:
/// 'bar' = "crate::types::Bar"
///
/// # Quoted column name
/// # Note: same quoting requirements as `macros.type_overrides`
/// '"Bar"' = "crate::types::Bar"
///
/// # Note: will NOT work (parses as `Bar`)
/// # "Bar" = "crate::types::Bar"
///
/// # Table name may be quoted (note the wrapping single-quotes)
/// [macros.table-overrides.'"Foo"']
/// 'bar' = "crate::types::Bar"
/// '"Bar"' = "crate::types::Bar"
///
/// # Table name may also be schema-qualified.
/// # Note how the dot is inside the quotes.
/// [macros.table-overrides.'my_schema.my_table']
/// 'my_column' = "crate::types::MyType"
///
/// # Quoted schema, table, and column names
/// [macros.table-overrides.'"My Schema"."My Table"']
/// '"My Column"' = "crate::types::MyType"
/// ```
pub table_overrides: BTreeMap<TableName, BTreeMap<ColumnName, RustType>>,
}
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case")
)]
pub struct PreferredCrates {
/// Specify the crate to use for mapping date/time types to Rust.
///
/// The default behavior is to use whatever crate is enabled,
/// [`chrono`] or [`time`] (the latter takes precedent).
///
/// [`chrono`]: crate::types::chrono
/// [`time`]: crate::types::time
///
/// Example: Always Use Chrono
/// -------
/// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable
/// the `time` feature of SQLx which will force it on for all crates using SQLx,
/// which will result in problems if your crate wants to use types from [`chrono`].
///
/// You can use the type override syntax (see `sqlx::query!` for details),
/// or you can force an override globally by setting this option.
///
/// #### `sqlx.toml`
/// ```toml
/// [macros.preferred-crates]
/// date-time = "chrono"
/// ```
///
/// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification
pub date_time: DateTimeCrate,
/// Specify the crate to use for mapping `NUMERIC` types to Rust.
///
/// The default behavior is to use whatever crate is enabled,
/// [`bigdecimal`] or [`rust_decimal`] (the latter takes precedent).
///
/// [`bigdecimal`]: crate::types::bigdecimal
/// [`rust_decimal`]: crate::types::rust_decimal
///
/// Example: Always Use `bigdecimal`
/// -------
/// Thanks to Cargo's [feature unification], a crate in the dependency graph may enable
/// the `rust_decimal` feature of SQLx which will force it on for all crates using SQLx,
/// which will result in problems if your crate wants to use types from [`bigdecimal`].
///
/// You can use the type override syntax (see `sqlx::query!` for details),
/// or you can force an override globally by setting this option.
///
/// #### `sqlx.toml`
/// ```toml
/// [macros.preferred-crates]
/// numeric = "bigdecimal"
/// ```
///
/// [feature unification]: https://doc.rust-lang.org/cargo/reference/features.html#feature-unification
pub numeric: NumericCrate,
}
/// The preferred crate to use for mapping date/time types to Rust.
#[derive(Debug, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(rename_all = "snake_case")
)]
pub enum DateTimeCrate {
/// Use whichever crate is enabled (`time` then `chrono`).
#[default]
Inferred,
/// Always use types from [`chrono`][crate::types::chrono].
///
/// ```toml
/// [macros.preferred-crates]
/// date-time = "chrono"
/// ```
Chrono,
/// Always use types from [`time`][crate::types::time].
///
/// ```toml
/// [macros.preferred-crates]
/// date-time = "time"
/// ```
Time,
}
/// The preferred crate to use for mapping `NUMERIC` types to Rust.
#[derive(Debug, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(rename_all = "snake_case")
)]
pub enum NumericCrate {
/// Use whichever crate is enabled (`rust_decimal` then `bigdecimal`).
#[default]
Inferred,
/// Always use types from [`bigdecimal`][crate::types::bigdecimal].
///
/// ```toml
/// [macros.preferred-crates]
/// numeric = "bigdecimal"
/// ```
#[cfg_attr(feature = "sqlx-toml", serde(rename = "bigdecimal"))]
BigDecimal,
/// Always use types from [`rust_decimal`][crate::types::rust_decimal].
///
/// ```toml
/// [macros.preferred-crates]
/// numeric = "rust_decimal"
/// ```
RustDecimal,
}
/// A SQL type name; may optionally be schema-qualified.
///
/// See [`macros.type-overrides`][Config::type_overrides] for usages.
pub type SqlType = Box<str>;
/// A SQL table name; may optionally be schema-qualified.
///
/// See [`macros.table-overrides`][Config::table_overrides] for usages.
pub type TableName = Box<str>;
/// A column in a SQL table.
///
/// See [`macros.table-overrides`][Config::table_overrides] for usages.
pub type ColumnName = Box<str>;
/// A Rust type name or path.
///
/// Should be a global path (not relative).
pub type RustType = Box<str>;
/// Internal getter methods.
impl Config {
/// Get the override for a given type name (optionally schema-qualified).
pub fn type_override(&self, type_name: &str) -> Option<&str> {
// TODO: make this case-insensitive
self.type_overrides.get(type_name).map(|s| &**s)
}
/// Get the override for a given column and table name (optionally schema-qualified).
pub fn column_override(&self, table: &str, column: &str) -> Option<&str> {
self.table_overrides
.get(table)
.and_then(|by_column| by_column.get(column))
.map(|s| &**s)
}
}
impl DateTimeCrate {
/// Returns `self == Self::Inferred`
#[inline(always)]
pub fn is_inferred(&self) -> bool {
*self == Self::Inferred
}
#[inline(always)]
pub fn crate_name(&self) -> Option<&str> {
match self {
Self::Inferred => None,
Self::Chrono => Some("chrono"),
Self::Time => Some("time"),
}
}
}
impl NumericCrate {
/// Returns `self == Self::Inferred`
#[inline(always)]
pub fn is_inferred(&self) -> bool {
*self == Self::Inferred
}
#[inline(always)]
pub fn crate_name(&self) -> Option<&str> {
match self {
Self::Inferred => None,
Self::BigDecimal => Some("bigdecimal"),
Self::RustDecimal => Some("rust_decimal"),
}
}
}

View File

@ -0,0 +1,212 @@
use std::collections::BTreeSet;
/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`.
///
/// ### Note
/// A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these
/// configuration options. We recommend using `sqlx::migrate!()` instead.
///
/// ### Warning: Potential Data Loss or Corruption!
/// Many of these options, if changed after migrations are set up,
/// can result in data loss or corruption of a production database
/// if the proper precautions are not taken.
///
/// Be sure you know what you are doing and that you read all relevant documentation _thoroughly_.
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case", deny_unknown_fields)
)]
pub struct Config {
/// Specify the names of schemas to create if they don't already exist.
///
/// This is done before checking the existence of the migrations table
/// (`_sqlx_migrations` or overridden `table_name` below) so that it may be placed in
/// one of these schemas.
///
/// ### Example
/// `sqlx.toml`:
/// ```toml
/// [migrate]
/// create-schemas = ["foo"]
/// ```
pub create_schemas: BTreeSet<Box<str>>,
/// Override the name of the table used to track executed migrations.
///
/// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`.
///
/// Potentially useful for multi-tenant databases.
///
/// ### Warning: Potential Data Loss or Corruption!
/// Changing this option for a production database will likely result in data loss or corruption
/// as the migration machinery will no longer be aware of what migrations have been applied
/// and will attempt to re-run them.
///
/// You should create the new table as a copy of the existing migrations table (with contents!),
/// and be sure all instances of your application have been migrated to the new
/// table before deleting the old one.
///
/// ### Example
/// `sqlx.toml`:
/// ```toml
/// [migrate]
/// # Put `_sqlx_migrations` in schema `foo`
/// table-name = "foo._sqlx_migrations"
/// ```
pub table_name: Option<Box<str>>,
/// Override the directory used for migrations files.
///
/// Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`.
pub migrations_dir: Option<Box<str>>,
/// Specify characters that should be ignored when hashing migrations.
///
/// Any characters contained in the given array will be dropped when a migration is hashed.
///
/// ### Warning: May Change Hashes for Existing Migrations
/// Changing the characters considered in hashing migrations will likely
/// change the output of the hash.
///
/// This may require manual rectification for deployed databases.
///
/// ### Example: Ignore Carriage Return (`<CR>` | `\r`)
/// Line ending differences between platforms can result in migrations having non-repeatable
/// hashes. The most common culprit is the carriage return (`<CR>` | `\r`), which Windows
/// uses in its line endings alongside line feed (`<LF>` | `\n`), often written `CRLF` or `\r\n`,
/// whereas Linux and macOS use only line feeds.
///
/// `sqlx.toml`:
/// ```toml
/// [migrate]
/// ignored-chars = ["\r"]
/// ```
///
/// For projects using Git, this can also be addressed using [`.gitattributes`]:
///
/// ```text
/// # Force newlines in migrations to be line feeds on all platforms
/// migrations/*.sql text eol=lf
/// ```
///
/// This may require resetting or re-checking out the migrations files to take effect.
///
/// [`.gitattributes`]: https://git-scm.com/docs/gitattributes
///
/// ### Example: Ignore all Whitespace Characters
/// To make your migrations amenable to reformatting, you may wish to tell SQLx to ignore
/// _all_ whitespace characters in migrations.
///
/// ##### Warning: Beware Syntactically Significant Whitespace!
/// If your migrations use string literals or quoted identifiers which contain whitespace,
/// this configuration will cause the migration machinery to ignore some changes to these.
/// This may result in a mismatch between the development and production versions of
/// your database.
///
/// `sqlx.toml`:
/// ```toml
/// [migrate]
/// # Ignore common whitespace characters when hashing
/// ignored-chars = [" ", "\t", "\r", "\n"] # Space, tab, CR, LF
/// ```
// Likely lower overhead for small sets than `HashSet`.
pub ignored_chars: BTreeSet<char>,
/// Specify default options for new migrations created with `sqlx migrate add`.
pub defaults: MigrationDefaults,
}
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case")
)]
pub struct MigrationDefaults {
/// Specify the default type of migration that `sqlx migrate add` should create by default.
///
/// ### Example: Use Reversible Migrations by Default
/// `sqlx.toml`:
/// ```toml
/// [migrate.defaults]
/// migration-type = "reversible"
/// ```
pub migration_type: DefaultMigrationType,
/// Specify the default scheme that `sqlx migrate add` should use for version integers.
///
/// ### Example: Use Sequential Versioning by Default
/// `sqlx.toml`:
/// ```toml
/// [migrate.defaults]
/// migration-versioning = "sequential"
/// ```
pub migration_versioning: DefaultVersioning,
}
/// The default type of migration that `sqlx migrate add` should create by default.
#[derive(Debug, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(rename_all = "snake_case")
)]
pub enum DefaultMigrationType {
/// Create the same migration type as that of the latest existing migration,
/// or `Simple` otherwise.
#[default]
Inferred,
/// Create non-reversible migrations (`<VERSION>_<DESCRIPTION>.sql`) by default.
Simple,
/// Create reversible migrations (`<VERSION>_<DESCRIPTION>.up.sql` and `[...].down.sql`) by default.
Reversible,
}
/// The default scheme that `sqlx migrate add` should use for version integers.
#[derive(Debug, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(rename_all = "snake_case")
)]
pub enum DefaultVersioning {
/// Infer the versioning scheme from existing migrations:
///
/// * If the versions of the last two migrations differ by `1`, infer `Sequential`.
/// * If only one migration exists and has version `1`, infer `Sequential`.
/// * Otherwise, infer `Timestamp`.
#[default]
Inferred,
/// Use UTC timestamps for migration versions.
///
/// This is the recommended versioning format as it's less likely to collide when multiple
/// developers are creating migrations on different branches.
///
/// The exact timestamp format is unspecified.
Timestamp,
/// Use sequential integers for migration versions.
Sequential,
}
#[cfg(feature = "migrate")]
impl Config {
pub fn migrations_dir(&self) -> &str {
self.migrations_dir.as_deref().unwrap_or("migrations")
}
pub fn table_name(&self) -> &str {
self.table_name.as_deref().unwrap_or("_sqlx_migrations")
}
pub fn to_resolve_config(&self) -> crate::migrate::ResolveConfig {
let mut config = crate::migrate::ResolveConfig::new();
config.ignore_chars(self.ignored_chars.iter().copied());
config
}
}

207
sqlx-core/src/config/mod.rs Normal file
View File

@ -0,0 +1,207 @@
//! (Exported for documentation only) Guide and reference for `sqlx.toml` files.
//!
//! To use, create a `sqlx.toml` file in your crate root (the same directory as your `Cargo.toml`).
//! The configuration in a `sqlx.toml` configures SQLx *only* for the current crate.
//!
//! Requires the `sqlx-toml` feature (not enabled by default).
//!
//! `sqlx-cli` will also read `sqlx.toml` when running migrations.
//!
//! See the [`Config`] type and its fields for individual configuration options.
//!
//! See the [reference][`_reference`] for the full `sqlx.toml` file.
use std::error::Error;
use std::fmt::Debug;
use std::io;
use std::path::{Path, PathBuf};
/// Configuration shared by multiple components.
///
/// See [`common::Config`] for details.
pub mod common;
/// Configuration for the `query!()` family of macros.
///
/// See [`macros::Config`] for details.
pub mod macros;
/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`.
///
/// See [`migrate::Config`] for details.
pub mod migrate;
/// Reference for `sqlx.toml` files
///
/// Source: `sqlx-core/src/config/reference.toml`
///
/// ```toml
#[doc = include_str!("reference.toml")]
/// ```
pub mod _reference {}
#[cfg(all(test, feature = "sqlx-toml"))]
mod tests;
/// The parsed structure of a `sqlx.toml` file.
#[derive(Debug, Default)]
#[cfg_attr(
feature = "sqlx-toml",
derive(serde::Deserialize),
serde(default, rename_all = "kebab-case", deny_unknown_fields)
)]
pub struct Config {
/// Configuration shared by multiple components.
///
/// See [`common::Config`] for details.
pub common: common::Config,
/// Configuration for the `query!()` family of macros.
///
/// See [`macros::Config`] for details.
pub macros: macros::Config,
/// Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`.
///
/// See [`migrate::Config`] for details.
pub migrate: migrate::Config,
}
/// Error returned from various methods of [`Config`].
#[derive(thiserror::Error, Debug)]
pub enum ConfigError {
/// The loading method expected `CARGO_MANIFEST_DIR` to be set and it wasn't.
///
/// This is necessary to locate the root of the crate currently being compiled.
///
/// See [the "Environment Variables" page of the Cargo Book][cargo-env] for details.
///
/// [cargo-env]: https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates
#[error("environment variable `CARGO_MANIFEST_DIR` must be set and valid")]
Env(
#[from]
#[source]
std::env::VarError,
),
/// No configuration file was found. Not necessarily fatal.
#[error("config file {path:?} not found")]
NotFound { path: PathBuf },
/// An I/O error occurred while attempting to read the config file at `path`.
///
/// If the error is [`io::ErrorKind::NotFound`], [`Self::NotFound`] is returned instead.
#[error("error reading config file {path:?}")]
Io {
path: PathBuf,
#[source]
error: io::Error,
},
/// An error in the TOML was encountered while parsing the config file at `path`.
///
/// The error gives line numbers and context when printed with `Display`/`ToString`.
///
/// Only returned if the `sqlx-toml` feature is enabled.
#[error("error parsing config file {path:?}")]
Parse {
path: PathBuf,
/// Type-erased [`toml::de::Error`].
#[source]
error: Box<dyn Error + Send + Sync + 'static>,
},
/// A `sqlx.toml` file was found or specified, but the `sqlx-toml` feature is not enabled.
#[error("SQLx found config file at {path:?} but the `sqlx-toml` feature was not enabled")]
ParseDisabled { path: PathBuf },
}
impl ConfigError {
/// Create a [`ConfigError`] from a [`std::io::Error`].
///
/// Maps to either `NotFound` or `Io`.
pub fn from_io(path: impl Into<PathBuf>, error: io::Error) -> Self {
if error.kind() == io::ErrorKind::NotFound {
Self::NotFound { path: path.into() }
} else {
Self::Io {
path: path.into(),
error,
}
}
}
/// If this error means the file was not found, return the path that was attempted.
pub fn not_found_path(&self) -> Option<&Path> {
if let Self::NotFound { path } = self {
Some(path)
} else {
None
}
}
}
/// Internal methods for loading a `Config`.
#[allow(clippy::result_large_err)]
impl Config {
/// Get the cached config, or read `$CARGO_MANIFEST_DIR/sqlx.toml`.
///
/// On success, the config is cached in a `static` and returned by future calls.
///
/// Errors if `CARGO_MANIFEST_DIR` is not set, or if the config file could not be read.
///
/// If the file does not exist, the cache is populated with `Config::default()`.
pub fn try_from_crate_or_default() -> Result<Self, ConfigError> {
Self::read_from(get_crate_path()?).or_else(|e| {
if let ConfigError::NotFound { .. } = e {
Ok(Config::default())
} else {
Err(e)
}
})
}
/// Get the cached config, or attempt to read it from the path given.
///
/// On success, the config is cached in a `static` and returned by future calls.
///
/// Errors if the config file does not exist, or could not be read.
pub fn try_from_path(path: PathBuf) -> Result<Self, ConfigError> {
Self::read_from(path)
}
#[cfg(feature = "sqlx-toml")]
fn read_from(path: PathBuf) -> Result<Self, ConfigError> {
// The `toml` crate doesn't provide an incremental reader.
let toml_s = match std::fs::read_to_string(&path) {
Ok(toml) => toml,
Err(error) => {
return Err(ConfigError::from_io(path, error));
}
};
// TODO: parse and lint TOML structure before deserializing
// Motivation: https://github.com/toml-rs/toml/issues/761
tracing::debug!("read config TOML from {path:?}:\n{toml_s}");
toml::from_str(&toml_s).map_err(|error| ConfigError::Parse {
path,
error: Box::new(error),
})
}
#[cfg(not(feature = "sqlx-toml"))]
fn read_from(path: PathBuf) -> Result<Self, ConfigError> {
match path.try_exists() {
Ok(true) => Err(ConfigError::ParseDisabled { path }),
Ok(false) => Err(ConfigError::NotFound { path }),
Err(e) => Err(ConfigError::from_io(path, e)),
}
}
}
fn get_crate_path() -> Result<PathBuf, ConfigError> {
let mut path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR")?);
path.push("sqlx.toml");
Ok(path)
}

View File

@ -0,0 +1,194 @@
# `sqlx.toml` reference.
#
# Note: shown values are *not* defaults.
# They are explicitly set to non-default values to test parsing.
# Refer to the comment for a given option for its default value.
###############################################################################################
# Configuration shared by multiple components.
[common]
# Change the environment variable to get the database URL.
#
# This is used by both the macros and `sqlx-cli`.
#
# If not specified, defaults to `DATABASE_URL`
database-url-var = "FOO_DATABASE_URL"
###############################################################################################
# Configuration for the `query!()` family of macros.
[macros]
[macros.preferred-crates]
# Force the macros to use the `chrono` crate for date/time types, even if `time` is enabled.
#
# Defaults to "inferred": use whichever crate is enabled (`time` takes precedence over `chrono`).
date-time = "chrono"
# Or, ensure the macros always prefer `time`
# in case new date/time crates are added in the future:
# date-time = "time"
# Force the macros to use the `rust_decimal` crate for `NUMERIC`, even if `bigdecimal` is enabled.
#
# Defaults to "inferred": use whichever crate is enabled (`bigdecimal` takes precedence over `rust_decimal`).
numeric = "rust_decimal"
# Or, ensure the macros always prefer `bigdecimal`
# in case new decimal crates are added in the future:
# numeric = "bigdecimal"
# Set global overrides for mapping SQL types to Rust types.
#
# Default type mappings are defined by the database driver.
# Refer to the `sqlx::types` module for details.
#
# Postgres users: schema qualification should not be used for types in the search path.
#
# ### Note: Orthogonal to Nullability
# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>`
# or not. They only override the inner type used.
[macros.type-overrides]
# Override a built-in type (map all `UUID` columns to `crate::types::MyUuid`)
# Note: currently, the case of the type name MUST match.
# Built-in types are spelled in all-uppercase to match SQL convention.
'UUID' = "crate::types::MyUuid"
# Support an external or custom wrapper type (e.g. from the `isn` Postgres extension)
# (NOTE: FOR DOCUMENTATION PURPOSES ONLY; THIS CRATE/TYPE DOES NOT EXIST AS OF WRITING)
'isbn13' = "isn_rs::isbn::ISBN13"
# SQL type `foo` to Rust type `crate::types::Foo`:
'foo' = "crate::types::Foo"
# SQL type `"Bar"` to Rust type `crate::types::Bar`; notice the extra pair of quotes:
'"Bar"' = "crate::types::Bar"
# Will NOT work (the first pair of quotes are parsed by TOML)
# "Bar" = "crate::types::Bar"
# Schema qualified
'foo.bar' = "crate::types::Bar"
# Schema qualified and quoted
'foo."Bar"' = "crate::schema::foo::Bar"
# Quoted schema name
'"Foo".bar' = "crate::schema::foo::Bar"
# Quoted schema and type name
'"Foo"."Bar"' = "crate::schema::foo::Bar"
# Set per-table and per-column overrides for mapping SQL types to Rust types.
#
# Note: table name is required in the header.
#
# Postgres users: schema qualification should not be used for types in the search path.
#
# ### Note: Orthogonal to Nullability
# These overrides do not affect whether `query!()` decides to wrap a column in `Option<_>`
# or not. They only override the inner type used.
[macros.table-overrides.'foo']
# Map column `bar` of table `foo` to Rust type `crate::types::Foo`:
'bar' = "crate::types::Bar"
# Quoted column name
# Note: same quoting requirements as `macros.type_overrides`
'"Bar"' = "crate::types::Bar"
# Note: will NOT work (parses as `Bar`)
# "Bar" = "crate::types::Bar"
# Table name may be quoted (note the wrapping single-quotes)
[macros.table-overrides.'"Foo"']
'bar' = "crate::types::Bar"
'"Bar"' = "crate::types::Bar"
# Table name may also be schema-qualified.
# Note how the dot is inside the quotes.
[macros.table-overrides.'my_schema.my_table']
'my_column' = "crate::types::MyType"
# Quoted schema, table, and column names
[macros.table-overrides.'"My Schema"."My Table"']
'"My Column"' = "crate::types::MyType"
###############################################################################################
# Configuration for migrations when executed using `sqlx::migrate!()` or through `sqlx-cli`.
#
# ### Note
# A manually constructed [`Migrator`][crate::migrate::Migrator] will not be aware of these
# configuration options. We recommend using `sqlx::migrate!()` instead.
#
# ### Warning: Potential Data Loss or Corruption!
# Many of these options, if changed after migrations are set up,
# can result in data loss or corruption of a production database
# if the proper precautions are not taken.
#
# Be sure you know what you are doing and that you read all relevant documentation _thoroughly_.
[migrate]
# Override the name of the table used to track executed migrations.
#
# May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`.
#
# Potentially useful for multi-tenant databases.
#
# ### Warning: Potential Data Loss or Corruption!
# Changing this option for a production database will likely result in data loss or corruption
# as the migration machinery will no longer be aware of what migrations have been applied
# and will attempt to re-run them.
#
# You should create the new table as a copy of the existing migrations table (with contents!),
# and be sure all instances of your application have been migrated to the new
# table before deleting the old one.
table-name = "foo._sqlx_migrations"
# Override the directory used for migrations files.
#
# Relative to the crate root for `sqlx::migrate!()`, or the current directory for `sqlx-cli`.
migrations-dir = "foo/migrations"
# Specify characters that should be ignored when hashing migrations.
#
# Any characters contained in the given set will be dropped when a migration is hashed.
#
# Defaults to an empty array (don't drop any characters).
#
# ### Warning: May Change Hashes for Existing Migrations
# Changing the characters considered in hashing migrations will likely
# change the output of the hash.
#
# This may require manual rectification for deployed databases.
# ignored-chars = []
# Ignore Carriage Returns (`<CR>` | `\r`)
# Note that the TOML format requires double-quoted strings to process escapes.
# ignored-chars = ["\r"]
# Ignore common whitespace characters (beware syntatically significant whitespace!)
# Space, tab, CR, LF, zero-width non-breaking space (U+FEFF)
#
# U+FEFF is added by some editors as a magic number at the beginning of a text file indicating it is UTF-8 encoded,
# where it is known as a byte-order mark (BOM): https://en.wikipedia.org/wiki/Byte_order_mark
ignored-chars = [" ", "\t", "\r", "\n", "\uFEFF"]
# Set default options for new migrations.
[migrate.defaults]
# Specify reversible migrations by default (for `sqlx migrate create`).
#
# Defaults to "inferred": uses the type of the last migration, or "simple" otherwise.
migration-type = "reversible"
# Specify simple (non-reversible) migrations by default.
# migration-type = "simple"
# Specify sequential versioning by default (for `sqlx migrate create`).
#
# Defaults to "inferred": guesses the versioning scheme from the latest migrations,
# or "timestamp" otherwise.
migration-versioning = "sequential"
# Specify timestamp versioning by default.
# migration-versioning = "timestamp"

View File

@ -0,0 +1,93 @@
use crate::config::{self, Config};
use std::collections::BTreeSet;
#[test]
fn reference_parses_as_config() {
let config: Config = toml::from_str(include_str!("reference.toml"))
// The `Display` impl of `toml::Error` is *actually* more useful than `Debug`
.unwrap_or_else(|e| panic!("expected reference.toml to parse as Config: {e}"));
assert_common_config(&config.common);
assert_macros_config(&config.macros);
assert_migrate_config(&config.migrate);
}
fn assert_common_config(config: &config::common::Config) {
assert_eq!(config.database_url_var.as_deref(), Some("FOO_DATABASE_URL"));
}
fn assert_macros_config(config: &config::macros::Config) {
use config::macros::*;
assert_eq!(config.preferred_crates.date_time, DateTimeCrate::Chrono);
assert_eq!(config.preferred_crates.numeric, NumericCrate::RustDecimal);
// Type overrides
// Don't need to cover everything, just some important canaries.
assert_eq!(config.type_override("UUID"), Some("crate::types::MyUuid"));
assert_eq!(config.type_override("foo"), Some("crate::types::Foo"));
assert_eq!(config.type_override(r#""Bar""#), Some("crate::types::Bar"),);
assert_eq!(
config.type_override(r#""Foo".bar"#),
Some("crate::schema::foo::Bar"),
);
assert_eq!(
config.type_override(r#""Foo"."Bar""#),
Some("crate::schema::foo::Bar"),
);
// Column overrides
assert_eq!(
config.column_override("foo", "bar"),
Some("crate::types::Bar"),
);
assert_eq!(
config.column_override("foo", r#""Bar""#),
Some("crate::types::Bar"),
);
assert_eq!(
config.column_override(r#""Foo""#, "bar"),
Some("crate::types::Bar"),
);
assert_eq!(
config.column_override(r#""Foo""#, r#""Bar""#),
Some("crate::types::Bar"),
);
assert_eq!(
config.column_override("my_schema.my_table", "my_column"),
Some("crate::types::MyType"),
);
assert_eq!(
config.column_override(r#""My Schema"."My Table""#, r#""My Column""#),
Some("crate::types::MyType"),
);
}
fn assert_migrate_config(config: &config::migrate::Config) {
use config::migrate::*;
assert_eq!(config.table_name.as_deref(), Some("foo._sqlx_migrations"));
assert_eq!(config.migrations_dir.as_deref(), Some("foo/migrations"));
let ignored_chars = BTreeSet::from([' ', '\t', '\r', '\n', '\u{FEFF}']);
assert_eq!(config.ignored_chars, ignored_chars);
assert_eq!(
config.defaults.migration_type,
DefaultMigrationType::Reversible
);
assert_eq!(
config.defaults.migration_versioning,
DefaultVersioning::Sequential
);
}

View File

@ -91,6 +91,8 @@ pub mod any;
#[cfg(feature = "migrate")]
pub mod testing;
pub mod config;
pub use error::{Error, Result};
pub use either::Either;

View File

@ -39,4 +39,7 @@ pub enum MigrateError {
"migration {0} is partially applied; fix and remove row from `_sqlx_migrations` table"
)]
Dirty(i64),
#[error("database driver does not support creation of schemas at migrate time: {0}")]
CreateSchemasNotSupported(String),
}

View File

@ -25,18 +25,31 @@ pub trait MigrateDatabase {
// 'e = Executor
pub trait Migrate {
/// Create a database schema with the given name if it does not already exist.
fn create_schema_if_not_exists<'e>(
&'e mut self,
schema_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>>;
// ensure migrations table exists
// will create or migrate it if needed
fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>>;
fn ensure_migrations_table<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<(), MigrateError>>;
// Return the version on which the database is dirty or None otherwise.
// "dirty" means there is a partially applied migration that failed.
fn dirty_version(&mut self) -> BoxFuture<'_, Result<Option<i64>, MigrateError>>;
fn dirty_version<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Option<i64>, MigrateError>>;
// Return the ordered list of applied migrations
fn list_applied_migrations(
&mut self,
) -> BoxFuture<'_, Result<Vec<AppliedMigration>, MigrateError>>;
fn list_applied_migrations<'e>(
&'e mut self,
table_name: &'e str,
) -> BoxFuture<'e, Result<Vec<AppliedMigration>, MigrateError>>;
// Should acquire a database lock so that only one migration process
// can run at a time. [`Migrate`] will call this function before applying
@ -50,16 +63,18 @@ pub trait Migrate {
// run SQL from migration in a DDL transaction
// insert new row to [_migrations] table on completion (success or failure)
// returns the time taking to run the migration SQL
fn apply<'e: 'm, 'm>(
fn apply<'e>(
&'e mut self,
migration: &'m Migration,
) -> BoxFuture<'m, Result<Duration, MigrateError>>;
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>>;
// run a revert SQL from migration in a DDL transaction
// deletes the row in [_migrations] table with specified migration version on completion (success or failure)
// returns the time taking to run the migration SQL
fn revert<'e: 'm, 'm>(
fn revert<'e>(
&'e mut self,
migration: &'m Migration,
) -> BoxFuture<'m, Result<Duration, MigrateError>>;
table_name: &'e str,
migration: &'e Migration,
) -> BoxFuture<'e, Result<Duration, MigrateError>>;
}

View File

@ -1,6 +1,5 @@
use std::borrow::Cow;
use sha2::{Digest, Sha384};
use std::borrow::Cow;
use super::MigrationType;
@ -22,8 +21,26 @@ impl Migration {
sql: Cow<'static, str>,
no_tx: bool,
) -> Self {
let checksum = Cow::Owned(Vec::from(Sha384::digest(sql.as_bytes()).as_slice()));
let checksum = checksum(&sql);
Self::with_checksum(
version,
description,
migration_type,
sql,
checksum.into(),
no_tx,
)
}
pub(crate) fn with_checksum(
version: i64,
description: Cow<'static, str>,
migration_type: MigrationType,
sql: Cow<'static, str>,
checksum: Cow<'static, [u8]>,
no_tx: bool,
) -> Self {
Migration {
version,
description,
@ -40,3 +57,39 @@ pub struct AppliedMigration {
pub version: i64,
pub checksum: Cow<'static, [u8]>,
}
pub fn checksum(sql: &str) -> Vec<u8> {
Vec::from(Sha384::digest(sql).as_slice())
}
pub fn checksum_fragments<'a>(fragments: impl Iterator<Item = &'a str>) -> Vec<u8> {
let mut digest = Sha384::new();
for fragment in fragments {
digest.update(fragment);
}
digest.finalize().to_vec()
}
#[test]
fn fragments_checksum_equals_full_checksum() {
// Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql`
let sql = "\
\u{FEFF}create table comment (\r\n\
\tcomment_id uuid primary key default gen_random_uuid(),\r\n\
\tpost_id uuid not null references post(post_id),\r\n\
\tuser_id uuid not null references \"user\"(user_id),\r\n\
\tcontent text not null,\r\n\
\tcreated_at timestamptz not null default now()\r\n\
);\r\n\
\r\n\
create index on comment(post_id, created_at);\r\n\
";
// Should yield a string for each character
let fragments_checksum = checksum_fragments(sql.split(""));
let full_checksum = checksum(sql);
assert_eq!(fragments_checksum, full_checksum);
}

View File

@ -74,8 +74,9 @@ impl MigrationType {
}
}
#[deprecated = "unused"]
pub fn infer(migrator: &Migrator, reversible: bool) -> MigrationType {
match migrator.iter().next() {
match migrator.iter().last() {
Some(first_migration) => first_migration.migration_type,
None => {
if reversible {

View File

@ -23,25 +23,11 @@ pub struct Migrator {
pub locking: bool,
#[doc(hidden)]
pub no_tx: bool,
}
#[doc(hidden)]
pub table_name: Cow<'static, str>,
fn validate_applied_migrations(
applied_migrations: &[AppliedMigration],
migrator: &Migrator,
) -> Result<(), MigrateError> {
if migrator.ignore_missing {
return Ok(());
}
let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
for applied_migration in applied_migrations {
if !migrations.contains(&applied_migration.version) {
return Err(MigrateError::VersionMissing(applied_migration.version));
}
}
Ok(())
#[doc(hidden)]
pub create_schemas: Cow<'static, [Cow<'static, str>]>,
}
impl Migrator {
@ -51,6 +37,8 @@ impl Migrator {
ignore_missing: false,
no_tx: false,
locking: true,
table_name: Cow::Borrowed("_sqlx_migrations"),
create_schemas: Cow::Borrowed(&[]),
};
/// Creates a new instance with the given source.
@ -81,6 +69,38 @@ impl Migrator {
})
}
/// Override the name of the table used to track executed migrations.
///
/// May be schema-qualified and/or contain quotes. Defaults to `_sqlx_migrations`.
///
/// Potentially useful for multi-tenant databases.
///
/// ### Warning: Potential Data Loss or Corruption!
/// Changing this option for a production database will likely result in data loss or corruption
/// as the migration machinery will no longer be aware of what migrations have been applied
/// and will attempt to re-run them.
///
/// You should create the new table as a copy of the existing migrations table (with contents!),
/// and be sure all instances of your application have been migrated to the new
/// table before deleting the old one.
pub fn dangerous_set_table_name(&mut self, table_name: impl Into<Cow<'static, str>>) -> &Self {
self.table_name = table_name.into();
self
}
/// Add a schema name to be created if it does not already exist.
///
/// May be used with [`Self::dangerous_set_table_name()`] to place the migrations table
/// in a new schema without requiring it to exist first.
///
/// ### Note: Support Depends on Database
/// SQLite cannot create new schemas without attaching them to a database file,
/// the path of which must be specified separately in an [`ATTACH DATABASE`](https://www.sqlite.org/lang_attach.html) command.
pub fn create_schema(&mut self, schema_name: impl Into<Cow<'static, str>>) -> &Self {
self.create_schemas.to_mut().push(schema_name.into());
self
}
/// Specify whether applied migrations that are missing from the resolved migrations should be ignored.
pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
self.ignore_missing = ignore_missing;
@ -134,12 +154,21 @@ impl Migrator {
<A::Connection as Deref>::Target: Migrate,
{
let mut conn = migrator.acquire().await?;
self.run_direct(&mut *conn).await
self.run_direct(None, &mut *conn).await
}
pub async fn run_to<'a, A>(&self, target: i64, migrator: A) -> Result<(), MigrateError>
where
A: Acquire<'a>,
<A::Connection as Deref>::Target: Migrate,
{
let mut conn = migrator.acquire().await?;
self.run_direct(Some(target), &mut *conn).await
}
// Getting around the annoying "implementation of `Acquire` is not general enough" error
#[doc(hidden)]
pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
pub async fn run_direct<C>(&self, target: Option<i64>, conn: &mut C) -> Result<(), MigrateError>
where
C: Migrate,
{
@ -148,16 +177,20 @@ impl Migrator {
conn.lock().await?;
}
for schema_name in self.create_schemas.iter() {
conn.create_schema_if_not_exists(schema_name).await?;
}
// creates [_migrations] table only if needed
// eventually this will likely migrate previous versions of the table
conn.ensure_migrations_table().await?;
conn.ensure_migrations_table(&self.table_name).await?;
let version = conn.dirty_version().await?;
let version = conn.dirty_version(&self.table_name).await?;
if let Some(version) = version {
return Err(MigrateError::Dirty(version));
}
let applied_migrations = conn.list_applied_migrations().await?;
let applied_migrations = conn.list_applied_migrations(&self.table_name).await?;
validate_applied_migrations(&applied_migrations, self)?;
let applied_migrations: HashMap<_, _> = applied_migrations
@ -166,6 +199,11 @@ impl Migrator {
.collect();
for migration in self.iter() {
if target.is_some_and(|target| target < migration.version) {
// Target version reached
break;
}
if migration.migration_type.is_down_migration() {
continue;
}
@ -177,7 +215,7 @@ impl Migrator {
}
}
None => {
conn.apply(migration).await?;
conn.apply(&self.table_name, migration).await?;
}
}
}
@ -222,14 +260,14 @@ impl Migrator {
// creates [_migrations] table only if needed
// eventually this will likely migrate previous versions of the table
conn.ensure_migrations_table().await?;
conn.ensure_migrations_table(&self.table_name).await?;
let version = conn.dirty_version().await?;
let version = conn.dirty_version(&self.table_name).await?;
if let Some(version) = version {
return Err(MigrateError::Dirty(version));
}
let applied_migrations = conn.list_applied_migrations().await?;
let applied_migrations = conn.list_applied_migrations(&self.table_name).await?;
validate_applied_migrations(&applied_migrations, self)?;
let applied_migrations: HashMap<_, _> = applied_migrations
@ -244,7 +282,7 @@ impl Migrator {
.filter(|m| applied_migrations.contains_key(&m.version))
.filter(|m| m.version > target)
{
conn.revert(migration).await?;
conn.revert(&self.table_name, migration).await?;
}
// unlock the migrator to allow other migrators to run
@ -256,3 +294,22 @@ impl Migrator {
Ok(())
}
}
fn validate_applied_migrations(
applied_migrations: &[AppliedMigration],
migrator: &Migrator,
) -> Result<(), MigrateError> {
if migrator.ignore_missing {
return Ok(());
}
let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
for applied_migration in applied_migrations {
if !migrations.contains(&applied_migration.version) {
return Err(MigrateError::VersionMissing(applied_migration.version));
}
}
Ok(())
}

View File

@ -11,7 +11,7 @@ pub use migrate::{Migrate, MigrateDatabase};
pub use migration::{AppliedMigration, Migration};
pub use migration_type::MigrationType;
pub use migrator::Migrator;
pub use source::MigrationSource;
pub use source::{MigrationSource, ResolveConfig, ResolveWith};
#[doc(hidden)]
pub use source::resolve_blocking;
pub use source::{resolve_blocking, resolve_blocking_with_config};

View File

@ -1,8 +1,9 @@
use crate::error::BoxDynError;
use crate::migrate::{Migration, MigrationType};
use crate::migrate::{migration, Migration, MigrationType};
use futures_core::future::BoxFuture;
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::fmt::Debug;
use std::fs;
use std::io;
@ -28,19 +29,48 @@ pub trait MigrationSource<'s>: Debug {
impl<'s> MigrationSource<'s> for &'s Path {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move {
let canonical = self.canonicalize()?;
let migrations_with_paths =
crate::rt::spawn_blocking(move || resolve_blocking(&canonical)).await?;
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
})
// Behavior changed from previous because `canonicalize()` is potentially blocking
// since it might require going to disk to fetch filesystem data.
self.to_owned().resolve()
}
}
impl MigrationSource<'static> for PathBuf {
fn resolve(self) -> BoxFuture<'static, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move { self.as_path().resolve().await })
// Technically this could just be `Box::pin(spawn_blocking(...))`
// but that would actually be a breaking behavior change because it would call
// `spawn_blocking()` on the current thread
Box::pin(async move {
crate::rt::spawn_blocking(move || {
let migrations_with_paths = resolve_blocking(&self)?;
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
})
.await
})
}
}
/// A [`MigrationSource`] implementation with configurable resolution.
///
/// `S` may be `PathBuf`, `&Path` or any type that implements `Into<PathBuf>`.
///
/// See [`ResolveConfig`] for details.
#[derive(Debug)]
pub struct ResolveWith<S>(pub S, pub ResolveConfig);
impl<'s, S: Debug + Into<PathBuf> + Send + 's> MigrationSource<'s> for ResolveWith<S> {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move {
let path = self.0.into();
let config = self.1;
let migrations_with_paths =
crate::rt::spawn_blocking(move || resolve_blocking_with_config(&path, &config))
.await?;
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
})
}
}
@ -52,11 +82,87 @@ pub struct ResolveError {
source: Option<io::Error>,
}
/// Configuration for migration resolution using [`ResolveWith`].
#[derive(Debug, Default)]
pub struct ResolveConfig {
ignored_chars: BTreeSet<char>,
}
impl ResolveConfig {
/// Return a default, empty configuration.
pub fn new() -> Self {
ResolveConfig {
ignored_chars: BTreeSet::new(),
}
}
/// Ignore a character when hashing migrations.
///
/// The migration SQL string itself will still contain the character,
/// but it will not be included when calculating the checksum.
///
/// This can be used to ignore whitespace characters so changing formatting
/// does not change the checksum.
///
/// Adding the same `char` more than once is a no-op.
///
/// ### Note: Changes Migration Checksum
/// This will change the checksum of resolved migrations,
/// which may cause problems with existing deployments.
///
/// **Use at your own risk.**
pub fn ignore_char(&mut self, c: char) -> &mut Self {
self.ignored_chars.insert(c);
self
}
/// Ignore one or more characters when hashing migrations.
///
/// The migration SQL string itself will still contain these characters,
/// but they will not be included when calculating the checksum.
///
/// This can be used to ignore whitespace characters so changing formatting
/// does not change the checksum.
///
/// Adding the same `char` more than once is a no-op.
///
/// ### Note: Changes Migration Checksum
/// This will change the checksum of resolved migrations,
/// which may cause problems with existing deployments.
///
/// **Use at your own risk.**
pub fn ignore_chars(&mut self, chars: impl IntoIterator<Item = char>) -> &mut Self {
self.ignored_chars.extend(chars);
self
}
/// Iterate over the set of ignored characters.
///
/// Duplicate `char`s are not included.
pub fn ignored_chars(&self) -> impl Iterator<Item = char> + '_ {
self.ignored_chars.iter().copied()
}
}
// FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly
// since it's `#[non_exhaustive]`.
#[doc(hidden)]
pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
let s = fs::read_dir(path).map_err(|e| ResolveError {
message: format!("error reading migration directory {}: {e}", path.display()),
resolve_blocking_with_config(path, &ResolveConfig::new())
}
#[doc(hidden)]
pub fn resolve_blocking_with_config(
path: &Path,
config: &ResolveConfig,
) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
let path = path.canonicalize().map_err(|e| ResolveError {
message: format!("error canonicalizing path {}", path.display()),
source: Some(e),
})?;
let s = fs::read_dir(&path).map_err(|e| ResolveError {
message: format!("error reading migration directory {}", path.display()),
source: Some(e),
})?;
@ -65,7 +171,7 @@ pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, Resolv
for res in s {
let entry = res.map_err(|e| ResolveError {
message: format!(
"error reading contents of migration directory {}: {e}",
"error reading contents of migration directory {}",
path.display()
),
source: Some(e),
@ -126,12 +232,15 @@ pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, Resolv
// opt-out of migration transaction
let no_tx = sql.starts_with("-- no-transaction");
let checksum = checksum_with(&sql, &config.ignored_chars);
migrations.push((
Migration::new(
Migration::with_checksum(
version,
Cow::Owned(description),
migration_type,
Cow::Owned(sql),
checksum.into(),
no_tx,
),
entry_path,
@ -143,3 +252,47 @@ pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, Resolv
Ok(migrations)
}
fn checksum_with(sql: &str, ignored_chars: &BTreeSet<char>) -> Vec<u8> {
if ignored_chars.is_empty() {
// This is going to be much faster because it doesn't have to UTF-8 decode `sql`.
return migration::checksum(sql);
}
migration::checksum_fragments(sql.split(|c| ignored_chars.contains(&c)))
}
#[test]
fn checksum_with_ignored_chars() {
// Ensure that `checksum_with` returns the same digest for a given set of ignored chars
// as the equivalent string with the characters removed.
let ignored_chars = [
' ', '\t', '\r', '\n',
// Zero-width non-breaking space (ZWNBSP), often added as a magic-number at the beginning
// of UTF-8 encoded files as a byte-order mark (BOM):
// https://en.wikipedia.org/wiki/Byte_order_mark
'\u{FEFF}',
];
// Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql`
let sql = "\
\u{FEFF}create table comment (\r\n\
\tcomment_id uuid primary key default gen_random_uuid(),\r\n\
\tpost_id uuid not null references post(post_id),\r\n\
\tuser_id uuid not null references \"user\"(user_id),\r\n\
\tcontent text not null,\r\n\
\tcreated_at timestamptz not null default now()\r\n\
);\r\n\
\r\n\
create index on comment(post_id, created_at);\r\n\
";
let stripped_sql = sql.replace(&ignored_chars[..], "");
let ignored_chars = BTreeSet::from(ignored_chars);
let digest_ignored = checksum_with(sql, &ignored_chars);
let digest_stripped = migration::checksum(&stripped_sql);
assert_eq!(digest_ignored, digest_stripped);
}

View File

@ -256,7 +256,7 @@ async fn setup_test_db<DB: Database>(
if let Some(migrator) = args.migrator {
migrator
.run_direct(&mut conn)
.run_direct(None, &mut conn)
.await
.expect("failed to apply migrations");
}

View File

@ -1,3 +1,4 @@
use crate::config::macros::PreferredCrates;
use crate::database::Database;
use crate::decode::Decode;
use crate::type_info::TypeInfo;
@ -26,12 +27,18 @@ pub trait TypeChecking: Database {
///
/// If the type has a borrowed equivalent suitable for query parameters,
/// this is that borrowed type.
fn param_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>;
fn param_type_for_id(
id: &Self::TypeInfo,
preferred_crates: &PreferredCrates,
) -> Result<&'static str, Error>;
/// Get the full path of the Rust type that corresponds to the given `TypeInfo`, if applicable.
///
/// Always returns the owned version of the type, suitable for decoding from `Row`.
fn return_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>;
fn return_type_for_id(
id: &Self::TypeInfo,
preferred_crates: &PreferredCrates,
) -> Result<&'static str, Error>;
/// Get the name of the Cargo feature gate that must be enabled to process the given `TypeInfo`,
/// if applicable.
@ -43,6 +50,22 @@ pub trait TypeChecking: Database {
fn fmt_value_debug(value: &<Self as Database>::Value) -> FmtValue<'_, Self>;
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("no built-in mapping found for SQL type; a type override may be required")]
NoMappingFound,
#[error("Cargo feature for configured `macros.preferred-crates.date-time` not enabled")]
DateTimeCrateFeatureNotEnabled,
#[error("Cargo feature for configured `macros.preferred-crates.numeric` not enabled")]
NumericCrateFeatureNotEnabled,
#[error("multiple date-time types are possible; falling back to `{fallback}`")]
AmbiguousDateTimeType { fallback: &'static str },
#[error("multiple numeric types are possible; falling back to `{fallback}`")]
AmbiguousNumericType { fallback: &'static str },
}
/// An adapter for [`Value`] which attempts to decode the value and format it when printed using [`Debug`].
pub struct FmtValue<'v, DB>
where
@ -140,36 +163,304 @@ macro_rules! impl_type_checking {
},
ParamChecking::$param_checking:ident,
feature-types: $ty_info:ident => $get_gate:expr,
datetime-types: {
chrono: {
$($chrono_ty:ty $(| $chrono_input:ty)?),*$(,)?
},
time: {
$($time_ty:ty $(| $time_input:ty)?),*$(,)?
},
},
numeric-types: {
bigdecimal: {
$($bigdecimal_ty:ty $(| $bigdecimal_input:ty)?),*$(,)?
},
rust_decimal: {
$($rust_decimal_ty:ty $(| $rust_decimal_input:ty)?),*$(,)?
},
},
) => {
impl $crate::type_checking::TypeChecking for $database {
const PARAM_CHECKING: $crate::type_checking::ParamChecking = $crate::type_checking::ParamChecking::$param_checking;
fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> {
match () {
fn param_type_for_id(
info: &Self::TypeInfo,
preferred_crates: &$crate::config::macros::PreferredCrates,
) -> Result<&'static str, $crate::type_checking::Error> {
use $crate::config::macros::{DateTimeCrate, NumericCrate};
use $crate::type_checking::Error;
// Check non-special types
// ---------------------
$(
$(#[$meta])?
if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok($crate::select_input_type!($ty $(, $input)?));
}
)*
$(
$(#[$meta])?
if <$ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok($crate::select_input_type!($ty $(, $input)?));
}
)*
// Check `macros.preferred-crates.date-time`
//
// Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled.
// Any crates added later should be _lower_ priority than `chrono` to avoid breakages.
// ----------------------------------------
#[cfg(feature = "time")]
if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) {
$(
$(#[$meta])?
_ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some($crate::select_input_type!($ty $(, $input)?)),
if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
if cfg!(feature = "chrono") {
return Err($crate::type_checking::Error::AmbiguousDateTimeType {
fallback: $crate::select_input_type!($time_ty $(, $time_input)?),
});
}
return Ok($crate::select_input_type!($time_ty $(, $time_input)?));
}
)*
$(
$(#[$meta])?
_ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some($crate::select_input_type!($ty $(, $input)?)),
if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) {
if cfg!(feature = "chrono") {
return Err($crate::type_checking::Error::AmbiguousDateTimeType {
fallback: $crate::select_input_type!($time_ty $(, $time_input)?),
});
}
return Ok($crate::select_input_type!($time_ty $(, $time_input)?));
}
)*
_ => None
}
#[cfg(not(feature = "time"))]
if preferred_crates.date_time == DateTimeCrate::Time {
return Err(Error::DateTimeCrateFeatureNotEnabled);
}
#[cfg(feature = "chrono")]
if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) {
$(
if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?));
}
)*
$(
if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok($crate::select_input_type!($chrono_ty $(, $chrono_input)?));
}
)*
}
#[cfg(not(feature = "chrono"))]
if preferred_crates.date_time == DateTimeCrate::Chrono {
return Err(Error::DateTimeCrateFeatureNotEnabled);
}
// Check `macros.preferred-crates.numeric`
//
// Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if
// both are enabled.
// ----------------------------------------
#[cfg(feature = "bigdecimal")]
if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) {
$(
if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
if cfg!(feature = "rust_decimal") {
return Err($crate::type_checking::Error::AmbiguousNumericType {
fallback: $crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?),
});
}
return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?));
}
)*
$(
if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) {
if cfg!(feature = "rust_decimal") {
return Err($crate::type_checking::Error::AmbiguousNumericType {
fallback: $crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?),
});
}
return Ok($crate::select_input_type!($bigdecimal_ty $(, $bigdecimal_input)?));
}
)*
}
#[cfg(not(feature = "bigdecimal"))]
if preferred_crates.numeric == NumericCrate::BigDecimal {
return Err(Error::NumericCrateFeatureNotEnabled);
}
#[cfg(feature = "rust_decimal")]
if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) {
$(
if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?));
}
)*
$(
if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?));
}
)*
}
#[cfg(not(feature = "rust_decimal"))]
if preferred_crates.numeric == NumericCrate::RustDecimal {
return Err(Error::NumericCrateFeatureNotEnabled);
}
Err(Error::NoMappingFound)
}
fn return_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> {
match () {
fn return_type_for_id(
info: &Self::TypeInfo,
preferred_crates: &$crate::config::macros::PreferredCrates,
) -> Result<&'static str, $crate::type_checking::Error> {
use $crate::config::macros::{DateTimeCrate, NumericCrate};
use $crate::type_checking::Error;
// Check non-special types
// ---------------------
$(
$(#[$meta])?
if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok(stringify!($ty));
}
)*
$(
$(#[$meta])?
if <$ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok(stringify!($ty));
}
)*
// Check `macros.preferred-crates.date-time`
//
// Due to legacy reasons, `time` takes precedent over `chrono` if both are enabled.
// Any crates added later should be _lower_ priority than `chrono` to avoid breakages.
// ----------------------------------------
#[cfg(feature = "time")]
if matches!(preferred_crates.date_time, DateTimeCrate::Time | DateTimeCrate::Inferred) {
$(
$(#[$meta])?
_ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(stringify!($ty)),
if <$time_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
if cfg!(feature = "chrono") {
return Err($crate::type_checking::Error::AmbiguousDateTimeType {
fallback: stringify!($time_ty),
});
}
return Ok(stringify!($time_ty));
}
)*
$(
$(#[$meta])?
_ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some(stringify!($ty)),
if <$time_ty as sqlx_core::types::Type<$database>>::compatible(info) {
if cfg!(feature = "chrono") {
return Err($crate::type_checking::Error::AmbiguousDateTimeType {
fallback: stringify!($time_ty),
});
}
return Ok(stringify!($time_ty));
}
)*
_ => None
}
#[cfg(not(feature = "time"))]
if preferred_crates.date_time == DateTimeCrate::Time {
return Err(Error::DateTimeCrateFeatureNotEnabled);
}
#[cfg(feature = "chrono")]
if matches!(preferred_crates.date_time, DateTimeCrate::Chrono | DateTimeCrate::Inferred) {
$(
if <$chrono_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok(stringify!($chrono_ty));
}
)*
$(
if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok(stringify!($chrono_ty));
}
)*
}
#[cfg(not(feature = "chrono"))]
if preferred_crates.date_time == DateTimeCrate::Chrono {
return Err(Error::DateTimeCrateFeatureNotEnabled);
}
// Check `macros.preferred-crates.numeric`
//
// Due to legacy reasons, `bigdecimal` takes precedent over `rust_decimal` if
// both are enabled.
// ----------------------------------------
#[cfg(feature = "bigdecimal")]
if matches!(preferred_crates.numeric, NumericCrate::BigDecimal | NumericCrate::Inferred) {
$(
if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
if cfg!(feature = "rust_decimal") {
return Err($crate::type_checking::Error::AmbiguousNumericType {
fallback: stringify!($bigdecimal_ty),
});
}
return Ok(stringify!($bigdecimal_ty));
}
)*
$(
if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(info) {
if cfg!(feature = "rust_decimal") {
return Err($crate::type_checking::Error::AmbiguousNumericType {
fallback: stringify!($bigdecimal_ty),
});
}
return Ok(stringify!($bigdecimal_ty));
}
)*
}
#[cfg(not(feature = "bigdecimal"))]
if preferred_crates.numeric == NumericCrate::BigDecimal {
return Err(Error::NumericCrateFeatureNotEnabled);
}
#[cfg(feature = "rust_decimal")]
if matches!(preferred_crates.numeric, NumericCrate::RustDecimal | NumericCrate::Inferred) {
$(
if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::type_info() == *info {
return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?));
}
)*
$(
if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(info) {
return Ok($crate::select_input_type!($rust_decimal_ty $(, $rust_decimal_input)?));
}
)*
}
#[cfg(not(feature = "rust_decimal"))]
if preferred_crates.numeric == NumericCrate::RustDecimal {
return Err(Error::NumericCrateFeatureNotEnabled);
}
Err(Error::NoMappingFound)
}
fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> {
@ -181,13 +472,50 @@ macro_rules! impl_type_checking {
let info = value.type_info();
match () {
#[cfg(feature = "time")]
{
$(
$(#[$meta])?
_ if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) => $crate::type_checking::FmtValue::debug::<$ty>(value),
if <$time_ty as sqlx_core::types::Type<$database>>::compatible(&info) {
return $crate::type_checking::FmtValue::debug::<$time_ty>(value);
}
)*
_ => $crate::type_checking::FmtValue::unknown(value),
}
#[cfg(feature = "chrono")]
{
$(
if <$chrono_ty as sqlx_core::types::Type<$database>>::compatible(&info) {
return $crate::type_checking::FmtValue::debug::<$chrono_ty>(value);
}
)*
}
#[cfg(feature = "bigdecimal")]
{
$(
if <$bigdecimal_ty as sqlx_core::types::Type<$database>>::compatible(&info) {
return $crate::type_checking::FmtValue::debug::<$bigdecimal_ty>(value);
}
)*
}
#[cfg(feature = "rust_decimal")]
{
$(
if <$rust_decimal_ty as sqlx_core::types::Type<$database>>::compatible(&info) {
return $crate::type_checking::FmtValue::debug::<$rust_decimal_ty>(value);
}
)*
}
$(
$(#[$meta])?
if <$ty as sqlx_core::types::Type<$database>>::compatible(&info) {
return $crate::type_checking::FmtValue::debug::<$ty>(value);
}
)*
$crate::type_checking::FmtValue::unknown(value)
}
}
};

View File

@ -27,6 +27,8 @@ derive = []
macros = []
migrate = ["sqlx-core/migrate"]
sqlx-toml = ["sqlx-core/sqlx-toml"]
# database
mysql = ["sqlx-mysql"]
postgres = ["sqlx-postgres"]

View File

@ -3,11 +3,13 @@ extern crate proc_macro;
use std::path::{Path, PathBuf};
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens, TokenStreamExt};
use sqlx_core::config::Config;
use sqlx_core::migrate::{Migration, MigrationType};
use syn::LitStr;
use sqlx_core::migrate::{Migration, MigrationType};
pub const DEFAULT_PATH: &str = "./migrations";
pub struct QuoteMigrationType(MigrationType);
@ -81,20 +83,26 @@ impl ToTokens for QuoteMigration {
}
}
pub fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result<TokenStream> {
expand_migrator_from_dir(&dir.value(), dir.span())
pub fn default_path(config: &Config) -> &str {
config
.migrate
.migrations_dir
.as_deref()
.unwrap_or(DEFAULT_PATH)
}
pub(crate) fn expand_migrator_from_dir(
dir: &str,
err_span: proc_macro2::Span,
) -> crate::Result<TokenStream> {
let path = crate::common::resolve_path(dir, err_span)?;
pub fn expand(path_arg: Option<LitStr>) -> crate::Result<TokenStream> {
let config = Config::try_from_crate_or_default()?;
expand_migrator(&path)
let path = match path_arg {
Some(path_arg) => crate::common::resolve_path(path_arg.value(), path_arg.span())?,
None => { crate::common::resolve_path(default_path(&config), Span::call_site()) }?,
};
expand_with_path(&config, &path)
}
pub(crate) fn expand_migrator(path: &Path) -> crate::Result<TokenStream> {
pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result<TokenStream> {
let path = path.canonicalize().map_err(|e| {
format!(
"error canonicalizing migration directory {}: {e}",
@ -102,11 +110,19 @@ pub(crate) fn expand_migrator(path: &Path) -> crate::Result<TokenStream> {
)
})?;
let resolve_config = config.migrate.to_resolve_config();
// Use the same code path to resolve migrations at compile time and runtime.
let migrations = sqlx_core::migrate::resolve_blocking(&path)?
let migrations = sqlx_core::migrate::resolve_blocking_with_config(&path, &resolve_config)?
.into_iter()
.map(|(migration, path)| QuoteMigration { migration, path });
let table_name = config.migrate.table_name();
let create_schemas = config.migrate.create_schemas.iter().map(|schema_name| {
quote! { ::std::borrow::Cow::Borrowed(#schema_name) }
});
#[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))]
{
let path = path.to_str().ok_or_else(|| {
@ -124,6 +140,8 @@ pub(crate) fn expand_migrator(path: &Path) -> crate::Result<TokenStream> {
migrations: ::std::borrow::Cow::Borrowed(&[
#(#migrations),*
]),
create_schemas: ::std::borrow::Cow::Borrowed(&[#(#create_schemas),*]),
table_name: ::std::borrow::Cow::Borrowed(#table_name),
..::sqlx::migrate::Migrator::DEFAULT
}
})

View File

@ -1,9 +1,12 @@
use crate::database::DatabaseExt;
use crate::query::QueryMacroInput;
use crate::query::{QueryMacroInput, Warnings};
use either::Either;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use sqlx_core::config::Config;
use sqlx_core::describe::Describe;
use sqlx_core::type_checking;
use sqlx_core::type_info::TypeInfo;
use syn::spanned::Spanned;
use syn::{Expr, ExprCast, ExprGroup, Type};
@ -11,6 +14,8 @@ use syn::{Expr, ExprCast, ExprGroup, Type};
/// and binds them to `DB::Arguments` with the ident `query_args`.
pub fn quote_args<DB: DatabaseExt>(
input: &QueryMacroInput,
config: &Config,
warnings: &mut Warnings,
info: &Describe<DB>,
) -> crate::Result<TokenStream> {
let db_path = DB::db_path();
@ -55,27 +60,7 @@ pub fn quote_args<DB: DatabaseExt>(
return Ok(quote!());
}
let param_ty =
DB::param_type_for_id(param_ty)
.ok_or_else(|| {
if let Some(feature_gate) = DB::get_feature_gate(param_ty) {
format!(
"optional sqlx feature `{}` required for type {} of param #{}",
feature_gate,
param_ty,
i + 1,
)
} else {
format!(
"no built in mapping found for type {} for param #{}; \
a type override may be required, see documentation for details",
param_ty,
i + 1
)
}
})?
.parse::<TokenStream>()
.map_err(|_| format!("Rust type mapping for {param_ty} not parsable"))?;
let param_ty = get_param_type::<DB>(param_ty, config, warnings, i)?;
Ok(quote_spanned!(expr.span() =>
// this shouldn't actually run
@ -120,6 +105,77 @@ pub fn quote_args<DB: DatabaseExt>(
})
}
fn get_param_type<DB: DatabaseExt>(
param_ty: &DB::TypeInfo,
config: &Config,
warnings: &mut Warnings,
i: usize,
) -> crate::Result<TokenStream> {
if let Some(type_override) = config.macros.type_override(param_ty.name()) {
return Ok(type_override.parse()?);
}
let err = match DB::param_type_for_id(param_ty, &config.macros.preferred_crates) {
Ok(t) => return Ok(t.parse()?),
Err(e) => e,
};
let param_num = i + 1;
let message = match err {
type_checking::Error::NoMappingFound => {
if let Some(feature_gate) = DB::get_feature_gate(param_ty) {
format!(
"optional sqlx feature `{feature_gate}` required for type {param_ty} of param #{param_num}",
)
} else {
format!(
"no built-in mapping for type {param_ty} of param #{param_num}; \
a type override may be required, see documentation for details"
)
}
}
type_checking::Error::DateTimeCrateFeatureNotEnabled => {
let feature_gate = config
.macros
.preferred_crates
.date_time
.crate_name()
.expect("BUG: got feature-not-enabled error for DateTimeCrate::Inferred");
format!(
"SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \
(configured by `macros.preferred-crates.date-time` in sqlx.toml)",
)
}
type_checking::Error::NumericCrateFeatureNotEnabled => {
let feature_gate = config
.macros
.preferred_crates
.numeric
.crate_name()
.expect("BUG: got feature-not-enabled error for NumericCrate::Inferred");
format!(
"SQLx feature `{feature_gate}` required for type {param_ty} of param #{param_num} \
(configured by `macros.preferred-crates.numeric` in sqlx.toml)",
)
}
type_checking::Error::AmbiguousDateTimeType { fallback } => {
warnings.ambiguous_datetime = true;
return Ok(fallback.parse()?);
}
type_checking::Error::AmbiguousNumericType { fallback } => {
warnings.ambiguous_numeric = true;
return Ok(fallback.parse()?);
}
};
Err(message.into())
}
fn get_type_override(expr: &Expr) -> Option<&Type> {
match expr {
Expr::Group(group) => get_type_override(&group.expr),

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{hash_map, HashMap};
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock, Mutex};
use std::{fs, io};
@ -15,6 +15,7 @@ use crate::database::DatabaseExt;
use crate::query::data::{hash_string, DynQueryData, QueryData};
use crate::query::input::RecordType;
use either::Either;
use sqlx_core::config::Config;
use url::Url;
mod args;
@ -26,7 +27,7 @@ mod output;
pub struct QueryDriver {
db_name: &'static str,
url_schemes: &'static [&'static str],
expand: fn(QueryMacroInput, QueryDataSource) -> crate::Result<TokenStream>,
expand: fn(&Config, QueryMacroInput, QueryDataSource) -> crate::Result<TokenStream>,
}
impl QueryDriver {
@ -74,6 +75,7 @@ struct Metadata {
offline: bool,
database_url: Option<String>,
offline_dir: Option<String>,
config: Config,
workspace_root: Arc<Mutex<Option<PathBuf>>>,
}
@ -111,7 +113,7 @@ static METADATA: LazyLock<Mutex<HashMap<String, Metadata>>> = LazyLock::new(Defa
// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
fn init_metadata(manifest_dir: &String) -> Metadata {
fn init_metadata(manifest_dir: &String) -> crate::Result<Metadata> {
let manifest_dir: PathBuf = manifest_dir.into();
let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir);
@ -122,15 +124,18 @@ fn init_metadata(manifest_dir: &String) -> Metadata {
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);
let database_url = env("DATABASE_URL").ok().or(database_url);
let config = Config::try_from_crate_or_default()?;
Metadata {
let database_url = env(config.common.database_url_var()).ok().or(database_url);
Ok(Metadata {
manifest_dir,
offline,
database_url,
offline_dir,
config,
workspace_root: Arc::new(Mutex::new(None)),
}
})
}
pub fn expand_input<'a>(
@ -148,9 +153,13 @@ pub fn expand_input<'a>(
guard
});
let metadata = metadata_lock
.entry(manifest_dir)
.or_insert_with_key(init_metadata);
let metadata = match metadata_lock.entry(manifest_dir) {
hash_map::Entry::Occupied(occupied) => occupied.into_mut(),
hash_map::Entry::Vacant(vacant) => {
let metadata = init_metadata(vacant.key())?;
vacant.insert(metadata)
}
};
let data_source = match &metadata {
Metadata {
@ -188,7 +197,7 @@ pub fn expand_input<'a>(
for driver in drivers {
if data_source.matches_driver(driver) {
return (driver.expand)(input, data_source);
return (driver.expand)(&metadata.config, input, data_source);
}
}
@ -210,6 +219,7 @@ pub fn expand_input<'a>(
}
fn expand_with<DB: DatabaseExt>(
config: &Config,
input: QueryMacroInput,
data_source: QueryDataSource,
) -> crate::Result<TokenStream>
@ -224,7 +234,7 @@ where
}
};
expand_with_data(input, query_data, offline)
expand_with_data(config, input, query_data, offline)
}
// marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize`
@ -235,7 +245,14 @@ impl<DB: Database> DescribeExt for Describe<DB> where
{
}
#[derive(Default)]
struct Warnings {
ambiguous_datetime: bool,
ambiguous_numeric: bool,
}
fn expand_with_data<DB: DatabaseExt>(
config: &Config,
input: QueryMacroInput,
data: QueryData<DB>,
offline: bool,
@ -259,7 +276,9 @@ where
}
}
let args_tokens = args::quote_args(&input, &data.describe)?;
let mut warnings = Warnings::default();
let args_tokens = args::quote_args(&input, config, &mut warnings, &data.describe)?;
let query_args = format_ident!("query_args");
@ -278,7 +297,7 @@ where
} else {
match input.record_type {
RecordType::Generated => {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
let columns = output::columns_to_rust::<DB>(&data.describe, config, &mut warnings)?;
let record_name: Type = syn::parse_str("Record").unwrap();
@ -314,22 +333,44 @@ where
record_tokens
}
RecordType::Given(ref out_ty) => {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
let columns = output::columns_to_rust::<DB>(&data.describe, config, &mut warnings)?;
output::quote_query_as::<DB>(&input, out_ty, &query_args, &columns)
}
RecordType::Scalar => {
output::quote_query_scalar::<DB>(&input, &query_args, &data.describe)?
}
RecordType::Scalar => output::quote_query_scalar::<DB>(
&input,
config,
&mut warnings,
&query_args,
&data.describe,
)?,
}
};
let mut warnings_out = TokenStream::new();
if warnings.ambiguous_datetime {
// Warns if the date-time crate is inferred but both `chrono` and `time` are enabled
warnings_out.extend(quote! {
::sqlx::warn_on_ambiguous_inferred_date_time_crate();
});
}
if warnings.ambiguous_numeric {
// Warns if the numeric crate is inferred but both `bigdecimal` and `rust_decimal` are enabled
warnings_out.extend(quote! {
::sqlx::warn_on_ambiguous_inferred_numeric_crate();
});
}
let ret_tokens = quote! {
{
#[allow(clippy::all)]
{
use ::sqlx::Arguments as _;
#warnings_out
#args_tokens
#output

Some files were not shown because too many files have changed in this diff Show More