mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
implement a runtime type compatibility check before decoding values
This commit is contained in:
parent
1dc582edd0
commit
129efcd367
@ -66,7 +66,7 @@ fn parse_row_description(rd: RowDescription) -> Statement {
|
||||
}
|
||||
|
||||
columns.push(Column {
|
||||
type_oid: field.type_id.0,
|
||||
type_id: field.type_id,
|
||||
format: field.type_format,
|
||||
});
|
||||
}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct TypeId(pub(crate) u32);
|
||||
|
||||
@ -72,3 +74,68 @@ impl TypeId {
|
||||
pub(crate) const JSON: TypeId = TypeId(114);
|
||||
pub(crate) const JSONB: TypeId = TypeId(3802);
|
||||
}
|
||||
|
||||
impl Display for TypeId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
TypeId::BOOL => f.write_str("BOOL"),
|
||||
|
||||
TypeId::INT2 => f.write_str("INT2"),
|
||||
TypeId::INT4 => f.write_str("INT4"),
|
||||
TypeId::INT8 => f.write_str("INT8"),
|
||||
|
||||
TypeId::FLOAT4 => f.write_str("FLOAT4"),
|
||||
TypeId::FLOAT8 => f.write_str("FLOAT8"),
|
||||
|
||||
TypeId::NUMERIC => f.write_str("NUMERIC"),
|
||||
|
||||
TypeId::TEXT => f.write_str("TEXT"),
|
||||
TypeId::VARCHAR => f.write_str("VARCHAR"),
|
||||
TypeId::BPCHAR => f.write_str("BPCHAR"),
|
||||
|
||||
TypeId::DATE => f.write_str("DATE"),
|
||||
TypeId::TIME => f.write_str("TIME"),
|
||||
TypeId::TIMESTAMP => f.write_str("TIMESTAMP"),
|
||||
TypeId::TIMESTAMPTZ => f.write_str("TIMESTAMPTZ"),
|
||||
|
||||
TypeId::BYTEA => f.write_str("BYTEA"),
|
||||
|
||||
TypeId::UUID => f.write_str("UUID"),
|
||||
|
||||
TypeId::CIDR => f.write_str("CIDR"),
|
||||
TypeId::INET => f.write_str("INET"),
|
||||
|
||||
TypeId::ARRAY_BOOL => f.write_str("BOOL[]"),
|
||||
|
||||
TypeId::ARRAY_INT2 => f.write_str("INT2[]"),
|
||||
TypeId::ARRAY_INT4 => f.write_str("INT4[]"),
|
||||
TypeId::ARRAY_INT8 => f.write_str("INT8[]"),
|
||||
|
||||
TypeId::ARRAY_FLOAT4 => f.write_str("FLOAT4[]"),
|
||||
TypeId::ARRAY_FLOAT8 => f.write_str("FLOAT8[]"),
|
||||
|
||||
TypeId::ARRAY_TEXT => f.write_str("TEXT[]"),
|
||||
TypeId::ARRAY_VARCHAR => f.write_str("VARCHAR[]"),
|
||||
TypeId::ARRAY_BPCHAR => f.write_str("BPCHAR[]"),
|
||||
|
||||
TypeId::ARRAY_NUMERIC => f.write_str("NUMERIC[]"),
|
||||
|
||||
TypeId::ARRAY_DATE => f.write_str("DATE[]"),
|
||||
TypeId::ARRAY_TIME => f.write_str("TIME[]"),
|
||||
TypeId::ARRAY_TIMESTAMP => f.write_str("TIMESTAMP[]"),
|
||||
TypeId::ARRAY_TIMESTAMPTZ => f.write_str("TIMESTAMPTZ[]"),
|
||||
|
||||
TypeId::ARRAY_BYTEA => f.write_str("BYTEA[]"),
|
||||
|
||||
TypeId::ARRAY_UUID => f.write_str("UUID[]"),
|
||||
|
||||
TypeId::ARRAY_CIDR => f.write_str("CIDR[]"),
|
||||
TypeId::ARRAY_INET => f.write_str("INET[]"),
|
||||
|
||||
TypeId::JSON => f.write_str("JSON"),
|
||||
TypeId::JSONB => f.write_str("JSONB"),
|
||||
|
||||
_ => write!(f, "<{}>", self.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::postgres::protocol::{DataRow, TypeFormat};
|
||||
use crate::postgres::protocol::{DataRow, TypeFormat, TypeId};
|
||||
use crate::postgres::value::PgValue;
|
||||
use crate::postgres::Postgres;
|
||||
use crate::row::{ColumnIndex, Row};
|
||||
@ -11,7 +11,7 @@ use crate::row::{ColumnIndex, Row};
|
||||
// For simple (unprepared) queries, format will always be text
|
||||
// For prepared queries, format will _almost_ always be binary
|
||||
pub(crate) struct Column {
|
||||
pub(crate) type_oid: u32,
|
||||
pub(crate) type_id: TypeId,
|
||||
pub(crate) format: TypeFormat,
|
||||
}
|
||||
|
||||
@ -53,9 +53,9 @@ impl<'c> Row<'c> for PgRow<'c> {
|
||||
let column = &self.statement.columns[index];
|
||||
let buffer = self.data.get(index);
|
||||
let value = match (column.format, buffer) {
|
||||
(_, None) => PgValue::null(column.type_oid),
|
||||
(TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_oid, buf),
|
||||
(TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_oid, buf)?,
|
||||
(_, None) => PgValue::null(column.type_id),
|
||||
(TypeFormat::Binary, Some(buf)) => PgValue::bytes(column.type_id, buf),
|
||||
(TypeFormat::Text, Some(buf)) => PgValue::utf8(column.type_id, buf)?,
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
|
||||
@ -234,7 +234,7 @@ impl Display for PgTypeInfo {
|
||||
if let Some(ref name) = self.name {
|
||||
write!(f, "{}", *name)
|
||||
} else {
|
||||
write!(f, "OID {}", self.id.0)
|
||||
write!(f, "{}", self.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
use crate::decode::Decode;
|
||||
use crate::io::Buf;
|
||||
use crate::postgres::protocol::TypeId;
|
||||
use crate::postgres::{PgData, PgValue, Postgres};
|
||||
use crate::types::Type;
|
||||
use byteorder::BigEndian;
|
||||
@ -60,14 +61,14 @@ impl<'de> PgSequenceDecoder<'de> {
|
||||
|
||||
let value = if len < 0 {
|
||||
// TODO: Grab the correct element OID
|
||||
T::decode(PgValue::null(0))?
|
||||
T::decode(PgValue::null(TypeId(0)))?
|
||||
} else {
|
||||
let value_buf = &buf[..(len as usize)];
|
||||
|
||||
*buf = &buf[(len as usize)..];
|
||||
|
||||
// TODO: Grab the correct element OID
|
||||
T::decode(PgValue::bytes(0, value_buf))?
|
||||
T::decode(PgValue::bytes(TypeId(0), value_buf))?
|
||||
};
|
||||
|
||||
self.len += 1;
|
||||
@ -137,14 +138,14 @@ impl<'de> PgSequenceDecoder<'de> {
|
||||
|
||||
let value = T::decode(if end == Some(0) {
|
||||
// TODO: Grab the correct element OID
|
||||
PgValue::null(0)
|
||||
PgValue::null(TypeId(0))
|
||||
} else if !self.mixed && value == "NULL" {
|
||||
// Yes, in arrays the text encoding of a NULL is just NULL
|
||||
// TODO: Grab the correct element OID
|
||||
PgValue::null(0)
|
||||
PgValue::null(TypeId(0))
|
||||
} else {
|
||||
// TODO: Grab the correct element OID
|
||||
PgValue::str(0, &*value)
|
||||
PgValue::str(TypeId(0), &*value)
|
||||
})?;
|
||||
|
||||
*s = if let Some(end) = end {
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::postgres::protocol::TypeId;
|
||||
use crate::postgres::{PgTypeInfo, Postgres};
|
||||
use crate::value::RawValue;
|
||||
use std::str::from_utf8;
|
||||
@ -11,7 +12,7 @@ pub enum PgData<'c> {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PgValue<'c> {
|
||||
type_oid: u32,
|
||||
type_id: TypeId,
|
||||
data: Option<PgData<'c>>,
|
||||
}
|
||||
|
||||
@ -32,30 +33,30 @@ impl<'c> PgValue<'c> {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub(crate) fn null(type_oid: u32) -> Self {
|
||||
pub(crate) fn null(type_id: TypeId) -> Self {
|
||||
Self {
|
||||
type_oid,
|
||||
type_id,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bytes(type_oid: u32, buf: &'c [u8]) -> Self {
|
||||
pub(crate) fn bytes(type_id: TypeId, buf: &'c [u8]) -> Self {
|
||||
Self {
|
||||
type_oid,
|
||||
type_id,
|
||||
data: Some(PgData::Binary(buf)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn utf8(type_oid: u32, buf: &'c [u8]) -> crate::Result<Postgres, Self> {
|
||||
pub(crate) fn utf8(type_id: TypeId, buf: &'c [u8]) -> crate::Result<Postgres, Self> {
|
||||
Ok(Self {
|
||||
type_oid,
|
||||
type_id,
|
||||
data: Some(PgData::Text(from_utf8(&buf).map_err(crate::Error::decode)?)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn str(type_oid: u32, s: &'c str) -> Self {
|
||||
pub(crate) fn str(type_id: TypeId, s: &'c str) -> Self {
|
||||
Self {
|
||||
type_oid,
|
||||
type_id,
|
||||
data: Some(PgData::Text(s)),
|
||||
}
|
||||
}
|
||||
@ -65,6 +66,6 @@ impl<'c> RawValue<'c> for PgValue<'c> {
|
||||
type Database = Postgres;
|
||||
|
||||
fn type_info(&self) -> PgTypeInfo {
|
||||
PgTypeInfo::with_oid(self.type_oid)
|
||||
PgTypeInfo::with_oid(self.type_id.0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::decode::Decode;
|
||||
use crate::types::Type;
|
||||
use crate::value::HasRawValue;
|
||||
use crate::types::{Type, TypeInfo};
|
||||
use crate::value::{HasRawValue, RawValue};
|
||||
|
||||
/// A type that can be used to index into a [`Row`].
|
||||
///
|
||||
@ -133,7 +133,22 @@ where
|
||||
I: ColumnIndex<'c, Self>,
|
||||
T: Decode<'c, Self::Database>,
|
||||
{
|
||||
Ok(Decode::decode(self.try_get_raw(index)?)?)
|
||||
let value = self.try_get_raw(index)?;
|
||||
let value_type_info = value.type_info();
|
||||
let output_type_info = T::type_info();
|
||||
|
||||
if !value_type_info.compatible(&output_type_info) {
|
||||
let ty_name = std::any::type_name::<T>();
|
||||
|
||||
return Err(decode_err!(
|
||||
"mismatched types; Rust type `{}` (as SQL type {}) is not compatible with SQL type {}",
|
||||
ty_name,
|
||||
output_type_info,
|
||||
value_type_info
|
||||
));
|
||||
}
|
||||
|
||||
T::decode(value)
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
||||
@ -12,7 +12,7 @@ use sqlx_test::{new, test_prepared_type, test_type};
|
||||
test_type!(null(
|
||||
Postgres,
|
||||
Option<i16>,
|
||||
"NULL" == None::<i16>
|
||||
"NULL::int2" == None::<i16>
|
||||
));
|
||||
|
||||
test_type!(bool(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user