fix(mssql): handle multi-chunk packets

fixes #523
This commit is contained in:
Ryan Leckey 2020-07-17 06:08:36 -07:00
parent 6fdb0d534f
commit f345c23e51
3 changed files with 88 additions and 26 deletions

View File

@ -69,28 +69,18 @@ where
where
T: Decode<'de, C>,
{
// zero-fills the space in the read buffer
self.rbuf.resize(cnt, 0);
T::decode_with(self.read_raw(cnt).await?.freeze(), context)
}
let mut read = 0;
while cnt > read {
// read in bytes from the stream into the read buffer starting
// from the offset we last read from
let n = self.stream.read(&mut self.rbuf[read..]).await?;
pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
let buf = self.rbuf.split_to(cnt);
if n == 0 {
// a zero read when we had space in the read buffer
// should be treated as an EOF
Ok(buf)
}
// and an unexpected EOF means the server told us to go away
return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
}
read += n;
}
T::decode_with(self.rbuf.split_to(cnt).freeze(), context)
pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
read_raw_into(&mut self.stream, buf, cnt).await
}
}
@ -113,3 +103,34 @@ where
&mut self.stream
}
}
async fn read_raw_into<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
cnt: usize,
) -> Result<(), Error> {
let offset = buf.len();
// zero-fills the space in the read buffer
buf.resize(offset + cnt, 0);
let mut read = offset;
while (offset + cnt) > read {
// read in bytes from the stream into the read buffer starting
// from the offset we last read from
let n = stream.read(&mut buf[read..]).await?;
if n == 0 {
// a zero read when we had space in the read buffer
// should be treated as an EOF
// and an unexpected EOF means the server told us to go away
return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
}
read += n;
}
Ok(())
}

View File

@ -1,6 +1,6 @@
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use sqlx_rt::TcpStream;
use crate::error::Error;
@ -92,9 +92,7 @@ impl MssqlStream {
// receive the next packet from the database
// blocks until a packet is available
pub(super) async fn recv_packet(&mut self) -> Result<(PacketHeader, Bytes), Error> {
// TODO: Support packet chunking for large packet sizes
let header: PacketHeader = self.inner.read(8).await?;
let mut header: PacketHeader = self.inner.read(8).await?;
// NOTE: From what I can tell, the response type from the server should ~always~
// be TabularResult. Here we expect that and die otherwise.
@ -105,10 +103,21 @@ impl MssqlStream {
));
}
let payload_len = (header.length - 8) as usize;
let payload: Bytes = self.inner.read(payload_len).await?;
let mut payload = BytesMut::new();
Ok((header, payload))
loop {
self.inner
.read_raw_into(&mut payload, (header.length - 8) as usize)
.await?;
if header.status.contains(Status::END_OF_MESSAGE) {
break;
}
header = self.inner.read(8).await?;
}
Ok((header, payload.freeze()))
}
// receive the next ~message~

View File

@ -112,6 +112,38 @@ CREATE TABLE #users (id INTEGER PRIMARY KEY);
Ok(())
}
#[sqlx_macros::test]
async fn it_can_return_1000_rows() -> anyhow::Result<()> {
let mut conn = new::<Mssql>().await?;
let _ = conn
.execute(
r#"
CREATE TABLE #users (id INTEGER PRIMARY KEY);
"#,
)
.await?;
for index in 1..=1000_i32 {
let done = sqlx::query("INSERT INTO #users (id) VALUES (@p1)")
.bind(index * 2)
.execute(&mut conn)
.await?;
assert_eq!(done.rows_affected(), 1);
}
let sum: i32 = sqlx::query("SELECT id FROM #users")
.try_map(|row: MssqlRow| row.try_get::<i32, _>(0))
.fetch(&mut conn)
.try_fold(0_i32, |acc, x| async move { Ok(acc + x) })
.await?;
assert_eq!(sum, 1001000);
Ok(())
}
#[sqlx_macros::test]
async fn it_selects_null() -> anyhow::Result<()> {
let mut conn = new::<Mssql>().await?;