fix(mssql): potential stall on re-using connection after an dropped incomplete fetch stream

This commit is contained in:
Ryan Leckey
2020-06-10 00:32:38 -07:00
parent ce4286dff5
commit 72ca9036c5
5 changed files with 49 additions and 40 deletions

View File

@@ -74,9 +74,6 @@ impl MssqlConnection {
} }
} }
Ok(Self { Ok(Self { stream })
stream,
pending_done_count: 0,
})
} }
} }

View File

@@ -9,7 +9,7 @@ use crate::describe::{Column, Describe};
use crate::error::Error; use crate::error::Error;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::mssql::protocol::col_meta_data::Flags; use crate::mssql::protocol::col_meta_data::Flags;
use crate::mssql::protocol::done::{Done, Status}; use crate::mssql::protocol::done::Status;
use crate::mssql::protocol::message::Message; use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
@@ -17,33 +17,9 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo}; use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo};
impl MssqlConnection { impl MssqlConnection {
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.stream.wbuf.is_empty() {
self.pending_done_count += 1;
self.stream.flush().await?;
}
while self.pending_done_count > 0 {
let message = self.stream.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
if !done.status.contains(Status::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
}
}
Ok(())
}
fn handle_done(&mut self, _: Done) {
self.pending_done_count -= 1;
}
async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> { async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
self.wait_until_ready().await?; self.stream.wait_until_ready().await?;
self.pending_done_count += 1; self.stream.pending_done_count += 1;
if let Some(mut arguments) = arguments { if let Some(mut arguments) = arguments {
let proc = Either::Right(Procedure::ExecuteSql); let proc = Either::Right(Procedure::ExecuteSql);
@@ -112,12 +88,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
} }
Message::Done(done) | Message::DoneProc(done) => { Message::Done(done) | Message::DoneProc(done) => {
if !done.status.contains(Status::DONE_MORE) {
self.stream.handle_done(&done);
}
if done.status.contains(Status::DONE_COUNT) { if done.status.contains(Status::DONE_COUNT) {
r#yield!(Either::Left(done.affected_rows)); r#yield!(Either::Left(done.affected_rows));
} }
if !done.status.contains(Status::DONE_MORE) { if !done.status.contains(Status::DONE_MORE) {
self.handle_done(done);
break; break;
} }
} }
@@ -221,12 +200,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
Box::pin(async move { Box::pin(async move {
self.stream.flush().await?; self.stream.flush().await?;
self.stream.wait_until_ready().await?;
self.stream.pending_done_count += 1;
loop { loop {
match self.stream.recv_message().await? { match self.stream.recv_message().await? {
Message::DoneProc(done) | Message::Done(done) => { Message::DoneProc(done) | Message::Done(done) => {
if !done.status.contains(Status::DONE_MORE) { if !done.status.contains(Status::DONE_MORE) {
// done with prepare // done with prepare
self.stream.handle_done(&done);
break; break;
} }
} }

View File

@@ -16,9 +16,6 @@ mod stream;
pub struct MssqlConnection { pub struct MssqlConnection {
pub(crate) stream: MssqlStream, pub(crate) stream: MssqlStream,
// number of Done* messages that we are currently expecting
pub(crate) pending_done_count: usize,
} }
impl Debug for MssqlConnection { impl Debug for MssqlConnection {
@@ -42,7 +39,7 @@ impl Connection for MssqlConnection {
#[doc(hidden)] #[doc(hidden)]
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
self.wait_until_ready().boxed() self.stream.wait_until_ready().boxed()
} }
#[doc(hidden)] #[doc(hidden)]

View File

@@ -6,7 +6,7 @@ use sqlx_rt::TcpStream;
use crate::error::Error; use crate::error::Error;
use crate::io::{BufStream, Encode}; use crate::io::{BufStream, Encode};
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData}; use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
use crate::mssql::protocol::done::Done; use crate::mssql::protocol::done::{Done, Status as DoneStatus};
use crate::mssql::protocol::env_change::EnvChange; use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error as ProtocolError; use crate::mssql::protocol::error::Error as ProtocolError;
use crate::mssql::protocol::info::Info; use crate::mssql::protocol::info::Info;
@@ -22,6 +22,9 @@ use crate::net::MaybeTlsStream;
pub(crate) struct MssqlStream { pub(crate) struct MssqlStream {
inner: BufStream<MaybeTlsStream<TcpStream>>, inner: BufStream<MaybeTlsStream<TcpStream>>,
// how many Done (or Error) we are currently waiting for
pub(crate) pending_done_count: usize,
// current transaction descriptor // current transaction descriptor
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK // set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
pub(crate) transaction_descriptor: u64, pub(crate) transaction_descriptor: u64,
@@ -44,6 +47,7 @@ impl MssqlStream {
inner, inner,
columns: Vec::new(), columns: Vec::new(),
response: None, response: None,
pending_done_count: 0,
transaction_descriptor: 0, transaction_descriptor: 0,
}) })
} }
@@ -146,8 +150,8 @@ impl MssqlStream {
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?), MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
MessageType::Error => { MessageType::Error => {
let err = ProtocolError::get(buf)?; let error = ProtocolError::get(buf)?;
return Err(MssqlDatabaseError(err).into()); return self.handle_error(error);
} }
MessageType::ColMetaData => { MessageType::ColMetaData => {
@@ -165,6 +169,35 @@ impl MssqlStream {
self.response = Some(self.recv_packet().await?); self.response = Some(self.recv_packet().await?);
} }
} }
pub(crate) fn handle_done(&mut self, _done: &Done) {
self.pending_done_count -= 1;
}
pub(crate) fn handle_error<T>(&mut self, error: ProtocolError) -> Result<T, Error> {
// error is sent _instead_ of a done
self.pending_done_count -= 1;
Err(MssqlDatabaseError(error).into())
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.wbuf.is_empty() {
self.flush().await?;
}
while self.pending_done_count > 0 {
let message = self.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
if !done.status.contains(DoneStatus::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(&done);
}
}
}
Ok(())
}
} }
impl Deref for MssqlStream { impl Deref for MssqlStream {

View File

@@ -61,7 +61,7 @@ impl TransactionManager for MssqlTransactionManager {
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1)) Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
}; };
conn.pending_done_count += 1; conn.stream.pending_done_count += 1;
conn.stream.write_packet( conn.stream.write_packet(
PacketType::SqlBatch, PacketType::SqlBatch,
SqlBatch { SqlBatch {