diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index f4af7475..4646cb7a 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -69,28 +69,18 @@ where where T: Decode<'de, C>, { - // zero-fills the space in the read buffer - self.rbuf.resize(cnt, 0); + T::decode_with(self.read_raw(cnt).await?.freeze(), context) + } - let mut read = 0; - while cnt > read { - // read in bytes from the stream into the read buffer starting - // from the offset we last read from - let n = self.stream.read(&mut self.rbuf[read..]).await?; + pub async fn read_raw(&mut self, cnt: usize) -> Result { + read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?; + let buf = self.rbuf.split_to(cnt); - if n == 0 { - // a zero read when we had space in the read buffer - // should be treated as an EOF + Ok(buf) + } - // and an unexpected EOF means the server told us to go away - - return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); - } - - read += n; - } - - T::decode_with(self.rbuf.split_to(cnt).freeze(), context) + pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> { + read_raw_into(&mut self.stream, buf, cnt).await } } @@ -113,3 +103,34 @@ where &mut self.stream } } + +async fn read_raw_into( + stream: &mut S, + buf: &mut BytesMut, + cnt: usize, +) -> Result<(), Error> { + let offset = buf.len(); + + // zero-fills the space in the read buffer + buf.resize(offset + cnt, 0); + + let mut read = offset; + while (offset + cnt) > read { + // read in bytes from the stream into the read buffer starting + // from the offset we last read from + let n = stream.read(&mut buf[read..]).await?; + + if n == 0 { + // a zero read when we had space in the read buffer + // should be treated as an EOF + + // and an unexpected EOF means the server told us to go away + + return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); + } + + read += n; + } + + Ok(()) +} diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index 1793c76c..d7d3604f 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -1,6 +1,6 @@ use std::ops::{Deref, DerefMut}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use sqlx_rt::TcpStream; use crate::error::Error; @@ -92,9 +92,7 @@ impl MssqlStream { // receive the next packet from the database // blocks until a packet is available pub(super) async fn recv_packet(&mut self) -> Result<(PacketHeader, Bytes), Error> { - // TODO: Support packet chunking for large packet sizes - - let header: PacketHeader = self.inner.read(8).await?; + let mut header: PacketHeader = self.inner.read(8).await?; // NOTE: From what I can tell, the response type from the server should ~always~ // be TabularResult. Here we expect that and die otherwise. @@ -105,10 +103,21 @@ impl MssqlStream { )); } - let payload_len = (header.length - 8) as usize; - let payload: Bytes = self.inner.read(payload_len).await?; + let mut payload = BytesMut::new(); - Ok((header, payload)) + loop { + self.inner + .read_raw_into(&mut payload, (header.length - 8) as usize) + .await?; + + if header.status.contains(Status::END_OF_MESSAGE) { + break; + } + + header = self.inner.read(8).await?; + } + + Ok((header, payload.freeze())) } // receive the next ~message~ diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 1a89cb49..56b12673 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -112,6 +112,38 @@ CREATE TABLE #users (id INTEGER PRIMARY KEY); Ok(()) } +#[sqlx_macros::test] +async fn it_can_return_1000_rows() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = conn + .execute( + r#" +CREATE TABLE #users (id INTEGER PRIMARY KEY); + "#, + ) + .await?; + + for index in 1..=1000_i32 { + let done = sqlx::query("INSERT INTO #users (id) VALUES (@p1)") + .bind(index * 2) + .execute(&mut conn) + .await?; + + assert_eq!(done.rows_affected(), 1); + } + + let sum: i32 = sqlx::query("SELECT id FROM #users") + .try_map(|row: MssqlRow| row.try_get::(0)) + .fetch(&mut conn) + .try_fold(0_i32, |acc, x| async move { Ok(acc + x) }) + .await?; + + assert_eq!(sum, 1001000); + + Ok(()) +} + #[sqlx_macros::test] async fn it_selects_null() -> anyhow::Result<()> { let mut conn = new::().await?;