refactor: remove unsafe for utf8 conversions within protocol

This commit is contained in:
Ryan Leckey 2021-02-22 23:41:22 -08:00
parent 3df0743bdf
commit d2da565d0b
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
7 changed files with 37 additions and 67 deletions

View File

@ -1,3 +1,4 @@
use std::convert::TryFrom;
use std::io;
use bytes::{Buf, Bytes};
@ -6,28 +7,22 @@ use memchr::memchr;
#[allow(clippy::module_name_repetitions)]
pub trait BufExt: Buf {
/// # Safety
/// This function is unsafe because it does not check the bytes that are read are valid UTF-8.
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString;
fn get_str(&mut self, n: usize) -> io::Result<ByteString>;
/// # Safety
/// This function is unsafe because it does not check the bytes that are read are valid UTF-8.
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<ByteString>;
fn get_str_nul(&mut self) -> io::Result<ByteString>;
}
impl BufExt for Bytes {
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString {
ByteString::from_bytes_unchecked(self.split_to(n))
fn get_str(&mut self, n: usize) -> io::Result<ByteString> {
ByteString::try_from(self.split_to(n))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<ByteString> {
fn get_str_nul(&mut self) -> io::Result<ByteString> {
let nul = memchr(b'\0', self).ok_or(io::ErrorKind::InvalidData)?;
Ok(ByteString::from_bytes_unchecked(self.split_to(nul + 1).slice(..nul)))
ByteString::try_from(self.split_to(nul + 1).slice(..nul))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
}
@ -40,32 +35,30 @@ mod tests {
use super::BufExt;
#[test]
fn test_get_str() {
fn test_get_str() -> io::Result<()> {
let mut buf = Bytes::from_static(b"Hello World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_unchecked(5) };
let s = buf.get_str(5)?;
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_unchecked(5) };
let s2 = buf.get_str(5)?;
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");
Ok(())
}
#[test]
fn test_get_str_nul() -> io::Result<()> {
let mut buf = Bytes::from_static(b"Hello\0 World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_nul_unchecked()? };
let s = buf.get_str_nul()?;
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_nul_unchecked()? };
let s2 = buf.get_str_nul()?;
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");

View File

@ -1,3 +1,5 @@
use std::io;
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use sqlx_core::io::BufExt;
@ -9,11 +11,9 @@ use sqlx_core::io::BufExt;
pub(crate) trait MySqlBufExt: BufExt {
fn get_uint_lenenc(&mut self) -> u64;
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString;
fn get_str_lenenc(&mut self) -> io::Result<ByteString>;
#[allow(unsafe_code)]
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString;
fn get_str_eof(&mut self) -> io::Result<ByteString>;
fn get_bytes_lenenc(&mut self) -> Bytes;
}
@ -37,17 +37,14 @@ impl MySqlBufExt for Bytes {
}
}
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString {
#[allow(clippy::cast_possible_truncation)]
fn get_str_lenenc(&mut self) -> io::Result<ByteString> {
let len = self.get_uint_lenenc() as usize;
self.get_str_unchecked(len)
self.get_str(len)
}
#[allow(unsafe_code)]
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString {
self.get_str_unchecked(self.len())
fn get_str_eof(&mut self) -> io::Result<ByteString> {
self.get_str(self.len())
}
fn get_bytes_lenenc(&mut self) -> Bytes {

View File

@ -18,9 +18,7 @@ impl Deserialize<'_> for AuthSwitch {
let tag = buf.get_u8();
debug_assert_eq!(tag, 0xfe);
// SAFE: auth plugins are ASCII only
#[allow(unsafe_code)]
let name = unsafe { buf.get_str_nul_unchecked()? };
let name = buf.get_str_nul()?;
if buf.ends_with(&[0]) {
// if this terminates in a NUL; drop the NUL

View File

@ -26,21 +26,17 @@ pub(crate) struct ColumnDefinition {
}
impl Deserialize<'_> for ColumnDefinition {
#[allow(unsafe_code)]
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
// UNSAFE: fields are known to be UTF-8 as we have connected with the
// UTF-8 connection charset
let catalog = unsafe { buf.get_str_lenenc_unchecked() };
let catalog = buf.get_str_lenenc()?;
// we are told that this always "def"
debug_assert_eq!(catalog, "def");
let schema = unsafe { buf.get_str_lenenc_unchecked() };
let table_alias = unsafe { buf.get_str_lenenc_unchecked() };
let table = unsafe { buf.get_str_lenenc_unchecked() };
let alias = unsafe { buf.get_str_lenenc_unchecked() };
let name = unsafe { buf.get_str_lenenc_unchecked() };
let schema = buf.get_str_lenenc()?;
let table_alias = buf.get_str_lenenc()?;
let table = buf.get_str_lenenc()?;
let alias = buf.get_str_lenenc()?;
let name = buf.get_str_lenenc()?;
let fixed_len_fields_len = buf.get_uint_lenenc();

View File

@ -37,16 +37,12 @@ impl Deserialize<'_> for ErrPacket {
// if the next byte is '#' then we have the SQL STATE
buf.advance(1);
// UNSAFE: the SQL STATE is an ASCII error code
#[allow(unsafe_code)]
Some(unsafe { buf.get_str_unchecked(5) })
Some(buf.get_str(5)?)
} else {
None
};
// UNSAFE: the human-readable error message is UTF-8
#[allow(unsafe_code)]
let error_message = unsafe { buf.get_str_eof_unchecked() };
let error_message = buf.get_str_eof()?;
Ok(Self { sql_state, error_code, error_message })
}

View File

@ -34,11 +34,7 @@ pub(crate) struct Handshake {
impl Deserialize<'_> for Handshake {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let protocol_version = buf.get_u8();
// UNSAFE: server version is known to be ASCII
#[allow(unsafe_code)]
let server_version = unsafe { buf.get_str_nul_unchecked()? };
let server_version = buf.get_str_nul()?;
let connection_id = buf.get_u32_le();
// first 8 bytes of the auth-plugin data
@ -103,12 +99,7 @@ impl Deserialize<'_> for Handshake {
let auth_plugin_name_end = memchr(b'\0', &buf).unwrap_or_else(|| buf.len());
// UNSAFE: auth plugin names are known to be ASCII
#[allow(unsafe_code)]
let auth_plugin_name_ =
unsafe { Some(buf.get_str_unchecked(auth_plugin_name_end)) };
auth_plugin_name = auth_plugin_name_;
auth_plugin_name = Some(buf.get_str(auth_plugin_name_end)?);
}
}

View File

@ -46,13 +46,12 @@ impl Deserialize<'_, Capabilities> for OkPacket {
ByteString::default()
} else {
// human readable status information
#[allow(unsafe_code)]
if capabilities.contains(Capabilities::SESSION_TRACK) {
// if [CLIENT_SESSION_TRACK] the info comes down as string<lenenc>
unsafe { buf.get_str_lenenc_unchecked() }
buf.get_str_lenenc()?
} else {
// otherwise the ASCII info is sent as string<EOF>
unsafe { buf.get_str_eof_unchecked() }
buf.get_str_eof()?
}
};