From 1a59d3308a5e26863fab6bdb6e0b77270df3a6da Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Mon, 6 Jul 2020 19:39:33 +0200 Subject: [PATCH] Pg: Implementing `BIT` and `VARBIT` using BitVec --- Cargo.lock | 7 ++ Cargo.toml | 3 +- sqlx-core/Cargo.toml | 3 +- sqlx-core/src/postgres/types/bit_vec.rs | 107 ++++++++++++++++++++++++ sqlx-core/src/postgres/types/mod.rs | 11 +++ sqlx-core/src/types/mod.rs | 4 + sqlx-macros/Cargo.toml | 1 + sqlx-macros/src/database/postgres.rs | 3 + tests/postgres/types.rs | 25 ++++++ 9 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 sqlx-core/src/postgres/types/bit_vec.rs diff --git a/Cargo.lock b/Cargo.lock index 00724728..b7019d74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,6 +212,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bit-vec" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0dc55f2d8a1a85650ac47858bb001b4c0dd73d79e3c455a842925e68d29cd3" + [[package]] name = "bitflags" version = "1.2.1" @@ -2033,6 +2039,7 @@ dependencies = [ "atoi", "base64", "bigdecimal", + "bit-vec", "bitflags", "byteorder", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 1fe09d4d..0c9eec32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ offline = [ "sqlx-macros/offline", "sqlx-core/offline" ] # intended mainly for CI and docs all = [ "tls", "all-databases", "all-types" ] all-databases = [ "mysql", "sqlite", "postgres", "mssql", "any" ] -all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid" ] +all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid", "bit-vec" ] # runtime runtime-async-std = [ "sqlx-core/runtime-async-std", "sqlx-macros/runtime-async-std" ] @@ -71,6 +71,7 @@ ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-macros/ipnetwork" ] uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ] json = [ "sqlx-core/json", "sqlx-macros/json" ] time = [ "sqlx-core/time", "sqlx-macros/time" ] +bit-vec = [ "sqlx-core/bit-vec", "sqlx-macros/bit-vec"] [dependencies] sqlx-core = { version = "=0.4.0-beta.1", path = "sqlx-core", default-features = false } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 75058f29..8e6fba71 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -25,7 +25,7 @@ mssql = [ "uuid", "encoding_rs", "regex" ] any = [] # types -all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid" ] +all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid", "bit-vec" ] bigdecimal = [ "bigdecimal_", "num-bigint" ] decimal = [ "rust_decimal", "num-bigint", "num-traits" ] json = [ "serde", "serde_json" ] @@ -45,6 +45,7 @@ base64 = { version = "0.12.1", default-features = false, optional = true, featur bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" } rust_decimal = { version = "1.6.0", optional = true } num-traits = { version = "0.2.12", optional = true } +bit-vec = { version = "0.6.2", optional = true } bitflags = { version = "1.2.1", default-features = false } bytes = "0.5.4" byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] } diff --git a/sqlx-core/src/postgres/types/bit_vec.rs b/sqlx-core/src/postgres/types/bit_vec.rs new file mode 100644 index 00000000..68809e9d --- /dev/null +++ b/sqlx-core/src/postgres/types/bit_vec.rs @@ -0,0 +1,107 @@ +use crate::{ + decode::Decode, + encode::{Encode, IsNull}, + error::BoxDynError, + postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}, + types::Type, +}; +use bit_vec::BitVec; +use bytes::Buf; +use std::{io, mem}; + +impl Type for BitVec { + fn type_info() -> PgTypeInfo { + PgTypeInfo::VARBIT + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::BIT || *ty == PgTypeInfo::VARBIT + } +} + +impl Type for [BitVec] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::VARBIT_ARRAY + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::BIT_ARRAY || *ty == PgTypeInfo::VARBIT_ARRAY + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[BitVec] as Type>::type_info() + } +} + +impl Encode<'_, Postgres> for BitVec { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + buf.extend(&(self.len() as i32).to_be_bytes()); + buf.extend(self.to_bytes()); + + IsNull::No + } + + fn size_hint(&self) -> usize { + mem::size_of::() + self.len() + } +} + +impl Decode<'_, Postgres> for BitVec { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let mut bytes = value.as_bytes()?; + let len = bytes.get_i32(); + + if len < 0 { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "Negative VARBIT length.", + ))? + } + + // The smallest amount of data we can read is one byte + let bytes_len = (len as usize + 7) / 8; + + if bytes.remaining() != bytes_len { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "VARBIT length mismatch.", + ))?; + } + + let mut bitvec = BitVec::from_bytes(bytes.bytes()); + + // Chop off zeroes from the back. We get bits in bytes, so if + // our bitvec is not in full bytes, extra zeroes are added to + // the end. + while bitvec.len() > len as usize { + bitvec.pop(); + } + + Ok(bitvec) + } + PgValueFormat::Text => { + let s = value.as_str()?; + let mut bit_vec = BitVec::with_capacity(s.len()); + + for c in s.chars() { + match c { + '0' => bit_vec.push(false), + '1' => bit_vec.push(true), + _ => { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "VARBIT data contains other characters than 1 or 0.", + ))?; + } + } + } + + Ok(bit_vec) + } + } + } +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 811ee847..dff2ba2d 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -78,6 +78,14 @@ //! |---------------------------------------|------------------------------------------------------| //! | `ipnetwork::IpNetwork` | INET, CIDR | //! +//! ### [`bit-vec`](https://crates.io/crates/bit-vec) +//! +//! Requires the `bit-vec` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bit_vec::BitVec` | BIT, VARBIT | +//! //! ### [`json`](https://crates.io/crates/serde_json) //! //! Requires the `json` Cargo feature flag. @@ -193,6 +201,9 @@ mod json; #[cfg(feature = "ipnetwork")] mod ipnetwork; +#[cfg(feature = "bit-vec")] +mod bit_vec; + pub use interval::PgInterval; pub use money::PgMoney; pub use range::PgRange; diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index d9bb35b3..92a512ac 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -36,6 +36,10 @@ pub mod chrono { }; } +#[cfg(feature = "bit-vec")] +#[cfg_attr(docsrs, doc(cfg(feature = "bit-vec")))] +pub use bit_vec::BitVec; + #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] pub mod time { diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 9f3ecc25..8063ca91 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -40,6 +40,7 @@ chrono = [ "sqlx-core/chrono" ] time = [ "sqlx-core/time" ] ipnetwork = [ "sqlx-core/ipnetwork" ] uuid = [ "sqlx-core/uuid" ] +bit-vec = [ "sqlx-core/bit-vec" ] json = [ "sqlx-core/json", "serde_json" ] [dependencies] diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index c22b8ba8..a29b56cc 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -61,6 +61,9 @@ impl_database_ext! { #[cfg(feature = "json")] serde_json::Value, + #[cfg(feature = "bit-vec")] + sqlx::types::BitVec, + // Arrays Vec | &[bool], diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index fe3f8e65..9526f357 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -167,6 +167,31 @@ test_type!(ipnetwork(Postgres, .unwrap(), )); +#[cfg(feature = "bit-vec")] +test_type!(bitvec( + Postgres, + // A full byte VARBIT + "B'01101001'" == sqlx::types::BitVec::from_bytes(&[0b0110_1001]), + // A VARBIT value missing five bits from a byte + "B'110'" == { + let mut bit_vec = sqlx::types::BitVec::with_capacity(4); + bit_vec.push(true); + bit_vec.push(true); + bit_vec.push(false); + bit_vec + }, + // A BIT value + "B'01101'::bit(5)" == { + let mut bit_vec = sqlx::types::BitVec::with_capacity(5); + bit_vec.push(false); + bit_vec.push(true); + bit_vec.push(true); + bit_vec.push(false); + bit_vec.push(true); + bit_vec + }, +)); + #[cfg(feature = "ipnetwork")] test_type!(ipnetwork_vec>(Postgres, "'{127.0.0.1,8.8.8.8/24}'::inet[]"