Fix result set decoding OkPacket instead of EofPacket at the end of the transaction

This commit is contained in:
Daniel Akhterov
2019-07-26 13:28:12 -07:00
parent 47598bbcd6
commit 05024dbd0a
4 changed files with 121 additions and 46 deletions

View File

@@ -144,8 +144,12 @@ impl<'a> Decoder<'a> {
}
#[inline]
pub fn decode_string_eof(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.buf.len());
pub fn decode_string_eof(&mut self, length: usize) -> Bytes {
let value = self.buf.slice(self.index, if length >= self.index {
length
} else {
self.buf.len()
});
self.index = self.buf.len();
value
}
@@ -177,8 +181,12 @@ impl<'a> Decoder<'a> {
}
#[inline]
pub fn decode_byte_eof(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.buf.len());
pub fn decode_byte_eof(&mut self, length: usize) -> Bytes {
let value = self.buf.slice(self.index, if length >= self.index {
length
} else {
self.buf.len()
});
self.index = self.buf.len();
value
}
@@ -330,7 +338,7 @@ mod tests {
fn it_decodes_string_eof() {
let buf = Bytes::from(b"\x01".to_vec());
let mut decoder = Decoder::new(&buf);
let string: Bytes = decoder.decode_string_eof();
let string: Bytes = decoder.decode_string_eof(0);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
@@ -367,7 +375,7 @@ mod tests {
fn it_decodes_byte_eof() {
let buf = Bytes::from(b"\x01".to_vec());
let mut decoder = Decoder::new(&buf);
let string: Bytes = decoder.decode_byte_eof();
let string: Bytes = decoder.decode_byte_eof(0);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);

View File

@@ -54,9 +54,9 @@ impl Deserialize for ErrPacket {
if decoder.buf[decoder.index] == b'#' {
sql_state_marker = Some(decoder.decode_string_fix(1));
sql_state = Some(decoder.decode_string_fix(5));
error_message = Some(decoder.decode_string_eof());
error_message = Some(decoder.decode_string_eof(length as usize));
} else {
error_message = Some(decoder.decode_string_eof());
error_message = Some(decoder.decode_string_eof(length as usize));
}
}

View File

@@ -22,10 +22,14 @@ pub struct OkPacket {
impl Deserialize for OkPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
// Packet header
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
// Used later for the byte_eof decoding
let index = decoder.index;
// Packet body
let packet_header = decoder.decode_int_1();
if packet_header != 0 && packet_header != 0xFE {
@@ -41,7 +45,7 @@ impl Deserialize for OkPacket {
let session_state_info = None;
let value = None;
let info = decoder.decode_byte_eof();
let info = decoder.decode_byte_eof(index + length as usize);
Ok(OkPacket {
length,
@@ -78,9 +82,9 @@ mod test {
#[rustfmt::skip]
let buf = __bytes_builder!(
// length
// int<3> length
0u8, 0u8, 0u8,
// seq_no
// // int<1> seq_no
1u8,
// 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set)
0u8,

View File

@@ -63,9 +63,9 @@ impl Deserialize for ResultSet {
}
if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() {
EofPacket::deserialize(ctx)?;
} else {
OkPacket::deserialize(ctx)?;
} else {
EofPacket::deserialize(ctx)?;
}
Ok(ResultSet {
@@ -97,156 +97,219 @@ mod test {
.await?;
// TODO: Use byte string as input for test; this is a valid return from a mariadb.
// Reference: b"\x01\0\0\x01\x04(\0\0\x02\x03def\x04test\x05users\x05users\x02id\x02id\x0c\x08\0\x80\0\0\0\xfd\x03@\0\0\04\0\0\x03\x03def\x04test\x05users\x05users\x08username\x08username\x0c\x08\0\xff\xff\0\0\xfc\x11\x10\0\0\04\0\0\x04\x03def\x04test\x05users\x05users\x08password\x08password\x0c\x08\0\xff\xff\0\0\xfc\x11\x10\0\0\0<\0\0\x05\x03def\x04test\x05users\x05users\x0caccess_level\x0caccess_level\x0c\x08\0\x07\0\0\0\xfe\x01\x11\0\0\0\x05\0\0\x06\xfe\0\0\"\0>\0\0\x07$044d3f34-af65-11e9-a2e5-0242ac110003\x04josh\x0bpassword123\x07regular4\0\0\x08$d83dd1c4-ada9-11e9-96bc-0242ac110003\x06daniel\x01f\x05admin\x05\0\0\t\xfe\0\0\"\0\0
#[rustfmt::skip]
let buf = __bytes_builder!(
// ------------------- //
// Column Count packet //
// ------------------- //
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<lenenc> tag code or length
4u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
40u8, 0u8, 0u8,
// int<1> seq_no
2u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
2u8, b"id",
// string<lenenc> column
2u8, b"id",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0x80_u8, 0u8, 0u8, 0u8,
// int<1> Field types
0xFD_u8,
// int<2> Field detail flag
3u8, 64u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
3u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"username",
// string<lenenc> column
8u8, b"username",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
4u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"password",
// string<lenenc> column
8u8, b"password",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
60u8, 0u8, 0u8,
// int<1> seq_no
5u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
12u8, b"access_level",
12u8, b"access_level",
12u8,
// string<lenenc> column alias
0x0C_u8, b"access_level",
// string<lenenc> column
0x0C_u8, b"access_level",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
7u8, 0u8, 0u8, 0u8,
// int<1> Field types
0xFE_u8,
// int<2> Field detail flag
1u8, 0x11_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
6u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8,
// ----------------- //
// Result Row Packet //
// ----------------- //
// int<3> length
62u8, 0u8, 0u8,
// int<1> seq_no
7u8,
32u8, b"044d3f34-af65-11e9-a2e5-0242ac110003",
// string<lenenc> column data
36u8, b"044d3f34-af65-11e9-a2e5-0242ac110003",
// string<lenenc> column data
4u8, b"josh",
11u8, b"password123",
// string<lenenc> column data
0x0B_u8, b"password123",
// string<lenenc> column data
7u8, b"regular",
// ----------------- //
// Result Row Packet //
// ----------------- //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
8u8,
32u8, b"d83dd1c4-ada9-11e9-96bc-0242ac110003",
// string<lenenc> column data
36u8, b"d83dd1c4-ada9-11e9-96bc-0242ac110003",
// string<lenenc> column data
6u8, b"daniel",
// string<lenenc> column data
1u8, b"f",
// string<lenenc> column data
5u8, b"admin",
// ------------- //
// OK/EOF Packet //
// ------------- //
// int<3> length
5u8, 0u8, 0u8,
9u8,
// int<1> seq_no
1u8,
// 0xFE: Required header for last packet of result set
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8
);
conn.select_db("test").await?;
conn.query("SELECT * FROM users").await?;
let buf = conn.stream.next_bytes().await?;
println!("{:?}", buf);
let mut ctx = DeContext::new(&mut conn.context, &buf);
ResultSet::deserialize(&mut ctx)?;
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// ---------- //
// EOF Packet //
// ---------- //
// ------------------- //
// N Result Row Packet //
// ------------------- //
// ---------- //
// EOF Packet //
// ---------- //
Ok(())
}
}