diff --git a/Cargo.lock b/Cargo.lock index 5fc0c369..51115be7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -829,6 +829,15 @@ dependencies = [ "libc", ] +[[package]] +name = "ipnetwork" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8eca9f51da27bc908ef3dd85c21e1bbba794edaf94d7841e37356275b82d31e" +dependencies = [ + "serde", +] + [[package]] name = "itoa" version = "0.4.5" @@ -1706,6 +1715,8 @@ dependencies = [ "generic-array", "hex", "hmac", + "ipnetwork", + "libc", "libsqlite3-sys", "log", "matches", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 4a4d6227..9b16aaa6 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -18,6 +18,7 @@ unstable = [] # we need a feature which activates `num-bigint` as well because # `bigdecimal` uses types from it but does not reexport (tsk tsk) bigdecimal_bigint = ["bigdecimal", "num-bigint"] +network-address = [ "ipnetwork", "libc" ] postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] sqlite = [ "libsqlite3-sys" ] @@ -43,6 +44,8 @@ futures-util = { version = "0.3.4", default-features = false } generic-array = { version = "0.12.3", default-features = false, optional = true } hex = "0.4.2" hmac = { version = "0.7.1", default-features = false, optional = true } +ipnetwork = { version = "0.16.0", default-feature = false, optional = true } +libc = { version = "0.2.68", default-feature = false, optional = true } log = { version = "0.4.8", default-features = false } md-5 = { version = "0.8.0", default-features = false, optional = true } memchr = { version = "2.3.3", default-features = false } diff --git a/sqlx-core/src/postgres/protocol/type_id.rs b/sqlx-core/src/postgres/protocol/type_id.rs index b2866c57..5a6b9578 100644 --- a/sqlx-core/src/postgres/protocol/type_id.rs +++ b/sqlx-core/src/postgres/protocol/type_id.rs @@ -33,6 +33,9 @@ impl TypeId { pub(crate) const UUID: TypeId = TypeId(2950); + pub(crate) const CIDR: TypeId = TypeId(650); + pub(crate) const INET: TypeId = TypeId(869); + // Arrays pub(crate) const ARRAY_BOOL: TypeId = TypeId(1000); @@ -56,4 +59,7 @@ impl TypeId { pub(crate) const ARRAY_BYTEA: TypeId = TypeId(1001); pub(crate) const ARRAY_UUID: TypeId = TypeId(2951); + + pub(crate) const ARRAY_CIDR: TypeId = TypeId(651); + pub(crate) const ARRAY_INET: TypeId = TypeId(1041); } diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 384a08b5..98574509 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -33,6 +33,14 @@ //! |---------------------------------------|------------------------------------------------------| //! | `uuid::Uuid` | UUID | //! +//! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) +//! +//! Requires the `network-address` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `ipnetwork::IpNetwork` | INET, CIDR | +//! //! # Composite types //! //! Anonymous composite types are represented as tuples. @@ -70,6 +78,8 @@ mod chrono; #[cfg(feature = "uuid")] mod uuid; +mod network; + /// Type information for a Postgres SQL type. #[derive(Debug, Clone)] pub struct PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/network.rs b/sqlx-core/src/postgres/types/network.rs new file mode 100644 index 00000000..18135ef4 --- /dev/null +++ b/sqlx-core/src/postgres/types/network.rs @@ -0,0 +1,119 @@ +use std::convert::TryInto; +use std::net::{Ipv4Addr, Ipv6Addr}; + +use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; + +use crate::decode::Decode; +use crate::encode::Encode; +use crate::postgres::protocol::TypeId; +use crate::postgres::row::PgValue; +use crate::postgres::types::PgTypeInfo; +use crate::postgres::Postgres; +use crate::types::Type; +use crate::Error; + +#[cfg(windows)] +const AF_INET: u8 = 2; +// Maybe not used, but defining to follow Rust's libstd/net/sys +#[cfg(redox)] +const AF_INET: u8 = 1; +#[cfg(not(any(windows, redox)))] +const AF_INET: u8 = libc::AF_INET as u8; + +const PGSQL_AF_INET: u8 = AF_INET; +const PGSQL_AF_INET6: u8 = AF_INET + 1; + +const INET_TYPE: u8 = 0; +const CIDR_TYPE: u8 = 1; + +impl Type for IpNetwork { + fn type_info() -> PgTypeInfo { + PgTypeInfo::new(TypeId::INET, "INET") + } +} + +impl Type for [IpNetwork] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::new(TypeId::ARRAY_INET, "INET[]") + } +} + +impl Encode for IpNetwork { + fn encode(&self, buf: &mut Vec) { + encode(self, INET_TYPE, buf) + } + + fn size_hint(&self) -> usize { + match self { + IpNetwork::V4(_) => 8, + IpNetwork::V6(_) => 20, + } + } +} + +impl<'de> Decode<'de, Postgres> for IpNetwork { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + PgValue::Binary(buf) => decode(buf, INET_TYPE), + PgValue::Text(s) => decode(s.as_bytes(), INET_TYPE), + } + } +} + +fn encode(net: &IpNetwork, net_type: u8, buf: &mut Vec) { + match net { + IpNetwork::V4(net) => { + buf.push(PGSQL_AF_INET); + buf.push(net.prefix()); + buf.push(net_type); + buf.push(4); + buf.extend_from_slice(&net.ip().octets()); + } + IpNetwork::V6(net) => { + buf.push(PGSQL_AF_INET6); + buf.push(net.prefix()); + buf.push(net_type); + buf.push(16); + buf.extend_from_slice(&net.ip().octets()); + } + } +} + +fn decode(bytes: &[u8], net_type: u8) -> crate::Result { + if bytes.len() <= 8 { + return Err(Error::Decode("Input too short".into())); + } + + let af = bytes[0]; + let prefix = bytes[1]; + let type_ = bytes[2]; + let len = bytes[3]; + + if type_ == net_type { + if af == PGSQL_AF_INET && bytes.len() == 8 && len == 4 { + let inet = Ipv4Network::new( + Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), + prefix, + ) + .map_err(Error::decode)?; + + return Ok(IpNetwork::V4(inet)); + } + + if af == PGSQL_AF_INET6 && bytes.len() == 20 && len == 16 { + let inet = Ipv6Network::new( + Ipv6Addr::from([ + bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], + bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], bytes[16], bytes[17], + bytes[18], bytes[19], + ]), + prefix, + ) + .map_err(Error::decode)?; + + return Ok(IpNetwork::V6(inet)); + } + } + + return Err(Error::Decode("Invalid inet_struct".into())); +}