diff --git a/Cargo.lock b/Cargo.lock index ee2fe494b..0cb2f5716 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2516,6 +2516,7 @@ dependencies = [ "dotenv", "env_logger", "futures 0.3.5", + "geo", "paste", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index c875d9c4f..1aae67d6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,7 @@ sqlx-macros = { version = "0.4.0-pre", path = "sqlx-macros", default-features = [dev-dependencies] anyhow = "1.0.31" time_ = { version = "0.2.16", package = "time" } +geo_ = { version = "0.13.0", package = "geo" } futures = "0.3.5" env_logger = "0.7.1" async-std = { version = "1.6.0", features = [ "attributes" ] } diff --git a/sqlx-core/src/postgres/types/geo.rs b/sqlx-core/src/postgres/types/geo.rs index 7731076b3..bcf29ec41 100644 --- a/sqlx-core/src/postgres/types/geo.rs +++ b/sqlx-core/src/postgres/types/geo.rs @@ -6,7 +6,7 @@ use crate::{ types::Type, }; use byteorder::{BigEndian, ByteOrder}; -use geo::{Line, Coordinate}; +use geo::{Coordinate, Line}; use std::{mem, num::ParseFloatError}; // @@ -53,14 +53,6 @@ impl Decode<'_, Postgres> for Coordinate { } } -fn decode_coordinate_binary(buf: &[u8]) -> Result, BoxDynError> { - let x = BigEndian::read_f64(buf); - - let y = BigEndian::read_f64(buf); - - Ok((x, y).into()) -} - impl Encode<'_, Postgres> for Coordinate { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { let _ = Encode::::encode(self.x, buf); @@ -74,6 +66,18 @@ impl Encode<'_, Postgres> for Coordinate { } } +fn decode_coordinate_binary(buf: &[u8]) -> Result, BoxDynError> { + if buf.len() != 16 { + Err("Invalid data received when expecting a POINT".into()) + } else { + let x = BigEndian::read_f64(&buf[..8]); + + let y = BigEndian::read_f64(&buf[8..]); + + Ok((x, y).into()) + } +} + impl Type for Line { fn type_info() -> PgTypeInfo { PgTypeInfo::LSEG @@ -85,9 +89,8 @@ impl Decode<'_, Postgres> for Line { match value.format() { PgValueFormat::Binary => { let buf = value.as_bytes()?; - let start = decode_coordinate_binary(buf)?; - // buf.advance(Encode::::size_hint(&start)); - let end = decode_coordinate_binary(buf)?; + let start = decode_coordinate_binary(&buf[..16])?; + let end = decode_coordinate_binary(&buf[16..])?; Ok(Line::new(start, end)) } @@ -95,17 +98,26 @@ impl Decode<'_, Postgres> for Line { // TODO: is there no way to make this make use of the Decode for Coordinate? PgValueFormat::Text => { let brackets: &[_] = &['[', ']']; - let mut s = value.as_str()? + let mut s = value + .as_str()? .trim_matches(brackets) .split(|c| c == '(' || c == ')' || c == ',') .filter_map(|part| if part == "" { None } else { Some(part) }); match (s.next(), s.next(), s.next(), s.next()) { (Some(x1), Some(y1), Some(x2), Some(y2)) => { - let x1 = x1.parse().map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; - let y1 = y1.parse().map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; - let x2 = x2.parse().map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; - let y2 = y2.parse().map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; + let x1 = x1 + .parse() + .map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; + let y1 = y1 + .parse() + .map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; + let x2 = x2 + .parse() + .map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; + let y2 = y2 + .parse() + .map_err(|e: ParseFloatError| crate::error::Error::Decode(e.into()))?; let start = Coordinate::from((x1, y1)); let end = Coordinate::from((x2, y2)); diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index fe3f8e65b..d4d03763a 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -386,6 +386,24 @@ test_type!(decimal(Postgres, "12345.6789::numeric" == sqlx::types::Decimal::from_str("12345.6789").unwrap(), )); +#[cfg(feature = "geo")] +mod geometric { + use super::*; + use geo_::{Coordinate, Line}; + + test_type!(point>(Postgres, + "SELECT ({0} ~= $1)::int4, {0} as _2, $2 as _3", + "point (1, 5)" == Coordinate::::from((1.0, 5.0)), + "point (1.5, 7)" == Coordinate::::from((1.5, 7.0)), + "point (5.0, 12.5)" == Coordinate::::from((5.0, 12.5)), + )); + + test_type!(line>(Postgres, + "lseg (point (1, 0), point (2, 0))" == Line::::new((1.0, 0.0), (2.0, 0.0)), + "lseg (point (2, 0), point (1, 0))" == Line::::new((2.0, 0.0), (1.0, 0.0)), + )); +} + const EXC2: Bound = Bound::Excluded(2); const EXC3: Bound = Bound::Excluded(3); const INC1: Bound = Bound::Included(1);