diff --git a/sqlx-core/src/mariadb/connection.rs b/sqlx-core/src/mariadb/connection.rs index 479c48d4..75d43966 100644 --- a/sqlx-core/src/mariadb/connection.rs +++ b/sqlx-core/src/mariadb/connection.rs @@ -18,6 +18,7 @@ use std::{ net::{IpAddr, SocketAddr}, }; use url::Url; +use url::quirks::protocol; pub struct MariaDb { pub(crate) stream: BufStream, @@ -145,24 +146,15 @@ impl MariaDb { 0xfe | 0x00 => OkPacket::decode(buf, capabilities)?, 0xff => { - let err = ErrPacket::decode(buf)?; - - // TODO: Bubble as Error::Database - // panic!("received db err = {:?}", err); - return Err( - io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", err)).into(), - ); + return ErrPacket::decode(buf)?.expect_error(); } id => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( + return Err(protocol_err!( "unexpected packet identifier 0x{:X?} when expecting 0xFE (OK) or 0xFF \ (ERR)", id - ), - ) + ) .into()); } }) @@ -228,8 +220,7 @@ impl MariaDb { if packet[0] == 0x00 { let _ok = OkPacket::decode(packet, capabilities)?; } else if packet[0] == 0xFF { - let err = ErrPacket::decode(packet)?; - panic!("received db err = {:?}", err); + return ErrPacket::decode(packet)?.expect_error(); } else { // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes // how many columns we can expect when fetching rows from this statement diff --git a/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs b/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs index 5a07dd71..1549ec0f 100644 --- a/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs +++ b/sqlx-core/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs @@ -18,14 +18,11 @@ pub struct ComStmtPrepareOk { } impl ComStmtPrepareOk { - pub(crate) fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> crate::Result { let header = buf.get_u8()?; if header != 0x00 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected COM_STMT_PREPARE_OK (0x00); received {}", header), - )); + return Err(protocol_err!("expected COM_STMT_PREPARE_OK (0x00); received {}", header).into()); } let statement_id = buf.get_u32::()?; diff --git a/sqlx-core/src/mariadb/protocol/response/eof.rs b/sqlx-core/src/mariadb/protocol/response/eof.rs index 34b90dc2..21d4ed75 100644 --- a/sqlx-core/src/mariadb/protocol/response/eof.rs +++ b/sqlx-core/src/mariadb/protocol/response/eof.rs @@ -15,13 +15,10 @@ pub struct EofPacket { } impl EofPacket { - pub(crate) fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> crate::Result { let header = buf.get_u8()?; if header != 0xFE { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected 0xFE; received {}", header), - )); + return Err(protocol_err!("expected 0xFE; received {}", header)); } let warning_count = buf.get_u16::()?; diff --git a/sqlx-core/src/mariadb/protocol/response/ok.rs b/sqlx-core/src/mariadb/protocol/response/ok.rs index ff0be09a..d2075cc6 100644 --- a/sqlx-core/src/mariadb/protocol/response/ok.rs +++ b/sqlx-core/src/mariadb/protocol/response/ok.rs @@ -21,13 +21,10 @@ pub struct OkPacket { } impl OkPacket { - pub fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result { + pub fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result { let header = buf.get_u8()?; if header != 0 && header != 0xFE { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("expected 0x00 or 0xFE; received 0x{:X}", header), - )); + return Err(protocol_err!("expected 0x00 or 0xFE; received 0x{:X}", header)); } let affected_rows = buf.get_uint_lenenc::()?.unwrap_or(0); diff --git a/sqlx-core/src/mariadb/protocol/response/row.rs b/sqlx-core/src/mariadb/protocol/response/row.rs index b51badfe..9d16021f 100644 --- a/sqlx-core/src/mariadb/protocol/response/row.rs +++ b/sqlx-core/src/mariadb/protocol/response/row.rs @@ -22,11 +22,13 @@ unsafe impl Send for ResultRow {} unsafe impl Sync for ResultRow {} impl ResultRow { - pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> io::Result { + pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> crate::Result { // 0x00 header : byte<1> let header = buf.get_u8()?; - // TODO: Replace with InvalidData err - debug_assert_eq!(header, 0); + + if header != 0 { + return Err(protocol_err!("expected header 0x00, got: {:#04X}", header).into()) + } // NULL-Bitmap : byte<(number_of_columns + 9) / 8> let null_len = (columns.len() + 9) / 8; diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 28052975..079f33b0 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -98,11 +98,7 @@ impl Postgres { } auth => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("requires unimplemented authentication method: {:?}", auth), - ) - .into()); + return Err(protocol_err!("requires unimplemented authentication method: {:?}", auth).into()); } } } @@ -118,11 +114,7 @@ impl Postgres { } message => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received unexpected message: {:?}", message), - ) - .into()); + return Err(protocol_err!("received unexpected message: {:?}", message).into()); } } } @@ -211,10 +203,7 @@ impl Postgres { } message => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received unexpected message: {:?}", message), - ) + return Err(protocol_err!("received unexpected message: {:?}", message) .into()); } } @@ -265,10 +254,7 @@ impl Postgres { b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)), id => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received unknown message id: {:?}", id), - ) + return Err(protocol_err!("received unknown message id: {:?}", id) .into()); } }; diff --git a/sqlx-core/src/postgres/protocol/authentication.rs b/sqlx-core/src/postgres/protocol/authentication.rs index 600cac8d..faf3a3f8 100644 --- a/sqlx-core/src/postgres/protocol/authentication.rs +++ b/sqlx-core/src/postgres/protocol/authentication.rs @@ -43,7 +43,7 @@ pub enum Authentication { } impl Decode for Authentication { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { Ok(match buf.get_u32::()? { 0 => Authentication::Ok, @@ -104,10 +104,7 @@ impl Decode for Authentication { } id => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("unknown authentication response: {}", id), - )); + return Err(protocol_err!("unknown authentication response: {}", id).into()); } }) } diff --git a/sqlx-core/src/postgres/protocol/backend_key_data.rs b/sqlx-core/src/postgres/protocol/backend_key_data.rs index fbac8487..67ed8be6 100644 --- a/sqlx-core/src/postgres/protocol/backend_key_data.rs +++ b/sqlx-core/src/postgres/protocol/backend_key_data.rs @@ -25,7 +25,7 @@ impl BackendKeyData { } impl Decode for BackendKeyData { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { let process_id = buf.get_u32::()?; let secret_key = buf.get_u32::()?; diff --git a/sqlx-core/src/postgres/protocol/command_complete.rs b/sqlx-core/src/postgres/protocol/command_complete.rs index 3fd5b6ab..289b40f1 100644 --- a/sqlx-core/src/postgres/protocol/command_complete.rs +++ b/sqlx-core/src/postgres/protocol/command_complete.rs @@ -15,7 +15,7 @@ impl CommandComplete { } impl Decode for CommandComplete { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { // TODO: MariaDb/MySQL return 0 for affected rows in a SELECT .. statement. // PostgreSQL returns a row count. Should we force return 0 for compatibilities sake? diff --git a/sqlx-core/src/postgres/protocol/data_row.rs b/sqlx-core/src/postgres/protocol/data_row.rs index a7c356e4..70039429 100644 --- a/sqlx-core/src/postgres/protocol/data_row.rs +++ b/sqlx-core/src/postgres/protocol/data_row.rs @@ -19,7 +19,7 @@ unsafe impl Send for DataRow {} unsafe impl Sync for DataRow {} impl Decode for DataRow { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { let cnt = buf.get_u16::()? as usize; let buffer: Pin> = Pin::new(buf.into()); let mut buf = &*buffer; diff --git a/sqlx-core/src/postgres/protocol/decode.rs b/sqlx-core/src/postgres/protocol/decode.rs index 232e16fe..31c5bf4b 100644 --- a/sqlx-core/src/postgres/protocol/decode.rs +++ b/sqlx-core/src/postgres/protocol/decode.rs @@ -1,7 +1,7 @@ use std::io; pub trait Decode { - fn decode(src: &[u8]) -> io::Result + fn decode(src: &[u8]) -> crate::Result where Self: Sized; } diff --git a/sqlx-core/src/postgres/protocol/notification_response.rs b/sqlx-core/src/postgres/protocol/notification_response.rs index 2aa07494..22f2e9dc 100644 --- a/sqlx-core/src/postgres/protocol/notification_response.rs +++ b/sqlx-core/src/postgres/protocol/notification_response.rs @@ -45,7 +45,7 @@ impl fmt::Debug for NotificationResponse { } impl Decode for NotificationResponse { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { let pid = buf.get_u32::()?; let buffer = Pin::new(buf.into()); diff --git a/sqlx-core/src/postgres/protocol/parameter_description.rs b/sqlx-core/src/postgres/protocol/parameter_description.rs index 7514f3e5..c5b31f91 100644 --- a/sqlx-core/src/postgres/protocol/parameter_description.rs +++ b/sqlx-core/src/postgres/protocol/parameter_description.rs @@ -9,7 +9,7 @@ pub struct ParameterDescription { } impl Decode for ParameterDescription { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { let cnt = buf.get_u16::()? as usize; let mut ids = Vec::with_capacity(cnt); diff --git a/sqlx-core/src/postgres/protocol/parameter_status.rs b/sqlx-core/src/postgres/protocol/parameter_status.rs index 3815c7f7..b8b03e86 100644 --- a/sqlx-core/src/postgres/protocol/parameter_status.rs +++ b/sqlx-core/src/postgres/protocol/parameter_status.rs @@ -34,7 +34,7 @@ impl ParameterStatus { } impl Decode for ParameterStatus { - fn decode(buf: &[u8]) -> io::Result { + fn decode(buf: &[u8]) -> crate::Result { let buffer = Pin::new(buf.into()); let mut buf: &[u8] = &*buffer; diff --git a/sqlx-core/src/postgres/protocol/ready_for_query.rs b/sqlx-core/src/postgres/protocol/ready_for_query.rs index 83e37187..40d4c639 100644 --- a/sqlx-core/src/postgres/protocol/ready_for_query.rs +++ b/sqlx-core/src/postgres/protocol/ready_for_query.rs @@ -28,7 +28,7 @@ impl ReadyForQuery { } impl Decode for ReadyForQuery { - fn decode(buf: &[u8]) -> io::Result { + fn decode(buf: &[u8]) -> crate::Result { Ok(Self { status: match buf[0] { b'I' => TransactionStatus::Idle, @@ -36,13 +36,10 @@ impl Decode for ReadyForQuery { b'E' => TransactionStatus::Error, status => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( + return Err(protocol_err!( "received {:?} for TransactionStatus in ReadyForQuery", status - ), - )); + ).into()); } }, }) diff --git a/sqlx-core/src/postgres/protocol/response.rs b/sqlx-core/src/postgres/protocol/response.rs index 6ea3adc7..ac076e17 100644 --- a/sqlx-core/src/postgres/protocol/response.rs +++ b/sqlx-core/src/postgres/protocol/response.rs @@ -54,9 +54,9 @@ impl Severity { } impl FromStr for Severity { - type Err = io::Error; + type Err = crate::Error; - fn from_str(s: &str) -> io::Result { + fn from_str(s: &str) -> crate::Result { Ok(match s { "PANIC" => Severity::Panic, "FATAL" => Severity::Fatal, @@ -68,7 +68,7 @@ impl FromStr for Severity { "LOG" => Severity::Log, _ => { - return Err(io::ErrorKind::InvalidData.into()); + return Err(protocol_err!("unexpected response severity: {}", s).into()); } }) } @@ -225,7 +225,7 @@ impl fmt::Debug for Response { } impl Decode for Response { - fn decode(buf: &[u8]) -> io::Result { + fn decode(buf: &[u8]) -> crate::Result { let buffer: Pin> = Pin::new(buf.into()); let mut buf: &[u8] = &*buffer; @@ -286,7 +286,7 @@ impl Decode for Response { position = Some( field_value .parse() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?, + .or(Err(protocol_err!("expected int, got: {}", field_value)))?, ); } @@ -294,7 +294,7 @@ impl Decode for Response { internal_position = Some( field_value .parse() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?, + .or(Err(protocol_err!("expected int, got: {}", field_value)))?, ); } @@ -334,7 +334,7 @@ impl Decode for Response { line = Some( field_value .parse() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?, + .or(Err(protocol_err!("expected int, got: {}", field_value)))?, ); } @@ -344,35 +344,23 @@ impl Decode for Response { _ => { // TODO: Should we return these somehow, like in a map? - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received unknown field in Response: {}", field_type), - )); + return Err(protocol_err!("received unknown field in Response: {}", field_type).into()); } } } let severity = severity_non_local .or_else(move || unsafe { severity?.as_ref() }.parse().ok()) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "did not receieve field `severity` for Response", - ) - })?; + .ok_or(protocol_err!("did not receieve field `severity` for Response"))?; - let code = code.ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, + let code = code.ok_or( + protocol_err!( "did not receieve field `code` for Response", ) - })?; - let message = message.ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "did not receieve field `message` for Response", - ) - })?; + )?; + let message = message.ok_or( + protocol_err!("did not receieve field `message` for Response") + )?; Ok(Self { buffer, diff --git a/sqlx-core/src/postgres/protocol/row_description.rs b/sqlx-core/src/postgres/protocol/row_description.rs index 98bc67e0..276f22c2 100644 --- a/sqlx-core/src/postgres/protocol/row_description.rs +++ b/sqlx-core/src/postgres/protocol/row_description.rs @@ -20,7 +20,7 @@ pub struct RowField { } impl Decode for RowDescription { - fn decode(mut buf: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> crate::Result { let cnt = buf.get_u16::()? as usize; let mut fields = Vec::with_capacity(cnt);