mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-09-28 13:31:41 +00:00
Merge remote-tracking branch 'origin/main' into sqlx-toml
# Conflicts: # Cargo.lock # Cargo.toml # sqlx-cli/src/database.rs # sqlx-cli/src/lib.rs # sqlx-mysql/src/connection/executor.rs
This commit is contained in:
commit
8429f2e989
50
.github/pull_request_template.md
vendored
50
.github/pull_request_template.md
vendored
@ -1,2 +1,50 @@
|
||||
<!--
|
||||
PR AUTHOR INSTRUCTIONS; PLEASE READ.
|
||||
|
||||
Give your pull request an accurate and descriptive title. It should mention what component(s) or database driver(s) it touches.
|
||||
Pull requests with undescriptive or inaccurate titles *may* be closed or have their titles changed before merging.
|
||||
|
||||
Fill out the fields below.
|
||||
|
||||
All pull requests *must* pass CI to be merged. Check your pull request frequently for build failures until all checks pass.
|
||||
Address build failures by pushing new commits or amending existing ones. Feel free to ask for help if you get stuck.
|
||||
If a failure seems spurious (timeout or cache failure), you may push a new commit to re-run it.
|
||||
|
||||
After addressing review comments, re-request review to show that you are ready for your PR to be looked at again.
|
||||
|
||||
Pull requests which sit for a long time with broken CI or unaddressed review comments will be closed to clear the backlog.
|
||||
If this happens, you are welcome to open a new pull request, but please be sure to address the feedback you have received previously.
|
||||
|
||||
Bug fixes should include a regression test which fails before the fix and passes afterwards. If this is infeasible, please explain why.
|
||||
|
||||
New features *should* include unit or integration tests in the appropriate folders. Database specific tests should go in `tests/<database>`.
|
||||
|
||||
Note that unsolicited pull requests implementing large or complex changes may not be reviwed right away.
|
||||
Maintainer time and energy is limited and massive unsolicited pull requests require an outsized effort to review.
|
||||
|
||||
To make the best use of your time and ours, search for and participate in existing discussion on the issue tracker before opening a pull request.
|
||||
The solution you came up with may have already been rejected or postponed due to other work needing to be done first,
|
||||
or there may be a pending solution going down a different direction that you hadn't considered.
|
||||
|
||||
Pull requests that take existing discussion into account are the most likely to be merged.
|
||||
|
||||
Delete this block comment before submission to show that you have read and understand these instructions.
|
||||
-->
|
||||
|
||||
### Does your PR solve an issue?
|
||||
### Delete this text and add "fixes #(issue number)"
|
||||
Delete this text and add "fixes #(issue number)".
|
||||
|
||||
Do *not* just list issue numbers here as they will not be automatically closed on merging this pull request unless prefixed with "fixes" or "closes".
|
||||
|
||||
### Is this a breaking change?
|
||||
Delete this text and answer yes/no and explain.
|
||||
|
||||
If yes, this pull request will need to wait for the next major release (`0.{x + 1}.0`)
|
||||
|
||||
Behavior changes _can_ be breaking if significant enough.
|
||||
Consider [Hyrum's Law](https://www.hyrumslaw.com/):
|
||||
|
||||
> With a sufficient number of users of an API,
|
||||
> it does not matter what you promise in the contract:
|
||||
> all observable behaviors of your system
|
||||
> will be depended on by somebody.
|
||||
|
6
.github/workflows/sqlx.yml
vendored
6
.github/workflows/sqlx.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
- run: >
|
||||
cargo clippy
|
||||
--no-default-features
|
||||
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
-- -D warnings
|
||||
|
||||
# Run beta for new warnings but don't break the build.
|
||||
@ -47,7 +47,7 @@ jobs:
|
||||
- run: >
|
||||
cargo +beta clippy
|
||||
--no-default-features
|
||||
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--target-dir target/beta/
|
||||
|
||||
check-minimal-versions:
|
||||
@ -140,7 +140,7 @@ jobs:
|
||||
- run: >
|
||||
cargo test
|
||||
--no-default-features
|
||||
--features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }}
|
||||
--features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }}
|
||||
--
|
||||
--test-threads=1
|
||||
env:
|
||||
|
639
Cargo.lock
generated
639
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -76,9 +76,10 @@ _unstable-all-types = [
|
||||
"mac_address",
|
||||
"uuid",
|
||||
"bit-vec",
|
||||
"bstr"
|
||||
]
|
||||
# Render documentation that wouldn't otherwise be shown (e.g. `sqlx_core::config`).
|
||||
_unstable-doc = []
|
||||
_unstable-doc = ["sqlite-preupdate-hook"]
|
||||
|
||||
# Base runtime features without TLS
|
||||
runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"]
|
||||
@ -114,6 +115,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"]
|
||||
mysql = ["sqlx-mysql", "sqlx-macros?/mysql"]
|
||||
sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"]
|
||||
sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"]
|
||||
sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"]
|
||||
|
||||
# types
|
||||
json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"]
|
||||
@ -127,6 +129,7 @@ rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mys
|
||||
time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"]
|
||||
uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"]
|
||||
regexp = ["sqlx-sqlite?/regexp"]
|
||||
bstr = ["sqlx-core/bstr"]
|
||||
|
||||
[workspace.dependencies]
|
||||
# Core Crates
|
||||
|
56
FAQ.md
56
FAQ.md
@ -36,6 +36,62 @@ as they can often be a whole year or more out-of-date.
|
||||
|
||||
[`rust-version`]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-rust-version-field
|
||||
|
||||
----------------------------------------------------------------
|
||||
|
||||
### Can SQLx Add Support for New Databases?
|
||||
|
||||
We are always open to discuss adding support for new databases, but as of writing, have no plans to in the short term.
|
||||
|
||||
Implementing support for a new database in SQLx is a _huge_ lift. Expecting this work to be done for free is highly unrealistic.
|
||||
In all likelihood, the implementation would need to be written from scratch.
|
||||
Even if Rust bindings exist, they may not support `async`.
|
||||
Even if they support `async`, they may only support either Tokio or `async-std`, and not both.
|
||||
Even if they support Tokio and `async-std`, the API may not be flexible enough or provide sufficient information (e.g. for implementing the macros).
|
||||
|
||||
If we have to write the implementation from scratch, is the protocol publicly documented, and stable?
|
||||
|
||||
Even if everything is supported on the client side, how will we run tests against the database? Is it open-source, or proprietary? Will it require a paid license?
|
||||
|
||||
For example, Oracle Database's protocol is proprietary and only supported through their own libraries, which do not support Rust, and only have blocking APIs (see: [Oracle Call Interface for C](https://docs.oracle.com/en/database/oracle/oracle-database/23/lnoci/index.html)).
|
||||
This makes it a poor candidate for an async-native crate like SQLx--though we support SQLite, which also only has a blocking API, that's the exception and not the rule. Wrapping blocking APIs is not very scalable.
|
||||
|
||||
We still have plans to bring back the MSSQL driver, but this is not feasible as of writing with the current maintenance workload. Should this change, an announcement will be made on Github as well as our [Discord server](https://discord.gg/uuruzJ7).
|
||||
|
||||
### What If I'm Willing to Contribute the Implementation?
|
||||
|
||||
Being willing to contribute an implementation for a new database is one thing, but there's also the ongoing maintenance burden to consider.
|
||||
|
||||
Are you willing to provide support long-term?
|
||||
Will there be enough users that we can rely on outside contributions?
|
||||
Or is support going to fall to the current maintainer(s)?
|
||||
|
||||
This is the kind of thing that will need to be supported in SQLx _long_ after the initial implementation, or else later need to be removed.
|
||||
If you don't have plans for how to support a new driver long-term, then it doesn't belong as part of SQLx itself.
|
||||
|
||||
However, drivers don't necessarily need to live _in_ SQLx anymore. Since 0.7.0, drivers don't need to be compiled-in to be functional.
|
||||
Support for third-party drivers in `sqlx-cli` and the `query!()` macros is pending, as well as documenting the process of writing a driver, but contributions are welcome in this regard.
|
||||
|
||||
For example, see [sqlx-exasol](https://crates.io/crates/sqlx-exasol).
|
||||
|
||||
----------------------------------------------------------------
|
||||
### Can SQLx Add Support for New Data-Type Crates (e.g. Jiff in addition to `chrono` and `time`)?
|
||||
|
||||
This has a lot of the same considerations as adding support for new databases (see above), but with one big additional problem: Semantic Versioning.
|
||||
|
||||
When we add trait implementations for types from an external crate, that crate then becomes part of our public API. We become beholden to its release cycle.
|
||||
|
||||
If the crate's API is still evolving, meaning they are making breaking changes frequently, and thus releasing new major versions frequently, that then becomes a burden on us to upgrade and release a new major version as well so everyone _else_ can upgrade.
|
||||
|
||||
We don't have the maintainer bandwidth to support multiple major versions simultaneously (we have no Long-Term Support policy), so this means that users who want to keep up-to-date are forced to make frequent manual upgrades as well.
|
||||
|
||||
Thus, it is best that we stick to only supporting crates which have a stable API, and which are not making new major releases frequently.
|
||||
|
||||
Conversely, adding support for SQLx _in_ these crates may not be desirable either, since SQLx is a large dependency and a higher-level crate. In this case, the SemVer problem gets pushed onto the other crate.
|
||||
|
||||
There isn't a satisfying answer to this problem, but one option is to have an intermediate wrapper crate.
|
||||
For example, [`jiff-sqlx`](https://crates.io/crates/jiff-sqlx), which is maintained by the author of Jiff.
|
||||
API changes to SQLx are pending to make this pattern easier to use.
|
||||
|
||||
----------------------------------------------------------------
|
||||
### I'm getting `HandshakeFailure` or `CorruptMessage` when trying to connect to a server over TLS using RusTLS. What gives?
|
||||
|
||||
|
@ -196,6 +196,10 @@ be removed in the future.
|
||||
* May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended.
|
||||
* Can increase build time due to the use of bindgen.
|
||||
|
||||
- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API.
|
||||
* Exposed as a separate feature because it's generally not enabled by default.
|
||||
* Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it.
|
||||
|
||||
- `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime.
|
||||
|
||||
- `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`.
|
||||
@ -204,7 +208,7 @@ be removed in the future.
|
||||
|
||||
- `migrate`: Add support for the migration management and `migrate!` macro, which allow compile-time embedded migrations.
|
||||
|
||||
- `uuid`: Add support for UUID (in Postgres).
|
||||
- `uuid`: Add support for UUID.
|
||||
|
||||
- `chrono`: Add support for date and time types from `chrono`.
|
||||
|
||||
|
@ -26,7 +26,7 @@ path = "src/bin/cargo-sqlx.rs"
|
||||
|
||||
[dependencies]
|
||||
dotenvy = "0.15.0"
|
||||
tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tokio = { version = "1.15.0", features = ["macros", "rt", "rt-multi-thread", "signal"] }
|
||||
sqlx = { workspace = true, default-features = false, features = [
|
||||
"runtime-tokio",
|
||||
"migrate",
|
||||
@ -37,9 +37,8 @@ clap = { version = "4.3.10", features = ["derive", "env"] }
|
||||
clap_complete = { version = "4.3.1", optional = true }
|
||||
chrono = { version = "0.4.19", default-features = false, features = ["clock"] }
|
||||
anyhow = "1.0.52"
|
||||
async-trait = "0.1.52"
|
||||
console = "0.15.0"
|
||||
promptly = "0.3.0"
|
||||
dialoguer = { version = "0.11", default-features = false }
|
||||
serde_json = "1.0.73"
|
||||
glob = "0.3.0"
|
||||
openssl = { version = "0.10.38", optional = true }
|
||||
|
@ -13,9 +13,12 @@ enum Cli {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
let Cli::Sqlx(opt) = Cli::parse();
|
||||
|
||||
if !opt.no_dotenv {
|
||||
dotenvy::dotenv().ok();
|
||||
}
|
||||
|
||||
if let Err(error) = sqlx_cli::run(opt).await {
|
||||
println!("{} {}", style("error:").bold().red(), error);
|
||||
process::exit(1);
|
||||
|
@ -4,9 +4,14 @@ use sqlx_cli::Opt;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
let opt = Opt::parse();
|
||||
|
||||
if !opt.no_dotenv {
|
||||
dotenvy::dotenv().ok();
|
||||
}
|
||||
|
||||
// no special handling here
|
||||
if let Err(error) = sqlx_cli::run(Opt::parse()).await {
|
||||
if let Err(error) = sqlx_cli::run(opt).await {
|
||||
println!("{} {}", style("error:").bold().red(), error);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
use crate::opt::{ConnectOpts, MigrationSourceOpt};
|
||||
use crate::{migrate, Config};
|
||||
use console::style;
|
||||
use promptly::{prompt, ReadlineError};
|
||||
use console::{style, Term};
|
||||
use dialoguer::Confirm;
|
||||
use sqlx::any::Any;
|
||||
use sqlx::migrate::MigrateDatabase;
|
||||
use std::{io, mem};
|
||||
use tokio::task;
|
||||
|
||||
pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> {
|
||||
// NOTE: only retry the idempotent action.
|
||||
@ -24,7 +26,7 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
pub async fn drop(connect_opts: &ConnectOpts, confirm: bool, force: bool) -> anyhow::Result<()> {
|
||||
if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?) {
|
||||
if confirm && !ask_to_continue_drop(connect_opts.expect_db_url()?.to_owned()).await {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@ -63,27 +65,46 @@ pub async fn setup(
|
||||
migrate::run(config, migration_source, connect_opts, false, false, None).await
|
||||
}
|
||||
|
||||
fn ask_to_continue_drop(db_url: &str) -> bool {
|
||||
loop {
|
||||
let r: Result<String, ReadlineError> =
|
||||
prompt(format!("Drop database at {}? (y/n)", style(db_url).cyan()));
|
||||
match r {
|
||||
Ok(response) => {
|
||||
if response == "n" || response == "N" {
|
||||
return false;
|
||||
} else if response == "y" || response == "Y" {
|
||||
return true;
|
||||
} else {
|
||||
println!(
|
||||
"Response not recognized: {}\nPlease type 'y' or 'n' and press enter.",
|
||||
response
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{e}");
|
||||
return false;
|
||||
async fn ask_to_continue_drop(db_url: String) -> bool {
|
||||
// If the setup operation is cancelled while we are waiting for the user to decide whether
|
||||
// or not to drop the database, this will restore the terminal's cursor to its normal state.
|
||||
struct RestoreCursorGuard {
|
||||
disarmed: bool,
|
||||
}
|
||||
|
||||
impl Drop for RestoreCursorGuard {
|
||||
fn drop(&mut self) {
|
||||
if !self.disarmed {
|
||||
Term::stderr().show_cursor().unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut guard = RestoreCursorGuard { disarmed: false };
|
||||
|
||||
let decision_result = task::spawn_blocking(move || {
|
||||
Confirm::new()
|
||||
.with_prompt(format!("Drop database at {}?", style(&db_url).cyan()))
|
||||
.wait_for_newline(true)
|
||||
.default(false)
|
||||
.show_default(true)
|
||||
.interact()
|
||||
})
|
||||
.await
|
||||
.expect("Confirm thread panicked");
|
||||
match decision_result {
|
||||
Ok(decision) => {
|
||||
guard.disarmed = true;
|
||||
decision
|
||||
}
|
||||
Err(dialoguer::Error::IO(err)) if err.kind() == io::ErrorKind::Interrupted => {
|
||||
// Sometimes CTRL + C causes this error to be returned
|
||||
mem::drop(guard);
|
||||
false
|
||||
}
|
||||
Err(err) => {
|
||||
mem::drop(guard);
|
||||
panic!("Confirm dialog failed with {err}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ use anyhow::{Context, Result};
|
||||
use futures::{Future, TryFutureExt};
|
||||
|
||||
use sqlx::{AnyConnection, Connection};
|
||||
use tokio::{select, signal};
|
||||
|
||||
use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand};
|
||||
|
||||
@ -24,6 +25,26 @@ pub use crate::opt::Opt;
|
||||
pub use sqlx::_unstable::config::{self, Config};
|
||||
|
||||
pub async fn run(opt: Opt) -> Result<()> {
|
||||
// This `select!` is here so that when the process receives a `SIGINT` (CTRL + C),
|
||||
// the futures currently running on this task get dropped before the program exits.
|
||||
// This is currently necessary for the consumers of the `dialoguer` crate to restore
|
||||
// the user's terminal if the process is interrupted while a dialog is being displayed.
|
||||
|
||||
let ctrlc_fut = signal::ctrl_c();
|
||||
let do_run_fut = do_run(opt);
|
||||
|
||||
select! {
|
||||
biased;
|
||||
_ = ctrlc_fut => {
|
||||
Ok(())
|
||||
},
|
||||
do_run_outcome = do_run_fut => {
|
||||
do_run_outcome
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_run(opt: Opt) -> Result<()> {
|
||||
let config = config_from_current_dir().await?;
|
||||
|
||||
match opt.command {
|
||||
|
@ -12,6 +12,10 @@ use std::ops::{Deref, Not};
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(version, about, author)]
|
||||
pub struct Opt {
|
||||
/// Do not automatically load `.env` files.
|
||||
#[clap(long)]
|
||||
pub no_dotenv: bool,
|
||||
|
||||
#[clap(subcommand)]
|
||||
pub command: Command,
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ _tls-native-tls = ["native-tls"]
|
||||
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
|
||||
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
|
||||
_tls-rustls-ring-native-roots = ["_tls-rustls", "rustls/ring", "rustls-native-certs"]
|
||||
_tls-rustls = ["rustls", "rustls-pemfile"]
|
||||
_tls-rustls = ["rustls"]
|
||||
_tls-none = []
|
||||
|
||||
# support offline/decoupled building (enables serialization of `Describe`)
|
||||
@ -47,8 +47,7 @@ tokio = { workspace = true, optional = true }
|
||||
# TLS
|
||||
native-tls = { version = "0.2.10", optional = true }
|
||||
|
||||
rustls = { version = "0.23.11", default-features = false, features = ["std", "tls12"], optional = true }
|
||||
rustls-pemfile = { version = "2", optional = true }
|
||||
rustls = { version = "0.23.15", default-features = false, features = ["std", "tls12"], optional = true }
|
||||
webpki-roots = { version = "0.26", optional = true }
|
||||
rustls-native-certs = { version = "0.8.0", optional = true }
|
||||
|
||||
@ -62,6 +61,7 @@ mac_address = { workspace = true, optional = true }
|
||||
uuid = { workspace = true, optional = true }
|
||||
|
||||
async-io = { version = "1.9.0", optional = true }
|
||||
base64 = { version = "0.22.0", default-features = false, features = ["std"] }
|
||||
bytes = "1.1.0"
|
||||
chrono = { version = "0.4.34", default-features = false, features = ["clock"], optional = true }
|
||||
crc = { version = "3", optional = true }
|
||||
|
@ -111,7 +111,8 @@ use crate::{error::Error, row::Row};
|
||||
/// different placeholder values, if applicable.
|
||||
///
|
||||
/// This is similar to how `#[serde(default)]` behaves.
|
||||
/// ### `flatten`
|
||||
///
|
||||
/// #### `flatten`
|
||||
///
|
||||
/// If you want to handle a field that implements [`FromRow`],
|
||||
/// you can use the `flatten` attribute to specify that you want
|
||||
@ -177,33 +178,6 @@ use crate::{error::Error, row::Row};
|
||||
/// assert!(user.addresses.is_empty());
|
||||
/// ```
|
||||
///
|
||||
/// ## Manual implementation
|
||||
///
|
||||
/// You can also implement the [`FromRow`] trait by hand. This can be useful if you
|
||||
/// have a struct with a field that needs manual decoding:
|
||||
///
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use sqlx::{FromRow, sqlite::SqliteRow, sqlx::Row};
|
||||
/// struct MyCustomType {
|
||||
/// custom: String,
|
||||
/// }
|
||||
///
|
||||
/// struct Foo {
|
||||
/// bar: MyCustomType,
|
||||
/// }
|
||||
///
|
||||
/// impl FromRow<'_, SqliteRow> for Foo {
|
||||
/// fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
|
||||
/// Ok(Self {
|
||||
/// bar: MyCustomType {
|
||||
/// custom: row.try_get("custom")?
|
||||
/// }
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// #### `try_from`
|
||||
///
|
||||
/// When your struct contains a field whose type is not matched with the database type,
|
||||
@ -271,6 +245,59 @@ use crate::{error::Error, row::Row};
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// By default the `#[sqlx(json)]` attribute will assume that the underlying database row is
|
||||
/// _not_ NULL. This can cause issues when your field type is an `Option<T>` because this would be
|
||||
/// represented as the _not_ NULL (in terms of DB) JSON value of `null`.
|
||||
///
|
||||
/// If you wish to describe a database row which _is_ NULLable but _cannot_ contain the JSON value `null`,
|
||||
/// use the `#[sqlx(json(nullable))]` attrubute.
|
||||
///
|
||||
/// For example
|
||||
/// ```rust,ignore
|
||||
/// #[derive(serde::Deserialize)]
|
||||
/// struct Data {
|
||||
/// field1: String,
|
||||
/// field2: u64
|
||||
/// }
|
||||
///
|
||||
/// #[derive(sqlx::FromRow)]
|
||||
/// struct User {
|
||||
/// id: i32,
|
||||
/// name: String,
|
||||
/// #[sqlx(json(nullable))]
|
||||
/// metadata: Option<Data>
|
||||
/// }
|
||||
/// ```
|
||||
/// Would describe a database field which _is_ NULLable but if it exists it must be the JSON representation of `Data`
|
||||
/// and cannot be the JSON value `null`
|
||||
///
|
||||
/// ## Manual implementation
|
||||
///
|
||||
/// You can also implement the [`FromRow`] trait by hand. This can be useful if you
|
||||
/// have a struct with a field that needs manual decoding:
|
||||
///
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use sqlx::{FromRow, sqlite::SqliteRow, sqlx::Row};
|
||||
/// struct MyCustomType {
|
||||
/// custom: String,
|
||||
/// }
|
||||
///
|
||||
/// struct Foo {
|
||||
/// bar: MyCustomType,
|
||||
/// }
|
||||
///
|
||||
/// impl FromRow<'_, SqliteRow> for Foo {
|
||||
/// fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
|
||||
/// Ok(Self {
|
||||
/// bar: MyCustomType {
|
||||
/// custom: row.try_get("custom")?
|
||||
/// }
|
||||
/// })
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait FromRow<'r, R: Row>: Sized {
|
||||
fn from_row(row: &'r R) -> Result<Self, Error>;
|
||||
}
|
||||
@ -286,7 +313,7 @@ where
|
||||
}
|
||||
|
||||
// implement FromRow for tuples of types that implement Decode
|
||||
// up to tuples of 9 values
|
||||
// up to tuples of 16 values
|
||||
|
||||
macro_rules! impl_from_row_for_tuple {
|
||||
($( ($idx:tt) -> $T:ident );+;) => {
|
||||
|
@ -1,10 +1,9 @@
|
||||
use crate::error::Error;
|
||||
use futures_core::Future;
|
||||
use futures_util::ready;
|
||||
use sqlx_rt::AsyncWrite;
|
||||
use std::future::Future;
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
// Atomic operation that writes the full buffer to the stream, flushes the stream, and then
|
||||
// clears the buffer (even if either of the two previous operations failed).
|
||||
|
@ -2,10 +2,9 @@ use std::future::Future;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use bytes::BufMut;
|
||||
use futures_core::ready;
|
||||
|
||||
pub use buffered::{BufferedSocket, WriteBuffer};
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use futures_util::future;
|
||||
use std::io::{self, BufReader, Cursor, Read, Write};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
@ -9,7 +9,10 @@ use rustls::{
|
||||
WebPkiServerVerifier,
|
||||
},
|
||||
crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvider},
|
||||
pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
|
||||
pki_types::{
|
||||
pem::{self, PemObject},
|
||||
CertificateDer, PrivateKeyDer, ServerName, UnixTime,
|
||||
},
|
||||
CertificateError, ClientConfig, ClientConnection, Error as TlsError, RootCertStore,
|
||||
};
|
||||
|
||||
@ -141,9 +144,8 @@ where
|
||||
|
||||
if let Some(ca) = tls_config.root_cert_path {
|
||||
let data = ca.data().await?;
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
for result in rustls_pemfile::certs(&mut cursor) {
|
||||
for result in CertificateDer::pem_slice_iter(&data) {
|
||||
let Ok(cert) = result else {
|
||||
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
|
||||
};
|
||||
@ -196,19 +198,15 @@ where
|
||||
}
|
||||
|
||||
fn certs_from_pem(pem: Vec<u8>) -> Result<Vec<CertificateDer<'static>>, Error> {
|
||||
let cur = Cursor::new(pem);
|
||||
let mut reader = BufReader::new(cur);
|
||||
rustls_pemfile::certs(&mut reader)
|
||||
CertificateDer::pem_slice_iter(&pem)
|
||||
.map(|result| result.map_err(|err| Error::Tls(err.into())))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn private_key_from_pem(pem: Vec<u8>) -> Result<PrivateKeyDer<'static>, Error> {
|
||||
let cur = Cursor::new(pem);
|
||||
let mut reader = BufReader::new(cur);
|
||||
match rustls_pemfile::private_key(&mut reader) {
|
||||
Ok(Some(key)) => Ok(key),
|
||||
Ok(None) => Err(Error::Configuration("no keys found pem file".into())),
|
||||
match PrivateKeyDer::from_pem_slice(&pem) {
|
||||
Ok(key) => Ok(key),
|
||||
Err(pem::Error::NoItemsFound) => Err(Error::Configuration("no keys found pem file".into())),
|
||||
Err(e) => Err(Error::Configuration(e.to_string().into())),
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
use crate::net::Socket;
|
||||
|
||||
use std::io::{self, Read, Write};
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use futures_core::ready;
|
||||
use futures_util::future;
|
||||
|
||||
pub struct StdSocket<S> {
|
||||
|
@ -10,6 +10,7 @@ use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};
|
||||
|
||||
use std::cmp;
|
||||
use std::future::Future;
|
||||
use std::pin::pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::task::Poll;
|
||||
@ -130,19 +131,12 @@ impl<DB: Database> PoolInner<DB> {
|
||||
// This is just going to cause unnecessary churn in `acquire()`.
|
||||
.filter(|_| self.size() < self.options.max_connections);
|
||||
|
||||
let acquire_self = self.semaphore.acquire(1).fuse();
|
||||
let mut close_event = self.close_event();
|
||||
let mut acquire_self = pin!(self.semaphore.acquire(1).fuse());
|
||||
let mut close_event = pin!(self.close_event());
|
||||
|
||||
if let Some(parent) = parent {
|
||||
let acquire_parent = parent.0.semaphore.acquire(1);
|
||||
let parent_close_event = parent.0.close_event();
|
||||
|
||||
futures_util::pin_mut!(
|
||||
acquire_parent,
|
||||
acquire_self,
|
||||
close_event,
|
||||
parent_close_event
|
||||
);
|
||||
let mut acquire_parent = pin!(parent.0.semaphore.acquire(1));
|
||||
let mut parent_close_event = pin!(parent.0.close_event());
|
||||
|
||||
let mut poll_parent = false;
|
||||
|
||||
|
@ -56,7 +56,7 @@
|
||||
|
||||
use std::fmt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::{Duration, Instant};
|
||||
@ -565,11 +565,11 @@ impl CloseEvent {
|
||||
.await
|
||||
.map_or(Ok(()), |_| Err(Error::PoolClosed))?;
|
||||
|
||||
futures_util::pin_mut!(fut);
|
||||
let mut fut = pin!(fut);
|
||||
|
||||
// I find that this is clearer in intent than `futures_util::future::select()`
|
||||
// or `futures_util::select_biased!{}` (which isn't enabled anyway).
|
||||
futures_util::future::poll_fn(|cx| {
|
||||
std::future::poll_fn(|cx| {
|
||||
// Poll `fut` first as the wakeup event is more likely for it than `self`.
|
||||
if let Poll::Ready(ret) = fut.as_mut().poll(cx) {
|
||||
return Poll::Ready(Ok(ret));
|
||||
|
@ -484,7 +484,7 @@ impl<DB: Database> PoolOptions<DB> {
|
||||
/// .await?;
|
||||
///
|
||||
/// // Close the connection if the backend memory usage exceeds 256 MiB.
|
||||
/// Ok(total_memory_usage <= (2 << 28))
|
||||
/// Ok(total_memory_usage <= (1 << 28))
|
||||
/// }))
|
||||
/// .connect("postgres:// …").await?;
|
||||
/// # Ok(())
|
||||
|
@ -323,6 +323,11 @@ where
|
||||
separated.push_unseparated(")");
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
separated.push_separator,
|
||||
"No value being pushed. QueryBuilder may not build correct sql query!"
|
||||
);
|
||||
|
||||
separated.query_builder
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,9 @@ use std::time::Duration;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||
pub use fixtures::FixtureSnapshot;
|
||||
use sha2::{Digest, Sha512};
|
||||
|
||||
use crate::connection::{ConnectOptions, Connection};
|
||||
use crate::database::Database;
|
||||
@ -41,6 +43,17 @@ pub trait TestSupport: Database {
|
||||
/// This snapshot can then be used to generate test fixtures.
|
||||
fn snapshot(conn: &mut Self::Connection)
|
||||
-> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>>;
|
||||
|
||||
/// Generate a unique database name for the given test path.
|
||||
fn db_name(args: &TestArgs) -> String {
|
||||
let mut hasher = Sha512::new();
|
||||
hasher.update(args.test_path.as_bytes());
|
||||
let hash = hasher.finalize();
|
||||
let hash = URL_SAFE.encode(&hash[..39]);
|
||||
let db_name = format!("_sqlx_test_{}", hash).replace('-', "_");
|
||||
debug_assert!(db_name.len() == 63);
|
||||
db_name
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestFixture {
|
||||
@ -217,7 +230,7 @@ where
|
||||
let res = test_fn(test_context.pool_opts, test_context.connect_opts).await;
|
||||
|
||||
if res.is_success() {
|
||||
if let Err(e) = DB::cleanup_test(&test_context.db_name).await {
|
||||
if let Err(e) = DB::cleanup_test(&DB::db_name(&args)).await {
|
||||
eprintln!(
|
||||
"failed to delete database {:?}: {}",
|
||||
test_context.db_name, e
|
||||
|
@ -85,6 +85,9 @@ pub mod mac_address {
|
||||
pub use json::{Json, JsonRawValue, JsonValue};
|
||||
pub use text::Text;
|
||||
|
||||
#[cfg(feature = "bstr")]
|
||||
pub use bstr::{BStr, BString};
|
||||
|
||||
/// Indicates that a SQL type is supported for a database.
|
||||
///
|
||||
/// ## Compile-time verification
|
||||
|
@ -1,8 +1,8 @@
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote_spanned;
|
||||
use syn::{
|
||||
punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, LitStr, Meta, Token, Type,
|
||||
Variant,
|
||||
parenthesized, punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, LitStr,
|
||||
Meta, Token, Type, Variant,
|
||||
};
|
||||
|
||||
macro_rules! assert_attribute {
|
||||
@ -61,13 +61,18 @@ pub struct SqlxContainerAttributes {
|
||||
pub default: bool,
|
||||
}
|
||||
|
||||
pub enum JsonAttribute {
|
||||
NonNullable,
|
||||
Nullable,
|
||||
}
|
||||
|
||||
pub struct SqlxChildAttributes {
|
||||
pub rename: Option<String>,
|
||||
pub default: bool,
|
||||
pub flatten: bool,
|
||||
pub try_from: Option<Type>,
|
||||
pub skip: bool,
|
||||
pub json: bool,
|
||||
pub json: Option<JsonAttribute>,
|
||||
}
|
||||
|
||||
pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
|
||||
@ -144,7 +149,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
|
||||
let mut try_from = None;
|
||||
let mut flatten = false;
|
||||
let mut skip: bool = false;
|
||||
let mut json = false;
|
||||
let mut json = None;
|
||||
|
||||
for attr in input.iter().filter(|a| a.path().is_ident("sqlx")) {
|
||||
attr.parse_nested_meta(|meta| {
|
||||
@ -163,13 +168,21 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
|
||||
} else if meta.path.is_ident("skip") {
|
||||
skip = true;
|
||||
} else if meta.path.is_ident("json") {
|
||||
json = true;
|
||||
if meta.input.peek(syn::token::Paren) {
|
||||
let content;
|
||||
parenthesized!(content in meta.input);
|
||||
let literal: Ident = content.parse()?;
|
||||
assert_eq!(literal.to_string(), "nullable", "Unrecognized `json` attribute. Valid values are `json` or `json(nullable)`");
|
||||
json = Some(JsonAttribute::Nullable);
|
||||
} else {
|
||||
json = Some(JsonAttribute::NonNullable);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
if json && flatten {
|
||||
if json.is_some() && flatten {
|
||||
fail!(
|
||||
attr,
|
||||
"Cannot use `json` and `flatten` together on the same field"
|
||||
|
@ -6,7 +6,7 @@ use syn::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
attributes::{parse_child_attributes, parse_container_attributes},
|
||||
attributes::{parse_child_attributes, parse_container_attributes, JsonAttribute},
|
||||
rename_all,
|
||||
};
|
||||
|
||||
@ -99,7 +99,7 @@ fn expand_derive_from_row_struct(
|
||||
|
||||
let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) {
|
||||
// <No attributes>
|
||||
(false, None, false) => {
|
||||
(false, None, None) => {
|
||||
predicates
|
||||
.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>));
|
||||
predicates.push(parse_quote!(#ty: ::sqlx::types::Type<R::Database>));
|
||||
@ -107,12 +107,12 @@ fn expand_derive_from_row_struct(
|
||||
parse_quote!(__row.try_get(#id_s))
|
||||
}
|
||||
// Flatten
|
||||
(true, None, false) => {
|
||||
(true, None, None) => {
|
||||
predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>));
|
||||
parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(__row))
|
||||
}
|
||||
// Flatten + Try from
|
||||
(true, Some(try_from), false) => {
|
||||
(true, Some(try_from), None) => {
|
||||
predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>));
|
||||
parse_quote!(
|
||||
<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(__row)
|
||||
@ -130,11 +130,11 @@ fn expand_derive_from_row_struct(
|
||||
)
|
||||
}
|
||||
// Flatten + Json
|
||||
(true, _, true) => {
|
||||
(true, _, Some(_)) => {
|
||||
panic!("Cannot use both flatten and json")
|
||||
}
|
||||
// Try from
|
||||
(false, Some(try_from), false) => {
|
||||
(false, Some(try_from), None) => {
|
||||
predicates
|
||||
.push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>));
|
||||
predicates.push(parse_quote!(#try_from: ::sqlx::types::Type<R::Database>));
|
||||
@ -154,8 +154,8 @@ fn expand_derive_from_row_struct(
|
||||
})
|
||||
)
|
||||
}
|
||||
// Try from + Json
|
||||
(false, Some(try_from), true) => {
|
||||
// Try from + Json mandatory
|
||||
(false, Some(try_from), Some(JsonAttribute::NonNullable)) => {
|
||||
predicates
|
||||
.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>));
|
||||
predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type<R::Database>));
|
||||
@ -175,14 +175,25 @@ fn expand_derive_from_row_struct(
|
||||
})
|
||||
)
|
||||
},
|
||||
// Try from + Json nullable
|
||||
(false, Some(_), Some(JsonAttribute::Nullable)) => {
|
||||
panic!("Cannot use both try from and json nullable")
|
||||
},
|
||||
// Json
|
||||
(false, None, true) => {
|
||||
(false, None, Some(JsonAttribute::NonNullable)) => {
|
||||
predicates
|
||||
.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>));
|
||||
predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type<R::Database>));
|
||||
|
||||
parse_quote!(__row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0))
|
||||
},
|
||||
(false, None, Some(JsonAttribute::Nullable)) => {
|
||||
predicates
|
||||
.push(parse_quote!(::core::option::Option<::sqlx::types::Json<#ty>>: ::sqlx::decode::Decode<#lifetime, R::Database>));
|
||||
predicates.push(parse_quote!(::core::option::Option<::sqlx::types::Json<#ty>>: ::sqlx::types::Type<R::Database>));
|
||||
|
||||
parse_quote!(__row.try_get::<::core::option::Option<::sqlx::types::Json<_>>, _>(#id_s).map(|x| x.and_then(|y| y.0)))
|
||||
},
|
||||
};
|
||||
|
||||
if attributes.default {
|
||||
|
@ -16,7 +16,7 @@ use sqlx_core::database::Database;
|
||||
use sqlx_core::describe::Describe;
|
||||
use sqlx_core::executor::Executor;
|
||||
use sqlx_core::transaction::TransactionManager;
|
||||
use std::future;
|
||||
use std::{future, pin::pin};
|
||||
|
||||
sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql);
|
||||
|
||||
@ -113,8 +113,7 @@ impl AnyConnectionBackend for MySqlConnection {
|
||||
|
||||
Box::pin(async move {
|
||||
let arguments = arguments?;
|
||||
let stream = self.run(query, arguments, persistent).await?;
|
||||
futures_util::pin_mut!(stream);
|
||||
let mut stream = pin!(self.run(query, arguments, persistent).await?);
|
||||
|
||||
while let Some(result) = stream.try_next().await? {
|
||||
if let Either::Right(row) = result {
|
||||
|
@ -21,9 +21,9 @@ use either::Either;
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
use futures_util::TryStreamExt;
|
||||
use sqlx_core::column::{ColumnOrigin, TableColumn};
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use std::{borrow::Cow, pin::pin, sync::Arc};
|
||||
|
||||
impl MySqlConnection {
|
||||
async fn prepare_statement<'c>(
|
||||
@ -112,7 +112,7 @@ impl MySqlConnection {
|
||||
self.inner.stream.wait_until_ready().await?;
|
||||
self.inner.stream.waiting.push_back(Waiting::Result);
|
||||
|
||||
Ok(Box::pin(try_stream! {
|
||||
Ok(try_stream! {
|
||||
// make a slot for the shared column data
|
||||
// as long as a reference to a row is not held past one iteration, this enables us
|
||||
// to re-use this memory freely between result sets
|
||||
@ -241,7 +241,7 @@ impl MySqlConnection {
|
||||
r#yield!(v);
|
||||
}
|
||||
}
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,8 +264,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
|
||||
|
||||
Box::pin(try_stream! {
|
||||
let arguments = arguments?;
|
||||
let s = self.run(sql, arguments, persistent).await?;
|
||||
pin_mut!(s);
|
||||
let mut s = pin!(self.run(sql, arguments, persistent).await?);
|
||||
|
||||
while let Some(v) = s.try_next().await? {
|
||||
r#yield!(v);
|
||||
|
@ -1,29 +1,25 @@
|
||||
use std::fmt::Write;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
use crate::connection::Connection;
|
||||
use sqlx_core::connection::Connection;
|
||||
use sqlx_core::query_builder::QueryBuilder;
|
||||
use sqlx_core::query_scalar::query_scalar;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::executor::Executor;
|
||||
use crate::pool::{Pool, PoolOptions};
|
||||
use crate::query::query;
|
||||
use crate::query_builder::QueryBuilder;
|
||||
use crate::query_scalar::query_scalar;
|
||||
use crate::{MySql, MySqlConnectOptions, MySqlConnection};
|
||||
|
||||
pub(crate) use sqlx_core::testing::*;
|
||||
|
||||
// Using a blocking `OnceCell` here because the critical sections are short.
|
||||
static MASTER_POOL: OnceCell<Pool<MySql>> = OnceCell::new();
|
||||
// Automatically delete any databases created before the start of the test binary.
|
||||
static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
|
||||
|
||||
impl TestSupport for MySql {
|
||||
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
|
||||
@ -34,21 +30,11 @@ impl TestSupport for MySql {
|
||||
Box::pin(async move {
|
||||
let mut conn = MASTER_POOL
|
||||
.get()
|
||||
.expect("cleanup_test() invoked outside `#[sqlx::test]")
|
||||
.expect("cleanup_test() invoked outside `#[sqlx::test]`")
|
||||
.acquire()
|
||||
.await?;
|
||||
|
||||
let db_id = db_id(db_name);
|
||||
|
||||
conn.execute(&format!("drop database if exists {db_name};")[..])
|
||||
.await?;
|
||||
|
||||
query("delete from _sqlx_test_databases where db_id = ?")
|
||||
.bind(db_id)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
do_cleanup(&mut conn, db_name).await
|
||||
})
|
||||
}
|
||||
|
||||
@ -58,13 +44,55 @@ impl TestSupport for MySql {
|
||||
|
||||
let mut conn = MySqlConnection::connect(&url).await?;
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let delete_db_names: Vec<String> =
|
||||
query_scalar("select db_name from _sqlx_test_databases")
|
||||
.fetch_all(&mut conn)
|
||||
.await?;
|
||||
|
||||
if delete_db_names.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
|
||||
|
||||
let mut command = String::new();
|
||||
|
||||
for db_name in &delete_db_names {
|
||||
command.clear();
|
||||
|
||||
let db_name = format!("_sqlx_test_database_{db_name}");
|
||||
|
||||
writeln!(command, "drop database if exists {db_name:?};").ok();
|
||||
match conn.execute(&*command).await {
|
||||
Ok(_deleted) => {
|
||||
deleted_db_names.push(db_name);
|
||||
}
|
||||
// Assume a database error just means the DB is still in use.
|
||||
Err(Error::Database(dbe)) => {
|
||||
eprintln!("could not clean test database {db_name:?}: {dbe}")
|
||||
}
|
||||
// Bubble up other errors
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
if deleted_db_names.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut query =
|
||||
QueryBuilder::new("delete from _sqlx_test_databases where db_name in (");
|
||||
|
||||
let mut separated = query.separated(",");
|
||||
|
||||
for db_name in &deleted_db_names {
|
||||
separated.push_bind(db_name);
|
||||
}
|
||||
|
||||
query.push(")").build().execute(&mut conn).await?;
|
||||
|
||||
let num_deleted = do_cleanup(&mut conn, now).await?;
|
||||
let _ = conn.close().await;
|
||||
Ok(Some(num_deleted))
|
||||
Ok(Some(delete_db_names.len()))
|
||||
})
|
||||
}
|
||||
|
||||
@ -117,7 +145,7 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<MySql>, Error> {
|
||||
conn.execute(
|
||||
r#"
|
||||
create table if not exists _sqlx_test_databases (
|
||||
db_id bigint unsigned primary key auto_increment,
|
||||
db_name text primary key,
|
||||
test_path text not null,
|
||||
created_at timestamp not null default current_timestamp
|
||||
);
|
||||
@ -125,34 +153,19 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<MySql>, Error> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record the current time _before_ we acquire the `DO_CLEANUP` permit. This
|
||||
// prevents the first test thread from accidentally deleting new test dbs
|
||||
// created by other test threads if we're a bit slow.
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let db_name = MySql::db_name(args);
|
||||
do_cleanup(&mut conn, &db_name).await?;
|
||||
|
||||
// Only run cleanup if the test binary just started.
|
||||
if DO_CLEANUP.swap(false, Ordering::SeqCst) {
|
||||
do_cleanup(&mut conn, now).await?;
|
||||
}
|
||||
|
||||
query("insert into _sqlx_test_databases(test_path) values (?)")
|
||||
query("insert into _sqlx_test_databases(db_name, test_path) values (?, ?)")
|
||||
.bind(&db_name)
|
||||
.bind(args.test_path)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
// MySQL doesn't have `INSERT ... RETURNING`
|
||||
let new_db_id: u64 = query_scalar("select last_insert_id()")
|
||||
.fetch_one(&mut *conn)
|
||||
conn.execute(&format!("create database {db_name:?}")[..])
|
||||
.await?;
|
||||
|
||||
let new_db_name = db_name(new_db_id);
|
||||
|
||||
conn.execute(&format!("create database {new_db_name}")[..])
|
||||
.await?;
|
||||
|
||||
eprintln!("created database {new_db_name}");
|
||||
eprintln!("created database {db_name}");
|
||||
|
||||
Ok(TestContext {
|
||||
pool_opts: PoolOptions::new()
|
||||
@ -167,74 +180,18 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<MySql>, Error> {
|
||||
.connect_options()
|
||||
.deref()
|
||||
.clone()
|
||||
.database(&new_db_name),
|
||||
db_name: new_db_name,
|
||||
.database(&db_name),
|
||||
db_name,
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_cleanup(conn: &mut MySqlConnection, created_before: Duration) -> Result<usize, Error> {
|
||||
// since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads
|
||||
let created_before_as_secs = created_before.as_secs() - 2;
|
||||
let delete_db_ids: Vec<u64> = query_scalar(
|
||||
"select db_id from _sqlx_test_databases \
|
||||
where created_at < from_unixtime(?)",
|
||||
)
|
||||
.bind(created_before_as_secs)
|
||||
.fetch_all(&mut *conn)
|
||||
.await?;
|
||||
async fn do_cleanup(conn: &mut MySqlConnection, db_name: &str) -> Result<(), Error> {
|
||||
let delete_db_command = format!("drop database if exists {db_name:?};");
|
||||
conn.execute(&*delete_db_command).await?;
|
||||
query("delete from _sqlx_test.databases where db_name = $1::text")
|
||||
.bind(db_name)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
if delete_db_ids.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut deleted_db_ids = Vec::with_capacity(delete_db_ids.len());
|
||||
|
||||
let mut command = String::new();
|
||||
|
||||
for db_id in delete_db_ids {
|
||||
command.clear();
|
||||
|
||||
let db_name = db_name(db_id);
|
||||
|
||||
writeln!(command, "drop database if exists {db_name}").ok();
|
||||
match conn.execute(&*command).await {
|
||||
Ok(_deleted) => {
|
||||
deleted_db_ids.push(db_id);
|
||||
}
|
||||
// Assume a database error just means the DB is still in use.
|
||||
Err(Error::Database(dbe)) => {
|
||||
eprintln!("could not clean test database {db_id:?}: {dbe}")
|
||||
}
|
||||
// Bubble up other errors
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
let mut query = QueryBuilder::new("delete from _sqlx_test_databases where db_id in (");
|
||||
|
||||
let mut separated = query.separated(",");
|
||||
|
||||
for db_id in &deleted_db_ids {
|
||||
separated.push_bind(db_id);
|
||||
}
|
||||
|
||||
query.push(")").build().execute(&mut *conn).await?;
|
||||
|
||||
Ok(deleted_db_ids.len())
|
||||
}
|
||||
|
||||
fn db_name(id: u64) -> String {
|
||||
format!("_sqlx_test_database_{id}")
|
||||
}
|
||||
|
||||
fn db_id(name: &str) -> u64 {
|
||||
name.trim_start_matches("_sqlx_test_database_")
|
||||
.parse()
|
||||
.unwrap_or_else(|_1| panic!("failed to parse ID from database name {name:?}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_db_name_id() {
|
||||
assert_eq!(db_name(12345), "_sqlx_test_database_12345");
|
||||
assert_eq!(db_id("_sqlx_test_database_12345"), 12345);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ use crate::{
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use std::future;
|
||||
use std::{future, pin::pin};
|
||||
|
||||
use sqlx_core::any::{
|
||||
Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
|
||||
@ -115,8 +115,7 @@ impl AnyConnectionBackend for PgConnection {
|
||||
|
||||
Box::pin(async move {
|
||||
let arguments = arguments?;
|
||||
let stream = self.run(query, arguments, 1, persistent, None).await?;
|
||||
futures_util::pin_mut!(stream);
|
||||
let mut stream = pin!(self.run(query, arguments, 1, persistent, None).await?);
|
||||
|
||||
if let Some(Either::Right(row)) = stream.try_next().await? {
|
||||
return Ok(Some(AnyRow::try_from(&row)?));
|
||||
|
@ -22,7 +22,7 @@ use sqlx_core::error::BoxDynError;
|
||||
// that has a patch, we then apply the patch which should write to &mut Vec<u8>,
|
||||
// backtrack and update the prefixed-len, then write until the next patch offset
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct PgArgumentBuffer {
|
||||
buffer: Vec<u8>,
|
||||
|
||||
@ -46,20 +46,32 @@ pub struct PgArgumentBuffer {
|
||||
type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum HoleKind {
|
||||
Type { name: UStr },
|
||||
Array(Arc<PgArrayOf>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Patch {
|
||||
buf_offset: usize,
|
||||
arg_index: usize,
|
||||
#[allow(clippy::type_complexity)]
|
||||
callback: Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
|
||||
callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Patch {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Patch")
|
||||
.field("buf_offset", &self.buf_offset)
|
||||
.field("arg_index", &self.arg_index)
|
||||
.field("callback", &"<callback>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of [`Arguments`] for PostgreSQL.
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct PgArguments {
|
||||
// Types of each bind parameter
|
||||
pub(crate) types: Vec<PgTypeInfo>,
|
||||
@ -194,7 +206,7 @@ impl PgArgumentBuffer {
|
||||
self.patches.push(Patch {
|
||||
buf_offset: offset,
|
||||
arg_index,
|
||||
callback: Box::new(callback),
|
||||
callback: Arc::new(callback),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -15,10 +15,10 @@ use crate::{
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
use futures_util::TryStreamExt;
|
||||
use sqlx_core::arguments::Arguments;
|
||||
use sqlx_core::Either;
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use std::{borrow::Cow, pin::pin, sync::Arc};
|
||||
|
||||
async fn prepare(
|
||||
conn: &mut PgConnection,
|
||||
@ -395,8 +395,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
|
||||
Box::pin(try_stream! {
|
||||
let arguments = arguments?;
|
||||
let s = self.run(sql, arguments, 0, persistent, metadata).await?;
|
||||
pin_mut!(s);
|
||||
let mut s = pin!(self.run(sql, arguments, 0, persistent, metadata).await?);
|
||||
|
||||
while let Some(v) = s.try_next().await? {
|
||||
r#yield!(v);
|
||||
@ -422,8 +421,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
|
||||
Box::pin(async move {
|
||||
let arguments = arguments?;
|
||||
let s = self.run(sql, arguments, 1, persistent, metadata).await?;
|
||||
pin_mut!(s);
|
||||
let mut s = pin!(self.run(sql, arguments, 1, persistent, metadata).await?);
|
||||
|
||||
// With deferred constraints we need to check all responses as we
|
||||
// could get a OK response (with uncommitted data), only to get an
|
||||
|
@ -129,6 +129,9 @@ impl PgPoolCopyExt for Pool<Postgres> {
|
||||
}
|
||||
}
|
||||
|
||||
// (1 GiB - 1) - 1 - length prefix (4 bytes)
|
||||
pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4;
|
||||
|
||||
/// A connection in streaming `COPY FROM STDIN` mode.
|
||||
///
|
||||
/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
|
||||
@ -186,15 +189,20 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
||||
|
||||
/// Send a chunk of `COPY` data.
|
||||
///
|
||||
/// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6
|
||||
/// bytes) and may be partially sent if this call is cancelled.
|
||||
///
|
||||
/// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
|
||||
pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
|
||||
self.conn
|
||||
.as_deref_mut()
|
||||
.expect("send_data: conn taken")
|
||||
.inner
|
||||
.stream
|
||||
.send(CopyData(data))
|
||||
.await?;
|
||||
for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
|
||||
self.conn
|
||||
.as_deref_mut()
|
||||
.expect("send_data: conn taken")
|
||||
.inner
|
||||
.stream
|
||||
.send(CopyData(chunk))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
@ -230,10 +238,10 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
||||
}
|
||||
|
||||
// Write the length
|
||||
let read32 = u32::try_from(read)
|
||||
.map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
|
||||
let read32 = i32::try_from(read)
|
||||
.map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;
|
||||
|
||||
(&mut buf.get_mut()[1..]).put_u32(read32 + 4);
|
||||
(&mut buf.get_mut()[1..]).put_i32(read32 + 4);
|
||||
|
||||
conn.inner.stream.flush().await?;
|
||||
}
|
||||
|
@ -34,6 +34,9 @@ mod value;
|
||||
#[doc(hidden)]
|
||||
pub mod any;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub use copy::PG_COPY_MAX_DATA_LEN;
|
||||
|
||||
#[cfg(feature = "migrate")]
|
||||
mod migrate;
|
||||
|
||||
|
@ -1,20 +1,18 @@
|
||||
use std::fmt::Write;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
use crate::connection::Connection;
|
||||
use sqlx_core::connection::Connection;
|
||||
use sqlx_core::query_scalar::query_scalar;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::executor::Executor;
|
||||
use crate::pool::{Pool, PoolOptions};
|
||||
use crate::query::query;
|
||||
use crate::query_scalar::query_scalar;
|
||||
use crate::{PgConnectOptions, PgConnection, Postgres};
|
||||
|
||||
pub(crate) use sqlx_core::testing::*;
|
||||
@ -22,7 +20,6 @@ pub(crate) use sqlx_core::testing::*;
|
||||
// Using a blocking `OnceCell` here because the critical sections are short.
|
||||
static MASTER_POOL: OnceCell<Pool<Postgres>> = OnceCell::new();
|
||||
// Automatically delete any databases created before the start of the test binary.
|
||||
static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
|
||||
|
||||
impl TestSupport for Postgres {
|
||||
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
|
||||
@ -33,19 +30,11 @@ impl TestSupport for Postgres {
|
||||
Box::pin(async move {
|
||||
let mut conn = MASTER_POOL
|
||||
.get()
|
||||
.expect("cleanup_test() invoked outside `#[sqlx::test]")
|
||||
.expect("cleanup_test() invoked outside `#[sqlx::test]`")
|
||||
.acquire()
|
||||
.await?;
|
||||
|
||||
conn.execute(&format!("drop database if exists {db_name:?};")[..])
|
||||
.await?;
|
||||
|
||||
query("delete from _sqlx_test.databases where db_name = $1")
|
||||
.bind(db_name)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
do_cleanup(&mut conn, db_name).await
|
||||
})
|
||||
}
|
||||
|
||||
@ -55,13 +44,42 @@ impl TestSupport for Postgres {
|
||||
|
||||
let mut conn = PgConnection::connect(&url).await?;
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let delete_db_names: Vec<String> =
|
||||
query_scalar("select db_name from _sqlx_test.databases")
|
||||
.fetch_all(&mut conn)
|
||||
.await?;
|
||||
|
||||
if delete_db_names.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
|
||||
|
||||
let mut command = String::new();
|
||||
|
||||
for db_name in &delete_db_names {
|
||||
command.clear();
|
||||
writeln!(command, "drop database if exists {db_name:?};").ok();
|
||||
match conn.execute(&*command).await {
|
||||
Ok(_deleted) => {
|
||||
deleted_db_names.push(db_name);
|
||||
}
|
||||
// Assume a database error just means the DB is still in use.
|
||||
Err(Error::Database(dbe)) => {
|
||||
eprintln!("could not clean test database {db_name:?}: {dbe}")
|
||||
}
|
||||
// Bubble up other errors
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
query("delete from _sqlx_test.databases where db_name = any($1::text[])")
|
||||
.bind(&deleted_db_names)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
let num_deleted = do_cleanup(&mut conn, now).await?;
|
||||
let _ = conn.close().await;
|
||||
Ok(Some(num_deleted))
|
||||
Ok(Some(delete_db_names.len()))
|
||||
})
|
||||
}
|
||||
|
||||
@ -116,8 +134,9 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
|
||||
// I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically,
|
||||
// but a clearly related bug with `CREATE TABLE` has been known since 2007:
|
||||
// https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org
|
||||
// magic constant 8318549251334697844 is just 8 ascii bytes 'sqlxtest'.
|
||||
r#"
|
||||
lock table pg_catalog.pg_namespace in share row exclusive mode;
|
||||
select pg_advisory_xact_lock(8318549251334697844);
|
||||
|
||||
create schema if not exists _sqlx_test;
|
||||
|
||||
@ -135,31 +154,22 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record the current time _before_ we acquire the `DO_CLEANUP` permit. This
|
||||
// prevents the first test thread from accidentally deleting new test dbs
|
||||
// created by other test threads if we're a bit slow.
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap();
|
||||
let db_name = Postgres::db_name(args);
|
||||
do_cleanup(&mut conn, &db_name).await?;
|
||||
|
||||
// Only run cleanup if the test binary just started.
|
||||
if DO_CLEANUP.swap(false, Ordering::SeqCst) {
|
||||
do_cleanup(&mut conn, now).await?;
|
||||
}
|
||||
|
||||
let new_db_name: String = query_scalar(
|
||||
query(
|
||||
r#"
|
||||
insert into _sqlx_test.databases(db_name, test_path)
|
||||
select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1
|
||||
returning db_name
|
||||
insert into _sqlx_test.databases(db_name, test_path) values ($1, $2)
|
||||
"#,
|
||||
)
|
||||
.bind(&db_name)
|
||||
.bind(args.test_path)
|
||||
.fetch_one(&mut *conn)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
conn.execute(&format!("create database {new_db_name:?}")[..])
|
||||
.await?;
|
||||
let create_command = format!("create database {db_name:?}");
|
||||
debug_assert!(create_command.starts_with("create database \""));
|
||||
conn.execute(&(create_command)[..]).await?;
|
||||
|
||||
Ok(TestContext {
|
||||
pool_opts: PoolOptions::new()
|
||||
@ -174,52 +184,18 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
|
||||
.connect_options()
|
||||
.deref()
|
||||
.clone()
|
||||
.database(&new_db_name),
|
||||
db_name: new_db_name,
|
||||
.database(&db_name),
|
||||
db_name,
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result<usize, Error> {
|
||||
// since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads
|
||||
let created_before = i64::try_from(created_before.as_secs()).unwrap() - 2;
|
||||
|
||||
let delete_db_names: Vec<String> = query_scalar(
|
||||
"select db_name from _sqlx_test.databases \
|
||||
where created_at < (to_timestamp($1) at time zone 'UTC')",
|
||||
)
|
||||
.bind(created_before)
|
||||
.fetch_all(&mut *conn)
|
||||
.await?;
|
||||
|
||||
if delete_db_names.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
|
||||
let delete_db_names = delete_db_names.into_iter();
|
||||
|
||||
let mut command = String::new();
|
||||
|
||||
for db_name in delete_db_names {
|
||||
command.clear();
|
||||
writeln!(command, "drop database if exists {db_name:?};").ok();
|
||||
match conn.execute(&*command).await {
|
||||
Ok(_deleted) => {
|
||||
deleted_db_names.push(db_name);
|
||||
}
|
||||
// Assume a database error just means the DB is still in use.
|
||||
Err(Error::Database(dbe)) => {
|
||||
eprintln!("could not clean test database {db_name:?}: {dbe}")
|
||||
}
|
||||
// Bubble up other errors
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
query("delete from _sqlx_test.databases where db_name = any($1::text[])")
|
||||
.bind(&deleted_db_names)
|
||||
async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> {
|
||||
let delete_db_command = format!("drop database if exists {db_name:?};");
|
||||
conn.execute(&*delete_db_command).await?;
|
||||
query("delete from _sqlx_test.databases where db_name = $1::text")
|
||||
.bind(db_name)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(deleted_db_names.len())
|
||||
Ok(())
|
||||
}
|
||||
|
@ -36,6 +36,10 @@ impl_type_checking!(
|
||||
|
||||
sqlx::postgres::types::PgLine,
|
||||
|
||||
sqlx::postgres::types::PgLSeg,
|
||||
|
||||
sqlx::postgres::types::PgBox,
|
||||
|
||||
#[cfg(feature = "uuid")]
|
||||
sqlx::types::Uuid,
|
||||
|
||||
|
@ -185,7 +185,7 @@ pub enum PgTypeKind {
|
||||
Range(PgTypeInfo),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct PgArrayOf {
|
||||
pub(crate) elem_name: UStr,
|
||||
|
@ -20,7 +20,7 @@ const IS_POINT_FLAG: u32 = 1 << 31;
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PgCube {
|
||||
/// A one-dimensional point.
|
||||
// FIXME: `Point1D(f64)
|
||||
// FIXME: `Point1D(f64)`
|
||||
Point(f64),
|
||||
/// An N-dimensional point ("represented internally as a zero-volume cube").
|
||||
// FIXME: `PointND(f64)`
|
||||
@ -32,7 +32,7 @@ pub enum PgCube {
|
||||
|
||||
// FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`?
|
||||
/// An N-dimensional cube with points representing lower-left and upper-right corners, respectively.
|
||||
// FIXME: CubeND { lower_left: Vec<f64>, upper_right: Vec<f64> }`
|
||||
// FIXME: `CubeND { lower_left: Vec<f64>, upper_right: Vec<f64> }`
|
||||
MultiDimension(Vec<Vec<f64>>),
|
||||
}
|
||||
|
||||
|
321
sqlx-postgres/src/types/geometry/box.rs
Normal file
321
sqlx-postgres/src/types/geometry/box.rs
Normal file
@ -0,0 +1,321 @@
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::types::Type;
|
||||
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
|
||||
use sqlx_core::bytes::Buf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const ERROR: &str = "error decoding BOX";
|
||||
|
||||
/// ## Postgres Geometric Box type
|
||||
///
|
||||
/// Description: Rectangular box
|
||||
/// Representation: `((upper_right_x,upper_right_y),(lower_left_x,lower_left_y))`
|
||||
///
|
||||
/// Boxes are represented by pairs of points that are opposite corners of the box. Values of type box are specified using any of the following syntaxes:
|
||||
///
|
||||
/// ```text
|
||||
/// ( ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) )
|
||||
/// ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y )
|
||||
/// upper_right_x , upper_right_y , lower_left_x , lower_left_y
|
||||
/// ```
|
||||
/// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box.
|
||||
/// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order.
|
||||
///
|
||||
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PgBox {
|
||||
pub upper_right_x: f64,
|
||||
pub upper_right_y: f64,
|
||||
pub lower_left_x: f64,
|
||||
pub lower_left_y: f64,
|
||||
}
|
||||
|
||||
impl Type<Postgres> for PgBox {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("box")
|
||||
}
|
||||
}
|
||||
|
||||
impl PgHasArrayType for PgBox {
|
||||
fn array_type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("_box")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Decode<'r, Postgres> for PgBox {
|
||||
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match value.format() {
|
||||
PgValueFormat::Text => Ok(PgBox::from_str(value.as_str()?)?),
|
||||
PgValueFormat::Binary => Ok(PgBox::from_bytes(value.as_bytes()?)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q> Encode<'q, Postgres> for PgBox {
|
||||
fn produces(&self) -> Option<PgTypeInfo> {
|
||||
Some(PgTypeInfo::with_name("box"))
|
||||
}
|
||||
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
|
||||
self.serialize(buf)?;
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PgBox {
|
||||
type Err = BoxDynError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
|
||||
let mut parts = sanitised.split(',');
|
||||
|
||||
let upper_right_x = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get upper_right_x from {}", ERROR, s))?;
|
||||
|
||||
let upper_right_y = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get upper_right_y from {}", ERROR, s))?;
|
||||
|
||||
let lower_left_x = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get lower_left_x from {}", ERROR, s))?;
|
||||
|
||||
let lower_left_y = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get lower_left_y from {}", ERROR, s))?;
|
||||
|
||||
if parts.next().is_some() {
|
||||
return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into());
|
||||
}
|
||||
|
||||
Ok(PgBox {
|
||||
upper_right_x,
|
||||
upper_right_y,
|
||||
lower_left_x,
|
||||
lower_left_y,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PgBox {
|
||||
fn from_bytes(mut bytes: &[u8]) -> Result<PgBox, BoxDynError> {
|
||||
let upper_right_x = bytes.get_f64();
|
||||
let upper_right_y = bytes.get_f64();
|
||||
let lower_left_x = bytes.get_f64();
|
||||
let lower_left_y = bytes.get_f64();
|
||||
|
||||
Ok(PgBox {
|
||||
upper_right_x,
|
||||
upper_right_y,
|
||||
lower_left_x,
|
||||
lower_left_y,
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
|
||||
let min_x = &self.upper_right_x.min(self.lower_left_x);
|
||||
let min_y = &self.upper_right_y.min(self.lower_left_y);
|
||||
let max_x = &self.upper_right_x.max(self.lower_left_x);
|
||||
let max_y = &self.upper_right_y.max(self.lower_left_y);
|
||||
|
||||
buff.extend_from_slice(&max_x.to_be_bytes());
|
||||
buff.extend_from_slice(&max_y.to_be_bytes());
|
||||
buff.extend_from_slice(&min_x.to_be_bytes());
|
||||
buff.extend_from_slice(&min_y.to_be_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn serialize_to_vec(&self) -> Vec<u8> {
|
||||
let mut buff = PgArgumentBuffer::default();
|
||||
self.serialize(&mut buff).unwrap();
|
||||
buff.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod box_tests {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::PgBox;
|
||||
|
||||
const BOX_BYTES: &[u8] = &[
|
||||
64, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_box_type_bytes_in_order() {
|
||||
let pg_box = PgBox::from_bytes(BOX_BYTES).unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 2.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: -2.,
|
||||
lower_left_y: -2.
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_box_type_str_first_syntax() {
|
||||
let pg_box = PgBox::from_str("[( 1, 2), (3, 4 )]").unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 1.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: 3.,
|
||||
lower_left_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn can_deserialise_box_type_str_second_syntax() {
|
||||
let pg_box = PgBox::from_str("(( 1, 2), (3, 4 ))").unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 1.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: 3.,
|
||||
lower_left_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_box_type_str_third_syntax() {
|
||||
let pg_box = PgBox::from_str("(1, 2), (3, 4 )").unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 1.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: 3.,
|
||||
lower_left_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_box_type_str_fourth_syntax() {
|
||||
let pg_box = PgBox::from_str("1, 2, 3, 4").unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 1.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: 3.,
|
||||
lower_left_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cannot_deserialise_too_many_numbers() {
|
||||
let input_str = "1, 2, 3, 4, 5";
|
||||
let pg_box = PgBox::from_str(input_str);
|
||||
assert!(pg_box.is_err());
|
||||
if let Err(err) = pg_box {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding BOX: too many numbers inputted in {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cannot_deserialise_too_few_numbers() {
|
||||
let input_str = "1, 2, 3 ";
|
||||
let pg_box = PgBox::from_str(input_str);
|
||||
assert!(pg_box.is_err());
|
||||
if let Err(err) = pg_box {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding BOX: could not get lower_left_y from {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cannot_deserialise_invalid_numbers() {
|
||||
let input_str = "1, 2, 3, FOUR";
|
||||
let pg_box = PgBox::from_str(input_str);
|
||||
assert!(pg_box.is_err());
|
||||
if let Err(err) = pg_box {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding BOX: could not get lower_left_y from {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_box_type_str_float() {
|
||||
let pg_box = PgBox::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 1.1,
|
||||
upper_right_y: 2.2,
|
||||
lower_left_x: 3.3,
|
||||
lower_left_y: 4.4
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_serialise_box_type_in_order() {
|
||||
let pg_box = PgBox {
|
||||
upper_right_x: 2.,
|
||||
lower_left_x: -2.,
|
||||
upper_right_y: -2.,
|
||||
lower_left_y: 2.,
|
||||
};
|
||||
assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_serialise_box_type_out_of_order() {
|
||||
let pg_box = PgBox {
|
||||
upper_right_x: -2.,
|
||||
lower_left_x: 2.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_y: -2.,
|
||||
};
|
||||
assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_order_box() {
|
||||
let pg_box = PgBox {
|
||||
upper_right_x: -2.,
|
||||
lower_left_x: 2.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_y: -2.,
|
||||
};
|
||||
let bytes = pg_box.serialize_to_vec();
|
||||
|
||||
let pg_box = PgBox::from_bytes(&bytes).unwrap();
|
||||
assert_eq!(
|
||||
pg_box,
|
||||
PgBox {
|
||||
upper_right_x: 2.,
|
||||
upper_right_y: 2.,
|
||||
lower_left_x: -2.,
|
||||
lower_left_y: -2.
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
283
sqlx-postgres/src/types/geometry/line_segment.rs
Normal file
283
sqlx-postgres/src/types/geometry/line_segment.rs
Normal file
@ -0,0 +1,283 @@
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::types::Type;
|
||||
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
|
||||
use sqlx_core::bytes::Buf;
|
||||
use std::str::FromStr;
|
||||
|
||||
const ERROR: &str = "error decoding LSEG";
|
||||
|
||||
/// ## Postgres Geometric Line Segment type
|
||||
///
|
||||
/// Description: Finite line segment
|
||||
/// Representation: `((start_x,start_y),(end_x,end_y))`
|
||||
///
|
||||
///
|
||||
/// Line segments are represented by pairs of points that are the endpoints of the segment. Values of type lseg are specified using any of the following syntaxes:
|
||||
/// ```text
|
||||
/// [ ( start_x , start_y ) , ( end_x , end_y ) ]
|
||||
/// ( ( start_x , start_y ) , ( end_x , end_y ) )
|
||||
/// ( start_x , start_y ) , ( end_x , end_y )
|
||||
/// start_x , start_y , end_x , end_y
|
||||
/// ```
|
||||
/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment.
|
||||
///
|
||||
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG
|
||||
#[doc(alias = "line segment")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PgLSeg {
|
||||
pub start_x: f64,
|
||||
pub start_y: f64,
|
||||
pub end_x: f64,
|
||||
pub end_y: f64,
|
||||
}
|
||||
|
||||
impl Type<Postgres> for PgLSeg {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("lseg")
|
||||
}
|
||||
}
|
||||
|
||||
impl PgHasArrayType for PgLSeg {
|
||||
fn array_type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("_lseg")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Decode<'r, Postgres> for PgLSeg {
|
||||
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
match value.format() {
|
||||
PgValueFormat::Text => Ok(PgLSeg::from_str(value.as_str()?)?),
|
||||
PgValueFormat::Binary => Ok(PgLSeg::from_bytes(value.as_bytes()?)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q> Encode<'q, Postgres> for PgLSeg {
|
||||
fn produces(&self) -> Option<PgTypeInfo> {
|
||||
Some(PgTypeInfo::with_name("lseg"))
|
||||
}
|
||||
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
|
||||
self.serialize(buf)?;
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PgLSeg {
|
||||
type Err = BoxDynError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
|
||||
let mut parts = sanitised.split(',');
|
||||
|
||||
let start_x = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get start_x from {}", ERROR, s))?;
|
||||
|
||||
let start_y = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get start_y from {}", ERROR, s))?;
|
||||
|
||||
let end_x = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get end_x from {}", ERROR, s))?;
|
||||
|
||||
let end_y = parts
|
||||
.next()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.ok_or_else(|| format!("{}: could not get end_y from {}", ERROR, s))?;
|
||||
|
||||
if parts.next().is_some() {
|
||||
return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into());
|
||||
}
|
||||
|
||||
Ok(PgLSeg {
|
||||
start_x,
|
||||
start_y,
|
||||
end_x,
|
||||
end_y,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PgLSeg {
|
||||
fn from_bytes(mut bytes: &[u8]) -> Result<PgLSeg, BoxDynError> {
|
||||
let start_x = bytes.get_f64();
|
||||
let start_y = bytes.get_f64();
|
||||
let end_x = bytes.get_f64();
|
||||
let end_y = bytes.get_f64();
|
||||
|
||||
Ok(PgLSeg {
|
||||
start_x,
|
||||
start_y,
|
||||
end_x,
|
||||
end_y,
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> {
|
||||
buff.extend_from_slice(&self.start_x.to_be_bytes());
|
||||
buff.extend_from_slice(&self.start_y.to_be_bytes());
|
||||
buff.extend_from_slice(&self.end_x.to_be_bytes());
|
||||
buff.extend_from_slice(&self.end_y.to_be_bytes());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn serialize_to_vec(&self) -> Vec<u8> {
|
||||
let mut buff = PgArgumentBuffer::default();
|
||||
self.serialize(&mut buff).unwrap();
|
||||
buff.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod lseg_tests {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::PgLSeg;
|
||||
|
||||
const LINE_SEGMENT_BYTES: &[u8] = &[
|
||||
63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102,
|
||||
102, 102, 102, 102, 102, 64, 17, 153, 153, 153, 153, 153, 154,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_bytes() {
|
||||
let lseg = PgLSeg::from_bytes(LINE_SEGMENT_BYTES).unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.1,
|
||||
start_y: 2.2,
|
||||
end_x: 3.3,
|
||||
end_y: 4.4
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_str_first_syntax() {
|
||||
let lseg = PgLSeg::from_str("[( 1, 2), (3, 4 )]").unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.,
|
||||
start_y: 2.,
|
||||
end_x: 3.,
|
||||
end_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_str_second_syntax() {
|
||||
let lseg = PgLSeg::from_str("(( 1, 2), (3, 4 ))").unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.,
|
||||
start_y: 2.,
|
||||
end_x: 3.,
|
||||
end_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_str_third_syntax() {
|
||||
let lseg = PgLSeg::from_str("(1, 2), (3, 4 )").unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.,
|
||||
start_y: 2.,
|
||||
end_x: 3.,
|
||||
end_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_str_fourth_syntax() {
|
||||
let lseg = PgLSeg::from_str("1, 2, 3, 4").unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.,
|
||||
start_y: 2.,
|
||||
end_x: 3.,
|
||||
end_y: 4.
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_too_many_numbers() {
|
||||
let input_str = "1, 2, 3, 4, 5";
|
||||
let lseg = PgLSeg::from_str(input_str);
|
||||
assert!(lseg.is_err());
|
||||
if let Err(err) = lseg {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding LSEG: too many numbers inputted in {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_too_few_numbers() {
|
||||
let input_str = "1, 2, 3";
|
||||
let lseg = PgLSeg::from_str(input_str);
|
||||
assert!(lseg.is_err());
|
||||
if let Err(err) = lseg {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding LSEG: could not get end_y from {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_invalid_numbers() {
|
||||
let input_str = "1, 2, 3, FOUR";
|
||||
let lseg = PgLSeg::from_str(input_str);
|
||||
assert!(lseg.is_err());
|
||||
if let Err(err) = lseg {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("error decoding LSEG: could not get end_y from {input_str}")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_deserialise_lseg_type_str_float() {
|
||||
let lseg = PgLSeg::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
|
||||
assert_eq!(
|
||||
lseg,
|
||||
PgLSeg {
|
||||
start_x: 1.1,
|
||||
start_y: 2.2,
|
||||
end_x: 3.3,
|
||||
end_y: 4.4
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_serialise_lseg_type() {
|
||||
let lseg = PgLSeg {
|
||||
start_x: 1.1,
|
||||
start_y: 2.2,
|
||||
end_x: 3.3,
|
||||
end_y: 4.4,
|
||||
};
|
||||
assert_eq!(lseg.serialize_to_vec(), LINE_SEGMENT_BYTES,)
|
||||
}
|
||||
}
|
@ -1,2 +1,4 @@
|
||||
pub mod r#box;
|
||||
pub mod line;
|
||||
pub mod line_segment;
|
||||
pub mod point;
|
||||
|
@ -21,8 +21,10 @@
|
||||
//! | [`PgLQuery`] | LQUERY |
|
||||
//! | [`PgCiText`] | CITEXT<sup>1</sup> |
|
||||
//! | [`PgCube`] | CUBE |
|
||||
//! | [`PgPoint] | POINT |
|
||||
//! | [`PgLine] | LINE |
|
||||
//! | [`PgPoint`] | POINT |
|
||||
//! | [`PgLine`] | LINE |
|
||||
//! | [`PgLSeg`] | LSEG |
|
||||
//! | [`PgBox`] | BOX |
|
||||
//! | [`PgHstore`] | HSTORE |
|
||||
//!
|
||||
//! <sup>1</sup> SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc.,
|
||||
@ -259,7 +261,9 @@ pub use array::PgHasArrayType;
|
||||
pub use citext::PgCiText;
|
||||
pub use cube::PgCube;
|
||||
pub use geometry::line::PgLine;
|
||||
pub use geometry::line_segment::PgLSeg;
|
||||
pub use geometry::point::PgPoint;
|
||||
pub use geometry::r#box::PgBox;
|
||||
pub use hstore::PgHstore;
|
||||
pub use interval::PgInterval;
|
||||
pub use lquery::PgLQuery;
|
||||
|
@ -41,13 +41,13 @@ impl<'a> PgRecordEncoder<'a> {
|
||||
{
|
||||
let ty = value.produces().unwrap_or_else(T::type_info);
|
||||
|
||||
if let PgType::DeclareWithName(name) = ty.0 {
|
||||
match ty.0 {
|
||||
// push a hole for this type ID
|
||||
// to be filled in on query execution
|
||||
self.buf.patch_type_by_name(&name);
|
||||
} else {
|
||||
PgType::DeclareWithName(name) => self.buf.patch_type_by_name(&name),
|
||||
PgType::DeclareArrayOf(array) => self.buf.patch_array_type(array),
|
||||
// write type id
|
||||
self.buf.extend(&ty.0.oid().0.to_be_bytes());
|
||||
pg_type => self.buf.extend(&pg_type.oid().0.to_be_bytes()),
|
||||
}
|
||||
|
||||
self.buf.encode(value)?;
|
||||
|
@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]
|
||||
|
||||
regexp = ["dep:regex"]
|
||||
|
||||
preupdate-hook = ["libsqlite3-sys/preupdate_hook"]
|
||||
|
||||
bundled = ["libsqlite3-sys/bundled"]
|
||||
unbundled = ["libsqlite3-sys/buildtime_bindgen"]
|
||||
|
||||
@ -48,6 +50,7 @@ atoi = "2.0"
|
||||
|
||||
log = "0.4.18"
|
||||
tracing = { version = "0.1.37", features = ["log"] }
|
||||
thiserror = "2.0.0"
|
||||
|
||||
serde = { version = "1.0.145", features = ["derive"], optional = true }
|
||||
regex = { version = "1.5.5", optional = true }
|
||||
|
@ -17,6 +17,7 @@ use sqlx_core::database::Database;
|
||||
use sqlx_core::describe::Describe;
|
||||
use sqlx_core::executor::Executor;
|
||||
use sqlx_core::transaction::TransactionManager;
|
||||
use std::pin::pin;
|
||||
|
||||
sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite);
|
||||
|
||||
@ -105,12 +106,12 @@ impl AnyConnectionBackend for SqliteConnection {
|
||||
let args = arguments.map(map_arguments);
|
||||
|
||||
Box::pin(async move {
|
||||
let stream = self
|
||||
.worker
|
||||
.execute(query, args, self.row_channel_size, persistent, Some(1))
|
||||
.map_ok(flume::Receiver::into_stream)
|
||||
.await?;
|
||||
futures_util::pin_mut!(stream);
|
||||
let mut stream = pin!(
|
||||
self.worker
|
||||
.execute(query, args, self.row_channel_size, persistent, Some(1))
|
||||
.map_ok(flume::Receiver::into_stream)
|
||||
.await?
|
||||
);
|
||||
|
||||
if let Some(Either::Right(row)) = stream.try_next().await? {
|
||||
return Ok(Some(AnyRow::try_from(&row)?));
|
||||
|
@ -296,6 +296,8 @@ impl EstablishParams {
|
||||
log_settings: self.log_settings.clone(),
|
||||
progress_handler_callback: None,
|
||||
update_hook_callback: None,
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
preupdate_hook_callback: None,
|
||||
commit_hook_callback: None,
|
||||
rollback_hook_callback: None,
|
||||
})
|
||||
|
@ -8,7 +8,7 @@ use sqlx_core::describe::Describe;
|
||||
use sqlx_core::error::Error;
|
||||
use sqlx_core::executor::{Execute, Executor};
|
||||
use sqlx_core::Either;
|
||||
use std::future;
|
||||
use std::{future, pin::pin};
|
||||
|
||||
impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
||||
type Database = Sqlite;
|
||||
@ -56,13 +56,11 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
|
||||
let persistent = query.persistent() && arguments.is_some();
|
||||
|
||||
Box::pin(async move {
|
||||
let stream = self
|
||||
let mut stream = pin!(self
|
||||
.worker
|
||||
.execute(sql, arguments, self.row_channel_size, persistent, Some(1))
|
||||
.map_ok(flume::Receiver::into_stream)
|
||||
.try_flatten_stream();
|
||||
|
||||
futures_util::pin_mut!(stream);
|
||||
.try_flatten_stream());
|
||||
|
||||
while let Some(res) = stream.try_next().await? {
|
||||
if let Either::Right(row) = res {
|
||||
|
@ -14,6 +14,8 @@ use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
|
||||
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
|
||||
};
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub use preupdate_hook::*;
|
||||
|
||||
pub(crate) use handle::ConnectionHandle;
|
||||
use sqlx_core::common::StatementCache;
|
||||
@ -26,7 +28,7 @@ use crate::connection::establish::EstablishParams;
|
||||
use crate::connection::worker::ConnectionWorker;
|
||||
use crate::options::OptimizeOnClose;
|
||||
use crate::statement::VirtualStatement;
|
||||
use crate::{Sqlite, SqliteConnectOptions};
|
||||
use crate::{Sqlite, SqliteConnectOptions, SqliteError};
|
||||
|
||||
pub(crate) mod collation;
|
||||
pub(crate) mod describe;
|
||||
@ -36,6 +38,8 @@ mod executor;
|
||||
mod explain;
|
||||
mod handle;
|
||||
pub(crate) mod intmap;
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
mod preupdate_hook;
|
||||
|
||||
mod worker;
|
||||
|
||||
@ -88,6 +92,7 @@ pub struct UpdateHookResult<'a> {
|
||||
pub table: &'a str,
|
||||
pub rowid: i64,
|
||||
}
|
||||
|
||||
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
|
||||
unsafe impl Send for UpdateHookHandler {}
|
||||
|
||||
@ -112,6 +117,8 @@ pub(crate) struct ConnectionState {
|
||||
progress_handler_callback: Option<Handler>,
|
||||
|
||||
update_hook_callback: Option<UpdateHookHandler>,
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,
|
||||
|
||||
commit_hook_callback: Option<CommitHookHandler>,
|
||||
|
||||
@ -138,6 +145,16 @@ impl ConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub(crate) fn remove_preupdate_hook(&mut self) {
|
||||
if let Some(mut handler) = self.preupdate_hook_callback.take() {
|
||||
unsafe {
|
||||
libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
|
||||
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_commit_hook(&mut self) {
|
||||
if let Some(mut handler) = self.commit_hook_callback.take() {
|
||||
unsafe {
|
||||
@ -421,6 +438,34 @@ impl LockedSqliteHandle<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
|
||||
/// At most one preupdate hook may be registered at a time on a single database connection.
|
||||
///
|
||||
/// The preupdate hook only fires for changes to real database tables;
|
||||
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
|
||||
///
|
||||
/// See https://sqlite.org/c3ref/preupdate_count.html
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub fn set_preupdate_hook<F>(&mut self, callback: F)
|
||||
where
|
||||
F: FnMut(PreupdateHookResult) + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let callback_boxed = Box::new(callback);
|
||||
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||
let handler = callback.as_ptr() as *mut _;
|
||||
self.guard.remove_preupdate_hook();
|
||||
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));
|
||||
|
||||
libsqlite3_sys::sqlite3_preupdate_hook(
|
||||
self.as_raw_handle().as_mut(),
|
||||
Some(preupdate_hook::<F>),
|
||||
handler,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
|
||||
/// returns `false`, then the operation is turned into a ROLLBACK.
|
||||
///
|
||||
@ -485,6 +530,11 @@ impl LockedSqliteHandle<'_> {
|
||||
self.guard.remove_update_hook();
|
||||
}
|
||||
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub fn remove_preupdate_hook(&mut self) {
|
||||
self.guard.remove_preupdate_hook();
|
||||
}
|
||||
|
||||
pub fn remove_commit_hook(&mut self) {
|
||||
self.guard.remove_commit_hook();
|
||||
}
|
||||
@ -492,6 +542,10 @@ impl LockedSqliteHandle<'_> {
|
||||
pub fn remove_rollback_hook(&mut self) {
|
||||
self.guard.remove_rollback_hook();
|
||||
}
|
||||
|
||||
pub fn last_error(&mut self) -> Option<SqliteError> {
|
||||
SqliteError::try_new(self.guard.handle.as_ptr())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnectionState {
|
||||
|
160
sqlx-sqlite/src/connection/preupdate_hook.rs
Normal file
160
sqlx-sqlite/src/connection/preupdate_hook.rs
Normal file
@ -0,0 +1,160 @@
|
||||
use super::SqliteOperation;
|
||||
use crate::type_info::DataType;
|
||||
use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef};
|
||||
|
||||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new,
|
||||
sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
|
||||
};
|
||||
use std::ffi::CStr;
|
||||
use std::marker::PhantomData;
|
||||
use std::os::raw::{c_char, c_int, c_void};
|
||||
use std::panic::catch_unwind;
|
||||
use std::ptr;
|
||||
use std::ptr::NonNull;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PreupdateError {
|
||||
/// Error returned from the database.
|
||||
#[error("error returned from database: {0}")]
|
||||
Database(#[source] SqliteError),
|
||||
/// Index is not within the valid column range
|
||||
#[error("{0} is not within the valid column range")]
|
||||
ColumnIndexOutOfBounds(i32),
|
||||
/// Column value accessor was invoked from an invalid operation
|
||||
#[error("column value accessor was invoked from an invalid operation")]
|
||||
InvalidOperation,
|
||||
}
|
||||
|
||||
pub(crate) struct PreupdateHookHandler(
|
||||
pub(super) NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
|
||||
);
|
||||
unsafe impl Send for PreupdateHookHandler {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PreupdateHookResult<'a> {
|
||||
pub operation: SqliteOperation,
|
||||
pub database: &'a str,
|
||||
pub table: &'a str,
|
||||
db: *mut sqlite3,
|
||||
// The database pointer should not be usable after the preupdate hook.
|
||||
// The lifetime on this struct needs to ensure it cannot outlive the callback.
|
||||
_db_lifetime: PhantomData<&'a ()>,
|
||||
old_row_id: i64,
|
||||
new_row_id: i64,
|
||||
}
|
||||
|
||||
impl<'a> PreupdateHookResult<'a> {
|
||||
/// Gets the amount of columns in the row being inserted, deleted, or updated.
|
||||
pub fn get_column_count(&self) -> i32 {
|
||||
unsafe { sqlite3_preupdate_count(self.db) }
|
||||
}
|
||||
|
||||
/// Gets the depth of the query that triggered the preupdate hook.
|
||||
/// Returns 0 if the preupdate callback was invoked as a result of
|
||||
/// a direct insert, update, or delete operation;
|
||||
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
|
||||
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
|
||||
pub fn get_query_depth(&self) -> i32 {
|
||||
unsafe { sqlite3_preupdate_depth(self.db) }
|
||||
}
|
||||
|
||||
/// Gets the row id of the row being updated/deleted.
|
||||
/// Returns an error if called from an insert operation.
|
||||
pub fn get_old_row_id(&self) -> Result<i64, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Insert {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
Ok(self.old_row_id)
|
||||
}
|
||||
|
||||
/// Gets the row id of the row being inserted/updated.
|
||||
/// Returns an error if called from a delete operation.
|
||||
pub fn get_new_row_id(&self) -> Result<i64, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Delete {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
Ok(self.new_row_id)
|
||||
}
|
||||
|
||||
/// Gets the value of the row being updated/deleted at the specified index.
|
||||
/// Returns an error if called from an insert operation or the index is out of bounds.
|
||||
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Insert {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
self.validate_column_index(i)?;
|
||||
|
||||
let mut p_value: *mut sqlite3_value = ptr::null_mut();
|
||||
unsafe {
|
||||
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
|
||||
self.get_value(ret, p_value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the value of the row being inserted/updated at the specified index.
|
||||
/// Returns an error if called from a delete operation or the index is out of bounds.
|
||||
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Delete {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
self.validate_column_index(i)?;
|
||||
|
||||
let mut p_value: *mut sqlite3_value = ptr::null_mut();
|
||||
unsafe {
|
||||
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
|
||||
self.get_value(ret, p_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> {
|
||||
if i < 0 || i >= self.get_column_count() {
|
||||
return Err(PreupdateError::ColumnIndexOutOfBounds(i));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn get_value(
|
||||
&self,
|
||||
ret: i32,
|
||||
p_value: *mut sqlite3_value,
|
||||
) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if ret != SQLITE_OK {
|
||||
return Err(PreupdateError::Database(SqliteError::new(self.db)));
|
||||
}
|
||||
let data_type = DataType::from_code(sqlite3_value_type(p_value));
|
||||
// SAFETY: SQLite will free the sqlite3_value when the callback returns
|
||||
Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type)))
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) extern "C" fn preupdate_hook<F>(
|
||||
callback: *mut c_void,
|
||||
db: *mut sqlite3,
|
||||
op_code: c_int,
|
||||
database: *const c_char,
|
||||
table: *const c_char,
|
||||
old_row_id: i64,
|
||||
new_row_id: i64,
|
||||
) where
|
||||
F: FnMut(PreupdateHookResult) + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let _ = catch_unwind(|| {
|
||||
let callback: *mut F = callback.cast::<F>();
|
||||
let operation: SqliteOperation = op_code.into();
|
||||
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
|
||||
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
|
||||
|
||||
(*callback)(PreupdateHookResult {
|
||||
operation,
|
||||
database,
|
||||
table,
|
||||
old_row_id,
|
||||
new_row_id,
|
||||
db,
|
||||
_db_lifetime: PhantomData,
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
@ -151,7 +151,8 @@ impl ConnectionWorker {
|
||||
match limit {
|
||||
None => {
|
||||
for res in iter {
|
||||
if tx.send(res).is_err() {
|
||||
let has_error = res.is_err();
|
||||
if tx.send(res).is_err() || has_error {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -171,7 +172,8 @@ impl ConnectionWorker {
|
||||
}
|
||||
}
|
||||
}
|
||||
if tx.send(res).is_err() {
|
||||
let has_error = res.is_err();
|
||||
if tx.send(res).is_err() || has_error {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -23,9 +23,17 @@ pub struct SqliteError {
|
||||
|
||||
impl SqliteError {
|
||||
pub(crate) fn new(handle: *mut sqlite3) -> Self {
|
||||
Self::try_new(handle).expect("There should be an error")
|
||||
}
|
||||
|
||||
pub(crate) fn try_new(handle: *mut sqlite3) -> Option<Self> {
|
||||
// returns the extended result code even when extended result codes are disabled
|
||||
let code: c_int = unsafe { sqlite3_extended_errcode(handle) };
|
||||
|
||||
if code == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// return English-language text that describes the error
|
||||
let message = unsafe {
|
||||
let msg = sqlite3_errmsg(handle);
|
||||
@ -34,10 +42,10 @@ impl SqliteError {
|
||||
from_utf8_unchecked(CStr::from_ptr(msg).to_bytes())
|
||||
};
|
||||
|
||||
Self {
|
||||
Some(Self {
|
||||
code,
|
||||
message: message.to_owned(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// For errors during extension load, the error message is supplied via a separate pointer
|
||||
|
@ -46,6 +46,8 @@ use std::sync::atomic::AtomicBool;
|
||||
|
||||
pub use arguments::{SqliteArgumentValue, SqliteArguments};
|
||||
pub use column::SqliteColumn;
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub use connection::PreupdateHookResult;
|
||||
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
|
||||
pub use database::Sqlite;
|
||||
pub use error::SqliteError;
|
||||
|
@ -30,6 +30,10 @@ impl TestSupport for Sqlite {
|
||||
) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn db_name(args: &TestArgs) -> String {
|
||||
convert_path(args.test_path)
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_context(args: &TestArgs) -> Result<TestContext<Sqlite>, Error> {
|
||||
|
@ -1,4 +1,5 @@
|
||||
use std::borrow::Cow;
|
||||
use std::marker::PhantomData;
|
||||
use std::ptr::NonNull;
|
||||
use std::slice::from_raw_parts;
|
||||
use std::str::from_utf8;
|
||||
@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo};
|
||||
|
||||
enum SqliteValueData<'r> {
|
||||
Value(&'r SqliteValue),
|
||||
BorrowedHandle(ValueHandle<'r>),
|
||||
}
|
||||
|
||||
pub struct SqliteValueRef<'r>(SqliteValueData<'r>);
|
||||
@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> {
|
||||
Self(SqliteValueData::Value(value))
|
||||
}
|
||||
|
||||
// SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop.
|
||||
// The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it.
|
||||
#[allow(unused)]
|
||||
pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info);
|
||||
Self(SqliteValueData::BorrowedHandle(handle))
|
||||
}
|
||||
|
||||
// NOTE: `int()` is deliberately omitted because it will silently truncate a wider value,
|
||||
// which is likely to cause bugs:
|
||||
// https://github.com/launchbadge/sqlx/issues/3179
|
||||
// (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161
|
||||
pub(super) fn int64(&self) -> i64 {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.int64(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.int64(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.int64(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn double(&self) -> f64 {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.double(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.double(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.double(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn blob(&self) -> &'r [u8] {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.blob(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.blob(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.blob(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn text(&self) -> Result<&'r str, BoxDynError> {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.text(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.text(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> {
|
||||
type Database = Sqlite;
|
||||
|
||||
fn to_owned(&self) -> SqliteValue {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.clone(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => (*v).clone(),
|
||||
SqliteValueData::BorrowedHandle(v) => unsafe {
|
||||
SqliteValue::new(v.value.as_ptr(), v.type_info.clone())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
match self.0 {
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.type_info(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.type_info(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
match self.0 {
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.is_null(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.is_null(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SqliteValue {
|
||||
pub(crate) handle: Arc<ValueHandle>,
|
||||
pub(crate) type_info: SqliteTypeInfo,
|
||||
pub struct SqliteValue(Arc<ValueHandle<'static>>);
|
||||
|
||||
pub(crate) struct ValueHandle<'a> {
|
||||
value: NonNull<sqlite3_value>,
|
||||
type_info: SqliteTypeInfo,
|
||||
free_on_drop: bool,
|
||||
_sqlite_value_lifetime: PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
pub(crate) struct ValueHandle(NonNull<sqlite3_value>);
|
||||
|
||||
// SAFE: only protected value objects are stored in SqliteValue
|
||||
unsafe impl Send for ValueHandle {}
|
||||
unsafe impl Sync for ValueHandle {}
|
||||
|
||||
impl SqliteValue {
|
||||
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
unsafe impl<'a> Send for ValueHandle<'a> {}
|
||||
unsafe impl<'a> Sync for ValueHandle<'a> {}
|
||||
|
||||
impl ValueHandle<'static> {
|
||||
fn new_owned(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
|
||||
Self {
|
||||
value,
|
||||
type_info,
|
||||
handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup(
|
||||
value,
|
||||
)))),
|
||||
free_on_drop: true,
|
||||
_sqlite_value_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ValueHandle<'a> {
|
||||
fn new_borrowed(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
|
||||
Self {
|
||||
value,
|
||||
type_info,
|
||||
free_on_drop: false,
|
||||
_sqlite_value_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info_opt(&self) -> Option<SqliteTypeInfo> {
|
||||
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) });
|
||||
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) });
|
||||
|
||||
if let DataType::Null = dt {
|
||||
None
|
||||
@ -112,15 +143,15 @@ impl SqliteValue {
|
||||
}
|
||||
|
||||
fn int64(&self) -> i64 {
|
||||
unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) }
|
||||
unsafe { sqlite3_value_int64(self.value.as_ptr()) }
|
||||
}
|
||||
|
||||
fn double(&self) -> f64 {
|
||||
unsafe { sqlite3_value_double(self.handle.0.as_ptr()) }
|
||||
unsafe { sqlite3_value_double(self.value.as_ptr()) }
|
||||
}
|
||||
|
||||
fn blob(&self) -> &[u8] {
|
||||
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) };
|
||||
fn blob<'b>(&self) -> &'b [u8] {
|
||||
let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) };
|
||||
|
||||
// This likely means UB in SQLite itself or our usage of it;
|
||||
// signed integer overflow is UB in the C standard.
|
||||
@ -133,15 +164,45 @@ impl SqliteValue {
|
||||
return &[];
|
||||
}
|
||||
|
||||
let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8;
|
||||
let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8;
|
||||
debug_assert!(!ptr.is_null());
|
||||
|
||||
unsafe { from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
fn text(&self) -> Result<&str, BoxDynError> {
|
||||
fn text<'b>(&self) -> Result<&'b str, BoxDynError> {
|
||||
Ok(from_utf8(self.blob())?)
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
self.type_info_opt()
|
||||
.map(Cow::Owned)
|
||||
.unwrap_or(Cow::Borrowed(&self.type_info))
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for ValueHandle<'a> {
|
||||
fn drop(&mut self) {
|
||||
if self.free_on_drop {
|
||||
unsafe {
|
||||
sqlite3_value_free(self.value.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SqliteValue {
|
||||
// SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop.
|
||||
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
let handle =
|
||||
ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info);
|
||||
Self(Arc::new(handle))
|
||||
}
|
||||
}
|
||||
|
||||
impl Value for SqliteValue {
|
||||
@ -152,21 +213,11 @@ impl Value for SqliteValue {
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
self.type_info_opt()
|
||||
.map(Cow::Owned)
|
||||
.unwrap_or(Cow::Borrowed(&self.type_info))
|
||||
self.0.type_info()
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ValueHandle {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
sqlite3_value_free(self.0.as_ptr());
|
||||
}
|
||||
self.0.is_null()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,14 @@
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![doc = include_str!("lib.md")]
|
||||
|
||||
#[cfg(all(
|
||||
feature = "sqlite-preupdate-hook",
|
||||
not(any(feature = "sqlite", feature = "sqlite-unbundled"))
|
||||
))]
|
||||
compile_error!(
|
||||
"sqlite-preupdate-hook requires either 'sqlite' or 'sqlite-unbundled' to be enabled"
|
||||
);
|
||||
|
||||
pub use sqlx_core::acquire::Acquire;
|
||||
pub use sqlx_core::arguments::{Arguments, IntoArguments};
|
||||
pub use sqlx_core::column::Column;
|
||||
|
@ -494,6 +494,31 @@ async fn test_from_row_json_attr() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_from_row_json_attr_nullable() -> anyhow::Result<()> {
|
||||
#[derive(serde::Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct J {
|
||||
a: u32,
|
||||
b: u32,
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct Record {
|
||||
#[sqlx(json(nullable))]
|
||||
j: Option<J>,
|
||||
}
|
||||
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let record = sqlx::query_as::<_, Record>("select NULL as j")
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(record.j.is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_from_row_json_try_from_attr() -> anyhow::Result<()> {
|
||||
#[derive(serde::Deserialize)]
|
||||
|
@ -810,3 +810,69 @@ async fn test_custom_pg_array() -> anyhow::Result<()> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_record_array_type() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
DROP TABLE IF EXISTS responses;
|
||||
|
||||
DROP TYPE IF EXISTS http_response CASCADE;
|
||||
DROP TYPE IF EXISTS header_pair CASCADE;
|
||||
|
||||
CREATE TYPE header_pair AS (
|
||||
name TEXT,
|
||||
value TEXT
|
||||
);
|
||||
|
||||
CREATE TYPE http_response AS (
|
||||
headers header_pair[]
|
||||
);
|
||||
|
||||
CREATE TABLE responses (
|
||||
response http_response NOT NULL
|
||||
);
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[derive(Debug, sqlx::Type)]
|
||||
#[sqlx(type_name = "http_response")]
|
||||
struct HttpResponseRecord {
|
||||
headers: Vec<HeaderPairRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, sqlx::Type)]
|
||||
#[sqlx(type_name = "header_pair")]
|
||||
struct HeaderPairRecord {
|
||||
name: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
let value = HttpResponseRecord {
|
||||
headers: vec![
|
||||
HeaderPairRecord {
|
||||
name: "Content-Type".to_owned(),
|
||||
value: "text/html; charset=utf-8".to_owned(),
|
||||
},
|
||||
HeaderPairRecord {
|
||||
name: "Cache-Control".to_owned(),
|
||||
value: "max-age=0".to_owned(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
"
|
||||
INSERT INTO responses (response)
|
||||
VALUES ($1)
|
||||
",
|
||||
)
|
||||
.bind(&value)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -3,13 +3,13 @@ use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use sqlx::postgres::types::Oid;
|
||||
use sqlx::postgres::{
|
||||
PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener,
|
||||
PgPoolOptions, PgRow, PgSeverity, Postgres,
|
||||
PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN,
|
||||
};
|
||||
use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo};
|
||||
use sqlx_core::{bytes::Bytes, error::BoxDynError};
|
||||
use sqlx_test::{new, pool, setup_if_needed};
|
||||
use std::env;
|
||||
use std::pin::Pin;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@ -637,8 +637,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
|
||||
let pool = pool.clone();
|
||||
sqlx_core::rt::spawn(async move {
|
||||
while !pool.is_closed() {
|
||||
let acquire = pool.acquire();
|
||||
futures::pin_mut!(acquire);
|
||||
let mut acquire = pin!(pool.acquire());
|
||||
|
||||
// poll the acquire future once to put the waiter in the queue
|
||||
future::poll_fn(move |cx| {
|
||||
@ -2042,3 +2041,78 @@ async fn test_issue_3052() {
|
||||
"expected encode error, got {too_large_error:?}",
|
||||
);
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
let mut row = "1".repeat(PG_COPY_MAX_DATA_LEN / 10 - 1);
|
||||
row.push_str("\n");
|
||||
|
||||
// creates a payload with COPY_MAX_DATA_LEN + 1 as size
|
||||
let mut payload = row.repeat(10);
|
||||
payload.push_str("12345678\n");
|
||||
|
||||
assert_eq!(payload.len(), PG_COPY_MAX_DATA_LEN + 1);
|
||||
|
||||
let mut copy = conn.copy_in_raw("COPY products(name) FROM STDIN").await?;
|
||||
|
||||
assert!(copy.send(payload.as_bytes()).await.is_ok());
|
||||
assert!(copy.finish().await.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_copy_in_error_case(query: &str, expected_error: &str) -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
conn.execute("CREATE TEMPORARY TABLE IF NOT EXISTS invalid_copy_target (id int4)")
|
||||
.await?;
|
||||
// Try the COPY operation
|
||||
match conn.copy_in_raw(query).await {
|
||||
Ok(_) => anyhow::bail!("expected error"),
|
||||
Err(e) => assert!(
|
||||
e.to_string().contains(expected_error),
|
||||
"expected error to contain: {expected_error}, got: {e:?}"
|
||||
),
|
||||
}
|
||||
// Verify connection is still usable
|
||||
let value = sqlx::query("select 1 + 1")
|
||||
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
assert_eq!(2i32, value);
|
||||
Ok(())
|
||||
}
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_recover_from_copy_in_to_missing_table() -> anyhow::Result<()> {
|
||||
test_copy_in_error_case(
|
||||
r#"
|
||||
COPY nonexistent_table (id) FROM STDIN WITH (FORMAT CSV, HEADER);
|
||||
"#,
|
||||
"does not exist",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_recover_from_copy_in_empty_query() -> anyhow::Result<()> {
|
||||
test_copy_in_error_case("", "EmptyQuery").await
|
||||
}
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_recover_from_copy_in_syntax_error() -> anyhow::Result<()> {
|
||||
test_copy_in_error_case(
|
||||
r#"
|
||||
COPY FROM STDIN WITH (FORMAT CSV);
|
||||
"#,
|
||||
"syntax error",
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_recover_from_copy_in_invalid_params() -> anyhow::Result<()> {
|
||||
test_copy_in_error_case(
|
||||
r#"
|
||||
COPY invalid_copy_target FROM STDIN WITH (FORMAT CSV, INVALID_PARAM true);
|
||||
"#,
|
||||
"invalid_param",
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
@ -509,6 +509,21 @@ test_type!(line<sqlx::postgres::types::PgLine>(Postgres,
|
||||
"line('((0.0, 0.0), (1.0,1.0))')" == sqlx::postgres::types::PgLine { a: 1., b: -1., c: 0. },
|
||||
));
|
||||
|
||||
#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))]
|
||||
test_type!(lseg<sqlx::postgres::types::PgLSeg>(Postgres,
|
||||
"lseg('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.},
|
||||
));
|
||||
|
||||
#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))]
|
||||
test_type!(box<sqlx::postgres::types::PgBox>(Postgres,
|
||||
"box('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1. , lower_left_y: 2.},
|
||||
));
|
||||
|
||||
#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))]
|
||||
test_type!(_box<Vec<sqlx::postgres::types::PgBox>>(Postgres,
|
||||
"array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }],
|
||||
));
|
||||
|
||||
#[cfg(feature = "rust_decimal")]
|
||||
test_type!(decimal<sqlx::types::Decimal>(Postgres,
|
||||
"0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),
|
||||
|
@ -2,11 +2,14 @@ use futures::TryStreamExt;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_xoshiro::Xoshiro256PlusPlus;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
|
||||
use sqlx::Decode;
|
||||
use sqlx::{
|
||||
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
|
||||
SqliteConnection, SqlitePool, Statement, TypeInfo,
|
||||
};
|
||||
use sqlx::{Value, ValueRef};
|
||||
use sqlx_test::new;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[sqlx_macros::test]
|
||||
@ -798,7 +801,7 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_update_hook(move |result| {
|
||||
@ -807,11 +810,13 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
assert_eq!(result.rowid, 2);
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -852,10 +857,11 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_commit_hook() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_commit_hook(move || {
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
assert_eq!(state, "test");
|
||||
false
|
||||
});
|
||||
@ -870,7 +876,7 @@ async fn test_query_with_commit_hook() -> anyhow::Result<()> {
|
||||
}
|
||||
_ => panic!("expected an error"),
|
||||
}
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -916,8 +922,10 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
|
||||
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||
assert_eq!(state, "test");
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let mut tx = conn.begin().await?;
|
||||
@ -925,6 +933,7 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
tx.rollback().await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -960,3 +969,227 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook({
|
||||
move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Insert);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(2, result.get_new_row_id().unwrap());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(
|
||||
4,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_new_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_new_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_new_column_value(4).is_err());
|
||||
// old values aren't available for inserts
|
||||
assert!(result.get_old_column_value(0).is_err());
|
||||
assert!(result.get_old_row_id().is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
let _ = sqlx::query("DELETE FROM tweet where id = 4")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Delete);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(2, result.get_old_row_id().unwrap());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(
|
||||
5,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_old_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_old_column_value(1).unwrap()).unwrap()
|
||||
);
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_old_column_value(4).is_err());
|
||||
// new values aren't available for deletes
|
||||
assert!(result.get_new_column_value(0).is_err());
|
||||
assert!(result.get_new_row_id().is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let _ = sqlx::query("DELETE FROM tweet WHERE id = 5")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
let sqlite_value_stored: Arc<std::sync::Mutex<Option<_>>> = Default::default();
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook({
|
||||
let sqlite_value_stored = sqlite_value_stored.clone();
|
||||
move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Update);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(4, result.get_column_count());
|
||||
|
||||
assert_eq!(2, result.get_old_row_id().unwrap());
|
||||
assert_eq!(2, result.get_new_row_id().unwrap());
|
||||
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_old_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_new_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_old_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World2",
|
||||
<String as Decode<Sqlite>>::decode(result.get_new_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
*sqlite_value_stored.lock().unwrap() =
|
||||
Some(result.get_old_column_value(0).unwrap().to_owned());
|
||||
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_old_column_value(4).is_err());
|
||||
assert!(result.get_new_column_value(4).is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
let _ = sqlx::query("DELETE FROM tweet where id = 6")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
// Ensure that taking an owned SqliteValue maintains a valid reference after the callback returns
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(
|
||||
sqlite_value_stored.lock().unwrap().take().unwrap().as_ref()
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||
let ref_counted_object = Arc::new(0);
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
{
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
}
|
||||
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_get_last_error() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let _ = sqlx::query("select 1").fetch_one(&mut conn).await?;
|
||||
|
||||
{
|
||||
let mut handle = conn.lock_handle().await?;
|
||||
assert!(handle.last_error().is_none());
|
||||
}
|
||||
|
||||
let _ = sqlx::query("invalid statement").fetch_one(&mut conn).await;
|
||||
|
||||
{
|
||||
let mut handle = conn.lock_handle().await?;
|
||||
assert!(handle.last_error().is_some());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user