From 5d6d6985cd2274dc90501a29dcb58b54befe91a1 Mon Sep 17 00:00:00 2001 From: Jonas Malaco Date: Fri, 28 Feb 2025 21:42:53 -0300 Subject: [PATCH 01/12] docs(pool): recommend actix-web ThinData over Data to avoid two Arcs (#3762) Both actix_web::web::Data and sqlx::PgPool internally wrap an Arc. Thus, using Data as an extractor in an actix-web route handler results in two Arcs wrapping the data of interest, which isn't ideal. Actix-web 4.9.0 introduced a new web::ThinData extractor for cases like this, where the data is already wrapped in an `Arc` (or is otherwise similarly cheap and sensible to simply clone), which doesn't wrap the inner value in a (second) Arc. Since the new extractor is better suited to the task, suggest it in place of web::Data when giving an example on how to share a pool. --- sqlx-core/src/pool/mod.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 042bc5c7b..8aa9041ab 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -109,7 +109,8 @@ mod options; /// application/daemon/web server/etc. and then shared with all tasks throughout the process' /// lifetime. How best to accomplish this depends on your program architecture. /// -/// In Actix-Web, for example, you can share a single pool with all request handlers using [web::Data]. +/// In Actix-Web, for example, you can efficiently share a single pool with all request handlers +/// using [web::ThinData]. /// /// Cloning `Pool` is cheap as it is simply a reference-counted handle to the inner pool state. /// When the last remaining handle to the pool is dropped, the connections owned by the pool are @@ -131,7 +132,7 @@ mod options; /// * [PgPool][crate::postgres::PgPool] (PostgreSQL) /// * [SqlitePool][crate::sqlite::SqlitePool] (SQLite) /// -/// [web::Data]: https://docs.rs/actix-web/3/actix_web/web/struct.Data.html +/// [web::ThinData]: https://docs.rs/actix-web/4.9.0/actix_web/web/struct.ThinData.html /// /// ### Note: Drop Behavior /// Due to a lack of async `Drop`, dropping the last `Pool` handle may not immediately clean From c5ea6c44355292e4a03f3ec8266b085278648e44 Mon Sep 17 00:00:00 2001 From: Mattia Righetti Date: Sun, 2 Mar 2025 22:29:29 +0000 Subject: [PATCH 02/12] feat: sqlx sqlite expose de/serialize (#3745) * feat: implement serialze no copy on lockedsqlitehandle * feat: implement serialize on sqliteconnection * feat: implement deserialize on sqliteconnection and add sqlitebuf wrapper type * refactor: misc sqlite type and deserialize refactoring * chore: misc clippy refactoring * fix: misc refactoring and fixes - pass non-owned byte slice to deserialize - `SqliteBufError` and better error handling - more impl for `SqliteOnwedBuf` so it can be used as a slice - default serialize for `SqliteConnection` * refactor: move serialize and deserialize on worker thread This implements `Command::Serialize` and `Command::Deserialize` and moves the serialize and deserialize logic to the worker thread. `Serialize` will need some more iterations as it's not clear whether it would need to wait for other write transactions before running. * refactor: misc refactoring and changes - Merged deserialize module with serialize module - Moved `SqliteOwnedBuf` into serialize module - Fixed rustdocs * chore: API tweaks, better docs, tests * fix: unused import * fix: export `SqliteOwnedBuf`, docs and safety tweaks --------- Co-authored-by: Austin Bonander --- sqlx-core/src/error.rs | 14 +- sqlx-sqlite/src/connection/collation.rs | 5 +- sqlx-sqlite/src/connection/establish.rs | 36 ++- sqlx-sqlite/src/connection/handle.rs | 11 + sqlx-sqlite/src/connection/mod.rs | 3 +- sqlx-sqlite/src/connection/serialize.rs | 297 ++++++++++++++++++++++++ sqlx-sqlite/src/connection/worker.rs | 41 ++++ sqlx-sqlite/src/error.rs | 48 +++- sqlx-sqlite/src/lib.rs | 1 + sqlx-sqlite/src/statement/handle.rs | 4 +- sqlx-sqlite/src/statement/virtual.rs | 2 +- tests/sqlite/sqlite.rs | 131 ++++++++++- 12 files changed, 549 insertions(+), 44 deletions(-) create mode 100644 sqlx-sqlite/src/connection/serialize.rs diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774addd..98b42fbcd 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -34,6 +34,12 @@ pub enum Error { #[error("error with configuration: {0}")] Configuration(#[source] BoxDynError), + /// One or more of the arguments to the called function was invalid. + /// + /// The string contains more information. + #[error("{0}")] + InvalidArgument(String), + /// Error returned from the database. #[error("error returned from database: {0}")] Database(#[source] Box), @@ -79,7 +85,7 @@ pub enum Error { }, /// Error occured while encoding a value. - #[error("error occured while encoding a value: {0}")] + #[error("error occurred while encoding a value: {0}")] Encode(#[source] BoxDynError), /// Error occurred while decoding a value. @@ -136,6 +142,12 @@ impl Error { Error::Protocol(err.to_string()) } + #[doc(hidden)] + #[inline] + pub fn database(err: impl DatabaseError) -> Self { + Error::Database(Box::new(err)) + } + #[doc(hidden)] #[inline] pub fn config(err: impl StdError + Send + Sync + 'static) -> Self { diff --git a/sqlx-sqlite/src/connection/collation.rs b/sqlx-sqlite/src/connection/collation.rs index 573a9af89..e7422138b 100644 --- a/sqlx-sqlite/src/connection/collation.rs +++ b/sqlx-sqlite/src/connection/collation.rs @@ -10,7 +10,6 @@ use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; use crate::connection::handle::ConnectionHandle; use crate::error::Error; -use crate::SqliteError; #[derive(Clone)] pub struct Collation { @@ -67,7 +66,7 @@ impl Collation { } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Arc::from_raw(raw_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } } @@ -112,7 +111,7 @@ where } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Box::from_raw(boxed_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 5b8aa01b6..334b1616a 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -204,10 +204,10 @@ impl EstablishParams { // SAFE: tested for NULL just above // This allows any returns below to close this handle with RAII - let handle = unsafe { ConnectionHandle::new(handle) }; + let mut handle = unsafe { ConnectionHandle::new(handle) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } // Enable extended result codes @@ -226,33 +226,29 @@ impl EstablishParams { for ext in self.extensions.iter() { // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer // rather than by calling `sqlite3_errmsg` - let mut error = null_mut(); + let mut error_msg = null_mut(); status = unsafe { sqlite3_load_extension( handle.as_ptr(), ext.0.as_ptr(), ext.1.as_ref().map_or(null(), |e| e.as_ptr()), - addr_of_mut!(error), + addr_of_mut!(error_msg), ) }; if status != SQLITE_OK { + let mut e = handle.expect_error(); + // SAFETY: We become responsible for any memory allocation at `&error`, so test // for null and take an RAII version for returns - let err_msg = if !error.is_null() { - unsafe { - let e = CStr::from_ptr(error).into(); - sqlite3_free(error as *mut c_void); - e - } - } else { - CString::new("Unknown error when loading extension") - .expect("text should be representable as a CString") - }; - return Err(Error::Database(Box::new(SqliteError::extension( - handle.as_ptr(), - &err_msg, - )))); + if !error_msg.is_null() { + e = e.with_message(unsafe { + let msg = CStr::from_ptr(error_msg).to_string_lossy().into(); + sqlite3_free(error_msg as *mut c_void); + msg + }); + } + return Err(Error::Database(Box::new(e))); } } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION // on by disabling the flag again once we've loaded all the requested modules. @@ -271,7 +267,7 @@ impl EstablishParams { // configure a `regexp` function for sqlite, it does not come with one by default let status = crate::regexp::register(handle.as_ptr()); if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } } @@ -286,7 +282,7 @@ impl EstablishParams { status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } Ok(ConnectionState { diff --git a/sqlx-sqlite/src/connection/handle.rs b/sqlx-sqlite/src/connection/handle.rs index aaf5b74ea..60fbe17dc 100644 --- a/sqlx-sqlite/src/connection/handle.rs +++ b/sqlx-sqlite/src/connection/handle.rs @@ -46,6 +46,17 @@ impl ConnectionHandle { unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } } + pub(crate) fn last_error(&mut self) -> Option { + // SAFETY: we have exclusive access to the database handle + unsafe { SqliteError::try_new(self.as_ptr()) } + } + + #[track_caller] + pub(crate) fn expect_error(&mut self) -> SqliteError { + self.last_error() + .expect("expected error code to be set in current context") + } + pub(crate) fn exec(&mut self, query: impl Into) -> Result<(), Error> { let query = query.into(); let query = CString::new(query).map_err(|_| err_protocol!("query contains nul bytes"))?; diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 7412eef12..3316ad40c 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -40,6 +40,7 @@ mod handle; pub(crate) mod intmap; #[cfg(feature = "preupdate-hook")] mod preupdate_hook; +pub(crate) mod serialize; mod worker; @@ -544,7 +545,7 @@ impl LockedSqliteHandle<'_> { } pub fn last_error(&mut self) -> Option { - SqliteError::try_new(self.guard.handle.as_ptr()) + self.guard.handle.last_error() } } diff --git a/sqlx-sqlite/src/connection/serialize.rs b/sqlx-sqlite/src/connection/serialize.rs new file mode 100644 index 000000000..c8835093d --- /dev/null +++ b/sqlx-sqlite/src/connection/serialize.rs @@ -0,0 +1,297 @@ +use super::ConnectionState; +use crate::{error::Error, SqliteConnection, SqliteError}; +use libsqlite3_sys::{ + sqlite3_deserialize, sqlite3_free, sqlite3_malloc64, sqlite3_serialize, + SQLITE_DESERIALIZE_FREEONCLOSE, SQLITE_DESERIALIZE_READONLY, SQLITE_DESERIALIZE_RESIZEABLE, + SQLITE_NOMEM, SQLITE_OK, +}; +use std::ffi::c_char; +use std::fmt::Debug; +use std::{ + ops::{Deref, DerefMut}, + ptr, + ptr::NonNull, +}; + +impl SqliteConnection { + /// Serialize the given SQLite database schema using [`sqlite3_serialize()`]. + /// + /// The returned buffer is a SQLite managed allocation containing the equivalent data + /// as writing the database to disk. It is freed on-drop. + /// + /// To serialize the primary, unqualified schema (`main`), pass `None` for the schema name. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if the schema does not exist or another error occurs. + /// + /// [`sqlite3_serialize()`]: https://sqlite.org/c3ref/serialize.html + pub async fn serialize(&mut self, schema: Option<&str>) -> Result { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.serialize(schema).await + } + + /// Deserialize a SQLite database from a buffer into the specified schema using [`sqlite3_deserialize()`]. + /// + /// The given schema will be disconnected and re-connected as an in-memory database + /// backed by `data`, which should be the serialized form of a database previously returned + /// by a call to [`Self::serialize()`], documented as being equivalent to + /// the contents of the database file on disk. + /// + /// An error will be returned if a schema with the given name is not already attached. + /// You can use `ATTACH ':memory' as ""` to create an empty schema first. + /// + /// Pass `None` to deserialize to the primary, unqualified schema (`main`). + /// + /// The SQLite connection will take ownership of `data` and will free it when the connection + /// is closed or the schema is detached ([`SQLITE_DESERIALIZE_FREEONCLOSE`][deserialize-flags]). + /// + /// If `read_only` is `true`, the schema is opened as read-only ([`SQLITE_DESERIALIZE_READONLY`][deserialize-flags]). + /// If `false`, the schema is marked as resizable ([`SQLITE_DESERIALIZE_RESIZABLE`][deserialize-flags]). + /// + /// If the database is in WAL mode, an error is returned. + /// See [`sqlite3_deserialize()`] for details. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if an error occurs during deserialization. + /// + /// [`sqlite3_deserialize()`]: https://sqlite.org/c3ref/deserialize.html + /// [deserialize-flags]: https://sqlite.org/c3ref/c_deserialize_freeonclose.html + pub async fn deserialize( + &mut self, + schema: Option<&str>, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.deserialize(schema, data, read_only).await + } +} + +pub(crate) fn serialize( + conn: &mut ConnectionState, + schema: Option, +) -> Result { + let mut size = 0; + + let buf = unsafe { + let ptr = sqlite3_serialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + &mut size, + 0, + ); + + // looking at the source, `sqlite3_serialize` actually sets `size = -1` on error: + // https://github.com/sqlite/sqlite/blob/da5f81387843f92652128087a8f8ecef0b79461d/src/memdb.c#L776 + usize::try_from(size) + .ok() + .and_then(|size| SqliteOwnedBuf::from_raw(ptr, size)) + }; + + if let Some(buf) = buf { + return Ok(buf); + } + + if let Some(error) = conn.handle.last_error() { + return Err(error.into()); + } + + if size > 0 { + // If `size` is positive but `sqlite3_serialize` still returned NULL, + // the most likely culprit is an out-of-memory condition. + return Err(SqliteError::from_code(SQLITE_NOMEM).into()); + } + + // Otherwise, the schema was probably not found. + // We return the equivalent error as when you try to execute `PRAGMA .page_count` + // against a non-existent schema. + Err(SqliteError::generic(format!( + "database {} does not exist", + schema.as_ref().map_or("main", SchemaName::as_str) + )) + .into()) +} + +pub(crate) fn deserialize( + conn: &mut ConnectionState, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, +) -> Result<(), Error> { + // SQLITE_DESERIALIZE_FREEONCLOSE causes SQLite to take ownership of the buffer + let mut flags = SQLITE_DESERIALIZE_FREEONCLOSE; + if read_only { + flags |= SQLITE_DESERIALIZE_READONLY; + } else { + flags |= SQLITE_DESERIALIZE_RESIZEABLE; + } + + let (buf, size) = data.into_raw(); + + let rc = unsafe { + sqlite3_deserialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + buf, + i64::try_from(size).unwrap(), + i64::try_from(size).unwrap(), + flags, + ) + }; + + match rc { + SQLITE_OK => Ok(()), + SQLITE_NOMEM => Err(SqliteError::from_code(SQLITE_NOMEM).into()), + // SQLite unfortunately doesn't set any specific message for deserialization errors. + _ => Err(SqliteError::generic("an error occurred during deserialization").into()), + } +} + +/// Memory buffer owned and allocated by SQLite. Freed on drop. +/// +/// Intended primarily for use with [`SqliteConnection::serialize()`] and [`SqliteConnection::deserialize()`]. +/// +/// Can be created from `&[u8]` using the `TryFrom` impl. The slice must not be empty. +#[derive(Debug)] +pub struct SqliteOwnedBuf { + ptr: NonNull, + size: usize, +} + +unsafe impl Send for SqliteOwnedBuf {} +unsafe impl Sync for SqliteOwnedBuf {} + +impl Drop for SqliteOwnedBuf { + fn drop(&mut self) { + unsafe { + sqlite3_free(self.ptr.as_ptr().cast()); + } + } +} + +impl SqliteOwnedBuf { + /// Uses `sqlite3_malloc` to allocate a buffer and returns a pointer to it. + /// + /// # Safety + /// The allocated buffer is uninitialized. + unsafe fn with_capacity(size: usize) -> Option { + let ptr = sqlite3_malloc64(u64::try_from(size).unwrap()).cast::(); + Self::from_raw(ptr, size) + } + + /// Creates a new mem buffer from a pointer that has been created with sqlite_malloc + /// + /// # Safety: + /// * The pointer must point to a valid allocation created by `sqlite3_malloc()`, or `NULL`. + unsafe fn from_raw(ptr: *mut u8, size: usize) -> Option { + Some(Self { + ptr: NonNull::new(ptr)?, + size, + }) + } + + fn into_raw(self) -> (*mut u8, usize) { + let raw = (self.ptr.as_ptr(), self.size); + // this is used in sqlite_deserialize and + // underlying buffer must not be freed + std::mem::forget(self); + raw + } +} + +/// # Errors +/// Returns [`Error::InvalidArgument`] if the slice is empty. +impl TryFrom<&[u8]> for SqliteOwnedBuf { + type Error = Error; + + fn try_from(bytes: &[u8]) -> Result { + unsafe { + // SAFETY: `buf` is not initialized until `ptr::copy_nonoverlapping` completes. + let mut buf = Self::with_capacity(bytes.len()).ok_or_else(|| { + Error::InvalidArgument("SQLite owned buffer cannot be empty".to_string()) + })?; + ptr::copy_nonoverlapping(bytes.as_ptr(), buf.ptr.as_mut(), buf.size); + Ok(buf) + } + } +} + +impl Deref for SqliteOwnedBuf { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } +} + +impl DerefMut for SqliteOwnedBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut(), self.size) } + } +} + +impl AsRef<[u8]> for SqliteOwnedBuf { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for SqliteOwnedBuf { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +/// Checked schema name to pass to SQLite. +/// +/// # Safety: +/// * Valid UTF-8 (not guaranteed by `CString`) +/// * No internal zero bytes (`\0`) (not guaranteed by `String`) +/// * Terminated with a zero byte (`\0`) (not guaranteed by `String`) +#[derive(Debug)] +pub(crate) struct SchemaName(Box); + +impl SchemaName { + /// Get the schema name as a string without the zero byte terminator. + pub fn as_str(&self) -> &str { + &self.0[..self.0.len() - 1] + } + + /// Get a pointer to the string data, suitable for passing as C's `*const char`. + /// + /// # Safety + /// The string data is guaranteed to be terminated with a zero byte. + pub fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() as *const c_char + } +} + +impl<'a> TryFrom<&'a str> for SchemaName { + type Error = Error; + + fn try_from(name: &'a str) -> Result { + // SAFETY: we must ensure that the string does not contain an internal NULL byte + if let Some(pos) = name.as_bytes().iter().position(|&b| b == 0) { + return Err(Error::InvalidArgument(format!( + "schema name {name:?} contains a zero byte at index {pos}" + ))); + } + + let capacity = name.len().checked_add(1).unwrap(); + + let mut s = String::new(); + // `String::with_capacity()` does not guarantee that it will not overallocate, + // which might mean an unnecessary reallocation to make `capacity == len` + // in the conversion to `Box`. + s.reserve_exact(capacity); + + s.push_str(name); + s.push('\0'); + + Ok(SchemaName(s.into())) + } +} diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c1c67636f..8a1d140b2 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -21,6 +21,8 @@ use crate::connection::execute; use crate::connection::ConnectionState; use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement}; +use super::serialize::{deserialize, serialize, SchemaName, SqliteOwnedBuf}; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -54,6 +56,16 @@ enum Command { tx: flume::Sender, Error>>, limit: Option, }, + Serialize { + schema: Option, + tx: oneshot::Sender>, + }, + Deserialize { + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + tx: oneshot::Sender>, + }, Begin { tx: rendezvous_oneshot::Sender>, }, @@ -263,6 +275,12 @@ impl ConnectionWorker { } } } + Command::Serialize { schema, tx } => { + tx.send(serialize(&mut conn, schema)).ok(); + } + Command::Deserialize { schema, data, read_only, tx } => { + tx.send(deserialize(&mut conn, schema, data, read_only)).ok(); + } Command::ClearCache { tx } => { conn.statements.clear(); update_cached_statements_size(&conn, &shared.cached_statements_size); @@ -358,6 +376,29 @@ impl ConnectionWorker { self.oneshot_cmd(|tx| Command::Ping { tx }).await } + pub(crate) async fn deserialize( + &mut self, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Deserialize { + schema, + data, + read_only, + tx, + }) + .await? + } + + pub(crate) async fn serialize( + &mut self, + schema: Option, + ) -> Result { + self.oneshot_cmd(|tx| Command::Serialize { schema, tx }) + .await? + } + async fn oneshot_cmd(&mut self, command: F) -> Result where F: FnOnce(oneshot::Sender) -> Command, diff --git a/sqlx-sqlite/src/error.rs b/sqlx-sqlite/src/error.rs index 0d34bc102..eee2e8b1a 100644 --- a/sqlx-sqlite/src/error.rs +++ b/sqlx-sqlite/src/error.rs @@ -2,12 +2,12 @@ use std::error::Error as StdError; use std::ffi::CStr; use std::fmt::{self, Display, Formatter}; use std::os::raw::c_int; -use std::{borrow::Cow, str::from_utf8_unchecked}; +use std::{borrow::Cow, str}; use libsqlite3_sys::{ - sqlite3, sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, + sqlite3, sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, SQLITE_CONSTRAINT_FOREIGNKEY, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, - SQLITE_CONSTRAINT_UNIQUE, + SQLITE_CONSTRAINT_UNIQUE, SQLITE_ERROR, }; pub(crate) use sqlx_core::error::*; @@ -18,15 +18,15 @@ pub(crate) use sqlx_core::error::*; #[derive(Debug)] pub struct SqliteError { code: c_int, - message: String, + message: Cow<'static, str>, } impl SqliteError { - pub(crate) fn new(handle: *mut sqlite3) -> Self { + pub(crate) unsafe fn new(handle: *mut sqlite3) -> Self { Self::try_new(handle).expect("There should be an error") } - pub(crate) fn try_new(handle: *mut sqlite3) -> Option { + pub(crate) unsafe fn try_new(handle: *mut sqlite3) -> Option { // returns the extended result code even when extended result codes are disabled let code: c_int = unsafe { sqlite3_extended_errcode(handle) }; @@ -39,20 +39,44 @@ impl SqliteError { let msg = sqlite3_errmsg(handle); debug_assert!(!msg.is_null()); - from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()) + str::from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned() }; Some(Self { code, - message: message.to_owned(), + message: message.into(), }) } /// For errors during extension load, the error message is supplied via a separate pointer - pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self { - let mut err = Self::new(handle); - err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() }; - err + pub(crate) fn with_message(mut self, error_msg: String) -> Self { + self.message = error_msg.into(); + self + } + + pub(crate) fn from_code(code: c_int) -> Self { + let message = unsafe { + let errstr = sqlite3_errstr(code); + + if !errstr.is_null() { + // SAFETY: `errstr` is guaranteed to be UTF-8 + // The lifetime of the string is "internally managed"; + // the implementation just selects from an array of static strings. + // We copy to an owned buffer in case `libsqlite3` is dynamically loaded somehow. + Cow::Owned(str::from_utf8_unchecked(CStr::from_ptr(errstr).to_bytes()).into()) + } else { + Cow::Borrowed("") + } + }; + + SqliteError { code, message } + } + + pub(crate) fn generic(message: impl Into>) -> Self { + Self { + code: SQLITE_ERROR, + message: message.into(), + } } } diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f1a45c3d3..e4a122b6b 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,6 +46,7 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; +pub use connection::serialize::SqliteOwnedBuf; #[cfg(feature = "preupdate-hook")] pub use connection::PreupdateHookResult; pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index 2925d1a19..ccc299fcd 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -81,8 +81,8 @@ impl StatementHandle { } #[inline] - pub(crate) fn last_error(&self) -> SqliteError { - SqliteError::new(unsafe { self.db_handle() }) + pub(crate) fn last_error(&mut self) -> SqliteError { + unsafe { SqliteError::new(self.db_handle()) } } #[inline] diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 6be980c36..2817146bc 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -184,7 +184,7 @@ fn prepare( }; if status != SQLITE_OK { - return Err(SqliteError::new(conn).into()); + return Err(unsafe { SqliteError::new(conn).into() }); } // tail should point to the first byte past the end of the first SQL diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 16b4b2d9f..92a113873 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,12 +2,10 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; -use sqlx::Decode; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; -use sqlx::{Value, ValueRef}; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -271,7 +269,7 @@ async fn it_handles_empty_queries() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_parameters() -> anyhow::Result<()> { +async fn it_binds_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: i32 = sqlx::query_scalar("SELECT ?") @@ -293,7 +291,7 @@ fn it_binds_parameters() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_dollar_parameters() -> anyhow::Result<()> { +async fn it_binds_dollar_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: (i32, i32) = sqlx::query_as("SELECT $1, $2") @@ -973,6 +971,8 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. @@ -1021,6 +1021,8 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") .execute(&mut conn) @@ -1064,6 +1066,9 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + use sqlx::Decode; + use sqlx::{Value, ValueRef}; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") .execute(&mut conn) @@ -1193,3 +1198,121 @@ async fn test_get_last_error() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_serialize_deserialize() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql("create table foo(bar integer not null, baz text not null)") + .execute(&mut conn) + .await?; + + sqlx::query("insert into foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')") + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(None).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + conn.deserialize(None, serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} +#[sqlx_macros::test] +async fn test_serialize_deserialize_with_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql( + "attach ':memory:' as foo; create table foo.foo(bar integer not null, baz text not null)", + ) + .execute(&mut conn) + .await?; + + sqlx::query( + "insert into foo.foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')", + ) + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(Some("foo")).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + // Unexpected quirk: the schema must exist before deserialization. + sqlx::raw_sql("attach ':memory:' as foo") + .execute(&mut conn) + .await?; + + conn.deserialize(Some("foo"), serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo.foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_nonexistent_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foobar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::Database(dbe) = err else { + panic!("expected DatabaseError: {err:?}") + }; + + assert_eq!(dbe.code().as_deref(), Some("1")); + assert_eq!(dbe.message(), "database foobar does not exist"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_invalid_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foo\0bar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::InvalidArgument(msg) = err else { + panic!("expected InvalidArgument: {err:?}") + }; + + assert_eq!( + msg, + "schema name \"foo\\0bar\" contains a zero byte at index 3" + ); + + Ok(()) +} From 5c573e15eba7832f7aacd7cc2f71d4201f9c7a85 Mon Sep 17 00:00:00 2001 From: "James H." <32926722+jayy-lmao@users.noreply.github.com> Date: Mon, 3 Mar 2025 09:29:53 +1100 Subject: [PATCH 03/12] feat(postgres): add geometry path (#3716) * feat: add geometry path * fix: paths to pg point * test: remove array tests for path * Fix readme: uuid feature is gating for all repos (#3720) The readme previously stated that the uuid feature is only for postres but it actually also gates the functionality in mysql and sqlite. * Replace some futures_util APIs with std variants (#3721) * feat(sqlx-cli): Add flag to disable automatic loading of .env files (#3724) * Add flag to disable automatic loading of .env files * Update sqlx-cli/src/opt.rs Co-authored-by: Austin Bonander --------- Co-authored-by: Austin Bonander * chore: expose bstr feature (#3714) * chore: replace rustls-pemfile with rustls-pki-types (#3725) * QueryBuilder: add `debug_assert` when `push_values` is passed an empty set of tuples (#3734) * throw a warning in tracing so that the empty tuples would be noticed * use debug assertion to throw a panic in debug mode * fix: merge conflicts * chore(cli): remove unused async-trait crate from dependencies (#3754) * Update pull_request_template.md * Fix example calculation (#3741) * Avoid privilege requirements by using an advisory lock in test setup (postgres). (#3753) * feat(sqlx-postgres): use advisory lock to avoid setup race condition * fix(sqlx-postgres): numeric hex constants not supported before postgres 16 * Small doc correction. (#3755) When sqlx-core/src/from_row.rs was updated to implement FromRow for tuples of up to 16 values, a comment was left stating that it was implemented up to tuples of 9 values. * Update FAQ.md * refactor(cli): replace promptly with dialoguer (#3669) * docs(pool): recommend actix-web ThinData over Data to avoid two Arcs (#3762) Both actix_web::web::Data and sqlx::PgPool internally wrap an Arc. Thus, using Data as an extractor in an actix-web route handler results in two Arcs wrapping the data of interest, which isn't ideal. Actix-web 4.9.0 introduced a new web::ThinData extractor for cases like this, where the data is already wrapped in an `Arc` (or is otherwise similarly cheap and sensible to simply clone), which doesn't wrap the inner value in a (second) Arc. Since the new extractor is better suited to the task, suggest it in place of web::Data when giving an example on how to share a pool. * fix: merge conflicts * fix: use types mod from main * fix: merge conflicts * fix: merge conflicts * fix: merge conflicts * fix: ordering of types mod * fix: path import * test: no array test for path --------- Co-authored-by: Jon Thacker Co-authored-by: Paolo Barbolini Co-authored-by: Ben Wilber Co-authored-by: Austin Bonander Co-authored-by: joeydewaal <99046430+joeydewaal@users.noreply.github.com> Co-authored-by: tottoto Co-authored-by: Ethan Wang Co-authored-by: Stefan Schindler Co-authored-by: kildrens <5198060+kildrens@users.noreply.github.com> Co-authored-by: Marti Serra Co-authored-by: Jonas Malaco --- sqlx-postgres/src/type_checking.rs | 2 + sqlx-postgres/src/types/geometry/mod.rs | 1 + sqlx-postgres/src/types/geometry/path.rs | 372 +++++++++++++++++++++++ sqlx-postgres/src/types/mod.rs | 2 + tests/postgres/types.rs | 6 + 5 files changed, 383 insertions(+) create mode 100644 sqlx-postgres/src/types/geometry/path.rs diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 68a4fcfef..5758c264a 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -40,6 +40,8 @@ impl_type_checking!( sqlx::postgres::types::PgBox, + sqlx::postgres::types::PgPath, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index 7fe2898fc..f67846fef 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,4 +1,5 @@ pub mod r#box; pub mod line; pub mod line_segment; +pub mod path; pub mod point; diff --git a/sqlx-postgres/src/types/geometry/path.rs b/sqlx-postgres/src/types/geometry/path.rs new file mode 100644 index 000000000..87a3b3e8d --- /dev/null +++ b/sqlx-postgres/src/types/geometry/path.rs @@ -0,0 +1,372 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Path type +/// +/// Description: Open path or Closed path (similar to polygon) +/// Representation: Open `[(x1,y1),...]`, Closed `((x1,y1),...)` +/// +/// Paths are represented by lists of connected points. Paths can be open, where the first and last points in the list are considered not connected, or closed, where the first and last points are considered connected. +/// Values of type path are specified using any of the following syntaxes: +/// ```text +/// [ ( x1 , y1 ) , ... , ( xn , yn ) ] +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path. +/// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS +#[derive(Debug, Clone, PartialEq)] +pub struct PgPath { + pub closed: bool, + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + is_closed: bool, + length: usize, +} + +impl Type for PgPath { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("path") + } +} + +impl PgHasArrayType for PgPath { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_path") + } +} + +impl<'r> Decode<'r, Postgres> for PgPath { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPath { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("path")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPath { + type Err = Error; + + fn from_str(s: &str) -> Result { + let closed = !s.contains('['); + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in PATH: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPath { points, closed }); + } + + Err(Error::Decode( + format!("could not get path from {}", s).into(), + )) + } +} + +impl PgPath { + fn header(&self) -> Header { + Header { + is_closed: self.closed, + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPath { + closed: header.is_closed, + points: out_points, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected PATH data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let is_closed = buf.get_i8(); + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received PATH data length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { + is_closed: is_closed != 0, + length, + }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let is_closed = self.is_closed as i8; + + let length = i32::try_from(self.length).map_err(|_| { + format!( + "PATH length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(is_closed.to_be_bytes()); + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod path_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPath; + + const PATH_CLOSED_BYTES: &[u8] = &[ + 1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_OPEN_BYTES: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, 0, 0, 0, 0, + ]; + + const PATH_UNEVEN_POINTS: &[u8] = &[ + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, + 64, 16, 0, 0, + ]; + + #[test] + fn can_deserialise_path_type_bytes_closed() { + let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn cannot_deserialise_path_type_uneven_point_bytes() { + let path = PgPath::from_bytes(PATH_UNEVEN_POINTS); + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("expected 32 bytes after header, got 28") + ) + } + } + + #[test] + fn can_deserialise_path_type_bytes_open() { + let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] + } + ) + } + + #[test] + fn can_deserialise_path_type_str_first_syntax() { + let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + path, + PgPath { + closed: false, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_path_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let path = PgPath::from_str(input_str); + + assert!(path.is_err()); + + if let Err(err) = path { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in PATH: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_path_type_str_second_syntax() { + let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_third_syntax() { + let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_fourth_syntax() { + let path = PgPath::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_path_type_str_float() { + let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + path, + PgPath { + closed: true, + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_path_type() { + let path = PgPath { + closed: true, + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }], + }; + assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index a5fd70836..5d684c969 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -25,6 +25,7 @@ //! | [`PgLine`] | LINE | //! | [`PgLSeg`] | LSEG | //! | [`PgBox`] | BOX | +//! | [`PgPath`] | PATH | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -262,6 +263,7 @@ pub use citext::PgCiText; pub use cube::PgCube; pub use geometry::line::PgLine; pub use geometry::line_segment::PgLSeg; +pub use geometry::path::PgPath; pub use geometry::point::PgPoint; pub use geometry::r#box::PgBox; pub use hstore::PgHstore; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index ccf88b109..0d15caf8d 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -524,6 +524,12 @@ test_type!(_box>(Postgres, "array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }], )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(path(Postgres, + "path('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgPath { closed: true, points: vec![ sqlx::postgres::types::PgPoint { x: 1., y: 2. }, sqlx::postgres::types::PgPoint { x: 3. , y: 4. } ]}, + "path('[(1.0, 2.0), (3.0,4.0)]')" == sqlx::postgres::types::PgPath { closed: false, points: vec![ sqlx::postgres::types::PgPoint { x: 1., y: 2. }, sqlx::postgres::types::PgPoint { x: 3. , y: 4. } ]}, +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), From 7af998c2abda9901c361bad85f124bbed9dde05e Mon Sep 17 00:00:00 2001 From: joeydewaal <99046430+joeydewaal@users.noreply.github.com> Date: Tue, 4 Mar 2025 21:56:08 +0100 Subject: [PATCH 04/12] chore(Sqlite): remove ci.db from repo (#3768) --- ci.db | Bin 36864 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 ci.db diff --git a/ci.db b/ci.db deleted file mode 100644 index cc158a72804c405a8716f792f5ca51ec421c6ae0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 36864 zcmeI*?{Cva7zgmXI6vw@Fd$8qCLvu`p~e#%tI)JpAkuJcBhqBJ?z+CROma|*U_NC8#&WYq);%-0a`BJb) zVf56K>=wC9C?$6pBZM^6qN^oORh22{1+}Fo@~34RTR+JzBHIzJnV;&9DGX2)a>e(ZWf=|2~2Z7}HZ zcE4^-544)OyiC7aO=FRVLP$3flktla+44d}Cw9z(TI+5ky!q)-j45lSPOyI+nMGDS z!Ls`BFi^uZYgHJ|{8`;cPuTjv=G~2ca@Mj|zQbZRU*|UOclb6-D_j0Z4epoas@YvW z?d;83rws1ui9%)CqBYYr>8qP_u5^E?Nc4Yj3a_u(fW;z^`Ebp7@A38aPS0VEy~E!r zRky5C)x2%~cMr3^Y3jXXDS1ChF7ivpkHq*zt&ku90SG_<0uX=z1Rwwb2tWV=5I7S8 zS4$OoT_@V67>~mnY&VR?qmOAt|4P@FlMe@(AFB_&lK)d~p$ z5P$##AOHafKmY;|fB*y_0D*HPP|+yWN}8tY$^1Vxej@6D1OW&@00Izz00bZa0SG_< z0uX?}*%DY$Q~vq`eNTA%O*WhOPkyH(Muqu*YP?nR|7WX%h!z46fB*y_009U<00Izz z00bZaf%g-b&;P$6#+&zRgQFk-0SG_<0uX=z1Rwwb2tWV=5cseKuF^8G^i2^>*gduK zgE$P5@Bgct{v$yE0uX=z1Rwwb2tWV=5P$##AaITalIwrX_?{TQ89%EH5(FRs0SG_< z0uX=z1Rwwb2tWV=AFjX^O(T}vS2y=p5054P0rX0#SVap}$^HK| Date: Tue, 4 Mar 2025 13:04:27 -0800 Subject: [PATCH 05/12] fix: CI * Fix breakage from Rustup 1.28 * Let `Swatinem/rust-cache` generate cache keys --- .github/workflows/examples.yml | 37 +++++++++-------- .github/workflows/sqlx-cli.yml | 28 +++++++------ .github/workflows/sqlx.yml | 75 +++++++++++++++++++--------------- 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 280d1fc4f..0dfbcbdf2 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -14,20 +14,20 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Use latest Rust - run: rustup override set stable + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup override set stable - uses: Swatinem/rust-cache@v2 - with: - key: sqlx-cli - run: > - cargo build - -p sqlx-cli - --bin sqlx - --release - --no-default-features - --features mysql,postgres,sqlite + cargo build + -p sqlx-cli + --bin sqlx + --release + --no-default-features + --features mysql,postgres,sqlite - uses: actions/upload-artifact@v4 with: @@ -63,9 +63,10 @@ jobs: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: mysql-examples - name: Todos (Setup) working-directory: examples/mysql/todos @@ -98,7 +99,7 @@ jobs: name: sqlx-cli path: /home/runner/.local/bin - - run: | + - run: | ls -R /home/runner/.local/bin chmod +x $HOME/.local/bin/sqlx echo $HOME/.local/bin >> $GITHUB_PATH @@ -106,9 +107,8 @@ jobs: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: pg-examples + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install - name: Axum Social with Tests (Setup) working-directory: examples/postgres/axum-social-with-tests @@ -217,9 +217,10 @@ jobs: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: sqlite-examples - name: TODOs (Setup) env: diff --git a/.github/workflows/sqlx-cli.yml b/.github/workflows/sqlx-cli.yml index 3aeb3d7d3..2250e0bfc 100644 --- a/.github/workflows/sqlx-cli.yml +++ b/.github/workflows/sqlx-cli.yml @@ -15,8 +15,9 @@ jobs: steps: - uses: actions/checkout@v4 - - run: | - rustup update + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install rustup component add clippy rustup toolchain install beta rustup component add --toolchain beta clippy @@ -40,18 +41,19 @@ jobs: matrix: # Note: macOS-latest uses M1 Silicon (ARM64) os: - - ubuntu-latest - # FIXME: migrations tests fail on Windows for whatever reason - # - windows-latest - - macOS-13 - - macOS-latest + - ubuntu-latest + # FIXME: migrations tests fail on Windows for whatever reason + # - windows-latest + - macOS-13 + - macOS-latest steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-test - run: cargo test --manifest-path sqlx-cli/Cargo.toml @@ -85,12 +87,12 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Use latest Rust - run: rustup override set stable + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup override set stable - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-cli - run: cargo build --manifest-path sqlx-cli/Cargo.toml --bin cargo-sqlx ${{ matrix.args }} diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 3f1f44d39..8461dd491 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -21,21 +21,22 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - runtime: [async-std, tokio] - tls: [native-tls, rustls, none] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls, none ] steps: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}" - - - run: | - rustup update + # Swatinem/rust-cache recommends setting up the rust toolchain first because it's used in cache keys + - name: Setup Rust + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + run: | + rustup show active-toolchain || rustup toolchain install rustup component add clippy rustup toolchain install beta rustup component add --toolchain beta clippy + - uses: Swatinem/rust-cache@v2 + - run: > cargo clippy --no-default-features @@ -55,8 +56,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - - run: rustup update - - run: rustup toolchain install nightly + - name: Setup Rust + run: | + rustup show active-toolchain || rustup toolchain install + rustup toolchain install nightly - run: cargo +nightly generate-lockfile -Z minimal-versions - run: cargo build --all-features @@ -66,12 +69,11 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: Swatinem/rust-cache@v2 - with: - key: ${{ runner.os }}-test + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install - - name: Install Rust - run: rustup update + - uses: Swatinem/rust-cache@v2 - name: Test sqlx-core run: > @@ -116,17 +118,19 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - runtime: [async-std, tokio] - linking: [sqlite, sqlite-unbundled] + runtime: [ async-std, tokio ] + linking: [ sqlite, sqlite-unbundled ] needs: check steps: - uses: actions/checkout@v4 - run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so + # https://blog.rust-lang.org/2025/03/02/Rustup-1.28.0.html + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-${{ matrix.linking }}-${{ matrix.runtime }}-${{ matrix.tls }}" - name: Install system sqlite library if: ${{ matrix.linking == 'sqlite-unbundled' }} @@ -182,16 +186,17 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - postgres: [17, 13] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + postgres: [ 17, 13 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ matrix.tls }}" - env: # FIXME: needed to disable `ltree` tests in Postgres 9.6 @@ -282,16 +287,17 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - mysql: [8] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + mysql: [ 8 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}" - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} @@ -370,16 +376,17 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - mariadb: [verylatest, 11_4, 10_11, 10_4] - runtime: [async-std, tokio] - tls: [native-tls, rustls-aws-lc-rs, rustls-ring, none] + mariadb: [ verylatest, 11_4, 10_11, 10_4 ] + runtime: [ async-std, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] needs: check steps: - uses: actions/checkout@v4 + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + - uses: Swatinem/rust-cache@v2 - with: - key: "${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}" - run: cargo build --features mysql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} From c3fd645409bd48fdd70b79734b27291aab6a3ec9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 4 Mar 2025 13:51:45 -0800 Subject: [PATCH 06/12] fix(ci): upgrade Ubuntu image to 24.04 For some reason the `cargo +beta clippy` step is failing because `libsqlite3-sys` starts requiring Glibc >= 2.39 but I don't have time to figure out why and I can't reproduce it in a clean environment. --- .github/workflows/sqlx.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 8461dd491..7f573a634 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -10,7 +10,7 @@ on: jobs: format: name: Format - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: rustup component add rustfmt @@ -18,7 +18,7 @@ jobs: check: name: Check - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: runtime: [ async-std, tokio ] @@ -53,7 +53,7 @@ jobs: check-minimal-versions: name: Check build using minimal versions - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - name: Setup Rust @@ -65,7 +65,7 @@ jobs: test: name: Unit Tests - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -115,7 +115,7 @@ jobs: sqlite: name: SQLite - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: runtime: [ async-std, tokio ] @@ -183,7 +183,7 @@ jobs: postgres: name: Postgres - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: postgres: [ 17, 13 ] @@ -284,7 +284,7 @@ jobs: mysql: name: MySQL - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: mysql: [ 8 ] @@ -373,7 +373,7 @@ jobs: mariadb: name: MariaDB - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: mariadb: [ verylatest, 11_4, 10_11, 10_4 ] From a92626d6cc1aef88b0036d633ed370f6f9129c32 Mon Sep 17 00:00:00 2001 From: Chitoku Date: Tue, 4 Mar 2025 18:51:33 +0900 Subject: [PATCH 07/12] postgres: Use current tracing span when dropping PgListener --- sqlx-core/src/ext/async_stream.rs | 4 ++-- sqlx-postgres/src/listener.rs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index a83aabed1..56777ca4d 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -121,7 +121,7 @@ impl<'a, T> Stream for TryAsyncStream<'a, T> { #[macro_export] macro_rules! try_stream { ($($block:tt)*) => { - $crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move { + $crate::ext::async_stream::TryAsyncStream::new(move |yielder| ::tracing::Instrument::in_current_span(async move { // Anti-footgun: effectively pins `yielder` to this future to prevent any accidental // move to another task, which could deadlock. let yielder = &yielder; @@ -133,6 +133,6 @@ macro_rules! try_stream { } $($block)* - }) + })) } } diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index b96f8d829..17a46a916 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -9,6 +9,7 @@ use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; use sqlx_core::transaction::Transaction; use sqlx_core::Either; +use tracing::Instrument; use crate::describe::Describe; use crate::error::Error; @@ -366,7 +367,7 @@ impl Drop for PgListener { }; // Unregister any listeners before returning the connection to the pool. - crate::rt::spawn(fut); + crate::rt::spawn(fut.in_current_span()); } } } From ca3a5090369238d156eb859d3fd699d86681f73f Mon Sep 17 00:00:00 2001 From: "James H." <32926722+jayy-lmao@users.noreply.github.com> Date: Fri, 7 Mar 2025 20:25:45 +1100 Subject: [PATCH 08/12] feat(postgres): add geometry polygon (#3769) * feat: add polygon * test: paths for pgpoints in polygon test * fix: import typo * chore(Sqlite): remove ci.db from repo (#3768) * fix: CI * Fix breakage from Rustup 1.28 * Let `Swatinem/rust-cache` generate cache keys * fix(ci): upgrade Ubuntu image to 24.04 For some reason the `cargo +beta clippy` step is failing because `libsqlite3-sys` starts requiring Glibc >= 2.39 but I don't have time to figure out why and I can't reproduce it in a clean environment. --------- Co-authored-by: joeydewaal <99046430+joeydewaal@users.noreply.github.com> Co-authored-by: Austin Bonander --- sqlx-postgres/src/type_checking.rs | 2 + sqlx-postgres/src/types/geometry/mod.rs | 1 + sqlx-postgres/src/types/geometry/polygon.rs | 363 ++++++++++++++++++++ sqlx-postgres/src/types/mod.rs | 2 + tests/postgres/types.rs | 9 + 5 files changed, 377 insertions(+) create mode 100644 sqlx-postgres/src/types/geometry/polygon.rs diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 5758c264a..c82fd6218 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -42,6 +42,8 @@ impl_type_checking!( sqlx::postgres::types::PgPath, + sqlx::postgres::types::PgPolygon, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index f67846fef..1437d72c5 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -3,3 +3,4 @@ pub mod line; pub mod line_segment; pub mod path; pub mod point; +pub mod polygon; diff --git a/sqlx-postgres/src/types/geometry/polygon.rs b/sqlx-postgres/src/types/geometry/polygon.rs new file mode 100644 index 000000000..500c9933e --- /dev/null +++ b/sqlx-postgres/src/types/geometry/polygon.rs @@ -0,0 +1,363 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{PgPoint, Type}; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::mem; +use std::str::FromStr; + +const BYTE_WIDTH: usize = mem::size_of::(); + +/// ## Postgres Geometric Polygon type +/// +/// Description: Polygon (similar to closed polygon) +/// Representation: `((x1,y1),...)` +/// +/// Polygons are represented by lists of points (the vertexes of the polygon). Polygons are very similar to closed paths; the essential semantic difference is that a polygon is considered to include the area within it, while a path is not. +/// An important implementation difference between polygons and paths is that the stored representation of a polygon includes its smallest bounding box. This speeds up certain search operations, although computing the bounding box adds overhead while constructing new polygons. +/// Values of type polygon are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) +/// ( x1 , y1 ) , ... , ( xn , yn ) +/// ( x1 , y1 , ... , xn , yn ) +/// x1 , y1 , ... , xn , yn +/// ``` +/// +/// where the points are the end points of the line segments comprising the boundary of the polygon. +/// +/// Seeh ttps://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-POLYGON +#[derive(Debug, Clone, PartialEq)] +pub struct PgPolygon { + pub points: Vec, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + length: usize, +} + +impl Type for PgPolygon { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("polygon") + } +} + +impl PgHasArrayType for PgPolygon { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_polygon") + } +} + +impl<'r> Decode<'r, Postgres> for PgPolygon { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgPolygon::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgPolygon::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgPolygon { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("polygon")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgPolygon { + type Err = Error; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let parts = sanitised.split(',').collect::>(); + + let mut points = vec![]; + + if parts.len() % 2 != 0 { + return Err(Error::Decode( + format!("Unmatched pair in POLYGON: {}", s).into(), + )); + } + + for chunk in parts.chunks_exact(2) { + if let [x_str, y_str] = chunk { + let x = parse_float_from_str(x_str, "could not get x")?; + let y = parse_float_from_str(y_str, "could not get y")?; + + let point = PgPoint { x, y }; + points.push(point); + } + } + + if !points.is_empty() { + return Ok(PgPolygon { points }); + } + + Err(Error::Decode( + format!("could not get polygon from {}", s).into(), + )) + } +} + +impl PgPolygon { + fn header(&self) -> Header { + Header { + length: self.points.len(), + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ) + .into()); + } + + if bytes.len() % BYTE_WIDTH * 2 != 0 { + return Err(format!( + "data length not divisible by pairs of {BYTE_WIDTH}: {}", + bytes.len() + ) + .into()); + } + + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); + while bytes.has_remaining() { + let point = PgPoint { + x: bytes.get_f64(), + y: bytes.get_f64(), + }; + out_points.push(point) + } + Ok(PgPolygon { points: out_points }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { + let header = self.header(); + buff.reserve(header.data_size()); + header.try_write(buff)?; + + for point in &self.points { + buff.extend_from_slice(&point.x.to_be_bytes()); + buff.extend_from_slice(&point.y.to_be_bytes()); + } + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +impl Header { + const HEADER_WIDTH: usize = mem::size_of::() + mem::size_of::(); + + fn data_size(&self) -> usize { + self.length * BYTE_WIDTH * 2 + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::HEADER_WIDTH { + return Err(format!( + "expected polygon data to contain at least {} bytes, got {}", + Self::HEADER_WIDTH, + buf.len() + )); + } + + let length = buf.get_i32(); + + let length = usize::try_from(length).ok().ok_or_else(|| { + format!( + "received polygon with length: {length}. Expected length between 0 and {}", + usize::MAX + ) + })?; + + Ok(Self { length }) + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let length = i32::try_from(self.length).map_err(|_| { + format!( + "polygon length exceeds allowed maximum ({} > {})", + self.length, + i32::MAX + ) + })?; + + buff.extend(length.to_be_bytes()); + + Ok(()) + } +} + +fn parse_float_from_str(s: &str, error_msg: &str) -> Result { + s.parse().map_err(|_| Error::Decode(error_msg.into())) +} + +#[cfg(test)] +mod polygon_tests { + + use std::str::FromStr; + + use crate::types::PgPoint; + + use super::PgPolygon; + + const POLYGON_BYTES: &[u8] = &[ + 0, 0, 0, 12, 192, 0, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, + 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 63, + 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, + 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, + 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, + 240, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, + 0, 0, 0, + ]; + + #[test] + fn can_deserialise_polygon_type_bytes() { + let polygon = PgPolygon::from_bytes(POLYGON_BYTES).unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. } + ] + } + ) + } + + #[test] + fn can_deserialise_polygon_type_str_first_syntax() { + let polygon = PgPolygon::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_second_syntax() { + let polygon = PgPolygon::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn cannot_deserialise_polygon_type_str_uneven_points_first_syntax() { + let input_str = "[( 1, 2), (3)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: Unmatched pair in POLYGON: {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_polygon_type_str_invalid_numbers() { + let input_str = "[( 1, 2), (2, three)]"; + let polygon = PgPolygon::from_str(input_str); + + assert!(polygon.is_err()); + + if let Err(err) = polygon { + assert_eq!( + err.to_string(), + format!("error occurred while decoding: could not get y") + ) + } + } + + #[test] + fn can_deserialise_polygon_type_str_third_syntax() { + let polygon = PgPolygon::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_fourth_syntax() { + let polygon = PgPolygon::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] + } + ); + } + + #[test] + fn can_deserialise_polygon_type_str_float() { + let polygon = PgPolygon::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + polygon, + PgPolygon { + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] + } + ); + } + + #[test] + fn can_serialise_polygon_type() { + let polygon = PgPolygon { + points: vec![ + PgPoint { x: -2., y: -3. }, + PgPoint { x: -1., y: -3. }, + PgPoint { x: -1., y: -1. }, + PgPoint { x: 1., y: 1. }, + PgPoint { x: 1., y: 3. }, + PgPoint { x: 2., y: 3. }, + PgPoint { x: 2., y: -3. }, + PgPoint { x: 1., y: -3. }, + PgPoint { x: 1., y: 0. }, + PgPoint { x: -1., y: 0. }, + PgPoint { x: -1., y: -2. }, + PgPoint { x: -2., y: -2. }, + ], + }; + assert_eq!(polygon.serialize_to_vec(), POLYGON_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 5d684c969..550ce6292 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -26,6 +26,7 @@ //! | [`PgLSeg`] | LSEG | //! | [`PgBox`] | BOX | //! | [`PgPath`] | PATH | +//! | [`PgPolygon`] | POLYGON | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -265,6 +266,7 @@ pub use geometry::line::PgLine; pub use geometry::line_segment::PgLSeg; pub use geometry::path::PgPath; pub use geometry::point::PgPoint; +pub use geometry::polygon::PgPolygon; pub use geometry::r#box::PgBox; pub use hstore::PgHstore; pub use interval::PgInterval; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 0d15caf8d..d88e1657c 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -530,6 +530,15 @@ test_type!(path(Postgres, "path('[(1.0, 2.0), (3.0,4.0)]')" == sqlx::postgres::types::PgPath { closed: false, points: vec![ sqlx::postgres::types::PgPoint { x: 1., y: 2. }, sqlx::postgres::types::PgPoint { x: 3. , y: 4. } ]}, )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(polygon(Postgres, + "polygon('((-2,-3),(-1,-3),(-1,-1),(1,1),(1,3),(2,3),(2,-3),(1,-3),(1,0),(-1,0),(-1,-2),(-2,-2))')" ~= sqlx::postgres::types::PgPolygon { points: vec![ + sqlx::postgres::types::PgPoint { x: -2., y: -3. }, sqlx::postgres::types::PgPoint { x: -1., y: -3. }, sqlx::postgres::types::PgPoint { x: -1., y: -1. }, sqlx::postgres::types::PgPoint { x: 1., y: 1. }, + sqlx::postgres::types::PgPoint { x: 1., y: 3. }, sqlx::postgres::types::PgPoint { x: 2., y: 3. }, sqlx::postgres::types::PgPoint { x: 2., y: -3. }, sqlx::postgres::types::PgPoint { x: 1., y: -3. }, + sqlx::postgres::types::PgPoint { x: 1., y: 0. }, sqlx::postgres::types::PgPoint { x: -1., y: 0. }, sqlx::postgres::types::PgPoint { x: -1., y: -2. }, sqlx::postgres::types::PgPoint { x: -2., y: -2. }, + ]}, +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), From 2f10c29dfd48dd9bac66b2fbabced6e8f0cfd445 Mon Sep 17 00:00:00 2001 From: "James H." <32926722+jayy-lmao@users.noreply.github.com> Date: Mon, 10 Mar 2025 09:01:30 +1100 Subject: [PATCH 09/12] feat(postgres): add geometry circle (#3773) * feat: circle * docs: comments --- sqlx-postgres/src/type_checking.rs | 2 + sqlx-postgres/src/types/geometry/box.rs | 5 +- sqlx-postgres/src/types/geometry/circle.rs | 250 ++++++++++++++++++ sqlx-postgres/src/types/geometry/line.rs | 5 +- .../src/types/geometry/line_segment.rs | 5 +- sqlx-postgres/src/types/geometry/mod.rs | 1 + sqlx-postgres/src/types/geometry/path.rs | 5 +- sqlx-postgres/src/types/geometry/point.rs | 5 +- sqlx-postgres/src/types/geometry/polygon.rs | 5 +- sqlx-postgres/src/types/mod.rs | 2 + tests/postgres/types.rs | 8 + 11 files changed, 287 insertions(+), 6 deletions(-) create mode 100644 sqlx-postgres/src/types/geometry/circle.rs diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index c82fd6218..a28531c9b 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -44,6 +44,8 @@ impl_type_checking!( sqlx::postgres::types::PgPolygon, + sqlx::postgres::types::PgCircle, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs index 988c028ed..28016b278 100644 --- a/sqlx-postgres/src/types/geometry/box.rs +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -23,7 +23,10 @@ const ERROR: &str = "error decoding BOX"; /// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box. /// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +/// See [Postgres Manual, Section 8.8.4: Geometric Types - Boxes][PG.S.8.8.4] for details. +/// +/// [PG.S.8.8.4]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +/// #[derive(Debug, Clone, PartialEq)] pub struct PgBox { pub upper_right_x: f64, diff --git a/sqlx-postgres/src/types/geometry/circle.rs b/sqlx-postgres/src/types/geometry/circle.rs new file mode 100644 index 000000000..dde54dd27 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/circle.rs @@ -0,0 +1,250 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use sqlx_core::Error; +use std::str::FromStr; + +const ERROR: &str = "error decoding CIRCLE"; + +/// ## Postgres Geometric Circle type +/// +/// Description: Circle +/// Representation: `< (x, y), radius >` (center point and radius) +/// +/// ```text +/// < ( x , y ) , radius > +/// ( ( x , y ) , radius ) +/// ( x , y ) , radius +/// x , y , radius +/// ``` +/// where `(x,y)` is the center point. +/// +/// See [Postgres Manual, Section 8.8.7, Geometric Types - Circles][PG.S.8.8.7] for details. +/// +/// [PG.S.8.8.7]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-CIRCLE +/// +#[derive(Debug, Clone, PartialEq)] +pub struct PgCircle { + pub x: f64, + pub y: f64, + pub radius: f64, +} + +impl Type for PgCircle { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("circle") + } +} + +impl PgHasArrayType for PgCircle { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_circle") + } +} + +impl<'r> Decode<'r, Postgres> for PgCircle { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgCircle::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgCircle::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgCircle { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("circle")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgCircle { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['<', '>', '(', ')', ' '], ""); + let mut parts = sanitised.split(','); + + let x = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get x from {}", ERROR, s))?; + + let y = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get y from {}", ERROR, s))?; + + let radius = parts + .next() + .and_then(|s| s.trim().parse::().ok()) + .ok_or_else(|| format!("{}: could not get radius from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + if radius < 0. { + return Err(format!("{}: cannot have negative radius: {}", ERROR, s).into()); + } + + Ok(PgCircle { x, y, radius }) + } +} + +impl PgCircle { + fn from_bytes(mut bytes: &[u8]) -> Result { + let x = bytes.get_f64(); + let y = bytes.get_f64(); + let r = bytes.get_f64(); + Ok(PgCircle { x, y, radius: r }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> { + buff.extend_from_slice(&self.x.to_be_bytes()); + buff.extend_from_slice(&self.y.to_be_bytes()); + buff.extend_from_slice(&self.radius.to_be_bytes()); + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod circle_tests { + + use std::str::FromStr; + + use super::PgCircle; + + const CIRCLE_BYTES: &[u8] = &[ + 63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102, + 102, 102, 102, 102, 102, + ]; + + #[test] + fn can_deserialise_circle_type_bytes() { + let circle = PgCircle::from_bytes(CIRCLE_BYTES).unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ) + } + + #[test] + fn can_deserialise_circle_type_str() { + let circle = PgCircle::from_str("<(1, 2), 3 >").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_second_syntax() { + let circle = PgCircle::from_str("((1, 2), 3 )").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_third_syntax() { + let circle = PgCircle::from_str("(1, 2), 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn can_deserialise_circle_type_str_fourth_syntax() { + let circle = PgCircle::from_str("1, 2, 3 ").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.0, + y: 2.0, + radius: 3.0 + } + ); + } + + #[test] + fn cannot_deserialise_circle_invalid_numbers() { + let input_str = "1, 2, Three"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: could not get radius from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_circle_negative_radius() { + let input_str = "1, 2, -3"; + let circle = PgCircle::from_str(input_str); + assert!(circle.is_err()); + if let Err(err) = circle { + assert_eq!( + err.to_string(), + format!("error decoding CIRCLE: cannot have negative radius: {input_str}") + ) + } + } + + #[test] + fn can_deserialise_circle_type_str_float() { + let circle = PgCircle::from_str("<(1.1, 2.2), 3.3>").unwrap(); + assert_eq!( + circle, + PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3 + } + ); + } + + #[test] + fn can_serialise_circle_type() { + let circle = PgCircle { + x: 1.1, + y: 2.2, + radius: 3.3, + }; + assert_eq!(circle.serialize_to_vec(), CIRCLE_BYTES,) + } +} diff --git a/sqlx-postgres/src/types/geometry/line.rs b/sqlx-postgres/src/types/geometry/line.rs index 43f93c1c3..8f08c949e 100644 --- a/sqlx-postgres/src/types/geometry/line.rs +++ b/sqlx-postgres/src/types/geometry/line.rs @@ -15,7 +15,10 @@ const ERROR: &str = "error decoding LINE"; /// /// Lines are represented by the linear equation Ax + By + C = 0, where A and B are not both zero. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LINE +/// See [Postgres Manual, Section 8.8.2, Geometric Types - Lines][PG.S.8.8.2] for details. +/// +/// [PG.S.8.8.2]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LINE +/// #[derive(Debug, Clone, PartialEq)] pub struct PgLine { pub a: f64, diff --git a/sqlx-postgres/src/types/geometry/line_segment.rs b/sqlx-postgres/src/types/geometry/line_segment.rs index 5dc5efc74..cd08e4da4 100644 --- a/sqlx-postgres/src/types/geometry/line_segment.rs +++ b/sqlx-postgres/src/types/geometry/line_segment.rs @@ -23,7 +23,10 @@ const ERROR: &str = "error decoding LSEG"; /// ``` /// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG +/// See [Postgres Manual, Section 8.8.3, Geometric Types - Line Segments][PG.S.8.8.3] for details. +/// +/// [PG.S.8.8.3]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LSEG +/// #[doc(alias = "line segment")] #[derive(Debug, Clone, PartialEq)] pub struct PgLSeg { diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index 1437d72c5..c3142145e 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,4 +1,5 @@ pub mod r#box; +pub mod circle; pub mod line; pub mod line_segment; pub mod path; diff --git a/sqlx-postgres/src/types/geometry/path.rs b/sqlx-postgres/src/types/geometry/path.rs index 87a3b3e8d..6799289fa 100644 --- a/sqlx-postgres/src/types/geometry/path.rs +++ b/sqlx-postgres/src/types/geometry/path.rs @@ -27,7 +27,10 @@ const BYTE_WIDTH: usize = mem::size_of::(); /// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path. /// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS +/// See [Postgres Manual, Section 8.8.5, Geometric Types - Paths][PG.S.8.8.5] for details. +/// +/// [PG.S.8.8.5]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS +/// #[derive(Debug, Clone, PartialEq)] pub struct PgPath { pub closed: bool, diff --git a/sqlx-postgres/src/types/geometry/point.rs b/sqlx-postgres/src/types/geometry/point.rs index cc1067295..83b7c24d0 100644 --- a/sqlx-postgres/src/types/geometry/point.rs +++ b/sqlx-postgres/src/types/geometry/point.rs @@ -19,7 +19,10 @@ use std::str::FromStr; /// ```` /// where x and y are the respective coordinates, as floating-point numbers. /// -/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS +/// See [Postgres Manual, Section 8.8.1, Geometric Types - Points][PG.S.8.8.1] for details. +/// +/// [PG.S.8.8.1]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS +/// #[derive(Debug, Clone, PartialEq)] pub struct PgPoint { pub x: f64, diff --git a/sqlx-postgres/src/types/geometry/polygon.rs b/sqlx-postgres/src/types/geometry/polygon.rs index 500c9933e..a5a203c68 100644 --- a/sqlx-postgres/src/types/geometry/polygon.rs +++ b/sqlx-postgres/src/types/geometry/polygon.rs @@ -28,7 +28,10 @@ const BYTE_WIDTH: usize = mem::size_of::(); /// /// where the points are the end points of the line segments comprising the boundary of the polygon. /// -/// Seeh ttps://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-POLYGON +/// See [Postgres Manual, Section 8.8.6, Geometric Types - Polygons][PG.S.8.8.6] for details. +/// +/// [PG.S.8.8.6]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-POLYGON +/// #[derive(Debug, Clone, PartialEq)] pub struct PgPolygon { pub points: Vec, diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 550ce6292..c3493139c 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -27,6 +27,7 @@ //! | [`PgBox`] | BOX | //! | [`PgPath`] | PATH | //! | [`PgPolygon`] | POLYGON | +//! | [`PgCircle`] | CIRCLE | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -262,6 +263,7 @@ mod bit_vec; pub use array::PgHasArrayType; pub use citext::PgCiText; pub use cube::PgCube; +pub use geometry::circle::PgCircle; pub use geometry::line::PgLine; pub use geometry::line_segment::PgLSeg; pub use geometry::path::PgPath; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index d88e1657c..da20467ea 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -539,6 +539,14 @@ test_type!(polygon(Postgres, ]}, )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(circle(Postgres, + "circle('<(1.1, -2.2), 3.3>')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('((1.1, -2.2), 3.3)')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('(1.1, -2.2), 3.3')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, + "circle('1.1, -2.2, 3.3')" ~= sqlx::postgres::types::PgCircle { x: 1.1, y:-2.2, radius: 3.3 }, +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), From 393b731d5e04664cd0ece1bfd130086187090d79 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 10 Mar 2025 14:29:46 -0700 Subject: [PATCH 10/12] Merge of #3427 (by @mpyw) and #3614 (by @bonsairobo) (#3765) * feat: Implement `get_transaction_depth` for drivers * test: Verify `get_transaction_depth()` on postgres * Refactor: `TransactionManager` delegation without BC SQLite implementation is currently WIP * Fix: Avoid breaking changes on `AnyConnectionBackend` * Refactor: Remove verbose `SqliteConnection` typing * Feat: Implementation for SQLite I have included `AtomicUsize` in `WorkerSharedState`. Ideally, it is not desirable to execute `load` and `fetch_add` in two separate steps, but we decided to allow it here since there is only one thread writing. To prevent writing from other threads, the field itself was made private, and a getter method was provided with `pub(crate)`. * Refactor: Same approach for `cached_statements_size` ref: a66787d36d62876b55475ef2326d17bade817aed * Fix: Add missing `is_in_transaction` for backend * Doc: Remove verbose "synchronously" word * Fix: Remove useless `mut` qualifier * feat: add Connection::begin_with This patch completes the plumbing of an optional statement from these methods to `TransactionManager::begin` without any validation of the provided statement. There is a new `Error::InvalidSavePoint` which is triggered by any attempt to call `Connection::begin_with` when we are already inside of a transaction. * feat: add Pool::begin_with and Pool::try_begin_with * feat: add Error::BeginFailed and validate that custom "begin" statements are successful * chore: add tests of Error::BeginFailed * chore: add tests of Error::InvalidSavePointStatement * chore: test begin_with works for all SQLite "BEGIN" statements * chore: improve comment on Connection::begin_with * feat: add default impl of `Connection::begin_with` This makes the new method a non-breaking change. * refactor: combine if statement + unwrap_or_else into one match * feat: use in-memory SQLite DB to avoid conflicts across tests run in parallel * feedback: remove public wrapper for sqlite3_txn_state Move the wrapper directly into the test that uses it instead. * fix: cache Status on MySqlConnection * fix: compilation errors * fix: format * fix: postgres test * refactor: delete `Connection::get_transaction_depth` * fix: tests --------- Co-authored-by: mpyw Co-authored-by: Duncan Fairbanks --- Cargo.toml | 1 + sqlx-core/src/acquire.rs | 4 +- sqlx-core/src/any/connection/backend.rs | 29 +++++++++++- sqlx-core/src/any/connection/mod.rs | 13 +++++- sqlx-core/src/any/transaction.rs | 13 +++++- sqlx-core/src/connection.rs | 30 +++++++++++- sqlx-core/src/error.rs | 6 +++ sqlx-core/src/pool/connection.rs | 2 +- sqlx-core/src/pool/mod.rs | 39 +++++++++++++++- sqlx-core/src/transaction.rs | 26 +++++++++-- sqlx-mysql/src/any.rs | 12 ++++- sqlx-mysql/src/connection/establish.rs | 1 + sqlx-mysql/src/connection/executor.rs | 4 ++ sqlx-mysql/src/connection/mod.rs | 23 ++++++++- sqlx-mysql/src/protocol/response/status.rs | 2 +- sqlx-mysql/src/transaction.rs | 26 +++++++++-- sqlx-postgres/src/any.rs | 12 ++++- sqlx-postgres/src/connection/mod.rs | 20 +++++++- sqlx-postgres/src/transaction.rs | 28 +++++++++-- sqlx-sqlite/src/any.rs | 13 +++++- sqlx-sqlite/src/connection/establish.rs | 1 - sqlx-sqlite/src/connection/mod.rs | 30 ++++++++---- sqlx-sqlite/src/connection/worker.rs | 54 +++++++++++++++++----- sqlx-sqlite/src/transaction.rs | 26 +++++++++-- tests/mysql/error.rs | 28 ++++++++++- tests/postgres/error.rs | 28 ++++++++++- tests/postgres/postgres.rs | 5 ++ tests/sqlite/error.rs | 28 ++++++++++- tests/sqlite/sqlite.rs | 51 ++++++++++++++++++++ 29 files changed, 494 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5a040e546..f31d715b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -189,6 +189,7 @@ rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" criterion = { version = "0.5.1", features = ["async_tokio"] } +libsqlite3-sys = { version = "0.30.1" } # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-core/src/acquire.rs b/sqlx-core/src/acquire.rs index c9d7fb215..59bac9fa5 100644 --- a/sqlx-core/src/acquire.rs +++ b/sqlx-core/src/acquire.rs @@ -93,7 +93,7 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool { let conn = self.acquire(); Box::pin(async move { - Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await + Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?), None).await }) } } @@ -121,7 +121,7 @@ macro_rules! impl_acquire { 'c, Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>, > { - $crate::transaction::Transaction::begin(self) + $crate::transaction::Transaction::begin(self, None) } } }; diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index b30cbe83f..6c84c1d8c 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -3,6 +3,7 @@ use crate::describe::Describe; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; +use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -26,7 +27,13 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; @@ -34,6 +41,26 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn start_rollback(&mut self); + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(&self) -> usize { + unimplemented!("get_transaction_depth() is not implemented for this backend. This is a provided method to avoid a breaking change, but it will become a required method in version 0.9 and later."); + } + + /// Checks if the connection is currently in a transaction. + /// + /// This method returns `true` if the current transaction depth is greater than 0, + /// indicating that a transaction is active. It returns `false` if the transaction depth is 0, + /// meaning no transaction is active. + #[inline] + fn is_in_transaction(&self) -> bool { + self.get_transaction_depth() != 0 + } + /// The number of statements currently cached in the connection. fn cached_statements_size(&self) -> usize { 0 diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index b6f795848..8cf8fc510 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,4 +1,5 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnectOptions}; use crate::connection::{ConnectOptions, Connection}; @@ -87,7 +88,17 @@ impl Connection for AnyConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index fce417562..a553cda92 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,6 +1,8 @@ use futures_util::future::BoxFuture; +use std::borrow::Cow; use crate::any::{Any, AnyConnection}; +use crate::database::Database; use crate::error::Error; use crate::transaction::TransactionManager; @@ -9,8 +11,11 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { - conn.backend.begin() + fn begin<'conn>( + conn: &'conn mut AnyConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + conn.backend.begin(statement) } fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> { @@ -24,4 +29,8 @@ impl TransactionManager for AnyTransactionManager { fn start_rollback(conn: &mut AnyConnection) { conn.backend.start_rollback() } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.backend.get_transaction_depth() + } } diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index ce2aa6c62..74e8cd3e8 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,9 +1,10 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; -use crate::transaction::Transaction; +use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; +use std::borrow::Cow; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; @@ -49,6 +50,33 @@ pub trait Connection: Send { where Self: Sized; + /// Begin a new transaction with a custom statement. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + /// + /// Returns an error if the connection is already in a transaction or if + /// `statement` does not put the connection into a transaction. + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) + } + + /// Returns `true` if the connection is currently in a transaction. + /// + /// # Note: Automatic Rollbacks May Not Be Counted + /// Certain database errors (such as a serializable isolation failure) + /// can cause automatic rollbacks of a transaction + /// which may not be indicated in the return value of this method. + #[inline] + fn is_in_transaction(&self) -> bool { + ::TransactionManager::get_transaction_depth(self) != 0 + } + /// Execute the function inside a transaction. /// /// If the function returns an error, the transaction will be rolled back. If it does not diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 98b42fbcd..9ad5eff46 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -117,6 +117,12 @@ pub enum Error { #[cfg(feature = "migrate")] #[error("{0}")] Migrate(#[source] Box), + + #[error("attempted to call begin_with at non-zero transaction depth")] + InvalidSavePointStatement, + + #[error("got unexpected connection status after attempting to begin transaction")] + BeginFailed, } impl StdError for Box {} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index bf3a6d4b1..c029fec6e 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -191,7 +191,7 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection futures_core::future::BoxFuture<'c, Result, Error>> { - crate::transaction::Transaction::begin(&mut **self) + crate::transaction::Transaction::begin(&mut **self, None) } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 8aa9041ab..d85bce246 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,6 +54,7 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::{pin, Pin}; @@ -368,13 +369,17 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction. pub async fn begin(&self) -> Result, Error> { - Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + None, + ) + .await } /// Attempts to retrieve a connection and immediately begins a new transaction if successful. pub async fn try_begin(&self) -> Result>, Error> { match self.try_acquire() { - Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn)) + Some(conn) => Transaction::begin(MaybePoolConnection::PoolConnection(conn), None) .await .map(Some), @@ -382,6 +387,36 @@ impl Pool { } } + /// Retrieves a connection and immediately begins a new transaction using `statement`. + pub async fn begin_with( + &self, + statement: impl Into>, + ) -> Result, Error> { + Transaction::begin( + MaybePoolConnection::PoolConnection(self.acquire().await?), + Some(statement.into()), + ) + .await + } + + /// Attempts to retrieve a connection and, if successful, immediately begins a new + /// transaction using `statement`. + pub async fn try_begin_with( + &self, + statement: impl Into>, + ) -> Result>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin( + MaybePoolConnection::PoolConnection(conn), + Some(statement.into()), + ) + .await + .map(Some), + + None => Ok(None), + } + } + /// Shut down the connection pool, immediately waking all tasks waiting for a connection. /// /// Upon calling this method, any currently waiting or subsequent calls to [`Pool::acquire`] and diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 9cd38aab3..2a84ff655 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -16,9 +16,16 @@ pub trait TransactionManager { type Database: Database; /// Begin a new transaction or establish a savepoint within the active transaction. - fn begin( - conn: &mut ::Connection, - ) -> BoxFuture<'_, Result<(), Error>>; + /// + /// If this is a new transaction, `statement` may be used instead of the + /// default "BEGIN" statement. + /// + /// If we are already inside a transaction and `statement.is_some()`, then + /// `Error::InvalidSavePoint` is returned without running any statements. + fn begin<'conn>( + conn: &'conn mut ::Connection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>>; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -32,6 +39,14 @@ pub trait TransactionManager { /// Starts to abort the active transaction or restore from the most recent snapshot. fn start_rollback(conn: &mut ::Connection); + + /// Returns the current transaction depth. + /// + /// Transaction depth indicates the level of nested transactions: + /// - Level 0: No active transaction. + /// - Level 1: A transaction is active. + /// - Level 2 or higher: A transaction is active and one or more SAVEPOINTs have been created within it. + fn get_transaction_depth(conn: &::Connection) -> usize; } /// An in-progress database transaction or savepoint. @@ -83,11 +98,12 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, + statement: Option>, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); Box::pin(async move { - DB::TransactionManager::begin(&mut conn).await?; + DB::TransactionManager::begin(&mut conn, statement).await?; Ok(Self { connection: conn, @@ -237,7 +253,7 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<' #[inline] fn begin(self) -> BoxFuture<'t, Result, Error>> { - Transaction::begin(&mut **self) + Transaction::begin(&mut **self, None) } } diff --git a/sqlx-mysql/src/any.rs b/sqlx-mysql/src/any.rs index e01e41d68..19b3a6f27 100644 --- a/sqlx-mysql/src/any.rs +++ b/sqlx-mysql/src/any.rs @@ -16,6 +16,7 @@ use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::transaction::TransactionManager; +use std::borrow::Cow; use std::{future, pin::pin}; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql); @@ -37,8 +38,11 @@ impl AnyConnectionBackend for MySqlConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - MySqlTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + MySqlTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -53,6 +57,10 @@ impl AnyConnectionBackend for MySqlConnection { MySqlTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + MySqlTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index 0623a0556..85a9d84f9 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -27,6 +27,7 @@ impl MySqlConnection { inner: Box::new(MySqlConnectionInner { stream, transaction_depth: 0, + status_flags: Default::default(), cache_statement: StatementCache::new(options.statement_cache_capacity), log_settings: options.log_settings.clone(), }), diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index d93aac0d6..4f5af4bf6 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -166,6 +166,8 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; + self.inner.status_flags = ok.status; + let rows_affected = ok.affected_rows; logger.increase_rows_affected(rows_affected); let done = MySqlQueryResult { @@ -208,6 +210,8 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.inner.stream.capabilities)?; + self.inner.status_flags = eof.status; + r#yield!(Either::Left(MySqlQueryResult { rows_affected: 0, last_insert_id: 0, diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index c4978a770..0a2f5fb83 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use futures_core::future::BoxFuture; @@ -7,6 +8,7 @@ pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; use crate::error::Error; +use crate::protocol::response::Status; use crate::protocol::statement::StmtClose; use crate::protocol::text::{Ping, Quit}; use crate::statement::MySqlStatementMetadata; @@ -34,6 +36,7 @@ pub(crate) struct MySqlConnectionInner { // transaction status pub(crate) transaction_depth: usize, + status_flags: Status, // cache by query string to the statement id and metadata cache_statement: StatementCache<(u32, MySqlStatementMetadata)>, @@ -41,6 +44,14 @@ pub(crate) struct MySqlConnectionInner { log_settings: LogSettings, } +impl MySqlConnection { + pub(crate) fn in_transaction(&self) -> bool { + self.inner + .status_flags + .intersects(Status::SERVER_STATUS_IN_TRANS) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() @@ -111,7 +122,17 @@ impl Connection for MySqlConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/protocol/response/status.rs b/sqlx-mysql/src/protocol/response/status.rs index bf5013dee..4a8bb0375 100644 --- a/sqlx-mysql/src/protocol/response/status.rs +++ b/sqlx-mysql/src/protocol/response/status.rs @@ -1,7 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index d8538cc2b..545cb5f4f 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::connection::Waiting; @@ -14,12 +16,24 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - fn begin(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut MySqlConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { let depth = conn.inner.transaction_depth; - - conn.execute(&*begin_ansi_transaction_sql(depth)).await?; - conn.inner.transaction_depth = depth + 1; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in a transaction + // (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + conn.execute(&*statement).await?; + if !conn.in_transaction() { + return Err(Error::BeginFailed); + } + conn.inner.transaction_depth += 1; Ok(()) }) @@ -65,4 +79,8 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.transaction_depth = depth - 1; } } + + fn get_transaction_depth(conn: &MySqlConnection) -> usize { + conn.inner.transaction_depth + } } diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index a7b30fb65..762f53e5d 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,6 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt}; +use std::borrow::Cow; use std::{future, pin::pin}; use sqlx_core::any::{ @@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - PgTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + PgTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -55,6 +59,10 @@ impl AnyConnectionBackend for PgConnection { PgTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + PgTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e53..96e3e2fe1 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -127,6 +128,13 @@ impl PgConnection { Ok(()) } + + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.transaction_status { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } } impl Debug for PgConnection { @@ -179,7 +187,17 @@ impl Connection for PgConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index e7c78488e..23352a8dc 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,4 +1,6 @@ use futures_core::future::BoxFuture; +use sqlx_core::database::Database; +use std::borrow::Cow; use crate::error::Error; use crate::executor::Executor; @@ -13,13 +15,27 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { + fn begin<'conn>( + conn: &'conn mut PgConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { Box::pin(async move { + let depth = conn.inner.transaction_depth; + let statement = match statement { + // custom `BEGIN` statements are not allowed if we're already in + // a transaction (we need to issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; + let rollback = Rollback::new(conn); - let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth); - rollback.conn.queue_simple_query(&query)?; - rollback.conn.inner.transaction_depth += 1; + rollback.conn.queue_simple_query(&statement)?; rollback.conn.wait_until_ready().await?; + if !rollback.conn.in_transaction() { + return Err(Error::BeginFailed); + } + rollback.conn.inner.transaction_depth += 1; rollback.defuse(); Ok(()) @@ -62,6 +78,10 @@ impl TransactionManager for PgTransactionManager { conn.inner.transaction_depth -= 1; } } + + fn get_transaction_depth(conn: &::Connection) -> usize { + conn.inner.transaction_depth + } } struct Rollback<'c> { diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index 2cc585540..c72370d0f 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -38,8 +40,11 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self) } - fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { - SqliteTransactionManager::begin(self) + fn begin( + &mut self, + statement: Option>, + ) -> BoxFuture<'_, sqlx_core::Result<()>> { + SqliteTransactionManager::begin(self, statement) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { @@ -54,6 +59,10 @@ impl AnyConnectionBackend for SqliteConnection { SqliteTransactionManager::start_rollback(self) } + fn get_transaction_depth(&self) -> usize { + SqliteTransactionManager::get_transaction_depth(self) + } + fn shrink_buffers(&mut self) { // NO-OP. } diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 334b1616a..c5d2450fb 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -288,7 +288,6 @@ impl EstablishParams { Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), - transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3316ad40c..b94ad91c4 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -11,8 +12,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, - sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, + sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -106,9 +107,6 @@ unsafe impl Send for RollbackHookHandler {} pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, - // transaction status - pub(crate) transaction_depth: usize, - pub(crate) statements: Statements, log_settings: LogSettings, @@ -253,14 +251,21 @@ impl Connection for SqliteConnection { where Self: Sized, { - Transaction::begin(self) + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl Into>, + ) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into())) } fn cached_statements_size(&self) -> usize { - self.worker - .shared - .cached_statements_size - .load(std::sync::atomic::Ordering::Acquire) + self.worker.shared.get_cached_statements_size() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { @@ -547,6 +552,11 @@ impl LockedSqliteHandle<'_> { pub fn last_error(&mut self) -> Option { self.guard.handle.last_error() } + + pub(crate) fn in_transaction(&mut self) -> bool { + let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; + ret == 0 + } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index 8a1d140b2..00a4c2999 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -36,10 +36,21 @@ pub(crate) struct ConnectionWorker { } pub(crate) struct WorkerSharedState { - pub(crate) cached_statements_size: AtomicUsize, + transaction_depth: AtomicUsize, + cached_statements_size: AtomicUsize, pub(crate) conn: Mutex, } +impl WorkerSharedState { + pub(crate) fn get_transaction_depth(&self) -> usize { + self.transaction_depth.load(Ordering::Acquire) + } + + pub(crate) fn get_cached_statements_size(&self) -> usize { + self.cached_statements_size.load(Ordering::Acquire) + } +} + enum Command { Prepare { query: Box, @@ -68,6 +79,7 @@ enum Command { }, Begin { tx: rendezvous_oneshot::Sender>, + statement: Option>, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -105,6 +117,7 @@ impl ConnectionWorker { }; let shared = Arc::new(WorkerSharedState { + transaction_depth: AtomicUsize::new(0), cached_statements_size: AtomicUsize::new(0), // note: must be fair because in `Command::UnlockDb` we unlock the mutex // and then immediately try to relock it; an unfair mutex would immediately @@ -194,13 +207,27 @@ impl ConnectionWorker { update_cached_statements_size(&conn, &shared.cached_statements_size); } - Command::Begin { tx } => { - let depth = conn.transaction_depth; + Command::Begin { tx, statement } => { + let depth = shared.transaction_depth.load(Ordering::Acquire); + + let statement = match statement { + // custom `BEGIN` statements are not allowed if + // we're already in a transaction (we need to + // issue a `SAVEPOINT` instead) + Some(_) if depth > 0 => { + if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() { + break; + } + continue; + }, + Some(statement) => statement, + None => begin_ansi_transaction_sql(depth), + }; let res = conn.handle - .exec(begin_ansi_transaction_sql(depth)) + .exec(statement) .map(|_| { - conn.transaction_depth += 1; + shared.transaction_depth.fetch_add(1, Ordering::Release); }); let res_ok = res.is_ok(); @@ -213,7 +240,7 @@ impl ConnectionWorker { .handle .exec(rollback_ansi_transaction_sql(depth + 1)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) { // The rollback failed. To prevent leaving the connection @@ -225,13 +252,13 @@ impl ConnectionWorker { } } Command::Commit { tx } => { - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(commit_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -251,13 +278,13 @@ impl ConnectionWorker { continue; } - let depth = conn.transaction_depth; + let depth = shared.transaction_depth.load(Ordering::Acquire); let res = if depth > 0 { conn.handle .exec(rollback_ansi_transaction_sql(depth)) .map(|_| { - conn.transaction_depth -= 1; + shared.transaction_depth.fetch_sub(1, Ordering::Release); }) } else { Ok(()) @@ -351,8 +378,11 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin(&mut self) -> Result<(), Error> { - self.oneshot_cmd_with_ack(|tx| Command::Begin { tx }) + pub(crate) async fn begin( + &mut self, + statement: Option>, + ) -> Result<(), Error> { + self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 24eaca51b..55a80ab9f 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,17 +1,33 @@ use futures_core::future::BoxFuture; +use std::borrow::Cow; -use crate::{Sqlite, SqliteConnection}; use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use crate::{Sqlite, SqliteConnection}; + /// Implementation of [`TransactionManager`] for SQLite. pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - fn begin(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(conn.worker.begin()) + fn begin<'conn>( + conn: &'conn mut SqliteConnection, + statement: Option>, + ) -> BoxFuture<'conn, Result<(), Error>> { + Box::pin(async { + let is_custom_statement = statement.is_some(); + conn.worker.begin(statement).await?; + if is_custom_statement { + // Check that custom statement actually put the connection into a transaction. + let mut handle = conn.lock_handle().await?; + if !handle.in_transaction() { + return Err(Error::BeginFailed); + } + } + Ok(()) + }) } fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> { @@ -25,4 +41,8 @@ impl TransactionManager for SqliteTransactionManager { fn start_rollback(conn: &mut SqliteConnection) { conn.worker.start_rollback().ok(); } + + fn get_transaction_depth(conn: &SqliteConnection) -> usize { + conn.worker.shared.get_transaction_depth() + } } diff --git a/tests/mysql/error.rs b/tests/mysql/error.rs index 7c84266c3..3ee1024fc 100644 --- a/tests/mysql/error.rs +++ b/tests/mysql/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, mysql::MySql, Connection}; +use sqlx::{error::ErrorKind, mysql::MySql, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/error.rs b/tests/postgres/error.rs index d6f78140d..32bf81477 100644 --- a/tests/postgres/error.rs +++ b/tests/postgres/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, postgres::Postgres, Connection}; +use sqlx::{error::ErrorKind, postgres::Postgres, Connection, Error}; use sqlx_test::new; #[sqlx_macros::test] @@ -74,3 +74,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7de4a9cdc..fc7108bf4 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -515,6 +515,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { let mut conn = new::().await?; + assert!(!conn.is_in_transaction()); conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") .await?; @@ -523,6 +524,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin let mut tx = conn.begin().await?; // transaction + assert!(tx.is_in_transaction()); // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -532,6 +534,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // begin once more let mut tx2 = tx.begin().await?; // savepoint + assert!(tx2.is_in_transaction()); // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") @@ -541,6 +544,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // never mind, rollback tx2.rollback().await?; // roll that one back + assert!(tx.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") @@ -551,6 +555,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // actually, commit tx.commit().await?; + assert!(!conn.is_in_transaction()); // did we really? let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") diff --git a/tests/sqlite/error.rs b/tests/sqlite/error.rs index 1f6b797e6..8729842b7 100644 --- a/tests/sqlite/error.rs +++ b/tests/sqlite/error.rs @@ -1,4 +1,4 @@ -use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Executor}; +use sqlx::{error::ErrorKind, sqlite::Sqlite, Connection, Error, Executor}; use sqlx_test::new; #[sqlx_macros::test] @@ -70,3 +70,29 @@ async fn it_fails_with_check_violation() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_fails_with_begin_failed() -> anyhow::Result<()> { + let mut conn = new::().await?; + let res = conn.begin_with("SELECT * FROM tweet").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::BeginFailed), "{err:?}"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_invalid_save_point_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut txn = conn.begin().await?; + let txn_conn = sqlx::Acquire::acquire(&mut txn).await?; + let res = txn_conn.begin_with("BEGIN").await; + + let err = res.unwrap_err(); + + assert!(matches!(err, Error::InvalidSavePointStatement), "{err}"); + + Ok(()) +} diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 92a113873..c23c4fc9e 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -6,6 +6,7 @@ use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; +use sqlx_sqlite::LockedSqliteHandle; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -1316,3 +1317,53 @@ async fn test_serialize_invalid_schema() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_use_transaction_options() -> anyhow::Result<()> { + async fn check_txn_state(conn: &mut SqliteConnection, expected: SqliteTransactionState) { + let state = transaction_state(&mut conn.lock_handle().await.unwrap()); + assert_eq!(state, expected); + } + + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); + + check_txn_state(&mut conn, SqliteTransactionState::None).await; + + let mut tx = conn.begin_with("BEGIN DEFERRED").await?; + check_txn_state(&mut tx, SqliteTransactionState::None).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + let mut tx = conn.begin_with("BEGIN EXCLUSIVE").await?; + check_txn_state(&mut tx, SqliteTransactionState::Write).await; + drop(tx); + + Ok(()) +} + +fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { + use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE}; + + let unchecked_state = + unsafe { sqlite3_txn_state(handle.as_raw_handle().as_ptr(), std::ptr::null()) }; + match unchecked_state { + SQLITE_TXN_NONE => SqliteTransactionState::None, + SQLITE_TXN_READ => SqliteTransactionState::Read, + SQLITE_TXN_WRITE => SqliteTransactionState::Write, + _ => panic!("unknown txn state: {unchecked_state}"), + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SqliteTransactionState { + None, + Read, + Write, +} From e474be6d4b4f7b8a1bbeb65363ef53015feebc47 Mon Sep 17 00:00:00 2001 From: Robin Schroer Date: Sun, 16 Mar 2025 15:21:56 +0900 Subject: [PATCH 11/12] docs: Fix a copy-paste error on get_username docs (#3786) I suspect this is a copy-paste error, it's meant to say username, not port. --- sqlx-mysql/src/options/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index db2b20c19..87732cb40 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -448,7 +448,7 @@ impl MySqlConnectOptions { self.socket.as_ref() } - /// Get the server's port. + /// Get the current username. /// /// # Example /// From 1c9cbe939ada22f377e51f3d60d538bcfc567e8f Mon Sep 17 00:00:00 2001 From: Beau Gieskens Date: Mon, 24 Mar 2025 10:19:05 +1000 Subject: [PATCH 12/12] feat: add ipnet support (#3710) * feat: add ipnet support * fix: ipnet not decoding IP address strings * fix: prefer ipnetwork to ipnet for compatibility * fix: unnecessary cfg --- Cargo.lock | 8 ++ Cargo.toml | 3 + README.md | 2 + sqlx-core/Cargo.toml | 1 + sqlx-core/src/types/mod.rs | 7 + sqlx-macros-core/Cargo.toml | 1 + sqlx-macros/Cargo.toml | 1 + sqlx-postgres/Cargo.toml | 2 + sqlx-postgres/src/type_checking.rs | 6 + sqlx-postgres/src/types/ipnet/ipaddr.rs | 62 +++++++++ sqlx-postgres/src/types/ipnet/ipnet.rs | 130 ++++++++++++++++++ sqlx-postgres/src/types/ipnet/mod.rs | 7 + .../src/types/{ => ipnetwork}/ipaddr.rs | 0 .../src/types/{ => ipnetwork}/ipnetwork.rs | 0 sqlx-postgres/src/types/ipnetwork/mod.rs | 5 + sqlx-postgres/src/types/mod.rs | 19 ++- tests/postgres/types.rs | 43 +++++- tests/ui-tests.rs | 2 +- 18 files changed, 293 insertions(+), 6 deletions(-) create mode 100644 sqlx-postgres/src/types/ipnet/ipaddr.rs create mode 100644 sqlx-postgres/src/types/ipnet/ipnet.rs create mode 100644 sqlx-postgres/src/types/ipnet/mod.rs rename sqlx-postgres/src/types/{ => ipnetwork}/ipaddr.rs (100%) rename sqlx-postgres/src/types/{ => ipnetwork}/ipnetwork.rs (100%) create mode 100644 sqlx-postgres/src/types/ipnetwork/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 07754e7c2..f1c4604c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1939,6 +1939,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ipnet" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" + [[package]] name = "ipnetwork" version = "0.20.0" @@ -3443,6 +3449,7 @@ dependencies = [ "hashbrown 0.15.2", "hashlink", "indexmap 2.7.0", + "ipnet", "ipnetwork", "log", "mac_address", @@ -3698,6 +3705,7 @@ dependencies = [ "hkdf", "hmac", "home", + "ipnet", "ipnetwork", "itoa", "log", diff --git a/Cargo.toml b/Cargo.toml index f31d715b2..fe2669794 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ _unstable-all-types = [ "json", "time", "chrono", + "ipnet", "ipnetwork", "mac_address", "uuid", @@ -117,6 +118,7 @@ json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sq bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] @@ -144,6 +146,7 @@ sqlx = { version = "=0.8.3", path = ".", default-features = false } bigdecimal = "0.4.0" bit-vec = "0.6.3" chrono = { version = "0.4.34", default-features = false, features = ["std", "clock"] } +ipnet = "2.3.0" ipnetwork = "0.20.0" mac_address = "1.1.5" rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] } diff --git a/README.md b/README.md index c3b501ca4..cc0ecf2e6 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,8 @@ be removed in the future. - `rust_decimal`: Add support for `NUMERIC` using the `rust_decimal` crate. +- `ipnet`: Add support for `INET` and `CIDR` (in postgres) using the `ipnet` crate. + - `ipnetwork`: Add support for `INET` and `CIDR` (in postgres) using the `ipnetwork` crate. - `json`: Add support for `JSON` and `JSONB` (in postgres) using the `serde_json` crate. diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index dcd808302..f6017a9fe 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -48,6 +48,7 @@ bit-vec = { workspace = true, optional = true } bigdecimal = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 909dd4927..b00427daa 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -67,6 +67,13 @@ pub use bigdecimal::BigDecimal; #[doc(no_inline)] pub use rust_decimal::Decimal; +#[cfg(feature = "ipnet")] +#[cfg_attr(docsrs, doc(cfg(feature = "ipnet")))] +pub mod ipnet { + #[doc(no_inline)] + pub use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +} + #[cfg(feature = "ipnetwork")] #[cfg_attr(docsrs, doc(cfg(feature = "ipnetwork")))] pub mod ipnetwork { diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 46786b7d8..85efa8091 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -38,6 +38,7 @@ json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlit bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 5617d3f25..b513c3e80 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -37,6 +37,7 @@ sqlite-unbundled = ["sqlx-macros-core/sqlite-unbundled"] bigdecimal = ["sqlx-macros-core/bigdecimal"] bit-vec = ["sqlx-macros-core/bit-vec"] chrono = ["sqlx-macros-core/chrono"] +ipnet = ["sqlx-macros-core/ipnet"] ipnetwork = ["sqlx-macros-core/ipnetwork"] mac_address = ["sqlx-macros-core/mac_address"] rust_decimal = ["sqlx-macros-core/rust_decimal"] diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 174a73b3f..818aadbab 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -19,6 +19,7 @@ offline = ["sqlx-core/offline"] bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"] bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"] chrono = ["dep:chrono", "sqlx-core/chrono"] +ipnet = ["dep:ipnet", "sqlx-core/ipnet"] ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"] mac_address = ["dep:mac_address", "sqlx-core/mac_address"] rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"] @@ -43,6 +44,7 @@ sha2 = { version = "0.10.0", default-features = false } bigdecimal = { workspace = true, optional = true } bit-vec = { workspace = true, optional = true } chrono = { workspace = true, optional = true } +ipnet = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index a28531c9b..672d9f73e 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -88,6 +88,9 @@ impl_type_checking!( #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, + #[cfg(feature = "ipnet")] + sqlx::types::ipnet::IpNet, + #[cfg(feature = "mac_address")] sqlx::types::mac_address::MacAddress, @@ -149,6 +152,9 @@ impl_type_checking!( #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], + #[cfg(feature = "ipnet")] + Vec | &[sqlx::types::ipnet::IpNet], + #[cfg(feature = "mac_address")] Vec | &[sqlx::types::mac_address::MacAddress], diff --git a/sqlx-postgres/src/types/ipnet/ipaddr.rs b/sqlx-postgres/src/types/ipnet/ipaddr.rs new file mode 100644 index 000000000..b157eff3c --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipaddr.rs @@ -0,0 +1,62 @@ +use std::net::IpAddr; + +use ipnet::IpNet; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; + +impl Type for IpAddr +where + IpNet: Type, +{ + fn type_info() -> PgTypeInfo { + IpNet::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + IpNet::compatible(ty) + } +} + +impl PgHasArrayType for IpAddr { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + ::array_compatible(ty) + } +} + +impl<'db> Encode<'db, Postgres> for IpAddr +where + IpNet: Encode<'db, Postgres>, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + IpNet::from(*self).encode_by_ref(buf) + } + + fn size_hint(&self) -> usize { + IpNet::from(*self).size_hint() + } +} + +impl<'db> Decode<'db, Postgres> for IpAddr +where + IpNet: Decode<'db, Postgres>, +{ + fn decode(value: PgValueRef<'db>) -> Result { + let ipnet = IpNet::decode(value)?; + + if matches!(ipnet, IpNet::V4(net) if net.prefix_len() != 32) + || matches!(ipnet, IpNet::V6(net) if net.prefix_len() != 128) + { + Err("lossy decode from inet/cidr")? + } + + Ok(ipnet.addr()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/ipnet.rs b/sqlx-postgres/src/types/ipnet/ipnet.rs new file mode 100644 index 000000000..1f986174b --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/ipnet.rs @@ -0,0 +1,130 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +#[cfg(feature = "ipnet")] +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; + +// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39 + +// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc` +// just for one constant. +const PGSQL_AF_INET: u8 = 2; // AF_INET +const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1; + +impl Type for IpNet { + fn type_info() -> PgTypeInfo { + PgTypeInfo::INET + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET + } +} + +impl PgHasArrayType for IpNet { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::INET_ARRAY + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } +} + +impl Encode<'_, Postgres> for IpNet { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 + // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 + + match self { + IpNet::V4(net) => { + buf.push(PGSQL_AF_INET); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(4); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()) // address + } + + IpNet::V6(net) => { + buf.push(PGSQL_AF_INET6); // ip_family + buf.push(net.prefix_len()); // ip_bits + buf.push(0); // is_cidr + buf.push(16); // nb (number of bytes) + buf.extend_from_slice(&net.addr().octets()); // address + } + } + + Ok(IsNull::No) + } + + fn size_hint(&self) -> usize { + match self { + IpNet::V4(_) => 8, + IpNet::V6(_) => 20, + } + } +} + +impl Decode<'_, Postgres> for IpNet { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + let s = value.as_str()?; + println!("{s}"); + if s.contains('/') { + return Ok(s.parse()?); + } + // IpNet::from_str doesn't handle conversion from IpAddr to IpNet + let addr: IpAddr = s.parse()?; + return Ok(addr.into()); + } + }; + + if bytes.len() >= 8 { + let family = bytes[0]; + let prefix = bytes[1]; + let _is_cidr = bytes[2] != 0; + let len = bytes[3]; + + match family { + PGSQL_AF_INET => { + if bytes.len() == 8 && len == 4 { + let inet = Ipv4Net::new( + Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), + prefix, + )?; + + return Ok(IpNet::V4(inet)); + } + } + + PGSQL_AF_INET6 => { + if bytes.len() == 20 && len == 16 { + let inet = Ipv6Net::new( + Ipv6Addr::from([ + bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], + bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + bytes[16], bytes[17], bytes[18], bytes[19], + ]), + prefix, + )?; + + return Ok(IpNet::V6(inet)); + } + } + + _ => { + return Err(format!("unknown ip family {family}").into()); + } + } + } + + Err("invalid data received when expecting an INET".into()) + } +} diff --git a/sqlx-postgres/src/types/ipnet/mod.rs b/sqlx-postgres/src/types/ipnet/mod.rs new file mode 100644 index 000000000..cd40cf30d --- /dev/null +++ b/sqlx-postgres/src/types/ipnet/mod.rs @@ -0,0 +1,7 @@ +// Prefer `ipnetwork` over `ipnet` because it was implemented first (want to avoid breaking change). +#[cfg(not(feature = "ipnetwork"))] +mod ipaddr; + +// Parent module is named after the `ipnet` crate, this is named after the `IpNet` type. +#[allow(clippy::module_inception)] +mod ipnet; diff --git a/sqlx-postgres/src/types/ipaddr.rs b/sqlx-postgres/src/types/ipnetwork/ipaddr.rs similarity index 100% rename from sqlx-postgres/src/types/ipaddr.rs rename to sqlx-postgres/src/types/ipnetwork/ipaddr.rs diff --git a/sqlx-postgres/src/types/ipnetwork.rs b/sqlx-postgres/src/types/ipnetwork/ipnetwork.rs similarity index 100% rename from sqlx-postgres/src/types/ipnetwork.rs rename to sqlx-postgres/src/types/ipnetwork/ipnetwork.rs diff --git a/sqlx-postgres/src/types/ipnetwork/mod.rs b/sqlx-postgres/src/types/ipnetwork/mod.rs new file mode 100644 index 000000000..de40244c6 --- /dev/null +++ b/sqlx-postgres/src/types/ipnetwork/mod.rs @@ -0,0 +1,5 @@ +mod ipaddr; + +// Parent module is named after the `ipnetwork` crate, this is named after the `IpNetwork` type. +#[allow(clippy::module_inception)] +mod ipnetwork; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index c3493139c..0faefbb48 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -87,7 +87,7 @@ //! //! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) //! -//! Requires the `ipnetwork` Cargo feature flag. +//! Requires the `ipnetwork` Cargo feature flag (takes precedence over `ipnet` if both are used). //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| @@ -100,6 +100,17 @@ //! //! `IpNetwork` does not have this limitation. //! +//! ### [`ipnet`](https://crates.io/crates/ipnet) +//! +//! Requires the `ipnet` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `ipnet::IpNet` | INET, CIDR | +//! | `std::net::IpAddr` | INET, CIDR | +//! +//! The same `IpAddr` limitation for smaller network prefixes applies as with `ipnet`. +//! //! ### [`mac_address`](https://crates.io/crates/mac_address) //! //! Requires the `mac_address` Cargo feature flag. @@ -248,11 +259,11 @@ mod time; #[cfg(feature = "uuid")] mod uuid; -#[cfg(feature = "ipnetwork")] -mod ipnetwork; +#[cfg(feature = "ipnet")] +mod ipnet; #[cfg(feature = "ipnetwork")] -mod ipaddr; +mod ipnetwork; #[cfg(feature = "mac_address")] mod mac_address; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index da20467ea..d5d34bc1b 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -2,6 +2,7 @@ extern crate time_ as time; use std::net::SocketAddr; use std::ops::Bound; +use std::str::FromStr; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; @@ -9,7 +10,6 @@ use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; use sqlx_core::executor::Executor; use sqlx_core::types::Text; -use std::str::FromStr; test_type!(null>(Postgres, "NULL::int2" == None:: @@ -171,6 +171,38 @@ test_type!(uuid_vec>(Postgres, ] )); +#[cfg(feature = "ipnet")] +test_type!(ipnet(Postgres, + "'127.0.0.1'::inet" + == "127.0.0.1/32" + .parse::() + .unwrap(), + "'8.8.8.8/24'::inet" + == "8.8.8.8/24" + .parse::() + .unwrap(), + "'10.1.1/24'::inet" + == "10.1.1.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0'::inet" + == "::ffff:1.2.3.0/128" + .parse::() + .unwrap(), + "'2001:4f8:3:ba::/64'::inet" + == "2001:4f8:3:ba::/64" + .parse::() + .unwrap(), + "'192.168'::cidr" + == "192.168.0.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0/120'::cidr" + == "::ffff:1.2.3.0/120" + .parse::() + .unwrap(), +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork(Postgres, "'127.0.0.1'::inet" @@ -232,6 +264,15 @@ test_type!(bitvec( }, )); +#[cfg(feature = "ipnet")] +test_type!(ipnet_vec>(Postgres, + "'{127.0.0.1,8.8.8.8/24}'::inet[]" + == vec![ + "127.0.0.1/32".parse::().unwrap(), + "8.8.8.8/24".parse::().unwrap() + ] +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork_vec>(Postgres, "'{127.0.0.1,8.8.8.8/24}'::inet[]" diff --git a/tests/ui-tests.rs b/tests/ui-tests.rs index f74694b87..4a5ca240e 100644 --- a/tests/ui-tests.rs +++ b/tests/ui-tests.rs @@ -17,7 +17,7 @@ fn ui_tests() { t.compile_fail("tests/ui/postgres/gated/uuid.rs"); } - if cfg!(not(feature = "ipnetwork")) { + if cfg!(not(feature = "ipnet")) && cfg!(not(feature = "ipnetwork")) { t.compile_fail("tests/ui/postgres/gated/ipnetwork.rs"); } }