From 16e3f1025ad1e106d1acff05f591b8db62d688e2 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 4 Jul 2024 17:17:20 -0700 Subject: [PATCH] fix(postgres): add missing type resolution for arrays by name --- Cargo.lock | 28 ++-- Cargo.toml | 2 +- sqlx-core/Cargo.toml | 2 +- sqlx-core/src/ext/ustr.rs | 14 ++ sqlx-core/src/lib.rs | 7 +- sqlx-core/src/type_info.rs | 10 ++ sqlx-core/src/types/mod.rs | 4 +- sqlx-macros-core/src/derives/type.rs | 52 +++++-- sqlx-postgres/Cargo.toml | 3 + sqlx-postgres/src/arguments.rs | 30 +++- sqlx-postgres/src/connection/describe.rs | 45 +++++- sqlx-postgres/src/connection/establish.rs | 1 + sqlx-postgres/src/connection/executor.rs | 7 +- sqlx-postgres/src/connection/mod.rs | 1 + sqlx-postgres/src/type_info.rs | 182 +++++++++++++++++++--- sqlx-postgres/src/types/array.rs | 10 +- sqlx-sqlite/src/connection/explain.rs | 6 +- tests/migrate/macro.rs | 1 + tests/postgres/derives.rs | 12 +- 19 files changed, 333 insertions(+), 84 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b351e991..851cc162 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -35,7 +35,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", "once_cell", "version_check", "zerocopy", @@ -574,9 +573,9 @@ dependencies = [ [[package]] name = "borsh" -version = "1.3.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f58b559fd6448c6e2fd0adb5720cd98a2506594cafa4737ff98c396f3e82f667" +checksum = "a6362ed55def622cddc70a4746a68554d7b687713770de539e59a739b249f8ed" dependencies = [ "borsh-derive", "cfg_aliases", @@ -584,9 +583,9 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.3.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aadb5b6ccbd078890f6d7003694e33816e6b784358f18e15e7e6d9f065a57cd" +checksum = "c3ef8005764f53cd4dca619f5bf64cafd4664dada50ece25e4d81de54c80cc0b" dependencies = [ "once_cell", "proc-macro-crate", @@ -704,9 +703,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cfg_aliases" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" @@ -1561,9 +1560,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.11", "allocator-api2", @@ -1575,7 +1574,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "692eaaf7f7607518dd3cef090f1474b61edc5301d8012f09579920df68b725ee" dependencies = [ - "hashbrown 0.14.3", + "hashbrown 0.14.5", ] [[package]] @@ -1789,7 +1788,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown 0.14.5", ] [[package]] @@ -3227,7 +3226,6 @@ dependencies = [ name = "sqlx-core" version = "0.8.0-alpha.0" dependencies = [ - "ahash 0.8.11", "async-io 1.13.0", "async-std", "atoi", @@ -3248,6 +3246,7 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", + "hashbrown 0.14.5", "hashlink", "hex", "indexmap 2.2.5", @@ -3524,6 +3523,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", + "sqlx", "sqlx-core", "stringprep", "thiserror", @@ -3837,9 +3837,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" [[package]] name = "toml_edit" diff --git a/Cargo.toml b/Cargo.toml index e25ac18b..2a7c75d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -135,7 +135,7 @@ bit-vec = "0.6.3" chrono = { version = "0.4.22", default-features = false } ipnetwork = "0.20.0" mac_address = "1.1.5" -rust_decimal = "1.26.1" +rust_decimal = { version = "1.26.1", default-features = false, features = ["std"] } time = { version = "0.3.36", features = ["formatting", "parsing", "macros"] } uuid = "1.1.2" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index d81414b5..2917932b 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -51,7 +51,6 @@ uuid = { workspace = true, optional = true } async-io = { version = "1.9.0", optional = true } paste = "1.0.6" -ahash = "0.8.7" atoi = "2.0" bytes = "1.1.0" @@ -88,6 +87,7 @@ bstr = { version = "1.0", default-features = false, features = ["std"], optional hashlink = "0.9.0" indexmap = "2.0" event-listener = "5.2.0" +hashbrown = "0.14.5" [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } diff --git a/sqlx-core/src/ext/ustr.rs b/sqlx-core/src/ext/ustr.rs index 0e60fdfc..95fa754f 100644 --- a/sqlx-core/src/ext/ustr.rs +++ b/sqlx-core/src/ext/ustr.rs @@ -17,6 +17,14 @@ impl UStr { pub fn new(s: &str) -> Self { UStr::Shared(Arc::from(s.to_owned())) } + + /// Apply [str::strip_prefix], without copying if possible. + pub fn strip_prefix(this: &Self, prefix: &str) -> Option { + match this { + UStr::Static(s) => s.strip_prefix(prefix).map(Self::Static), + UStr::Shared(s) => s.strip_prefix(prefix).map(|s| Self::Shared(s.into())), + } + } } impl Deref for UStr { @@ -60,6 +68,12 @@ impl From<&'static str> for UStr { } } +impl<'a> From<&'a UStr> for UStr { + fn from(value: &'a UStr) -> Self { + value.clone() + } +} + impl From for UStr { #[inline] fn from(s: String) -> Self { diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index ef8b267c..cc0122c9 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -95,9 +95,8 @@ pub mod testing; pub use error::{Error, Result}; -/// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. -pub use ahash::AHashMap as HashMap; pub use either::Either; +pub use hashbrown::{hash_map, HashMap}; pub use indexmap::IndexMap; pub use percent_encoding; pub use smallvec::SmallVec; @@ -105,8 +104,6 @@ pub use url::{self, Url}; pub use bytes; -//type HashMap = std::collections::HashMap; - /// Helper module to get drivers compiling again that used to be in this crate, /// to avoid having to replace tons of `use crate::<...>` imports. /// @@ -119,6 +116,6 @@ pub mod driver_prelude { }; pub use crate::error::{Error, Result}; - pub use crate::HashMap; + pub use crate::{hash_map, HashMap}; pub use either::Either; } diff --git a/sqlx-core/src/type_info.rs b/sqlx-core/src/type_info.rs index 72b2a5a9..812bcf37 100644 --- a/sqlx-core/src/type_info.rs +++ b/sqlx-core/src/type_info.rs @@ -9,6 +9,16 @@ pub trait TypeInfo: Debug + Display + Clone + PartialEq + Send + Sync { /// should be a rough approximation of how they are written in SQL in the given database. fn name(&self) -> &str; + /// Return `true` if `self` and `other` represent mutually compatible types. + /// + /// Defaults to `self == other`. + fn type_compatible(&self, other: &Self) -> bool + where + Self: Sized, + { + self == other + } + #[doc(hidden)] fn is_void(&self) -> bool { false diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index e83c27a1..25837b1e 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -210,8 +210,10 @@ pub trait Type { /// /// When binding arguments with `query!` or `query_as!`, this method is consulted to determine /// if the Rust type is acceptable. + /// + /// Defaults to checking [`TypeInfo::type_compatible()`]. fn compatible(ty: &DB::TypeInfo) -> bool { - *ty == Self::type_info() + Self::type_info().type_compatible(ty) } } diff --git a/sqlx-macros-core/src/derives/type.rs b/sqlx-macros-core/src/derives/type.rs index ef14918d..d035ec5a 100644 --- a/sqlx-macros-core/src/derives/type.rs +++ b/sqlx-macros-core/src/derives/type.rs @@ -14,28 +14,27 @@ use syn::{ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; match &input.data { + // Newtype structs: + // struct Foo(i32); Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. - }) if unnamed.len() == 1 => { - expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + }) => { + if unnamed.len() == 1 { + expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) + } else { + Err(syn::Error::new_spanned( + input, + "structs with zero or more than one unnamed field are not supported", + )) + } } - Data::Enum(DataEnum { variants, .. }) => match attrs.repr { - Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), - None => expand_derive_has_sql_type_strong_enum(input, variants), - }, + // Record types + // struct Foo { foo: i32, bar: String } Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_has_sql_type_struct(input, named), - Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), - Data::Struct(DataStruct { - fields: Fields::Unnamed(..), - .. - }) => Err(syn::Error::new_spanned( - input, - "structs with zero or more than one unnamed field are not supported", - )), Data::Struct(DataStruct { fields: Fields::Unit, .. @@ -43,6 +42,14 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { input, "unit structs are not supported", )), + + Data::Enum(DataEnum { variants, .. }) => match attrs.repr { + // Enums that encode to/from integers (weak enums) + Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), + // Enums that decode to/from strings (strong enums) + None => expand_derive_has_sql_type_strong_enum(input, variants), + }, + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), } } @@ -148,9 +155,10 @@ fn expand_derive_has_sql_type_weak_enum( if cfg!(feature = "postgres") && !attrs.no_pg_array { ts.extend(quote!( + #[automatically_derived] impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { - <#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info() + <#repr as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); @@ -197,9 +205,10 @@ fn expand_derive_has_sql_type_strong_enum( if !attributes.no_pg_array { tts.extend(quote!( + #[automatically_derived] impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { - <#ident as ::sqlx::postgres::PgHasArrayType>::array_type_info() + ::sqlx::postgres::PgTypeInfo::array_of(#ty_name) } } )); @@ -244,6 +253,17 @@ fn expand_derive_has_sql_type_struct( } } )); + + if !attributes.no_pg_array { + tts.extend(quote!( + #[automatically_derived] + impl ::sqlx::postgres::PgHasArrayType for #ident { + fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { + ::sqlx::postgres::PgTypeInfo::array_of(#ty_name) + } + } + )); + } } Ok(tts) diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 5f6b8392..1ed9b14f 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -71,5 +71,8 @@ workspace = true # We use JSON in the driver implementation itself so there's no reason not to enable it here. features = ["json"] +[dev-dependencies] +sqlx.workspace = true + [target.'cfg(target_os = "windows")'.dependencies] etcetera = "0.8.0" diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 9975b2fd..7911a066 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -1,5 +1,6 @@ use std::fmt::{self, Write}; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; use crate::encode::{Encode, IsNull}; use crate::error::Error; @@ -7,6 +8,7 @@ use crate::ext::ustr::UStr; use crate::types::Type; use crate::{PgConnection, PgTypeInfo, Postgres}; +use crate::type_info::PgArrayOf; pub(crate) use sqlx_core::arguments::Arguments; use sqlx_core::error::BoxDynError; @@ -41,7 +43,12 @@ pub struct PgArgumentBuffer { // This is done for Records and Arrays as the OID is needed well before we are in an async // function and can just ask postgres. // - type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }> + type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }> +} + +enum HoleKind { + Type { name: UStr }, + Array(Arc), } struct Patch { @@ -106,8 +113,11 @@ impl PgArguments { (patch.callback)(buf, ty); } - for (offset, name) in type_holes { - let oid = conn.fetch_type_id_by_name(name).await?; + for (offset, kind) in type_holes { + let oid = match kind { + HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?, + HoleKind::Array(array) => conn.fetch_array_type_id(array).await?, + }; buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes()); } @@ -186,7 +196,19 @@ impl PgArgumentBuffer { let offset = self.len(); self.extend_from_slice(&0_u32.to_be_bytes()); - self.type_holes.push((offset, type_name.clone())); + self.type_holes.push(( + offset, + HoleKind::Type { + name: type_name.clone(), + }, + )); + } + + pub(crate) fn patch_array_type(&mut self, array: Arc) { + let offset = self.len(); + + self.extend_from_slice(&0_u32.to_be_bytes()); + self.type_holes.push((offset, HoleKind::Array(array))); } fn snapshot(&self) -> PgArgumentBufferSnapshot { diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 37952b8a..82bba18f 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -4,7 +4,7 @@ use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; use crate::query_scalar::{query_scalar, query_scalar_with}; use crate::statement::PgStatementMetadata; -use crate::type_info::{PgCustomType, PgType, PgTypeKind}; +use crate::type_info::{PgArrayOf, PgCustomType, PgType, PgTypeKind}; use crate::types::Json; use crate::types::Oid; use crate::HashMap; @@ -355,6 +355,19 @@ WHERE rngtypid = $1 }) } + pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result { + if let Some(oid) = ty.try_oid() { + return Ok(oid); + } + + match ty { + PgType::DeclareWithName(name) => self.fetch_type_id_by_name(name).await, + PgType::DeclareArrayOf(array) => self.fetch_array_type_id(array).await, + // `.try_oid()` should return `Some()` or it should be covered here + _ => unreachable!("(bug) OID should be resolvable for type {ty:?}"), + } + } + pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result { if let Some(oid) = self.cache_type_oid.get(name) { return Ok(*oid); @@ -366,13 +379,41 @@ WHERE rngtypid = $1 .fetch_optional(&mut *self) .await? .ok_or_else(|| Error::TypeNotFound { - type_name: String::from(name), + type_name: name.into(), })?; self.cache_type_oid.insert(name.to_string().into(), oid); Ok(oid) } + pub(crate) async fn fetch_array_type_id(&mut self, array: &PgArrayOf) -> Result { + if let Some(oid) = self + .cache_type_oid + .get(&array.elem_name) + .and_then(|elem_oid| self.cache_elem_type_to_array.get(elem_oid)) + { + return Ok(*oid); + } + + // language=SQL + let (elem_oid, array_oid): (Oid, Oid) = + query_as("SELECT oid, typarray FROM pg_catalog.pg_type WHERE oid = $1::regtype::oid") + .bind(&*array.elem_name) + .fetch_optional(&mut *self) + .await? + .ok_or_else(|| Error::TypeNotFound { + type_name: array.name.to_string(), + })?; + + // Avoids copying `elem_name` until necessary + self.cache_type_oid + .entry_ref(&array.elem_name) + .insert(elem_oid); + self.cache_elem_type_to_array.insert(elem_oid, array_oid); + + Ok(array_oid) + } + pub(crate) async fn get_nullable_for_columns( &mut self, stmt_id: Oid, diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 9f5008f9..83b9843a 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -146,6 +146,7 @@ impl PgConnection { cache_statement: StatementCache::new(options.statement_cache_capacity), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), + cache_elem_type_to_array: HashMap::new(), log_settings: options.log_settings.clone(), }) } diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 019c5b3e..bb73db1e 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -7,7 +7,6 @@ use crate::message::{ RowDescription, }; use crate::statement::PgStatementMetadata; -use crate::type_info::PgType; use crate::types::Oid; use crate::{ statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, @@ -36,11 +35,7 @@ async fn prepare( let mut param_types = Vec::with_capacity(parameters.len()); for ty in parameters { - param_types.push(if let PgType::DeclareWithName(name) = &ty.0 { - conn.fetch_type_id_by_name(name).await? - } else { - ty.0.oid() - }); + param_types.push(conn.resolve_type_id(&ty.0).await?); } // flush and wait until we are re-ready diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 6259033b..1c7a4682 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -55,6 +55,7 @@ pub struct PgConnection { // cache user-defined types by id <-> info cache_type_info: HashMap, cache_type_oid: HashMap, + cache_elem_type_to_array: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index b01a1bfa..f50ea7fb 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -11,6 +11,34 @@ use crate::types::Oid; pub(crate) use sqlx_core::type_info::TypeInfo; /// Type information for a PostgreSQL type. +/// +/// ### Note: Implementation of `==` ([`PartialEq::eq()`]) +/// Because `==` on [`TypeInfo`]s has been used throughout the SQLx API as a synonym for type compatibility, +/// e.g. in the default impl of [`Type::compatible()`][sqlx_core::types::Type::compatible], +/// some concessions have been made in the implementation. +/// +/// When comparing two `PgTypeInfo`s using the `==` operator ([`PartialEq::eq()`]), +/// if one was constructed with [`Self::with_oid()`] and the other with [`Self::with_name()`] or +/// [`Self::array_of()`], `==` will return `true`: +/// +/// ``` +/// # use sqlx::postgres::{types::Oid, PgTypeInfo}; +/// // Potentially surprising result, this assert will pass: +/// assert_eq!(PgTypeInfo::with_oid(Oid(1)), PgTypeInfo::with_name("definitely_not_real")); +/// ``` +/// +/// Since it is not possible in this case to prove the types are _not_ compatible (because +/// both `PgTypeInfo`s need to be resolved by an active connection to know for sure) +/// and type compatibility is mainly done as a sanity check anyway, +/// it was deemed acceptable to fudge equality in this very specific case. +/// +/// This also applies when querying with the text protocol (not using prepared statements, +/// e.g. [`sqlx::raw_sql()`][sqlx_core::raw_sql::raw_sql]), as the connection will be unable +/// to look up the type info like it normally does when preparing a statement: it won't know +/// what the OIDs of the output columns will be until it's in the middle of reading the result, +/// and by that time it's too late. +/// +/// To compare types for exact equality, use [`Self::type_eq()`] instead. #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgTypeInfo(pub(crate) PgType); @@ -132,6 +160,8 @@ pub enum PgType { // NOTE: Do we want to bring back type declaration by ID? It's notoriously fragile but // someone may have a user for it DeclareWithOid(Oid), + + DeclareArrayOf(Arc), } #[derive(Debug, Clone)] @@ -155,6 +185,13 @@ pub enum PgTypeKind { Range(PgTypeInfo), } +#[derive(Debug)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgArrayOf { + pub(crate) elem_name: UStr, + pub(crate) name: Box, +} + impl PgTypeInfo { /// Returns the corresponding `PgTypeInfo` if the OID is a built-in type and recognized by SQLx. pub(crate) fn try_from_oid(oid: Oid) -> Option { @@ -233,18 +270,79 @@ impl PgTypeInfo { /// /// The OID for the type will be fetched from Postgres on use of /// a value of this type. The fetched OID will be cached per-connection. + /// + /// ### Note: Type Names Prefixed with `_` + /// In `pg_catalog.pg_type`, Postgres prefixes a type name with `_` to denote an array of that + /// type, e.g. `int4[]` actually exists in `pg_type` as `_int4`. + /// + /// Previously, it was necessary in manual [`PgHasArrayType`][crate::PgHasArrayType] impls + /// to return [`PgTypeInfo::with_name()`] with the type name prefixed with `_` to denote + /// an array type, but this would not work with schema-qualified names. + /// + /// As of 0.8, [`PgTypeInfo::array_of()`] is used to declare an array type, + /// and the Postgres driver is now able to properly resolve arrays of custom types, + /// even in other schemas, which was not previously supported. + /// + /// It is highly recommended to migrate existing usages to [`PgTypeInfo::array_of()`] where + /// applicable. + /// + /// However, to maintain compatibility, the driver now infers any type name prefixed with `_` + /// to be an array of that type. This may introduce some breakages for types which use + /// a `_` prefix but which are not arrays. + /// + /// As a workaround, type names with `_` as a prefix but which are not arrays should be wrapped + /// in quotes, e.g.: + /// ``` + /// use sqlx::postgres::PgTypeInfo; + /// use sqlx::Type; + /// + /// /// `CREATE TYPE "_foo" AS ENUM ('Bar', 'Baz');` + /// #[derive(sqlx::Type)] + /// // Will prevent SQLx from inferring `_foo` as an array type. + /// #[sqlx(type_name = r#""_foo""#)] + /// enum Foo { + /// Bar, + /// Baz + /// } + /// + /// assert_eq!(Foo::type_info().name(), r#""_foo""#); + /// ``` pub const fn with_name(name: &'static str) -> Self { Self(PgType::DeclareWithName(UStr::Static(name))) } + /// Create a `PgTypeInfo` of an array from the name of its element type. + /// + /// The array type OID will be fetched from Postgres on use of a value of this type. + /// The fetched OID will be cached per-connection. + pub fn array_of(elem_name: &'static str) -> Self { + // to satisfy `name()` and `display_name()`, we need to construct strings to return + Self(PgType::DeclareArrayOf(Arc::new(PgArrayOf { + elem_name: elem_name.into(), + name: format!("{elem_name}[]").into(), + }))) + } + /// Create a `PgTypeInfo` from an OID. /// /// Note that the OID for a type is very dependent on the environment. If you only ever use /// one database or if this is an unhandled built-in type, you should be fine. Otherwise, - /// you will be better served using [`with_name`](Self::with_name). + /// you will be better served using [`Self::with_name()`]. + /// + /// ### Note: Interaction with `==` + /// This constructor may give surprising results with `==`. + /// + /// See [the type-level docs][Self] for details. pub const fn with_oid(oid: Oid) -> Self { Self(PgType::DeclareWithOid(oid)) } + + /// Returns `true` if `self` can be compared exactly to `other`. + /// + /// Unlike `==`, this will return false if + pub fn type_eq(&self, other: &Self) -> bool { + self.eq_impl(other, false) + } } // DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file @@ -464,6 +562,9 @@ impl PgType { PgType::DeclareWithName(_) => { return None; } + PgType::DeclareArrayOf(_) => { + return None; + } }) } @@ -564,6 +665,7 @@ impl PgType { PgType::Custom(ty) => &ty.name, PgType::DeclareWithOid(_) => "?", PgType::DeclareWithName(name) => name, + PgType::DeclareArrayOf(array) => &array.name, } } @@ -664,6 +766,7 @@ impl PgType { PgType::Custom(ty) => &ty.name, PgType::DeclareWithOid(_) => "?", PgType::DeclareWithName(name) => name, + PgType::DeclareArrayOf(array) => &array.name, } } @@ -771,13 +874,16 @@ impl PgType { PgType::DeclareWithName(name) => { unreachable!("(bug) use of unresolved type declaration [name={name}]"); } + PgType::DeclareArrayOf(array) => { + unreachable!( + "(bug) use of unresolved type declaration [array of={}]", + array.elem_name + ); + } } } /// If `self` is an array type, return the type info for its element. - /// - /// This method should only be called on resolved types: calling it on - /// a type that is merely declared (DeclareWithOid/Name) is a bug. pub(crate) fn try_array_element(&self) -> Option> { // We explicitly match on all the `None` cases to ensure an exhaustive match. match self { @@ -885,14 +991,50 @@ impl PgType { PgTypeKind::Enum(_) => None, PgTypeKind::Range(_) => None, }, - PgType::DeclareWithOid(oid) => { - unreachable!("(bug) use of unresolved type declaration [oid={}]", oid.0); - } + PgType::DeclareWithOid(_) => None, PgType::DeclareWithName(name) => { - unreachable!("(bug) use of unresolved type declaration [name={name}]"); + // LEGACY: infer the array element name from a `_` prefix + UStr::strip_prefix(name, "_") + .map(|elem| Cow::Owned(PgTypeInfo(PgType::DeclareWithName(elem)))) } + PgType::DeclareArrayOf(array) => Some(Cow::Owned(PgTypeInfo(PgType::DeclareWithName( + array.elem_name.clone(), + )))), } } + + /// Returns `true` if this type cannot be matched by name. + fn is_declare_with_oid(&self) -> bool { + matches!(self, Self::DeclareWithOid(_)) + } + + /// Compare two `PgType`s, first by OID, then by array element, then by name. + /// + /// If `soft_eq` is true and `self` or `other` is `DeclareWithOid` but not both, return `true` + /// before checking names. + fn eq_impl(&self, other: &Self, soft_eq: bool) -> bool { + if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) { + // If there are OIDs available, use OIDs to perform a direct match + return a == b; + } + + if soft_eq && (self.is_declare_with_oid() || other.is_declare_with_oid()) { + // If we get to this point, one instance is `DeclareWithOid()` and the other is + // `DeclareArrayOf()` or `DeclareWithName()`, which means we can't compare the two. + // + // Since this is only likely to occur when using the text protocol where we can't + // resolve type names before executing a query, we can just opt out of typechecking. + return true; + } + + if let (Some(elem_a), Some(elem_b)) = (self.try_array_element(), other.try_array_element()) + { + return elem_a == elem_b; + } + + // Otherwise, perform a match on the name + name_eq(self.name(), other.name()) + } } impl TypeInfo for PgTypeInfo { @@ -907,6 +1049,13 @@ impl TypeInfo for PgTypeInfo { fn is_void(&self) -> bool { matches!(self.0, PgType::Void) } + + fn type_compatible(&self, other: &Self) -> bool + where + Self: Sized, + { + self == other + } } impl PartialEq for PgCustomType { @@ -1140,22 +1289,7 @@ impl Display for PgTypeInfo { impl PartialEq for PgType { fn eq(&self, other: &PgType) -> bool { - if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) { - // If there are OIDs available, use OIDs to perform a direct match - a == b - } else if matches!( - (self, other), - (PgType::DeclareWithName(_), PgType::DeclareWithOid(_)) - | (PgType::DeclareWithOid(_), PgType::DeclareWithName(_)) - ) { - // One is a declare-with-name and the other is a declare-with-id - // This only occurs in the TEXT protocol with custom types - // Just opt-out of type checking here - true - } else { - // Otherwise, perform a match on the name - name_eq(self.name(), other.name()) - } + self.eq_impl(other, true) } } diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index 4b24a86a..d6242227 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -156,11 +156,10 @@ where T: Encode<'q, Postgres> + Type, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - let type_info = if self.is_empty() { - T::type_info() - } else { - self[0].produces().unwrap_or_else(T::type_info) - }; + let type_info = self + .first() + .and_then(Encode::produces) + .unwrap_or_else(T::type_info); buf.extend(&1_i32.to_be_bytes()); // number of dimensions buf.extend(&0_i32.to_be_bytes()); // flags @@ -168,6 +167,7 @@ where // element type match type_info.0 { PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), + PgType::DeclareArrayOf(array) => buf.patch_array_type(array), ty => { buf.extend(&ty.oid().0.to_be_bytes()); diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index 6a95b312..a18cd58a 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -5,7 +5,7 @@ use crate::from_row::FromRow; use crate::logger::{BranchParent, BranchResult, DebugDiff}; use crate::type_info::DataType; use crate::SqliteTypeInfo; -use sqlx_core::HashMap; +use sqlx_core::{hash_map, HashMap}; use std::fmt::Debug; use std::str::from_utf8; @@ -535,13 +535,13 @@ impl BranchList { ) { logger.add_branch(&state, &state.branch_parent.unwrap()); match self.visited_branch_state.entry(state.mem) { - std::collections::hash_map::Entry::Vacant(entry) => { + hash_map::Entry::Vacant(entry) => { //this state is not identical to another state, so it will need to be processed state.mem = entry.key().clone(); //replace state.mem since .entry() moved it entry.insert(state.get_reference()); self.states.push(state); } - std::collections::hash_map::Entry::Occupied(entry) => { + hash_map::Entry::Occupied(entry) => { //already saw a state identical to this one, so no point in processing it state.mem = entry.key().clone(); //replace state.mem since .entry() moved it logger.add_result(state, BranchResult::Dedup(*entry.get())); diff --git a/tests/migrate/macro.rs b/tests/migrate/macro.rs index 03cb578f..7bc4e913 100644 --- a/tests/migrate/macro.rs +++ b/tests/migrate/macro.rs @@ -1,3 +1,4 @@ +#![cfg(unix)] use sqlx::migrate::Migrator; use std::path::Path; diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index 086fcc6d..b4c29b48 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -155,6 +155,9 @@ test_type!(weak_enum(Postgres, "0::int4" == Weak::One, "2::int4" == Weak::Two, "4::int4" == Weak::Three, +)); + +test_type!(weak_enum_array>(Postgres, "'{0, 2, 4}'::int4[]" == vec![Weak::One, Weak::Two, Weak::Three], )); @@ -162,7 +165,10 @@ test_type!(strong_enum(Postgres, "'one'::text" == Strong::One, "'two'::text" == Strong::Two, "'four'::text" == Strong::Three, - "'{'one', 'two', 'four'}'::text[]" == vec![Strong::One, Strong::Two, Strong::Three], +)); + +test_type!(strong_enum_array>(Postgres, + "ARRAY['one', 'two', 'four']" == vec![Strong::One, Strong::Two, Strong::Three], )); test_type!(floatrange(Postgres, @@ -753,11 +759,13 @@ async fn test_enum_with_schema() -> anyhow::Result<()> { assert_eq!(foo, Foo::Baz); - let foos: Vec = sqlx::query_scalar!("SELECT ARRAY($1::foo.\"Foo\", $2::foo.\"Foo\")") + let foos: Vec = sqlx::query_scalar("SELECT ARRAY[$1::foo.\"Foo\", $2::foo.\"Foo\"]") .bind(Foo::Bar) .bind(Foo::Baz) .fetch_one(&mut conn) .await?; assert_eq!(foos, [Foo::Bar, Foo::Baz]); + + Ok(()) }