Fix next_packet

This commit is contained in:
Daniel Akhterov 2019-08-05 22:45:25 -07:00
parent 6696e54c33
commit 4c1da595cb
3 changed files with 68 additions and 57 deletions

View File

@ -101,8 +101,10 @@ mod test {
password: None,
}).await?;
println!("selecting db");
conn.select_db("test").await?;
println!("querying");
conn.query("SELECT * FROM users").await?;
Ok(())

View File

@ -104,8 +104,11 @@ impl Connection {
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<Option<ResultSet>, Error> {
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
println!("awaiting next packet");
let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
ctx.next_packet().await?;
println!("Got next packet");
match ctx.decoder.peek_tag() {
0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()),
0x00 => {
@ -162,6 +165,7 @@ impl Connection {
pub struct Framed {
inner: TcpStream,
buf: BytesMut,
index: usize,
}
impl Framed {
@ -169,65 +173,14 @@ impl Framed {
Self {
inner: stream,
buf: BytesMut::with_capacity(8 * 1024),
index: 0,
}
}
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0usize;
let mut packet_headers: Vec<PacketHeader> = Vec::new();
loop {
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > rbuf.len() {
let reserve = packet_header.combined_length() - rbuf.len();
rbuf.reserve(reserve);
unsafe {
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
} else if rbuf.len() == len {
rbuf.reserve(32);
unsafe {
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
// If we have a packet_header and the amount of currently read bytes (len) is less than
// the specified length inside packet_header, then we can continue reading to rbuf; but
// only up until packet_header.length.
// Else if the total number of bytes read is equal to packet_header then we will
// return rbuf as it should contain the entire packet.
// Else we read too many bytes -- which shouldn't happen -- and will return an error.
let bytes_read;
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > len {
bytes_read = self.inner.read(&mut rbuf[len..packet_header.combined_length()]).await?;
} else {
return Ok(rbuf.freeze());
}
} else {
// Only read header to make sure that we dont' read the next packets buffer.
bytes_read = self.inner.read(&mut rbuf[len..len + 4]).await?;
}
if bytes_read > 0 {
len += bytes_read;
// If we have read less than 4 bytes, and we don't already have a packet_header
// we must try to read again. The packet_header is always present and is 4 bytes long.
if bytes_read < 4 && packet_headers.len() == 0 {
continue;
}
} else {
// Read 0 bytes from the server; end-of-stream
return Ok(rbuf.freeze());
}
// If we don't have a packet header or the last packet header had a length of
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
// because the entire message hasn't been received.
@ -236,10 +189,62 @@ impl Framed {
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
if let Some(packet_header) = packet_headers.last() {
if packet_header.length as usize == encode::U24_MAX {
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?);
}
} else if self.buf.len() > 4 {
match PacketHeader::try_from(&self.buf[0..]) {
Ok(v) => packet_headers.push(v),
Err(_) => {},
}
}
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.buf.len() {
self.buf.reserve(packet_header.combined_length() - self.buf.len());
unsafe {
self.buf.set_len(self.buf.capacity());
self.inner.initializer().initialize(&mut self.buf[self.index..]);
}
}
} else if self.buf.len() == self.index {
self.buf.reserve(32);
unsafe {
self.buf.set_len(self.buf.capacity());
self.inner.initializer().initialize(&mut self.buf[self.index..]);
}
}
// If we have a packet_header and the amount of currently read bytes (len) is less than
// the specified length inside packet_header, then we can continue reading to self.buf.
// Else if the total number of bytes read is equal to packet_header then we will
// return self.buf from 0 to self.index as it should contain the entire packet.
let bytes_read;
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.index {
bytes_read = self.inner.read(&mut self.buf[self.index..]).await?;
} else {
// Get the packet from the buffer, reset index, and return packet
let packet = self.buf.split_to(packet_header.combined_length()).freeze();
self.index -= packet.len();
return Ok(packet);
}
} else {
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
bytes_read = self.inner.read(&mut self.buf[self.index..]).await?;
}
if bytes_read > 0 {
self.index += bytes_read;
// If we have read less than 4 bytes, and we don't already have a packet_header
// we must try to read again. The packet_header is always present and is 4 bytes long.
if bytes_read < 4 && packet_headers.len() == 0 {
continue;
}
} else {
// Read 0 bytes from the server; end-of-stream
panic!("Cannot read 0 bytes from stream");
}
}
}

View File

@ -1,7 +1,7 @@
use byteorder::LittleEndian;
use byteorder::ByteOrder;
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone, Copy)]
pub struct PacketHeader {
pub length: u32,
pub seq_no: u8,
@ -24,10 +24,14 @@ impl core::convert::TryFrom<&[u8]> for PacketHeader {
if buffer.len() < 4 {
failure::bail!("Buffer length is too short")
} else {
Ok(PacketHeader {
let packet = PacketHeader {
length: LittleEndian::read_u24(&buffer),
seq_no: buffer[3],
})
};
if packet.length == 0 && packet.seq_no == 0{
failure::bail!("Length and seq_no cannot be zero");
}
Ok(packet)
}
}
}