diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index 8929bdda..bd67a342 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -101,8 +101,10 @@ mod test { password: None, }).await?; + println!("selecting db"); conn.select_db("test").await?; + println!("querying"); conn.query("SELECT * FROM users").await?; Ok(()) diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index 6d658a21..271bbbfa 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -104,8 +104,11 @@ impl Connection { pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result, Error> { self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?; + println!("awaiting next packet"); let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream); ctx.next_packet().await?; + + println!("Got next packet"); match ctx.decoder.peek_tag() { 0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()), 0x00 => { @@ -162,6 +165,7 @@ impl Connection { pub struct Framed { inner: TcpStream, buf: BytesMut, + index: usize, } impl Framed { @@ -169,65 +173,14 @@ impl Framed { Self { inner: stream, buf: BytesMut::with_capacity(8 * 1024), + index: 0, } } pub async fn next_packet(&mut self) -> Result { - let mut rbuf = BytesMut::new(); - let mut len = 0usize; let mut packet_headers: Vec = Vec::new(); loop { - if let Some(packet_header) = packet_headers.last() { - if packet_header.combined_length() > rbuf.len() { - let reserve = packet_header.combined_length() - rbuf.len(); - rbuf.reserve(reserve); - - unsafe { - rbuf.set_len(rbuf.capacity()); - self.inner.initializer().initialize(&mut rbuf[len..]); - } - } - } else if rbuf.len() == len { - rbuf.reserve(32); - - unsafe { - rbuf.set_len(rbuf.capacity()); - self.inner.initializer().initialize(&mut rbuf[len..]); - } - } - - // If we have a packet_header and the amount of currently read bytes (len) is less than - // the specified length inside packet_header, then we can continue reading to rbuf; but - // only up until packet_header.length. - // Else if the total number of bytes read is equal to packet_header then we will - // return rbuf as it should contain the entire packet. - // Else we read too many bytes -- which shouldn't happen -- and will return an error. - let bytes_read; - - if let Some(packet_header) = packet_headers.last() { - if packet_header.combined_length() > len { - bytes_read = self.inner.read(&mut rbuf[len..packet_header.combined_length()]).await?; - } else { - return Ok(rbuf.freeze()); - } - } else { - // Only read header to make sure that we dont' read the next packets buffer. - bytes_read = self.inner.read(&mut rbuf[len..len + 4]).await?; - } - - if bytes_read > 0 { - len += bytes_read; - // If we have read less than 4 bytes, and we don't already have a packet_header - // we must try to read again. The packet_header is always present and is 4 bytes long. - if bytes_read < 4 && packet_headers.len() == 0 { - continue; - } - } else { - // Read 0 bytes from the server; end-of-stream - return Ok(rbuf.freeze()); - } - // If we don't have a packet header or the last packet header had a length of // 0xFF_FF_FF (the max possible length); then we must continue receiving packets // because the entire message hasn't been received. @@ -236,10 +189,62 @@ impl Framed { // TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions. if let Some(packet_header) = packet_headers.last() { if packet_header.length as usize == encode::U24_MAX { - packet_headers.push(PacketHeader::try_from(&rbuf[0..])?); + packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?); + } + } else if self.buf.len() > 4 { + match PacketHeader::try_from(&self.buf[0..]) { + Ok(v) => packet_headers.push(v), + Err(_) => {}, + } + } + + if let Some(packet_header) = packet_headers.last() { + if packet_header.combined_length() > self.buf.len() { + self.buf.reserve(packet_header.combined_length() - self.buf.len()); + + unsafe { + self.buf.set_len(self.buf.capacity()); + self.inner.initializer().initialize(&mut self.buf[self.index..]); + } + } + } else if self.buf.len() == self.index { + self.buf.reserve(32); + + unsafe { + self.buf.set_len(self.buf.capacity()); + self.inner.initializer().initialize(&mut self.buf[self.index..]); + } + } + + // If we have a packet_header and the amount of currently read bytes (len) is less than + // the specified length inside packet_header, then we can continue reading to self.buf. + // Else if the total number of bytes read is equal to packet_header then we will + // return self.buf from 0 to self.index as it should contain the entire packet. + let bytes_read; + + if let Some(packet_header) = packet_headers.last() { + if packet_header.combined_length() > self.index { + bytes_read = self.inner.read(&mut self.buf[self.index..]).await?; + } else { + // Get the packet from the buffer, reset index, and return packet + let packet = self.buf.split_to(packet_header.combined_length()).freeze(); + self.index -= packet.len(); + return Ok(packet); } } else { - packet_headers.push(PacketHeader::try_from(&rbuf[0..])?); + bytes_read = self.inner.read(&mut self.buf[self.index..]).await?; + } + + if bytes_read > 0 { + self.index += bytes_read; + // If we have read less than 4 bytes, and we don't already have a packet_header + // we must try to read again. The packet_header is always present and is 4 bytes long. + if bytes_read < 4 && packet_headers.len() == 0 { + continue; + } + } else { + // Read 0 bytes from the server; end-of-stream + panic!("Cannot read 0 bytes from stream"); } } } diff --git a/src/mariadb/protocol/packets/packet_header.rs b/src/mariadb/protocol/packets/packet_header.rs index 8a742810..28e38dd9 100644 --- a/src/mariadb/protocol/packets/packet_header.rs +++ b/src/mariadb/protocol/packets/packet_header.rs @@ -1,7 +1,7 @@ use byteorder::LittleEndian; use byteorder::ByteOrder; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone, Copy)] pub struct PacketHeader { pub length: u32, pub seq_no: u8, @@ -24,10 +24,14 @@ impl core::convert::TryFrom<&[u8]> for PacketHeader { if buffer.len() < 4 { failure::bail!("Buffer length is too short") } else { - Ok(PacketHeader { + let packet = PacketHeader { length: LittleEndian::read_u24(&buffer), seq_no: buffer[3], - }) + }; + if packet.length == 0 && packet.seq_no == 0{ + failure::bail!("Length and seq_no cannot be zero"); + } + Ok(packet) } } }