From 93b90be9f7dd316889aa01f3e6066062a133180c Mon Sep 17 00:00:00 2001 From: Charles Samborski Date: Tue, 16 Mar 2021 18:39:52 +0100 Subject: [PATCH] fix(postgres): Add support for domain types description Fix commit updates the `postgres::connection::describe` module to add full support for domain types. Domain types were previously confused with their category which caused invalid oid resolution. Fixes launchbadge/sqlx#110 --- sqlx-core/src/postgres/connection/describe.rs | 141 ++++++++++++++++-- tests/postgres/postgres.rs | 112 ++++++++++++++ tests/postgres/setup.sql | 8 + 3 files changed, 245 insertions(+), 16 deletions(-) diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index f9e2ebf3..097058ce 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -9,9 +9,86 @@ use crate::query_scalar::{query_scalar, query_scalar_with}; use crate::types::Json; use crate::HashMap; use futures_core::future::BoxFuture; +use std::convert::TryFrom; use std::fmt::Write; use std::sync::Arc; +/// Describes the type of the `pg_type.typtype` column +/// +/// See +enum TypType { + Base, + Composite, + Domain, + Enum, + Pseudo, + Range, +} + +impl TryFrom for TypType { + type Error = (); + + fn try_from(t: u8) -> Result { + let t = match t { + b'b' => Self::Base, + b'c' => Self::Composite, + b'd' => Self::Domain, + b'e' => Self::Enum, + b'p' => Self::Pseudo, + b'r' => Self::Range, + _ => return Err(()), + }; + Ok(t) + } +} + +/// Describes the type of the `pg_type.typcategory` column +/// +/// See +enum TypCategory { + Array, + Boolean, + Composite, + DateTime, + Enum, + Geometric, + Network, + Numeric, + Pseudo, + Range, + String, + Timespan, + User, + BitString, + Unknown, +} + +impl TryFrom for TypCategory { + type Error = (); + + fn try_from(c: u8) -> Result { + let c = match c { + b'A' => Self::Array, + b'B' => Self::Boolean, + b'C' => Self::Composite, + b'D' => Self::DateTime, + b'E' => Self::Enum, + b'G' => Self::Geometric, + b'I' => Self::Network, + b'N' => Self::Numeric, + b'P' => Self::Pseudo, + b'R' => Self::Range, + b'S' => Self::String, + b'T' => Self::Timespan, + b'U' => Self::User, + b'V' => Self::BitString, + b'X' => Self::Unknown, + _ => return Err(()), + }; + Ok(c) + } +} + impl PgConnection { pub(super) async fn handle_row_description( &mut self, @@ -106,31 +183,46 @@ impl PgConnection { fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result> { Box::pin(async move { - let (name, category, relation_id, element): (String, i8, u32, u32) = query_as( - "SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1", + let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as( + "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1", ) .bind(oid) .fetch_one(&mut *self) .await?; - match category as u8 { - b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), - name: name.into(), - oid, - })))), + let typ_type = TypType::try_from(typ_type as u8); + let category = TypCategory::try_from(category as u8); - b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Pseudo, - name: name.into(), - oid, - })))), + match (typ_type, category) { + (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, - b'R' => self.fetch_range_by_oid(oid, name).await, + (Ok(TypType::Base), Ok(TypCategory::Array)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), + name: name.into(), + oid, + })))) + } - b'E' => self.fetch_enum_by_oid(oid, name).await, + (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => { + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + kind: PgTypeKind::Pseudo, + name: name.into(), + oid, + })))) + } - b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await, + (Ok(TypType::Range), Ok(TypCategory::Range)) => { + self.fetch_range_by_oid(oid, name).await + } + + (Ok(TypType::Enum), Ok(TypCategory::Enum)) => { + self.fetch_enum_by_oid(oid, name).await + } + + (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { + self.fetch_composite_by_oid(oid, relation_id, name).await + } _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Simple, @@ -198,6 +290,23 @@ ORDER BY attnum }) } + fn fetch_domain_by_oid( + &mut self, + oid: u32, + base_type: u32, + name: String, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; + + Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { + oid, + name: name.into(), + kind: PgTypeKind::Domain(base_type), + })))) + }) + } + fn fetch_range_by_oid( &mut self, oid: u32, diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index dee9062d..c49750dc 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -887,3 +887,115 @@ from (values (null)) vals(val) Ok(()) } + +#[sqlx_macros::test] +async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> { + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct MonthId(i16); + + impl sqlx::Type for MonthId { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("month_id") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for MonthId { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(>::decode(value)?)) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for MonthId { + fn encode_by_ref( + &self, + buf: &mut sqlx::postgres::PgArgumentBuffer, + ) -> sqlx::encode::IsNull { + self.0.encode(buf) + } + } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct WinterYearMonth { + year: i32, + month: MonthId, + } + + impl sqlx::Type for WinterYearMonth { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("winter_year_month") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + + let year = decoder.try_decode::()?; + let month = decoder.try_decode::()?; + + Ok(Self { year, month }) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth { + fn encode_by_ref( + &self, + buf: &mut sqlx::postgres::PgArgumentBuffer, + ) -> sqlx::encode::IsNull { + let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf); + encoder.encode(self.year); + encoder.encode(self.month); + encoder.finish(); + sqlx::encode::IsNull::No + } + } + + let mut conn = new::().await?; + + { + let result = sqlx::query("DELETE FROM heating_bills;") + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + { + let result = sqlx::query( + "INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);", + ) + .bind(WinterYearMonth { + year: 2021, + month: MonthId(1), + }) + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + { + let result = sqlx::query("DELETE FROM heating_bills;") + .execute(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.rows_affected(), 1); + } + + Ok(()) +} diff --git a/tests/postgres/setup.sql b/tests/postgres/setup.sql index 9818d139..d013d434 100644 --- a/tests/postgres/setup.sql +++ b/tests/postgres/setup.sql @@ -29,3 +29,11 @@ CREATE TABLE products ( name TEXT, price NUMERIC CHECK (price > 0) ); + +CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12); +CREATE TYPE year_month AS (year INT4, month month_id); +CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3); +CREATE TABLE heating_bills ( + month winter_year_month NOT NULL PRIMARY KEY, + cost INT4 NOT NULL +);