From 72ca9036c5ebdaa04f019d9ddc908d606ba992da Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 10 Jun 2020 00:32:38 -0700 Subject: [PATCH] fix(mssql): potential stall on re-using connection after an dropped incomplete fetch stream --- sqlx-core/src/mssql/connection/establish.rs | 5 +-- sqlx-core/src/mssql/connection/executor.rs | 38 ++++++-------------- sqlx-core/src/mssql/connection/mod.rs | 5 +-- sqlx-core/src/mssql/connection/stream.rs | 39 +++++++++++++++++++-- sqlx-core/src/mssql/transaction.rs | 2 +- 5 files changed, 49 insertions(+), 40 deletions(-) diff --git a/sqlx-core/src/mssql/connection/establish.rs b/sqlx-core/src/mssql/connection/establish.rs index b678d23e..919cddd6 100644 --- a/sqlx-core/src/mssql/connection/establish.rs +++ b/sqlx-core/src/mssql/connection/establish.rs @@ -74,9 +74,6 @@ impl MssqlConnection { } } - Ok(Self { - stream, - pending_done_count: 0, - }) + Ok(Self { stream }) } } diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index 00ffd7b5..a38820c0 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -9,7 +9,7 @@ use crate::describe::{Column, Describe}; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::mssql::protocol::col_meta_data::Flags; -use crate::mssql::protocol::done::{Done, Status}; +use crate::mssql::protocol::done::Status; use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; @@ -17,33 +17,9 @@ use crate::mssql::protocol::sql_batch::SqlBatch; use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo}; impl MssqlConnection { - pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.stream.wbuf.is_empty() { - self.pending_done_count += 1; - self.stream.flush().await?; - } - - while self.pending_done_count > 0 { - let message = self.stream.recv_message().await?; - - if let Message::DoneProc(done) | Message::Done(done) = message { - if !done.status.contains(Status::DONE_MORE) { - // finished RPC procedure *OR* SQL batch - self.handle_done(done); - } - } - } - - Ok(()) - } - - fn handle_done(&mut self, _: Done) { - self.pending_done_count -= 1; - } - async fn run(&mut self, query: &str, arguments: Option) -> Result<(), Error> { - self.wait_until_ready().await?; - self.pending_done_count += 1; + self.stream.wait_until_ready().await?; + self.stream.pending_done_count += 1; if let Some(mut arguments) = arguments { let proc = Either::Right(Procedure::ExecuteSql); @@ -112,12 +88,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { } Message::Done(done) | Message::DoneProc(done) => { + if !done.status.contains(Status::DONE_MORE) { + self.stream.handle_done(&done); + } + if done.status.contains(Status::DONE_COUNT) { r#yield!(Either::Left(done.affected_rows)); } if !done.status.contains(Status::DONE_MORE) { - self.handle_done(done); break; } } @@ -221,12 +200,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { Box::pin(async move { self.stream.flush().await?; + self.stream.wait_until_ready().await?; + self.stream.pending_done_count += 1; loop { match self.stream.recv_message().await? { Message::DoneProc(done) | Message::Done(done) => { if !done.status.contains(Status::DONE_MORE) { // done with prepare + self.stream.handle_done(&done); break; } } diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs index 2bd5b4d8..c81d32be 100644 --- a/sqlx-core/src/mssql/connection/mod.rs +++ b/sqlx-core/src/mssql/connection/mod.rs @@ -16,9 +16,6 @@ mod stream; pub struct MssqlConnection { pub(crate) stream: MssqlStream, - - // number of Done* messages that we are currently expecting - pub(crate) pending_done_count: usize, } impl Debug for MssqlConnection { @@ -42,7 +39,7 @@ impl Connection for MssqlConnection { #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { - self.wait_until_ready().boxed() + self.stream.wait_until_ready().boxed() } #[doc(hidden)] diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index a0e8518a..a1829799 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -6,7 +6,7 @@ use sqlx_rt::TcpStream; use crate::error::Error; use crate::io::{BufStream, Encode}; use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData}; -use crate::mssql::protocol::done::Done; +use crate::mssql::protocol::done::{Done, Status as DoneStatus}; use crate::mssql::protocol::env_change::EnvChange; use crate::mssql::protocol::error::Error as ProtocolError; use crate::mssql::protocol::info::Info; @@ -22,6 +22,9 @@ use crate::net::MaybeTlsStream; pub(crate) struct MssqlStream { inner: BufStream>, + // how many Done (or Error) we are currently waiting for + pub(crate) pending_done_count: usize, + // current transaction descriptor // set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK pub(crate) transaction_descriptor: u64, @@ -44,6 +47,7 @@ impl MssqlStream { inner, columns: Vec::new(), response: None, + pending_done_count: 0, transaction_descriptor: 0, }) } @@ -146,8 +150,8 @@ impl MssqlStream { MessageType::DoneProc => Message::DoneProc(Done::get(buf)?), MessageType::Error => { - let err = ProtocolError::get(buf)?; - return Err(MssqlDatabaseError(err).into()); + let error = ProtocolError::get(buf)?; + return self.handle_error(error); } MessageType::ColMetaData => { @@ -165,6 +169,35 @@ impl MssqlStream { self.response = Some(self.recv_packet().await?); } } + + pub(crate) fn handle_done(&mut self, _done: &Done) { + self.pending_done_count -= 1; + } + + pub(crate) fn handle_error(&mut self, error: ProtocolError) -> Result { + // error is sent _instead_ of a done + self.pending_done_count -= 1; + Err(MssqlDatabaseError(error).into()) + } + + pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { + if !self.wbuf.is_empty() { + self.flush().await?; + } + + while self.pending_done_count > 0 { + let message = self.recv_message().await?; + + if let Message::DoneProc(done) | Message::Done(done) = message { + if !done.status.contains(DoneStatus::DONE_MORE) { + // finished RPC procedure *OR* SQL batch + self.handle_done(&done); + } + } + } + + Ok(()) + } } impl Deref for MssqlStream { diff --git a/sqlx-core/src/mssql/transaction.rs b/sqlx-core/src/mssql/transaction.rs index 79e94dfe..894e2372 100644 --- a/sqlx-core/src/mssql/transaction.rs +++ b/sqlx-core/src/mssql/transaction.rs @@ -61,7 +61,7 @@ impl TransactionManager for MssqlTransactionManager { Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1)) }; - conn.pending_done_count += 1; + conn.stream.pending_done_count += 1; conn.stream.write_packet( PacketType::SqlBatch, SqlBatch {