From 129efcd3675dbb79aae6e7cca2d25c0ec7dbfa34 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 25 Mar 2020 02:07:17 -0700 Subject: [PATCH] implement a runtime type compatibility check before decoding values --- sqlx-core/src/postgres/cursor.rs | 2 +- sqlx-core/src/postgres/protocol/type_id.rs | 67 ++++++++++++++++++++ sqlx-core/src/postgres/row.rs | 10 +-- sqlx-core/src/postgres/types/mod.rs | 2 +- sqlx-core/src/postgres/types/raw/sequence.rs | 11 ++-- sqlx-core/src/postgres/value.rs | 21 +++--- sqlx-core/src/row.rs | 21 +++++- tests/postgres-types.rs | 2 +- 8 files changed, 110 insertions(+), 26 deletions(-) diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index cf2f3bd4..9824f58f 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -66,7 +66,7 @@ fn parse_row_description(rd: RowDescription) -> Statement { } columns.push(Column { - type_oid: field.type_id.0, + type_id: field.type_id, format: field.type_format, }); } diff --git a/sqlx-core/src/postgres/protocol/type_id.rs b/sqlx-core/src/postgres/protocol/type_id.rs index 5622e5fc..88441f76 100644 --- a/sqlx-core/src/postgres/protocol/type_id.rs +++ b/sqlx-core/src/postgres/protocol/type_id.rs @@ -1,3 +1,5 @@ +use std::fmt::{self, Display}; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct TypeId(pub(crate) u32); @@ -72,3 +74,68 @@ impl TypeId { pub(crate) const JSON: TypeId = TypeId(114); pub(crate) const JSONB: TypeId = TypeId(3802); } + +impl Display for TypeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TypeId::BOOL => f.write_str("BOOL"), + + TypeId::INT2 => f.write_str("INT2"), + TypeId::INT4 => f.write_str("INT4"), + TypeId::INT8 => f.write_str("INT8"), + + TypeId::FLOAT4 => f.write_str("FLOAT4"), + TypeId::FLOAT8 => f.write_str("FLOAT8"), + + TypeId::NUMERIC => f.write_str("NUMERIC"), + + TypeId::TEXT => f.write_str("TEXT"), + TypeId::VARCHAR => f.write_str("VARCHAR"), + TypeId::BPCHAR => f.write_str("BPCHAR"), + + TypeId::DATE => f.write_str("DATE"), + TypeId::TIME => f.write_str("TIME"), + TypeId::TIMESTAMP => f.write_str("TIMESTAMP"), + TypeId::TIMESTAMPTZ => f.write_str("TIMESTAMPTZ"), + + TypeId::BYTEA => f.write_str("BYTEA"), + + TypeId::UUID => f.write_str("UUID"), + + TypeId::CIDR => f.write_str("CIDR"), + TypeId::INET => f.write_str("INET"), + + TypeId::ARRAY_BOOL => f.write_str("BOOL[]"), + + TypeId::ARRAY_INT2 => f.write_str("INT2[]"), + TypeId::ARRAY_INT4 => f.write_str("INT4[]"), + TypeId::ARRAY_INT8 => f.write_str("INT8[]"), + + TypeId::ARRAY_FLOAT4 => f.write_str("FLOAT4[]"), + TypeId::ARRAY_FLOAT8 => f.write_str("FLOAT8[]"), + + TypeId::ARRAY_TEXT => f.write_str("TEXT[]"), + TypeId::ARRAY_VARCHAR => f.write_str("VARCHAR[]"), + TypeId::ARRAY_BPCHAR => f.write_str("BPCHAR[]"), + + TypeId::ARRAY_NUMERIC => f.write_str("NUMERIC[]"), + + TypeId::ARRAY_DATE => f.write_str("DATE[]"), + TypeId::ARRAY_TIME => f.write_str("TIME[]"), + TypeId::ARRAY_TIMESTAMP => f.write_str("TIMESTAMP[]"), + TypeId::ARRAY_TIMESTAMPTZ => f.write_str("TIMESTAMPTZ[]"), + + TypeId::ARRAY_BYTEA => f.write_str("BYTEA[]"), + + TypeId::ARRAY_UUID => f.write_str("UUID[]"), + + TypeId::ARRAY_CIDR => f.write_str("CIDR[]"), + TypeId::ARRAY_INET => f.write_str("INET[]"), + + TypeId::JSON => f.write_str("JSON"), + TypeId::JSONB => f.write_str("JSONB"), + + _ => write!(f, "<{}>", self.0), + } + } +} diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 55509efc..148fd8f9 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::postgres::protocol::{DataRow, TypeFormat}; +use crate::postgres::protocol::{DataRow, TypeFormat, TypeId}; use crate::postgres::value::PgValue; use crate::postgres::Postgres; use crate::row::{ColumnIndex, Row}; @@ -11,7 +11,7 @@ use crate::row::{ColumnIndex, Row}; // For simple (unprepared) queries, format will always be text // For prepared queries, format will _almost_ always be binary pub(crate) struct Column { - pub(crate) type_oid: u32, + pub(crate) type_id: TypeId, pub(crate) format: TypeFormat, } @@ -53,9 +53,9 @@ impl<'c> Row<'c> for PgRow<'c> { let column = &self.statement.columns[index]; let buffer = self.data.get(index); let value = match (column.format, buffer) { - (_, None) => PgValue::null(column.type_oid), - (TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_oid, buf), - (TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_oid, buf)?, + (_, None) => PgValue::null(column.type_id), + (TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_id, buf), + (TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_id, buf)?, }; Ok(value) diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 2dd0fc31..7a914c45 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -234,7 +234,7 @@ impl Display for PgTypeInfo { if let Some(ref name) = self.name { write!(f, "{}", *name) } else { - write!(f, "OID {}", self.id.0) + write!(f, "{}", self.id) } } } diff --git a/sqlx-core/src/postgres/types/raw/sequence.rs b/sqlx-core/src/postgres/types/raw/sequence.rs index 652924ba..f700867c 100644 --- a/sqlx-core/src/postgres/types/raw/sequence.rs +++ b/sqlx-core/src/postgres/types/raw/sequence.rs @@ -1,5 +1,6 @@ use crate::decode::Decode; use crate::io::Buf; +use crate::postgres::protocol::TypeId; use crate::postgres::{PgData, PgValue, Postgres}; use crate::types::Type; use byteorder::BigEndian; @@ -60,14 +61,14 @@ impl<'de> PgSequenceDecoder<'de> { let value = if len < 0 { // TODO: Grab the correct element OID - T::decode(PgValue::null(0))? + T::decode(PgValue::null(TypeId(0)))? } else { let value_buf = &buf[..(len as usize)]; *buf = &buf[(len as usize)..]; // TODO: Grab the correct element OID - T::decode(PgValue::bytes(0, value_buf))? + T::decode(PgValue::bytes(TypeId(0), value_buf))? }; self.len += 1; @@ -137,14 +138,14 @@ impl<'de> PgSequenceDecoder<'de> { let value = T::decode(if end == Some(0) { // TODO: Grab the correct element OID - PgValue::null(0) + PgValue::null(TypeId(0)) } else if !self.mixed && value == "NULL" { // Yes, in arrays the text encoding of a NULL is just NULL // TODO: Grab the correct element OID - PgValue::null(0) + PgValue::null(TypeId(0)) } else { // TODO: Grab the correct element OID - PgValue::str(0, &*value) + PgValue::str(TypeId(0), &*value) })?; *s = if let Some(end) = end { diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs index 0878aa28..a9926de2 100644 --- a/sqlx-core/src/postgres/value.rs +++ b/sqlx-core/src/postgres/value.rs @@ -1,4 +1,5 @@ use crate::error::UnexpectedNullError; +use crate::postgres::protocol::TypeId; use crate::postgres::{PgTypeInfo, Postgres}; use crate::value::RawValue; use std::str::from_utf8; @@ -11,7 +12,7 @@ pub enum PgData<'c> { #[derive(Debug)] pub struct PgValue<'c> { - type_oid: u32, + type_id: TypeId, data: Option>, } @@ -32,30 +33,30 @@ impl<'c> PgValue<'c> { self.data } - pub(crate) fn null(type_oid: u32) -> Self { + pub(crate) fn null(type_id: TypeId) -> Self { Self { - type_oid, + type_id, data: None, } } - pub(crate) fn bytes(type_oid: u32, buf: &'c [u8]) -> Self { + pub(crate) fn bytes(type_id: TypeId, buf: &'c [u8]) -> Self { Self { - type_oid, + type_id, data: Some(PgData::Binary(buf)), } } - pub(crate) fn utf8(type_oid: u32, buf: &'c [u8]) -> crate::Result { + pub(crate) fn utf8(type_id: TypeId, buf: &'c [u8]) -> crate::Result { Ok(Self { - type_oid, + type_id, data: Some(PgData::Text(from_utf8(&buf).map_err(crate::Error::decode)?)), }) } - pub(crate) fn str(type_oid: u32, s: &'c str) -> Self { + pub(crate) fn str(type_id: TypeId, s: &'c str) -> Self { Self { - type_oid, + type_id, data: Some(PgData::Text(s)), } } @@ -65,6 +66,6 @@ impl<'c> RawValue<'c> for PgValue<'c> { type Database = Postgres; fn type_info(&self) -> PgTypeInfo { - PgTypeInfo::with_oid(self.type_oid) + PgTypeInfo::with_oid(self.type_id.0) } } diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index 5bf758a1..ac532d83 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -2,8 +2,8 @@ use crate::database::Database; use crate::decode::Decode; -use crate::types::Type; -use crate::value::HasRawValue; +use crate::types::{Type, TypeInfo}; +use crate::value::{HasRawValue, RawValue}; /// A type that can be used to index into a [`Row`]. /// @@ -133,7 +133,22 @@ where I: ColumnIndex<'c, Self>, T: Decode<'c, Self::Database>, { - Ok(Decode::decode(self.try_get_raw(index)?)?) + let value = self.try_get_raw(index)?; + let value_type_info = value.type_info(); + let output_type_info = T::type_info(); + + if !value_type_info.compatible(&output_type_info) { + let ty_name = std::any::type_name::(); + + return Err(decode_err!( + "mismatched types; Rust type `{}` (as SQL type {}) is not compatible with SQL type {}", + ty_name, + output_type_info, + value_type_info + )); + } + + T::decode(value) } #[doc(hidden)] diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index 565b6c27..1a464810 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -12,7 +12,7 @@ use sqlx_test::{new, test_prepared_type, test_type}; test_type!(null( Postgres, Option, - "NULL" == None:: + "NULL::int2" == None:: )); test_type!(bool(