Merge remote-tracking branch 'origin/main' into sqlx-toml

# Conflicts:
#	.github/workflows/examples.yml
#	sqlx-postgres/src/connection/mod.rs
This commit is contained in:
Austin Bonander
2025-03-29 15:57:39 -07:00
68 changed files with 2465 additions and 187 deletions

View File

@@ -19,6 +19,7 @@ offline = ["sqlx-core/offline"]
bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"]
bit-vec = ["dep:bit-vec", "sqlx-core/bit-vec"]
chrono = ["dep:chrono", "sqlx-core/chrono"]
ipnet = ["dep:ipnet", "sqlx-core/ipnet"]
ipnetwork = ["dep:ipnetwork", "sqlx-core/ipnetwork"]
mac_address = ["dep:mac_address", "sqlx-core/mac_address"]
rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"]
@@ -43,6 +44,7 @@ sha2 = { version = "0.10.0", default-features = false }
bigdecimal = { workspace = true, optional = true }
bit-vec = { workspace = true, optional = true }
chrono = { workspace = true, optional = true }
ipnet = { workspace = true, optional = true }
ipnetwork = { workspace = true, optional = true }
mac_address = { workspace = true, optional = true }
rust_decimal = { workspace = true, optional = true }

View File

@@ -5,6 +5,7 @@ use crate::{
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::{stream, StreamExt, TryFutureExt, TryStreamExt};
use std::borrow::Cow;
use std::{future, pin::pin};
use sqlx_core::any::{
@@ -39,8 +40,11 @@ impl AnyConnectionBackend for PgConnection {
Connection::ping(self)
}
fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
PgTransactionManager::begin(self)
fn begin(
&mut self,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'_, sqlx_core::Result<()>> {
PgTransactionManager::begin(self, statement)
}
fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
@@ -55,6 +59,10 @@ impl AnyConnectionBackend for PgConnection {
PgTransactionManager::start_rollback(self)
}
fn get_transaction_depth(&self) -> usize {
PgTransactionManager::get_transaction_depth(self)
}
fn shrink_buffers(&mut self) {
Connection::shrink_buffers(self);
}

View File

@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
@@ -135,6 +136,13 @@ impl PgConnection {
Ok(())
}
pub(crate) fn in_transaction(&self) -> bool {
match self.inner.transaction_status {
TransactionStatus::Transaction => true,
TransactionStatus::Error | TransactionStatus::Idle => false,
}
}
}
impl Debug for PgConnection {
@@ -187,7 +195,17 @@ impl Connection for PgConnection {
where
Self: Sized,
{
Transaction::begin(self)
Transaction::begin(self, None)
}
fn begin_with(
&mut self,
statement: impl Into<Cow<'static, str>>,
) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
where
Self: Sized,
{
Transaction::begin(self, Some(statement.into()))
}
fn cached_statements_size(&self) -> usize {

View File

@@ -9,6 +9,7 @@ use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use sqlx_core::acquire::Acquire;
use sqlx_core::transaction::Transaction;
use sqlx_core::Either;
use tracing::Instrument;
use crate::describe::Describe;
use crate::error::Error;
@@ -366,7 +367,7 @@ impl Drop for PgListener {
};
// Unregister any listeners before returning the connection to the pool.
crate::rt::spawn(fut);
crate::rt::spawn(fut.in_current_span());
}
}
}

View File

@@ -1,4 +1,6 @@
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;
use std::borrow::Cow;
use crate::error::Error;
use crate::executor::Executor;
@@ -13,13 +15,27 @@ pub struct PgTransactionManager;
impl TransactionManager for PgTransactionManager {
type Database = Postgres;
fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
fn begin<'conn>(
conn: &'conn mut PgConnection,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'conn, Result<(), Error>> {
Box::pin(async move {
let depth = conn.inner.transaction_depth;
let statement = match statement {
// custom `BEGIN` statements are not allowed if we're already in
// a transaction (we need to issue a `SAVEPOINT` instead)
Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement),
Some(statement) => statement,
None => begin_ansi_transaction_sql(depth),
};
let rollback = Rollback::new(conn);
let query = begin_ansi_transaction_sql(rollback.conn.inner.transaction_depth);
rollback.conn.queue_simple_query(&query)?;
rollback.conn.inner.transaction_depth += 1;
rollback.conn.queue_simple_query(&statement)?;
rollback.conn.wait_until_ready().await?;
if !rollback.conn.in_transaction() {
return Err(Error::BeginFailed);
}
rollback.conn.inner.transaction_depth += 1;
rollback.defuse();
Ok(())
@@ -62,6 +78,10 @@ impl TransactionManager for PgTransactionManager {
conn.inner.transaction_depth -= 1;
}
}
fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.inner.transaction_depth
}
}
struct Rollback<'c> {

View File

@@ -40,12 +40,21 @@ impl_type_checking!(
sqlx::postgres::types::PgBox,
sqlx::postgres::types::PgPath,
sqlx::postgres::types::PgPolygon,
sqlx::postgres::types::PgCircle,
#[cfg(feature = "uuid")]
sqlx::types::Uuid,
#[cfg(feature = "ipnetwork")]
sqlx::types::ipnetwork::IpNetwork,
#[cfg(feature = "ipnet")]
sqlx::types::ipnet::IpNet,
#[cfg(feature = "mac_address")]
sqlx::types::mac_address::MacAddress,
@@ -77,6 +86,9 @@ impl_type_checking!(
#[cfg(feature = "ipnetwork")]
Vec<sqlx::types::ipnetwork::IpNetwork> | &[sqlx::types::ipnetwork::IpNetwork],
#[cfg(feature = "ipnet")]
Vec<sqlx::types::ipnet::IpNet> | &[sqlx::types::ipnet::IpNet],
#[cfg(feature = "mac_address")]
Vec<sqlx::types::mac_address::MacAddress> | &[sqlx::types::mac_address::MacAddress],

View File

@@ -23,7 +23,10 @@ const ERROR: &str = "error decoding BOX";
/// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box.
/// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order.
///
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES
/// See [Postgres Manual, Section 8.8.4: Geometric Types - Boxes][PG.S.8.8.4] for details.
///
/// [PG.S.8.8.4]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgBox {
pub upper_right_x: f64,

View File

@@ -0,0 +1,250 @@
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use sqlx_core::bytes::Buf;
use sqlx_core::Error;
use std::str::FromStr;
const ERROR: &str = "error decoding CIRCLE";
/// ## Postgres Geometric Circle type
///
/// Description: Circle
/// Representation: `< (x, y), radius >` (center point and radius)
///
/// ```text
/// < ( x , y ) , radius >
/// ( ( x , y ) , radius )
/// ( x , y ) , radius
/// x , y , radius
/// ```
/// where `(x,y)` is the center point.
///
/// See [Postgres Manual, Section 8.8.7, Geometric Types - Circles][PG.S.8.8.7] for details.
///
/// [PG.S.8.8.7]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-CIRCLE
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgCircle {
pub x: f64,
pub y: f64,
pub radius: f64,
}
impl Type<Postgres> for PgCircle {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("circle")
}
}
impl PgHasArrayType for PgCircle {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::with_name("_circle")
}
}
impl<'r> Decode<'r, Postgres> for PgCircle {
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
match value.format() {
PgValueFormat::Text => Ok(PgCircle::from_str(value.as_str()?)?),
PgValueFormat::Binary => Ok(PgCircle::from_bytes(value.as_bytes()?)?),
}
}
}
impl<'q> Encode<'q, Postgres> for PgCircle {
fn produces(&self) -> Option<PgTypeInfo> {
Some(PgTypeInfo::with_name("circle"))
}
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.serialize(buf)?;
Ok(IsNull::No)
}
}
impl FromStr for PgCircle {
type Err = BoxDynError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let sanitised = s.replace(['<', '>', '(', ')', ' '], "");
let mut parts = sanitised.split(',');
let x = parts
.next()
.and_then(|s| s.trim().parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get x from {}", ERROR, s))?;
let y = parts
.next()
.and_then(|s| s.trim().parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get y from {}", ERROR, s))?;
let radius = parts
.next()
.and_then(|s| s.trim().parse::<f64>().ok())
.ok_or_else(|| format!("{}: could not get radius from {}", ERROR, s))?;
if parts.next().is_some() {
return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into());
}
if radius < 0. {
return Err(format!("{}: cannot have negative radius: {}", ERROR, s).into());
}
Ok(PgCircle { x, y, radius })
}
}
impl PgCircle {
fn from_bytes(mut bytes: &[u8]) -> Result<PgCircle, Error> {
let x = bytes.get_f64();
let y = bytes.get_f64();
let r = bytes.get_f64();
Ok(PgCircle { x, y, radius: r })
}
fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), Error> {
buff.extend_from_slice(&self.x.to_be_bytes());
buff.extend_from_slice(&self.y.to_be_bytes());
buff.extend_from_slice(&self.radius.to_be_bytes());
Ok(())
}
#[cfg(test)]
fn serialize_to_vec(&self) -> Vec<u8> {
let mut buff = PgArgumentBuffer::default();
self.serialize(&mut buff).unwrap();
buff.to_vec()
}
}
#[cfg(test)]
mod circle_tests {
use std::str::FromStr;
use super::PgCircle;
const CIRCLE_BYTES: &[u8] = &[
63, 241, 153, 153, 153, 153, 153, 154, 64, 1, 153, 153, 153, 153, 153, 154, 64, 10, 102,
102, 102, 102, 102, 102,
];
#[test]
fn can_deserialise_circle_type_bytes() {
let circle = PgCircle::from_bytes(CIRCLE_BYTES).unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.1,
y: 2.2,
radius: 3.3
}
)
}
#[test]
fn can_deserialise_circle_type_str() {
let circle = PgCircle::from_str("<(1, 2), 3 >").unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.0,
y: 2.0,
radius: 3.0
}
);
}
#[test]
fn can_deserialise_circle_type_str_second_syntax() {
let circle = PgCircle::from_str("((1, 2), 3 )").unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.0,
y: 2.0,
radius: 3.0
}
);
}
#[test]
fn can_deserialise_circle_type_str_third_syntax() {
let circle = PgCircle::from_str("(1, 2), 3 ").unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.0,
y: 2.0,
radius: 3.0
}
);
}
#[test]
fn can_deserialise_circle_type_str_fourth_syntax() {
let circle = PgCircle::from_str("1, 2, 3 ").unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.0,
y: 2.0,
radius: 3.0
}
);
}
#[test]
fn cannot_deserialise_circle_invalid_numbers() {
let input_str = "1, 2, Three";
let circle = PgCircle::from_str(input_str);
assert!(circle.is_err());
if let Err(err) = circle {
assert_eq!(
err.to_string(),
format!("error decoding CIRCLE: could not get radius from {input_str}")
)
}
}
#[test]
fn cannot_deserialise_circle_negative_radius() {
let input_str = "1, 2, -3";
let circle = PgCircle::from_str(input_str);
assert!(circle.is_err());
if let Err(err) = circle {
assert_eq!(
err.to_string(),
format!("error decoding CIRCLE: cannot have negative radius: {input_str}")
)
}
}
#[test]
fn can_deserialise_circle_type_str_float() {
let circle = PgCircle::from_str("<(1.1, 2.2), 3.3>").unwrap();
assert_eq!(
circle,
PgCircle {
x: 1.1,
y: 2.2,
radius: 3.3
}
);
}
#[test]
fn can_serialise_circle_type() {
let circle = PgCircle {
x: 1.1,
y: 2.2,
radius: 3.3,
};
assert_eq!(circle.serialize_to_vec(), CIRCLE_BYTES,)
}
}

View File

@@ -15,7 +15,10 @@ const ERROR: &str = "error decoding LINE";
///
/// Lines are represented by the linear equation Ax + By + C = 0, where A and B are not both zero.
///
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LINE
/// See [Postgres Manual, Section 8.8.2, Geometric Types - Lines][PG.S.8.8.2] for details.
///
/// [PG.S.8.8.2]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LINE
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgLine {
pub a: f64,

View File

@@ -23,7 +23,10 @@ const ERROR: &str = "error decoding LSEG";
/// ```
/// where `(start_x,start_y) and (end_x,end_y)` are the end points of the line segment.
///
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-LSEG
/// See [Postgres Manual, Section 8.8.3, Geometric Types - Line Segments][PG.S.8.8.3] for details.
///
/// [PG.S.8.8.3]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-LSEG
///
#[doc(alias = "line segment")]
#[derive(Debug, Clone, PartialEq)]
pub struct PgLSeg {

View File

@@ -1,4 +1,7 @@
pub mod r#box;
pub mod circle;
pub mod line;
pub mod line_segment;
pub mod path;
pub mod point;
pub mod polygon;

View File

@@ -0,0 +1,375 @@
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::{PgPoint, Type};
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use sqlx_core::bytes::Buf;
use sqlx_core::Error;
use std::mem;
use std::str::FromStr;
const BYTE_WIDTH: usize = mem::size_of::<f64>();
/// ## Postgres Geometric Path type
///
/// Description: Open path or Closed path (similar to polygon)
/// Representation: Open `[(x1,y1),...]`, Closed `((x1,y1),...)`
///
/// Paths are represented by lists of connected points. Paths can be open, where the first and last points in the list are considered not connected, or closed, where the first and last points are considered connected.
/// Values of type path are specified using any of the following syntaxes:
/// ```text
/// [ ( x1 , y1 ) , ... , ( xn , yn ) ]
/// ( ( x1 , y1 ) , ... , ( xn , yn ) )
/// ( x1 , y1 ) , ... , ( xn , yn )
/// ( x1 , y1 , ... , xn , yn )
/// x1 , y1 , ... , xn , yn
/// ```
/// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path.
/// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed.
///
/// See [Postgres Manual, Section 8.8.5, Geometric Types - Paths][PG.S.8.8.5] for details.
///
/// [PG.S.8.8.5]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgPath {
pub closed: bool,
pub points: Vec<PgPoint>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct Header {
is_closed: bool,
length: usize,
}
impl Type<Postgres> for PgPath {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("path")
}
}
impl PgHasArrayType for PgPath {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::with_name("_path")
}
}
impl<'r> Decode<'r, Postgres> for PgPath {
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
match value.format() {
PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?),
PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?),
}
}
}
impl<'q> Encode<'q, Postgres> for PgPath {
fn produces(&self) -> Option<PgTypeInfo> {
Some(PgTypeInfo::with_name("path"))
}
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.serialize(buf)?;
Ok(IsNull::No)
}
}
impl FromStr for PgPath {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let closed = !s.contains('[');
let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
let parts = sanitised.split(',').collect::<Vec<_>>();
let mut points = vec![];
if parts.len() % 2 != 0 {
return Err(Error::Decode(
format!("Unmatched pair in PATH: {}", s).into(),
));
}
for chunk in parts.chunks_exact(2) {
if let [x_str, y_str] = chunk {
let x = parse_float_from_str(x_str, "could not get x")?;
let y = parse_float_from_str(y_str, "could not get y")?;
let point = PgPoint { x, y };
points.push(point);
}
}
if !points.is_empty() {
return Ok(PgPath { points, closed });
}
Err(Error::Decode(
format!("could not get path from {}", s).into(),
))
}
}
impl PgPath {
fn header(&self) -> Header {
Header {
is_closed: self.closed,
length: self.points.len(),
}
}
fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> {
let header = Header::try_read(&mut bytes)?;
if bytes.len() != header.data_size() {
return Err(format!(
"expected {} bytes after header, got {}",
header.data_size(),
bytes.len()
)
.into());
}
if bytes.len() % BYTE_WIDTH * 2 != 0 {
return Err(format!(
"data length not divisible by pairs of {BYTE_WIDTH}: {}",
bytes.len()
)
.into());
}
let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2));
while bytes.has_remaining() {
let point = PgPoint {
x: bytes.get_f64(),
y: bytes.get_f64(),
};
out_points.push(point)
}
Ok(PgPath {
closed: header.is_closed,
points: out_points,
})
}
fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> {
let header = self.header();
buff.reserve(header.data_size());
header.try_write(buff)?;
for point in &self.points {
buff.extend_from_slice(&point.x.to_be_bytes());
buff.extend_from_slice(&point.y.to_be_bytes());
}
Ok(())
}
#[cfg(test)]
fn serialize_to_vec(&self) -> Vec<u8> {
let mut buff = PgArgumentBuffer::default();
self.serialize(&mut buff).unwrap();
buff.to_vec()
}
}
impl Header {
const HEADER_WIDTH: usize = mem::size_of::<i8>() + mem::size_of::<i32>();
fn data_size(&self) -> usize {
self.length * BYTE_WIDTH * 2
}
fn try_read(buf: &mut &[u8]) -> Result<Self, String> {
if buf.len() < Self::HEADER_WIDTH {
return Err(format!(
"expected PATH data to contain at least {} bytes, got {}",
Self::HEADER_WIDTH,
buf.len()
));
}
let is_closed = buf.get_i8();
let length = buf.get_i32();
let length = usize::try_from(length).ok().ok_or_else(|| {
format!(
"received PATH data length: {length}. Expected length between 0 and {}",
usize::MAX
)
})?;
Ok(Self {
is_closed: is_closed != 0,
length,
})
}
fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
let is_closed = self.is_closed as i8;
let length = i32::try_from(self.length).map_err(|_| {
format!(
"PATH length exceeds allowed maximum ({} > {})",
self.length,
i32::MAX
)
})?;
buff.extend(is_closed.to_be_bytes());
buff.extend(length.to_be_bytes());
Ok(())
}
}
fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> {
s.parse().map_err(|_| Error::Decode(error_msg.into()))
}
#[cfg(test)]
mod path_tests {
use std::str::FromStr;
use crate::types::PgPoint;
use super::PgPath;
const PATH_CLOSED_BYTES: &[u8] = &[
1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
64, 16, 0, 0, 0, 0, 0, 0,
];
const PATH_OPEN_BYTES: &[u8] = &[
0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
64, 16, 0, 0, 0, 0, 0, 0,
];
const PATH_UNEVEN_POINTS: &[u8] = &[
0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
64, 16, 0, 0,
];
#[test]
fn can_deserialise_path_type_bytes_closed() {
let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap();
assert_eq!(
path,
PgPath {
closed: true,
points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }]
}
)
}
#[test]
fn cannot_deserialise_path_type_uneven_point_bytes() {
let path = PgPath::from_bytes(PATH_UNEVEN_POINTS);
assert!(path.is_err());
if let Err(err) = path {
assert_eq!(
err.to_string(),
format!("expected 32 bytes after header, got 28")
)
}
}
#[test]
fn can_deserialise_path_type_bytes_open() {
let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap();
assert_eq!(
path,
PgPath {
closed: false,
points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }]
}
)
}
#[test]
fn can_deserialise_path_type_str_first_syntax() {
let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap();
assert_eq!(
path,
PgPath {
closed: false,
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn cannot_deserialise_path_type_str_uneven_points_first_syntax() {
let input_str = "[( 1, 2), (3)]";
let path = PgPath::from_str(input_str);
assert!(path.is_err());
if let Err(err) = path {
assert_eq!(
err.to_string(),
format!("error occurred while decoding: Unmatched pair in PATH: {input_str}")
)
}
}
#[test]
fn can_deserialise_path_type_str_second_syntax() {
let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap();
assert_eq!(
path,
PgPath {
closed: true,
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_path_type_str_third_syntax() {
let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap();
assert_eq!(
path,
PgPath {
closed: true,
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_path_type_str_fourth_syntax() {
let path = PgPath::from_str("1, 2, 3, 4").unwrap();
assert_eq!(
path,
PgPath {
closed: true,
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_path_type_str_float() {
let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
assert_eq!(
path,
PgPath {
closed: true,
points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }]
}
);
}
#[test]
fn can_serialise_path_type() {
let path = PgPath {
closed: true,
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }],
};
assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,)
}
}

View File

@@ -19,7 +19,10 @@ use std::str::FromStr;
/// ````
/// where x and y are the respective coordinates, as floating-point numbers.
///
/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS
/// See [Postgres Manual, Section 8.8.1, Geometric Types - Points][PG.S.8.8.1] for details.
///
/// [PG.S.8.8.1]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-GEOMETRIC-POINTS
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgPoint {
pub x: f64,

View File

@@ -0,0 +1,366 @@
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::{PgPoint, Type};
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use sqlx_core::bytes::Buf;
use sqlx_core::Error;
use std::mem;
use std::str::FromStr;
const BYTE_WIDTH: usize = mem::size_of::<f64>();
/// ## Postgres Geometric Polygon type
///
/// Description: Polygon (similar to closed polygon)
/// Representation: `((x1,y1),...)`
///
/// Polygons are represented by lists of points (the vertexes of the polygon). Polygons are very similar to closed paths; the essential semantic difference is that a polygon is considered to include the area within it, while a path is not.
/// An important implementation difference between polygons and paths is that the stored representation of a polygon includes its smallest bounding box. This speeds up certain search operations, although computing the bounding box adds overhead while constructing new polygons.
/// Values of type polygon are specified using any of the following syntaxes:
///
/// ```text
/// ( ( x1 , y1 ) , ... , ( xn , yn ) )
/// ( x1 , y1 ) , ... , ( xn , yn )
/// ( x1 , y1 , ... , xn , yn )
/// x1 , y1 , ... , xn , yn
/// ```
///
/// where the points are the end points of the line segments comprising the boundary of the polygon.
///
/// See [Postgres Manual, Section 8.8.6, Geometric Types - Polygons][PG.S.8.8.6] for details.
///
/// [PG.S.8.8.6]: https://www.postgresql.org/docs/current/datatype-geometric.html#DATATYPE-POLYGON
///
#[derive(Debug, Clone, PartialEq)]
pub struct PgPolygon {
pub points: Vec<PgPoint>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
struct Header {
length: usize,
}
impl Type<Postgres> for PgPolygon {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("polygon")
}
}
impl PgHasArrayType for PgPolygon {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::with_name("_polygon")
}
}
impl<'r> Decode<'r, Postgres> for PgPolygon {
fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
match value.format() {
PgValueFormat::Text => Ok(PgPolygon::from_str(value.as_str()?)?),
PgValueFormat::Binary => Ok(PgPolygon::from_bytes(value.as_bytes()?)?),
}
}
}
impl<'q> Encode<'q, Postgres> for PgPolygon {
fn produces(&self) -> Option<PgTypeInfo> {
Some(PgTypeInfo::with_name("polygon"))
}
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
self.serialize(buf)?;
Ok(IsNull::No)
}
}
impl FromStr for PgPolygon {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
let parts = sanitised.split(',').collect::<Vec<_>>();
let mut points = vec![];
if parts.len() % 2 != 0 {
return Err(Error::Decode(
format!("Unmatched pair in POLYGON: {}", s).into(),
));
}
for chunk in parts.chunks_exact(2) {
if let [x_str, y_str] = chunk {
let x = parse_float_from_str(x_str, "could not get x")?;
let y = parse_float_from_str(y_str, "could not get y")?;
let point = PgPoint { x, y };
points.push(point);
}
}
if !points.is_empty() {
return Ok(PgPolygon { points });
}
Err(Error::Decode(
format!("could not get polygon from {}", s).into(),
))
}
}
impl PgPolygon {
fn header(&self) -> Header {
Header {
length: self.points.len(),
}
}
fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> {
let header = Header::try_read(&mut bytes)?;
if bytes.len() != header.data_size() {
return Err(format!(
"expected {} bytes after header, got {}",
header.data_size(),
bytes.len()
)
.into());
}
if bytes.len() % BYTE_WIDTH * 2 != 0 {
return Err(format!(
"data length not divisible by pairs of {BYTE_WIDTH}: {}",
bytes.len()
)
.into());
}
let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2));
while bytes.has_remaining() {
let point = PgPoint {
x: bytes.get_f64(),
y: bytes.get_f64(),
};
out_points.push(point)
}
Ok(PgPolygon { points: out_points })
}
fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> {
let header = self.header();
buff.reserve(header.data_size());
header.try_write(buff)?;
for point in &self.points {
buff.extend_from_slice(&point.x.to_be_bytes());
buff.extend_from_slice(&point.y.to_be_bytes());
}
Ok(())
}
#[cfg(test)]
fn serialize_to_vec(&self) -> Vec<u8> {
let mut buff = PgArgumentBuffer::default();
self.serialize(&mut buff).unwrap();
buff.to_vec()
}
}
impl Header {
const HEADER_WIDTH: usize = mem::size_of::<i8>() + mem::size_of::<i32>();
fn data_size(&self) -> usize {
self.length * BYTE_WIDTH * 2
}
fn try_read(buf: &mut &[u8]) -> Result<Self, String> {
if buf.len() < Self::HEADER_WIDTH {
return Err(format!(
"expected polygon data to contain at least {} bytes, got {}",
Self::HEADER_WIDTH,
buf.len()
));
}
let length = buf.get_i32();
let length = usize::try_from(length).ok().ok_or_else(|| {
format!(
"received polygon with length: {length}. Expected length between 0 and {}",
usize::MAX
)
})?;
Ok(Self { length })
}
fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
let length = i32::try_from(self.length).map_err(|_| {
format!(
"polygon length exceeds allowed maximum ({} > {})",
self.length,
i32::MAX
)
})?;
buff.extend(length.to_be_bytes());
Ok(())
}
}
fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> {
s.parse().map_err(|_| Error::Decode(error_msg.into()))
}
#[cfg(test)]
mod polygon_tests {
use std::str::FromStr;
use crate::types::PgPoint;
use super::PgPolygon;
const POLYGON_BYTES: &[u8] = &[
0, 0, 0, 12, 192, 0, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0,
0, 192, 8, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 63,
240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0,
0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192,
8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0, 0, 0, 0, 192, 8, 0, 0, 0, 0, 0, 0, 63, 240, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191, 240, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 191,
240, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0,
0, 0, 0,
];
#[test]
fn can_deserialise_polygon_type_bytes() {
let polygon = PgPolygon::from_bytes(POLYGON_BYTES).unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![
PgPoint { x: -2., y: -3. },
PgPoint { x: -1., y: -3. },
PgPoint { x: -1., y: -1. },
PgPoint { x: 1., y: 1. },
PgPoint { x: 1., y: 3. },
PgPoint { x: 2., y: 3. },
PgPoint { x: 2., y: -3. },
PgPoint { x: 1., y: -3. },
PgPoint { x: 1., y: 0. },
PgPoint { x: -1., y: 0. },
PgPoint { x: -1., y: -2. },
PgPoint { x: -2., y: -2. }
]
}
)
}
#[test]
fn can_deserialise_polygon_type_str_first_syntax() {
let polygon = PgPolygon::from_str("[( 1, 2), (3, 4 )]").unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_polygon_type_str_second_syntax() {
let polygon = PgPolygon::from_str("(( 1, 2), (3, 4 ))").unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn cannot_deserialise_polygon_type_str_uneven_points_first_syntax() {
let input_str = "[( 1, 2), (3)]";
let polygon = PgPolygon::from_str(input_str);
assert!(polygon.is_err());
if let Err(err) = polygon {
assert_eq!(
err.to_string(),
format!("error occurred while decoding: Unmatched pair in POLYGON: {input_str}")
)
}
}
#[test]
fn cannot_deserialise_polygon_type_str_invalid_numbers() {
let input_str = "[( 1, 2), (2, three)]";
let polygon = PgPolygon::from_str(input_str);
assert!(polygon.is_err());
if let Err(err) = polygon {
assert_eq!(
err.to_string(),
format!("error occurred while decoding: could not get y")
)
}
}
#[test]
fn can_deserialise_polygon_type_str_third_syntax() {
let polygon = PgPolygon::from_str("(1, 2), (3, 4 )").unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_polygon_type_str_fourth_syntax() {
let polygon = PgPolygon::from_str("1, 2, 3, 4").unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
}
);
}
#[test]
fn can_deserialise_polygon_type_str_float() {
let polygon = PgPolygon::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
assert_eq!(
polygon,
PgPolygon {
points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }]
}
);
}
#[test]
fn can_serialise_polygon_type() {
let polygon = PgPolygon {
points: vec![
PgPoint { x: -2., y: -3. },
PgPoint { x: -1., y: -3. },
PgPoint { x: -1., y: -1. },
PgPoint { x: 1., y: 1. },
PgPoint { x: 1., y: 3. },
PgPoint { x: 2., y: 3. },
PgPoint { x: 2., y: -3. },
PgPoint { x: 1., y: -3. },
PgPoint { x: 1., y: 0. },
PgPoint { x: -1., y: 0. },
PgPoint { x: -1., y: -2. },
PgPoint { x: -2., y: -2. },
],
};
assert_eq!(polygon.serialize_to_vec(), POLYGON_BYTES,)
}
}

View File

@@ -0,0 +1,62 @@
use std::net::IpAddr;
use ipnet::IpNet;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};
impl Type<Postgres> for IpAddr
where
IpNet: Type<Postgres>,
{
fn type_info() -> PgTypeInfo {
IpNet::type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
IpNet::compatible(ty)
}
}
impl PgHasArrayType for IpAddr {
fn array_type_info() -> PgTypeInfo {
<IpNet as PgHasArrayType>::array_type_info()
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
<IpNet as PgHasArrayType>::array_compatible(ty)
}
}
impl<'db> Encode<'db, Postgres> for IpAddr
where
IpNet: Encode<'db, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
IpNet::from(*self).encode_by_ref(buf)
}
fn size_hint(&self) -> usize {
IpNet::from(*self).size_hint()
}
}
impl<'db> Decode<'db, Postgres> for IpAddr
where
IpNet: Decode<'db, Postgres>,
{
fn decode(value: PgValueRef<'db>) -> Result<Self, BoxDynError> {
let ipnet = IpNet::decode(value)?;
if matches!(ipnet, IpNet::V4(net) if net.prefix_len() != 32)
|| matches!(ipnet, IpNet::V6(net) if net.prefix_len() != 128)
{
Err("lossy decode from inet/cidr")?
}
Ok(ipnet.addr())
}
}

View File

@@ -0,0 +1,130 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[cfg(feature = "ipnet")]
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39
// Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc`
// just for one constant.
const PGSQL_AF_INET: u8 = 2; // AF_INET
const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1;
impl Type<Postgres> for IpNet {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INET
}
fn compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET
}
}
impl PgHasArrayType for IpNet {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::INET_ARRAY
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY
}
}
impl Encode<'_, Postgres> for IpNet {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293
// https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271
match self {
IpNet::V4(net) => {
buf.push(PGSQL_AF_INET); // ip_family
buf.push(net.prefix_len()); // ip_bits
buf.push(0); // is_cidr
buf.push(4); // nb (number of bytes)
buf.extend_from_slice(&net.addr().octets()) // address
}
IpNet::V6(net) => {
buf.push(PGSQL_AF_INET6); // ip_family
buf.push(net.prefix_len()); // ip_bits
buf.push(0); // is_cidr
buf.push(16); // nb (number of bytes)
buf.extend_from_slice(&net.addr().octets()); // address
}
}
Ok(IsNull::No)
}
fn size_hint(&self) -> usize {
match self {
IpNet::V4(_) => 8,
IpNet::V6(_) => 20,
}
}
}
impl Decode<'_, Postgres> for IpNet {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
let bytes = match value.format() {
PgValueFormat::Binary => value.as_bytes()?,
PgValueFormat::Text => {
let s = value.as_str()?;
println!("{s}");
if s.contains('/') {
return Ok(s.parse()?);
}
// IpNet::from_str doesn't handle conversion from IpAddr to IpNet
let addr: IpAddr = s.parse()?;
return Ok(addr.into());
}
};
if bytes.len() >= 8 {
let family = bytes[0];
let prefix = bytes[1];
let _is_cidr = bytes[2] != 0;
let len = bytes[3];
match family {
PGSQL_AF_INET => {
if bytes.len() == 8 && len == 4 {
let inet = Ipv4Net::new(
Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]),
prefix,
)?;
return Ok(IpNet::V4(inet));
}
}
PGSQL_AF_INET6 => {
if bytes.len() == 20 && len == 16 {
let inet = Ipv6Net::new(
Ipv6Addr::from([
bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9],
bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
bytes[16], bytes[17], bytes[18], bytes[19],
]),
prefix,
)?;
return Ok(IpNet::V6(inet));
}
}
_ => {
return Err(format!("unknown ip family {family}").into());
}
}
}
Err("invalid data received when expecting an INET".into())
}
}

View File

@@ -0,0 +1,7 @@
// Prefer `ipnetwork` over `ipnet` because it was implemented first (want to avoid breaking change).
#[cfg(not(feature = "ipnetwork"))]
mod ipaddr;
// Parent module is named after the `ipnet` crate, this is named after the `IpNet` type.
#[allow(clippy::module_inception)]
mod ipnet;

View File

@@ -0,0 +1,5 @@
mod ipaddr;
// Parent module is named after the `ipnetwork` crate, this is named after the `IpNetwork` type.
#[allow(clippy::module_inception)]
mod ipnetwork;

View File

@@ -25,6 +25,9 @@
//! | [`PgLine`] | LINE |
//! | [`PgLSeg`] | LSEG |
//! | [`PgBox`] | BOX |
//! | [`PgPath`] | PATH |
//! | [`PgPolygon`] | POLYGON |
//! | [`PgCircle`] | CIRCLE |
//! | [`PgHstore`] | HSTORE |
//!
//! <sup>1</sup> SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc.,
@@ -84,7 +87,7 @@
//!
//! ### [`ipnetwork`](https://crates.io/crates/ipnetwork)
//!
//! Requires the `ipnetwork` Cargo feature flag.
//! Requires the `ipnetwork` Cargo feature flag (takes precedence over `ipnet` if both are used).
//!
//! | Rust type | Postgres type(s) |
//! |---------------------------------------|------------------------------------------------------|
@@ -97,6 +100,17 @@
//!
//! `IpNetwork` does not have this limitation.
//!
//! ### [`ipnet`](https://crates.io/crates/ipnet)
//!
//! Requires the `ipnet` Cargo feature flag.
//!
//! | Rust type | Postgres type(s) |
//! |---------------------------------------|------------------------------------------------------|
//! | `ipnet::IpNet` | INET, CIDR |
//! | `std::net::IpAddr` | INET, CIDR |
//!
//! The same `IpAddr` limitation for smaller network prefixes applies as with `ipnet`.
//!
//! ### [`mac_address`](https://crates.io/crates/mac_address)
//!
//! Requires the `mac_address` Cargo feature flag.
@@ -245,11 +259,11 @@ mod time;
#[cfg(feature = "uuid")]
mod uuid;
#[cfg(feature = "ipnetwork")]
mod ipnetwork;
#[cfg(feature = "ipnet")]
mod ipnet;
#[cfg(feature = "ipnetwork")]
mod ipaddr;
mod ipnetwork;
#[cfg(feature = "mac_address")]
mod mac_address;
@@ -260,9 +274,12 @@ mod bit_vec;
pub use array::PgHasArrayType;
pub use citext::PgCiText;
pub use cube::PgCube;
pub use geometry::circle::PgCircle;
pub use geometry::line::PgLine;
pub use geometry::line_segment::PgLSeg;
pub use geometry::path::PgPath;
pub use geometry::point::PgPoint;
pub use geometry::polygon::PgPolygon;
pub use geometry::r#box::PgBox;
pub use hstore::PgHstore;
pub use interval::PgInterval;