postgres: test type compatibility for record fields

This commit is contained in:
Ryan Leckey 2020-03-25 02:28:10 -07:00
parent 129efcd367
commit 6ebd5c8c1e
3 changed files with 52 additions and 37 deletions

View File

@ -1,10 +1,23 @@
//! Error<DB>and Result types.
use crate::database::Database;
use crate::types::Type;
use std::any::type_name;
use std::error::Error as StdError;
use std::fmt::{self, Debug, Display};
use std::io;
#[allow(unused_macros)]
macro_rules! decode_err {
($s:literal, $($args:tt)*) => {
crate::Error::Decode(format!($s, $($args)*).into())
};
($expr:expr) => {
crate::Error::decode($expr)
};
}
/// A specialized `Result` type for SQLx.
pub type Result<DB, T> = std::result::Result<T, Error<DB>>;
@ -70,6 +83,21 @@ impl<DB: Database> Error<DB> {
{
Error::<DB>::Decode(err.into())
}
#[allow(dead_code)]
pub(crate) fn mismatched_types<T>(expected: DB::TypeInfo) -> Self
where
T: Type<DB>,
{
let ty_name = type_name::<T>();
return decode_err!(
"mismatched types; Rust type `{}` (as SQL type {}) is not compatible with SQL type {}",
ty_name,
T::type_info(),
expected
);
}
}
impl<DB: Database + Debug> StdError for Error<DB> {
@ -225,17 +253,6 @@ macro_rules! tls_err {
($($args:tt)*) => { crate::error::TlsError { args: format_args!($($args)*)} };
}
#[allow(unused_macros)]
macro_rules! decode_err {
($s:literal, $($args:tt)*) => {
crate::Error::Decode(format!($s, $($args)*).into())
};
($expr:expr) => {
crate::Error::decode($expr)
};
}
/// An unexpected `NULL` was encountered during decoding.
///
/// Returned from `Row::get` if the value from the database is `NULL`

View File

@ -1,8 +1,8 @@
use crate::decode::Decode;
use crate::io::Buf;
use crate::postgres::protocol::TypeId;
use crate::postgres::{PgData, PgValue, Postgres};
use crate::types::Type;
use crate::postgres::{PgData, PgTypeInfo, PgValue, Postgres};
use crate::types::{Type, TypeInfo};
use byteorder::BigEndian;
pub(crate) struct PgSequenceDecoder<'de> {
@ -49,26 +49,32 @@ impl<'de> PgSequenceDecoder<'de> {
// mixed sequences can contain values of many different types
// the OID of the type is encoded next to each value
if self.mixed {
// TODO: We should fail if this type is not _compatible_; but
// I want to make sure we handle this _and_ the outer level
// type mismatch errors at the same time
let type_id = if self.mixed {
let oid = buf.get_u32::<BigEndian>()?;
let expected_ty = PgTypeInfo::with_oid(oid);
let _oid = buf.get_u32::<BigEndian>()?;
}
if !expected_ty.compatible(&T::type_info()) {
return Err(crate::Error::mismatched_types::<T>(expected_ty));
}
TypeId(oid)
} else {
// NOTE: We don't validate the element type for non-mixed sequences because
// the outer type like `text[]` would have already ensured we are dealing
// with a Vec<String>
T::type_info().id
};
let len = buf.get_i32::<BigEndian>()? as isize;
let value = if len < 0 {
// TODO: Grab the correct element OID
T::decode(PgValue::null(TypeId(0)))?
T::decode(PgValue::null(type_id))?
} else {
let value_buf = &buf[..(len as usize)];
*buf = &buf[(len as usize)..];
// TODO: Grab the correct element OID
T::decode(PgValue::bytes(TypeId(0), value_buf))?
T::decode(PgValue::bytes(type_id, value_buf))?
};
self.len += 1;
@ -136,15 +142,15 @@ impl<'de> PgSequenceDecoder<'de> {
break None;
};
// NOTE: We pass `0` as the type ID because we don't have a reasonable value
// we could use. In TEXT mode, sequences aren't typed.
let value = T::decode(if end == Some(0) {
// TODO: Grab the correct element OID
PgValue::null(TypeId(0))
} else if !self.mixed && value == "NULL" {
// Yes, in arrays the text encoding of a NULL is just NULL
// TODO: Grab the correct element OID
PgValue::null(TypeId(0))
} else {
// TODO: Grab the correct element OID
PgValue::str(TypeId(0), &*value)
})?;

View File

@ -134,18 +134,10 @@ where
T: Decode<'c, Self::Database>,
{
let value = self.try_get_raw(index)?;
let value_type_info = value.type_info();
let output_type_info = T::type_info();
let expected_ty = value.type_info();
if !value_type_info.compatible(&output_type_info) {
let ty_name = std::any::type_name::<T>();
return Err(decode_err!(
"mismatched types; Rust type `{}` (as SQL type {}) is not compatible with SQL type {}",
ty_name,
output_type_info,
value_type_info
));
if !expected_ty.compatible(&T::type_info()) {
return Err(crate::Error::mismatched_types::<T>(expected_ty));
}
T::decode(value)