diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 2423acb8..2bfc30d8 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -1,6 +1,8 @@ //! **PostgreSQL** database driver. +// https://github.com/launchbadge/sqlx/issues/3440 #![deny(clippy::cast_possible_truncation)] #![deny(clippy::cast_possible_wrap)] +#![deny(clippy::cast_sign_loss)] #[macro_use] extern crate sqlx_core; diff --git a/sqlx-postgres/src/types/cube.rs b/sqlx-postgres/src/types/cube.rs index bf778b8b..a489e31d 100644 --- a/sqlx-postgres/src/types/cube.rs +++ b/sqlx-postgres/src/types/cube.rs @@ -3,24 +3,52 @@ use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; use sqlx_core::Error; use std::str::FromStr; const BYTE_WIDTH: usize = 8; -const CUBE_TYPE_ZERO_VOLUME: usize = 128; -const CUBE_TYPE_DEFAULT: usize = 0; -const CUBE_DIMENSION_ONE: usize = 1; -const DIMENSIONALITY_POSITION: usize = 3; -const START_INDEX: usize = 4; +/// +const MAX_DIMENSIONS: usize = 100; + +const IS_POINT_FLAG: u32 = 1 << 31; + +// FIXME(breaking): these variants are confusingly named and structured +// consider changing them or making this an opaque wrapper around `Vec` #[derive(Debug, Clone, PartialEq)] pub enum PgCube { + /// A one-dimensional point. + // FIXME: `Point1D(f64) Point(f64), + /// An N-dimensional point ("represented internally as a zero-volume cube"). + // FIXME: `PointND(f64)` ZeroVolume(Vec), + + /// A one-dimensional interval with starting and ending points. + // FIXME: `Interval1D { start: f64, end: f64 }` OneDimensionInterval(f64, f64), + + // FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`? + /// An N-dimensional cube with points representing lower-left and upper-right corners, respectively. + // FIXME: CubeND { lower_left: Vec, upper_right: Vec }` MultiDimension(Vec>), } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct Header { + dimensions: usize, + is_point: bool, +} + +#[derive(Debug, thiserror::Error)] +#[error("error decoding CUBE (is_point: {is_point}, dimensions: {dimensions})")] +struct DecodeError { + is_point: bool, + dimensions: usize, + message: String, +} + impl Type for PgCube { fn type_info() -> PgTypeInfo { PgTypeInfo::with_name("cube") @@ -37,7 +65,7 @@ impl<'r> Decode<'r, Postgres> for PgCube { fn decode(value: PgValueRef<'r>) -> Result> { match value.format() { PgValueFormat::Text => Ok(PgCube::from_str(value.as_str()?)?), - PgValueFormat::Binary => Ok(pg_cube_from_bytes(value.as_bytes()?)?), + PgValueFormat::Binary => Ok(PgCube::from_bytes(value.as_bytes()?)?), } } } @@ -51,6 +79,10 @@ impl<'q> Encode<'q, Postgres> for PgCube { self.serialize(buf)?; Ok(IsNull::No) } + + fn size_hint(&self) -> usize { + self.header().encoded_size() + } } impl FromStr for PgCube { @@ -81,86 +113,84 @@ impl FromStr for PgCube { } } -fn pg_cube_from_bytes(bytes: &[u8]) -> Result { - let cube_type = bytes - .first() - .map(|&byte| byte as usize) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))?; - - let dimensionality = bytes - .get(DIMENSIONALITY_POSITION) - .map(|&byte| byte as usize) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))?; - - match (cube_type, dimensionality) { - (CUBE_TYPE_ZERO_VOLUME, CUBE_DIMENSION_ONE) => { - let point = get_f64_from_bytes(bytes, 4)?; - Ok(PgCube::Point(point)) - } - (CUBE_TYPE_ZERO_VOLUME, _) => { - Ok(PgCube::ZeroVolume(deserialize_vector(bytes, START_INDEX)?)) - } - (CUBE_TYPE_DEFAULT, CUBE_DIMENSION_ONE) => { - let x_start = 4; - let y_start = x_start + BYTE_WIDTH; - let x = get_f64_from_bytes(bytes, x_start)?; - let y = get_f64_from_bytes(bytes, y_start)?; - Ok(PgCube::OneDimensionInterval(x, y)) - } - (CUBE_TYPE_DEFAULT, dim) => Ok(PgCube::MultiDimension(deserialize_matrix( - bytes, - START_INDEX, - dim, - )?)), - (flag, dimension) => Err(Error::Decode( - format!( - "Could not deserialise cube with flag {} and dimension {}: {:?}", - flag, dimension, bytes - ) - .into(), - )), - } -} - impl PgCube { - fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> { + fn header(&self) -> Header { + match self { + PgCube::Point(..) => Header { + is_point: true, + dimensions: 1, + }, + PgCube::ZeroVolume(values) => Header { + is_point: true, + dimensions: values.len(), + }, + PgCube::OneDimensionInterval(..) => Header { + is_point: false, + dimensions: 1, + }, + PgCube::MultiDimension(multi_values) => Header { + is_point: false, + dimensions: multi_values.first().map(|arr| arr.len()).unwrap_or(0), + }, + } + } + + fn from_bytes(mut bytes: &[u8]) -> Result { + let header = Header::try_read(&mut bytes)?; + + if bytes.len() != header.data_size() { + return Err(DecodeError::new( + &header, + format!( + "expected {} bytes after header, got {}", + header.data_size(), + bytes.len() + ), + ) + .into()); + } + + match (header.is_point, header.dimensions) { + (true, 1) => Ok(PgCube::Point(bytes.get_f64())), + (true, _) => Ok(PgCube::ZeroVolume( + read_vec(&mut bytes).map_err(|e| DecodeError::new(&header, e))?, + )), + (false, 1) => Ok(PgCube::OneDimensionInterval( + bytes.get_f64(), + bytes.get_f64(), + )), + (false, _) => Ok(PgCube::MultiDimension(read_cube(&header, bytes)?)), + } + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let header = self.header(); + + buff.reserve(header.data_size()); + + header.try_write(buff)?; + match self { PgCube::Point(value) => { - buff.extend(&[CUBE_TYPE_ZERO_VOLUME as u8, 0, 0, CUBE_DIMENSION_ONE as u8]); buff.extend_from_slice(&value.to_be_bytes()); } PgCube::ZeroVolume(values) => { - let dimension = values.len() as u8; - buff.extend_from_slice(&[CUBE_TYPE_ZERO_VOLUME as u8, 0, 0]); - buff.extend_from_slice(&dimension.to_be_bytes()); - let bytes = values - .iter() - .flat_map(|v| v.to_be_bytes()) - .collect::>(); - buff.extend_from_slice(&bytes); + buff.extend(values.iter().flat_map(|v| v.to_be_bytes())); } PgCube::OneDimensionInterval(x, y) => { - buff.extend_from_slice(&[0, 0, 0, CUBE_DIMENSION_ONE as u8]); buff.extend_from_slice(&x.to_be_bytes()); buff.extend_from_slice(&y.to_be_bytes()); } PgCube::MultiDimension(multi_values) => { - let dimension = multi_values - .first() - .map(|arr| arr.len() as u8) - .unwrap_or(1_u8); - buff.extend_from_slice(&[0, 0, 0]); - buff.extend_from_slice(&dimension.to_be_bytes()); - let bytes = multi_values - .iter() - .flatten() - .flat_map(|v| v.to_be_bytes()) - .collect::>(); - buff.extend_from_slice(&bytes); + if multi_values.len() != 2 { + return Err(format!("invalid CUBE value: {self:?}")); + } + + buff.extend( + multi_values + .iter() + .flat_map(|point| point.iter().flat_map(|scalar| scalar.to_be_bytes())), + ); } }; Ok(()) @@ -174,41 +204,46 @@ impl PgCube { } } -fn get_f64_from_bytes(bytes: &[u8], start: usize) -> Result { - bytes - .get(start..start + BYTE_WIDTH) - .ok_or(Error::Decode( - format!("Could not decode cube bytes: {:?}", bytes).into(), - ))? - .try_into() - .map(f64::from_be_bytes) - .map_err(|err| Error::Decode(format!("Invalid bytes slice: {:?}", err).into())) +fn read_vec(bytes: &mut &[u8]) -> Result, String> { + if bytes.len() % BYTE_WIDTH != 0 { + return Err(format!( + "data length not divisible by {BYTE_WIDTH}: {}", + bytes.len() + )); + } + + let mut out = Vec::with_capacity(bytes.len() / BYTE_WIDTH); + + while bytes.has_remaining() { + out.push(bytes.get_f64()); + } + + Ok(out) } -fn deserialize_vector(bytes: &[u8], start_index: usize) -> Result, Error> { - let steps = (bytes.len() - start_index) / BYTE_WIDTH; - (0..steps) - .map(|i| get_f64_from_bytes(bytes, start_index + i * BYTE_WIDTH)) - .collect() -} +fn read_cube(header: &Header, mut bytes: &[u8]) -> Result>, String> { + if bytes.len() != header.data_size() { + return Err(format!( + "expected {} bytes, got {}", + header.data_size(), + bytes.len() + )); + } -fn deserialize_matrix( - bytes: &[u8], - start_index: usize, - dim: usize, -) -> Result>, Error> { - let step = BYTE_WIDTH * dim; - let steps = (bytes.len() - start_index) / step; + let mut out = Vec::with_capacity(2); - (0..steps) - .map(|step_idx| { - (0..dim) - .map(|dim_idx| { - get_f64_from_bytes(bytes, start_index + step_idx * step + dim_idx * BYTE_WIDTH) - }) - .collect() - }) - .collect() + // Expecting exactly 2 N-dimensional points + for _ in 0..2 { + let mut point = Vec::new(); + + for _ in 0..header.dimensions { + point.push(bytes.get_f64()); + } + + out.push(point); + } + + Ok(out) } fn parse_float_from_str(s: &str, error_msg: &str) -> Result { @@ -268,12 +303,86 @@ fn remove_parentheses(s: &str) -> String { s.trim_matches(|c| c == '(' || c == ')').to_string() } +impl Header { + const PACKED_WIDTH: usize = size_of::(); + + fn encoded_size(&self) -> usize { + Self::PACKED_WIDTH + self.data_size() + } + + fn data_size(&self) -> usize { + if self.is_point { + self.dimensions * BYTE_WIDTH + } else { + self.dimensions * BYTE_WIDTH * 2 + } + } + + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + if self.dimensions > MAX_DIMENSIONS { + return Err(format!( + "CUBE dimensionality exceeds allowed maximum ({} > {MAX_DIMENSIONS})", + self.dimensions + )); + } + + // Cannot overflow thanks to the above check. + #[allow(clippy::cast_possible_truncation)] + let mut packed = self.dimensions as u32; + + // https://github.com/postgres/postgres/blob/e3ec9dc1bf4983fcedb6f43c71ea12ee26aefc7a/contrib/cube/cubedata.h#L18-L24 + if self.is_point { + packed |= IS_POINT_FLAG; + } + + buff.extend(packed.to_be_bytes()); + + Ok(()) + } + + fn try_read(buf: &mut &[u8]) -> Result { + if buf.len() < Self::PACKED_WIDTH { + return Err(format!( + "expected CUBE data to contain at least {} bytes, got {}", + Self::PACKED_WIDTH, + buf.len() + )); + } + + let packed = buf.get_u32(); + + let is_point = packed & IS_POINT_FLAG != 0; + let dimensions = packed & !IS_POINT_FLAG; + + // can only overflow on 16-bit platforms + let dimensions = usize::try_from(dimensions) + .ok() + .filter(|&it| it <= MAX_DIMENSIONS) + .ok_or_else(|| format!("received CUBE data with higher than expected dimensionality: {dimensions} (is_point: {is_point})"))?; + + Ok(Self { + is_point, + dimensions, + }) + } +} + +impl DecodeError { + fn new(header: &Header, message: String) -> Self { + DecodeError { + is_point: header.is_point, + dimensions: header.dimensions, + message, + } + } +} + #[cfg(test)] mod cube_tests { use std::str::FromStr; - use crate::types::{cube::pg_cube_from_bytes, PgCube}; + use super::PgCube; const POINT_BYTES: &[u8] = &[128, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 0]; const ZERO_VOLUME_BYTES: &[u8] = &[ @@ -293,7 +402,7 @@ mod cube_tests { #[test] fn can_deserialise_point_type_byes() { - let cube = pg_cube_from_bytes(POINT_BYTES).unwrap(); + let cube = PgCube::from_bytes(POINT_BYTES).unwrap(); assert_eq!(cube, PgCube::Point(2.)) } @@ -311,7 +420,7 @@ mod cube_tests { } #[test] fn can_deserialise_zero_volume_bytes() { - let cube = pg_cube_from_bytes(ZERO_VOLUME_BYTES).unwrap(); + let cube = PgCube::from_bytes(ZERO_VOLUME_BYTES).unwrap(); assert_eq!(cube, PgCube::ZeroVolume(vec![2., 3.])); } @@ -333,7 +442,7 @@ mod cube_tests { #[test] fn can_deserialise_one_dimension_interval_bytes() { - let cube = pg_cube_from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap(); + let cube = PgCube::from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap(); assert_eq!(cube, PgCube::OneDimensionInterval(7., 8.)) } @@ -355,7 +464,7 @@ mod cube_tests { #[test] fn can_deserialise_multi_dimension_2_dimension_byte() { - let cube = pg_cube_from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap(); + let cube = PgCube::from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap(); assert_eq!( cube, PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]) @@ -396,7 +505,7 @@ mod cube_tests { #[test] fn can_deserialise_multi_dimension_3_dimension_bytes() { - let cube = pg_cube_from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap(); + let cube = PgCube::from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap(); assert_eq!( cube, PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])