From d2da565d0b7652fdca02dca82b7e454db1f338cb Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 22 Feb 2021 23:41:22 -0800 Subject: [PATCH] refactor: remove unsafe for utf8 conversions within protocol --- sqlx-core/src/io/buf.rs | 39 +++++++++++--------------- sqlx-mysql/src/io/buf.rs | 19 ++++++------- sqlx-mysql/src/protocol/auth_switch.rs | 4 +-- sqlx-mysql/src/protocol/column_def.rs | 16 ++++------- sqlx-mysql/src/protocol/err.rs | 8 ++---- sqlx-mysql/src/protocol/handshake.rs | 13 ++------- sqlx-mysql/src/protocol/ok.rs | 5 ++-- 7 files changed, 37 insertions(+), 67 deletions(-) diff --git a/sqlx-core/src/io/buf.rs b/sqlx-core/src/io/buf.rs index 838b2662..6c7f18bf 100644 --- a/sqlx-core/src/io/buf.rs +++ b/sqlx-core/src/io/buf.rs @@ -1,3 +1,4 @@ +use std::convert::TryFrom; use std::io; use bytes::{Buf, Bytes}; @@ -6,28 +7,22 @@ use memchr::memchr; #[allow(clippy::module_name_repetitions)] pub trait BufExt: Buf { - /// # Safety - /// This function is unsafe because it does not check the bytes that are read are valid UTF-8. - #[allow(unsafe_code)] - unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString; + fn get_str(&mut self, n: usize) -> io::Result; - /// # Safety - /// This function is unsafe because it does not check the bytes that are read are valid UTF-8. - #[allow(unsafe_code)] - unsafe fn get_str_nul_unchecked(&mut self) -> io::Result; + fn get_str_nul(&mut self) -> io::Result; } impl BufExt for Bytes { - #[allow(unsafe_code)] - unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString { - ByteString::from_bytes_unchecked(self.split_to(n)) + fn get_str(&mut self, n: usize) -> io::Result { + ByteString::try_from(self.split_to(n)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } - #[allow(unsafe_code)] - unsafe fn get_str_nul_unchecked(&mut self) -> io::Result { + fn get_str_nul(&mut self) -> io::Result { let nul = memchr(b'\0', self).ok_or(io::ErrorKind::InvalidData)?; - Ok(ByteString::from_bytes_unchecked(self.split_to(nul + 1).slice(..nul))) + ByteString::try_from(self.split_to(nul + 1).slice(..nul)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } } @@ -40,32 +35,30 @@ mod tests { use super::BufExt; #[test] - fn test_get_str() { + fn test_get_str() -> io::Result<()> { let mut buf = Bytes::from_static(b"Hello World\0"); - #[allow(unsafe_code)] - let s = unsafe { buf.get_str_unchecked(5) }; + let s = buf.get_str(5)?; buf.advance(1); - #[allow(unsafe_code)] - let s2 = unsafe { buf.get_str_unchecked(5) }; + let s2 = buf.get_str(5)?; assert_eq!(&s, "Hello"); assert_eq!(&s2, "World"); + + Ok(()) } #[test] fn test_get_str_nul() -> io::Result<()> { let mut buf = Bytes::from_static(b"Hello\0 World\0"); - #[allow(unsafe_code)] - let s = unsafe { buf.get_str_nul_unchecked()? }; + let s = buf.get_str_nul()?; buf.advance(1); - #[allow(unsafe_code)] - let s2 = unsafe { buf.get_str_nul_unchecked()? }; + let s2 = buf.get_str_nul()?; assert_eq!(&s, "Hello"); assert_eq!(&s2, "World"); diff --git a/sqlx-mysql/src/io/buf.rs b/sqlx-mysql/src/io/buf.rs index a031974c..a74fe6b6 100644 --- a/sqlx-mysql/src/io/buf.rs +++ b/sqlx-mysql/src/io/buf.rs @@ -1,3 +1,5 @@ +use std::io; + use bytes::{Buf, Bytes}; use bytestring::ByteString; use sqlx_core::io::BufExt; @@ -9,11 +11,9 @@ use sqlx_core::io::BufExt; pub(crate) trait MySqlBufExt: BufExt { fn get_uint_lenenc(&mut self) -> u64; - #[allow(unsafe_code)] - unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString; + fn get_str_lenenc(&mut self) -> io::Result; - #[allow(unsafe_code)] - unsafe fn get_str_eof_unchecked(&mut self) -> ByteString; + fn get_str_eof(&mut self) -> io::Result; fn get_bytes_lenenc(&mut self) -> Bytes; } @@ -37,17 +37,14 @@ impl MySqlBufExt for Bytes { } } - #[allow(unsafe_code)] - unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString { - #[allow(clippy::cast_possible_truncation)] + fn get_str_lenenc(&mut self) -> io::Result { let len = self.get_uint_lenenc() as usize; - self.get_str_unchecked(len) + self.get_str(len) } - #[allow(unsafe_code)] - unsafe fn get_str_eof_unchecked(&mut self) -> ByteString { - self.get_str_unchecked(self.len()) + fn get_str_eof(&mut self) -> io::Result { + self.get_str(self.len()) } fn get_bytes_lenenc(&mut self) -> Bytes { diff --git a/sqlx-mysql/src/protocol/auth_switch.rs b/sqlx-mysql/src/protocol/auth_switch.rs index 099d89d0..75b4c7f3 100644 --- a/sqlx-mysql/src/protocol/auth_switch.rs +++ b/sqlx-mysql/src/protocol/auth_switch.rs @@ -18,9 +18,7 @@ impl Deserialize<'_> for AuthSwitch { let tag = buf.get_u8(); debug_assert_eq!(tag, 0xfe); - // SAFE: auth plugins are ASCII only - #[allow(unsafe_code)] - let name = unsafe { buf.get_str_nul_unchecked()? }; + let name = buf.get_str_nul()?; if buf.ends_with(&[0]) { // if this terminates in a NUL; drop the NUL diff --git a/sqlx-mysql/src/protocol/column_def.rs b/sqlx-mysql/src/protocol/column_def.rs index 119a2a05..fb1fe3ce 100644 --- a/sqlx-mysql/src/protocol/column_def.rs +++ b/sqlx-mysql/src/protocol/column_def.rs @@ -26,21 +26,17 @@ pub(crate) struct ColumnDefinition { } impl Deserialize<'_> for ColumnDefinition { - #[allow(unsafe_code)] fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - // UNSAFE: fields are known to be UTF-8 as we have connected with the - // UTF-8 connection charset - - let catalog = unsafe { buf.get_str_lenenc_unchecked() }; + let catalog = buf.get_str_lenenc()?; // we are told that this always "def" debug_assert_eq!(catalog, "def"); - let schema = unsafe { buf.get_str_lenenc_unchecked() }; - let table_alias = unsafe { buf.get_str_lenenc_unchecked() }; - let table = unsafe { buf.get_str_lenenc_unchecked() }; - let alias = unsafe { buf.get_str_lenenc_unchecked() }; - let name = unsafe { buf.get_str_lenenc_unchecked() }; + let schema = buf.get_str_lenenc()?; + let table_alias = buf.get_str_lenenc()?; + let table = buf.get_str_lenenc()?; + let alias = buf.get_str_lenenc()?; + let name = buf.get_str_lenenc()?; let fixed_len_fields_len = buf.get_uint_lenenc(); diff --git a/sqlx-mysql/src/protocol/err.rs b/sqlx-mysql/src/protocol/err.rs index bd7f5770..74ad3741 100644 --- a/sqlx-mysql/src/protocol/err.rs +++ b/sqlx-mysql/src/protocol/err.rs @@ -37,16 +37,12 @@ impl Deserialize<'_> for ErrPacket { // if the next byte is '#' then we have the SQL STATE buf.advance(1); - // UNSAFE: the SQL STATE is an ASCII error code - #[allow(unsafe_code)] - Some(unsafe { buf.get_str_unchecked(5) }) + Some(buf.get_str(5)?) } else { None }; - // UNSAFE: the human-readable error message is UTF-8 - #[allow(unsafe_code)] - let error_message = unsafe { buf.get_str_eof_unchecked() }; + let error_message = buf.get_str_eof()?; Ok(Self { sql_state, error_code, error_message }) } diff --git a/sqlx-mysql/src/protocol/handshake.rs b/sqlx-mysql/src/protocol/handshake.rs index f953db0a..240176ec 100644 --- a/sqlx-mysql/src/protocol/handshake.rs +++ b/sqlx-mysql/src/protocol/handshake.rs @@ -34,11 +34,7 @@ pub(crate) struct Handshake { impl Deserialize<'_> for Handshake { fn deserialize_with(mut buf: Bytes, _: ()) -> Result { let protocol_version = buf.get_u8(); - - // UNSAFE: server version is known to be ASCII - #[allow(unsafe_code)] - let server_version = unsafe { buf.get_str_nul_unchecked()? }; - + let server_version = buf.get_str_nul()?; let connection_id = buf.get_u32_le(); // first 8 bytes of the auth-plugin data @@ -103,12 +99,7 @@ impl Deserialize<'_> for Handshake { let auth_plugin_name_end = memchr(b'\0', &buf).unwrap_or_else(|| buf.len()); - // UNSAFE: auth plugin names are known to be ASCII - #[allow(unsafe_code)] - let auth_plugin_name_ = - unsafe { Some(buf.get_str_unchecked(auth_plugin_name_end)) }; - - auth_plugin_name = auth_plugin_name_; + auth_plugin_name = Some(buf.get_str(auth_plugin_name_end)?); } } diff --git a/sqlx-mysql/src/protocol/ok.rs b/sqlx-mysql/src/protocol/ok.rs index 41449cb1..eb3cc71d 100644 --- a/sqlx-mysql/src/protocol/ok.rs +++ b/sqlx-mysql/src/protocol/ok.rs @@ -46,13 +46,12 @@ impl Deserialize<'_, Capabilities> for OkPacket { ByteString::default() } else { // human readable status information - #[allow(unsafe_code)] if capabilities.contains(Capabilities::SESSION_TRACK) { // if [CLIENT_SESSION_TRACK] the info comes down as string - unsafe { buf.get_str_lenenc_unchecked() } + buf.get_str_lenenc()? } else { // otherwise the ASCII info is sent as string - unsafe { buf.get_str_eof_unchecked() } + buf.get_str_eof()? } };