implement a runtime type compatibility check before decoding values

This commit is contained in:
Ryan Leckey 2020-03-25 02:07:17 -07:00
parent 1dc582edd0
commit 129efcd367
8 changed files with 110 additions and 26 deletions

View File

@ -66,7 +66,7 @@ fn parse_row_description(rd: RowDescription) -> Statement {
}
columns.push(Column {
type_oid: field.type_id.0,
type_id: field.type_id,
format: field.type_format,
});
}

View File

@ -1,3 +1,5 @@
use std::fmt::{self, Display};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TypeId(pub(crate) u32);
@ -72,3 +74,68 @@ impl TypeId {
pub(crate) const JSON: TypeId = TypeId(114);
pub(crate) const JSONB: TypeId = TypeId(3802);
}
impl Display for TypeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TypeId::BOOL => f.write_str("BOOL"),
TypeId::INT2 => f.write_str("INT2"),
TypeId::INT4 => f.write_str("INT4"),
TypeId::INT8 => f.write_str("INT8"),
TypeId::FLOAT4 => f.write_str("FLOAT4"),
TypeId::FLOAT8 => f.write_str("FLOAT8"),
TypeId::NUMERIC => f.write_str("NUMERIC"),
TypeId::TEXT => f.write_str("TEXT"),
TypeId::VARCHAR => f.write_str("VARCHAR"),
TypeId::BPCHAR => f.write_str("BPCHAR"),
TypeId::DATE => f.write_str("DATE"),
TypeId::TIME => f.write_str("TIME"),
TypeId::TIMESTAMP => f.write_str("TIMESTAMP"),
TypeId::TIMESTAMPTZ => f.write_str("TIMESTAMPTZ"),
TypeId::BYTEA => f.write_str("BYTEA"),
TypeId::UUID => f.write_str("UUID"),
TypeId::CIDR => f.write_str("CIDR"),
TypeId::INET => f.write_str("INET"),
TypeId::ARRAY_BOOL => f.write_str("BOOL[]"),
TypeId::ARRAY_INT2 => f.write_str("INT2[]"),
TypeId::ARRAY_INT4 => f.write_str("INT4[]"),
TypeId::ARRAY_INT8 => f.write_str("INT8[]"),
TypeId::ARRAY_FLOAT4 => f.write_str("FLOAT4[]"),
TypeId::ARRAY_FLOAT8 => f.write_str("FLOAT8[]"),
TypeId::ARRAY_TEXT => f.write_str("TEXT[]"),
TypeId::ARRAY_VARCHAR => f.write_str("VARCHAR[]"),
TypeId::ARRAY_BPCHAR => f.write_str("BPCHAR[]"),
TypeId::ARRAY_NUMERIC => f.write_str("NUMERIC[]"),
TypeId::ARRAY_DATE => f.write_str("DATE[]"),
TypeId::ARRAY_TIME => f.write_str("TIME[]"),
TypeId::ARRAY_TIMESTAMP => f.write_str("TIMESTAMP[]"),
TypeId::ARRAY_TIMESTAMPTZ => f.write_str("TIMESTAMPTZ[]"),
TypeId::ARRAY_BYTEA => f.write_str("BYTEA[]"),
TypeId::ARRAY_UUID => f.write_str("UUID[]"),
TypeId::ARRAY_CIDR => f.write_str("CIDR[]"),
TypeId::ARRAY_INET => f.write_str("INET[]"),
TypeId::JSON => f.write_str("JSON"),
TypeId::JSONB => f.write_str("JSONB"),
_ => write!(f, "<{}>", self.0),
}
}
}

View File

@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::postgres::protocol::{DataRow, TypeFormat};
use crate::postgres::protocol::{DataRow, TypeFormat, TypeId};
use crate::postgres::value::PgValue;
use crate::postgres::Postgres;
use crate::row::{ColumnIndex, Row};
@ -11,7 +11,7 @@ use crate::row::{ColumnIndex, Row};
// For simple (unprepared) queries, format will always be text
// For prepared queries, format will _almost_ always be binary
pub(crate) struct Column {
pub(crate) type_oid: u32,
pub(crate) type_id: TypeId,
pub(crate) format: TypeFormat,
}
@ -53,9 +53,9 @@ impl<'c> Row<'c> for PgRow<'c> {
let column = &self.statement.columns[index];
let buffer = self.data.get(index);
let value = match (column.format, buffer) {
(_, None) => PgValue::null(column.type_oid),
(TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_oid, buf),
(TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_oid, buf)?,
(_, None) => PgValue::null(column.type_id),
(TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_id, buf),
(TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_id, buf)?,
};
Ok(value)

View File

@ -234,7 +234,7 @@ impl Display for PgTypeInfo {
if let Some(ref name) = self.name {
write!(f, "{}", *name)
} else {
write!(f, "OID {}", self.id.0)
write!(f, "{}", self.id)
}
}
}

View File

@ -1,5 +1,6 @@
use crate::decode::Decode;
use crate::io::Buf;
use crate::postgres::protocol::TypeId;
use crate::postgres::{PgData, PgValue, Postgres};
use crate::types::Type;
use byteorder::BigEndian;
@ -60,14 +61,14 @@ impl<'de> PgSequenceDecoder<'de> {
let value = if len < 0 {
// TODO: Grab the correct element OID
T::decode(PgValue::null(0))?
T::decode(PgValue::null(TypeId(0)))?
} else {
let value_buf = &buf[..(len as usize)];
*buf = &buf[(len as usize)..];
// TODO: Grab the correct element OID
T::decode(PgValue::bytes(0, value_buf))?
T::decode(PgValue::bytes(TypeId(0), value_buf))?
};
self.len += 1;
@ -137,14 +138,14 @@ impl<'de> PgSequenceDecoder<'de> {
let value = T::decode(if end == Some(0) {
// TODO: Grab the correct element OID
PgValue::null(0)
PgValue::null(TypeId(0))
} else if !self.mixed && value == "NULL" {
// Yes, in arrays the text encoding of a NULL is just NULL
// TODO: Grab the correct element OID
PgValue::null(0)
PgValue::null(TypeId(0))
} else {
// TODO: Grab the correct element OID
PgValue::str(0, &*value)
PgValue::str(TypeId(0), &*value)
})?;
*s = if let Some(end) = end {

View File

@ -1,4 +1,5 @@
use crate::error::UnexpectedNullError;
use crate::postgres::protocol::TypeId;
use crate::postgres::{PgTypeInfo, Postgres};
use crate::value::RawValue;
use std::str::from_utf8;
@ -11,7 +12,7 @@ pub enum PgData<'c> {
#[derive(Debug)]
pub struct PgValue<'c> {
type_oid: u32,
type_id: TypeId,
data: Option<PgData<'c>>,
}
@ -32,30 +33,30 @@ impl<'c> PgValue<'c> {
self.data
}
pub(crate) fn null(type_oid: u32) -> Self {
pub(crate) fn null(type_id: TypeId) -> Self {
Self {
type_oid,
type_id,
data: None,
}
}
pub(crate) fn bytes(type_oid: u32, buf: &'c [u8]) -> Self {
pub(crate) fn bytes(type_id: TypeId, buf: &'c [u8]) -> Self {
Self {
type_oid,
type_id,
data: Some(PgData::Binary(buf)),
}
}
pub(crate) fn utf8(type_oid: u32, buf: &'c [u8]) -> crate::Result<Postgres, Self> {
pub(crate) fn utf8(type_id: TypeId, buf: &'c [u8]) -> crate::Result<Postgres, Self> {
Ok(Self {
type_oid,
type_id,
data: Some(PgData::Text(from_utf8(&buf).map_err(crate::Error::decode)?)),
})
}
pub(crate) fn str(type_oid: u32, s: &'c str) -> Self {
pub(crate) fn str(type_id: TypeId, s: &'c str) -> Self {
Self {
type_oid,
type_id,
data: Some(PgData::Text(s)),
}
}
@ -65,6 +66,6 @@ impl<'c> RawValue<'c> for PgValue<'c> {
type Database = Postgres;
fn type_info(&self) -> PgTypeInfo {
PgTypeInfo::with_oid(self.type_oid)
PgTypeInfo::with_oid(self.type_id.0)
}
}

View File

@ -2,8 +2,8 @@
use crate::database::Database;
use crate::decode::Decode;
use crate::types::Type;
use crate::value::HasRawValue;
use crate::types::{Type, TypeInfo};
use crate::value::{HasRawValue, RawValue};
/// A type that can be used to index into a [`Row`].
///
@ -133,7 +133,22 @@ where
I: ColumnIndex<'c, Self>,
T: Decode<'c, Self::Database>,
{
Ok(Decode::decode(self.try_get_raw(index)?)?)
let value = self.try_get_raw(index)?;
let value_type_info = value.type_info();
let output_type_info = T::type_info();
if !value_type_info.compatible(&output_type_info) {
let ty_name = std::any::type_name::<T>();
return Err(decode_err!(
"mismatched types; Rust type `{}` (as SQL type {}) is not compatible with SQL type {}",
ty_name,
output_type_info,
value_type_info
));
}
T::decode(value)
}
#[doc(hidden)]

View File

@ -12,7 +12,7 @@ use sqlx_test::{new, test_prepared_type, test_type};
test_type!(null(
Postgres,
Option<i16>,
"NULL" == None::<i16>
"NULL::int2" == None::<i16>
));
test_type!(bool(