diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 8d25e96f8..5b8718347 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -14,12 +14,12 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Use latest Rust - run: rustup override set stable + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup override set stable - uses: Swatinem/rust-cache@v2 - with: - key: sqlx-cli - run: > cargo build @@ -63,9 +63,10 @@ jobs: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: mysql-examples - name: Todos (Setup) working-directory: examples/mysql/todos @@ -106,9 +107,8 @@ jobs: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: pg-examples + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install - name: Axum Social with Tests (Setup) working-directory: examples/postgres/axum-social-with-tests @@ -231,9 +231,10 @@ jobs: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: sqlite-examples - name: TODOs (Setup) env: diff --git a/.github/workflows/sqlx-cli.yml b/.github/workflows/sqlx-cli.yml index 3aeb3d7d3..2250e0bfc 100644 --- a/.github/workflows/sqlx-cli.yml +++ b/.github/workflows/sqlx-cli.yml @@ -15,8 +15,9 @@ jobs: steps: - uses: actions/checkout@v4 - - run: | - rustup update + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install rustup component add clippy rustup toolchain install beta rustup component add --toolchain beta clippy @@ -40,18 +41,19 @@ jobs: matrix: # Note: macOS-latest uses M1 Silicon (ARM64) os: - - ubuntu-latest - # FIXME: migrations tests fail on Windows for whatever reason - # - windows-latest - - macOS-13 - - macOS-latest + - ubuntu-latest + # FIXME: migrations tests fail on Windows for whatever reason + # - windows-latest + - macOS-13 + - macOS-latest steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-test - run: cargo test --manifest-path sqlx-cli/Cargo.toml @@ -85,12 +87,12 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Use latest Rust - run: rustup override set stable + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup override set stable - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-cli - run: cargo build --manifest-path sqlx-cli/Cargo.toml --bin cargo-sqlx ${{ matrix.args }} diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 3f1f44d39..7f573a634 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -10,7 +10,7 @@ on: jobs: format: name: Format - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: rustup component add rustfmt @@ -18,24 +18,25 @@ jobs: check: name: Check - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: - runtime: [async-std, tokio] - tls: [native-tls, rustls, none] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls, none ] steps: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}" - - - run: | - rustup update + # Swatinem/rust-cache recommends setting up the rust toolchain first because it's used in cache keys + - name: Setup Rust + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + run: | + rustup show active-toolchain || rustup toolchain install rustup component add clippy rustup toolchain install beta rustup component add --toolchain beta clippy + - uses: Swatinem/rust-cache@v2 + - run: > cargo clippy --no-default-features @@ -52,26 +53,27 @@ jobs: check-minimal-versions: name: Check build using minimal versions - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - - run: rustup update - - run: rustup toolchain install nightly + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup toolchain install nightly - run: cargo +nightly generate-lockfile -Z minimal-versions - run: cargo build --all-features test: name: Unit Tests - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-test + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install - - name: Install Rust - run: rustup update + - uses: Swatinem/rust-cache@v2 - name: Test sqlx-core run: > @@ -113,20 +115,22 @@ jobs: sqlite: name: SQLite - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: - runtime: [async-std, tokio] - linking: [sqlite, sqlite-unbundled] + runtime: [ async-std, tokio ] + linking: [ sqlite, sqlite-unbundled ] needs: check steps: - uses: actions/checkout@v4 - run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-${{ matrix.linking }}-${{ matrix.runtime }}-${{ matrix.tls }}" - name: Install system sqlite library if: ${{ matrix.linking == 'sqlite-unbundled' }} @@ -179,19 +183,20 @@ jobs: postgres: name: Postgres - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: - postgres: [17, 13] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + postgres: [ 17, 13 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ matrix.tls }}" - env: # FIXME: needed to disable `ltree` tests in Postgres 9.6 @@ -279,19 +284,20 @@ jobs: mysql: name: MySQL - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: - mysql: [8] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + mysql: [ 8 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}" - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} @@ -367,19 +373,20 @@ jobs: mariadb: name: MariaDB - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: - mariadb: [verylatest, 11_4, 10_11, 10_4] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + mariadb: [ verylatest, 11_4, 10_11, 10_4 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}" - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} diff --git a/Cargo.lock b/Cargo.lock index ab37075bc..9307cee6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2129,6 +2129,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ipnet" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" + [[package]] name = "ipnetwork" version = "0.20.0" @@ -3691,6 +3697,7 @@ dependencies = [ "hashbrown 0.15.2", "hashlink", "indexmap 2.7.0", + "ipnet", "ipnetwork", "log", "mac_address", @@ -3989,6 +3996,7 @@ dependencies = [ "hkdf", "hmac", "home", + "ipnet", "ipnetwork", "itoa", "log", diff --git a/Cargo.toml b/Cargo.toml index 4d74d38c4..50ef662d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ _unstable-all-types = [ "json", "time", "chrono", + "ipnet", "ipnetwork", "mac_address", "uuid", @@ -123,6 +124,7 @@ json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sq bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] @@ -150,6 +152,7 @@ sqlx = { version = "=0.8.3", path = ".", default-features = false } bigdecimal = "0.4.0" bit-vec = "0.6.3" chrono = { version = "0.4.34", default-features = false, features = ["std", "clock"] } +ipnet = "2.3.0" ipnetwork = "0.20.0" mac_address = "1.1.5" rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] } @@ -195,6 +198,7 @@ rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" criterion = { version = "0.5.1", features = ["async_tokio"] } +libsqlite3-sys = { version = "0.30.1" } # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/README.md b/README.md index c3b501ca4..cc0ecf2e6 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,8 @@ be removed in the future. - `rust_decimal`: Add support for `NUMERIC` using the `rust_decimal` crate. +- `ipnet`: Add support for `INET` and `CIDR` (in postgres) using the `ipnet` crate. + - `ipnetwork`: Add support for `INET` and `CIDR` (in postgres) using the `ipnetwork` crate. - `json`: Add support for `JSON` and `JSONB` (in postgres) using the `serde_json` crate. diff --git a/ci.db b/ci.db deleted file mode 100644 index cc158a728..000000000 Binary files a/ci.db and /dev/null differ diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 97c2a8b7e..b87a69d2d 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -56,6 +56,7 @@ bit-vec = { workspace = true, optional = true } bigdecimal = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } diff --git a/sqlx-core/src/acquire.rs b/sqlx-core/src/acquire.rs index c9d7fb215..59bac9fa5 100644 --- a/sqlx-core/src/acquire.rs +++ b/sqlx-core/src/acquire.rs @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool { let conn = self.acquire(); Box::pin(async move { - Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await + Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await }) } } @@ -121,7 +121,7 @@ macro_rules! impl_acquire { 'c, Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>, > { - $crate::transaction::Transaction::begin(self) + $crate::transaction::Transaction::begin(self, None) } } }; diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f..6c84c1d8c 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -3,6 +3,7 @@ use crate::describe::Describe; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; @@ -34,6 +41,26 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848..8cf8fc510 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnectOptions}; use crate::connection::{ConnectOptions, Connection}; @@ -87,7 +88,17 @@ impl Connection for AnyConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce417562..a553cda92 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,6 +1,8 @@ use futures_util::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnection}; +use crate::database::Database; use crate::error::Error; use crate::transaction::TransactionManager; @@ -9,8 +11,11 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { - conn.backend.begin() + fn begin<'conn>( + conn: &'conn mut AnyConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + conn.backend.begin(statement) } fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { @@ -24,4 +29,8 @@ impl TransactionManager for AnyTransactionManager { fn start_rollback(conn: &mut AnyConnection) { conn.backend.start_rollback() } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.backend.get_transaction_depth() + } } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c62..74e8cd3e8 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,9 +1,10 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; -use crate::transaction::Transaction; +use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; @@ -49,6 +50,33 @@ pub trait Connection: Send { where Self: Sized; + /// Begin a new transaction with a custom statement. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + /// + /// Returns an error if the connection is already in a transaction or if + /// `statement` does not put the connection into a transaction. + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) + } + + /// Returns `true` if the connection is currently in a transaction. + /// + /// # Note: Automatic Rollbacks May Not Be Counted + /// Certain database errors (such as a serializable isolation failure) + /// can cause automatic rollbacks of a transaction + /// which may not be indicated in the return value of this method. + #[inline] + fn is_in_transaction(&self) -> bool { + ::TransactionManager::get_transaction_depth(self) != 0 + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd..9ad5eff46 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -34,6 +34,12 @@ pub enum Error { #[error("error with configuration: {0}")] Configuration(#[source] BoxDynError), + /// One or more of the arguments to the called function was invalid. + /// + /// The string contains more information. + #[error("{0}")] + InvalidArgument(String), + /// Error returned from the database. #[error("error returned from database: {0}")] Database(#[source] Box), @@ -79,7 +85,7 @@ pub enum Error { }, /// Error occured while encoding a value. - #[error("error occured while encoding a value: {0}")] + #[error("error occurred while encoding a value: {0}")] Encode(#[source] BoxDynError), /// Error occurred while decoding a value. @@ -111,6 +117,12 @@ pub enum Error { #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), + + #[error("attempted to call begin_with at non-zero transaction depth")] + InvalidSavePointStatement, + + #[error("got unexpected connection status after attempting to begin transaction")] + BeginFailed, } impl StdError for Box {} @@ -136,6 +148,12 @@ impl Error { Error::Protocol(err.to_string()) } + #[doc(hidden)] + #[inline] + pub fn database(err: impl DatabaseError) -> Self { + Error::Database(Box::new(err)) + } + #[doc(hidden)] #[inline] pub fn config(err: impl StdError + Send + Sync + 'static) -> Self { diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index a83aabed1..56777ca4d 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -121,7 +121,7 @@ impl<'a, T> Stream for TryAsyncStream<'a, T> { #[macro_export] macro_rules! try_stream { ($($block:tt)*) => { - $crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move { + $crate::ext::async_stream::TryAsyncStream::new(move |yielder| ::tracing::Instrument::in_current_span(async move { // Anti-footgun: effectively pins `yielder` to this future to prevent any accidental // move to another task, which could deadlock. let yielder = &yielder; @@ -133,6 +133,6 @@ macro_rules! try_stream { } $($block)* - }) + })) } } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1..c029fec6e 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection futures_core::future::BoxFuture<'c, Result, Error>> { - crate::transaction::Transaction::begin(&mut **self) + crate::transaction::Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 042bc5c7b..d85bce246 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,6 +54,7 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::{pin, Pin}; @@ -109,7 +110,8 @@ mod options; /// application/daemon/web server/etc. and then shared with all tasks throughout the process' /// lifetime. How best to accomplish this depends on your program architecture. /// -/// In Actix-Web, for example, you can share a single pool with all request handlers using [web::Data]. +/// In Actix-Web, for example, you can efficiently share a single pool with all request handlers +/// using [web::ThinData]. /// /// Cloning `Pool` is cheap as it is simply a reference-counted handle to the inner pool state. /// When the last remaining handle to the pool is dropped, the connections owned by the pool are @@ -131,7 +133,7 @@ mod options; /// * [PgPool][crate::postgres::PgPool] (PostgreSQL) /// * [SqlitePool][crate::sqlite::SqlitePool] (SQLite) /// -/// [web::Data]: https://docs.rs/actix-web/3/actix_web/web/struct.Data.html +/// [web::ThinData]: https://docs.rs/actix-web/4.9.0/actix_web/web/struct.ThinData.html /// /// ### Note: Drop Behavior /// Due to a lack of async `Drop`, dropping the last `Pool` handle may not immediately clean @@ -367,13 +369,17 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction. pub async fn begin(&self) -> Result, Error> { - Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + None, + ) + .await } /// Attempts to retrieve a connection and immediately begins a new transaction if successful. pub async fn try_begin(&self) -> Result>, Error> { match self.try_acquire() { - Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn)) + Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None) .await .map(Some), @@ -381,6 +387,36 @@ impl Pool { } } + /// Retrieves a connection and immediately begins a new transaction using `statement`. + pub async fn begin_with( + &self, + statement: impl Into>, + ) -> Result, Error> { + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + Some(statement.into()), + ) + .await + } + + /// Attempts to retrieve a connection and, if successful, immediately begins a new + /// transaction using `statement`. + pub async fn try_begin_with( + &self, + statement: impl Into>, + ) -> Result>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin( + MaybePoolConnection::PoolConnection(conn), + Some(statement.into()), + ) + .await + .map(Some), + + None => Ok(None), + } + } + /// Shut down the connection pool, immediately waking all tasks waiting for a connection. /// /// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3..2a84ff655 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -16,9 +16,16 @@ pub trait TransactionManager { type Database: Database; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin( - conn: &mut ::Connection, - ) -> BoxFuture<'_, Result<(), Error>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin<'conn>( + conn: &'conn mut ::Connection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>>; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -32,6 +39,14 @@ pub trait TransactionManager { /// Starts to abort the active transaction or restore from the most recent snapshot. fn start_rollback(conn: &mut ::Connection); + + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(conn: &::Connection) -> usize; } /// An in-progress database transaction or savepoint. @@ -83,11 +98,12 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, + statement: Option>, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); Box::pin(async move { - DB::TransactionManager::begin(&mut conn).await?; + DB::TransactionManager::begin(&mut conn, statement).await?; Ok(Self { connection: conn, @@ -237,7 +253,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' #[inline] fn begin(self) -> BoxFuture<'t, Result, Error>> { - Transaction::begin(&mut **self) + Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 909dd4927..b00427daa 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -67,6 +67,13 @@ pub use bigdecimal::BigDecimal; #[doc(no_inline)] pub use rust_decimal::Decimal; +#[cfg(feature = "ipnet")] +#[cfg_attr(docsrs, doc(cfg(feature = "ipnet")))] +pub mod ipnet { + #[doc(no_inline)] + pub use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +} + #[cfg(feature = "ipnetwork")] #[cfg_attr(docsrs, doc(cfg(feature = "ipnetwork")))] pub mod ipnetwork { diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index ad1a8e18e..1b534d96b 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -40,6 +40,7 @@ json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlit bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 6792af6ec..e6436986d 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -39,6 +39,7 @@ sqlite-unbundled = ["sqlx-macros-core/sqlite-unbundled"] bigdecimal = ["sqlx-macros-core/bigdecimal"] bit-vec = ["sqlx-macros-core/bit-vec"] chrono = ["sqlx-macros-core/chrono"] +ipnet = ["sqlx-macros-core/ipnet"] ipnetwork = ["sqlx-macros-core/ipnetwork"] mac_address = ["sqlx-macros-core/mac_address"] rust_decimal = ["sqlx-macros-core/rust_decimal"] diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index e01e41d68..19b3a6f27 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::borrow::Cow; use std::{future, pin::pin}; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - MySqlTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + MySqlTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -53,6 +57,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + MySqlTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 0623a0556..85a9d84f9 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -27,6 +27,7 @@ impl MySqlConnection { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, + status_flags: Default::default(), cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index bc8d0b620..4ad507b90 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -167,6 +167,8 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; + self.inner.status_flags = ok.status; + let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { @@ -209,6 +211,8 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; + self.inner.status_flags = eof.status; + r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0, diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a770..0a2f5fb83 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; @@ -7,6 +8,7 @@ pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; +use crate::protocol::response::Status; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; @@ -34,6 +36,7 @@ pub(crate) struct MySqlConnectionInner { // transaction status pub(crate) transaction_depth: usize, + status_flags: Status, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, @@ -41,6 +44,14 @@ pub(crate) struct MySqlConnectionInner { log_settings: LogSettings, } +impl MySqlConnection { + pub(crate) fn in_transaction(&self) -> bool { + self.inner + .status_flags + .intersects(Status::SERVER_STATUS_IN_TRANS) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() @@ -111,7 +122,17 @@ impl Connection for MySqlConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index db2b20c19..87732cb40 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -448,7 +448,7 @@ impl MySqlConnectOptions { self.socket.as_ref() } - /// Get the server's port. + /// Get the current username. /// /// # Example /// diff --git a/sqlx-mysql/src/protocol/response/status.rs b/sqlx-mysql/src/protocol/response/status.rs index bf5013dee..4a8bb0375 100644 --- a/sqlx-mysql/src/protocol/response/status.rs +++ b/sqlx-mysql/src/protocol/response/status.rs @@ -1,7 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index d8538cc2b..545cb5f4f 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::connection::Waiting; @@ -14,12 +16,24 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut MySqlConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.inner.transaction_depth = depth + 1; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + conn.execute(&*statement).await?; + if !conn.in_transaction() { + return Err(Error::BeginFailed); + } + conn.inner.transaction_depth += 1; Ok(()) }) @@ -65,4 +79,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.transaction_depth = depth - 1; } } + + fn get_transaction_depth(conn: &MySqlConnection) -> usize { + conn.inner.transaction_depth + } } diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 174a73b3f..818aadbab 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -19,6 +19,7 @@ offline = ["sqlx-core/offline"] bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"] bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"] chrono = ["dep:chrono", "sqlx-core/chrono"] +ipnet = ["dep:ipnet", "sqlx-core/ipnet"] ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"] mac_address = ["dep:mac_address", "sqlx-core/mac_address"] rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] @@ -43,6 +44,7 @@ sha2 = { version = "0.10.0", default-features = false } bigdecimal = { workspace = true, optional = true } bit-vec = { workspace = true, optional = true } chrono = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index a7b30fb65..762f53e5d 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::borrow::Cow; use std::{future, pin::pin}; use sqlx_core::any::{ @@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - PgTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -55,6 +59,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + PgTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 3cb9ecaf6..26d87fda6 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -135,6 +136,13 @@ impl PgConnection { Ok(()) } + + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.transaction_status { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } } impl Debug for PgConnection { @@ -187,7 +195,17 @@ impl Connection for PgConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index b96f8d829..17a46a916 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -9,6 +9,7 @@ use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; use sqlx_core::transaction::Transaction; use sqlx_core::Either; +use tracing::Instrument; use crate::describe::Describe; use crate::error::Error; @@ -366,7 +367,7 @@ impl Drop for PgListener { }; // Unregister any listeners before returning the connection to the pool. - crate::rt::spawn(fut); + crate::rt::spawn(fut.in_current_span()); } } } diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index e7c78488e..23352a8dc 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,6 @@ use futures_core::future::BoxFuture; +use sqlx_core::database::Database; +use std::borrow::Cow; use crate::error::Error; use crate::executor::Executor; @@ -13,13 +15,27 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut PgConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { + let depth = conn.inner.transaction_depth; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + let rollback = Rollback::new(conn); - let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth); - rollback.conn.queue_simple_query(&query)?; - rollback.conn.inner.transaction_depth += 1; + rollback.conn.queue_simple_query(&statement)?; rollback.conn.wait_until_ready().await?; + if !rollback.conn.in_transaction() { + return Err(Error::BeginFailed); + } + rollback.conn.inner.transaction_depth += 1; rollback.defuse(); Ok(()) @@ -62,6 +78,10 @@ impl TransactionManager for PgTransactionManager { conn.inner.transaction_depth -= 1; } } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.inner.transaction_depth + } } struct Rollback<'c> { diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index f89690b27..8f63cf97f 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -40,12 +40,21 @@ impl_type_checking!( sqlx::postgres::types::PgBox, + sqlx::postgres::types::PgPath, + + sqlx::postgres::types::PgPolygon, + + sqlx::postgres::types::PgCircle, + #[cfg(feature = "uuid")] sqlx::types::Uuid, #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, + #[cfg(feature = "ipnet")] + sqlx::types::ipnet::IpNet, + #[cfg(feature = "mac_address")] sqlx::types::mac_address::MacAddress, @@ -77,6 +86,9 @@ impl_type_checking!( #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], + #[cfg(feature = "ipnet")] + Vec | &[sqlx::types::ipnet::IpNet], + #[cfg(feature = "mac_address")] Vec | &[sqlx::types::mac_address::MacAddress], diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs index 988c028ed..28016b278 100644 --- a/sqlx-postgres/src/types/geometry/box.rs +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -23,7 +23,10 @@ const ERROR: &str = "error decoding BOX"; /// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box. /// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +/// See [Postgres Manual, Section 8.8.4: Geometric Types - Boxes][PG.S.8.8.4] for details. +/// +/// [PG.S.8.8.4]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +/// #[derive(Debug, Clone, PartialEq)] pub struct PgBox { pub upper_right_x: f64, diff --git a/sqlx-postgres/src/types/geometry/circle.rs b/sqlx-postgres/src/types/geometry/circle.rs new file mode 100644 index 000000000..dde54dd27 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/circle.rs @@ -0,0 +1,250 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::str::FromStr; + +const ERROR: &str = "error decoding CIRCLE"; + +/// ## Postgres Geometric Circle type +/// +/// Description: Circle +/// Representation: `< (x, y), radius >` (center point and radius) +/// +/// ```text +/// < ( x , y ) , radius > +/// ( ( x , y ) , radius ) +/// ( x , y ) , radius +/// x , y , radius +/// ``` +/// where `(x,y)` is the center point. +/// +/// See [Postgres Manual, Section 8.8.7, Geometric Types - Circles][PG.S.8.8.7] for details. +/// +/// [PG.S.8.8.7]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-CIRCLE +/// +#[derive(Debug, Clone, PartialEq)] +pub struct PgCircle { + pub x: f64, + pub y: f64, + pub radius: f64, +} + +impl Type for PgCircle { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("circle") + } +} + +impl PgHasArrayType for PgCircle { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_circle") + } +} + +impl<'r> Decode<'r, Postgres> for PgCircle { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgCircle::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgCircle::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgCircle { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("circle")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgCircle { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['<', '>', '(', ')', ' '], ""); + let mut parts = sanitised.split(','); + + let x = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get x from {}", ERROR, s))?; + + let y = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get y from {}", ERROR, s))?; + + let radius = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get radius from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + if radius < 0. { + return Err(format!("{}: cannot have negative radius: {}", ERROR, s).into()); + } + + Ok(PgCircle { x, y, radius }) + } +} + +impl PgCircle { + fn from_bytes(mut bytes: &[u8]) -> Result { + let x = bytes.get_f64(); + let y = bytes.get_f64(); + let r = bytes.get_f64(); + Ok(PgCircle { x, y, radius: r }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> { + buff.extend_from_slice(&self.x.to_be_bytes()); + buff.extend_from_slice(&self.y.to_be_bytes()); + buff.extend_from_slice(&self.radius.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod circle_tests { + + use std::str::FromStr; + + use super::PgCircle; + + const CIRCLE_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, + ]; + + #[test] + fn can_deserialise_circle_type_bytes() { + let circle = PgCircle::from_bytes(CIRCLE_BYTES).unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ) + } + + #[test] + fn can_deserialise_circle_type_str() { + let circle = PgCircle::from_str("<(1, 2), 3 >").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_second_syntax() { + let circle = PgCircle::from_str("((1, 2), 3 )").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_third_syntax() { + let circle = PgCircle::from_str("(1, 2), 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_fourth_syntax() { + let circle = PgCircle::from_str("1, 2, 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn cannot_deserialise_circle_invalid_numbers() { + let input_str = "1, 2, Three"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: could not get radius from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_circle_negative_radius() { + let input_str = "1, 2, -3"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: cannot have negative radius: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_circle_type_str_float() { + let circle = PgCircle::from_str("<(1.1, 2.2), 3.3>").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ); + } + + #[test] + fn can_serialise_circle_type() { + let circle = PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3, + }; + assert_eq!(circle.serialize_to_vec(), CIRCLE_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/line.rs b/sqlx-postgres/src/types/geometry/line.rs index 43f93c1c3..8f08c949e 100644 --- a/sqlx-postgres/src/types/geometry/line.rs +++ b/sqlx-postgres/src/types/geometry/line.rs @@ -15,7 +15,10 @@ const ERROR: &str = "error decoding LINE"; /// /// Lines are represented by the linear equation Ax + By + C = 0, where A and B are not both zero. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LINE +/// See [Postgres Manual, Section 8.8.2, Geometric Types - Lines][PG.S.8.8.2] for details. +/// +/// [PG.S.8.8.2]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LINE +/// #[derive(Debug, Clone, PartialEq)] pub struct PgLine { pub a: f64, diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs index 5dc5efc74..cd08e4da4 100644 --- a/sqlx-postgres/src/types/geometry/line_segment.rs +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -23,7 +23,10 @@ const ERROR: &str = "error decoding LSEG"; /// ``` /// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG +/// See [Postgres Manual, Section 8.8.3, Geometric Types - Line Segments][PG.S.8.8.3] for details. +/// +/// [PG.S.8.8.3]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LSEG +/// #[doc(alias = "line segment")] #[derive(Debug, Clone, PartialEq)] pub struct PgLSeg { diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index 7fe2898fc..c3142145e 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,4 +1,7 @@ pub mod r#box; +pub mod circle; pub mod line; pub mod line_segment; +pub mod path; pub mod point; +pub mod polygon; diff --git a/sqlx-postgres/src/types/geometry/path.rs b/sqlx-postgres/src/types/geometry/path.rs new file mode 100644 index 000000000..6799289fa --- /dev/null +++ b/sqlx-postgres/src/types/geometry/path.rs @@ -0,0 +1,375 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Path type +/// +/// Description: Open path or Closed path (similar to polygon) +/// Representation: Open `[(x1,y1),...]`, Closed `((x1,y1),...)` +/// +/// Paths are represented by lists of connected points. Paths can be open, where the first and last points in the list are considered not connected, or closed, where the first and last points are considered connected. +/// Values of type path are specified using any of the following syntaxes: +/// ```text +/// [ ( x1 , y1 ) , ... , ( xn , yn ) ] +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path. +/// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed. +/// +/// See [Postgres Manual, Section 8.8.5, Geometric Types - Paths][PG.S.8.8.5] for details. +/// +/// [PG.S.8.8.5]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS +/// +#[derive(Debug, Clone, PartialEq)] +pub struct PgPath { + pub closed: bool, + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + is_closed: bool, + length: usize, +} + +impl Type for PgPath { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("path") + } +} + +impl PgHasArrayType for PgPath { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_path") + } +} + +impl<'r> Decode<'r, Postgres> for PgPath { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPath { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("path")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPath { + type Err = Error; + + fn from_str(s: &str) -> Result { + let closed = !s.contains('['); + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in PATH: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPath { points, closed }); + } + + Err(Error::Decode( + format!("could not get path from {}", s).into(), + )) + } +} + +impl PgPath { + fn header(&self) -> Header { + Header { + is_closed: self.closed, + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPath { + closed: header.is_closed, + points: out_points, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected PATH data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let is_closed = buf.get_i8(); + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received PATH data length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { + is_closed: is_closed != 0, + length, + }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let is_closed = self.is_closed as i8; + + let length = i32::try_from(self.length).map_err(|_| { + format!( + "PATH length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(is_closed.to_be_bytes()); + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod path_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPath; + + const PATH_CLOSED_BYTES: &[u8] = &[ + 1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_OPEN_BYTES: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_UNEVEN_POINTS: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, + ]; + + #[test] + fn can_deserialise_path_type_bytes_closed() { + let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn cannot_deserialise_path_type_uneven_point_bytes() { + let path = PgPath::from_bytes(PATH_UNEVEN_POINTS); + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("expected 32 bytes after header, got 28") + ) + } + } + + #[test] + fn can_deserialise_path_type_bytes_open() { + let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn can_deserialise_path_type_str_first_syntax() { + let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_path_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let path = PgPath::from_str(input_str); + + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in PATH: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_path_type_str_second_syntax() { + let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_third_syntax() { + let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_fourth_syntax() { + let path = PgPath::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_float() { + let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_path_type() { + let path = PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }], + }; + assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/point.rs b/sqlx-postgres/src/types/geometry/point.rs index cc1067295..83b7c24d0 100644 --- a/sqlx-postgres/src/types/geometry/point.rs +++ b/sqlx-postgres/src/types/geometry/point.rs @@ -19,7 +19,10 @@ use std::str::FromStr; /// ```` /// where x and y are the respective coordinates, as floating-point numbers. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS +/// See [Postgres Manual, Section 8.8.1, Geometric Types - Points][PG.S.8.8.1] for details. +/// +/// [PG.S.8.8.1]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS +/// #[derive(Debug, Clone, PartialEq)] pub struct PgPoint { pub x: f64, diff --git a/sqlx-postgres/src/types/geometry/polygon.rs b/sqlx-postgres/src/types/geometry/polygon.rs new file mode 100644 index 000000000..a5a203c68 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/polygon.rs @@ -0,0 +1,366 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Polygon type +/// +/// Description: Polygon (similar to closed polygon) +/// Representation: `((x1,y1),...)` +/// +/// Polygons are represented by lists of points (the vertexes of the polygon). Polygons are very similar to closed paths; the essential semantic difference is that a polygon is considered to include the area within it, while a path is not. +/// An important implementation difference between polygons and paths is that the stored representation of a polygon includes its smallest bounding box. This speeds up certain search operations, although computing the bounding box adds overhead while constructing new polygons. +/// Values of type polygon are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// +/// where the points are the end points of the line segments comprising the boundary of the polygon. +/// +/// See [Postgres Manual, Section 8.8.6, Geometric Types - Polygons][PG.S.8.8.6] for details. +/// +/// [PG.S.8.8.6]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-POLYGON +/// +#[derive(Debug, Clone, PartialEq)] +pub struct PgPolygon { + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + length: usize, +} + +impl Type for PgPolygon { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("polygon") + } +} + +impl PgHasArrayType for PgPolygon { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_polygon") + } +} + +impl<'r> Decode<'r, Postgres> for PgPolygon { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPolygon::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPolygon::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPolygon { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("polygon")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPolygon { + type Err = Error; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in POLYGON: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPolygon { points }); + } + + Err(Error::Decode( + format!("could not get polygon from {}", s).into(), + )) + } +} + +impl PgPolygon { + fn header(&self) -> Header { + Header { + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPolygon { points: out_points }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected polygon data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received polygon with length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { length }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let length = i32::try_from(self.length).map_err(|_| { + format!( + "polygon length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod polygon_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPolygon; + + const POLYGON_BYTES: &[u8] = &[ + 0, 0, 0, 12, 192, 0, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, + 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 63, + 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, + 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, + 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, + 240, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, + 0, 0, 0, + ]; + + #[test] + fn can_deserialise_polygon_type_bytes() { + let polygon = PgPolygon::from_bytes(POLYGON_BYTES).unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. } + ] + } + ) + } + + #[test] + fn can_deserialise_polygon_type_str_first_syntax() { + let polygon = PgPolygon::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_second_syntax() { + let polygon = PgPolygon::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_polygon_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in POLYGON: {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_polygon_type_str_invalid_numbers() { + let input_str = "[( 1, 2), (2, three)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: could not get y") + ) + } + } + + #[test] + fn can_deserialise_polygon_type_str_third_syntax() { + let polygon = PgPolygon::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_fourth_syntax() { + let polygon = PgPolygon::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_float() { + let polygon = PgPolygon::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_polygon_type() { + let polygon = PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. }, + ], + }; + assert_eq!(polygon.serialize_to_vec(), POLYGON_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/ipnet/ipaddr.rs b/sqlx-postgres/src/types/ipnet/ipaddr.rs new file mode 100644 index 000000000..b157eff3c --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipaddr.rs @@ -0,0 +1,62 @@ +use std::net::IpAddr; + +use ipnet::IpNet; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; + +impl Type for IpAddr +where + IpNet: Type, +{ + fn type_info() -> PgTypeInfo { + IpNet::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + IpNet::compatible(ty) + } +} + +impl PgHasArrayType for IpAddr { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + ::array_compatible(ty) + } +} + +impl<'db> Encode<'db, Postgres> for IpAddr +where + IpNet: Encode<'db, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + IpNet::from(*self).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + IpNet::from(*self).size_hint() + } +} + +impl<'db> Decode<'db, Postgres> for IpAddr +where + IpNet: Decode<'db, Postgres>, +{ + fn decode(value: PgValueRef<'db>) -> Result { + let ipnet = IpNet::decode(value)?; + + if matches!(ipnet, IpNet::V4(net) if net.prefix_len() != 32) + || matches!(ipnet, IpNet::V6(net) if net.prefix_len() != 128) + { + Err("lossy decode from inet/cidr")? + } + + Ok(ipnet.addr()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/ipnet.rs b/sqlx-postgres/src/types/ipnet/ipnet.rs new file mode 100644 index 000000000..1f986174b --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipnet.rs @@ -0,0 +1,130 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +#[cfg(feature = "ipnet")] +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39 + +// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc` +// just for one constant. +const PGSQL_AF_INET: u8 = 2; // AF_INET +const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1; + +impl Type for IpNet { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INET + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET + } +} + +impl PgHasArrayType for IpNet { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INET_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } +} + +impl Encode<'_, Postgres> for IpNet { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 + + match self { + IpNet::V4(net) => { + buf.push(PGSQL_AF_INET); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(4); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()) // address + } + + IpNet::V6(net) => { + buf.push(PGSQL_AF_INET6); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(16); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()); // address + } + } + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + match self { + IpNet::V4(_) => 8, + IpNet::V6(_) => 20, + } + } +} + +impl Decode<'_, Postgres> for IpNet { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + let s = value.as_str()?; + println!("{s}"); + if s.contains('/') { + return Ok(s.parse()?); + } + // IpNet::from_str doesn't handle conversion from IpAddr to IpNet + let addr: IpAddr = s.parse()?; + return Ok(addr.into()); + } + }; + + if bytes.len() >= 8 { + let family = bytes[0]; + let prefix = bytes[1]; + let _is_cidr = bytes[2] != 0; + let len = bytes[3]; + + match family { + PGSQL_AF_INET => { + if bytes.len() == 8 && len == 4 { + let inet = Ipv4Net::new( + Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), + prefix, + )?; + + return Ok(IpNet::V4(inet)); + } + } + + PGSQL_AF_INET6 => { + if bytes.len() == 20 && len == 16 { + let inet = Ipv6Net::new( + Ipv6Addr::from([ + bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], + bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + bytes[16], bytes[17], bytes[18], bytes[19], + ]), + prefix, + )?; + + return Ok(IpNet::V6(inet)); + } + } + + _ => { + return Err(format!("unknown ip family {family}").into()); + } + } + } + + Err("invalid data received when expecting an INET".into()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/mod.rs b/sqlx-postgres/src/types/ipnet/mod.rs new file mode 100644 index 000000000..cd40cf30d --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/mod.rs @@ -0,0 +1,7 @@ +// Prefer `ipnetwork` over `ipnet` because it was implemented first (want to avoid breaking change). +#[cfg(not(feature = "ipnetwork"))] +mod ipaddr; + +// Parent module is named after the `ipnet` crate, this is named after the `IpNet` type. +#[allow(clippy::module_inception)] +mod ipnet; diff --git a/sqlx-postgres/src/types/ipaddr.rs b/sqlx-postgres/src/types/ipnetwork/ipaddr.rs similarity index 100% rename from sqlx-postgres/src/types/ipaddr.rs rename to sqlx-postgres/src/types/ipnetwork/ipaddr.rs diff --git a/sqlx-postgres/src/types/ipnetwork.rs b/sqlx-postgres/src/types/ipnetwork/ipnetwork.rs similarity index 100% rename from sqlx-postgres/src/types/ipnetwork.rs rename to sqlx-postgres/src/types/ipnetwork/ipnetwork.rs diff --git a/sqlx-postgres/src/types/ipnetwork/mod.rs b/sqlx-postgres/src/types/ipnetwork/mod.rs new file mode 100644 index 000000000..de40244c6 --- /dev/null +++ b/sqlx-postgres/src/types/ipnetwork/mod.rs @@ -0,0 +1,5 @@ +mod ipaddr; + +// Parent module is named after the `ipnetwork` crate, this is named after the `IpNetwork` type. +#[allow(clippy::module_inception)] +mod ipnetwork; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index a5fd70836..0faefbb48 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -25,6 +25,9 @@ //! | [`PgLine`] | LINE | //! | [`PgLSeg`] | LSEG | //! | [`PgBox`] | BOX | +//! | [`PgPath`] | PATH | +//! | [`PgPolygon`] | POLYGON | +//! | [`PgCircle`] | CIRCLE | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -84,7 +87,7 @@ //! //! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) //! -//! Requires the `ipnetwork` Cargo feature flag. +//! Requires the `ipnetwork` Cargo feature flag (takes precedence over `ipnet` if both are used). //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| @@ -97,6 +100,17 @@ //! //! `IpNetwork` does not have this limitation. //! +//! ### [`ipnet`](https://crates.io/crates/ipnet) +//! +//! Requires the `ipnet` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `ipnet::IpNet` | INET, CIDR | +//! | `std::net::IpAddr` | INET, CIDR | +//! +//! The same `IpAddr` limitation for smaller network prefixes applies as with `ipnet`. +//! //! ### [`mac_address`](https://crates.io/crates/mac_address) //! //! Requires the `mac_address` Cargo feature flag. @@ -245,11 +259,11 @@ mod time; #[cfg(feature = "uuid")] mod uuid; -#[cfg(feature = "ipnetwork")] -mod ipnetwork; +#[cfg(feature = "ipnet")] +mod ipnet; #[cfg(feature = "ipnetwork")] -mod ipaddr; +mod ipnetwork; #[cfg(feature = "mac_address")] mod mac_address; @@ -260,9 +274,12 @@ mod bit_vec; pub use array::PgHasArrayType; pub use citext::PgCiText; pub use cube::PgCube; +pub use geometry::circle::PgCircle; pub use geometry::line::PgLine; pub use geometry::line_segment::PgLSeg; +pub use geometry::path::PgPath; pub use geometry::point::PgPoint; +pub use geometry::polygon::PgPolygon; pub use geometry::r#box::PgBox; pub use hstore::PgHstore; pub use interval::PgInterval; diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 2cc585540..c72370d0f 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -38,8 +40,11 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - SqliteTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + SqliteTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -54,6 +59,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + SqliteTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { // NO-OP. } diff --git a/sqlx-sqlite/src/connection/collation.rs b/sqlx-sqlite/src/connection/collation.rs index 573a9af89..e7422138b 100644 --- a/sqlx-sqlite/src/connection/collation.rs +++ b/sqlx-sqlite/src/connection/collation.rs @@ -10,7 +10,6 @@ use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; use crate::connection::handle::ConnectionHandle; use crate::error::Error; -use crate::SqliteError; #[derive(Clone)] pub struct Collation { @@ -67,7 +66,7 @@ impl Collation { } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Arc::from_raw(raw_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } } @@ -112,7 +111,7 @@ where } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Box::from_raw(boxed_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 5b8aa01b6..c5d2450fb 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -204,10 +204,10 @@ impl EstablishParams { // SAFE: tested for NULL just above // This allows any returns below to close this handle with RAII - let handle = unsafe { ConnectionHandle::new(handle) }; + let mut handle = unsafe { ConnectionHandle::new(handle) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } // Enable extended result codes @@ -226,33 +226,29 @@ impl EstablishParams { for ext in self.extensions.iter() { // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer // rather than by calling `sqlite3_errmsg` - let mut error = null_mut(); + let mut error_msg = null_mut(); status = unsafe { sqlite3_load_extension( handle.as_ptr(), ext.0.as_ptr(), ext.1.as_ref().map_or(null(), |e| e.as_ptr()), - addr_of_mut!(error), + addr_of_mut!(error_msg), ) }; if status != SQLITE_OK { + let mut e = handle.expect_error(); + // SAFETY: We become responsible for any memory allocation at `&error`, so test // for null and take an RAII version for returns - let err_msg = if !error.is_null() { - unsafe { - let e = CStr::from_ptr(error).into(); - sqlite3_free(error as *mut c_void); - e - } - } else { - CString::new("Unknown error when loading extension") - .expect("text should be representable as a CString") - }; - return Err(Error::Database(Box::new(SqliteError::extension( - handle.as_ptr(), - &err_msg, - )))); + if !error_msg.is_null() { + e = e.with_message(unsafe { + let msg = CStr::from_ptr(error_msg).to_string_lossy().into(); + sqlite3_free(error_msg as *mut c_void); + msg + }); + } + return Err(Error::Database(Box::new(e))); } } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION // on by disabling the flag again once we've loaded all the requested modules. @@ -271,7 +267,7 @@ impl EstablishParams { // configure a `regexp` function for sqlite, it does not come with one by default let status = crate::regexp::register(handle.as_ptr()); if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } } @@ -286,13 +282,12 @@ impl EstablishParams { status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), - transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, diff --git a/sqlx-sqlite/src/connection/handle.rs b/sqlx-sqlite/src/connection/handle.rs index aaf5b74ea..60fbe17dc 100644 --- a/sqlx-sqlite/src/connection/handle.rs +++ b/sqlx-sqlite/src/connection/handle.rs @@ -46,6 +46,17 @@ impl ConnectionHandle { unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } } + pub(crate) fn last_error(&mut self) -> Option { + // SAFETY: we have exclusive access to the database handle + unsafe { SqliteError::try_new(self.as_ptr()) } + } + + #[track_caller] + pub(crate) fn expect_error(&mut self) -> SqliteError { + self.last_error() + .expect("expected error code to be set in current context") + } + pub(crate) fn exec(&mut self, query: impl Into) -> Result<(), Error> { let query = query.into(); let query = CString::new(query).map_err(|_| err_protocol!("query contains nul bytes"))?; diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 7412eef12..b94ad91c4 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -11,8 +12,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -40,6 +41,7 @@ mod handle; pub(crate) mod intmap; #[cfg(feature = "preupdate-hook")] mod preupdate_hook; +pub(crate) mod serialize; mod worker; @@ -105,9 +107,6 @@ unsafe impl Send for RollbackHookHandler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - // transaction status - pub(crate) transaction_depth: usize, - pub(crate) statements: Statements, log_settings: LogSettings, @@ -252,14 +251,21 @@ impl Connection for SqliteConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { - self.worker - .shared - .cached_statements_size - .load(std::sync::atomic::Ordering::Acquire) + self.worker.shared.get_cached_statements_size() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { @@ -544,7 +550,12 @@ impl LockedSqliteHandle<'_> { } pub fn last_error(&mut self) -> Option { - SqliteError::try_new(self.guard.handle.as_ptr()) + self.guard.handle.last_error() + } + + pub(crate) fn in_transaction(&mut self) -> bool { + let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; + ret == 0 } } diff --git a/sqlx-sqlite/src/connection/serialize.rs b/sqlx-sqlite/src/connection/serialize.rs new file mode 100644 index 000000000..c8835093d --- /dev/null +++ b/sqlx-sqlite/src/connection/serialize.rs @@ -0,0 +1,297 @@ +use super::ConnectionState; +use crate::{error::Error, SqliteConnection, SqliteError}; +use libsqlite3_sys::{ + sqlite3_deserialize, sqlite3_free, sqlite3_malloc64, sqlite3_serialize, + SQLITE_DESERIALIZE_FREEONCLOSE, SQLITE_DESERIALIZE_READONLY, SQLITE_DESERIALIZE_RESIZEABLE, + SQLITE_NOMEM, SQLITE_OK, +}; +use std::ffi::c_char; +use std::fmt::Debug; +use std::{ + ops::{Deref, DerefMut}, + ptr, + ptr::NonNull, +}; + +impl SqliteConnection { + /// Serialize the given SQLite database schema using [`sqlite3_serialize()`]. + /// + /// The returned buffer is a SQLite managed allocation containing the equivalent data + /// as writing the database to disk. It is freed on-drop. + /// + /// To serialize the primary, unqualified schema (`main`), pass `None` for the schema name. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if the schema does not exist or another error occurs. + /// + /// [`sqlite3_serialize()`]: https://sqlite.org/c3ref/serialize.html + pub async fn serialize(&mut self, schema: Option<&str>) -> Result { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.serialize(schema).await + } + + /// Deserialize a SQLite database from a buffer into the specified schema using [`sqlite3_deserialize()`]. + /// + /// The given schema will be disconnected and re-connected as an in-memory database + /// backed by `data`, which should be the serialized form of a database previously returned + /// by a call to [`Self::serialize()`], documented as being equivalent to + /// the contents of the database file on disk. + /// + /// An error will be returned if a schema with the given name is not already attached. + /// You can use `ATTACH ':memory' as ""` to create an empty schema first. + /// + /// Pass `None` to deserialize to the primary, unqualified schema (`main`). + /// + /// The SQLite connection will take ownership of `data` and will free it when the connection + /// is closed or the schema is detached ([`SQLITE_DESERIALIZE_FREEONCLOSE`][deserialize-flags]). + /// + /// If `read_only` is `true`, the schema is opened as read-only ([`SQLITE_DESERIALIZE_READONLY`][deserialize-flags]). + /// If `false`, the schema is marked as resizable ([`SQLITE_DESERIALIZE_RESIZABLE`][deserialize-flags]). + /// + /// If the database is in WAL mode, an error is returned. + /// See [`sqlite3_deserialize()`] for details. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if an error occurs during deserialization. + /// + /// [`sqlite3_deserialize()`]: https://sqlite.org/c3ref/deserialize.html + /// [deserialize-flags]: https://sqlite.org/c3ref/c_deserialize_freeonclose.html + pub async fn deserialize( + &mut self, + schema: Option<&str>, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.deserialize(schema, data, read_only).await + } +} + +pub(crate) fn serialize( + conn: &mut ConnectionState, + schema: Option, +) -> Result { + let mut size = 0; + + let buf = unsafe { + let ptr = sqlite3_serialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + &mut size, + 0, + ); + + // looking at the source, `sqlite3_serialize` actually sets `size = -1` on error: + // https://github.com/sqlite/sqlite/blob/da5f81387843f92652128087a8f8ecef0b79461d/src/memdb.c#L776 + usize::try_from(size) + .ok() + .and_then(|size| SqliteOwnedBuf::from_raw(ptr, size)) + }; + + if let Some(buf) = buf { + return Ok(buf); + } + + if let Some(error) = conn.handle.last_error() { + return Err(error.into()); + } + + if size > 0 { + // If `size` is positive but `sqlite3_serialize` still returned NULL, + // the most likely culprit is an out-of-memory condition. + return Err(SqliteError::from_code(SQLITE_NOMEM).into()); + } + + // Otherwise, the schema was probably not found. + // We return the equivalent error as when you try to execute `PRAGMA .page_count` + // against a non-existent schema. + Err(SqliteError::generic(format!( + "database {} does not exist", + schema.as_ref().map_or("main", SchemaName::as_str) + )) + .into()) +} + +pub(crate) fn deserialize( + conn: &mut ConnectionState, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, +) -> Result<(), Error> { + // SQLITE_DESERIALIZE_FREEONCLOSE causes SQLite to take ownership of the buffer + let mut flags = SQLITE_DESERIALIZE_FREEONCLOSE; + if read_only { + flags |= SQLITE_DESERIALIZE_READONLY; + } else { + flags |= SQLITE_DESERIALIZE_RESIZEABLE; + } + + let (buf, size) = data.into_raw(); + + let rc = unsafe { + sqlite3_deserialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + buf, + i64::try_from(size).unwrap(), + i64::try_from(size).unwrap(), + flags, + ) + }; + + match rc { + SQLITE_OK => Ok(()), + SQLITE_NOMEM => Err(SqliteError::from_code(SQLITE_NOMEM).into()), + // SQLite unfortunately doesn't set any specific message for deserialization errors. + _ => Err(SqliteError::generic("an error occurred during deserialization").into()), + } +} + +/// Memory buffer owned and allocated by SQLite. Freed on drop. +/// +/// Intended primarily for use with [`SqliteConnection::serialize()`] and [`SqliteConnection::deserialize()`]. +/// +/// Can be created from `&[u8]` using the `TryFrom` impl. The slice must not be empty. +#[derive(Debug)] +pub struct SqliteOwnedBuf { + ptr: NonNull, + size: usize, +} + +unsafe impl Send for SqliteOwnedBuf {} +unsafe impl Sync for SqliteOwnedBuf {} + +impl Drop for SqliteOwnedBuf { + fn drop(&mut self) { + unsafe { + sqlite3_free(self.ptr.as_ptr().cast()); + } + } +} + +impl SqliteOwnedBuf { + /// Uses `sqlite3_malloc` to allocate a buffer and returns a pointer to it. + /// + /// # Safety + /// The allocated buffer is uninitialized. + unsafe fn with_capacity(size: usize) -> Option { + let ptr = sqlite3_malloc64(u64::try_from(size).unwrap()).cast::(); + Self::from_raw(ptr, size) + } + + /// Creates a new mem buffer from a pointer that has been created with sqlite_malloc + /// + /// # Safety: + /// * The pointer must point to a valid allocation created by `sqlite3_malloc()`, or `NULL`. + unsafe fn from_raw(ptr: *mut u8, size: usize) -> Option { + Some(Self { + ptr: NonNull::new(ptr)?, + size, + }) + } + + fn into_raw(self) -> (*mut u8, usize) { + let raw = (self.ptr.as_ptr(), self.size); + // this is used in sqlite_deserialize and + // underlying buffer must not be freed + std::mem::forget(self); + raw + } +} + +/// # Errors +/// Returns [`Error::InvalidArgument`] if the slice is empty. +impl TryFrom<&[u8]> for SqliteOwnedBuf { + type Error = Error; + + fn try_from(bytes: &[u8]) -> Result { + unsafe { + // SAFETY: `buf` is not initialized until `ptr::copy_nonoverlapping` completes. + let mut buf = Self::with_capacity(bytes.len()).ok_or_else(|| { + Error::InvalidArgument("SQLite owned buffer cannot be empty".to_string()) + })?; + ptr::copy_nonoverlapping(bytes.as_ptr(), buf.ptr.as_mut(), buf.size); + Ok(buf) + } + } +} + +impl Deref for SqliteOwnedBuf { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } +} + +impl DerefMut for SqliteOwnedBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut(), self.size) } + } +} + +impl AsRef<[u8]> for SqliteOwnedBuf { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for SqliteOwnedBuf { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +/// Checked schema name to pass to SQLite. +/// +/// # Safety: +/// * Valid UTF-8 (not guaranteed by `CString`) +/// * No internal zero bytes (`\0`) (not guaranteed by `String`) +/// * Terminated with a zero byte (`\0`) (not guaranteed by `String`) +#[derive(Debug)] +pub(crate) struct SchemaName(Box); + +impl SchemaName { + /// Get the schema name as a string without the zero byte terminator. + pub fn as_str(&self) -> &str { + &self.0[..self.0.len() - 1] + } + + /// Get a pointer to the string data, suitable for passing as C's `*const char`. + /// + /// # Safety + /// The string data is guaranteed to be terminated with a zero byte. + pub fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() as *const c_char + } +} + +impl<'a> TryFrom<&'a str> for SchemaName { + type Error = Error; + + fn try_from(name: &'a str) -> Result { + // SAFETY: we must ensure that the string does not contain an internal NULL byte + if let Some(pos) = name.as_bytes().iter().position(|&b| b == 0) { + return Err(Error::InvalidArgument(format!( + "schema name {name:?} contains a zero byte at index {pos}" + ))); + } + + let capacity = name.len().checked_add(1).unwrap(); + + let mut s = String::new(); + // `String::with_capacity()` does not guarantee that it will not overallocate, + // which might mean an unnecessary reallocation to make `capacity == len` + // in the conversion to `Box`. + s.reserve_exact(capacity); + + s.push_str(name); + s.push('\0'); + + Ok(SchemaName(s.into())) + } +} diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c1c67636f..00a4c2999 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -21,6 +21,8 @@ use crate::connection::execute; use crate::connection::ConnectionState; use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement}; +use super::serialize::{deserialize, serialize, SchemaName, SqliteOwnedBuf}; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -34,10 +36,21 @@ pub(crate) struct ConnectionWorker { } pub(crate) struct WorkerSharedState { - pub(crate) cached_statements_size: AtomicUsize, + transaction_depth: AtomicUsize, + cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } +impl WorkerSharedState { + pub(crate) fn get_transaction_depth(&self) -> usize { + self.transaction_depth.load(Ordering::Acquire) + } + + pub(crate) fn get_cached_statements_size(&self) -> usize { + self.cached_statements_size.load(Ordering::Acquire) + } +} + enum Command { Prepare { query: Box, @@ -54,8 +67,19 @@ enum Command { tx: flume::Sender, Error>>, limit: Option, }, + Serialize { + schema: Option, + tx: oneshot::Sender>, + }, + Deserialize { + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + tx: oneshot::Sender>, + }, Begin { tx: rendezvous_oneshot::Sender>, + statement: Option>, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -93,6 +117,7 @@ impl ConnectionWorker { }; let shared = Arc::new(WorkerSharedState { + transaction_depth: AtomicUsize::new(0), cached_statements_size: AtomicUsize::new(0), // note: must be fair because in `Command::UnlockDb` we unlock the mutex // and then immediately try to relock it; an unfair mutex would immediately @@ -182,13 +207,27 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { - let depth = conn.transaction_depth; + Command::Begin { tx, statement } => { + let depth = shared.transaction_depth.load(Ordering::Acquire); + + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { + if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { + break; + } + continue; + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; let res = conn.handle - .exec(begin_ansi_transaction_sql(depth)) + .exec(statement) .map(|_| { - conn.transaction_depth += 1; + shared.transaction_depth.fetch_add(1, Ordering::Release); }); let res_ok = res.is_ok(); @@ -201,7 +240,7 @@ impl ConnectionWorker { .handle .exec(rollback_ansi_transaction_sql(depth + 1)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) { // The rollback failed. To prevent leaving the connection @@ -213,13 +252,13 @@ impl ConnectionWorker { } } Command::Commit { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -239,13 +278,13 @@ impl ConnectionWorker { continue; } - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -263,6 +302,12 @@ impl ConnectionWorker { } } } + Command::Serialize { schema, tx } => { + tx.send(serialize(&mut conn, schema)).ok(); + } + Command::Deserialize { schema, data, read_only, tx } => { + tx.send(deserialize(&mut conn, schema, data, read_only)).ok(); + } Command::ClearCache { tx } => { conn.statements.clear(); update_cached_statements_size(&conn, &shared.cached_statements_size); @@ -333,8 +378,11 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd_with_ack(|tx| Command::Begin { tx }) + pub(crate) async fn begin( + &mut self, + statement: Option>, + ) -> Result<(), Error> { + self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } @@ -358,6 +406,29 @@ impl ConnectionWorker { self.oneshot_cmd(|tx| Command::Ping { tx }).await } + pub(crate) async fn deserialize( + &mut self, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Deserialize { + schema, + data, + read_only, + tx, + }) + .await? + } + + pub(crate) async fn serialize( + &mut self, + schema: Option, + ) -> Result { + self.oneshot_cmd(|tx| Command::Serialize { schema, tx }) + .await? + } + async fn oneshot_cmd(&mut self, command: F) -> Result where F: FnOnce(oneshot::Sender) -> Command, diff --git a/sqlx-sqlite/src/error.rs b/sqlx-sqlite/src/error.rs index 0d34bc102..eee2e8b1a 100644 --- a/sqlx-sqlite/src/error.rs +++ b/sqlx-sqlite/src/error.rs @@ -2,12 +2,12 @@ use std::error::Error as StdError; use std::ffi::CStr; use std::fmt::{self, Display, Formatter}; use std::os::raw::c_int; -use std::{borrow::Cow, str::from_utf8_unchecked}; +use std::{borrow::Cow, str}; use libsqlite3_sys::{ - sqlite3, sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, + sqlite3, sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, SQLITE_CONSTRAINT_FOREIGNKEY, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, - SQLITE_CONSTRAINT_UNIQUE, + SQLITE_CONSTRAINT_UNIQUE, SQLITE_ERROR, }; pub(crate) use sqlx_core::error::*; @@ -18,15 +18,15 @@ pub(crate) use sqlx_core::error::*; #[derive(Debug)] pub struct SqliteError { code: c_int, - message: String, + message: Cow<'static, str>, } impl SqliteError { - pub(crate) fn new(handle: *mut sqlite3) -> Self { + pub(crate) unsafe fn new(handle: *mut sqlite3) -> Self { Self::try_new(handle).expect("There should be an error") } - pub(crate) fn try_new(handle: *mut sqlite3) -> Option { + pub(crate) unsafe fn try_new(handle: *mut sqlite3) -> Option { // returns the extended result code even when extended result codes are disabled let code: c_int = unsafe { sqlite3_extended_errcode(handle) }; @@ -39,20 +39,44 @@ impl SqliteError { let msg = sqlite3_errmsg(handle); debug_assert!(!msg.is_null()); - from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()) + str::from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned() }; Some(Self { code, - message: message.to_owned(), + message: message.into(), }) } /// For errors during extension load, the error message is supplied via a separate pointer - pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self { - let mut err = Self::new(handle); - err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() }; - err + pub(crate) fn with_message(mut self, error_msg: String) -> Self { + self.message = error_msg.into(); + self + } + + pub(crate) fn from_code(code: c_int) -> Self { + let message = unsafe { + let errstr = sqlite3_errstr(code); + + if !errstr.is_null() { + // SAFETY: `errstr` is guaranteed to be UTF-8 + // The lifetime of the string is "internally managed"; + // the implementation just selects from an array of static strings. + // We copy to an owned buffer in case `libsqlite3` is dynamically loaded somehow. + Cow::Owned(str::from_utf8_unchecked(CStr::from_ptr(errstr).to_bytes()).into()) + } else { + Cow::Borrowed("") + } + }; + + SqliteError { code, message } + } + + pub(crate) fn generic(message: impl Into>) -> Self { + Self { + code: SQLITE_ERROR, + message: message.into(), + } } } diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f1a45c3d3..e4a122b6b 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,6 +46,7 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; +pub use connection::serialize::SqliteOwnedBuf; #[cfg(feature = "preupdate-hook")] pub use connection::PreupdateHookResult; pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index e6962ae54..e3a757868 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -84,8 +84,8 @@ impl StatementHandle { } #[inline] - pub(crate) fn last_error(&self) -> SqliteError { - SqliteError::new(unsafe { self.db_handle() }) + pub(crate) fn last_error(&mut self) -> SqliteError { + unsafe { SqliteError::new(self.db_handle()) } } #[inline] diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 345af307a..b25aa69e4 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -185,7 +185,7 @@ fn prepare( }; if status != SQLITE_OK { - return Err(SqliteError::new(conn).into()); + return Err(unsafe { SqliteError::new(conn).into() }); } // tail should point to the first byte past the end of the first SQL diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b..55a80ab9f 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,17 +1,33 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; -use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use crate::{Sqlite, SqliteConnection}; + /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(conn.worker.begin()) + fn begin<'conn>( + conn: &'conn mut SqliteConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + Box::pin(async { + let is_custom_statement = statement.is_some(); + conn.worker.begin(statement).await?; + if is_custom_statement { + // Check that custom statement actually put the connection into a transaction. + let mut handle = conn.lock_handle().await?; + if !handle.in_transaction() { + return Err(Error::BeginFailed); + } + } + Ok(()) + }) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { @@ -25,4 +41,8 @@ impl TransactionManager for SqliteTransactionManager { fn start_rollback(conn: &mut SqliteConnection) { conn.worker.start_rollback().ok(); } + + fn get_transaction_depth(conn: &SqliteConnection) -> usize { + conn.worker.shared.get_transaction_depth() + } } diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 7c84266c3..3ee1024fc 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, mysql::MySql, Connection}; +use sqlx::{error::ErrorKind, mysql::MySql, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index d6f78140d..32bf81477 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, postgres::Postgres, Connection}; +use sqlx::{error::ErrorKind, postgres::Postgres, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7de4a9cdc..fc7108bf4 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -515,6 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; + assert!(!conn.is_in_transaction()); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -523,6 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction + assert!(tx.is_in_transaction()); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -532,6 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint + assert!(tx2.is_in_transaction()); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -541,6 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back + assert!(tx.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -551,6 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; + assert!(!conn.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index ccf88b109..d5d34bc1b 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -2,6 +2,7 @@ extern crate time_ as time; use std::net::SocketAddr; use std::ops::Bound; +use std::str::FromStr; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; @@ -9,7 +10,6 @@ use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; use sqlx_core::executor::Executor; use sqlx_core::types::Text; -use std::str::FromStr; test_type!(null>(Postgres, "NULL::int2" == None:: @@ -171,6 +171,38 @@ test_type!(uuid_vec>(Postgres, ] )); +#[cfg(feature = "ipnet")] +test_type!(ipnet(Postgres, + "'127.0.0.1'::inet" + == "127.0.0.1/32" + .parse::() + .unwrap(), + "'8.8.8.8/24'::inet" + == "8.8.8.8/24" + .parse::() + .unwrap(), + "'10.1.1/24'::inet" + == "10.1.1.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0'::inet" + == "::ffff:1.2.3.0/128" + .parse::() + .unwrap(), + "'2001:4f8:3:ba::/64'::inet" + == "2001:4f8:3:ba::/64" + .parse::() + .unwrap(), + "'192.168'::cidr" + == "192.168.0.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0/120'::cidr" + == "::ffff:1.2.3.0/120" + .parse::() + .unwrap(), +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork(Postgres, "'127.0.0.1'::inet" @@ -232,6 +264,15 @@ test_type!(bitvec( }, )); +#[cfg(feature = "ipnet")] +test_type!(ipnet_vec>(Postgres, + "'{127.0.0.1,8.8.8.8/24}'::inet[]" + == vec![ + "127.0.0.1/32".parse::().unwrap(), + "8.8.8.8/24".parse::().unwrap() + ] +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork_vec>(Postgres, "'{127.0.0.1,8.8.8.8/24}'::inet[]" @@ -524,6 +565,29 @@ test_type!(_box>(Postgres, "array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }], )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(path(Postgres, + "path('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgPath { closed: true, points: vec![ sqlx::postgres::types::PgPoint { x: 1., y: 2. }, sqlx::postgres::types::PgPoint { x: 3. , y: 4. } ]}, + "path('[(1.0, 2.0), (3.0,4.0)]')" == sqlx::postgres::types::PgPath { closed: false, points: vec![ sqlx::postgres::types::PgPoint { x: 1., y: 2. }, sqlx::postgres::types::PgPoint { x: 3. , y: 4. } ]}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(polygon(Postgres, + "polygon('((-2,-3),(-1,-3),(-1,-1),(1,1),(1,3),(2,3),(2,-3),(1,-3),(1,0),(-1,0),(-1,-2),(-2,-2))')" ~= sqlx::postgres::types::PgPolygon { points: vec![ + sqlx::postgres::types::PgPoint { x: -2., y: -3. }, sqlx::postgres::types::PgPoint { x: -1., y: -3. }, sqlx::postgres::types::PgPoint { x: -1., y: -1. }, sqlx::postgres::types::PgPoint { x: 1., y: 1. }, + sqlx::postgres::types::PgPoint { x: 1., y: 3. }, sqlx::postgres::types::PgPoint { x: 2., y: 3. }, sqlx::postgres::types::PgPoint { x: 2., y: -3. }, sqlx::postgres::types::PgPoint { x: 1., y: -3. }, + sqlx::postgres::types::PgPoint { x: 1., y: 0. }, sqlx::postgres::types::PgPoint { x: -1., y: 0. }, sqlx::postgres::types::PgPoint { x: -1., y: -2. }, sqlx::postgres::types::PgPoint { x: -2., y: -2. }, + ]}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(circle(Postgres, + "circle('<(1.1, -2.2), 3.3>')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('((1.1, -2.2), 3.3)')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('(1.1, -2.2), 3.3')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('1.1, -2.2, 3.3')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 1f6b797e6..8729842b7 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Executor}; +use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Error, Executor}; use sqlx_test::new; #[sqlx_macros::test] @@ -70,3 +70,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 16b4b2d9f..c23c4fc9e 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,12 +2,11 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; -use sqlx::Decode; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; -use sqlx::{Value, ValueRef}; +use sqlx_sqlite::LockedSqliteHandle; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -271,7 +270,7 @@ async fn it_handles_empty_queries() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_parameters() -> anyhow::Result<()> { +async fn it_binds_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: i32 = sqlx::query_scalar("SELECT ?") @@ -293,7 +292,7 @@ fn it_binds_parameters() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_dollar_parameters() -> anyhow::Result<()> { +async fn it_binds_dollar_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: (i32, i32) = sqlx::query_as("SELECT $1, $2") @@ -973,6 +972,8 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. @@ -1021,6 +1022,8 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") .execute(&mut conn) @@ -1064,6 +1067,9 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + use sqlx::Decode; + use sqlx::{Value, ValueRef}; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") .execute(&mut conn) @@ -1193,3 +1199,171 @@ async fn test_get_last_error() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_serialize_deserialize() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql("create table foo(bar integer not null, baz text not null)") + .execute(&mut conn) + .await?; + + sqlx::query("insert into foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')") + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(None).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + conn.deserialize(None, serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} +#[sqlx_macros::test] +async fn test_serialize_deserialize_with_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql( + "attach ':memory:' as foo; create table foo.foo(bar integer not null, baz text not null)", + ) + .execute(&mut conn) + .await?; + + sqlx::query( + "insert into foo.foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')", + ) + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(Some("foo")).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + // Unexpected quirk: the schema must exist before deserialization. + sqlx::raw_sql("attach ':memory:' as foo") + .execute(&mut conn) + .await?; + + conn.deserialize(Some("foo"), serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo.foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_nonexistent_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foobar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::Database(dbe) = err else { + panic!("expected DatabaseError: {err:?}") + }; + + assert_eq!(dbe.code().as_deref(), Some("1")); + assert_eq!(dbe.message(), "database foobar does not exist"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_invalid_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foo\0bar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::InvalidArgument(msg) = err else { + panic!("expected InvalidArgument: {err:?}") + }; + + assert_eq!( + msg, + "schema name \"foo\\0bar\" contains a zero byte at index 3" + ); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + async fn check_txn_state(conn: &mut SqliteConnection, expected: SqliteTransactionState) { + let state = transaction_state(&mut conn.lock_handle().await.unwrap()); + assert_eq!(state, expected); + } + + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); + + check_txn_state(&mut conn, SqliteTransactionState::None).await; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + Ok(()) +} + +fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { + use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE}; + + let unchecked_state = + unsafe { sqlite3_txn_state(handle.as_raw_handle().as_ptr(), std::ptr::null()) }; + match unchecked_state { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => panic!("unknown txn state: {unchecked_state}"), + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SqliteTransactionState { + None, + Read, + Write, +} diff --git a/tests/ui-tests.rs b/tests/ui-tests.rs index f74694b87..4a5ca240e 100644 --- a/tests/ui-tests.rs +++ b/tests/ui-tests.rs @@ -17,7 +17,7 @@ fn ui_tests() { t.compile_fail("tests/ui/postgres/gated/uuid.rs"); } - if cfg!(not(feature = "ipnetwork")) { + if cfg!(not(feature = "ipnet")) && cfg!(not(feature = "ipnetwork")) { t.compile_fail("tests/ui/postgres/gated/ipnetwork.rs"); } }