Pg: Implementing BIT and VARBIT using BitVec

This commit is contained in:
Julius de Bruijn 2020-07-06 19:39:33 +02:00 committed by Ryan Leckey
parent e419bf9dfa
commit 1a59d3308a
9 changed files with 162 additions and 2 deletions

7
Cargo.lock generated
View File

@ -212,6 +212,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "bit-vec"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0dc55f2d8a1a85650ac47858bb001b4c0dd73d79e3c455a842925e68d29cd3"
[[package]]
name = "bitflags"
version = "1.2.1"
@ -2033,6 +2039,7 @@ dependencies = [
"atoi",
"base64",
"bigdecimal",
"bit-vec",
"bitflags",
"byteorder",
"bytes",

View File

@ -49,7 +49,7 @@ offline = [ "sqlx-macros/offline", "sqlx-core/offline" ]
# intended mainly for CI and docs
all = [ "tls", "all-databases", "all-types" ]
all-databases = [ "mysql", "sqlite", "postgres", "mssql", "any" ]
all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid" ]
all-types = [ "bigdecimal", "decimal", "json", "time", "chrono", "ipnetwork", "uuid", "bit-vec" ]
# runtime
runtime-async-std = [ "sqlx-core/runtime-async-std", "sqlx-macros/runtime-async-std" ]
@ -71,6 +71,7 @@ ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-macros/ipnetwork" ]
uuid = [ "sqlx-core/uuid", "sqlx-macros/uuid" ]
json = [ "sqlx-core/json", "sqlx-macros/json" ]
time = [ "sqlx-core/time", "sqlx-macros/time" ]
bit-vec = [ "sqlx-core/bit-vec", "sqlx-macros/bit-vec"]
[dependencies]
sqlx-core = { version = "=0.4.0-beta.1", path = "sqlx-core", default-features = false }

View File

@ -25,7 +25,7 @@ mssql = [ "uuid", "encoding_rs", "regex" ]
any = []
# types
all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid" ]
all-types = [ "chrono", "time", "bigdecimal", "decimal", "ipnetwork", "json", "uuid", "bit-vec" ]
bigdecimal = [ "bigdecimal_", "num-bigint" ]
decimal = [ "rust_decimal", "num-bigint", "num-traits" ]
json = [ "serde", "serde_json" ]
@ -45,6 +45,7 @@ base64 = { version = "0.12.1", default-features = false, optional = true, featur
bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" }
rust_decimal = { version = "1.6.0", optional = true }
num-traits = { version = "0.2.12", optional = true }
bit-vec = { version = "0.6.2", optional = true }
bitflags = { version = "1.2.1", default-features = false }
bytes = "0.5.4"
byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] }

View File

@ -0,0 +1,107 @@
use crate::{
decode::Decode,
encode::{Encode, IsNull},
error::BoxDynError,
postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres},
types::Type,
};
use bit_vec::BitVec;
use bytes::Buf;
use std::{io, mem};
impl Type<Postgres> for BitVec {
fn type_info() -> PgTypeInfo {
PgTypeInfo::VARBIT
}
fn compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::BIT || *ty == PgTypeInfo::VARBIT
}
}
impl Type<Postgres> for [BitVec] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::VARBIT_ARRAY
}
fn compatible(ty: &PgTypeInfo) -> bool {
*ty == PgTypeInfo::BIT_ARRAY || *ty == PgTypeInfo::VARBIT_ARRAY
}
}
impl Type<Postgres> for Vec<BitVec> {
fn type_info() -> PgTypeInfo {
<[BitVec] as Type<Postgres>>::type_info()
}
}
impl Encode<'_, Postgres> for BitVec {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
buf.extend(&(self.len() as i32).to_be_bytes());
buf.extend(self.to_bytes());
IsNull::No
}
fn size_hint(&self) -> usize {
mem::size_of::<i32>() + self.len()
}
}
impl Decode<'_, Postgres> for BitVec {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
match value.format() {
PgValueFormat::Binary => {
let mut bytes = value.as_bytes()?;
let len = bytes.get_i32();
if len < 0 {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Negative VARBIT length.",
))?
}
// The smallest amount of data we can read is one byte
let bytes_len = (len as usize + 7) / 8;
if bytes.remaining() != bytes_len {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"VARBIT length mismatch.",
))?;
}
let mut bitvec = BitVec::from_bytes(bytes.bytes());
// Chop off zeroes from the back. We get bits in bytes, so if
// our bitvec is not in full bytes, extra zeroes are added to
// the end.
while bitvec.len() > len as usize {
bitvec.pop();
}
Ok(bitvec)
}
PgValueFormat::Text => {
let s = value.as_str()?;
let mut bit_vec = BitVec::with_capacity(s.len());
for c in s.chars() {
match c {
'0' => bit_vec.push(false),
'1' => bit_vec.push(true),
_ => {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"VARBIT data contains other characters than 1 or 0.",
))?;
}
}
}
Ok(bit_vec)
}
}
}
}

View File

@ -78,6 +78,14 @@
//! |---------------------------------------|------------------------------------------------------|
//! | `ipnetwork::IpNetwork` | INET, CIDR |
//!
//! ### [`bit-vec`](https://crates.io/crates/bit-vec)
//!
//! Requires the `bit-vec` Cargo feature flag.
//!
//! | Rust type | Postgres type(s) |
//! |---------------------------------------|------------------------------------------------------|
//! | `bit_vec::BitVec` | BIT, VARBIT |
//!
//! ### [`json`](https://crates.io/crates/serde_json)
//!
//! Requires the `json` Cargo feature flag.
@ -193,6 +201,9 @@ mod json;
#[cfg(feature = "ipnetwork")]
mod ipnetwork;
#[cfg(feature = "bit-vec")]
mod bit_vec;
pub use interval::PgInterval;
pub use money::PgMoney;
pub use range::PgRange;

View File

@ -36,6 +36,10 @@ pub mod chrono {
};
}
#[cfg(feature = "bit-vec")]
#[cfg_attr(docsrs, doc(cfg(feature = "bit-vec")))]
pub use bit_vec::BitVec;
#[cfg(feature = "time")]
#[cfg_attr(docsrs, doc(cfg(feature = "time")))]
pub mod time {

View File

@ -40,6 +40,7 @@ chrono = [ "sqlx-core/chrono" ]
time = [ "sqlx-core/time" ]
ipnetwork = [ "sqlx-core/ipnetwork" ]
uuid = [ "sqlx-core/uuid" ]
bit-vec = [ "sqlx-core/bit-vec" ]
json = [ "sqlx-core/json", "serde_json" ]
[dependencies]

View File

@ -61,6 +61,9 @@ impl_database_ext! {
#[cfg(feature = "json")]
serde_json::Value,
#[cfg(feature = "bit-vec")]
sqlx::types::BitVec,
// Arrays
Vec<bool> | &[bool],

View File

@ -167,6 +167,31 @@ test_type!(ipnetwork<sqlx::types::ipnetwork::IpNetwork>(Postgres,
.unwrap(),
));
#[cfg(feature = "bit-vec")]
test_type!(bitvec<sqlx::types::BitVec>(
Postgres,
// A full byte VARBIT
"B'01101001'" == sqlx::types::BitVec::from_bytes(&[0b0110_1001]),
// A VARBIT value missing five bits from a byte
"B'110'" == {
let mut bit_vec = sqlx::types::BitVec::with_capacity(4);
bit_vec.push(true);
bit_vec.push(true);
bit_vec.push(false);
bit_vec
},
// A BIT value
"B'01101'::bit(5)" == {
let mut bit_vec = sqlx::types::BitVec::with_capacity(5);
bit_vec.push(false);
bit_vec.push(true);
bit_vec.push(true);
bit_vec.push(false);
bit_vec.push(true);
bit_vec
},
));
#[cfg(feature = "ipnetwork")]
test_type!(ipnetwork_vec<Vec<sqlx::types::ipnetwork::IpNetwork>>(Postgres,
"'{127.0.0.1,8.8.8.8/24}'::inet[]"