WIP: Add serde for prepared statements

This commit is contained in:
Daniel Akhterov 2019-07-29 21:35:08 -07:00
parent 3ddf4508af
commit 9ca4e10836
34 changed files with 506 additions and 93 deletions

View File

@ -102,15 +102,10 @@ impl Connection {
Ok(conn)
}
pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
where
S: Serialize,
{
pub async fn send<S>(&mut self, message: S) -> Result<(), Error> where S: Serialize {
self.encoder.clear();
self.encoder.alloc_packet_header();
self.encoder.seq_no(self.context.seq_no);
message.serialize(&mut self.context, &mut self.encoder)?;
self.encoder.encode_length();
self.stream.inner.write_all(&self.encoder.buf).await?;
self.stream.inner.flush().await?;

View File

@ -29,6 +29,11 @@ pub use protocol::PacketHeader;
pub use protocol::ResultSet;
pub use protocol::ResultRow;
pub use protocol::SSLRequestPacket;
pub use protocol::ComStmtPrepare;
pub use protocol::ComStmtPrepareOk;
pub use protocol::ComStmtPrepareResp;
pub use protocol::ComStmtClose;
pub use protocol::ComStmtExec;
pub use protocol::Decoder;
pub use protocol::DeContext;
pub use protocol::Deserialize;
@ -41,4 +46,6 @@ pub use protocol::ServerStatusFlag;
pub use protocol::FieldType;
pub use protocol::FieldDetailFlag;
pub use protocol::SessionChangeType;
pub use protocol::StmtExecFlag;
pub use protocol::TextProtocol;
pub use protocol::BinaryProtocol;

View File

@ -35,3 +35,16 @@ impl Into<u8> for TextProtocol {
self as u8
}
}
pub enum BinaryProtocol {
ComStmtPrepare = 0x16,
ComStmtClose = 0x19,
ComStmtExec = 0x17,
}
// Helper method to easily transform into u8
impl Into<u8> for BinaryProtocol {
fn into(self) -> u8 {
self as u8
}
}

View File

@ -108,7 +108,7 @@ impl<'a> Decoder<'a> {
// Decode an int<8> which is a i64
#[inline]
pub fn decode_int_8(&mut self) -> i64 {
pub fn decode_int_i64(&mut self) -> i64 {
let value = LittleEndian::read_i64(&self.buf[self.index..]);
self.index += 8;
value
@ -116,16 +116,16 @@ impl<'a> Decoder<'a> {
// Decode an int<4> which is a i32
#[inline]
pub fn decode_int_4(&mut self) -> i32 {
pub fn decode_int_i32(&mut self) -> i32 {
let value = LittleEndian::read_i32(&self.buf[self.index..]);
self.index += 4;
value
}
// Decode an int<4> which is a i32
// Decode an int<4> which is a u32
// This is a helper method for decoding flags.
#[inline]
pub fn decode_int_4_unsigned(&mut self) -> u32 {
pub fn decode_int_u32(&mut self) -> u32 {
let value = LittleEndian::read_u32(&self.buf[self.index..]);
self.index += 4;
value
@ -133,7 +133,7 @@ impl<'a> Decoder<'a> {
// Decode an int<3> which is a i24
#[inline]
pub fn decode_int_3(&mut self) -> i32 {
pub fn decode_int_i24(&mut self) -> i32 {
let value = LittleEndian::read_i24(&self.buf[self.index..]);
self.index += 3;
value
@ -141,7 +141,7 @@ impl<'a> Decoder<'a> {
// Decode an int<2> which is a i16
#[inline]
pub fn decode_int_2(&mut self) -> i16 {
pub fn decode_int_i16(&mut self) -> i16 {
let value = LittleEndian::read_i16(&self.buf[self.index..]);
self.index += 2;
value
@ -150,7 +150,7 @@ impl<'a> Decoder<'a> {
// Decode an int<2> as an u16
// This is a helper method for decoding flags.
#[inline]
pub fn decode_int_2_unsigned(&mut self) -> u16 {
pub fn decode_int_u16(&mut self) -> u16 {
let value = LittleEndian::read_u16(&self.buf[self.index..]);
self.index += 2;
value
@ -158,7 +158,7 @@ impl<'a> Decoder<'a> {
// Decode an int<1> which is a u8
#[inline]
pub fn decode_int_1(&mut self) -> u8 {
pub fn decode_int_u8(&mut self) -> u8 {
let value = self.buf[self.index];
self.index += 1;
value
@ -221,7 +221,7 @@ impl<'a> Decoder<'a> {
// Same as the string counter part, but copied to maintain consistency with the spec.
#[inline]
pub fn decode_byte_lenenc(&mut self) -> Bytes {
let length = self.decode_int_1();
let length = self.decode_int_u8();
let value = self.buf.slice(self.index, self.index + length as usize);
self.index = self.index + length as usize;
value
@ -319,7 +319,7 @@ mod tests {
fn it_decodes_int_8() {
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let int: i64 = decoder.decode_int_8();
let int: i64 = decoder.decode_int_i64();
assert_eq!(int, 0x0101010101010101);
assert_eq!(decoder.index, 8);
@ -329,7 +329,7 @@ mod tests {
fn it_decodes_int_4() {
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let int: i32 = decoder.decode_int_4();
let int: i32 = decoder.decode_int_i32();
assert_eq!(int, 0x01010101);
assert_eq!(decoder.index, 4);
@ -339,7 +339,7 @@ mod tests {
fn it_decodes_int_3() {
let buf = __bytes_builder!(1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let int: i32 = decoder.decode_int_3();
let int: i32 = decoder.decode_int_i24();
assert_eq!(int, 0x010101);
assert_eq!(decoder.index, 3);
@ -349,7 +349,7 @@ mod tests {
fn it_decodes_int_2() {
let buf = __bytes_builder!(1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let int: i16 = decoder.decode_int_2();
let int: i16 = decoder.decode_int_i16();
assert_eq!(int, 0x0101);
assert_eq!(decoder.index, 2);
@ -359,7 +359,7 @@ mod tests {
fn it_decodes_int_1() {
let buf = __bytes_builder!(1u8);
let mut decoder = Decoder::new(&buf);
let int: u8 = decoder.decode_int_1();
let int: u8 = decoder.decode_int_u8();
assert_eq!(int, 1u8);
assert_eq!(decoder.index, 1);

View File

@ -1,5 +1,6 @@
use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, Bytes, BytesMut};
use crate::mariadb::FieldType;
const U24_MAX: usize = 0xFF_FF_FF;
@ -37,7 +38,7 @@ impl Encoder {
let mut length = [0; 3];
if self.buf.len() > U24_MAX {
panic!("Buffer too long");
} else if self.buf.len() <= 4 {
} else if self.buf.len() < 4 {
panic!("Buffer too short. Only contains packet length and sequence number")
}
@ -52,31 +53,80 @@ impl Encoder {
// Encode a u64 as an int<8>
#[inline]
pub fn encode_int_8(&mut self, value: u64) {
pub fn encode_int_u64(&mut self, value: u64) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
// Encode a i64 as an int<8>
#[inline]
pub fn encode_int_i64(&mut self, value: i64) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_8(&mut self, bytes: &Bytes) {
self.buf.extend_from_slice(bytes);
}
// Encode a u32 as an int<4>
#[inline]
pub fn encode_int_4(&mut self, value: u32) {
pub fn encode_int_u32(&mut self, value: u32) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
// Encode a i32 as an int<4>
#[inline]
pub fn encode_int_i32(&mut self, value: i32) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_4(&mut self, bytes: &Bytes) {
self.buf.extend_from_slice(bytes);
}
// Encode a u32 (truncated to u24) as an int<3>
#[inline]
pub fn encode_int_3(&mut self, value: u32) {
pub fn encode_int_u24(&mut self, value: u32) {
self.buf.extend_from_slice(&value.to_le_bytes()[0..3]);
}
// Encode a i32 (truncated to i24) as an int<3>
#[inline]
pub fn encode_int_i24(&mut self, value: i32) {
self.buf.extend_from_slice(&value.to_le_bytes()[0..3]);
}
// Encode a u16 as an int<2>
#[inline]
pub fn encode_int_2(&mut self, value: u16) {
pub fn encode_int_u16(&mut self, value: u16) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
// Encode a i16 as an int<2>
#[inline]
pub fn encode_int_i16(&mut self, value: i16) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_2(&mut self, bytes: &Bytes) {
self.buf.extend_from_slice(bytes);
}
// Encode a u8 as an int<1>
#[inline]
pub fn encode_int_1(&mut self, value: u8) {
pub fn encode_int_u8(&mut self, value: u8) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_1(&mut self, bytes: &Bytes) {
self.buf.extend_from_slice(bytes);
}
// Encode a i8 as an int<1>
#[inline]
pub fn encode_int_i8(&mut self, value: i8) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
@ -87,15 +137,15 @@ impl Encoder {
if let Some(value) = value {
if *value > U24_MAX && *value <= std::u64::MAX as usize {
self.buf.put_u8(0xFE);
self.encode_int_8(*value as u64);
self.encode_int_u64(*value as u64);
} else if *value > std::u16::MAX as usize && *value <= U24_MAX {
self.buf.put_u8(0xFD);
self.encode_int_3(*value as u32);
self.encode_int_u24(*value as u32);
} else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize {
self.buf.put_u8(0xFC);
self.encode_int_2(*value as u16);
self.encode_int_u16(*value as u16);
} else if *value <= std::u8::MAX as usize {
match *value {
@ -160,7 +210,7 @@ impl Encoder {
panic!("String inside string lenenc serialization is too long");
}
self.encode_int_3(bytes.len() as u32);
self.encode_int_u24(bytes.len() as u32);
self.buf.extend_from_slice(bytes);
}
@ -179,6 +229,43 @@ impl Encoder {
pub fn encode_byte_eof(&mut self, bytes: &Bytes) {
self.buf.extend_from_slice(bytes);
}
#[inline]
pub fn encode_binary(&mut self, bytes: &Bytes, ty: &FieldType) {
match ty {
FieldType::MysqlTypeDecimal => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeTiny => self.encode_int_1(bytes),
FieldType::MysqlTypeShort => self.encode_int_2(bytes),
FieldType::MysqlTypeLong => self.encode_int_4(bytes),
FieldType::MysqlTypeFloat => self.encode_int_4(bytes),
FieldType::MysqlTypeDouble => self.encode_int_8(bytes),
FieldType::MysqlTypeNull => panic!("Type cannot be FieldType::MysqlTypeNull"),
FieldType::MysqlTypeTimestamp => unimplemented!(),
FieldType::MysqlTypeLonglong => self.encode_int_8(bytes),
FieldType::MysqlTypeInt24 => self.encode_int_4(bytes),
FieldType::MysqlTypeDate => unimplemented!(),
FieldType::MysqlTypeTime => unimplemented!(),
FieldType::MysqlTypeDatetime => unimplemented!(),
FieldType::MysqlTypeYear => self.encode_int_4(bytes),
FieldType::MysqlTypeNewdate => unimplemented!(),
FieldType::MysqlTypeVarchar => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeBit => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeTimestamp2 => unimplemented!(),
FieldType::MysqlTypeDatetime2 => unimplemented!(),
FieldType::MysqlTypeTime2 =>unimplemented!(),
FieldType::MysqlTypeJson => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeNewdecimal => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeEnum => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeSet => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeTinyBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeMediumBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeLongBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeVarString => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeString => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeGeometry => self.encode_string_lenenc(bytes),
}
}
}
impl From<BytesMut> for Encoder {
@ -292,7 +379,7 @@ mod tests {
#[test]
fn it_encodes_int_u64() {
let mut encoder = Encoder::new(128);
encoder.encode_int_8(std::u64::MAX);
encoder.encode_int_u64(std::u64::MAX);
assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
@ -300,7 +387,7 @@ mod tests {
#[test]
fn it_encodes_int_u32() {
let mut encoder = Encoder::new(128);
encoder.encode_int_4(std::u32::MAX);
encoder.encode_int_u32(std::u32::MAX);
assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF\xFF");
}
@ -308,7 +395,7 @@ mod tests {
#[test]
fn it_encodes_int_u24() {
let mut encoder = Encoder::new(128);
encoder.encode_int_3(U24_MAX as u32);
encoder.encode_int_u24(U24_MAX as u32);
assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF");
}
@ -316,7 +403,7 @@ mod tests {
#[test]
fn it_encodes_int_u16() {
let mut encoder = Encoder::new(128);
encoder.encode_int_2(std::u16::MAX);
encoder.encode_int_u16(std::u16::MAX);
assert_eq!(&encoder.buf[..], b"\xFF\xFF");
}
@ -324,7 +411,7 @@ mod tests {
#[test]
fn it_encodes_int_u8() {
let mut encoder = Encoder::new(128);
encoder.encode_int_1(std::u8::MAX);
encoder.encode_int_u8(std::u8::MAX);
assert_eq!(&encoder.buf[..], b"\xFF");
}

View File

@ -34,6 +34,11 @@ pub use packets::PacketHeader;
pub use packets::ResultSet;
pub use packets::ResultRow;
pub use packets::SSLRequestPacket;
pub use packets::ComStmtPrepare;
pub use packets::ComStmtPrepareOk;
pub use packets::ComStmtPrepareResp;
pub use packets::ComStmtClose;
pub use packets::ComStmtExec;
pub use decode::Decoder;
@ -53,5 +58,7 @@ pub use types::ServerStatusFlag;
pub use types::FieldType;
pub use types::FieldDetailFlag;
pub use types::SessionChangeType;
pub use types::StmtExecFlag;
pub use client::TextProtocol;
pub use client::BinaryProtocol;

View File

@ -11,7 +11,7 @@ pub struct AuthenticationSwitchRequestPacket {
impl Serialize for AuthenticationSwitchRequestPacket {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(0xFE);
encoder.encode_int_u8(0xFE);
encoder.encode_string_null(&self.auth_plugin_name);
encoder.encode_byte_eof(&self.auth_plugin_data);

View File

@ -17,7 +17,7 @@ impl Deserialize for ColumnPacket {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
let columns = decoder.decode_int_lenenc();

View File

@ -25,7 +25,7 @@ impl Deserialize for ColumnDefPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
// string<lenenc> catalog (always 'def')
let catalog = decoder.decode_string_lenenc();
@ -42,15 +42,15 @@ impl Deserialize for ColumnDefPacket {
// int<lenenc> length of fixed fields (=0xC)
let length_of_fixed_fields = decoder.decode_int_lenenc();
// int<2> character set number
let char_set = decoder.decode_int_2();
let char_set = decoder.decode_int_i16();
// int<4> max. column size
let max_columns = decoder.decode_int_4();
let max_columns = decoder.decode_int_i32();
// int<1> Field types
let field_type = FieldType::try_from(decoder.decode_int_1())?;
let field_type = FieldType::try_from(decoder.decode_int_u8())?;
// int<2> Field detail flag
let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2_unsigned());
let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_u16());
// int<1> decimals
let decimals = decoder.decode_int_1();
let decimals = decoder.decode_int_u8();
// int<2> - unused -
decoder.skip_bytes(2);

View File

@ -5,7 +5,12 @@ pub struct ComDebug();
impl Serialize for ComDebug {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComDebug.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComDebug.into());
encoder.encode_length();
Ok(())
}

View File

@ -8,9 +8,14 @@ pub struct ComInitDb {
impl Serialize for ComInitDb {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComInitDb.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComInitDb.into());
encoder.encode_string_null(&self.schema_name);
encoder.encode_length();
Ok(())
}
}

View File

@ -5,7 +5,12 @@ pub struct ComPing();
impl Serialize for ComPing {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComPing.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComPing.into());
encoder.encode_length();
Ok(())
}

View File

@ -7,8 +7,13 @@ pub struct ComProcessKill {
impl Serialize for ComProcessKill {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComProcessKill.into());
encoder.encode_int_4(self.process_id);
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComProcessKill.into());
encoder.encode_int_u32(self.process_id);
encoder.encode_length();
Ok(())
}

View File

@ -8,9 +8,14 @@ pub struct ComQuery {
impl Serialize for ComQuery {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComQuery.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComQuery.into());
encoder.encode_string_eof(&self.sql_statement);
encoder.encode_length();
Ok(())
}
}

View File

@ -5,7 +5,12 @@ pub struct ComQuit();
impl Serialize for ComQuit {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComQuit.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComQuit.into());
encoder.encode_length();
Ok(())
}

View File

@ -5,7 +5,12 @@ pub struct ComResetConnection();
impl Serialize for ComResetConnection {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComResetConnection.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComResetConnection.into());
encoder.encode_length();
Ok(())
}

View File

@ -13,8 +13,13 @@ pub struct ComSetOption {
impl Serialize for ComSetOption {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComSetOption.into());
encoder.encode_int_2(self.option.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComSetOption.into());
encoder.encode_int_u16(self.option.into());
encoder.encode_length();
Ok(())
}

View File

@ -12,8 +12,13 @@ pub struct ComShutdown {
impl Serialize for ComShutdown {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComShutdown.into());
encoder.encode_int_1(self.option.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComShutdown.into());
encoder.encode_int_u8(self.option.into());
encoder.encode_length();
Ok(())
}

View File

@ -5,7 +5,12 @@ pub struct ComSleep();
impl Serialize for ComSleep {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComSleep.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComSleep.into());
encoder.encode_length();
Ok(())
}

View File

@ -5,7 +5,12 @@ pub struct ComStatistics();
impl Serialize for ComStatistics {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_1(TextProtocol::ComStatistics.into());
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(TextProtocol::ComStatistics.into());
encoder.encode_length();
Ok(())
}

View File

@ -0,0 +1,20 @@
use std::convert::TryInto;
#[derive(Debug)]
pub struct ComStmtClose {
stmt_id: i32
}
impl crate::mariadb::Serialize for ComStmtClose {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> {
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtClose.into());
encoder.encode_int_i32(self.stmt_id);
encoder.encode_length();
Ok(())
}
}

View File

@ -0,0 +1,68 @@
use crate::mariadb::{StmtExecFlag, ColumnDefPacket, FieldDetailFlag};
use bytes::Bytes;
#[derive(Debug)]
pub struct ComStmtExec {
pub stmt_id: i32,
pub flags: StmtExecFlag,
pub params: Option<Vec<Option<Bytes>>>,
pub param_defs: Option<Vec<ColumnDefPacket>>,
}
impl crate::mariadb::Serialize for ComStmtExec {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> {
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtExec.into());
encoder.encode_int_i32(self.stmt_id);
encoder.encode_int_u8(self.flags as u8);
encoder.encode_int_u8(0);
if let Some(params) = &self.params {
let null_bitmap_size = (params.len() + 7) / 8;
let mut shift_amount = 0u8;
let mut bitmap = vec![0u8];
// Generate NULL-bitmap from params
for param in params {
if param.is_none() {
bitmap.push(bitmap.last().unwrap() & (1 << shift_amount));
}
shift_amount = (shift_amount + 1) % 8;
if shift_amount % 8 == 0 {
bitmap.push(0u8);
}
}
// Do not send the param types
encoder.encode_int_u8(if self.param_defs.is_some() {
1u8
} else {
0u8
});
if let Some(params) = &self.param_defs {
for param in params {
encoder.encode_int_u8(param.field_type as u8);
encoder.encode_int_u8(if (param.field_details & FieldDetailFlag::UNSIGNED).is_empty() {
1u8
} else {
0u8
});
}
}
// Encode params
for param in params {
}
}
encoder.encode_length();
Ok(())
}
}

View File

@ -0,0 +1,20 @@
use bytes::Bytes;
#[derive(Debug)]
pub struct ComStmtPrepare {
statement: Bytes
}
impl crate::mariadb::Serialize for ComStmtPrepare {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> {
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtPrepare.into());
encoder.encode_string_eof(&self.statement);
encoder.encode_length();
Ok(())
}
}

View File

@ -0,0 +1,34 @@
use std::convert::TryFrom;
#[derive(Debug)]
pub struct ComStmtPrepareOk {
pub stmt_id: i32,
pub columns: i16,
pub params: i16,
pub warnings: i16,
}
impl crate::mariadb::Deserialize for ComStmtPrepareOk {
fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result<Self, failure::Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
let stmt_id = decoder.decode_int_i32();
let columns = decoder.decode_int_i16();
let params = decoder.decode_int_i16();
// Skip 1 unused byte;
decoder.skip_bytes(1);
let warnings = decoder.decode_int_i16();
Ok(ComStmtPrepareOk {
stmt_id,
columns,
params,
warnings
})
}
}

View File

@ -0,0 +1,60 @@
use crate::mariadb::{ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket};
#[derive(Debug)]
pub struct ComStmtPrepareResp {
pub ok: ComStmtPrepareOk,
pub param_defs: Option<Vec<ColumnDefPacket>>,
pub res_columns: Option<Vec<ColumnDefPacket>>,
}
//int<1> 0x00 COM_STMT_PREPARE_OK header
//int<4> statement id
//int<2> number of columns in the returned result set (or 0 if statement does not return result set)
//int<2> number of prepared statement parameters ('?' placeholders)
//string<1> -not used-
//int<2> number of warnings
impl crate::mariadb::Deserialize for ComStmtPrepareResp {
fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result<Self, failure::Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let ok = ComStmtPrepareOk::deserialize(ctx)?;
let param_defs = if ok.params > 0 {
let param_defs = (0..ok.params).map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() {
EofPacket::deserialize(ctx)?;
}
Some(param_defs)
} else {
None
};
let res_columns = if ok.columns > 0 {
let param_defs = (0..ok.columns).map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() {
EofPacket::deserialize(ctx)?;
}
Some(param_defs)
} else {
None
};
Ok(ComStmtPrepareResp {
ok,
param_defs,
res_columns
})
}
}

View File

@ -18,16 +18,16 @@ impl Deserialize for EofPacket {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
let packet_header = decoder.decode_int_1();
let packet_header = decoder.decode_int_u8();
if packet_header != 0xFE {
panic!("Packet header is not 0xFE for ErrPacket");
}
let warning_count = decoder.decode_int_2();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned());
let warning_count = decoder.decode_int_i16();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16());
Ok(EofPacket { length, seq_no, warning_count, status })
}

View File

@ -24,14 +24,14 @@ impl Deserialize for ErrPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
let packet_header = decoder.decode_int_1();
let packet_header = decoder.decode_int_u8();
if packet_header != 0xFF {
panic!("Packet header is not 0xFF for ErrPacket");
}
let error_code = ErrorCode::try_from(decoder.decode_int_2())?;
let error_code = ErrorCode::try_from(decoder.decode_int_i16())?;
let mut stage = None;
let mut max_stage = None;
@ -44,9 +44,9 @@ impl Deserialize for ErrPacket {
// Progress Reporting
if error_code as u16 == 0xFFFF {
stage = Some(decoder.decode_int_1());
max_stage = Some(decoder.decode_int_1());
progress = Some(decoder.decode_int_3());
stage = Some(decoder.decode_int_u8());
max_stage = Some(decoder.decode_int_u8());
progress = Some(decoder.decode_int_i24());
progress_info = Some(decoder.decode_string_lenenc());
} else {
if decoder.buf[decoder.index] == b'#' {

View File

@ -20,9 +20,12 @@ pub struct HandshakeResponsePacket {
impl Serialize for HandshakeResponsePacket {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_4(self.capabilities.bits() as u32);
encoder.encode_int_4(self.max_packet_size);
encoder.encode_int_1(self.collation);
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u32(self.capabilities.bits() as u32);
encoder.encode_int_u32(self.max_packet_size);
encoder.encode_int_u8(self.collation);
// Filler
encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
@ -31,7 +34,7 @@ impl Serialize for HandshakeResponsePacket {
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
encoder.encode_int_4(capabilities.bits() as u32);
encoder.encode_int_u32(capabilities.bits() as u32);
}
} else {
encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 4]), 4);
@ -45,12 +48,12 @@ impl Serialize for HandshakeResponsePacket {
}
} else if !(ctx.capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
if let Some(auth_response) = &self.auth_response {
encoder.encode_int_1(self.auth_response_len.unwrap());
encoder.encode_int_u8(self.auth_response_len.unwrap());
encoder
.encode_string_fix(&auth_response, self.auth_response_len.unwrap() as usize);
}
} else {
encoder.encode_int_1(0);
encoder.encode_int_u8(0);
}
if !(ctx.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() {
@ -80,6 +83,8 @@ impl Serialize for HandshakeResponsePacket {
}
}
encoder.encode_length();
Ok(())
}
}

View File

@ -23,31 +23,31 @@ impl Deserialize for InitialHandshakePacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
if seq_no != 0 {
return Err(err_msg("Sequence Number of Initial Handshake Packet is not 0"));
}
let protocol_version = decoder.decode_int_1();
let protocol_version = decoder.decode_int_u8();
let server_version = decoder.decode_string_null()?;
let connection_id = decoder.decode_int_4();
let connection_id = decoder.decode_int_i32();
let auth_seed = decoder.decode_string_fix(8);
// Skip reserved byte
decoder.skip_bytes(1);
let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into());
let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
let collation = decoder.decode_int_1();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into());
let collation = decoder.decode_int_u8();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into());
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_2() as u32) << 16).into());
Capabilities::from_bits_truncate(((decoder.decode_int_i16() as u32) << 16).into());
let mut plugin_data_length = 0;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
plugin_data_length = decoder.decode_int_1();
plugin_data_length = decoder.decode_int_u8();
} else {
// Skip reserve byte
decoder.skip_bytes(1);
@ -58,7 +58,7 @@ impl Deserialize for InitialHandshakePacket {
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_4_unsigned() as u128) << 32).into());
Capabilities::from_bits_truncate(((decoder.decode_int_u32() as u128) << 32).into());
} else {
// Skip filler
decoder.skip_bytes(4);

View File

@ -21,6 +21,11 @@ pub mod packet_header;
pub mod result_set;
pub mod ssl_request;
pub mod result_row;
pub mod com_stmt_prepare;
pub mod com_stmt_prepare_ok;
pub mod com_stmt_prepare_resp;
pub mod com_stmt_close;
pub mod com_stmt_exec;
pub use auth_switch_request::AuthenticationSwitchRequestPacket;
pub use column::ColumnPacket;
@ -47,3 +52,8 @@ pub use packet_header::PacketHeader;
pub use result_set::ResultSet;
pub use result_row::ResultRow;
pub use ssl_request::SSLRequestPacket;
pub use com_stmt_prepare::ComStmtPrepare;
pub use com_stmt_prepare_ok::ComStmtPrepareOk;
pub use com_stmt_prepare_resp::ComStmtPrepareResp;
pub use com_stmt_close::ComStmtClose;
pub use com_stmt_exec::ComStmtExec;

View File

@ -23,21 +23,21 @@ impl Deserialize for OkPacket {
// Packet header
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
// Used later for the byte_eof decoding
let index = decoder.index;
// Packet body
let packet_header = decoder.decode_int_1();
let packet_header = decoder.decode_int_u8();
if packet_header != 0 && packet_header != 0xFE {
return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket"));
}
let affected_rows = decoder.decode_int_lenenc();
let last_insert_id = decoder.decode_int_lenenc();
let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into());
let warning_count = decoder.decode_int_2();
let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into());
let warning_count = decoder.decode_int_i16();
// Assuming CLIENT_SESSION_TRACK is unsupported
let session_state_info = None;

View File

@ -17,7 +17,7 @@ impl Deserialize for ResultRow {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let seq_no = decoder.decode_int_u8();
let row = if let Some(columns) = ctx.columns {
(0..columns).map(|_| decoder.decode_string_lenenc()).collect::<Vec<Bytes>>()

View File

@ -13,9 +13,12 @@ pub struct SSLRequestPacket {
impl Serialize for SSLRequestPacket {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.encode_int_4(self.capabilities.bits() as u32);
encoder.encode_int_4(self.max_packet_size);
encoder.encode_int_1(self.collation);
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.encode_int_u32(self.capabilities.bits() as u32);
encoder.encode_int_u32(self.max_packet_size);
encoder.encode_int_u8(self.collation);
// Filler
encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
@ -24,12 +27,14 @@ impl Serialize for SSLRequestPacket {
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
encoder.encode_int_4(capabilities.bits() as u32);
encoder.encode_int_u32(capabilities.bits() as u32);
}
} else {
encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 4]), 4);
}
encoder.encode_length();
Ok(())
}
}

View File

@ -111,6 +111,21 @@ pub enum FieldType {
MysqlTypeGeometry = 255,
}
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)]
#[TryFromPrimitiveType = "u8"]
pub enum StmtExecFlag {
NoCursor = 0,
ReadOnly = 1,
CursorForUpdate = 2,
ScrollableCursor = 3,
}
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)]
#[TryFromPrimitiveType = "u8"]
pub enum ParamFlag {
Unsigned = 128,
}
impl Default for Capabilities {
fn default() -> Self {
Capabilities::CLIENT_MYSQL
@ -135,6 +150,18 @@ impl Default for FieldType {
}
}
impl Default for StmtExecFlag {
fn default() -> Self {
StmtExecFlag::NoCursor
}
}
impl Default for ParamFlag {
fn default() -> Self {
ParamFlag::Unsigned
}
}
#[cfg(test)]
mod test {
use super::super::{decode::Decoder, types::Capabilities};
@ -144,6 +171,6 @@ mod test {
fn it_decodes_capabilities() {
let buf = Bytes::from(b"\xfe\xf7".to_vec());
let mut decoder = Decoder::new(&buf);
Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into());
Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
}
}