feat(mysql): implement more protocol types: QueryResponse, Query, QueryStep, ColumnDef, Row

This commit is contained in:
Ryan Leckey 2021-01-26 01:14:47 -08:00
parent 76dd78f639
commit d279c6b978
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
18 changed files with 359 additions and 92 deletions

View File

@ -14,6 +14,10 @@ impl MySqlDatabaseError {
pub(crate) fn new(code: u16, message: &str) -> Self {
Self(ErrPacket::new(code, message))
}
pub(crate) fn malformed_packet(message: &str) -> Self {
Self::new(2027, &format!("Malformed packet: {}", message))
}
}
impl DatabaseError for MySqlDatabaseError {

View File

@ -1,25 +1,39 @@
mod auth;
mod auth_plugin;
mod auth_response;
mod auth_switch;
mod capabilities;
mod column_def;
mod command;
mod eof;
mod err;
mod handshake;
mod handshake_response;
mod ok;
mod ping;
mod query;
mod query_response;
mod query_step;
mod packet;
mod quit;
mod row;
mod status;
pub(crate) use auth::{Auth, AuthResponse};
pub(crate) use packet::Packet;
pub(crate) use auth_plugin::AuthPlugin;
pub(crate) use auth_response::AuthResponse;
pub(crate) use auth_switch::AuthSwitch;
pub(crate) use capabilities::Capabilities;
pub(crate) use column_def::ColumnDefinition;
pub(crate) use command::{Command, MaybeCommand};
pub(crate) use eof::EofPacket;
pub(crate) use err::ErrPacket;
pub(crate) use handshake::Handshake;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use ok::OkPacket;
pub(crate) use ping::Ping;
pub(crate) use query::Query;
pub(crate) use query_response::QueryResponse;
pub(crate) use query_step::QueryStep;
pub(crate) use quit::Quit;
pub(crate) use row::Row;
pub(crate) use status::Status;

View File

@ -1,48 +0,0 @@
use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{Error, Result};
use crate::protocol::{AuthSwitch, Capabilities, MaybeCommand, OkPacket};
use crate::MySqlDatabaseError;
#[derive(Debug)]
pub(crate) enum Auth {
Ok(OkPacket),
MoreData(Bytes),
Switch(AuthSwitch),
}
impl Deserialize<'_, Capabilities> for Auth {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
match buf[0] {
0x00 => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
0x01 => Ok(Self::MoreData(buf.slice(1..))),
0xfe => AuthSwitch::deserialize_with(buf, capabilities).map(Self::Switch),
tag => Err(Error::connect(MySqlDatabaseError::new(
2027,
&format!(
"Malformed packet: Received 0x{:x} but expected one of: 0x0, 0x1, or 0xfe",
tag
),
))),
}
}
}
#[derive(Debug)]
pub(crate) struct AuthResponse {
pub(crate) data: Vec<u8>,
}
impl MaybeCommand for AuthResponse {}
impl Serialize<'_, Capabilities> for AuthResponse {
fn serialize_with(&self, buf: &mut Vec<u8>, _context: Capabilities) -> Result<()> {
buf.extend_from_slice(&self.data);
Ok(())
}
}

View File

@ -0,0 +1,34 @@
use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use crate::protocol::{AuthSwitch, Capabilities, OkPacket};
use crate::MySqlDatabaseError;
#[derive(Debug)]
pub(crate) enum AuthResponse {
Ok(OkPacket),
MoreData(Bytes),
Switch(AuthSwitch),
}
impl Deserialize<'_, Capabilities> for AuthResponse {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
match buf.get(0) {
Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
Some(0x01) => Ok(Self::MoreData(buf.slice(1..))),
Some(0xfe) => AuthSwitch::deserialize(buf).map(Self::Switch),
Some(tag) => Err(Error::connect(MySqlDatabaseError::malformed_packet(&format!(
"Received 0x{:x} but expected one of: 0x0 (OK), 0x1 (MORE DATA), or 0xfe (SWITCH) for auth response",
tag
)))),
None => Err(Error::connect(MySqlDatabaseError::malformed_packet(
"Received no bytes for auth response",
))),
}
}
}

View File

@ -2,7 +2,6 @@ use bytes::{buf::Chain, Buf, Bytes};
use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use super::Capabilities;
use crate::protocol::AuthPlugin;
// https://dev.mysql.com/doc/internals/en/authentication-method-change.html
@ -14,8 +13,8 @@ pub(crate) struct AuthSwitch {
pub(crate) plugin_data: Chain<Bytes, Bytes>,
}
impl Deserialize<'_, Capabilities> for AuthSwitch {
fn deserialize_with(mut buf: Bytes, _capabilities: Capabilities) -> Result<Self> {
impl Deserialize<'_> for AuthSwitch {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let tag = buf.get_u8();
debug_assert_eq!(tag, 0xfe);

View File

@ -0,0 +1,65 @@
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
use crate::io::MySqlBufExt;
/// Describes a column in the result set.
///
/// <https://mariadb.com/kb/en/result-set-packets/#column-definition-packet>
/// <https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition>
#[derive(Debug)]
pub(crate) struct ColumnDefinition {
pub(crate) catalog: ByteString,
pub(crate) schema: ByteString,
pub(crate) table_alias: ByteString,
pub(crate) table: ByteString,
pub(crate) alias: ByteString,
pub(crate) name: ByteString,
pub(crate) charset: u16,
pub(crate) max_size: u32,
pub(crate) ty: u8,
pub(crate) flags: u16,
pub(crate) decimals: u8,
}
impl Deserialize<'_> for ColumnDefinition {
#[allow(unsafe_code)]
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
// UNSAFE: fields are known to be UTF-8 as we have connected with the
// UTF-8 connection charset
let catalog = unsafe { buf.get_str_lenenc_unchecked() };
let schema = unsafe { buf.get_str_lenenc_unchecked() };
let table_alias = unsafe { buf.get_str_lenenc_unchecked() };
let table = unsafe { buf.get_str_lenenc_unchecked() };
let alias = unsafe { buf.get_str_lenenc_unchecked() };
let name = unsafe { buf.get_str_lenenc_unchecked() };
let fixed_len_fields_len = buf.get_uint_lenenc();
// we are told that this is *always* 0x0c
debug_assert_eq!(fixed_len_fields_len, 0x0c);
let charset = buf.get_u16_le();
let max_size = buf.get_u32_le();
let ty = buf.get_u8();
let flags = buf.get_u16_le();
let decimals = buf.get_u8();
Ok(Self {
catalog,
schema,
table_alias,
table,
alias,
name,
charset,
max_size,
ty,
flags,
decimals,
})
}
}

View File

@ -13,6 +13,9 @@ pub(crate) trait MaybeCommand {
}
}
// raw bytes are not a command
impl MaybeCommand for &'_ [u8] {}
/// Marker trait to signal that this protocol type is a Command.
pub(crate) trait Command: MaybeCommand {}

View File

@ -0,0 +1,31 @@
use bytes::{Buf, Bytes};
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
use crate::protocol::{Capabilities, Status};
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub(crate) struct EofPacket {
pub(crate) status: Status,
pub(crate) warnings: u16,
}
impl Deserialize<'_, Capabilities> for EofPacket {
fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self> {
let tag = buf.get_u8();
debug_assert_eq!(tag, 0xfe);
let status =
if capabilities.intersects(Capabilities::PROTOCOL_41 | Capabilities::TRANSACTIONS) {
Status::from_bits_truncate(buf.get_u16_le())
} else {
Status::empty()
};
let warnings =
if capabilities.contains(Capabilities::PROTOCOL_41) { buf.get_u16_le() } else { 0 };
Ok(Self { status, warnings })
}
}

View File

@ -4,7 +4,6 @@ use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use crate::io::MySqlBufExt;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
@ -27,14 +26,14 @@ impl ErrPacket {
}
}
impl Deserialize<'_, Capabilities> for ErrPacket {
fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self> {
impl Deserialize<'_> for ErrPacket {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let tag = buf.get_u8();
debug_assert!(tag == 0xff);
let error_code = buf.get_u16_le();
let sql_state = if capabilities.contains(Capabilities::PROTOCOL_41) && buf[0] == b'#' {
let sql_state = if buf[0] == b'#' {
// if the next byte is '#' then we have the SQL STATE
buf.advance(1);
@ -55,14 +54,13 @@ impl Deserialize<'_, Capabilities> for ErrPacket {
#[cfg(test)]
mod tests {
use super::{Capabilities, Deserialize, ErrPacket};
use super::{Deserialize, ErrPacket};
#[test]
fn test_err_connect_auth() {
const DATA: &[u8] = b"\xff\xe3\x04Client does not support authentication protocol requested by server; consider upgrading MySQL client";
let capabilities = Capabilities::PROTOCOL_41;
let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap();
let ok = ErrPacket::deserialize(DATA.into()).unwrap();
assert_eq!(ok.sql_state, None);
assert_eq!(ok.error_code, 1251);
@ -76,8 +74,7 @@ mod tests {
fn test_err_out_of_order() {
const DATA: &[u8] = b"\xff\x84\x04Got packets out of order";
let capabilities = Capabilities::PROTOCOL_41;
let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap();
let ok = ErrPacket::deserialize(DATA.into()).unwrap();
assert_eq!(ok.sql_state, None);
assert_eq!(ok.error_code, 1156);
@ -88,8 +85,7 @@ mod tests {
fn test_err_unknown_database() {
const DATA: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
let capabilities = Capabilities::PROTOCOL_41;
let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap();
let ok = ErrPacket::deserialize(DATA.into()).unwrap();
assert_eq!(ok.sql_state.as_deref(), Some("42000"));
assert_eq!(ok.error_code, 1049);

View File

@ -31,8 +31,8 @@ pub(crate) struct Handshake {
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
}
impl Deserialize<'_, Capabilities> for Handshake {
fn deserialize_with(mut buf: Bytes, _: Capabilities) -> Result<Self> {
impl Deserialize<'_> for Handshake {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
let protocol_version = buf.get_u8();
// UNSAFE: server version is known to be ASCII
@ -135,13 +135,11 @@ mod tests {
use super::{Capabilities, Handshake, Status};
const EMPTY: Capabilities = Capabilities::empty();
#[test]
fn handshake_mysql_8_0_18() {
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_8_0_18.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
assert_eq!(h.protocol_version, 10);
@ -186,7 +184,7 @@ mod tests {
fn handshake_mariadb_10_4_7() {
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
let mut h = Handshake::deserialize_with(HANDSHAKE_MARIA_DB_10_4_7.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic");
@ -230,7 +228,7 @@ mod tests {
fn handshake_mariadb_10_5_8() {
const HANDSHAKE_MARIA_DB_10_5_8: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0\x07\0\0\0'PB949cf\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0UY>hr&`3{55H\0mysql_native_password\0";
let mut h = Handshake::deserialize_with(HANDSHAKE_MARIA_DB_10_5_8.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_5_8.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal");
@ -274,7 +272,7 @@ mod tests {
fn handshake_mysql_5_6_50() {
const HANDSHAKE_MYSQL_5_6_50: &[u8] = b"\n5.6.50\0\x01\0\0\0-VLYZ:Pd\0\xff\xf7\x08\x02\0\x7f\x80\x15\0\0\0\0\0\0\0\0\0\0'2f+BL8nGV[G\0mysql_native_password\0";
let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_6_50.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_6_50.into()).unwrap();
assert_eq!(h.protocol_version, 10);
@ -318,7 +316,7 @@ mod tests {
fn handshake_mysql_5_0_96() {
const HANDSHAKE_MYSQL_5_0_96: &[u8] = b"\n5.0.96\0\x03\0\0\0bs=sNiGe\0,\xa2\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0IzMP)yLLx;[9\0";
let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_0_96.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_0_96.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.0.96");
@ -350,7 +348,7 @@ mod tests {
fn handshake_mysql_5_1_73() {
const HANDSHAKE_MYSQL_5_1_73: &[u8] = b"\n5.1.73\0\x01\0\0\0<fllZ\\Bs\0\xff\xf7\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0<qEC_87JO/9q\0";
let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_1_73.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_1_73.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.1.73");
@ -386,7 +384,7 @@ mod tests {
fn handshake_mysql_5_5_14() {
const HANDSHAKE_MYSQL_5_5_14: &[u8] = b"\n5.5.14\0\x01\0\0\0`o-/CEp'\0\xff\xf7\x08\x02\0\x0f\x80\x15\0\0\0\0\0\0\0\0\0\0kf@J5j6nJfAP\0mysql_native_password\0";
let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_5_14.into(), EMPTY).unwrap();
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_5_14.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.14");

View File

@ -9,6 +9,7 @@ use crate::protocol::{Capabilities, MaybeCommand};
#[derive(Debug)]
pub(crate) struct HandshakeResponse<'a> {
pub(crate) capabilities: Capabilities,
pub(crate) database: Option<&'a str>,
pub(crate) max_packet_size: u32,
pub(crate) charset: u8,
@ -19,12 +20,12 @@ pub(crate) struct HandshakeResponse<'a> {
impl MaybeCommand for HandshakeResponse<'_> {}
impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) -> Result<()> {
impl Serialize<'_> for HandshakeResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
// the truncation is the intent
// capability bits over 32 are MariaDB only (and we don't currently support them)
#[allow(clippy::cast_possible_truncation)]
buf.extend_from_slice(&(capabilities.bits() as u32).to_le_bytes());
buf.extend_from_slice(&(self.capabilities.bits() as u32).to_le_bytes());
buf.extend_from_slice(&self.max_packet_size.to_le_bytes());
buf.push(self.charset);
@ -35,9 +36,9 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
let auth_response = self.auth_response.as_slice();
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
if self.capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.write_bytes_lenenc(auth_response);
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
} else if self.capabilities.contains(Capabilities::SECURE_CONNECTION) {
debug_assert!(auth_response.len() <= u8::max_value().into());
buf.reserve(1 + auth_response.len());
@ -53,11 +54,11 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
buf.push(b'\0');
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if self.capabilities.contains(Capabilities::CONNECT_WITH_DB) {
buf.write_maybe_str_nul(self.database);
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if self.capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.write_str_nul(self.auth_plugin_name);
}

View File

@ -0,0 +1,32 @@
use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
#[derive(Debug)]
pub(crate) struct Packet {
pub(crate) bytes: Bytes,
}
impl Packet {
#[inline]
pub(crate) fn deserialize<'de, T>(self) -> Result<T>
where
T: Deserialize<'de> + Debug,
{
self.deserialize_with(())
}
#[inline]
pub(crate) fn deserialize_with<'de, T, Cx: 'de>(self, context: Cx) -> Result<T>
where
T: Deserialize<'de, Cx> + Debug,
{
let packet = T::deserialize_with(self.bytes, context)?;
log::trace!("read > {:?}", packet);
Ok(packet)
}
}

View File

@ -1,7 +1,7 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::protocol::{Capabilities, Command};
use crate::protocol::Command;
/// Check if the server is alive.
///
@ -11,8 +11,8 @@ use crate::protocol::{Capabilities, Command};
#[derive(Debug)]
pub(crate) struct Ping;
impl Serialize<'_, Capabilities> for Ping {
fn serialize_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<()> {
impl Serialize<'_> for Ping {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(0x0e);
Ok(())
@ -26,12 +26,11 @@ mod tests {
use sqlx_core::io::Serialize;
use super::Ping;
use crate::protocol::Capabilities;
#[test]
fn should_serialize() -> anyhow::Result<()> {
let mut buf = Vec::new();
Ping.serialize_with(&mut buf, Capabilities::empty())?;
Ping.serialize(&mut buf)?;
assert_eq!(&buf, &[0x0e]);

View File

@ -0,0 +1,25 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
use super::Command;
/// Send the server a text-based query that is executed immediately.
///
/// https://dev.mysql.com/doc/internals/en/com-query.html
/// https://mariadb.com/kb/en/com_query/
///
#[derive(Debug)]
pub(crate) struct Query<'q> {
pub(crate) sql: &'q str,
}
impl Serialize<'_> for Query<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(0x03);
buf.extend_from_slice(self.sql.as_bytes());
Ok(())
}
}
impl Command for Query<'_> {}

View File

@ -0,0 +1,47 @@
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use super::{Capabilities, OkPacket};
use crate::io::MySqlBufExt;
use crate::MySqlDatabaseError;
/// The query-response packet is a meta-packet that starts with one of:
///
/// - OK packet
/// - ERR packet
/// - LOCAL INFILE request (unimplemented)
/// - Result Set
///
/// A result set is *also* a meta-packet that starts with a length-encoded
/// integer for the number of columns. That is all we return from this
/// deserialization and expect the executor to follow up with reading
/// more from the stream.
///
/// <https://dev.mysql.com/doc/internals/en/com-query-response.html>
///
#[derive(Debug)]
pub(crate) enum QueryResponse {
Ok(OkPacket),
ResultSet { columns: u64 },
}
impl Deserialize<'_, Capabilities> for QueryResponse {
fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self> {
// .get does not consume the byte
match buf.get(0) {
Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
// ERR packets are handled on a higher-level (in `recv_packet`), we will
// never receive them here
// If its non-0, then its the number of columns and the start
// of a result set
Some(_) => Ok(Self::ResultSet { columns: buf.get_uint_lenenc() }),
None => Err(Error::connect(MySqlDatabaseError::malformed_packet(
"Received no bytes for COM_QUERY response",
))),
}
}
}

View File

@ -0,0 +1,40 @@
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use super::{Capabilities, ColumnDefinition, OkPacket, Row};
use crate::MySqlDatabaseError;
/// <https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset>
/// <https://mariadb.com/kb/en/result-set-packets/>
#[derive(Debug)]
pub(crate) enum QueryStep {
Row(Row),
End(OkPacket),
}
impl Deserialize<'_, (Capabilities, &'_ [ColumnDefinition])> for QueryStep {
fn deserialize_with(
buf: Bytes,
(capabilities, columns): (Capabilities, &'_ [ColumnDefinition]),
) -> Result<Self> {
// .get does not consume the byte
match buf.get(0) {
// To safely confirm that a packet with a 0xFE header is an OK packet (OK_Packet) or an
// EOF packet (EOF_Packet), you must also check that the packet length is less than 0xFFFFFF
Some(0xfe) if buf.len() < 0xFF_FF_FF => {
OkPacket::deserialize_with(buf, capabilities).map(Self::End)
}
// ERR packets are handled on a higher-level (in `recv_packet`), we will
// never receive them here
// If its non-0, then its a Row
Some(_) => Row::deserialize_with(buf, columns).map(Self::Row),
None => Err(Error::connect(MySqlDatabaseError::malformed_packet(
"Received no bytes for the next step in a result set",
))),
}
}
}

View File

@ -1,7 +1,7 @@
use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::protocol::{Capabilities, Command};
use crate::protocol::Command;
/// Tells the server that the client wants to close the connection.
///
@ -10,8 +10,8 @@ use crate::protocol::{Capabilities, Command};
#[derive(Debug)]
pub(crate) struct Quit;
impl Serialize<'_, Capabilities> for Quit {
fn serialize_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<()> {
impl Serialize<'_> for Quit {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
buf.push(0x01);
Ok(())
@ -25,12 +25,11 @@ mod tests {
use sqlx_core::io::Serialize;
use super::Quit;
use crate::protocol::Capabilities;
#[test]
fn should_serialize() -> anyhow::Result<()> {
let mut buf = Vec::new();
Quit.serialize_with(&mut buf, Capabilities::empty())?;
Quit.serialize(&mut buf)?;
assert_eq!(&buf, &[0x01]);

View File

@ -0,0 +1,28 @@
use bytes::{Buf, Bytes};
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
use crate::io::MySqlBufExt;
use crate::protocol::ColumnDefinition;
#[derive(Debug)]
pub(crate) struct Row {
pub(crate) values: Vec<Option<Bytes>>,
}
impl<'de> Deserialize<'de, &'de [ColumnDefinition]> for Row {
fn deserialize_with(mut buf: Bytes, columns: &'de [ColumnDefinition]) -> Result<Self> {
let mut values = Vec::with_capacity(columns.len());
for _ in columns {
values.push(if buf.get(0).copied() == Some(0xfb) {
buf.advance(1);
None
} else {
Some(buf.get_bytes_lenenc())
});
}
Ok(Self { values })
}
}