mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
refactor: remove unsafe for utf8 conversions within protocol
This commit is contained in:
parent
3df0743bdf
commit
d2da565d0b
@ -1,3 +1,4 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::io;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
@ -6,28 +7,22 @@ use memchr::memchr;
|
||||
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub trait BufExt: Buf {
|
||||
/// # Safety
|
||||
/// This function is unsafe because it does not check the bytes that are read are valid UTF-8.
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString;
|
||||
fn get_str(&mut self, n: usize) -> io::Result<ByteString>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe because it does not check the bytes that are read are valid UTF-8.
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<ByteString>;
|
||||
fn get_str_nul(&mut self) -> io::Result<ByteString>;
|
||||
}
|
||||
|
||||
impl BufExt for Bytes {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_unchecked(&mut self, n: usize) -> ByteString {
|
||||
ByteString::from_bytes_unchecked(self.split_to(n))
|
||||
fn get_str(&mut self, n: usize) -> io::Result<ByteString> {
|
||||
ByteString::try_from(self.split_to(n))
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<ByteString> {
|
||||
fn get_str_nul(&mut self) -> io::Result<ByteString> {
|
||||
let nul = memchr(b'\0', self).ok_or(io::ErrorKind::InvalidData)?;
|
||||
|
||||
Ok(ByteString::from_bytes_unchecked(self.split_to(nul + 1).slice(..nul)))
|
||||
ByteString::try_from(self.split_to(nul + 1).slice(..nul))
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,32 +35,30 @@ mod tests {
|
||||
use super::BufExt;
|
||||
|
||||
#[test]
|
||||
fn test_get_str() {
|
||||
fn test_get_str() -> io::Result<()> {
|
||||
let mut buf = Bytes::from_static(b"Hello World\0");
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
let s = unsafe { buf.get_str_unchecked(5) };
|
||||
let s = buf.get_str(5)?;
|
||||
|
||||
buf.advance(1);
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
let s2 = unsafe { buf.get_str_unchecked(5) };
|
||||
let s2 = buf.get_str(5)?;
|
||||
|
||||
assert_eq!(&s, "Hello");
|
||||
assert_eq!(&s2, "World");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_str_nul() -> io::Result<()> {
|
||||
let mut buf = Bytes::from_static(b"Hello\0 World\0");
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
let s = unsafe { buf.get_str_nul_unchecked()? };
|
||||
let s = buf.get_str_nul()?;
|
||||
|
||||
buf.advance(1);
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
let s2 = unsafe { buf.get_str_nul_unchecked()? };
|
||||
let s2 = buf.get_str_nul()?;
|
||||
|
||||
assert_eq!(&s, "Hello");
|
||||
assert_eq!(&s2, "World");
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
use std::io;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytestring::ByteString;
|
||||
use sqlx_core::io::BufExt;
|
||||
@ -9,11 +11,9 @@ use sqlx_core::io::BufExt;
|
||||
pub(crate) trait MySqlBufExt: BufExt {
|
||||
fn get_uint_lenenc(&mut self) -> u64;
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString;
|
||||
fn get_str_lenenc(&mut self) -> io::Result<ByteString>;
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString;
|
||||
fn get_str_eof(&mut self) -> io::Result<ByteString>;
|
||||
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes;
|
||||
}
|
||||
@ -37,17 +37,14 @@ impl MySqlBufExt for Bytes {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
fn get_str_lenenc(&mut self) -> io::Result<ByteString> {
|
||||
let len = self.get_uint_lenenc() as usize;
|
||||
|
||||
self.get_str_unchecked(len)
|
||||
self.get_str(len)
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString {
|
||||
self.get_str_unchecked(self.len())
|
||||
fn get_str_eof(&mut self) -> io::Result<ByteString> {
|
||||
self.get_str(self.len())
|
||||
}
|
||||
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes {
|
||||
|
||||
@ -18,9 +18,7 @@ impl Deserialize<'_> for AuthSwitch {
|
||||
let tag = buf.get_u8();
|
||||
debug_assert_eq!(tag, 0xfe);
|
||||
|
||||
// SAFE: auth plugins are ASCII only
|
||||
#[allow(unsafe_code)]
|
||||
let name = unsafe { buf.get_str_nul_unchecked()? };
|
||||
let name = buf.get_str_nul()?;
|
||||
|
||||
if buf.ends_with(&[0]) {
|
||||
// if this terminates in a NUL; drop the NUL
|
||||
|
||||
@ -26,21 +26,17 @@ pub(crate) struct ColumnDefinition {
|
||||
}
|
||||
|
||||
impl Deserialize<'_> for ColumnDefinition {
|
||||
#[allow(unsafe_code)]
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
// UNSAFE: fields are known to be UTF-8 as we have connected with the
|
||||
// UTF-8 connection charset
|
||||
|
||||
let catalog = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let catalog = buf.get_str_lenenc()?;
|
||||
|
||||
// we are told that this always "def"
|
||||
debug_assert_eq!(catalog, "def");
|
||||
|
||||
let schema = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let table_alias = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let table = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let alias = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let name = unsafe { buf.get_str_lenenc_unchecked() };
|
||||
let schema = buf.get_str_lenenc()?;
|
||||
let table_alias = buf.get_str_lenenc()?;
|
||||
let table = buf.get_str_lenenc()?;
|
||||
let alias = buf.get_str_lenenc()?;
|
||||
let name = buf.get_str_lenenc()?;
|
||||
|
||||
let fixed_len_fields_len = buf.get_uint_lenenc();
|
||||
|
||||
|
||||
@ -37,16 +37,12 @@ impl Deserialize<'_> for ErrPacket {
|
||||
// if the next byte is '#' then we have the SQL STATE
|
||||
buf.advance(1);
|
||||
|
||||
// UNSAFE: the SQL STATE is an ASCII error code
|
||||
#[allow(unsafe_code)]
|
||||
Some(unsafe { buf.get_str_unchecked(5) })
|
||||
Some(buf.get_str(5)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// UNSAFE: the human-readable error message is UTF-8
|
||||
#[allow(unsafe_code)]
|
||||
let error_message = unsafe { buf.get_str_eof_unchecked() };
|
||||
let error_message = buf.get_str_eof()?;
|
||||
|
||||
Ok(Self { sql_state, error_code, error_message })
|
||||
}
|
||||
|
||||
@ -34,11 +34,7 @@ pub(crate) struct Handshake {
|
||||
impl Deserialize<'_> for Handshake {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
let protocol_version = buf.get_u8();
|
||||
|
||||
// UNSAFE: server version is known to be ASCII
|
||||
#[allow(unsafe_code)]
|
||||
let server_version = unsafe { buf.get_str_nul_unchecked()? };
|
||||
|
||||
let server_version = buf.get_str_nul()?;
|
||||
let connection_id = buf.get_u32_le();
|
||||
|
||||
// first 8 bytes of the auth-plugin data
|
||||
@ -103,12 +99,7 @@ impl Deserialize<'_> for Handshake {
|
||||
|
||||
let auth_plugin_name_end = memchr(b'\0', &buf).unwrap_or_else(|| buf.len());
|
||||
|
||||
// UNSAFE: auth plugin names are known to be ASCII
|
||||
#[allow(unsafe_code)]
|
||||
let auth_plugin_name_ =
|
||||
unsafe { Some(buf.get_str_unchecked(auth_plugin_name_end)) };
|
||||
|
||||
auth_plugin_name = auth_plugin_name_;
|
||||
auth_plugin_name = Some(buf.get_str(auth_plugin_name_end)?);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -46,13 +46,12 @@ impl Deserialize<'_, Capabilities> for OkPacket {
|
||||
ByteString::default()
|
||||
} else {
|
||||
// human readable status information
|
||||
#[allow(unsafe_code)]
|
||||
if capabilities.contains(Capabilities::SESSION_TRACK) {
|
||||
// if [CLIENT_SESSION_TRACK] the info comes down as string<lenenc>
|
||||
unsafe { buf.get_str_lenenc_unchecked() }
|
||||
buf.get_str_lenenc()?
|
||||
} else {
|
||||
// otherwise the ASCII info is sent as string<EOF>
|
||||
unsafe { buf.get_str_eof_unchecked() }
|
||||
buf.get_str_eof()?
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user