diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a86c2873..6d193065 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,2 +1,50 @@ + + ### Does your PR solve an issue? -### Delete this text and add "fixes #(issue number)" +Delete this text and add "fixes #(issue number)". + +Do *not* just list issue numbers here as they will not be automatically closed on merging this pull request unless prefixed with "fixes" or "closes". + +### Is this a breaking change? +Delete this text and answer yes/no and explain. + +If yes, this pull request will need to wait for the next major release (`0.{x + 1}.0`) + +Behavior changes _can_ be breaking if significant enough. +Consider [Hyrum's Law](https://www.hyrumslaw.com/): + +> With a sufficient number of users of an API, +> it does not matter what you promise in the contract: +> all observable behaviors of your system +> will be depended on by somebody. diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 1a91e1fa..3f1f44d3 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -39,7 +39,7 @@ jobs: - run: > cargo clippy --no-default-features - --features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros + --features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros -- -D warnings # Run beta for new warnings but don't break the build. @@ -47,7 +47,7 @@ jobs: - run: > cargo +beta clippy --no-default-features - --features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros + --features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros --target-dir target/beta/ check-minimal-versions: @@ -140,7 +140,7 @@ jobs: - run: > cargo test --no-default-features - --features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }} + --features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }} -- --test-threads=1 env: diff --git a/Cargo.lock b/Cargo.lock index 1f4674ab..07754e7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,35 +2,20 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "accounts" -version = "0.1.0" -dependencies = [ - "argon2 0.5.3", - "password-hash 0.5.0", - "rand", - "serde", - "sqlx", - "thiserror 1.0.69", - "time", - "tokio", - "uuid", -] - [[package]] name = "addr2line" -version = "0.21.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "ahash" @@ -142,19 +127,7 @@ checksum = "db4ce4441f99dbd377ca8a8f57b698c44d0d6e712d8329b5040da5a64aa1ce73" dependencies = [ "base64ct", "blake2", - "password-hash 0.4.2", -] - -[[package]] -name = "argon2" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" -dependencies = [ - "base64ct", - "blake2", - "cpufeatures", - "password-hash 0.5.0", + "password-hash", ] [[package]] @@ -396,16 +369,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", - "axum-core 0.2.9", - "axum-macros 0.2.3", + "axum-core", + "axum-macros", "bitflags 1.3.2", "bytes", "futures-util", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.32", + "http", + "http-body", + "hyper", "itoa", - "matchit 0.5.0", + "matchit", "memchr", "mime", "percent-encoding", @@ -413,49 +386,14 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 0.1.2", + "sync_wrapper", "tokio", - "tower 0.4.13", + "tower", "tower-http", "tower-layer", "tower-service", ] -[[package]] -name = "axum" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" -dependencies = [ - "axum-core 0.5.0", - "axum-macros 0.5.0", - "bytes", - "form_urlencoded", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "http-body-util", - "hyper 1.6.0", - "hyper-util", - "itoa", - "matchit 0.8.4", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", - "sync_wrapper 1.0.2", - "tokio", - "tower 0.5.2", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "axum-core" version = "0.2.9" @@ -465,33 +403,13 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 0.2.12", - "http-body 0.4.6", + "http", + "http-body", "mime", "tower-layer", "tower-service", ] -[[package]] -name = "axum-core" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" -dependencies = [ - "bytes", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper 1.0.2", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "axum-macros" version = "0.2.3" @@ -504,17 +422,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "axum-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.96", -] - [[package]] name = "backoff" version = "0.4.0" @@ -531,17 +438,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -880,9 +787,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.30" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -890,9 +797,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -911,9 +818,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.28" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -927,17 +834,6 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" -[[package]] -name = "clipboard-win" -version = "4.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362" -dependencies = [ - "error-code", - "str-buf", - "winapi", -] - [[package]] name = "cmake" version = "0.1.52" @@ -947,33 +843,6 @@ dependencies = [ "cc", ] -[[package]] -name = "color-eyre" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" -dependencies = [ - "backtrace", - "color-spantrace", - "eyre", - "indenter", - "once_cell", - "owo-colors", - "tracing-error", -] - -[[package]] -name = "color-spantrace" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" -dependencies = [ - "once_cell", - "owo-colors", - "tracing-core", - "tracing-error", -] - [[package]] name = "colorchoice" version = "1.0.3" @@ -1240,6 +1109,17 @@ dependencies = [ "serde", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "thiserror 1.0.69", +] + [[package]] name = "difflib" version = "0.4.0" @@ -1258,27 +1138,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "displaydoc" version = "0.2.5" @@ -1329,12 +1188,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" -[[package]] -name = "endian-type" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" - [[package]] name = "env_filter" version = "0.1.3" @@ -1384,16 +1237,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "error-code" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21" -dependencies = [ - "libc", - "str-buf", -] - [[package]] name = "etcetera" version = "0.8.0" @@ -1432,16 +1275,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "eyre" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" -dependencies = [ - "indenter", - "once_cell", -] - [[package]] name = "fastrand" version = "1.9.0" @@ -1457,17 +1290,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "fd-lock" -version = "3.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" -dependencies = [ - "cfg-if", - "rustix 0.38.43", - "windows-sys 0.48.0", -] - [[package]] name = "filetime" version = "0.2.25" @@ -1705,9 +1527,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1834,17 +1656,6 @@ dependencies = [ "itoa", ] -[[package]] -name = "http" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http-body" version = "0.4.6" @@ -1852,30 +1663,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http 0.2.12", - "pin-project-lite", -] - -[[package]] -name = "http-body" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" -dependencies = [ - "bytes", - "http 1.2.0", -] - -[[package]] -name = "http-body-util" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" -dependencies = [ - "bytes", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", "pin-project-lite", ] @@ -1913,8 +1701,8 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 0.2.12", - "http-body 0.4.6", + "http", + "http-body", "httparse", "httpdate", "itoa", @@ -1926,41 +1714,6 @@ dependencies = [ "want", ] -[[package]] -name = "hyper" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "smallvec", - "tokio", -] - -[[package]] -name = "hyper-util" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" -dependencies = [ - "bytes", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "hyper 1.6.0", - "pin-project-lite", - "tokio", - "tower-service", -] - [[package]] name = "iana-time-zone" version = "0.1.61" @@ -2145,12 +1898,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" -[[package]] -name = "indenter" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" - [[package]] name = "indexmap" version = "1.9.3" @@ -2391,7 +2138,7 @@ version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8836fae9d0d4be2c8b4efcdd79e828a2faa058a90d005abf42f91cac5493a08e" dependencies = [ - "nix 0.28.0", + "nix", "winapi", ] @@ -2401,12 +2148,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" -[[package]] -name = "matchit" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" - [[package]] name = "md-5" version = "0.10.6" @@ -2423,15 +2164,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memoffset" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" -dependencies = [ - "autocfg", -] - [[package]] name = "memoffset" version = "0.9.1" @@ -2455,11 +2187,11 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" dependencies = [ - "adler", + "adler2", ] [[package]] @@ -2529,28 +2261,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "nibble_vec" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" -dependencies = [ - "smallvec", -] - -[[package]] -name = "nix" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3790c00a0150112de0f4cd161e3d7fc4b2d8a5542ffc35f099a2562aecb35c" -dependencies = [ - "bitflags 1.3.2", - "cc", - "cfg-if", - "libc", - "memoffset 0.6.5", -] - [[package]] name = "nix" version = "0.28.0" @@ -2561,7 +2271,7 @@ dependencies = [ "cfg-if", "cfg_aliases 0.1.1", "libc", - "memoffset 0.9.1", + "memoffset", ] [[package]] @@ -2580,16 +2290,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - [[package]] name = "num-bigint" version = "0.4.6" @@ -2655,9 +2355,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -2728,18 +2428,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - -[[package]] -name = "owo-colors" -version = "3.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" - [[package]] name = "parking" version = "2.2.1" @@ -2780,34 +2468,12 @@ dependencies = [ "subtle", ] -[[package]] -name = "password-hash" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" -dependencies = [ - "base64ct", - "rand_core", - "subtle", -] - [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "payments" -version = "0.1.0" -dependencies = [ - "accounts", - "rust_decimal", - "sqlx", - "time", - "uuid", -] - [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -3060,15 +2726,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "promptly" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9acbc6c5a5b029fe58342f58445acb00ccfe24624e538894bc2f04ce112980ba" -dependencies = [ - "rustyline", -] - [[package]] name = "ptr_meta" version = "0.1.4" @@ -3104,16 +2761,6 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" -[[package]] -name = "radix_trie" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" -dependencies = [ - "endian-type", - "nibble_vec", -] - [[package]] name = "rand" version = "0.8.5" @@ -3203,17 +2850,6 @@ dependencies = [ "bitflags 2.7.0", ] -[[package]] -name = "redox_users" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" -dependencies = [ - "getrandom", - "libredox", - "thiserror 1.0.69", -] - [[package]] name = "regex" version = "1.11.1" @@ -3398,15 +3034,6 @@ dependencies = [ "security-framework 3.2.0", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.10.1" @@ -3431,30 +3058,6 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" -[[package]] -name = "rustyline" -version = "9.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db7826789c0e25614b03e5a54a0717a86f9ff6e6e5247f92b369472869320039" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "clipboard-win", - "dirs-next", - "fd-lock", - "libc", - "log", - "memchr", - "nix 0.23.2", - "radix_trie", - "scopeguard", - "smallvec", - "unicode-segmentation", - "unicode-width 0.1.14", - "utf8parse", - "winapi", -] - [[package]] name = "ryu" version = "1.0.18" @@ -3538,18 +3141,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.218" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", @@ -3577,16 +3180,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_path_to_error" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" -dependencies = [ - "itoa", - "serde", -] - [[package]] name = "serde_spanned" version = "0.6.8" @@ -3659,13 +3252,10 @@ dependencies = [ ] [[package]] -name = "sharded-slab" -version = "0.1.7" +name = "shell-words" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" [[package]] name = "shlex" @@ -3812,19 +3402,18 @@ version = "0.8.3" dependencies = [ "anyhow", "assert_cmd", - "async-trait", "backoff", "cargo_metadata", "chrono", "clap", "clap_complete", "console", + "dialoguer", "dotenvy", "filetime", "futures", "glob", "openssl", - "promptly", "serde_json", "sqlx", "tempfile", @@ -3837,6 +3426,7 @@ version = "0.8.3" dependencies = [ "async-io 1.13.0", "async-std", + "base64 0.22.1", "bigdecimal", "bit-vec", "bstr", @@ -3864,7 +3454,6 @@ dependencies = [ "rust_decimal", "rustls", "rustls-native-certs", - "rustls-pemfile", "serde", "serde_json", "sha2", @@ -3874,7 +3463,6 @@ dependencies = [ "time", "tokio", "tokio-stream", - "toml", "tracing", "url", "uuid", @@ -3897,8 +3485,8 @@ name = "sqlx-example-postgres-axum-social" version = "0.1.0" dependencies = [ "anyhow", - "argon2 0.4.1", - "axum 0.5.17", + "argon2", + "axum", "dotenvy", "once_cell", "rand", @@ -3910,7 +3498,7 @@ dependencies = [ "thiserror 2.0.11", "time", "tokio", - "tower 0.4.13", + "tower", "tracing", "uuid", "validator", @@ -3975,22 +3563,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "sqlx-example-postgres-multi-tenant" -version = "0.8.3" -dependencies = [ - "accounts", - "axum 0.8.1", - "color-eyre", - "dotenvy", - "payments", - "rand", - "rust_decimal", - "sqlx", - "tokio", - "tracing-subscriber", -] - [[package]] name = "sqlx-example-postgres-todos" version = "0.1.0" @@ -4170,6 +3742,7 @@ dependencies = [ "serde_urlencoded", "sqlx", "sqlx-core", + "thiserror 2.0.11", "time", "tracing", "url", @@ -4208,12 +3781,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "str-buf" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" - [[package]] name = "stringprep" version = "0.1.5" @@ -4365,12 +3932,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" -[[package]] -name = "sync_wrapper" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" - [[package]] name = "synstructure" version = "0.13.1" @@ -4463,16 +4024,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "thread_local" -version = "1.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" -dependencies = [ - "cfg-if", - "once_cell", -] - [[package]] name = "time" version = "0.3.37" @@ -4629,22 +4180,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tower" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" -dependencies = [ - "futures-core", - "futures-util", - "pin-project-lite", - "sync_wrapper 1.0.2", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tower-http" version = "0.3.5" @@ -4655,11 +4190,11 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 0.2.12", - "http-body 0.4.6", + "http", + "http-body", "http-range-header", "pin-project-lite", - "tower 0.4.13", + "tower", "tower-layer", "tower-service", ] @@ -4706,42 +4241,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", - "valuable", -] - -[[package]] -name = "tracing-error" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b1581020d7a273442f5b45074a6a57d5757ad0a47dac0e9f0bd57b81936f3db" -dependencies = [ - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" -dependencies = [ - "nu-ansi-term", - "sharded-slab", - "smallvec", - "thread_local", - "tracing-core", - "tracing-log", ] [[package]] @@ -4870,9 +4369,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" +checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" dependencies = [ "serde", ] @@ -4919,12 +4418,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - [[package]] name = "value-bag" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 8769b56e..4d74d38c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,9 +76,10 @@ _unstable-all-types = [ "mac_address", "uuid", "bit-vec", + "bstr" ] # Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`). -_unstable-doc = [] +_unstable-doc = ["sqlite-preupdate-hook"] # Base runtime features without TLS runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] @@ -114,6 +115,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"] mysql = ["sqlx-mysql", "sqlx-macros?/mysql"] sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"] sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"] +sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"] # types json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] @@ -127,6 +129,7 @@ rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mys time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] regexp = ["sqlx-sqlite?/regexp"] +bstr = ["sqlx-core/bstr"] [workspace.dependencies] # Core Crates diff --git a/FAQ.md b/FAQ.md index f0bccd3c..cf13cf73 100644 --- a/FAQ.md +++ b/FAQ.md @@ -36,6 +36,62 @@ as they can often be a whole year or more out-of-date. [`rust-version`]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-rust-version-field +---------------------------------------------------------------- + +### Can SQLx Add Support for New Databases? + +We are always open to discuss adding support for new databases, but as of writing, have no plans to in the short term. + +Implementing support for a new database in SQLx is a _huge_ lift. Expecting this work to be done for free is highly unrealistic. +In all likelihood, the implementation would need to be written from scratch. +Even if Rust bindings exist, they may not support `async`. +Even if they support `async`, they may only support either Tokio or `async-std`, and not both. +Even if they support Tokio and `async-std`, the API may not be flexible enough or provide sufficient information (e.g. for implementing the macros). + +If we have to write the implementation from scratch, is the protocol publicly documented, and stable? + +Even if everything is supported on the client side, how will we run tests against the database? Is it open-source, or proprietary? Will it require a paid license? + +For example, Oracle Database's protocol is proprietary and only supported through their own libraries, which do not support Rust, and only have blocking APIs (see: [Oracle Call Interface for C](https://docs.oracle.com/en/database/oracle/oracle-database/23/lnoci/index.html)). +This makes it a poor candidate for an async-native crate like SQLx--though we support SQLite, which also only has a blocking API, that's the exception and not the rule. Wrapping blocking APIs is not very scalable. + +We still have plans to bring back the MSSQL driver, but this is not feasible as of writing with the current maintenance workload. Should this change, an announcement will be made on Github as well as our [Discord server](https://discord.gg/uuruzJ7). + +### What If I'm Willing to Contribute the Implementation? + +Being willing to contribute an implementation for a new database is one thing, but there's also the ongoing maintenance burden to consider. + +Are you willing to provide support long-term? +Will there be enough users that we can rely on outside contributions? +Or is support going to fall to the current maintainer(s)? + +This is the kind of thing that will need to be supported in SQLx _long_ after the initial implementation, or else later need to be removed. +If you don't have plans for how to support a new driver long-term, then it doesn't belong as part of SQLx itself. + +However, drivers don't necessarily need to live _in_ SQLx anymore. Since 0.7.0, drivers don't need to be compiled-in to be functional. +Support for third-party drivers in `sqlx-cli` and the `query!()` macros is pending, as well as documenting the process of writing a driver, but contributions are welcome in this regard. + +For example, see [sqlx-exasol](https://crates.io/crates/sqlx-exasol). + +---------------------------------------------------------------- +### Can SQLx Add Support for New Data-Type Crates (e.g. Jiff in addition to `chrono` and `time`)? + +This has a lot of the same considerations as adding support for new databases (see above), but with one big additional problem: Semantic Versioning. + +When we add trait implementations for types from an external crate, that crate then becomes part of our public API. We become beholden to its release cycle. + +If the crate's API is still evolving, meaning they are making breaking changes frequently, and thus releasing new major versions frequently, that then becomes a burden on us to upgrade and release a new major version as well so everyone _else_ can upgrade. + +We don't have the maintainer bandwidth to support multiple major versions simultaneously (we have no Long-Term Support policy), so this means that users who want to keep up-to-date are forced to make frequent manual upgrades as well. + +Thus, it is best that we stick to only supporting crates which have a stable API, and which are not making new major releases frequently. + +Conversely, adding support for SQLx _in_ these crates may not be desirable either, since SQLx is a large dependency and a higher-level crate. In this case, the SemVer problem gets pushed onto the other crate. + +There isn't a satisfying answer to this problem, but one option is to have an intermediate wrapper crate. +For example, [`jiff-sqlx`](https://crates.io/crates/jiff-sqlx), which is maintained by the author of Jiff. +API changes to SQLx are pending to make this pattern easier to use. + ---------------------------------------------------------------- ### I'm getting `HandshakeFailure` or `CorruptMessage` when trying to connect to a server over TLS using RusTLS. What gives? diff --git a/README.md b/README.md index 4d4a2338..c3b501ca 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,10 @@ be removed in the future. * May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended. * Can increase build time due to the use of bindgen. +- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API. + * Exposed as a separate feature because it's generally not enabled by default. + * Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it. + - `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime. - `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`. @@ -204,7 +208,7 @@ be removed in the future. - `migrate`: Add support for the migration management and `migrate!` macro, which allow compile-time embedded migrations. -- `uuid`: Add support for UUID (in Postgres). +- `uuid`: Add support for UUID. - `chrono`: Add support for date and time types from `chrono`. diff --git a/ci.db b/ci.db new file mode 100644 index 00000000..cc158a72 Binary files /dev/null and b/ci.db differ diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 4ece2263..ea281c0a 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -26,7 +26,7 @@ path = "src/bin/cargo-sqlx.rs" [dependencies] dotenvy = "0.15.0" -tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread", "signal"] } sqlx = { workspace = true, default-features = false, features = [ "runtime-tokio", "migrate", @@ -37,9 +37,8 @@ clap = { version = "4.3.10", features = ["derive", "env"] } clap_complete = { version = "4.3.1", optional = true } chrono = { version = "0.4.19", default-features = false, features = ["clock"] } anyhow = "1.0.52" -async-trait = "0.1.52" console = "0.15.0" -promptly = "0.3.0" +dialoguer = { version = "0.11", default-features = false } serde_json = "1.0.73" glob = "0.3.0" openssl = { version = "0.10.38", optional = true } diff --git a/sqlx-cli/src/bin/cargo-sqlx.rs b/sqlx-cli/src/bin/cargo-sqlx.rs index 58f7b345..c87147b6 100644 --- a/sqlx-cli/src/bin/cargo-sqlx.rs +++ b/sqlx-cli/src/bin/cargo-sqlx.rs @@ -13,9 +13,12 @@ enum Cli { #[tokio::main] async fn main() { - dotenvy::dotenv().ok(); let Cli::Sqlx(opt) = Cli::parse(); + if !opt.no_dotenv { + dotenvy::dotenv().ok(); + } + if let Err(error) = sqlx_cli::run(opt).await { println!("{} {}", style("error:").bold().red(), error); process::exit(1); diff --git a/sqlx-cli/src/bin/sqlx.rs b/sqlx-cli/src/bin/sqlx.rs index 59025cd7..c19b61f3 100644 --- a/sqlx-cli/src/bin/sqlx.rs +++ b/sqlx-cli/src/bin/sqlx.rs @@ -4,9 +4,14 @@ use sqlx_cli::Opt; #[tokio::main] async fn main() { - dotenvy::dotenv().ok(); + let opt = Opt::parse(); + + if !opt.no_dotenv { + dotenvy::dotenv().ok(); + } + // no special handling here - if let Err(error) = sqlx_cli::run(Opt::parse()).await { + if let Err(error) = sqlx_cli::run(opt).await { println!("{} {}", style("error:").bold().red(), error); std::process::exit(1); } diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index a0af55d6..eaba46ee 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,9 +1,11 @@ use crate::opt::{ConnectOpts, MigrationSourceOpt}; use crate::{migrate, Config}; -use console::style; -use promptly::{prompt, ReadlineError}; +use console::{style, Term}; +use dialoguer::Confirm; use sqlx::any::Any; use sqlx::migrate::MigrateDatabase; +use std::{io, mem}; +use tokio::task; pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { // NOTE: only retry the idempotent action. @@ -24,7 +26,7 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { } pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> anyhow::Result<()> { - if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?) { + if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?.to_owned()).await { return Ok(()); } @@ -63,27 +65,46 @@ pub async fn setup( migrate::run(config, migration_source, connect_opts, false, false, None).await } -fn ask_to_continue_drop(db_url: &str) -> bool { - loop { - let r: Result = - prompt(format!("Drop database at {}? (y/n)", style(db_url).cyan())); - match r { - Ok(response) => { - if response == "n" || response == "N" { - return false; - } else if response == "y" || response == "Y" { - return true; - } else { - println!( - "Response not recognized: {}\nPlease type 'y' or 'n' and press enter.", - response - ); - } - } - Err(e) => { - println!("{e}"); - return false; +async fn ask_to_continue_drop(db_url: String) -> bool { + // If the setup operation is cancelled while we are waiting for the user to decide whether + // or not to drop the database, this will restore the terminal's cursor to its normal state. + struct RestoreCursorGuard { + disarmed: bool, + } + + impl Drop for RestoreCursorGuard { + fn drop(&mut self) { + if !self.disarmed { + Term::stderr().show_cursor().unwrap() } } } + + let mut guard = RestoreCursorGuard { disarmed: false }; + + let decision_result = task::spawn_blocking(move || { + Confirm::new() + .with_prompt(format!("Drop database at {}?", style(&db_url).cyan())) + .wait_for_newline(true) + .default(false) + .show_default(true) + .interact() + }) + .await + .expect("Confirm thread panicked"); + match decision_result { + Ok(decision) => { + guard.disarmed = true; + decision + } + Err(dialoguer::Error::IO(err)) if err.kind() == io::ErrorKind::Interrupted => { + // Sometimes CTRL + C causes this error to be returned + mem::drop(guard); + false + } + Err(err) => { + mem::drop(guard); + panic!("Confirm dialog failed with {err}") + } + } } diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 43b301e4..67b1ef4a 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -6,6 +6,7 @@ use anyhow::{Context, Result}; use futures::{Future, TryFutureExt}; use sqlx::{AnyConnection, Connection}; +use tokio::{select, signal}; use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand}; @@ -24,6 +25,26 @@ pub use crate::opt::Opt; pub use sqlx::_unstable::config::{self, Config}; pub async fn run(opt: Opt) -> Result<()> { + // This `select!` is here so that when the process receives a `SIGINT` (CTRL + C), + // the futures currently running on this task get dropped before the program exits. + // This is currently necessary for the consumers of the `dialoguer` crate to restore + // the user's terminal if the process is interrupted while a dialog is being displayed. + + let ctrlc_fut = signal::ctrl_c(); + let do_run_fut = do_run(opt); + + select! { + biased; + _ = ctrlc_fut => { + Ok(()) + }, + do_run_outcome = do_run_fut => { + do_run_outcome + } + } +} + +async fn do_run(opt: Opt) -> Result<()> { let config = config_from_current_dir().await?; match opt.command { diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 9716303c..3230148e 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -12,6 +12,10 @@ use std::ops::{Deref, Not}; #[derive(Parser, Debug)] #[clap(version, about, author)] pub struct Opt { + /// Do not automatically load `.env` files. + #[clap(long)] + pub no_dotenv: bool, + #[clap(subcommand)] pub command: Command, } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index ee6e344e..97c2a8b7 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -25,7 +25,7 @@ _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] _tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"] _tls-rustls-ring-native-roots = ["_tls-rustls", "rustls/ring", "rustls-native-certs"] -_tls-rustls = ["rustls", "rustls-pemfile"] +_tls-rustls = ["rustls"] _tls-none = [] # support offline/decoupled building (enables serialization of `Describe`) @@ -47,8 +47,7 @@ tokio = { workspace = true, optional = true } # TLS native-tls = { version = "0.2.10", optional = true } -rustls = { version = "0.23.11", default-features = false, features = ["std", "tls12"], optional = true } -rustls-pemfile = { version = "2", optional = true } +rustls = { version = "0.23.15", default-features = false, features = ["std", "tls12"], optional = true } webpki-roots = { version = "0.26", optional = true } rustls-native-certs = { version = "0.8.0", optional = true } @@ -62,6 +61,7 @@ mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } async-io = { version = "1.9.0", optional = true } +base64 = { version = "0.22.0", default-features = false, features = ["std"] } bytes = "1.1.0" chrono = { version = "0.4.34", default-features = false, features = ["clock"], optional = true } crc = { version = "3", optional = true } diff --git a/sqlx-core/src/from_row.rs b/sqlx-core/src/from_row.rs index 9c647d37..ecd5847f 100644 --- a/sqlx-core/src/from_row.rs +++ b/sqlx-core/src/from_row.rs @@ -111,7 +111,8 @@ use crate::{error::Error, row::Row}; /// different placeholder values, if applicable. /// /// This is similar to how `#[serde(default)]` behaves. -/// ### `flatten` +/// +/// #### `flatten` /// /// If you want to handle a field that implements [`FromRow`], /// you can use the `flatten` attribute to specify that you want @@ -177,33 +178,6 @@ use crate::{error::Error, row::Row}; /// assert!(user.addresses.is_empty()); /// ``` /// -/// ## Manual implementation -/// -/// You can also implement the [`FromRow`] trait by hand. This can be useful if you -/// have a struct with a field that needs manual decoding: -/// -/// -/// ```rust,ignore -/// use sqlx::{FromRow, sqlite::SqliteRow, sqlx::Row}; -/// struct MyCustomType { -/// custom: String, -/// } -/// -/// struct Foo { -/// bar: MyCustomType, -/// } -/// -/// impl FromRow<'_, SqliteRow> for Foo { -/// fn from_row(row: &SqliteRow) -> sqlx::Result { -/// Ok(Self { -/// bar: MyCustomType { -/// custom: row.try_get("custom")? -/// } -/// }) -/// } -/// } -/// ``` -/// /// #### `try_from` /// /// When your struct contains a field whose type is not matched with the database type, @@ -271,6 +245,59 @@ use crate::{error::Error, row::Row}; /// } /// } /// ``` +/// +/// By default the `#[sqlx(json)]` attribute will assume that the underlying database row is +/// _not_ NULL. This can cause issues when your field type is an `Option` because this would be +/// represented as the _not_ NULL (in terms of DB) JSON value of `null`. +/// +/// If you wish to describe a database row which _is_ NULLable but _cannot_ contain the JSON value `null`, +/// use the `#[sqlx(json(nullable))]` attrubute. +/// +/// For example +/// ```rust,ignore +/// #[derive(serde::Deserialize)] +/// struct Data { +/// field1: String, +/// field2: u64 +/// } +/// +/// #[derive(sqlx::FromRow)] +/// struct User { +/// id: i32, +/// name: String, +/// #[sqlx(json(nullable))] +/// metadata: Option +/// } +/// ``` +/// Would describe a database field which _is_ NULLable but if it exists it must be the JSON representation of `Data` +/// and cannot be the JSON value `null` +/// +/// ## Manual implementation +/// +/// You can also implement the [`FromRow`] trait by hand. This can be useful if you +/// have a struct with a field that needs manual decoding: +/// +/// +/// ```rust,ignore +/// use sqlx::{FromRow, sqlite::SqliteRow, sqlx::Row}; +/// struct MyCustomType { +/// custom: String, +/// } +/// +/// struct Foo { +/// bar: MyCustomType, +/// } +/// +/// impl FromRow<'_, SqliteRow> for Foo { +/// fn from_row(row: &SqliteRow) -> sqlx::Result { +/// Ok(Self { +/// bar: MyCustomType { +/// custom: row.try_get("custom")? +/// } +/// }) +/// } +/// } +/// ``` pub trait FromRow<'r, R: Row>: Sized { fn from_row(row: &'r R) -> Result; } @@ -286,7 +313,7 @@ where } // implement FromRow for tuples of types that implement Decode -// up to tuples of 9 values +// up to tuples of 16 values macro_rules! impl_from_row_for_tuple { ($( ($idx:tt) -> $T:ident );+;) => { diff --git a/sqlx-core/src/io/write_and_flush.rs b/sqlx-core/src/io/write_and_flush.rs index 9e7824af..8a0db312 100644 --- a/sqlx-core/src/io/write_and_flush.rs +++ b/sqlx-core/src/io/write_and_flush.rs @@ -1,10 +1,9 @@ use crate::error::Error; -use futures_core::Future; -use futures_util::ready; use sqlx_rt::AsyncWrite; +use std::future::Future; use std::io::{BufRead, Cursor}; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; // Atomic operation that writes the full buffer to the stream, flushes the stream, and then // clears the buffer (even if either of the two previous operations failed). diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 6b09d318..d11f1588 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -2,10 +2,9 @@ use std::future::Future; use std::io; use std::path::Path; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use bytes::BufMut; -use futures_core::ready; pub use buffered::{BufferedSocket, WriteBuffer}; diff --git a/sqlx-core/src/net/tls/tls_rustls.rs b/sqlx-core/src/net/tls/tls_rustls.rs index d5685980..1a85cf0f 100644 --- a/sqlx-core/src/net/tls/tls_rustls.rs +++ b/sqlx-core/src/net/tls/tls_rustls.rs @@ -1,5 +1,5 @@ use futures_util::future; -use std::io::{self, BufReader, Cursor, Read, Write}; +use std::io::{self, Read, Write}; use std::sync::Arc; use std::task::{Context, Poll}; @@ -9,7 +9,10 @@ use rustls::{ WebPkiServerVerifier, }, crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider}, - pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, + pki_types::{ + pem::{self, PemObject}, + CertificateDer, PrivateKeyDer, ServerName, UnixTime, + }, CertificateError, ClientConfig, ClientConnection, Error as TlsError, RootCertStore, }; @@ -141,9 +144,8 @@ where if let Some(ca) = tls_config.root_cert_path { let data = ca.data().await?; - let mut cursor = Cursor::new(data); - for result in rustls_pemfile::certs(&mut cursor) { + for result in CertificateDer::pem_slice_iter(&data) { let Ok(cert) = result else { return Err(Error::Tls(format!("Invalid certificate {ca}").into())); }; @@ -196,19 +198,15 @@ where } fn certs_from_pem(pem: Vec) -> Result>, Error> { - let cur = Cursor::new(pem); - let mut reader = BufReader::new(cur); - rustls_pemfile::certs(&mut reader) + CertificateDer::pem_slice_iter(&pem) .map(|result| result.map_err(|err| Error::Tls(err.into()))) .collect() } fn private_key_from_pem(pem: Vec) -> Result, Error> { - let cur = Cursor::new(pem); - let mut reader = BufReader::new(cur); - match rustls_pemfile::private_key(&mut reader) { - Ok(Some(key)) => Ok(key), - Ok(None) => Err(Error::Configuration("no keys found pem file".into())), + match PrivateKeyDer::from_pem_slice(&pem) { + Ok(key) => Ok(key), + Err(pem::Error::NoItemsFound) => Err(Error::Configuration("no keys found pem file".into())), Err(e) => Err(Error::Configuration(e.to_string().into())), } } diff --git a/sqlx-core/src/net/tls/util.rs b/sqlx-core/src/net/tls/util.rs index 02a16ef5..ddbc7a58 100644 --- a/sqlx-core/src/net/tls/util.rs +++ b/sqlx-core/src/net/tls/util.rs @@ -1,9 +1,8 @@ use crate::net::Socket; use std::io::{self, Read, Write}; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; -use futures_core::ready; use futures_util::future; pub struct StdSocket { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index bbcc4313..2066364a 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -10,6 +10,7 @@ use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; use std::cmp; use std::future::Future; +use std::pin::pin; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; use std::sync::{Arc, RwLock}; use std::task::Poll; @@ -130,19 +131,12 @@ impl PoolInner { // This is just going to cause unnecessary churn in `acquire()`. .filter(|_| self.size() < self.options.max_connections); - let acquire_self = self.semaphore.acquire(1).fuse(); - let mut close_event = self.close_event(); + let mut acquire_self = pin!(self.semaphore.acquire(1).fuse()); + let mut close_event = pin!(self.close_event()); if let Some(parent) = parent { - let acquire_parent = parent.0.semaphore.acquire(1); - let parent_close_event = parent.0.close_event(); - - futures_util::pin_mut!( - acquire_parent, - acquire_self, - close_event, - parent_close_event - ); + let mut acquire_parent = pin!(parent.0.semaphore.acquire(1)); + let mut parent_close_event = pin!(parent.0.close_event()); let mut poll_parent = false; diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index e9986184..042bc5c7 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -56,7 +56,7 @@ use std::fmt; use std::future::Future; -use std::pin::Pin; +use std::pin::{pin, Pin}; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; @@ -565,11 +565,11 @@ impl CloseEvent { .await .map_or(Ok(()), |_| Err(Error::PoolClosed))?; - futures_util::pin_mut!(fut); + let mut fut = pin!(fut); // I find that this is clearer in intent than `futures_util::future::select()` // or `futures_util::select_biased!{}` (which isn't enabled anyway). - futures_util::future::poll_fn(|cx| { + std::future::poll_fn(|cx| { // Poll `fut` first as the wakeup event is more likely for it than `self`. if let Poll::Ready(ret) = fut.as_mut().poll(cx) { return Poll::Ready(Ok(ret)); diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 96dbf8ee..3d048f17 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -484,7 +484,7 @@ impl PoolOptions { /// .await?; /// /// // Close the connection if the backend memory usage exceeds 256 MiB. - /// Ok(total_memory_usage <= (2 << 28)) + /// Ok(total_memory_usage <= (1 << 28)) /// })) /// .connect("postgres:// …").await?; /// # Ok(()) diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index 0d02048d..b242bf7b 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -323,6 +323,11 @@ where separated.push_unseparated(")"); } + debug_assert!( + separated.push_separator, + "No value being pushed. QueryBuilder may not build correct sql query!" + ); + separated.query_builder } diff --git a/sqlx-core/src/testing/mod.rs b/sqlx-core/src/testing/mod.rs index 9db65e9d..d683fdf8 100644 --- a/sqlx-core/src/testing/mod.rs +++ b/sqlx-core/src/testing/mod.rs @@ -3,7 +3,9 @@ use std::time::Duration; use futures_core::future::BoxFuture; +use base64::{engine::general_purpose::URL_SAFE, Engine as _}; pub use fixtures::FixtureSnapshot; +use sha2::{Digest, Sha512}; use crate::connection::{ConnectOptions, Connection}; use crate::database::Database; @@ -41,6 +43,17 @@ pub trait TestSupport: Database { /// This snapshot can then be used to generate test fixtures. fn snapshot(conn: &mut Self::Connection) -> BoxFuture<'_, Result, Error>>; + + /// Generate a unique database name for the given test path. + fn db_name(args: &TestArgs) -> String { + let mut hasher = Sha512::new(); + hasher.update(args.test_path.as_bytes()); + let hash = hasher.finalize(); + let hash = URL_SAFE.encode(&hash[..39]); + let db_name = format!("_sqlx_test_{}", hash).replace('-', "_"); + debug_assert!(db_name.len() == 63); + db_name + } } pub struct TestFixture { @@ -217,7 +230,7 @@ where let res = test_fn(test_context.pool_opts, test_context.connect_opts).await; if res.is_success() { - if let Err(e) = DB::cleanup_test(&test_context.db_name).await { + if let Err(e) = DB::cleanup_test(&DB::db_name(&args)).await { eprintln!( "failed to delete database {:?}: {}", test_context.db_name, e diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 25837b1e..909dd492 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -85,6 +85,9 @@ pub mod mac_address { pub use json::{Json, JsonRawValue, JsonValue}; pub use text::Text; +#[cfg(feature = "bstr")] +pub use bstr::{BStr, BString}; + /// Indicates that a SQL type is supported for a database. /// /// ## Compile-time verification diff --git a/sqlx-macros-core/src/derives/attributes.rs b/sqlx-macros-core/src/derives/attributes.rs index cf18cffc..c6968790 100644 --- a/sqlx-macros-core/src/derives/attributes.rs +++ b/sqlx-macros-core/src/derives/attributes.rs @@ -1,8 +1,8 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote_spanned; use syn::{ - punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, LitStr, Meta, Token, Type, - Variant, + parenthesized, punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, LitStr, + Meta, Token, Type, Variant, }; macro_rules! assert_attribute { @@ -61,13 +61,18 @@ pub struct SqlxContainerAttributes { pub default: bool, } +pub enum JsonAttribute { + NonNullable, + Nullable, +} + pub struct SqlxChildAttributes { pub rename: Option, pub default: bool, pub flatten: bool, pub try_from: Option, pub skip: bool, - pub json: bool, + pub json: Option, } pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { @@ -144,7 +149,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result syn::Result - (false, None, false) => { + (false, None, None) => { predicates .push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); @@ -107,12 +107,12 @@ fn expand_derive_from_row_struct( parse_quote!(__row.try_get(#id_s)) } // Flatten - (true, None, false) => { + (true, None, None) => { predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>)); parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(__row)) } // Flatten + Try from - (true, Some(try_from), false) => { + (true, Some(try_from), None) => { predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>)); parse_quote!( <#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(__row) @@ -130,11 +130,11 @@ fn expand_derive_from_row_struct( ) } // Flatten + Json - (true, _, true) => { + (true, _, Some(_)) => { panic!("Cannot use both flatten and json") } // Try from - (false, Some(try_from), false) => { + (false, Some(try_from), None) => { predicates .push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#try_from: ::sqlx::types::Type)); @@ -154,8 +154,8 @@ fn expand_derive_from_row_struct( }) ) } - // Try from + Json - (false, Some(try_from), true) => { + // Try from + Json mandatory + (false, Some(try_from), Some(JsonAttribute::NonNullable)) => { predicates .push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type)); @@ -175,14 +175,25 @@ fn expand_derive_from_row_struct( }) ) }, + // Try from + Json nullable + (false, Some(_), Some(JsonAttribute::Nullable)) => { + panic!("Cannot use both try from and json nullable") + }, // Json - (false, None, true) => { + (false, None, Some(JsonAttribute::NonNullable)) => { predicates .push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type)); parse_quote!(__row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0)) }, + (false, None, Some(JsonAttribute::Nullable)) => { + predicates + .push(parse_quote!(::core::option::Option<::sqlx::types::Json<#ty>>: ::sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(::core::option::Option<::sqlx::types::Json<#ty>>: ::sqlx::types::Type)); + + parse_quote!(__row.try_get::<::core::option::Option<::sqlx::types::Json<_>>, _>(#id_s).map(|x| x.and_then(|y| y.0))) + }, }; if attributes.default { diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index 0466bfc0..e01e41d6 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,7 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; -use std::future; +use std::{future, pin::pin}; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -113,8 +113,7 @@ impl AnyConnectionBackend for MySqlConnection { Box::pin(async move { let arguments = arguments?; - let stream = self.run(query, arguments, persistent).await?; - futures_util::pin_mut!(stream); + let mut stream = pin!(self.run(query, arguments, persistent).await?); while let Some(result) = stream.try_next().await? { if let Either::Right(row) = result { diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index d0d9cf18..bc8d0b62 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -21,9 +21,9 @@ use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; -use futures_util::{pin_mut, TryStreamExt}; +use futures_util::TryStreamExt; use sqlx_core::column::{ColumnOrigin, TableColumn}; -use std::{borrow::Cow, sync::Arc}; +use std::{borrow::Cow, pin::pin, sync::Arc}; impl MySqlConnection { async fn prepare_statement<'c>( @@ -112,7 +112,7 @@ impl MySqlConnection { self.inner.stream.wait_until_ready().await?; self.inner.stream.waiting.push_back(Waiting::Result); - Ok(Box::pin(try_stream! { + Ok(try_stream! { // make a slot for the shared column data // as long as a reference to a row is not held past one iteration, this enables us // to re-use this memory freely between result sets @@ -241,7 +241,7 @@ impl MySqlConnection { r#yield!(v); } } - })) + }) } } @@ -264,8 +264,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { Box::pin(try_stream! { let arguments = arguments?; - let s = self.run(sql, arguments, persistent).await?; - pin_mut!(s); + let mut s = pin!(self.run(sql, arguments, persistent).await?); while let Some(v) = s.try_next().await? { r#yield!(v); diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index 2a9216d1..1981cf73 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,29 +1,25 @@ -use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; - -use crate::connection::Connection; +use sqlx_core::connection::Connection; +use sqlx_core::query_builder::QueryBuilder; +use sqlx_core::query_scalar::query_scalar; +use std::fmt::Write; use crate::error::Error; use crate::executor::Executor; use crate::pool::{Pool, PoolOptions}; use crate::query::query; -use crate::query_builder::QueryBuilder; -use crate::query_scalar::query_scalar; use crate::{MySql, MySqlConnectOptions, MySqlConnection}; pub(crate) use sqlx_core::testing::*; // Using a blocking `OnceCell` here because the critical sections are short. static MASTER_POOL: OnceCell> = OnceCell::new(); -// Automatically delete any databases created before the start of the test binary. -static DO_CLEANUP: AtomicBool = AtomicBool::new(true); impl TestSupport for MySql { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { @@ -34,21 +30,11 @@ impl TestSupport for MySql { Box::pin(async move { let mut conn = MASTER_POOL .get() - .expect("cleanup_test() invoked outside `#[sqlx::test]") + .expect("cleanup_test() invoked outside `#[sqlx::test]`") .acquire() .await?; - let db_id = db_id(db_name); - - conn.execute(&format!("drop database if exists {db_name};")[..]) - .await?; - - query("delete from _sqlx_test_databases where db_id = ?") - .bind(db_id) - .execute(&mut *conn) - .await?; - - Ok(()) + do_cleanup(&mut conn, db_name).await }) } @@ -58,13 +44,55 @@ impl TestSupport for MySql { let mut conn = MySqlConnection::connect(&url).await?; - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap(); + let delete_db_names: Vec = + query_scalar("select db_name from _sqlx_test_databases") + .fetch_all(&mut conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(None); + } + + let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); + + let mut command = String::new(); + + for db_name in &delete_db_names { + command.clear(); + + let db_name = format!("_sqlx_test_database_{db_name}"); + + writeln!(command, "drop database if exists {db_name:?};").ok(); + match conn.execute(&*command).await { + Ok(_deleted) => { + deleted_db_names.push(db_name); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name:?}: {dbe}") + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + if deleted_db_names.is_empty() { + return Ok(None); + } + + let mut query = + QueryBuilder::new("delete from _sqlx_test_databases where db_name in ("); + + let mut separated = query.separated(","); + + for db_name in &deleted_db_names { + separated.push_bind(db_name); + } + + query.push(")").build().execute(&mut conn).await?; - let num_deleted = do_cleanup(&mut conn, now).await?; let _ = conn.close().await; - Ok(Some(num_deleted)) + Ok(Some(delete_db_names.len())) }) } @@ -117,7 +145,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { conn.execute( r#" create table if not exists _sqlx_test_databases ( - db_id bigint unsigned primary key auto_increment, + db_name text primary key, test_path text not null, created_at timestamp not null default current_timestamp ); @@ -125,34 +153,19 @@ async fn test_context(args: &TestArgs) -> Result, Error> { ) .await?; - // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This - // prevents the first test thread from accidentally deleting new test dbs - // created by other test threads if we're a bit slow. - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap(); + let db_name = MySql::db_name(args); + do_cleanup(&mut conn, &db_name).await?; - // Only run cleanup if the test binary just started. - if DO_CLEANUP.swap(false, Ordering::SeqCst) { - do_cleanup(&mut conn, now).await?; - } - - query("insert into _sqlx_test_databases(test_path) values (?)") + query("insert into _sqlx_test_databases(db_name, test_path) values (?, ?)") + .bind(&db_name) .bind(args.test_path) .execute(&mut *conn) .await?; - // MySQL doesn't have `INSERT ... RETURNING` - let new_db_id: u64 = query_scalar("select last_insert_id()") - .fetch_one(&mut *conn) + conn.execute(&format!("create database {db_name:?}")[..]) .await?; - let new_db_name = db_name(new_db_id); - - conn.execute(&format!("create database {new_db_name}")[..]) - .await?; - - eprintln!("created database {new_db_name}"); + eprintln!("created database {db_name}"); Ok(TestContext { pool_opts: PoolOptions::new() @@ -167,74 +180,18 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .connect_options() .deref() .clone() - .database(&new_db_name), - db_name: new_db_name, + .database(&db_name), + db_name, }) } -async fn do_cleanup(conn: &mut MySqlConnection, created_before: Duration) -> Result { - // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads - let created_before_as_secs = created_before.as_secs() - 2; - let delete_db_ids: Vec = query_scalar( - "select db_id from _sqlx_test_databases \ - where created_at < from_unixtime(?)", - ) - .bind(created_before_as_secs) - .fetch_all(&mut *conn) - .await?; +async fn do_cleanup(conn: &mut MySqlConnection, db_name: &str) -> Result<(), Error> { + let delete_db_command = format!("drop database if exists {db_name:?};"); + conn.execute(&*delete_db_command).await?; + query("delete from _sqlx_test.databases where db_name = $1::text") + .bind(db_name) + .execute(&mut *conn) + .await?; - if delete_db_ids.is_empty() { - return Ok(0); - } - - let mut deleted_db_ids = Vec::with_capacity(delete_db_ids.len()); - - let mut command = String::new(); - - for db_id in delete_db_ids { - command.clear(); - - let db_name = db_name(db_id); - - writeln!(command, "drop database if exists {db_name}").ok(); - match conn.execute(&*command).await { - Ok(_deleted) => { - deleted_db_ids.push(db_id); - } - // Assume a database error just means the DB is still in use. - Err(Error::Database(dbe)) => { - eprintln!("could not clean test database {db_id:?}: {dbe}") - } - // Bubble up other errors - Err(e) => return Err(e), - } - } - - let mut query = QueryBuilder::new("delete from _sqlx_test_databases where db_id in ("); - - let mut separated = query.separated(","); - - for db_id in &deleted_db_ids { - separated.push_bind(db_id); - } - - query.push(")").build().execute(&mut *conn).await?; - - Ok(deleted_db_ids.len()) -} - -fn db_name(id: u64) -> String { - format!("_sqlx_test_database_{id}") -} - -fn db_id(name: &str) -> u64 { - name.trim_start_matches("_sqlx_test_database_") - .parse() - .unwrap_or_else(|_1| panic!("failed to parse ID from database name {name:?}")) -} - -#[test] -fn test_db_name_id() { - assert_eq!(db_name(12345), "_sqlx_test_database_12345"); - assert_eq!(db_id("_sqlx_test_database_12345"), 12345); + Ok(()) } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index efa9a044..a7b30fb6 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,7 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; -use std::future; +use std::{future, pin::pin}; use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, @@ -115,8 +115,7 @@ impl AnyConnectionBackend for PgConnection { Box::pin(async move { let arguments = arguments?; - let stream = self.run(query, arguments, 1, persistent, None).await?; - futures_util::pin_mut!(stream); + let mut stream = pin!(self.run(query, arguments, 1, persistent, None).await?); if let Some(Either::Right(row)) = stream.try_next().await? { return Ok(Some(AnyRow::try_from(&row)?)); diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index bc7e861c..62a227e5 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -22,7 +22,7 @@ use sqlx_core::error::BoxDynError; // that has a patch, we then apply the patch which should write to &mut Vec, // backtrack and update the prefixed-len, then write until the next patch offset -#[derive(Default)] +#[derive(Default, Debug, Clone)] pub struct PgArgumentBuffer { buffer: Vec, @@ -46,20 +46,32 @@ pub struct PgArgumentBuffer { type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }> } +#[derive(Debug, Clone)] enum HoleKind { Type { name: UStr }, Array(Arc), } +#[derive(Clone)] struct Patch { buf_offset: usize, arg_index: usize, #[allow(clippy::type_complexity)] - callback: Box, + callback: Arc, +} + +impl fmt::Debug for Patch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Patch") + .field("buf_offset", &self.buf_offset) + .field("arg_index", &self.arg_index) + .field("callback", &"") + .finish() + } } /// Implementation of [`Arguments`] for PostgreSQL. -#[derive(Default)] +#[derive(Default, Debug, Clone)] pub struct PgArguments { // Types of each bind parameter pub(crate) types: Vec, @@ -194,7 +206,7 @@ impl PgArgumentBuffer { self.patches.push(Patch { buf_offset: offset, arg_index, - callback: Box::new(callback), + callback: Arc::new(callback), }); } diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 9b5fd2a3..28e5e72e 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -15,10 +15,10 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; -use futures_util::{pin_mut, TryStreamExt}; +use futures_util::TryStreamExt; use sqlx_core::arguments::Arguments; use sqlx_core::Either; -use std::{borrow::Cow, sync::Arc}; +use std::{borrow::Cow, pin::pin, sync::Arc}; async fn prepare( conn: &mut PgConnection, @@ -395,8 +395,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(try_stream! { let arguments = arguments?; - let s = self.run(sql, arguments, 0, persistent, metadata).await?; - pin_mut!(s); + let mut s = pin!(self.run(sql, arguments, 0, persistent, metadata).await?); while let Some(v) = s.try_next().await? { r#yield!(v); @@ -422,8 +421,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { Box::pin(async move { let arguments = arguments?; - let s = self.run(sql, arguments, 1, persistent, metadata).await?; - pin_mut!(s); + let mut s = pin!(self.run(sql, arguments, 1, persistent, metadata).await?); // With deferred constraints we need to check all responses as we // could get a OK response (with uncommitted data), only to get an diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index ddc187e9..1315ea0e 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -129,6 +129,9 @@ impl PgPoolCopyExt for Pool { } } +// (1 GiB - 1) - 1 - length prefix (4 bytes) +pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4; + /// A connection in streaming `COPY FROM STDIN` mode. /// /// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. @@ -186,15 +189,20 @@ impl> PgCopyIn { /// Send a chunk of `COPY` data. /// + /// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6 + /// bytes) and may be partially sent if this call is cancelled. + /// /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { - self.conn - .as_deref_mut() - .expect("send_data: conn taken") - .inner - .stream - .send(CopyData(data)) - .await?; + for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .inner + .stream + .send(CopyData(chunk)) + .await?; + } Ok(self) } @@ -230,10 +238,10 @@ impl> PgCopyIn { } // Write the length - let read32 = u32::try_from(read) - .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + let read32 = i32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?; - (&mut buf.get_mut()[1..]).put_u32(read32 + 4); + (&mut buf.get_mut()[1..]).put_i32(read32 + 4); conn.inner.stream.flush().await?; } diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 792f8bbd..bded7549 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -34,6 +34,9 @@ mod value; #[doc(hidden)] pub mod any; +#[doc(hidden)] +pub use copy::PG_COPY_MAX_DATA_LEN; + #[cfg(feature = "migrate")] mod migrate; diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index fb36ab41..af20fe87 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,20 +1,18 @@ use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; - -use crate::connection::Connection; +use sqlx_core::connection::Connection; +use sqlx_core::query_scalar::query_scalar; use crate::error::Error; use crate::executor::Executor; use crate::pool::{Pool, PoolOptions}; use crate::query::query; -use crate::query_scalar::query_scalar; use crate::{PgConnectOptions, PgConnection, Postgres}; pub(crate) use sqlx_core::testing::*; @@ -22,7 +20,6 @@ pub(crate) use sqlx_core::testing::*; // Using a blocking `OnceCell` here because the critical sections are short. static MASTER_POOL: OnceCell> = OnceCell::new(); // Automatically delete any databases created before the start of the test binary. -static DO_CLEANUP: AtomicBool = AtomicBool::new(true); impl TestSupport for Postgres { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { @@ -33,19 +30,11 @@ impl TestSupport for Postgres { Box::pin(async move { let mut conn = MASTER_POOL .get() - .expect("cleanup_test() invoked outside `#[sqlx::test]") + .expect("cleanup_test() invoked outside `#[sqlx::test]`") .acquire() .await?; - conn.execute(&format!("drop database if exists {db_name:?};")[..]) - .await?; - - query("delete from _sqlx_test.databases where db_name = $1") - .bind(db_name) - .execute(&mut *conn) - .await?; - - Ok(()) + do_cleanup(&mut conn, db_name).await }) } @@ -55,13 +44,42 @@ impl TestSupport for Postgres { let mut conn = PgConnection::connect(&url).await?; - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap(); + let delete_db_names: Vec = + query_scalar("select db_name from _sqlx_test.databases") + .fetch_all(&mut conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(None); + } + + let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); + + let mut command = String::new(); + + for db_name in &delete_db_names { + command.clear(); + writeln!(command, "drop database if exists {db_name:?};").ok(); + match conn.execute(&*command).await { + Ok(_deleted) => { + deleted_db_names.push(db_name); + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name:?}: {dbe}") + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + query("delete from _sqlx_test.databases where db_name = any($1::text[])") + .bind(&deleted_db_names) + .execute(&mut conn) + .await?; - let num_deleted = do_cleanup(&mut conn, now).await?; let _ = conn.close().await; - Ok(Some(num_deleted)) + Ok(Some(delete_db_names.len())) }) } @@ -116,8 +134,9 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically, // but a clearly related bug with `CREATE TABLE` has been known since 2007: // https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org + // magic constant 8318549251334697844 is just 8 ascii bytes 'sqlxtest'. r#" - lock table pg_catalog.pg_namespace in share row exclusive mode; + select pg_advisory_xact_lock(8318549251334697844); create schema if not exists _sqlx_test; @@ -135,31 +154,22 @@ async fn test_context(args: &TestArgs) -> Result, Error> { ) .await?; - // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This - // prevents the first test thread from accidentally deleting new test dbs - // created by other test threads if we're a bit slow. - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap(); + let db_name = Postgres::db_name(args); + do_cleanup(&mut conn, &db_name).await?; - // Only run cleanup if the test binary just started. - if DO_CLEANUP.swap(false, Ordering::SeqCst) { - do_cleanup(&mut conn, now).await?; - } - - let new_db_name: String = query_scalar( + query( r#" - insert into _sqlx_test.databases(db_name, test_path) - select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1 - returning db_name + insert into _sqlx_test.databases(db_name, test_path) values ($1, $2) "#, ) + .bind(&db_name) .bind(args.test_path) - .fetch_one(&mut *conn) + .execute(&mut *conn) .await?; - conn.execute(&format!("create database {new_db_name:?}")[..]) - .await?; + let create_command = format!("create database {db_name:?}"); + debug_assert!(create_command.starts_with("create database \"")); + conn.execute(&(create_command)[..]).await?; Ok(TestContext { pool_opts: PoolOptions::new() @@ -174,52 +184,18 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .connect_options() .deref() .clone() - .database(&new_db_name), - db_name: new_db_name, + .database(&db_name), + db_name, }) } -async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result { - // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads - let created_before = i64::try_from(created_before.as_secs()).unwrap() - 2; - - let delete_db_names: Vec = query_scalar( - "select db_name from _sqlx_test.databases \ - where created_at < (to_timestamp($1) at time zone 'UTC')", - ) - .bind(created_before) - .fetch_all(&mut *conn) - .await?; - - if delete_db_names.is_empty() { - return Ok(0); - } - - let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let delete_db_names = delete_db_names.into_iter(); - - let mut command = String::new(); - - for db_name in delete_db_names { - command.clear(); - writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { - Ok(_deleted) => { - deleted_db_names.push(db_name); - } - // Assume a database error just means the DB is still in use. - Err(Error::Database(dbe)) => { - eprintln!("could not clean test database {db_name:?}: {dbe}") - } - // Bubble up other errors - Err(e) => return Err(e), - } - } - - query("delete from _sqlx_test.databases where db_name = any($1::text[])") - .bind(&deleted_db_names) +async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> { + let delete_db_command = format!("drop database if exists {db_name:?};"); + conn.execute(&*delete_db_command).await?; + query("delete from _sqlx_test.databases where db_name = $1::text") + .bind(db_name) .execute(&mut *conn) .await?; - Ok(deleted_db_names.len()) + Ok(()) } diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 41661a84..f89690b2 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -36,6 +36,10 @@ impl_type_checking!( sqlx::postgres::types::PgLine, + sqlx::postgres::types::PgLSeg, + + sqlx::postgres::types::PgBox, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index 3d948f73..28c56758 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -185,7 +185,7 @@ pub enum PgTypeKind { Range(PgTypeInfo), } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgArrayOf { pub(crate) elem_name: UStr, diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index f39d8265..cc2a0160 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -20,7 +20,7 @@ const IS_POINT_FLAG: u32 = 1 << 31; #[derive(Debug, Clone, PartialEq)] pub enum PgCube { /// A one-dimensional point. - // FIXME: `Point1D(f64) + // FIXME: `Point1D(f64)` Point(f64), /// An N-dimensional point ("represented internally as a zero-volume cube"). // FIXME: `PointND(f64)` @@ -32,7 +32,7 @@ pub enum PgCube { // FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`? /// An N-dimensional cube with points representing lower-left and upper-right corners, respectively. - // FIXME: CubeND { lower_left: Vec, upper_right: Vec }` + // FIXME: `CubeND { lower_left: Vec, upper_right: Vec }` MultiDimension(Vec>), } diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs new file mode 100644 index 00000000..988c028e --- /dev/null +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -0,0 +1,321 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding BOX"; + +/// ## Postgres Geometric Box type +/// +/// Description: Rectangular box +/// Representation: `((upper_right_x,upper_right_y),(lower_left_x,lower_left_y))` +/// +/// Boxes are represented by pairs of points that are opposite corners of the box. Values of type box are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) ) +/// ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) +/// upper_right_x , upper_right_y , lower_left_x , lower_left_y +/// ``` +/// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box. +/// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +#[derive(Debug, Clone, PartialEq)] +pub struct PgBox { + pub upper_right_x: f64, + pub upper_right_y: f64, + pub lower_left_x: f64, + pub lower_left_y: f64, +} + +impl Type for PgBox { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("box") + } +} + +impl PgHasArrayType for PgBox { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_box") + } +} + +impl<'r> Decode<'r, Postgres> for PgBox { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgBox::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgBox::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgBox { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("box")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgBox { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let upper_right_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_x from {}", ERROR, s))?; + + let upper_right_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_y from {}", ERROR, s))?; + + let lower_left_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_x from {}", ERROR, s))?; + + let lower_left_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } +} + +impl PgBox { + fn from_bytes(mut bytes: &[u8]) -> Result { + let upper_right_x = bytes.get_f64(); + let upper_right_y = bytes.get_f64(); + let lower_left_x = bytes.get_f64(); + let lower_left_y = bytes.get_f64(); + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let min_x = &self.upper_right_x.min(self.lower_left_x); + let min_y = &self.upper_right_y.min(self.lower_left_y); + let max_x = &self.upper_right_x.max(self.lower_left_x); + let max_y = &self.upper_right_y.max(self.lower_left_y); + + buff.extend_from_slice(&max_x.to_be_bytes()); + buff.extend_from_slice(&max_y.to_be_bytes()); + buff.extend_from_slice(&min_x.to_be_bytes()); + buff.extend_from_slice(&min_y.to_be_bytes()); + + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod box_tests { + + use std::str::FromStr; + + use super::PgBox; + + const BOX_BYTES: &[u8] = &[ + 64, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, + 0, 0, 0, 0, + ]; + + #[test] + fn can_deserialise_box_type_bytes_in_order() { + let pg_box = PgBox::from_bytes(BOX_BYTES).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } + + #[test] + fn can_deserialise_box_type_str_first_syntax() { + let pg_box = PgBox::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + #[test] + fn can_deserialise_box_type_str_second_syntax() { + let pg_box = PgBox::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_third_syntax() { + let pg_box = PgBox::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_fourth_syntax() { + let pg_box = PgBox::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn cannot_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_too_few_numbers() { + let input_str = "1, 2, 3 "; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_box_type_str_float() { + let pg_box = PgBox::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1.1, + upper_right_y: 2.2, + lower_left_x: 3.3, + lower_left_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_box_type_in_order() { + let pg_box = PgBox { + upper_right_x: 2., + lower_left_x: -2., + upper_right_y: -2., + lower_left_y: 2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_serialise_box_type_out_of_order() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_order_box() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + let bytes = pg_box.serialize_to_vec(); + + let pg_box = PgBox::from_bytes(&bytes).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } +} diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs new file mode 100644 index 00000000..5dc5efc7 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -0,0 +1,283 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding LSEG"; + +/// ## Postgres Geometric Line Segment type +/// +/// Description: Finite line segment +/// Representation: `((start_x,start_y),(end_x,end_y))` +/// +/// +/// Line segments are represented by pairs of points that are the endpoints of the segment. Values of type lseg are specified using any of the following syntaxes: +/// ```text +/// [ ( start_x , start_y ) , ( end_x , end_y ) ] +/// ( ( start_x , start_y ) , ( end_x , end_y ) ) +/// ( start_x , start_y ) , ( end_x , end_y ) +/// start_x , start_y , end_x , end_y +/// ``` +/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG +#[doc(alias = "line segment")] +#[derive(Debug, Clone, PartialEq)] +pub struct PgLSeg { + pub start_x: f64, + pub start_y: f64, + pub end_x: f64, + pub end_y: f64, +} + +impl Type for PgLSeg { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("lseg") + } +} + +impl PgHasArrayType for PgLSeg { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_lseg") + } +} + +impl<'r> Decode<'r, Postgres> for PgLSeg { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgLSeg::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgLSeg::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgLSeg { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("lseg")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgLSeg { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let start_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_x from {}", ERROR, s))?; + + let start_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get start_y from {}", ERROR, s))?; + + let end_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_x from {}", ERROR, s))?; + + let end_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get end_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } +} + +impl PgLSeg { + fn from_bytes(mut bytes: &[u8]) -> Result { + let start_x = bytes.get_f64(); + let start_y = bytes.get_f64(); + let end_x = bytes.get_f64(); + let end_y = bytes.get_f64(); + + Ok(PgLSeg { + start_x, + start_y, + end_x, + end_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + buff.extend_from_slice(&self.start_x.to_be_bytes()); + buff.extend_from_slice(&self.start_y.to_be_bytes()); + buff.extend_from_slice(&self.end_x.to_be_bytes()); + buff.extend_from_slice(&self.end_y.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod lseg_tests { + + use std::str::FromStr; + + use super::PgLSeg; + + const LINE_SEGMENT_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, 64, 17, 153, 153, 153, 153, 153, 154, + ]; + + #[test] + fn can_deserialise_lseg_type_bytes() { + let lseg = PgLSeg::from_bytes(LINE_SEGMENT_BYTES).unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ) + } + + #[test] + fn can_deserialise_lseg_type_str_first_syntax() { + let lseg = PgLSeg::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + #[test] + fn can_deserialise_lseg_type_str_second_syntax() { + let lseg = PgLSeg::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_third_syntax() { + let lseg = PgLSeg::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_lseg_type_str_fourth_syntax() { + let lseg = PgLSeg::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1., + start_y: 2., + end_x: 3., + end_y: 4. + } + ); + } + + #[test] + fn can_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn can_deserialise_too_few_numbers() { + let input_str = "1, 2, 3"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let lseg = PgLSeg::from_str(input_str); + assert!(lseg.is_err()); + if let Err(err) = lseg { + assert_eq!( + err.to_string(), + format!("error decoding LSEG: could not get end_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_lseg_type_str_float() { + let lseg = PgLSeg::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + lseg, + PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_lseg_type() { + let lseg = PgLSeg { + start_x: 1.1, + start_y: 2.2, + end_x: 3.3, + end_y: 4.4, + }; + assert_eq!(lseg.serialize_to_vec(), LINE_SEGMENT_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index daf9f1de..7fe2898f 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,2 +1,4 @@ +pub mod r#box; pub mod line; +pub mod line_segment; pub mod point; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 74734551..a5fd7083 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -21,8 +21,10 @@ //! | [`PgLQuery`] | LQUERY | //! | [`PgCiText`] | CITEXT1 | //! | [`PgCube`] | CUBE | -//! | [`PgPoint] | POINT | -//! | [`PgLine] | LINE | +//! | [`PgPoint`] | POINT | +//! | [`PgLine`] | LINE | +//! | [`PgLSeg`] | LSEG | +//! | [`PgBox`] | BOX | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -259,7 +261,9 @@ pub use array::PgHasArrayType; pub use citext::PgCiText; pub use cube::PgCube; pub use geometry::line::PgLine; +pub use geometry::line_segment::PgLSeg; pub use geometry::point::PgPoint; +pub use geometry::r#box::PgBox; pub use hstore::PgHstore; pub use interval::PgInterval; pub use lquery::PgLQuery; diff --git a/sqlx-postgres/src/types/record.rs b/sqlx-postgres/src/types/record.rs index c4eb6393..6e37182c 100644 --- a/sqlx-postgres/src/types/record.rs +++ b/sqlx-postgres/src/types/record.rs @@ -41,13 +41,13 @@ impl<'a> PgRecordEncoder<'a> { { let ty = value.produces().unwrap_or_else(T::type_info); - if let PgType::DeclareWithName(name) = ty.0 { + match ty.0 { // push a hole for this type ID // to be filled in on query execution - self.buf.patch_type_by_name(&name); - } else { + PgType::DeclareWithName(name) => self.buf.patch_type_by_name(&name), + PgType::DeclareArrayOf(array) => self.buf.patch_array_type(array), // write type id - self.buf.extend(&ty.0.oid().0.to_be_bytes()); + pg_type => self.buf.extend(&pg_type.oid().0.to_be_bytes()), } self.buf.encode(value)?; diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 391bf452..5ad57546 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"] regexp = ["dep:regex"] +preupdate-hook = ["libsqlite3-sys/preupdate_hook"] + bundled = ["libsqlite3-sys/bundled"] unbundled = ["libsqlite3-sys/buildtime_bindgen"] @@ -48,6 +50,7 @@ atoi = "2.0" log = "0.4.18" tracing = { version = "0.1.37", features = ["log"] } +thiserror = "2.0.0" serde = { version = "1.0.145", features = ["derive"], optional = true } regex = { version = "1.5.5", optional = true } diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 01600d99..2cc58554 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -17,6 +17,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::pin::pin; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite); @@ -105,12 +106,12 @@ impl AnyConnectionBackend for SqliteConnection { let args = arguments.map(map_arguments); Box::pin(async move { - let stream = self - .worker - .execute(query, args, self.row_channel_size, persistent, Some(1)) - .map_ok(flume::Receiver::into_stream) - .await?; - futures_util::pin_mut!(stream); + let mut stream = pin!( + self.worker + .execute(query, args, self.row_channel_size, persistent, Some(1)) + .map_ok(flume::Receiver::into_stream) + .await? + ); if let Some(Either::Right(row)) = stream.try_next().await? { return Ok(Some(AnyRow::try_from(&row)?)); diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 40f9b4c3..5b8aa01b 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -296,6 +296,8 @@ impl EstablishParams { log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, + #[cfg(feature = "preupdate-hook")] + preupdate_hook_callback: None, commit_hook_callback: None, rollback_hook_callback: None, }) diff --git a/sqlx-sqlite/src/connection/executor.rs b/sqlx-sqlite/src/connection/executor.rs index 541a4f7d..1f6ce772 100644 --- a/sqlx-sqlite/src/connection/executor.rs +++ b/sqlx-sqlite/src/connection/executor.rs @@ -8,7 +8,7 @@ use sqlx_core::describe::Describe; use sqlx_core::error::Error; use sqlx_core::executor::{Execute, Executor}; use sqlx_core::Either; -use std::future; +use std::{future, pin::pin}; impl<'c> Executor<'c> for &'c mut SqliteConnection { type Database = Sqlite; @@ -56,13 +56,11 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { let persistent = query.persistent() && arguments.is_some(); Box::pin(async move { - let stream = self + let mut stream = pin!(self .worker .execute(sql, arguments, self.row_channel_size, persistent, Some(1)) .map_ok(flume::Receiver::into_stream) - .try_flatten_stream(); - - futures_util::pin_mut!(stream); + .try_flatten_stream()); while let Some(res) = stream.try_next().await? { if let Either::Right(row) = res { diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index a579b8a6..7412eef1 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -14,6 +14,8 @@ use libsqlite3_sys::{ sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; +#[cfg(feature = "preupdate-hook")] +pub use preupdate_hook::*; pub(crate) use handle::ConnectionHandle; use sqlx_core::common::StatementCache; @@ -26,7 +28,7 @@ use crate::connection::establish::EstablishParams; use crate::connection::worker::ConnectionWorker; use crate::options::OptimizeOnClose; use crate::statement::VirtualStatement; -use crate::{Sqlite, SqliteConnectOptions}; +use crate::{Sqlite, SqliteConnectOptions, SqliteError}; pub(crate) mod collation; pub(crate) mod describe; @@ -36,6 +38,8 @@ mod executor; mod explain; mod handle; pub(crate) mod intmap; +#[cfg(feature = "preupdate-hook")] +mod preupdate_hook; mod worker; @@ -88,6 +92,7 @@ pub struct UpdateHookResult<'a> { pub table: &'a str, pub rowid: i64, } + pub(crate) struct UpdateHookHandler(NonNull); unsafe impl Send for UpdateHookHandler {} @@ -112,6 +117,8 @@ pub(crate) struct ConnectionState { progress_handler_callback: Option, update_hook_callback: Option, + #[cfg(feature = "preupdate-hook")] + preupdate_hook_callback: Option, commit_hook_callback: Option, @@ -138,6 +145,16 @@ impl ConnectionState { } } + #[cfg(feature = "preupdate-hook")] + pub(crate) fn remove_preupdate_hook(&mut self) { + if let Some(mut handler) = self.preupdate_hook_callback.take() { + unsafe { + libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + pub(crate) fn remove_commit_hook(&mut self) { if let Some(mut handler) = self.commit_hook_callback.take() { unsafe { @@ -421,6 +438,34 @@ impl LockedSqliteHandle<'_> { } } + /// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table. + /// At most one preupdate hook may be registered at a time on a single database connection. + /// + /// The preupdate hook only fires for changes to real database tables; + /// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1. + /// + /// See https://sqlite.org/c3ref/preupdate_count.html + #[cfg(feature = "preupdate-hook")] + pub fn set_preupdate_hook(&mut self, callback: F) + where + F: FnMut(PreupdateHookResult) + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_preupdate_hook(); + self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback)); + + libsqlite3_sys::sqlite3_preupdate_hook( + self.as_raw_handle().as_mut(), + Some(preupdate_hook::), + handler, + ); + } + } + /// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback /// returns `false`, then the operation is turned into a ROLLBACK. /// @@ -485,6 +530,11 @@ impl LockedSqliteHandle<'_> { self.guard.remove_update_hook(); } + #[cfg(feature = "preupdate-hook")] + pub fn remove_preupdate_hook(&mut self) { + self.guard.remove_preupdate_hook(); + } + pub fn remove_commit_hook(&mut self) { self.guard.remove_commit_hook(); } @@ -492,6 +542,10 @@ impl LockedSqliteHandle<'_> { pub fn remove_rollback_hook(&mut self) { self.guard.remove_rollback_hook(); } + + pub fn last_error(&mut self) -> Option { + SqliteError::try_new(self.guard.handle.as_ptr()) + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/connection/preupdate_hook.rs b/sqlx-sqlite/src/connection/preupdate_hook.rs new file mode 100644 index 00000000..edcb0781 --- /dev/null +++ b/sqlx-sqlite/src/connection/preupdate_hook.rs @@ -0,0 +1,160 @@ +use super::SqliteOperation; +use crate::type_info::DataType; +use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef}; + +use libsqlite3_sys::{ + sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new, + sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK, +}; +use std::ffi::CStr; +use std::marker::PhantomData; +use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; +use std::ptr; +use std::ptr::NonNull; + +#[derive(Debug, thiserror::Error)] +pub enum PreupdateError { + /// Error returned from the database. + #[error("error returned from database: {0}")] + Database(#[source] SqliteError), + /// Index is not within the valid column range + #[error("{0} is not within the valid column range")] + ColumnIndexOutOfBounds(i32), + /// Column value accessor was invoked from an invalid operation + #[error("column value accessor was invoked from an invalid operation")] + InvalidOperation, +} + +pub(crate) struct PreupdateHookHandler( + pub(super) NonNull, +); +unsafe impl Send for PreupdateHookHandler {} + +#[derive(Debug)] +pub struct PreupdateHookResult<'a> { + pub operation: SqliteOperation, + pub database: &'a str, + pub table: &'a str, + db: *mut sqlite3, + // The database pointer should not be usable after the preupdate hook. + // The lifetime on this struct needs to ensure it cannot outlive the callback. + _db_lifetime: PhantomData<&'a ()>, + old_row_id: i64, + new_row_id: i64, +} + +impl<'a> PreupdateHookResult<'a> { + /// Gets the amount of columns in the row being inserted, deleted, or updated. + pub fn get_column_count(&self) -> i32 { + unsafe { sqlite3_preupdate_count(self.db) } + } + + /// Gets the depth of the query that triggered the preupdate hook. + /// Returns 0 if the preupdate callback was invoked as a result of + /// a direct insert, update, or delete operation; + /// 1 for inserts, updates, or deletes invoked by top-level triggers; + /// 2 for changes resulting from triggers called by top-level triggers; and so forth. + pub fn get_query_depth(&self) -> i32 { + unsafe { sqlite3_preupdate_depth(self.db) } + } + + /// Gets the row id of the row being updated/deleted. + /// Returns an error if called from an insert operation. + pub fn get_old_row_id(&self) -> Result { + if self.operation == SqliteOperation::Insert { + return Err(PreupdateError::InvalidOperation); + } + Ok(self.old_row_id) + } + + /// Gets the row id of the row being inserted/updated. + /// Returns an error if called from a delete operation. + pub fn get_new_row_id(&self) -> Result { + if self.operation == SqliteOperation::Delete { + return Err(PreupdateError::InvalidOperation); + } + Ok(self.new_row_id) + } + + /// Gets the value of the row being updated/deleted at the specified index. + /// Returns an error if called from an insert operation or the index is out of bounds. + pub fn get_old_column_value(&self, i: i32) -> Result, PreupdateError> { + if self.operation == SqliteOperation::Insert { + return Err(PreupdateError::InvalidOperation); + } + self.validate_column_index(i)?; + + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_old(self.db, i, &mut p_value); + self.get_value(ret, p_value) + } + } + + /// Gets the value of the row being inserted/updated at the specified index. + /// Returns an error if called from a delete operation or the index is out of bounds. + pub fn get_new_column_value(&self, i: i32) -> Result, PreupdateError> { + if self.operation == SqliteOperation::Delete { + return Err(PreupdateError::InvalidOperation); + } + self.validate_column_index(i)?; + + let mut p_value: *mut sqlite3_value = ptr::null_mut(); + unsafe { + let ret = sqlite3_preupdate_new(self.db, i, &mut p_value); + self.get_value(ret, p_value) + } + } + + fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> { + if i < 0 || i >= self.get_column_count() { + return Err(PreupdateError::ColumnIndexOutOfBounds(i)); + } + Ok(()) + } + + unsafe fn get_value( + &self, + ret: i32, + p_value: *mut sqlite3_value, + ) -> Result, PreupdateError> { + if ret != SQLITE_OK { + return Err(PreupdateError::Database(SqliteError::new(self.db))); + } + let data_type = DataType::from_code(sqlite3_value_type(p_value)); + // SAFETY: SQLite will free the sqlite3_value when the callback returns + Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type))) + } +} + +pub(super) extern "C" fn preupdate_hook( + callback: *mut c_void, + db: *mut sqlite3, + op_code: c_int, + database: *const c_char, + table: *const c_char, + old_row_id: i64, + new_row_id: i64, +) where + F: FnMut(PreupdateHookResult) + Send + 'static, +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + let operation: SqliteOperation = op_code.into(); + let database = CStr::from_ptr(database).to_str().unwrap_or_default(); + let table = CStr::from_ptr(table).to_str().unwrap_or_default(); + + (*callback)(PreupdateHookResult { + operation, + database, + table, + old_row_id, + new_row_id, + db, + _db_lifetime: PhantomData, + }) + }); + } +} diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index a01de241..c1c67636 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -151,7 +151,8 @@ impl ConnectionWorker { match limit { None => { for res in iter { - if tx.send(res).is_err() { + let has_error = res.is_err(); + if tx.send(res).is_err() || has_error { break; } } @@ -171,7 +172,8 @@ impl ConnectionWorker { } } } - if tx.send(res).is_err() { + let has_error = res.is_err(); + if tx.send(res).is_err() || has_error { break; } } diff --git a/sqlx-sqlite/src/error.rs b/sqlx-sqlite/src/error.rs index c00374fe..0d34bc10 100644 --- a/sqlx-sqlite/src/error.rs +++ b/sqlx-sqlite/src/error.rs @@ -23,9 +23,17 @@ pub struct SqliteError { impl SqliteError { pub(crate) fn new(handle: *mut sqlite3) -> Self { + Self::try_new(handle).expect("There should be an error") + } + + pub(crate) fn try_new(handle: *mut sqlite3) -> Option { // returns the extended result code even when extended result codes are disabled let code: c_int = unsafe { sqlite3_extended_errcode(handle) }; + if code == 0 { + return None; + } + // return English-language text that describes the error let message = unsafe { let msg = sqlite3_errmsg(handle); @@ -34,10 +42,10 @@ impl SqliteError { from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()) }; - Self { + Some(Self { code, message: message.to_owned(), - } + }) } /// For errors during extension load, the error message is supplied via a separate pointer diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index 3bcb6d14..f1a45c3d 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,6 +46,8 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; +#[cfg(feature = "preupdate-hook")] +pub use connection::PreupdateHookResult; pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; pub use database::Sqlite; pub use error::SqliteError; diff --git a/sqlx-sqlite/src/testing/mod.rs b/sqlx-sqlite/src/testing/mod.rs index 3398c6b4..324b5191 100644 --- a/sqlx-sqlite/src/testing/mod.rs +++ b/sqlx-sqlite/src/testing/mod.rs @@ -30,6 +30,10 @@ impl TestSupport for Sqlite { ) -> BoxFuture<'_, Result, Error>> { todo!() } + + fn db_name(args: &TestArgs) -> String { + convert_path(args.test_path) + } } async fn test_context(args: &TestArgs) -> Result, Error> { diff --git a/sqlx-sqlite/src/value.rs b/sqlx-sqlite/src/value.rs index 967b3f74..469c4e70 100644 --- a/sqlx-sqlite/src/value.rs +++ b/sqlx-sqlite/src/value.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::marker::PhantomData; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::from_utf8; @@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo}; enum SqliteValueData<'r> { Value(&'r SqliteValue), + BorrowedHandle(ValueHandle<'r>), } pub struct SqliteValueRef<'r>(SqliteValueData<'r>); @@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> { Self(SqliteValueData::Value(value)) } + // SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop. + // The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it. + #[allow(unused)] + pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info); + Self(SqliteValueData::BorrowedHandle(handle)) + } + // NOTE: `int()` is deliberately omitted because it will silently truncate a wider value, // which is likely to cause bugs: // https://github.com/launchbadge/sqlx/issues/3179 // (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161 pub(super) fn int64(&self) -> i64 { - match self.0 { - SqliteValueData::Value(v) => v.int64(), + match &self.0 { + SqliteValueData::Value(v) => v.0.int64(), + SqliteValueData::BorrowedHandle(v) => v.int64(), } } pub(super) fn double(&self) -> f64 { - match self.0 { - SqliteValueData::Value(v) => v.double(), + match &self.0 { + SqliteValueData::Value(v) => v.0.double(), + SqliteValueData::BorrowedHandle(v) => v.double(), } } pub(super) fn blob(&self) -> &'r [u8] { - match self.0 { - SqliteValueData::Value(v) => v.blob(), + match &self.0 { + SqliteValueData::Value(v) => v.0.blob(), + SqliteValueData::BorrowedHandle(v) => v.blob(), } } pub(super) fn text(&self) -> Result<&'r str, BoxDynError> { - match self.0 { - SqliteValueData::Value(v) => v.text(), + match &self.0 { + SqliteValueData::Value(v) => v.0.text(), + SqliteValueData::BorrowedHandle(v) => v.text(), } } } @@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { type Database = Sqlite; fn to_owned(&self) -> SqliteValue { - match self.0 { - SqliteValueData::Value(v) => v.clone(), + match &self.0 { + SqliteValueData::Value(v) => (*v).clone(), + SqliteValueData::BorrowedHandle(v) => unsafe { + SqliteValue::new(v.value.as_ptr(), v.type_info.clone()) + }, } } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.type_info(), + SqliteValueData::BorrowedHandle(v) => v.type_info(), } } fn is_null(&self) -> bool { - match self.0 { + match &self.0 { SqliteValueData::Value(v) => v.is_null(), + SqliteValueData::BorrowedHandle(v) => v.is_null(), } } } #[derive(Clone)] -pub struct SqliteValue { - pub(crate) handle: Arc, - pub(crate) type_info: SqliteTypeInfo, +pub struct SqliteValue(Arc>); + +pub(crate) struct ValueHandle<'a> { + value: NonNull, + type_info: SqliteTypeInfo, + free_on_drop: bool, + _sqlite_value_lifetime: PhantomData<&'a ()>, } -pub(crate) struct ValueHandle(NonNull); - // SAFE: only protected value objects are stored in SqliteValue -unsafe impl Send for ValueHandle {} -unsafe impl Sync for ValueHandle {} - -impl SqliteValue { - pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { - debug_assert!(!value.is_null()); +unsafe impl<'a> Send for ValueHandle<'a> {} +unsafe impl<'a> Sync for ValueHandle<'a> {} +impl ValueHandle<'static> { + fn new_owned(value: NonNull, type_info: SqliteTypeInfo) -> Self { Self { + value, type_info, - handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup( - value, - )))), + free_on_drop: true, + _sqlite_value_lifetime: PhantomData, + } + } +} + +impl<'a> ValueHandle<'a> { + fn new_borrowed(value: NonNull, type_info: SqliteTypeInfo) -> Self { + Self { + value, + type_info, + free_on_drop: false, + _sqlite_value_lifetime: PhantomData, } } fn type_info_opt(&self) -> Option { - let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) }); + let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) }); if let DataType::Null = dt { None @@ -112,15 +143,15 @@ impl SqliteValue { } fn int64(&self) -> i64 { - unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_int64(self.value.as_ptr()) } } fn double(&self) -> f64 { - unsafe { sqlite3_value_double(self.handle.0.as_ptr()) } + unsafe { sqlite3_value_double(self.value.as_ptr()) } } - fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) }; + fn blob<'b>(&self) -> &'b [u8] { + let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) }; // This likely means UB in SQLite itself or our usage of it; // signed integer overflow is UB in the C standard. @@ -133,15 +164,45 @@ impl SqliteValue { return &[]; } - let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8; + let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } } - fn text(&self) -> Result<&str, BoxDynError> { + fn text<'b>(&self) -> Result<&'b str, BoxDynError> { Ok(from_utf8(self.blob())?) } + + fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { + self.type_info_opt() + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed(&self.type_info)) + } + + fn is_null(&self) -> bool { + unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL } + } +} + +impl<'a> Drop for ValueHandle<'a> { + fn drop(&mut self) { + if self.free_on_drop { + unsafe { + sqlite3_value_free(self.value.as_ptr()); + } + } + } +} + +impl SqliteValue { + // SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop. + pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { + debug_assert!(!value.is_null()); + let handle = + ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info); + Self(Arc::new(handle)) + } } impl Value for SqliteValue { @@ -152,21 +213,11 @@ impl Value for SqliteValue { } fn type_info(&self) -> Cow<'_, SqliteTypeInfo> { - self.type_info_opt() - .map(Cow::Owned) - .unwrap_or(Cow::Borrowed(&self.type_info)) + self.0.type_info() } fn is_null(&self) -> bool { - unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL } - } -} - -impl Drop for ValueHandle { - fn drop(&mut self) { - unsafe { - sqlite3_value_free(self.0.as_ptr()); - } + self.0.is_null() } } diff --git a/src/lib.rs b/src/lib.rs index aaa0e819..191c3564 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,14 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("lib.md")] +#[cfg(all( + feature = "sqlite-preupdate-hook", + not(any(feature = "sqlite", feature = "sqlite-unbundled")) +))] +compile_error!( + "sqlite-preupdate-hook requires either 'sqlite' or 'sqlite-unbundled' to be enabled" +); + pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; diff --git a/tests/mysql/macros.rs b/tests/mysql/macros.rs index f6bc7595..8187f6d8 100644 --- a/tests/mysql/macros.rs +++ b/tests/mysql/macros.rs @@ -494,6 +494,31 @@ async fn test_from_row_json_attr() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_from_row_json_attr_nullable() -> anyhow::Result<()> { + #[derive(serde::Deserialize)] + #[allow(dead_code)] + struct J { + a: u32, + b: u32, + } + + #[derive(sqlx::FromRow)] + struct Record { + #[sqlx(json(nullable))] + j: Option, + } + + let mut conn = new::().await?; + + let record = sqlx::query_as::<_, Record>("select NULL as j") + .fetch_one(&mut conn) + .await?; + + assert!(record.j.is_none()); + Ok(()) +} + #[sqlx_macros::test] async fn test_from_row_json_try_from_attr() -> anyhow::Result<()> { #[derive(serde::Deserialize)] diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index dada74fe..13f9bf1d 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -810,3 +810,69 @@ async fn test_custom_pg_array() -> anyhow::Result<()> { } Ok(()) } + +#[sqlx_macros::test] +async fn test_record_array_type() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute( + r#" +DROP TABLE IF EXISTS responses; + +DROP TYPE IF EXISTS http_response CASCADE; +DROP TYPE IF EXISTS header_pair CASCADE; + +CREATE TYPE header_pair AS ( + name TEXT, + value TEXT +); + +CREATE TYPE http_response AS ( + headers header_pair[] +); + +CREATE TABLE responses ( + response http_response NOT NULL +); + "#, + ) + .await?; + + #[derive(Debug, sqlx::Type)] + #[sqlx(type_name = "http_response")] + struct HttpResponseRecord { + headers: Vec, + } + + #[derive(Debug, sqlx::Type)] + #[sqlx(type_name = "header_pair")] + struct HeaderPairRecord { + name: String, + value: String, + } + + let value = HttpResponseRecord { + headers: vec![ + HeaderPairRecord { + name: "Content-Type".to_owned(), + value: "text/html; charset=utf-8".to_owned(), + }, + HeaderPairRecord { + name: "Cache-Control".to_owned(), + value: "max-age=0".to_owned(), + }, + ], + }; + + sqlx::query( + " +INSERT INTO responses (response) +VALUES ($1) + ", + ) + .bind(&value) + .execute(&mut conn) + .await?; + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 87a18db5..7de4a9cd 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -3,13 +3,13 @@ use futures::{Stream, StreamExt, TryStreamExt}; use sqlx::postgres::types::Oid; use sqlx::postgres::{ PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener, - PgPoolOptions, PgRow, PgSeverity, Postgres, + PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; -use std::pin::Pin; +use std::pin::{pin, Pin}; use std::sync::Arc; use std::time::Duration; @@ -637,8 +637,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { let pool = pool.clone(); sqlx_core::rt::spawn(async move { while !pool.is_closed() { - let acquire = pool.acquire(); - futures::pin_mut!(acquire); + let mut acquire = pin!(pool.acquire()); // poll the acquire future once to put the waiter in the queue future::poll_fn(move |cx| { @@ -2042,3 +2041,78 @@ async fn test_issue_3052() { "expected encode error, got {too_large_error:?}", ); } + +#[sqlx_macros::test] +async fn test_pg_copy_chunked() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut row = "1".repeat(PG_COPY_MAX_DATA_LEN / 10 - 1); + row.push_str("\n"); + + // creates a payload with COPY_MAX_DATA_LEN + 1 as size + let mut payload = row.repeat(10); + payload.push_str("12345678\n"); + + assert_eq!(payload.len(), PG_COPY_MAX_DATA_LEN + 1); + + let mut copy = conn.copy_in_raw("COPY products(name) FROM STDIN").await?; + + assert!(copy.send(payload.as_bytes()).await.is_ok()); + assert!(copy.finish().await.is_ok()); + Ok(()) +} + +async fn test_copy_in_error_case(query: &str, expected_error: &str) -> anyhow::Result<()> { + let mut conn = new::().await?; + conn.execute("CREATE TEMPORARY TABLE IF NOT EXISTS invalid_copy_target (id int4)") + .await?; + // Try the COPY operation + match conn.copy_in_raw(query).await { + Ok(_) => anyhow::bail!("expected error"), + Err(e) => assert!( + e.to_string().contains(expected_error), + "expected error to contain: {expected_error}, got: {e:?}" + ), + } + // Verify connection is still usable + let value = sqlx::query("select 1 + 1") + .try_map(|row: PgRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + assert_eq!(2i32, value); + Ok(()) +} +#[sqlx_macros::test] +async fn it_can_recover_from_copy_in_to_missing_table() -> anyhow::Result<()> { + test_copy_in_error_case( + r#" + COPY nonexistent_table (id) FROM STDIN WITH (FORMAT CSV, HEADER); + "#, + "does not exist", + ) + .await +} +#[sqlx_macros::test] +async fn it_can_recover_from_copy_in_empty_query() -> anyhow::Result<()> { + test_copy_in_error_case("", "EmptyQuery").await +} +#[sqlx_macros::test] +async fn it_can_recover_from_copy_in_syntax_error() -> anyhow::Result<()> { + test_copy_in_error_case( + r#" + COPY FROM STDIN WITH (FORMAT CSV); + "#, + "syntax error", + ) + .await +} +#[sqlx_macros::test] +async fn it_can_recover_from_copy_in_invalid_params() -> anyhow::Result<()> { + test_copy_in_error_case( + r#" + COPY invalid_copy_target FROM STDIN WITH (FORMAT CSV, INVALID_PARAM true); + "#, + "invalid_param", + ) + .await +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 3f6c3620..ccf88b10 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -509,6 +509,21 @@ test_type!(line(Postgres, "line('((0.0, 0.0), (1.0,1.0))')" == sqlx::postgres::types::PgLine { a: 1., b: -1., c: 0. }, )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(lseg(Postgres, + "lseg('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(box(Postgres, + "box('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1. , lower_left_y: 2.}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_box>(Postgres, + "array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }], +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index b733ccbb..16b4b2d9 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,11 +2,14 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; +use sqlx::Decode; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx::{Value, ValueRef}; use sqlx_test::new; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; #[sqlx_macros::test] @@ -798,7 +801,7 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow:: #[sqlx_macros::test] async fn test_query_with_update_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_update_hook(move |result| { @@ -807,11 +810,13 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> { assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); assert_eq!(result.rowid, 2); + CALLED.store(true, Ordering::Relaxed); }); let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )") .execute(&mut conn) .await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -852,10 +857,11 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul #[sqlx_macros::test] async fn test_query_with_commit_hook() -> anyhow::Result<()> { let mut conn = new::().await?; - + static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); conn.lock_handle().await?.set_commit_hook(move || { + CALLED.store(true, Ordering::Relaxed); assert_eq!(state, "test"); false }); @@ -870,7 +876,7 @@ async fn test_query_with_commit_hook() -> anyhow::Result<()> { } _ => panic!("expected an error"), } - + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -916,8 +922,10 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. let state = format!("test"); + static CALLED: AtomicBool = AtomicBool::new(false); conn.lock_handle().await?.set_rollback_hook(move || { assert_eq!(state, "test"); + CALLED.store(true, Ordering::Relaxed); }); let mut tx = conn.begin().await?; @@ -925,6 +933,7 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> { .execute(&mut *tx) .await?; tx.rollback().await?; + assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } @@ -960,3 +969,227 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + let mut conn = new::().await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook({ + move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Insert); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + assert_eq!(4, result.get_column_count()); + assert_eq!(2, result.get_new_row_id().unwrap()); + assert_eq!(0, result.get_query_depth()); + assert_eq!( + 4, + >::decode(result.get_new_column_value(0).unwrap()).unwrap() + ); + assert_eq!( + "Hello, World", + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() + ); + // out of bounds access should return an error + assert!(result.get_new_column_value(4).is_err()); + // old values aren't available for inserts + assert!(result.get_old_column_value(0).is_err()); + assert!(result.get_old_row_id().is_err()); + + CALLED.store(true, Ordering::Relaxed); + } + }); + + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + conn.lock_handle().await?.remove_preupdate_hook(); + let _ = sqlx::query("DELETE FROM tweet where id = 4") + .execute(&mut conn) + .await?; + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook(move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Delete); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + assert_eq!(4, result.get_column_count()); + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(0, result.get_query_depth()); + assert_eq!( + 5, + >::decode(result.get_old_column_value(0).unwrap()).unwrap() + ); + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap()).unwrap() + ); + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + // new values aren't available for deletes + assert!(result.get_new_column_value(0).is_err()); + assert!(result.get_new_row_id().is_err()); + + CALLED.store(true, Ordering::Relaxed); + }); + + let _ = sqlx::query("DELETE FROM tweet WHERE id = 5") + .execute(&mut conn) + .await?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + let mut conn = new::().await?; + let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") + .execute(&mut conn) + .await?; + static CALLED: AtomicBool = AtomicBool::new(false); + let sqlite_value_stored: Arc>> = Default::default(); + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_preupdate_hook({ + let sqlite_value_stored = sqlite_value_stored.clone(); + move |result| { + assert_eq!(state, "test"); + assert_eq!(result.operation, SqliteOperation::Update); + assert_eq!(result.database, "main"); + assert_eq!(result.table, "tweet"); + + assert_eq!(4, result.get_column_count()); + assert_eq!(4, result.get_column_count()); + + assert_eq!(2, result.get_old_row_id().unwrap()); + assert_eq!(2, result.get_new_row_id().unwrap()); + + assert_eq!(0, result.get_query_depth()); + assert_eq!(0, result.get_query_depth()); + + assert_eq!( + 6, + >::decode(result.get_old_column_value(0).unwrap()).unwrap() + ); + assert_eq!( + 6, + >::decode(result.get_new_column_value(0).unwrap()).unwrap() + ); + + assert_eq!( + "Hello, World", + >::decode(result.get_old_column_value(1).unwrap()) + .unwrap() + ); + assert_eq!( + "Hello, World2", + >::decode(result.get_new_column_value(1).unwrap()) + .unwrap() + ); + *sqlite_value_stored.lock().unwrap() = + Some(result.get_old_column_value(0).unwrap().to_owned()); + + // out of bounds access should return an error + assert!(result.get_old_column_value(4).is_err()); + assert!(result.get_new_column_value(4).is_err()); + + CALLED.store(true, Ordering::Relaxed); + } + }); + + let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6") + .execute(&mut conn) + .await?; + + assert!(CALLED.load(Ordering::Relaxed)); + conn.lock_handle().await?.remove_preupdate_hook(); + let _ = sqlx::query("DELETE FROM tweet where id = 6") + .execute(&mut conn) + .await?; + // Ensure that taking an owned SqliteValue maintains a valid reference after the callback returns + assert_eq!( + 6, + >::decode( + sqlite_value_stored.lock().unwrap().take().unwrap().as_ref() + ) + .unwrap() + ); + Ok(()) +} + +#[cfg(feature = "sqlite-preupdate-hook")] +#[sqlx_macros::test] +async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_preupdate_hook(move |_| { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_preupdate_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_get_last_error() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = sqlx::query("select 1").fetch_one(&mut conn).await?; + + { + let mut handle = conn.lock_handle().await?; + assert!(handle.last_error().is_none()); + } + + let _ = sqlx::query("invalid statement").fetch_one(&mut conn).await; + + { + let mut handle = conn.lock_handle().await?; + assert!(handle.last_error().is_some()); + } + + Ok(()) +}