diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 33d3b295d..f6470c366 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -98,7 +98,7 @@ jobs: strategy: matrix: - os: [ubuntu-latest, windows-latest, macOS-latest] + os: [ubuntu-latest, windows-latest]#, macOS-latest] include: - os: ubuntu-latest target: x86_64-unknown-linux-musl @@ -107,9 +107,10 @@ jobs: - os: windows-latest target: x86_64-pc-windows-msvc bin: target/debug/cargo-sqlx.exe - - os: macOS-latest - target: x86_64-apple-darwin - bin: target/debug/cargo-sqlx + # FIXME: macOS build fails because of missing pin-project-internal +# - os: macOS-latest +# target: x86_64-apple-darwin +# bin: target/debug/cargo-sqlx steps: - uses: actions/checkout@v2 diff --git a/Cargo.lock b/Cargo.lock index ff85040eb..e1027e792 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -42,9 +42,13 @@ dependencies = [ [[package]] name = "ahash" -version = "0.3.8" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217" +checksum = "f166b31431056f04477a03e281aed5655a3fb751c67cd82f70761fe062896d37" +dependencies = [ + "getrandom 0.2.0", + "lazy_static", +] [[package]] name = "aho-corasick" @@ -202,12 +206,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" [[package]] -name = "bigdecimal" -version = "0.1.2" +name = "base64" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1374191e2dd25f9ae02e3aa95041ed5d747fc77b3c102b49fe2dd9a8117a6244" +checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" + +[[package]] +name = "bigdecimal" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc403c26e6b03005522e6e8053384c4e881dfe5b2bf041c0c2c49be33d64a539" dependencies = [ - "num-bigint", + "num-bigint 0.3.0", "num-integer", "num-traits", ] @@ -343,16 +353,16 @@ dependencies = [ "atty", "bitflags", "strsim 0.8.0", - "textwrap", + "textwrap 0.11.0", "unicode-width", "vec_map", ] [[package]] name = "clap" -version = "3.0.0-beta.1" +version = "3.0.0-beta.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "860643c53f980f0d38a5e25dfab6c3c93b2cb3aa1fe192643d17a293c6c41936" +checksum = "4bd1061998a501ee7d4b6d449020df3266ca3124b941ec56cf2005c3779ca142" dependencies = [ "atty", "bitflags", @@ -362,19 +372,19 @@ dependencies = [ "os_str_bytes", "strsim 0.10.0", "termcolor", - "textwrap", + "textwrap 0.12.1", "unicode-width", "vec_map", ] [[package]] name = "clap_derive" -version = "3.0.0-beta.1" +version = "3.0.0-beta.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb51c9e75b94452505acd21d929323f5a5c6c4735a852adbd39ef5fb1b014f30" +checksum = "370f715b81112975b1b69db93e0b56ea4cd4e5002ac43b2da8474106a54096a1" dependencies = [ "heck", - "proc-macro-error 0.4.12", + "proc-macro-error", "proc-macro2", "quote", "syn", @@ -415,6 +425,22 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "console" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a50aab2529019abfabfa93f1e6c41ef392f91fbf179b347a7e96abb524884a08" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "regex", + "terminal_size", + "unicode-width", + "winapi 0.3.9", + "winapi-util", +] + [[package]] name = "copyless" version = "0.1.5" @@ -548,9 +574,9 @@ dependencies = [ [[package]] name = "crypto-mac" -version = "0.8.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" +checksum = "58bcd97a54c7ca5ce2f6eb16f6bede5b0ab5f0055fedc17d2f0b4466e21671ca" dependencies = [ "generic-array", "subtle", @@ -591,13 +617,14 @@ dependencies = [ [[package]] name = "dialoguer" -version = "0.6.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4aa86af7b19b40ef9cbef761ed411a49f0afa06b7b6dcd3dfe2f96a3c546138" +checksum = "70f807b2943dc90f9747497d9d65d7e92472149be0b88bf4ce1201b4ac979c26" dependencies = [ - "console", + "console 0.13.0", "lazy_static", "tempfile", + "zeroize 0.9.3", ] [[package]] @@ -848,6 +875,17 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "glob" version = "0.3.0" @@ -875,13 +913,9 @@ checksum = "d36fab90f82edc3c747f9d438e06cf0a491055896f2a279638bb5beed6c40177" [[package]] name = "hashbrown" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34f595585f103464d8d2f6e9864682d74c1601fed5e07d62b1c9058dba8246fb" -dependencies = [ - "ahash", - "autocfg 1.0.0", -] +checksum = "00d63df3d41950fb462ed38308eea019113ad1508da725bbedcd0fa5a85ef5f7" [[package]] name = "heck" @@ -909,9 +943,9 @@ checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35" [[package]] name = "hmac" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "126888268dcc288495a26bf004b38c5fdbb31682f992c84ceb046a1f0fe38840" +checksum = "deae6d9dbb35ec2c502d62b8f7b1c000a0822c3b0794ba36b3149c0a1c840dff" dependencies = [ "crypto-mac", "digest", @@ -939,9 +973,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b88cd59ee5f71fea89a62248fc8f387d44400cefe05ef548466d61ced9029a7" +checksum = "55e2e4c765aa53a0424761bf9f41aa7a6ac1efa87238f59560640e27fca028f2" dependencies = [ "autocfg 1.0.0", "hashbrown", @@ -1034,9 +1068,9 @@ checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a" [[package]] name = "libsqlite3-sys" -version = "0.18.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e704a02bcaecd4a08b93a23f6be59d0bd79cd161e0963e9499165a0a35df7bd" +checksum = "64d31059f22935e6c31830db5249ba2b7ecd54fd73a9909286f0a67aa55c2fbd" dependencies = [ "cc", "pkg-config", @@ -1224,6 +1258,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f3fc75e3697059fb1bc465e3d8cca6cf92f56854f201158b3f9c77d5a3cfa0" +dependencies = [ + "autocfg 1.0.0", + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.6.0" @@ -1240,7 +1285,7 @@ dependencies = [ "rand", "serde", "smallvec", - "zeroize", + "zeroize 1.1.0", ] [[package]] @@ -1384,22 +1429,9 @@ dependencies = [ [[package]] name = "paste" -version = "0.1.18" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" -dependencies = [ - "paste-impl", - "proc-macro-hack", -] - -[[package]] -name = "paste-impl" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" -dependencies = [ - "proc-macro-hack", -] +checksum = "0520af26d4cf99643dbbe093a61507922b57232d9978d8491fdc8f7b44573c8c" [[package]] name = "paw" @@ -1434,7 +1466,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59698ea79df9bf77104aefd39cc3ec990cb9693fb59c3b0a70ddf2646fdffb4b" dependencies = [ - "base64", + "base64 0.12.3", "once_cell", "regex", ] @@ -1501,45 +1533,19 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "237a5ed80e274dbc66f86bd59c1e25edc039660be53194b5fe0a482e0f2612ea" -[[package]] -name = "proc-macro-error" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" -dependencies = [ - "proc-macro-error-attr 0.4.12", - "proc-macro2", - "quote", - "syn", - "version_check", -] - [[package]] name = "proc-macro-error" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc175e9777c3116627248584e8f8b3e2987405cabe1c0adf7d1dd28f09dc7880" dependencies = [ - "proc-macro-error-attr 1.0.3", + "proc-macro-error-attr", "proc-macro2", "quote", "syn", "version_check", ] -[[package]] -name = "proc-macro-error-attr" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "syn-mid", - "version_check", -] - [[package]] name = "proc-macro-error-attr" version = "1.0.3" @@ -1595,7 +1601,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" dependencies = [ - "getrandom", + "getrandom 0.1.14", "libc", "rand_chacha", "rand_core", @@ -1618,7 +1624,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" dependencies = [ - "getrandom", + "getrandom 0.1.14", ] [[package]] @@ -1632,9 +1638,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f02856753d04e03e26929f820d0a0a337ebe71f849801eea335d464b349080" +checksum = "dcf6960dc9a5b4ee8d3e4c5787b4a112a8818e0290a42ff664ad60692fdf2032" dependencies = [ "autocfg 1.0.0", "crossbeam-deque", @@ -1644,12 +1650,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.7.1" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e92e15d89083484e11353891f1af602cc661426deb9564c298b270c726973280" +checksum = "e8c4fec834fb6e6d2dd5eece3c7b432a52f0ba887cf40e595190c4107edc08bf" dependencies = [ + "crossbeam-channel", "crossbeam-deque", - "crossbeam-queue", "crossbeam-utils", "lazy_static", "num_cpus", @@ -1697,6 +1703,18 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "remove_dir_all" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f43c8c593a759eb8eae137a1ad7b9ed881453a942ac3ce58b87b6e5c2364779" +dependencies = [ + "log", + "num_cpus", + "rayon", + "winapi 0.3.9", +] + [[package]] name = "rsa" version = "0.3.0" @@ -1716,7 +1734,7 @@ dependencies = [ "simple_asn1", "subtle", "thiserror", - "zeroize", + "zeroize 1.1.0", ] [[package]] @@ -1911,7 +1929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "692ca13de57ce0613a363c8c2f1de925adebc81b04c923ac60c5488bb44abe4b" dependencies = [ "chrono", - "num-bigint", + "num-bigint 0.2.6", "num-traits", ] @@ -2018,13 +2036,14 @@ dependencies = [ "async-trait", "cargo_metadata", "chrono", - "clap 3.0.0-beta.1", - "console", + "clap 3.0.0-beta.2", + "console 0.11.3", "dialoguer", "dotenv", "futures", "glob", "openssl", + "remove_dir_all 0.6.0", "serde", "serde_json", "sqlx", @@ -2036,8 +2055,9 @@ dependencies = [ name = "sqlx-core" version = "0.4.0-beta.1" dependencies = [ + "ahash", "atoi", - "base64", + "base64 0.13.0", "bigdecimal", "bit-vec", "bitflags", @@ -2055,7 +2075,6 @@ dependencies = [ "futures-core", "futures-util", "generic-array", - "hashbrown", "hex", "hmac", "ipnetwork", @@ -2066,7 +2085,7 @@ dependencies = [ "lru-cache", "md-5", "memchr", - "num-bigint", + "num-bigint 0.3.0", "once_cell", "parking_lot", "percent-encoding", @@ -2110,6 +2129,19 @@ dependencies = [ "sqlx", ] +[[package]] +name = "sqlx-example-postgres-todos" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-std", + "dotenv", + "futures", + "paw", + "sqlx", + "structopt", +] + [[package]] name = "sqlx-example-sqlite-todos" version = "0.1.0" @@ -2266,7 +2298,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "510413f9de616762a4fbeab62509bf15c729603b72d7cd71280fbca431b1c118" dependencies = [ "heck", - "proc-macro-error 1.0.3", + "proc-macro-error", "proc-macro2", "quote", "syn", @@ -2322,7 +2354,7 @@ dependencies = [ "libc", "rand", "redox_syscall", - "remove_dir_all", + "remove_dir_all 0.5.3", "winapi 0.3.9", ] @@ -2363,6 +2395,15 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "textwrap" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "203008d98caf094106cfaba70acfed15e18ed3ddb7d94e49baec153a2b462789" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.20" @@ -2777,6 +2818,12 @@ dependencies = [ "winapi-build", ] +[[package]] +name = "zeroize" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45af6a010d13e4cf5b54c94ba5a2b2eba5596b9e46bf5875612d332a1f2b3f86" + [[package]] name = "zeroize" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9969fd913..11729527f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "sqlx-bench", "examples/mysql/todos", "examples/postgres/listen", + "examples/postgres/todos", "examples/sqlite/todos", ] @@ -88,7 +89,7 @@ dotenv = "0.15.0" trybuild = "1.0.24" sqlx-rt = { path = "./sqlx-rt" } sqlx-test = { path = "./sqlx-test" } -paste = "0.1.16" +paste = "1.0.1" serde = { version = "1.0.111", features = [ "derive" ] } serde_json = "1.0.53" url = "2.1.1" diff --git a/README.md b/README.md index 6d7585272..72b7ec24a 100644 --- a/README.md +++ b/README.md @@ -124,11 +124,21 @@ sqlx = "0.4.0-beta.1" sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runtime-tokio", "macros" ] } ``` +**actix** + +```toml +# Cargo.toml +[dependencies] +sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runtime-actix", "macros" ] } +``` + #### Cargo Feature Flags * `runtime-async-std` (on by default): Use the `async-std` runtime. - * `runtime-tokio`: Use the `tokio` runtime. Mutually exclusive with the `runtime-async-std` feature. + * `runtime-tokio`: Use the `tokio` runtime. Mutually exclusive to all other runtimes. + + * `runtime-actix`: Use the `actix_rt` runtime. Mutually exclusive to all other runtimes. * `postgres`: Add support for the Postgres database server. @@ -140,6 +150,10 @@ sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runti * `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime. + * `macros`: Add support for the `query*!` macros, which allow compile-time checked queries. + + * `migrate`: Add support for the migration management and `migrate!` macro, which allow compile-time embedded migrations. + * `uuid`: Add support for UUID (in Postgres). * `chrono`: Add support for date and time types from `chrono`. @@ -333,7 +347,7 @@ WHERE organization = ? ", organization ) - .fetch_all() // -> Vec + .fetch_all(&pool) // -> Vec .await?; // countries[0].country diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index b94ff73d0..a6e248fda 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -29,18 +29,20 @@ dotenv = "0.15" tokio = { version = "0.2", features = ["macros"] } sqlx = { version = "0.4.0-beta.1", path = "..", default-features = false, features = [ "runtime-async-std", "migrate", "any", "offline" ] } futures = "0.3" -clap = "3.0.0-beta.1" +clap = "=3.0.0-beta.2" chrono = "0.4" anyhow = "1.0" url = { version = "2.1.1", default-features = false } async-trait = "0.1.30" console = "0.11.3" -dialoguer = "0.6.2" +dialoguer = "0.7.1" serde_json = { version = "1.0.53", features = ["preserve_order"] } serde = "1.0.110" glob = "0.3.0" cargo_metadata = "0.10.0" openssl = { version = "0.10.30", optional = true } +# workaround for https://github.com/rust-lang/rust/issues/29497 +remove_dir_all = "0.6.0" [features] default = [ "postgres", "sqlite", "mysql" ] diff --git a/sqlx-cli/README.md b/sqlx-cli/README.md index 9276451b3..80e7f81e9 100644 --- a/sqlx-cli/README.md +++ b/sqlx-cli/README.md @@ -80,4 +80,4 @@ database schema and queries in the project. Intended for use in Continuous Integ To make sure an accidentally-present `DATABASE_URL` environment variable or `.env` file does not result in `cargo build` (trying to) access the database, you can set the `SQLX_OFFLINE` environment -variable. +variable to `true`. diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 725584817..901c44c63 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,3 +1,4 @@ +use crate::migrate; use console::style; use dialoguer::Confirm; use sqlx::any::Any; @@ -18,6 +19,7 @@ pub async fn drop(uri: &str, confirm: bool) -> anyhow::Result<()> { "\nAre you sure you want to drop the database at {}?", style(uri).cyan() )) + .wait_for_newline(true) .default(false) .interact()? { @@ -30,3 +32,13 @@ pub async fn drop(uri: &str, confirm: bool) -> anyhow::Result<()> { Ok(()) } + +pub async fn reset(uri: &str, confirm: bool) -> anyhow::Result<()> { + drop(uri, confirm).await?; + setup(uri).await +} + +pub async fn setup(uri: &str) -> anyhow::Result<()> { + create(uri).await?; + migrate::run(uri).await +} diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 1f1ca0598..3da90e033 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -39,6 +39,8 @@ hint: This command only works in the manifest directory of a Cargo package."# Command::Database(database) => match database.command { DatabaseCommand::Create => database::create(&database_url).await?, DatabaseCommand::Drop { yes } => database::drop(&database_url, !yes).await?, + DatabaseCommand::Reset { yes } => database::reset(&database_url, yes).await?, + DatabaseCommand::Setup => database::setup(&database_url).await?, }, Command::Prepare { check: false, args } => prepare::run(&database_url, args)?, diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index a90b67e3b..a0f5afacb 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -5,7 +5,7 @@ pub struct Opt { #[clap(subcommand)] pub command: Command, - #[clap(short = "D", long)] + #[clap(short = 'D', long)] pub database_url: Option, } @@ -57,6 +57,15 @@ pub enum DatabaseCommand { #[clap(short)] yes: bool, }, + /// Drops the database specified in your DATABASE_URL, re-creates it, and runs any pending migrations. + Reset { + /// Automatic confirmation. Without this option, you will be prompted before dropping + /// your database. + #[clap(short)] + yes: bool, + }, + /// Creates the database specified in your DATABASE_URL and runs any pending migrations. + Setup, } /// Group of commands for creating and running migrations. diff --git a/sqlx-cli/src/prepare.rs b/sqlx-cli/src/prepare.rs index b38ddc1cc..8a588ba1d 100644 --- a/sqlx-cli/src/prepare.rs +++ b/sqlx-cli/src/prepare.rs @@ -1,6 +1,7 @@ use anyhow::{bail, Context}; use cargo_metadata::MetadataCommand; use console::style; +use remove_dir_all::remove_dir_all; use sqlx::any::{AnyConnectOptions, AnyKind}; use std::collections::BTreeMap; use std::fs::File; @@ -25,7 +26,7 @@ pub fn run(url: &str, cargo_args: Vec) -> anyhow::Result<()> { if data.is_empty() { println!( - "{} no queries found; do you have the `offline` feature enabled", + "{} no queries found; do you have the `offline` feature enabled in sqlx?", style("warning:").yellow() ); } @@ -80,7 +81,16 @@ pub fn check(url: &str, cargo_args: Vec) -> anyhow::Result<()> { fn run_prepare_step(cargo_args: Vec) -> anyhow::Result { // path to the Cargo executable let cargo = env::var("CARGO") - .context("`prepare` subcommand may only be invoked as `cargo sqlx prepare``")?; + .context("`prepare` subcommand may only be invoked as `cargo sqlx prepare`")?; + + let metadata = MetadataCommand::new() + .cargo_path(&cargo) + .exec() + .context("failed to execute `cargo metadata`")?; + + // try removing the target/sqlx directory before running, as stale files + // have repeatedly caused issues in the past. + let _ = remove_dir_all(metadata.target_directory.join("sqlx")); let check_status = Command::new(&cargo) .arg("rustc") @@ -94,17 +104,13 @@ fn run_prepare_step(cargo_args: Vec) -> anyhow::Result { "__sqlx_recompile_trigger=\"{}\"", SystemTime::UNIX_EPOCH.elapsed()?.as_millis() )) + .env("SQLX_OFFLINE", "false") .status()?; if !check_status.success() { bail!("`cargo check` failed with status: {}", check_status); } - let metadata = MetadataCommand::new() - .cargo_path(cargo) - .exec() - .context("failed to execute `cargo metadata`")?; - let pattern = metadata.target_directory.join("sqlx/query-*.json"); let mut data = BTreeMap::new(); diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 32cd74f25..e8053610b 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -12,6 +12,9 @@ authors = [ "Daniel Akhterov ", ] +[package.metadata.docs.rs] +features = ["all-databases", "all-types", "offline"] + [features] default = [ "runtime-async-std", "migrate" ] migrate = [ "sha2", "crc" ] @@ -39,10 +42,11 @@ runtime-actix = [ "sqlx-rt/runtime-actix" ] offline = [ "serde", "either/serde" ] [dependencies] +ahash = "0.5" atoi = "0.3.2" sqlx-rt = { path = "../sqlx-rt", version = "0.1.1" } -base64 = { version = "0.12.1", default-features = false, optional = true, features = [ "std" ] } -bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" } +base64 = { version = "0.13.0", default-features = false, optional = true, features = [ "std" ] } +bigdecimal_ = { version = "0.2.0", optional = true, package = "bigdecimal" } rust_decimal = { version = "1.7.0", optional = true } bit-vec = { version = "0.6.2", optional = true } bitflags = { version = "1.2.1", default-features = false } @@ -60,17 +64,16 @@ futures-channel = { version = "0.3.5", default-features = false, features = [ "s futures-core = { version = "0.3.5", default-features = false } futures-util = { version = "0.3.5", features = [ "sink" ] } generic-array = { version = "0.14.2", default-features = false, optional = true } -hashbrown = "0.8.0" hex = "0.4.2" -hmac = { version = "0.8.0", default-features = false, optional = true } +hmac = { version = "0.9.0", default-features = false, optional = true } itoa = "0.4.5" ipnetwork = { version = "0.17.0", default-features = false, optional = true } libc = "0.2.71" -libsqlite3-sys = { version = "0.18.0", optional = true, default-features = false, features = [ "pkg-config", "vcpkg", "bundled" ] } +libsqlite3-sys = { version = "0.20.1", optional = true, default-features = false, features = [ "pkg-config", "vcpkg", "bundled" ] } log = { version = "0.4.8", default-features = false } md-5 = { version = "0.9.0", default-features = false, optional = true } memchr = { version = "2.3.3", default-features = false } -num-bigint = { version = "0.2.0", default-features = false, optional = true, features = [ "std" ] } +num-bigint = { version = "0.3.0", default-features = false, optional = true, features = [ "std" ] } once_cell = "1.4.0" percent-encoding = "2.1.0" parking_lot = "0.11.0" diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index a2729c110..f87fbab55 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -76,7 +76,7 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.ensure_migrations_table(), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => unimplemented!(), } } @@ -92,7 +92,7 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.version(), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => unimplemented!(), } } @@ -108,7 +108,7 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.lock(), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => unimplemented!(), } } @@ -124,7 +124,7 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.unlock(), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => unimplemented!(), } } @@ -143,7 +143,10 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.validate(migration), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => { + let _ = migration; + unimplemented!() + } } } @@ -162,7 +165,10 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.apply(migration), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => { + let _ = migration; + unimplemented!() + } } } } diff --git a/sqlx-core/src/any/options.rs b/sqlx-core/src/any/options.rs index 0978b5778..33a9845eb 100644 --- a/sqlx-core/src/any/options.rs +++ b/sqlx-core/src/any/options.rs @@ -60,6 +60,34 @@ pub(crate) enum AnyConnectOptionsKind { Mssql(MssqlConnectOptions), } +#[cfg(feature = "postgres")] +impl From for AnyConnectOptions { + fn from(options: PgConnectOptions) -> Self { + Self(AnyConnectOptionsKind::Postgres(options)) + } +} + +#[cfg(feature = "mysql")] +impl From for AnyConnectOptions { + fn from(options: MySqlConnectOptions) -> Self { + Self(AnyConnectOptionsKind::MySql(options)) + } +} + +#[cfg(feature = "sqlite")] +impl From for AnyConnectOptions { + fn from(options: SqliteConnectOptions) -> Self { + Self(AnyConnectOptionsKind::Sqlite(options)) + } +} + +#[cfg(feature = "mssql")] +impl From for AnyConnectOptions { + fn from(options: MssqlConnectOptions) -> Self { + Self(AnyConnectOptionsKind::Mssql(options)) + } +} + impl FromStr for AnyConnectOptions { type Err = Error; diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index 19b712a3a..0c283c2e5 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -3,8 +3,8 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::statement::Statement; +use crate::HashMap; use either::Either; -use hashbrown::HashMap; use std::borrow::Cow; use std::sync::Arc; diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs index d5695a7cb..2ae097203 100644 --- a/sqlx-core/src/common/statement_cache.rs +++ b/sqlx-core/src/common/statement_cache.rs @@ -65,6 +65,7 @@ impl StatementCache { } /// Returns true if the cache capacity is more than 0. + #[allow(dead_code)] // Only used for some `cfg`s pub fn is_enabled(&self) -> bool { self.capacity() > 0 } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index ab164cd4b..ff9b05edd 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -92,6 +92,12 @@ pub enum Error { #[error("attempted to acquire a connection on a closed pool")] PoolClosed, + /// A background worker (e.g. [`StatementWorker`]) has crashed. + /// + /// [`StatementWorker`]: crate::sqlite::StatementWorker + #[error("attempted to communicate with a crashed background worker")] + WorkerCrashed, + #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), diff --git a/sqlx-core/src/from_row.rs b/sqlx-core/src/from_row.rs index 973c4b330..2c2ce3a70 100644 --- a/sqlx-core/src/from_row.rs +++ b/sqlx-core/src/from_row.rs @@ -47,6 +47,25 @@ use crate::row::Row; /// /// will read the content of the column `description` into the field `about_me`. /// +/// #### `rename_all` +/// By default, field names are expected verbatim (with the exception of the raw identifier prefix `r#`, if present). +/// Placed at the struct level, this attribute changes how the field name is mapped to its SQL column name: +/// +/// ```rust,ignore +/// #[derive(sqlx::FromRow)] +/// #[sqlx(rename_all = "camelCase")] +/// struct UserPost { +/// id: i32, +/// // remapped to "userId" +/// user_id: i32, +/// contents: String +/// } +/// ``` +/// +/// The supported values are `snake_case` (available if you have non-snake-case field names for some +/// reason), `lowercase`, `UPPERCASE`, `camelCase`, `SCREAMING_SNAKE_CASE` and `kebab-case`. +/// The styling of each option is intended to be an example of its behavior. +/// /// #### `default` /// /// When your struct contains a field that is not present in your query, diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index bf8124684..6e18f16fc 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -99,3 +99,7 @@ pub mod mysql; #[cfg(feature = "mssql")] #[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] pub mod mssql; + +/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. +use ahash::AHashMap as HashMap; +//type HashMap = std::collections::HashMap; diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index d7d3604f5..1ce061d50 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -20,7 +20,7 @@ use crate::mssql::protocol::return_value::ReturnValue; use crate::mssql::protocol::row::Row; use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError}; use crate::net::MaybeTlsStream; -use hashbrown::HashMap; +use crate::HashMap; use std::sync::Arc; pub(crate) struct MssqlStream { diff --git a/sqlx-core/src/mssql/protocol/col_meta_data.rs b/sqlx-core/src/mssql/protocol/col_meta_data.rs index ce8cfe0d8..a91a590a9 100644 --- a/sqlx-core/src/mssql/protocol/col_meta_data.rs +++ b/sqlx-core/src/mssql/protocol/col_meta_data.rs @@ -6,7 +6,7 @@ use crate::ext::ustr::UStr; use crate::mssql::io::MssqlBufExt; use crate::mssql::protocol::type_info::TypeInfo; use crate::mssql::MssqlColumn; -use hashbrown::HashMap; +use crate::HashMap; #[derive(Debug)] pub(crate) struct ColMetaData; diff --git a/sqlx-core/src/mssql/row.rs b/sqlx-core/src/mssql/row.rs index 6b78a3944..08f3ec639 100644 --- a/sqlx-core/src/mssql/row.rs +++ b/sqlx-core/src/mssql/row.rs @@ -4,7 +4,7 @@ use crate::ext::ustr::UStr; use crate::mssql::protocol::row::Row as ProtocolRow; use crate::mssql::{Mssql, MssqlColumn, MssqlValueRef}; use crate::row::Row; -use hashbrown::HashMap; +use crate::HashMap; use std::sync::Arc; pub struct MssqlRow { diff --git a/sqlx-core/src/mssql/statement.rs b/sqlx-core/src/mssql/statement.rs index eb90d4274..3bba4906b 100644 --- a/sqlx-core/src/mssql/statement.rs +++ b/sqlx-core/src/mssql/statement.rs @@ -3,8 +3,8 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::mssql::{Mssql, MssqlArguments, MssqlColumn, MssqlTypeInfo}; use crate::statement::Statement; +use crate::HashMap; use either::Either; -use hashbrown::HashMap; use std::borrow::Cow; use std::sync::Arc; diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 509ab0d1c..8de32e7ca 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -16,12 +16,12 @@ use crate::mysql::{ MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlDone, MySqlRow, MySqlTypeInfo, MySqlValueFormat, }; +use crate::HashMap; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; -use hashbrown::HashMap; use std::{borrow::Cow, sync::Arc}; impl MySqlConnection { diff --git a/sqlx-core/src/mysql/options/mod.rs b/sqlx-core/src/mysql/options/mod.rs index eb55e6586..ea7862fd2 100644 --- a/sqlx-core/src/mysql/options/mod.rs +++ b/sqlx-core/src/mysql/options/mod.rs @@ -8,7 +8,7 @@ pub use ssl_mode::MySqlSslMode; /// Options and flags which can be used to configure a MySQL connection. /// -/// A value of `PgConnectOptions` can be parsed from a connection URI, +/// A value of `MySqlConnectOptions` can be parsed from a connection URI, /// as described by [MySQL](https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-reference-jdbc-url-format.html). /// /// The generic format of the connection URL: diff --git a/sqlx-core/src/mysql/protocol/statement/execute.rs b/sqlx-core/src/mysql/protocol/statement/execute.rs index 5eb67f3f4..47d7e2089 100644 --- a/sqlx-core/src/mysql/protocol/statement/execute.rs +++ b/sqlx-core/src/mysql/protocol/statement/execute.rs @@ -16,7 +16,7 @@ impl<'q> Encode<'_, Capabilities> for Execute<'q> { buf.push(0x17); // COM_STMT_EXECUTE buf.extend(&self.statement.to_le_bytes()); buf.push(0); // NO_CURSOR - buf.extend(&0_u32.to_le_bytes()); // iterations (always 1): int<4> + buf.extend(&1_u32.to_le_bytes()); // iterations (always 1): int<4> if !self.arguments.types.is_empty() { buf.extend(&*self.arguments.null_bitmap); diff --git a/sqlx-core/src/mysql/protocol/text/column.rs b/sqlx-core/src/mysql/protocol/text/column.rs index 9c92696e5..0a539bb1c 100644 --- a/sqlx-core/src/mysql/protocol/text/column.rs +++ b/sqlx-core/src/mysql/protocol/text/column.rs @@ -167,6 +167,7 @@ impl ColumnType { ) -> &'static str { let is_binary = char_set == 63; let is_unsigned = flags.contains(ColumnFlags::UNSIGNED); + let is_enum = flags.contains(ColumnFlags::ENUM); match self { ColumnType::Tiny if max_size == Some(1) => "BOOLEAN", @@ -196,6 +197,7 @@ impl ColumnType { ColumnType::Json => "JSON", ColumnType::String if is_binary => "BINARY", + ColumnType::String if is_enum => "ENUM", ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY", ColumnType::String => "CHAR", diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index f67c24f6a..f910ded68 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -3,7 +3,7 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::mysql::{protocol, MySql, MySqlColumn, MySqlValueFormat, MySqlValueRef}; use crate::row::Row; -use hashbrown::HashMap; +use crate::HashMap; use std::sync::Arc; /// Implementation of [`Row`] for MySQL. diff --git a/sqlx-core/src/mysql/statement.rs b/sqlx-core/src/mysql/statement.rs index b9ba75f62..b6de92fa4 100644 --- a/sqlx-core/src/mysql/statement.rs +++ b/sqlx-core/src/mysql/statement.rs @@ -4,8 +4,8 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::mysql::{MySql, MySqlArguments, MySqlTypeInfo}; use crate::statement::Statement; +use crate::HashMap; use either::Either; -use hashbrown::HashMap; use std::borrow::Cow; use std::sync::Arc; diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index 044338c1c..6f0c1a338 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -6,8 +6,10 @@ use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; +const COLLATE_UTF8_GENERAL_CI: u16 = 33; const COLLATE_UTF8_UNICODE_CI: u16 = 192; const COLLATE_UTF8MB4_UNICODE_CI: u16 = 224; +const COLLATE_UTF8MB4_BIN: u16 = 46; impl Type for str { fn type_info() -> MySqlTypeInfo { @@ -31,8 +33,13 @@ impl Type for str { | ColumnType::String | ColumnType::VarString | ColumnType::Enum - ) && (ty.char_set == COLLATE_UTF8MB4_UNICODE_CI as u16 - || ty.char_set == COLLATE_UTF8_UNICODE_CI as u16) + ) && matches!( + ty.char_set, + COLLATE_UTF8MB4_UNICODE_CI + | COLLATE_UTF8_UNICODE_CI + | COLLATE_UTF8_GENERAL_CI + | COLLATE_UTF8MB4_BIN + ) } } diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index e17903c4b..aef5983e8 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -13,12 +13,13 @@ use std::mem; use std::ptr; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::Arc; +use std::task::Context; use std::time::Instant; pub(crate) struct SharedPool { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, - waiters: SegQueue, + waiters: SegQueue>, pub(super) size: AtomicU32, is_closed: AtomicBool, pub(super) options: PoolOptions, @@ -122,19 +123,22 @@ impl SharedPool { return Err(Error::PoolClosed); } - let mut waker_pushed = false; + let mut waiter = None; timeout( deadline_as_timeout::(deadline)?, // `poll_fn` gets us easy access to a `Waker` that we can push to our queue - future::poll_fn(|ctx| -> Poll<()> { - if !waker_pushed { - // only push the waker once - self.waiters.push(ctx.waker().to_owned()); - waker_pushed = true; - Poll::Pending - } else { + future::poll_fn(|cx| -> Poll<()> { + let waiter = waiter.get_or_insert_with(|| { + let waiter = Waiter::new(cx); + self.waiters.push(waiter.clone()); + waiter + }); + + if waiter.is_woken() { Poll::Ready(()) + } else { + Poll::Pending } }), ) @@ -346,7 +350,7 @@ fn spawn_reaper(pool: &Arc>) { /// (where the pool thinks it has more connections than it does). pub(in crate::pool) struct DecrementSizeGuard<'a> { size: &'a AtomicU32, - waiters: &'a SegQueue, + waiters: &'a SegQueue>, dropped: bool, } @@ -379,3 +383,26 @@ impl Drop for DecrementSizeGuard<'_> { } } } + +struct Waiter { + woken: AtomicBool, + waker: Waker, +} + +impl Waiter { + fn new(cx: &mut Context<'_>) -> Arc { + Arc::new(Self { + woken: AtomicBool::new(false), + waker: cx.waker().clone(), + }) + } + + fn wake(&self) { + self.woken.store(true, Ordering::Release); + self.waker.wake_by_ref(); + } + + fn is_woken(&self) -> bool { + self.woken.load(Ordering::Acquire) + } +} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index a5b8b4688..1458dd131 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -145,6 +145,20 @@ impl Pool { pub fn is_closed(&self) -> bool { self.0.is_closed() } + + /// Returns the number of connections currently active. This includes idle connections. + pub fn size(&self) -> u32 { + self.0.size() + } + + /// Returns the number of connections active and idle (not in use). + /// + /// This will block until the number of connections stops changing for at + /// least 2 atomic accesses in a row. If the number of idle connections is + /// changing rapidly, this may run indefinitely. + pub fn num_idle(&self) -> usize { + self.0.num_idle() + } } /// Returns a new [Pool] tied to the same shared connection pool. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index a8d5d168c..ff12968dc 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -142,6 +142,26 @@ impl PoolOptions { self } + /// Perform an action after connecting to the database. + /// + /// # Example + /// + /// ```no_run + /// # async fn f() -> Result<(), Box> { + /// use sqlx_core::executor::Executor; + /// use sqlx_core::postgres::PgPoolOptions; + /// // PostgreSQL + /// let pool = PgPoolOptions::new() + /// .after_connect(|conn| Box::pin(async move { + /// conn.execute("SET application_name = 'your_app';").await?; + /// conn.execute("SET search_path = 'my_schema';").await?; + /// + /// Ok(()) + /// })) + /// .connect("postgres:// â€Ķ").await?; + /// # Ok(()) + /// # } + /// ``` pub fn after_connect(mut self, callback: F) -> Self where for<'c> F: diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index d91c9c8dd..97ce8734f 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -5,8 +5,8 @@ use crate::postgres::type_info::{PgCustomType, PgType, PgTypeKind}; use crate::postgres::{PgArguments, PgColumn, PgConnection, PgTypeInfo}; use crate::query_as::query_as; use crate::query_scalar::{query_scalar, query_scalar_with}; +use crate::HashMap; use futures_core::future::BoxFuture; -use hashbrown::HashMap; use std::fmt::Write; use std::sync::Arc; diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs index 9218a2bc5..e3de0c6d6 100644 --- a/sqlx-core/src/postgres/connection/establish.rs +++ b/sqlx-core/src/postgres/connection/establish.rs @@ -1,4 +1,4 @@ -use hashbrown::HashMap; +use crate::HashMap; use crate::common::StatementCache; use crate::error::Error; diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index c8b60abfc..b8b75ea3e 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -1,9 +1,9 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; +use crate::HashMap; use futures_core::future::BoxFuture; use futures_util::{FutureExt, TryFutureExt}; -use hashbrown::HashMap; use crate::common::StatementCache; use crate::connection::Connection; diff --git a/sqlx-core/src/postgres/message/response.rs b/sqlx-core/src/postgres/message/response.rs index 8da3e10d9..767dd7673 100644 --- a/sqlx-core/src/postgres/message/response.rs +++ b/sqlx-core/src/postgres/message/response.rs @@ -26,6 +26,29 @@ impl PgSeverity { } } +impl std::convert::TryFrom<&str> for PgSeverity { + type Error = Error; + + fn try_from(s: &str) -> Result { + let result = match s { + "PANIC" => PgSeverity::Panic, + "FATAL" => PgSeverity::Fatal, + "ERROR" => PgSeverity::Error, + "WARNING" => PgSeverity::Warning, + "NOTICE" => PgSeverity::Notice, + "DEBUG" => PgSeverity::Debug, + "INFO" => PgSeverity::Info, + "LOG" => PgSeverity::Log, + + severity => { + return Err(err_protocol!("unknown severity: {:?}", severity)); + } + }; + + Ok(result) + } +} + #[derive(Debug)] pub struct Notice { storage: Bytes, @@ -84,7 +107,12 @@ impl Notice { impl Decode<'_> for Notice { fn decode_with(buf: Bytes, _: ()) -> Result { - let mut severity = PgSeverity::Log; + // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. + // Newer versions additionally come with the V field that is guaranteed to be in English. + // We thus read both versions and prefer the unlocalized one if available. + const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; + let mut severity_v = None; + let mut severity_s = None; let mut message = (0, 0); let mut code = (0, 0); @@ -103,23 +131,24 @@ impl Decode<'_> for Notice { break; } + use std::convert::TryInto; match field { - b'S' | b'V' => { - // unwrap: impossible to fail at this point - severity = match from_utf8(&buf[v.0 as usize..v.1 as usize]).unwrap() { - "PANIC" => PgSeverity::Panic, - "FATAL" => PgSeverity::Fatal, - "ERROR" => PgSeverity::Error, - "WARNING" => PgSeverity::Warning, - "NOTICE" => PgSeverity::Notice, - "DEBUG" => PgSeverity::Debug, - "INFO" => PgSeverity::Info, - "LOG" => PgSeverity::Log, + b'S' => { + // Discard potential errors, because the message might be localized + severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize]) + .unwrap() + .try_into() + .ok(); + } - severity => { - return Err(err_protocol!("unknown severity: {:?}", severity)); - } - }; + b'V' => { + // Propagate errors here, because V is not localized and thus we are missing a possible + // variant. + severity_v = Some( + from_utf8(&buf[v.0 as usize..v.1 as usize]) + .unwrap() + .try_into()?, + ); } b'M' => { @@ -135,7 +164,7 @@ impl Decode<'_> for Notice { } Ok(Self { - severity, + severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY), message, code, storage: buf, diff --git a/sqlx-core/src/postgres/options/mod.rs b/sqlx-core/src/postgres/options/mod.rs index d7849cb84..6a3e21d0c 100644 --- a/sqlx-core/src/postgres/options/mod.rs +++ b/sqlx-core/src/postgres/options/mod.rs @@ -26,6 +26,12 @@ pub use ssl_mode::PgSslMode; /// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | /// | `host` | `None` | Path to the directory containing a PostgreSQL unix domain socket, which will be used instead of TCP if set. | +/// | `hostaddr` | `None` | Same as `host`, but only accepts IP addresses. | +/// | `application-name` | `None` | The name will be displayed in the pg_stat_activity view and included in CSV log entries. | +/// | `user` | result of `whoami` | PostgreSQL user name to connect as. | +/// | `password` | `None` | Password to be used if the server demands password authentication. | +/// | `port` | `5432` | Port number to connect to at the server host, or socket file name extension for Unix-domain connections. | +/// | `dbname` | `None` | The database name. | /// /// The URI scheme designator can be either `postgresql://` or `postgres://`. /// Each of the URI parts is optional. @@ -37,6 +43,7 @@ pub use ssl_mode::PgSslMode; /// postgresql://localhost/mydb /// postgresql://user@localhost /// postgresql://user:secret@localhost +/// postgresql://localhost?dbname=mydb&user=postgres&password=postgres /// ``` /// /// # Example diff --git a/sqlx-core/src/postgres/options/parse.rs b/sqlx-core/src/postgres/options/parse.rs index 3d1cb258d..5c5cd71ee 100644 --- a/sqlx-core/src/postgres/options/parse.rs +++ b/sqlx-core/src/postgres/options/parse.rs @@ -1,6 +1,7 @@ use crate::error::Error; use crate::postgres::PgConnectOptions; use percent_encoding::percent_decode_str; +use std::net::IpAddr; use std::str::FromStr; use url::Url; @@ -13,7 +14,11 @@ impl FromStr for PgConnectOptions { let mut options = Self::default(); if let Some(host) = url.host_str() { - options = options.host(host); + let host_decoded = percent_decode_str(host); + options = match host_decoded.clone().next() { + Some(b'/') => options.socket(&*host_decoded.decode_utf8().map_err(Error::config)?), + _ => options.host(host), + } } if let Some(port) = url.port() { @@ -65,11 +70,22 @@ impl FromStr for PgConnectOptions { } } - "application_name" => { - options = options.application_name(&*value); + "hostaddr" => { + value.parse::().map_err(Error::config)?; + options = options.host(&*value) } - _ => {} + "port" => options = options.port(value.parse().map_err(Error::config)?), + + "dbname" => options = options.database(&*value), + + "user" => options = options.username(&*value), + + "password" => options = options.password(&*value), + + "application_name" => options = options.application_name(&*value), + + _ => log::warn!("ignoring unrecognized connect parameter: {}={}", key, value), } } @@ -94,6 +110,51 @@ fn it_parses_host_correctly_from_parameter() { assert_eq!("google.database.com", &opts.host); } +#[test] +fn it_parses_hostaddr_correctly_from_parameter() { + let uri = "postgres:///?hostaddr=8.8.8.8"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!("8.8.8.8", &opts.host); +} + +#[test] +fn it_parses_port_correctly_from_parameter() { + let uri = "postgres:///?port=1234"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(1234, opts.port); +} + +#[test] +fn it_parses_dbname_correctly_from_parameter() { + let uri = "postgres:///?dbname=some_db"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(Some("some_db"), opts.database.as_deref()); +} + +#[test] +fn it_parses_user_correctly_from_parameter() { + let uri = "postgres:///?user=some_user"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!("some_user", opts.username); +} + +#[test] +fn it_parses_password_correctly_from_parameter() { + let uri = "postgres:///?password=some_pass"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(None, opts.socket); + assert_eq!(Some("some_pass"), opts.password.as_deref()); +} + #[test] fn it_parses_application_name_correctly_from_parameter() { let uri = "postgres:///?application_name=some_name"; @@ -117,3 +178,20 @@ fn it_parses_password_with_non_ascii_chars_correctly() { assert_eq!(Some("p@ssw0rd".into()), opts.password); } + +#[test] +fn it_parses_socket_correctly_percent_encoded() { + let uri = "postgres://%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); +} +#[test] +fn it_parses_socket_correctly_with_username_percent_encoded() { + let uri = "postgres://some_user@%2Fvar%2Flib%2Fpostgres/database"; + let opts = PgConnectOptions::from_str(uri).unwrap(); + + assert_eq!("some_user", opts.username); + assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); + assert_eq!(Some("database"), opts.database.as_deref()); +} diff --git a/sqlx-core/src/postgres/statement.rs b/sqlx-core/src/postgres/statement.rs index 06aead1f5..4c01b9156 100644 --- a/sqlx-core/src/postgres/statement.rs +++ b/sqlx-core/src/postgres/statement.rs @@ -4,8 +4,8 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::postgres::{PgArguments, Postgres}; use crate::statement::Statement; +use crate::HashMap; use either::Either; -use hashbrown::HashMap; use std::borrow::Cow; use std::sync::Arc; diff --git a/sqlx-core/src/postgres/types/bigdecimal.rs b/sqlx-core/src/postgres/types/bigdecimal.rs index 28b2603c0..617c19c5f 100644 --- a/sqlx-core/src/postgres/types/bigdecimal.rs +++ b/sqlx-core/src/postgres/types/bigdecimal.rs @@ -1,7 +1,7 @@ use std::cmp; use std::convert::{TryFrom, TryInto}; -use bigdecimal::BigDecimal; +use bigdecimal::{BigDecimal, ToPrimitive, Zero}; use num_bigint::{BigInt, Sign}; use crate::decode::Decode; @@ -77,65 +77,64 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &BigDecimal) -> Result { - let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16); + if decimal.is_zero() { + return Ok(PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![], + }); + } // NOTE: this unfortunately copies the BigInt internally let (integer, exp) = decimal.as_bigint_and_exponent(); - // this routine is specifically optimized for base-10 - // FIXME: is there a way to iterate over the digits to avoid the Vec allocation - let (sign, base_10) = integer.to_radix_be(10); - - // weight is positive power of 10000 - // exp is the negative power of 10 - let weight_10 = base_10.len() as i64 - exp; - // scale is only nonzero when we have fractional digits // since `exp` is the _negative_ decimal exponent, it tells us // exactly what our scale should be let scale: i16 = cmp::max(0, exp).try_into()?; - // there's an implicit +1 offset in the interpretation - let weight: i16 = if weight_10 <= 0 { - weight_10 / 4 - 1 - } else { - // the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight) - (weight_10 - 1) / 4 - } - .try_into()?; + let (sign, uint) = integer.into_parts(); + let mut mantissa = uint.to_u128().unwrap(); - let digits_len = if base_10.len() % 4 != 0 { - base_10.len() / 4 + 1 - } else { - base_10.len() / 4 - }; + // If our scale is not a multiple of 4, we need to go to the next + // multiple. + let groups_diff = scale % 4; + if groups_diff > 0 { + let remainder = 4 - groups_diff as u32; + let power = 10u32.pow(remainder as u32) as u128; - let offset = weight_10.rem_euclid(4) as usize; - - let mut digits = Vec::with_capacity(digits_len); - - if let Some(first) = base_10.get(..offset) { - if offset != 0 { - digits.push(base_10_to_10000(first)); - } + mantissa = mantissa * power; } - if let Some(rest) = base_10.get(offset..) { - digits.extend( - rest.chunks(4) - .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), - ); + // Array to store max mantissa of Decimal in Postgres decimal format. + let mut digits = Vec::with_capacity(8); + + // Convert to base-10000. + while mantissa != 0 { + digits.push((mantissa % 10_000) as i16); + mantissa /= 10_000; } + // Change the endianness. + digits.reverse(); + + // Weight is number of digits on the left side of the decimal. + let digits_after_decimal = (scale + 3) as u16 / 4; + let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + + // Remove non-significant zeroes. while let Some(&0) = digits.last() { digits.pop(); } + let sign = match sign { + Sign::Plus | Sign::NoSign => PgNumericSign::Positive, + Sign::Minus => PgNumericSign::Negative, + }; + Ok(PgNumeric::Number { - sign: match sign { - Sign::Plus | Sign::NoSign => PgNumericSign::Positive, - Sign::Minus => PgNumericSign::Negative, - }, + sign, scale, weight, digits, diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index b5b373e40..a29104542 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -251,7 +251,7 @@ where impl<'q, DB, F, O, A> Map<'q, DB, F, A> where DB: Database, - F: Send + Sync + Fn(DB::Row) -> Result, + F: TryMapRow, O: Send + Unpin, A: 'q + Send + IntoArguments<'q, DB>, { @@ -277,7 +277,7 @@ where /// Execute multiple queries and return the generated results as a stream /// from each query, in a stream. pub fn fetch_many<'e, 'c: 'e, E>( - self, + mut self, executor: E, ) -> BoxStream<'e, Result, Error>> where @@ -294,7 +294,7 @@ where r#yield!(match v { Either::Left(v) => Either::Left(v), Either::Right(row) => { - Either::Right((self.mapper)(row)?) + Either::Right(self.mapper.try_map_row(row)?) } }); } @@ -333,7 +333,7 @@ where } /// Execute the query and returns at most one row. - pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + pub async fn fetch_optional<'e, 'c: 'e, E>(mut self, executor: E) -> Result, Error> where 'q: 'e, E: 'e + Executor<'c, Database = DB>, @@ -344,7 +344,7 @@ where let row = executor.fetch_optional(self.inner).await?; if let Some(row) = row { - (self.mapper)(row).map(Some) + self.mapper.try_map_row(row).map(Some) } else { Ok(None) } @@ -356,13 +356,13 @@ where // // See https://github.com/rust-lang/rust/issues/62529 -pub trait TryMapRow { +pub trait TryMapRow: Send { type Output: Unpin; fn try_map_row(&mut self, row: DB::Row) -> Result; } -pub trait MapRow { +pub trait MapRow: Send { type Output: Unpin; fn map_row(&mut self, row: DB::Row) -> Self::Output; @@ -449,7 +449,7 @@ macro_rules! impl_map_row { ($DB:ident, $R:ident) => { impl crate::query::MapRow<$DB> for F where - F: FnMut($R) -> O, + F: Send + FnMut($R) -> O, { type Output = O; @@ -460,7 +460,7 @@ macro_rules! impl_map_row { impl crate::query::TryMapRow<$DB> for F where - F: FnMut($R) -> Result, + F: Send + FnMut($R) -> Result, { type Output = O; diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index e79c54177..97e74efae 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -25,12 +25,12 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( let mut statement = statement?; // we start by finding the first statement that *can* return results - while let Some((statement, ..)) = statement.prepare(&mut conn.handle)? { - num_params += statement.bind_parameter_count(); + while let Some((stmt, ..)) = statement.prepare(&mut conn.handle)? { + num_params += stmt.bind_parameter_count(); let mut stepped = false; - let num = statement.column_count(); + let num = stmt.column_count(); if num == 0 { // no columns in this statement; skip continue; @@ -44,7 +44,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( // to [column_decltype] // if explain.. fails, ignore the failure and we'll have no fallback - let (fallback, fallback_nullable) = match explain(conn, statement.sql()).await { + let (fallback, fallback_nullable) = match explain(conn, stmt.sql()).await { Ok(v) => v, Err(err) => { log::debug!("describe: explain introspection failed: {}", err); @@ -54,24 +54,20 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( }; for col in 0..num { - let name = statement.column_name(col).to_owned(); + let name = stmt.column_name(col).to_owned(); - let type_info = if let Some(ty) = statement.column_decltype(col) { + let type_info = if let Some(ty) = stmt.column_decltype(col) { ty } else { // if that fails, we back up and attempt to step the statement // once *if* its read-only and then use [column_type] as a // fallback to [column_decltype] - if !stepped && statement.read_only() { + if !stepped && stmt.read_only() { stepped = true; - - conn.worker.execute(statement); - conn.worker.wake(); - - let _ = conn.worker.step(statement).await; + let _ = conn.worker.step(*stmt).await; } - let mut ty = statement.column_type_info(col); + let mut ty = stmt.column_type_info(col); if ty.0 == DataType::Null { if let Some(fallback) = fallback.get(col).cloned() { @@ -82,7 +78,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( ty }; - nullable.push(statement.column_nullable(col)?.or_else(|| { + nullable.push(stmt.column_nullable(col)?.or_else(|| { // if we do not *know* if this is nullable, check the EXPLAIN fallback fallback_nullable.get(col).copied().and_then(identity) })); diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index 954cac4f4..3311019c0 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -8,7 +8,7 @@ use crate::{ use libsqlite3_sys::{ sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, - SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, + SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; use sqlx_rt::blocking; use std::io; @@ -35,7 +35,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result Result Executor<'c> for &'c mut SqliteConnection { handle: ref mut conn, ref mut statements, ref mut statement, - ref worker, + ref mut worker, .. } = self; @@ -91,25 +91,18 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // keep track of how many arguments we have bound let mut num_arguments = 0; - while let Some((handle, columns, column_names, last_row_values)) = stmt.prepare(conn)? { + while let Some((stmt, columns, column_names, last_row_values)) = stmt.prepare(conn)? { // bind values to the statement - num_arguments += bind(handle, &arguments, num_arguments)?; - - // tell the worker about the new statement - worker.execute(handle); - - // wake up the worker if needed - // the worker parks its thread on async-std when not in use - worker.wake(); + num_arguments += bind(stmt, &arguments, num_arguments)?; loop { // save the rows from the _current_ position on the statement // and send them to the still-live row object - SqliteRow::inflate_if_needed(handle, &*columns, last_row_values.take()); + SqliteRow::inflate_if_needed(stmt, &*columns, last_row_values.take()); // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - let s = worker.step(handle).await?; + let s = worker.step(*stmt).await?; match s { Either::Left(changes) => { @@ -129,7 +122,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { Either::Right(()) => { let (row, weak_values_ref) = SqliteRow::current( - *handle, + *stmt, columns, column_names ); diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index a19cd9399..bcec602b3 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -2,7 +2,7 @@ use crate::error::Error; use crate::query_as::query_as; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteConnection, SqliteTypeInfo}; -use hashbrown::HashMap; +use crate::HashMap; use std::str::from_utf8; // affinity @@ -136,7 +136,8 @@ pub(super) async fn explain( } else if let Some(v) = r.get(&p2).copied() { // r[p3] = AGG ( r[p2] ) r.insert(p3, v); - n.insert(p3, n.get(&p2).copied().unwrap_or(true)); + let val = n.get(&p2).copied().unwrap_or(true); + n.insert(p3, val); } } @@ -151,7 +152,8 @@ pub(super) async fn explain( // r[p2] = r[p1] if let Some(v) = r.get(&p1).copied() { r.insert(p2, v); - n.insert(p2, n.get(&p1).copied().unwrap_or(true)); + let val = n.get(&p1).copied().unwrap_or(true); + n.insert(p2, val); } } @@ -165,7 +167,8 @@ pub(super) async fn explain( // r[p2] = NOT r[p1] if let Some(a) = r.get(&p1).copied() { r.insert(p2, a); - n.insert(p2, n.get(&p1).copied().unwrap_or(true)); + let val = n.get(&p1).copied().unwrap_or(true); + n.insert(p2, val); } } diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 5625367e5..e4c04f46d 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -106,8 +106,5 @@ impl Drop for SqliteConnection { // we must explicitly drop the statements as the drop-order in a struct is undefined self.statements.clear(); self.statement.take(); - - // we then explicitly close the worker - self.worker.close(); } } diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index fce099672..824e0bcb3 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -48,6 +48,7 @@ pub struct SqliteConnectOptions { pub(crate) create_if_missing: bool, pub(crate) journal_mode: SqliteJournalMode, pub(crate) foreign_keys: bool, + pub(crate) shared_cache: bool, pub(crate) statement_cache_capacity: usize, pub(crate) busy_timeout: Duration, } @@ -66,6 +67,7 @@ impl SqliteConnectOptions { read_only: false, create_if_missing: false, foreign_keys: true, + shared_cache: false, statement_cache_capacity: 100, journal_mode: SqliteJournalMode::Wal, busy_timeout: Duration::from_secs(5), diff --git a/sqlx-core/src/sqlite/options/parse.rs b/sqlx-core/src/sqlite/options/parse.rs index 119d7c3ba..7c21adf46 100644 --- a/sqlx-core/src/sqlite/options/parse.rs +++ b/sqlx-core/src/sqlite/options/parse.rs @@ -2,11 +2,14 @@ use crate::error::Error; use crate::sqlite::SqliteConnectOptions; use percent_encoding::percent_decode_str; use std::borrow::Cow; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; // https://www.sqlite.org/uri.html +static IN_MEMORY_DB_SEQ: AtomicUsize = AtomicUsize::new(0); + impl FromStr for SqliteConnectOptions { type Err = Error; @@ -24,6 +27,9 @@ impl FromStr for SqliteConnectOptions { if database == ":memory:" { options.in_memory = true; + options.shared_cache = true; + let seqno = IN_MEMORY_DB_SEQ.fetch_add(1, Ordering::Relaxed); + options.filename = Cow::Owned(PathBuf::from(format!("file:sqlx-in-memory-{}", seqno))); } else { // % decode to allow for `?` or `#` in the filename options.filename = Cow::Owned( @@ -58,6 +64,7 @@ impl FromStr for SqliteConnectOptions { "memory" => { options.in_memory = true; + options.shared_cache = true; } _ => { @@ -68,6 +75,25 @@ impl FromStr for SqliteConnectOptions { } } + // The cache query parameter specifies the cache behaviour across multiple + // connections to the same database within the process. A shared cache is + // essential for persisting data across connections to an in-memory database. + "cache" => match &*value { + "private" => { + options.shared_cache = false; + } + + "shared" => { + options.shared_cache = true; + } + + _ => { + return Err(Error::Configuration( + format!("unknown value {:?} for `cache`", value).into(), + )); + } + }, + _ => { return Err(Error::Configuration( format!( @@ -89,12 +115,19 @@ impl FromStr for SqliteConnectOptions { fn test_parse_in_memory() -> Result<(), Error> { let options: SqliteConnectOptions = "sqlite::memory:".parse()?; assert!(options.in_memory); + assert!(options.shared_cache); let options: SqliteConnectOptions = "sqlite://?mode=memory".parse()?; assert!(options.in_memory); + assert!(options.shared_cache); let options: SqliteConnectOptions = "sqlite://:memory:".parse()?; assert!(options.in_memory); + assert!(options.shared_cache); + + let options: SqliteConnectOptions = "sqlite://?mode=memory&cache=private".parse()?; + assert!(options.in_memory); + assert!(!options.shared_cache); Ok(()) } @@ -107,3 +140,12 @@ fn test_parse_read_only() -> Result<(), Error> { Ok(()) } + +#[test] +fn test_parse_shared_in_memory() -> Result<(), Error> { + let options: SqliteConnectOptions = "sqlite://a.db?cache=shared".parse()?; + assert!(options.shared_cache); + assert_eq!(&*options.filename.to_string_lossy(), "a.db"); + + Ok(()) +} diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 809e60ba5..84e5c0358 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -3,7 +3,7 @@ use std::slice; use std::sync::atomic::{AtomicPtr, Ordering}; use std::sync::{Arc, Weak}; -use hashbrown::HashMap; +use crate::HashMap; use crate::column::ColumnIndex; use crate::error::Error; diff --git a/sqlx-core/src/sqlite/statement/mod.rs b/sqlx-core/src/sqlite/statement/mod.rs index dfb3ce554..3ac3f8276 100644 --- a/sqlx-core/src/sqlite/statement/mod.rs +++ b/sqlx-core/src/sqlite/statement/mod.rs @@ -3,8 +3,8 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::{Sqlite, SqliteArguments, SqliteColumn, SqliteTypeInfo}; use crate::statement::Statement; +use crate::HashMap; use either::Either; -use hashbrown::HashMap; use std::borrow::Cow; use std::sync::Arc; diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 47f97e3e6..805e0e0b9 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,8 +3,8 @@ use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; use crate::sqlite::statement::StatementHandle; use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; +use crate::HashMap; use bytes::{Buf, Bytes}; -use hashbrown::HashMap; use libsqlite3_sys::{ sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 1d1b3085d..8b1d22997 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -1,19 +1,10 @@ use crate::error::Error; use crate::sqlite::statement::StatementHandle; +use crossbeam_channel::{unbounded, Sender}; use either::Either; -use libsqlite3_sys::sqlite3_stmt; +use futures_channel::oneshot; use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW}; -use sqlx_rt::yield_now; -use std::ptr::null_mut; -use std::sync::atomic::{spin_loop_hint, AtomicI32, AtomicPtr, Ordering}; -use std::sync::Arc; -use std::thread::{self, park, spawn, JoinHandle}; - -const STATE_CLOSE: i32 = -1; - -const STATE_READY: i32 = 0; - -const STATE_INITIAL: i32 = 1; +use std::thread; // Each SQLite connection has a dedicated thread. @@ -21,131 +12,52 @@ const STATE_INITIAL: i32 = 1; // OS resource usage. Low priority because a high concurrent load for SQLite3 is very // unlikely. -// TODO: Reduce atomic complexity. There must be a simpler way to do this that doesn't -// compromise performance. - pub(crate) struct StatementWorker { - statement: Arc>, - status: Arc, - handle: Option>, + tx: Sender, +} + +enum StatementWorkerCommand { + Step { + statement: StatementHandle, + tx: oneshot::Sender, Error>>, + }, } impl StatementWorker { pub(crate) fn new() -> Self { - let statement = Arc::new(AtomicPtr::new(null_mut::())); - let status = Arc::new(AtomicI32::new(STATE_INITIAL)); + let (tx, rx) = unbounded(); - let handle = spawn({ - let statement = Arc::clone(&statement); - let status = Arc::clone(&status); + thread::spawn(move || { + for cmd in rx { + match cmd { + StatementWorkerCommand::Step { statement, tx } => { + let status = unsafe { sqlite3_step(statement.0.as_ptr()) }; - move || { - // wait for the first command - park(); + let resp = match status { + SQLITE_ROW => Ok(Either::Right(())), + SQLITE_DONE => Ok(Either::Left(statement.changes())), + _ => Err(statement.last_error().into()), + }; - 'run: while status.load(Ordering::Acquire) >= 0 { - 'statement: loop { - match status.load(Ordering::Acquire) { - STATE_CLOSE => { - // worker has been dropped; get out - break 'run; - } - - STATE_READY => { - let statement = statement.load(Ordering::Acquire); - if statement.is_null() { - // we do not have the statement handle yet - thread::yield_now(); - continue; - } - - let v = unsafe { sqlite3_step(statement) }; - - status.store(v, Ordering::Release); - - if v == SQLITE_DONE { - // when a statement is _done_, we park the thread until - // we need it again - park(); - break 'statement; - } - } - - _ => { - // waits for the receiving end to be ready to receive the rows - // this should take less than 1 microsecond under most conditions - spin_loop_hint(); - } - } + let _ = tx.send(resp); } } } }); - Self { - handle: Some(handle), - statement, - status, - } + Self { tx } } - pub(crate) fn wake(&self) { - if let Some(handle) = &self.handle { - handle.thread().unpark(); - } - } + pub(crate) async fn step( + &mut self, + statement: StatementHandle, + ) -> Result, Error> { + let (tx, rx) = oneshot::channel(); - pub(crate) fn execute(&self, statement: &StatementHandle) { - // readies the worker to execute the statement - // for async-std, this unparks our dedicated thread + self.tx + .send(StatementWorkerCommand::Step { statement, tx }) + .map_err(|_| Error::WorkerCrashed)?; - self.statement - .store(statement.0.as_ptr(), Ordering::Release); - } - - pub(crate) async fn step(&self, statement: &StatementHandle) -> Result, Error> { - // storing <0> as a terminal in status releases the worker - // to proceed to the next [sqlite3_step] invocation - self.status.store(STATE_READY, Ordering::Release); - - // we then use a spin loop to wait for this to finish - // 99% of the time this should be < 1 Ξs - let status = loop { - let status = self - .status - .compare_and_swap(STATE_READY, STATE_READY, Ordering::AcqRel); - - if status != STATE_READY { - break status; - } - - yield_now().await; - }; - - match status { - // a row was found - SQLITE_ROW => Ok(Either::Right(())), - - // reached the end of the query results, - // emit the # of changes - SQLITE_DONE => Ok(Either::Left(statement.changes())), - - _ => Err(statement.last_error().into()), - } - } - - pub(crate) fn close(&mut self) { - self.status.store(STATE_CLOSE, Ordering::Release); - - if let Some(handle) = self.handle.take() { - handle.thread().unpark(); - handle.join().unwrap(); - } - } -} - -impl Drop for StatementWorker { - fn drop(&mut self) { - self.close(); + rx.await.map_err(|_| Error::WorkerCrashed)? } } diff --git a/sqlx-core/src/sqlite/types/chrono.rs b/sqlx-core/src/sqlite/types/chrono.rs index c94ed8164..cd01c3bde 100644 --- a/sqlx-core/src/sqlite/types/chrono.rs +++ b/sqlx-core/src/sqlite/types/chrono.rs @@ -169,7 +169,7 @@ impl<'r> Decode<'r, Sqlite> for NaiveDateTime { impl<'r> Decode<'r, Sqlite> for NaiveDate { fn decode(value: SqliteValueRef<'r>) -> Result { - Ok(NaiveDate::parse_from_str("%F", value.text()?)?) + Ok(NaiveDate::parse_from_str(value.text()?, "%F")?) } } diff --git a/sqlx-core/src/sqlite/types/json.rs b/sqlx-core/src/sqlite/types/json.rs new file mode 100644 index 000000000..8f6bf6dc8 --- /dev/null +++ b/sqlx-core/src/sqlite/types/json.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::sqlite::{ + type_info::DataType, Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef, +}; +use crate::types::{Json, Type}; + +impl Type for Json { + fn type_info() -> SqliteTypeInfo { + SqliteTypeInfo(DataType::Text) + } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Encode<'_, Sqlite> for Json +where + T: Serialize, +{ + fn encode_by_ref(&self, buf: &mut Vec>) -> IsNull { + let json_string_value = + serde_json::to_string(&self.0).expect("serde_json failed to convert to string"); + + Encode::::encode(json_string_value, buf) + } +} + +impl<'r, T> Decode<'r, Sqlite> for Json +where + T: 'r + Deserialize<'r>, +{ + fn decode(value: SqliteValueRef<'r>) -> Result { + let string_value = <&str as Decode>::decode(value)?; + + serde_json::from_str(&string_value) + .map(Json) + .map_err(Into::into) + } +} diff --git a/sqlx-core/src/sqlite/types/mod.rs b/sqlx-core/src/sqlite/types/mod.rs index 9553d5a26..84ca61549 100644 --- a/sqlx-core/src/sqlite/types/mod.rs +++ b/sqlx-core/src/sqlite/types/mod.rs @@ -35,4 +35,6 @@ mod bytes; mod chrono; mod float; mod int; +#[cfg(feature = "json")] +mod json; mod str; diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 92a512ac0..4e21d115d 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -98,6 +98,14 @@ pub use json::Json; /// struct UserId(i64); /// ``` /// +/// ##### Attributes +/// +/// * `#[sqlx(rename = "")]` on struct definition: instead of inferring the SQL type name from the inner +/// field (in the above case, `BIGINT`), explicitly set it to `` instead. May trigger +/// errors or unexpected behavior if the encoding of the given type is different than that of the +/// inferred type (e.g. if you rename the above to `VARCHAR`). +/// Affects Postgres only. +/// /// ### Enumeration /// /// Enumerations may be defined in Rust and can match SQL by diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index a29b56cc2..049d39644 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -88,9 +88,8 @@ impl_database_ext! { #[cfg(feature = "chrono")] Vec | &[sqlx::types::chrono::NaiveDateTime], - // TODO - // #[cfg(feature = "chrono")] - // Vec> | &[sqlx::types::chrono::DateTime<_>], + #[cfg(feature = "chrono")] + Vec> | &[sqlx::types::chrono::DateTime<_>], #[cfg(feature = "time")] Vec | &[sqlx::types::time::Time], @@ -113,6 +112,69 @@ impl_database_ext! { #[cfg(feature = "json")] Vec | &[serde_json::Value], + // Ranges + + sqlx::postgres::types::PgRange, + sqlx::postgres::types::PgRange, + + #[cfg(feature = "bigdecimal")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "chrono")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "chrono")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "chrono")] + sqlx::postgres::types::PgRange> | + sqlx::postgres::types::PgRange>, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + #[cfg(feature = "time")] + sqlx::postgres::types::PgRange, + + // Range arrays + + Vec> | &[sqlx::postgres::types::PgRange], + Vec> | &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "bigdecimal")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "chrono")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "chrono")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "chrono")] + Vec>> | + Vec>>, + + #[cfg(feature = "chrono")] + &[sqlx::postgres::types::PgRange>] | + &[sqlx::postgres::types::PgRange>], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], + + #[cfg(feature = "time")] + Vec> | + &[sqlx::postgres::types::PgRange], }, ParamChecking::Strong, feature-types: info => info.__type_feature_gate(), diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index f65f99e16..4d2dadbb0 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -32,6 +32,7 @@ pub enum RenameAll { SnakeCase, UpperCase, ScreamingSnakeCase, + KebabCase, } pub struct SqlxContainerAttributes { @@ -75,6 +76,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result RenameAll::SnakeCase, "UPPERCASE" => RenameAll::UpperCase, "SCREAMING_SNAKE_CASE" => RenameAll::ScreamingSnakeCase, + "kebab-case" => RenameAll::KebabCase, _ => fail!(meta, "unexpected value for rename_all"), }; @@ -121,13 +123,13 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { + Meta::List(list) => { for value in list.nested.iter() { match value { NestedMeta::Meta(meta) => match meta { diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 51435ce45..0785c894f 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -10,7 +10,7 @@ pub(crate) use r#type::expand_derive_type; pub(crate) use row::expand_derive_from_row; use self::attributes::RenameAll; -use heck::{ShoutySnakeCase, SnakeCase}; +use heck::{KebabCase, ShoutySnakeCase, SnakeCase}; use std::iter::FromIterator; use syn::DeriveInput; @@ -34,5 +34,6 @@ pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { RenameAll::SnakeCase => s.to_snake_case(), RenameAll::UpperCase => s.to_uppercase(), RenameAll::ScreamingSnakeCase => s.to_shouty_snake_case(), + RenameAll::KebabCase => s.to_kebab_case(), } } diff --git a/sqlx-macros/src/derives/row.rs b/sqlx-macros/src/derives/row.rs index 8f03e6875..2409c7fbf 100644 --- a/sqlx-macros/src/derives/row.rs +++ b/sqlx-macros/src/derives/row.rs @@ -2,7 +2,7 @@ use proc_macro2::Span; use quote::quote; use syn::{ parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Field, - Fields, FieldsNamed, Lifetime, Stmt, + Fields, FieldsNamed, FieldsUnnamed, Lifetime, Stmt, }; use super::attributes::parse_child_attributes; @@ -15,12 +15,9 @@ pub fn expand_derive_from_row(input: &DeriveInput) -> syn::Result expand_derive_from_row_struct(input, named), Data::Struct(DataStruct { - fields: Fields::Unnamed(_), + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. - }) => Err(syn::Error::new_spanned( - input, - "tuple structs are not supported", - )), + }) => expand_derive_from_row_struct_unnamed(input, unnamed), Data::Struct(DataStruct { fields: Fields::Unit, @@ -111,3 +108,55 @@ fn expand_derive_from_row_struct( } )) } + +fn expand_derive_from_row_struct_unnamed( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let ident = &input.ident; + + let generics = &input.generics; + + let (lifetime, provided) = generics + .lifetimes() + .next() + .map(|def| (def.lifetime.clone(), false)) + .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); + + let (_, ty_generics, _) = generics.split_for_impl(); + + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(R: sqlx::Row)); + + if provided { + generics.params.insert(0, parse_quote!(#lifetime)); + } + + let predicates = &mut generics.make_where_clause().predicates; + + predicates.push(parse_quote!(usize: sqlx::ColumnIndex)); + + for field in fields { + let ty = &field.ty; + + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let gets = fields + .iter() + .enumerate() + .map(|(idx, _)| quote!(row.try_get(#idx)?)); + + Ok(quote!( + impl #impl_generics sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { + fn from_row(row: &#lifetime R) -> sqlx::Result { + Ok(#ident ( + #(#gets),* + )) + } + } + )) +} diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index de5b8164b..eb568b6f5 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -36,7 +36,9 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { // if `dotenv` wasn't initialized by the above we make sure to do it here match ( - dotenv::var("SQLX_OFFLINE").is_ok(), + dotenv::var("SQLX_OFFLINE") + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false), dotenv::var("DATABASE_URL"), ) { (false, Ok(db_url)) => expand_from_db(input, &db_url), @@ -80,7 +82,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result crate::Result crate::Result crate::Result expand_with_data( input, QueryData::::from_dyn_data(query_data)?, + true, ), #[cfg(feature = "mysql")] sqlx_core::mysql::MySql::NAME => expand_with_data( input, QueryData::::from_dyn_data(query_data)?, + true, ), #[cfg(feature = "sqlite")] sqlx_core::sqlite::Sqlite::NAME => expand_with_data( input, QueryData::::from_dyn_data(query_data)?, + true, ), _ => Err(format!( "found query data for {} but the feature for that database was not enabled", @@ -182,6 +187,7 @@ impl DescribeExt for Describe {} fn expand_with_data( input: QueryMacroInput, data: QueryData, + #[allow(unused_variables)] offline: bool, ) -> crate::Result where Describe: DescribeExt, @@ -273,8 +279,10 @@ where #output }}; + // Store query metadata only if offline support is enabled but the current build is online. + // If the build is offline, the cache is our input so it's pointless to also write data for it. #[cfg(feature = "offline")] - { + if !offline { let mut save_dir = std::path::PathBuf::from( env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| "target/".into()), ); diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index 811b7b6e2..4f784e69f 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -16,15 +16,15 @@ compile_error!( "only one of 'runtime-actix', 'runtime-async-std' or 'runtime-tokio' features can be enabled" ); -pub use native_tls; +pub use native_tls::{self, Error as TlsError}; // // Actix *OR* Tokio // #[cfg(all( - not(feature = "runtime-async-std"), any(feature = "runtime-tokio", feature = "runtime-actix"), + not(feature = "runtime-async-std"), ))] pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, net::TcpStream, @@ -33,11 +33,14 @@ pub use tokio::{ #[cfg(all( unix, - not(feature = "runtime-async-std"), any(feature = "runtime-tokio", feature = "runtime-actix"), + not(feature = "runtime-async-std"), ))] pub use tokio::net::UnixStream; +#[cfg(all(feature = "tokio-native-tls", not(feature = "async-native-tls")))] +pub use tokio_native_tls::{TlsConnector, TlsStream}; + // // tokio // @@ -53,12 +56,6 @@ macro_rules! blocking { }; } -#[cfg(all(feature = "tokio-native-tls", not(feature = "async-native-tls")))] -pub use tokio_native_tls::{TlsConnector, TlsStream}; - -#[cfg(all(feature = "tokio-native-tls", not(feature = "async-native-tls")))] -pub use native_tls::Error as TlsError; - // // actix // @@ -113,7 +110,7 @@ macro_rules! blocking { pub use async_std::os::unix::net::UnixStream; #[cfg(all(feature = "async-native-tls", not(feature = "tokio-native-tls")))] -pub use async_native_tls::{Error as TlsError, TlsConnector, TlsStream}; +pub use async_native_tls::{TlsConnector, TlsStream}; #[cfg(all( feature = "runtime-async-std", @@ -155,7 +152,6 @@ mod tokio_runtime { .expect("failed to initialize Tokio runtime") }); - #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))] pub fn block_on(future: F) -> F::Output { RUNTIME.enter(|| RUNTIME.handle().block_on(future)) } diff --git a/src/macros.rs b/src/macros.rs index c165a9525..5b843dca5 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -604,7 +604,7 @@ macro_rules! query_file_as_unchecked ( /// ```rust,ignore /// use sqlx::migrate::Migrator; /// -/// static MIGRATOR: Migrator = sqlx::migrate!(); // defaults to "migrations" +/// static MIGRATOR: Migrator = sqlx::migrate!(); // defaults to "./migrations" /// ``` /// /// The directory must be relative to the project root (the directory containing `Cargo.toml`), @@ -618,6 +618,6 @@ macro_rules! migrate { }}; () => {{ - $crate::sqlx_macros::migrate!("migrations") + $crate::sqlx_macros::migrate!("./migrations") }}; } diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 1bcd1c102..fa69ee391 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -333,3 +333,57 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { Ok(()) } + +// repro is more reliable with the basic scheduler used by `#[tokio::test]` +#[cfg(feature = "runtime-tokio")] +#[tokio::test] +async fn test_issue_622() -> anyhow::Result<()> { + use std::time::Instant; + + setup_if_needed(); + + let pool = MySqlPoolOptions::new() + .max_connections(1) // also fails with higher counts, e.g. 5 + .connect(&std::env::var("DATABASE_URL").unwrap()) + .await?; + + println!("pool state: {:?}", pool); + + let mut handles = vec![]; + + // given repro spawned 100 tasks but I found it reliably reproduced with 3 + for i in 0..3 { + let pool = pool.clone(); + + handles.push(sqlx_rt::spawn(async move { + { + let mut conn = pool.acquire().await.unwrap(); + + let _ = sqlx::query("SELECT 1").fetch_one(&mut conn).await.unwrap(); + + // conn gets dropped here and should be returned to the pool + } + + // (do some other work here without holding on to a connection) + // this actually fixes the issue, depending on the timeout used + // sqlx_rt::sleep(Duration::from_millis(500)).await; + + { + let start = Instant::now(); + match pool.acquire().await { + Ok(conn) => { + println!("{} acquire took {:?}", i, start.elapsed()); + drop(conn); + } + Err(e) => panic!("{} acquire returned error: {} pool state: {:?}", i, e, pool), + } + } + + Result::<(), anyhow::Error>::Ok(()) + })); + } + + futures::future::try_join_all(handles).await?; + + Ok(()) +} diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index bd45d8433..a218db3ef 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -66,6 +66,14 @@ enum ColorScreamingSnake { BlueBlack, } +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(rename = "color-kebab-case")] +#[sqlx(rename_all = "kebab-case")] +enum ColorKebabCase { + RedGreen, + BlueBlack, +} + // "Strong" enum can map to a custom type #[derive(PartialEq, Debug, sqlx::Type)] #[sqlx(rename = "mood")] @@ -133,11 +141,13 @@ DROP TYPE IF EXISTS color_lower CASCADE; DROP TYPE IF EXISTS color_snake CASCADE; DROP TYPE IF EXISTS color_upper CASCADE; DROP TYPE IF EXISTS color_screaming_snake CASCADE; +DROP TYPE IF EXISTS "color-kebab-case" CASCADE; CREATE TYPE color_lower AS ENUM ( 'red', 'green', 'blue' ); CREATE TYPE color_snake AS ENUM ( 'red_green', 'blue_black' ); CREATE TYPE color_upper AS ENUM ( 'RED', 'GREEN', 'BLUE' ); CREATE TYPE color_screaming_snake AS ENUM ( 'RED_GREEN', 'BLUE_BLACK' ); +CREATE TYPE "color-kebab-case" AS ENUM ( 'red-green', 'blue-black' ); CREATE TABLE people ( id serial PRIMARY KEY, @@ -264,6 +274,18 @@ SELECT id, mood FROM people WHERE id = $1 assert!(rec.0); assert_eq!(rec.1, ColorScreamingSnake::RedGreen); + let rec: (bool, ColorKebabCase) = sqlx::query_as( + " + SELECT $1 = 'red-green'::\"color-kebab-case\", $1 + ", + ) + .bind(&ColorKebabCase::RedGreen) + .fetch_one(&mut conn) + .await?; + + assert!(rec.0); + assert_eq!(rec.1, ColorKebabCase::RedGreen); + Ok(()) } @@ -404,6 +426,44 @@ async fn test_from_row_with_rename() -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "macros")] +#[sqlx_macros::test] +async fn test_from_row_tuple() -> anyhow::Result<()> { + let mut conn = new::().await?; + + #[derive(Debug, sqlx::FromRow)] + struct Account(i32, String); + + let account: Account = sqlx::query_as( + "SELECT * from (VALUES (1, 'Herp Derpinson')) accounts(id, name) where id = $1", + ) + .bind(1_i32) + .fetch_one(&mut conn) + .await?; + + assert_eq!(account.0, 1); + assert_eq!(account.1, "Herp Derpinson"); + + // A _single_ lifetime may be used but only when using the lowest-level API currently (Query::fetch) + + #[derive(sqlx::FromRow)] + struct RefAccount<'a>(i32, &'a str); + + let mut cursor = sqlx::query( + "SELECT * from (VALUES (1, 'Herp Derpinson')) accounts(id, name) where id = $1", + ) + .bind(1_i32) + .fetch(&mut conn); + + let row = cursor.try_next().await?.unwrap(); + let account = RefAccount::from_row(&row)?; + + assert_eq!(account.0, 1); + assert_eq!(account.1, "Herp Derpinson"); + + Ok(()) +} + #[cfg(feature = "macros")] #[sqlx_macros::test] async fn test_default() -> anyhow::Result<()> { diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index bb7ba78fb..f64a6b0a5 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -709,3 +709,57 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { Ok(()) } + +// repro is more reliable with the basic scheduler used by `#[tokio::test]` +#[cfg(feature = "runtime-tokio")] +#[tokio::test] +async fn test_issue_622() -> anyhow::Result<()> { + use std::time::Instant; + + setup_if_needed(); + + let pool = PgPoolOptions::new() + .max_connections(1) // also fails with higher counts, e.g. 5 + .connect(&std::env::var("DATABASE_URL").unwrap()) + .await?; + + println!("pool state: {:?}", pool); + + let mut handles = vec![]; + + // given repro spawned 100 tasks but I found it reliably reproduced with 3 + for i in 0..3 { + let pool = pool.clone(); + + handles.push(sqlx_rt::spawn(async move { + { + let mut conn = pool.acquire().await.unwrap(); + + let _ = sqlx::query("SELECT 1").fetch_one(&mut conn).await.unwrap(); + + // conn gets dropped here and should be returned to the pool + } + + // (do some other work here without holding on to a connection) + // this actually fixes the issue, depending on the timeout used + // sqlx_rt::sleep(Duration::from_millis(500)).await; + + { + let start = Instant::now(); + match pool.acquire().await { + Ok(conn) => { + println!("{} acquire took {:?}", i, start.elapsed()); + drop(conn); + } + Err(e) => panic!("{} acquire returned error: {} pool state: {:?}", i, e, pool), + } + } + + Result::<(), anyhow::Error>::Ok(()) + })); + } + + futures::future::try_join_all(handles).await?; + + Ok(()) +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 9526f357d..a0aa64eb6 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -395,7 +395,15 @@ test_type!(bigdecimal(Postgres, "1::numeric" == "1".parse::().unwrap(), "10000::numeric" == "10000".parse::().unwrap(), "0.1::numeric" == "0.1".parse::().unwrap(), + "0.01::numeric" == "0.01".parse::().unwrap(), + "0.012::numeric" == "0.012".parse::().unwrap(), + "0.0123::numeric" == "0.0123".parse::().unwrap(), "0.01234::numeric" == "0.01234".parse::().unwrap(), + "0.012345::numeric" == "0.012345".parse::().unwrap(), + "0.0123456::numeric" == "0.0123456".parse::().unwrap(), + "0.01234567::numeric" == "0.01234567".parse::().unwrap(), + "0.012345678::numeric" == "0.012345678".parse::().unwrap(), + "0.0123456789::numeric" == "0.0123456789".parse::().unwrap(), "12.34::numeric" == "12.34".parse::().unwrap(), "12345.6789::numeric" == "12345.6789".parse::().unwrap(), )); diff --git a/tests/sqlite/types.rs b/tests/sqlite/types.rs index a64572ea5..61cd13a5e 100644 --- a/tests/sqlite/types.rs +++ b/tests/sqlite/types.rs @@ -1,4 +1,6 @@ -use sqlx::sqlite::Sqlite; +use sqlx::sqlite::{Sqlite, SqliteRow}; +use sqlx_core::row::Row; +use sqlx_test::new; use sqlx_test::test_type; test_type!(null>(Sqlite, @@ -32,6 +34,61 @@ test_type!(bytes>(Sqlite, == vec![0_u8, 0, 0, 0, 0x52] )); +#[cfg(feature = "json")] +mod json_tests { + use super::*; + use serde_json::{json, Value as JsonValue}; + use sqlx::types::Json; + use sqlx_test::test_type; + + test_type!(json( + Sqlite, + "'\"Hello, World\"'" == json!("Hello, World"), + "'\"😎\"'" == json!("😎"), + "'\"🙋‍♀ïļ\"'" == json!("🙋‍♀ïļ"), + "'[\"Hello\",\"World!\"]'" == json!(["Hello", "World!"]) + )); + + #[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)] + struct Friend { + name: String, + age: u32, + } + + test_type!(json_struct>( + Sqlite, + "\'{\"name\":\"Joe\",\"age\":33}\'" == Json(Friend { name: "Joe".to_string(), age: 33 }) + )); + + // NOTE: This is testing recursive (and transparent) usage of the `Json` wrapper. You don't + // need to wrap the Vec in Json<_> to make the example work. + + #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] + struct Customer { + json_column: Json>, + } + + test_type!(json_struct_json_column>( + Sqlite, + "\'{\"json_column\":[1,2]}\'" == Json(Customer { json_column: Json(vec![1, 2]) }) + )); + + #[sqlx_macros::test] + async fn it_json_extracts() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let value = sqlx::query("select JSON_EXTRACT(JSON('{ \"number\": 42 }'), '$.number') = ?1") + .bind(42_i32) + .try_map(|row: SqliteRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(true, value); + + Ok(()) + } +} + #[cfg(feature = "chrono")] mod chrono { use super::*;