use protocol_err! instead of InvalidData in more places

This commit is contained in:
Austin Bonander 2019-11-22 17:57:00 +00:00
parent 905320ff39
commit 4d033963ce
17 changed files with 48 additions and 96 deletions

View File

@ -18,6 +18,7 @@ use std::{
net::{IpAddr, SocketAddr},
};
use url::Url;
use url::quirks::protocol;
pub struct MariaDb {
pub(crate) stream: BufStream<TcpStream>,
@ -145,24 +146,15 @@ impl MariaDb {
0xfe | 0x00 => OkPacket::decode(buf, capabilities)?,
0xff => {
let err = ErrPacket::decode(buf)?;
// TODO: Bubble as Error::Database
// panic!("received db err = {:?}", err);
return Err(
io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", err)).into(),
);
return ErrPacket::decode(buf)?.expect_error();
}
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
return Err(protocol_err!(
"unexpected packet identifier 0x{:X?} when expecting 0xFE (OK) or 0xFF \
(ERR)",
id
),
)
)
.into());
}
})
@ -228,8 +220,7 @@ impl MariaDb {
if packet[0] == 0x00 {
let _ok = OkPacket::decode(packet, capabilities)?;
} else if packet[0] == 0xFF {
let err = ErrPacket::decode(packet)?;
panic!("received db err = {:?}", err);
return ErrPacket::decode(packet)?.expect_error();
} else {
// A Resultset starts with a [ColumnCountPacket] which is a single field that encodes
// how many columns we can expect when fetching rows from this statement

View File

@ -18,14 +18,11 @@ pub struct ComStmtPrepareOk {
}
impl ComStmtPrepareOk {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let header = buf.get_u8()?;
if header != 0x00 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected COM_STMT_PREPARE_OK (0x00); received {}", header),
));
return Err(protocol_err!("expected COM_STMT_PREPARE_OK (0x00); received {}", header).into());
}
let statement_id = buf.get_u32::<LittleEndian>()?;

View File

@ -15,13 +15,10 @@ pub struct EofPacket {
}
impl EofPacket {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let header = buf.get_u8()?;
if header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0xFE; received {}", header),
));
return Err(protocol_err!("expected 0xFE; received {}", header));
}
let warning_count = buf.get_u16::<LittleEndian>()?;

View File

@ -21,13 +21,10 @@ pub struct OkPacket {
}
impl OkPacket {
pub fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result<Self> {
pub fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result<Self> {
let header = buf.get_u8()?;
if header != 0 && header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0x00 or 0xFE; received 0x{:X}", header),
));
return Err(protocol_err!("expected 0x00 or 0xFE; received 0x{:X}", header));
}
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);

View File

@ -22,11 +22,13 @@ unsafe impl Send for ResultRow {}
unsafe impl Sync for ResultRow {}
impl ResultRow {
pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> io::Result<Self> {
pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> crate::Result<Self> {
// 0x00 header : byte<1>
let header = buf.get_u8()?;
// TODO: Replace with InvalidData err
debug_assert_eq!(header, 0);
if header != 0 {
return Err(protocol_err!("expected header 0x00, got: {:#04X}", header).into())
}
// NULL-Bitmap : byte<(number_of_columns + 9) / 8>
let null_len = (columns.len() + 9) / 8;

View File

@ -98,11 +98,7 @@ impl Postgres {
}
auth => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("requires unimplemented authentication method: {:?}", auth),
)
.into());
return Err(protocol_err!("requires unimplemented authentication method: {:?}", auth).into());
}
}
}
@ -118,11 +114,7 @@ impl Postgres {
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
return Err(protocol_err!("received unexpected message: {:?}", message).into());
}
}
}
@ -211,10 +203,7 @@ impl Postgres {
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
return Err(protocol_err!("received unexpected message: {:?}", message)
.into());
}
}
@ -265,10 +254,7 @@ impl Postgres {
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unknown message id: {:?}", id),
)
return Err(protocol_err!("received unknown message id: {:?}", id)
.into());
}
};

View File

@ -43,7 +43,7 @@ pub enum Authentication {
}
impl Decode for Authentication {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
Ok(match buf.get_u32::<NetworkEndian>()? {
0 => Authentication::Ok,
@ -104,10 +104,7 @@ impl Decode for Authentication {
}
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown authentication response: {}", id),
));
return Err(protocol_err!("unknown authentication response: {}", id).into());
}
})
}

View File

@ -25,7 +25,7 @@ impl BackendKeyData {
}
impl Decode for BackendKeyData {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let process_id = buf.get_u32::<NetworkEndian>()?;
let secret_key = buf.get_u32::<NetworkEndian>()?;

View File

@ -15,7 +15,7 @@ impl CommandComplete {
}
impl Decode for CommandComplete {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
// TODO: MariaDb/MySQL return 0 for affected rows in a SELECT .. statement.
// PostgreSQL returns a row count. Should we force return 0 for compatibilities sake?

View File

@ -19,7 +19,7 @@ unsafe impl Send for DataRow {}
unsafe impl Sync for DataRow {}
impl Decode for DataRow {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
let mut buf = &*buffer;

View File

@ -1,7 +1,7 @@
use std::io;
pub trait Decode {
fn decode(src: &[u8]) -> io::Result<Self>
fn decode(src: &[u8]) -> crate::Result<Self>
where
Self: Sized;
}

View File

@ -45,7 +45,7 @@ impl fmt::Debug for NotificationResponse {
}
impl Decode for NotificationResponse {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let pid = buf.get_u32::<NetworkEndian>()?;
let buffer = Pin::new(buf.into());

View File

@ -9,7 +9,7 @@ pub struct ParameterDescription {
}
impl Decode for ParameterDescription {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut ids = Vec::with_capacity(cnt);

View File

@ -34,7 +34,7 @@ impl ParameterStatus {
}
impl Decode for ParameterStatus {
fn decode(buf: &[u8]) -> io::Result<Self> {
fn decode(buf: &[u8]) -> crate::Result<Self> {
let buffer = Pin::new(buf.into());
let mut buf: &[u8] = &*buffer;

View File

@ -28,7 +28,7 @@ impl ReadyForQuery {
}
impl Decode for ReadyForQuery {
fn decode(buf: &[u8]) -> io::Result<Self> {
fn decode(buf: &[u8]) -> crate::Result<Self> {
Ok(Self {
status: match buf[0] {
b'I' => TransactionStatus::Idle,
@ -36,13 +36,10 @@ impl Decode for ReadyForQuery {
b'E' => TransactionStatus::Error,
status => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
return Err(protocol_err!(
"received {:?} for TransactionStatus in ReadyForQuery",
status
),
));
).into());
}
},
})

View File

@ -54,9 +54,9 @@ impl Severity {
}
impl FromStr for Severity {
type Err = io::Error;
type Err = crate::Error;
fn from_str(s: &str) -> io::Result<Self> {
fn from_str(s: &str) -> crate::Result<Self> {
Ok(match s {
"PANIC" => Severity::Panic,
"FATAL" => Severity::Fatal,
@ -68,7 +68,7 @@ impl FromStr for Severity {
"LOG" => Severity::Log,
_ => {
return Err(io::ErrorKind::InvalidData.into());
return Err(protocol_err!("unexpected response severity: {}", s).into());
}
})
}
@ -225,7 +225,7 @@ impl fmt::Debug for Response {
}
impl Decode for Response {
fn decode(buf: &[u8]) -> io::Result<Self> {
fn decode(buf: &[u8]) -> crate::Result<Self> {
let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
let mut buf: &[u8] = &*buffer;
@ -286,7 +286,7 @@ impl Decode for Response {
position = Some(
field_value
.parse()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
.or(Err(protocol_err!("expected int, got: {}", field_value)))?,
);
}
@ -294,7 +294,7 @@ impl Decode for Response {
internal_position = Some(
field_value
.parse()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
.or(Err(protocol_err!("expected int, got: {}", field_value)))?,
);
}
@ -334,7 +334,7 @@ impl Decode for Response {
line = Some(
field_value
.parse()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?,
.or(Err(protocol_err!("expected int, got: {}", field_value)))?,
);
}
@ -344,35 +344,23 @@ impl Decode for Response {
_ => {
// TODO: Should we return these somehow, like in a map?
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unknown field in Response: {}", field_type),
));
return Err(protocol_err!("received unknown field in Response: {}", field_type).into());
}
}
}
let severity = severity_non_local
.or_else(move || unsafe { severity?.as_ref() }.parse().ok())
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"did not receieve field `severity` for Response",
)
})?;
.ok_or(protocol_err!("did not receieve field `severity` for Response"))?;
let code = code.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
let code = code.ok_or(
protocol_err!(
"did not receieve field `code` for Response",
)
})?;
let message = message.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"did not receieve field `message` for Response",
)
})?;
)?;
let message = message.ok_or(
protocol_err!("did not receieve field `message` for Response")
)?;
Ok(Self {
buffer,

View File

@ -20,7 +20,7 @@ pub struct RowField {
}
impl Decode for RowDescription {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut fields = Vec::with_capacity(cnt);