fix(mssql): potential stall on re-using connection after an dropped incomplete fetch stream

This commit is contained in:
Ryan Leckey 2020-06-10 00:32:38 -07:00
parent ce4286dff5
commit 72ca9036c5
5 changed files with 49 additions and 40 deletions

View File

@ -74,9 +74,6 @@ impl MssqlConnection {
}
}
Ok(Self {
stream,
pending_done_count: 0,
})
Ok(Self { stream })
}
}

View File

@ -9,7 +9,7 @@ use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::col_meta_data::Flags;
use crate::mssql::protocol::done::{Done, Status};
use crate::mssql::protocol::done::Status;
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
@ -17,33 +17,9 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo};
impl MssqlConnection {
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.stream.wbuf.is_empty() {
self.pending_done_count += 1;
self.stream.flush().await?;
}
while self.pending_done_count > 0 {
let message = self.stream.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
if !done.status.contains(Status::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
}
}
Ok(())
}
fn handle_done(&mut self, _: Done) {
self.pending_done_count -= 1;
}
async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
self.wait_until_ready().await?;
self.pending_done_count += 1;
self.stream.wait_until_ready().await?;
self.stream.pending_done_count += 1;
if let Some(mut arguments) = arguments {
let proc = Either::Right(Procedure::ExecuteSql);
@ -112,12 +88,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
}
Message::Done(done) | Message::DoneProc(done) => {
if !done.status.contains(Status::DONE_MORE) {
self.stream.handle_done(&done);
}
if done.status.contains(Status::DONE_COUNT) {
r#yield!(Either::Left(done.affected_rows));
}
if !done.status.contains(Status::DONE_MORE) {
self.handle_done(done);
break;
}
}
@ -221,12 +200,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
Box::pin(async move {
self.stream.flush().await?;
self.stream.wait_until_ready().await?;
self.stream.pending_done_count += 1;
loop {
match self.stream.recv_message().await? {
Message::DoneProc(done) | Message::Done(done) => {
if !done.status.contains(Status::DONE_MORE) {
// done with prepare
self.stream.handle_done(&done);
break;
}
}

View File

@ -16,9 +16,6 @@ mod stream;
pub struct MssqlConnection {
pub(crate) stream: MssqlStream,
// number of Done* messages that we are currently expecting
pub(crate) pending_done_count: usize,
}
impl Debug for MssqlConnection {
@ -42,7 +39,7 @@ impl Connection for MssqlConnection {
#[doc(hidden)]
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
self.wait_until_ready().boxed()
self.stream.wait_until_ready().boxed()
}
#[doc(hidden)]

View File

@ -6,7 +6,7 @@ use sqlx_rt::TcpStream;
use crate::error::Error;
use crate::io::{BufStream, Encode};
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error as ProtocolError;
use crate::mssql::protocol::info::Info;
@ -22,6 +22,9 @@ use crate::net::MaybeTlsStream;
pub(crate) struct MssqlStream {
inner: BufStream<MaybeTlsStream<TcpStream>>,
// how many Done (or Error) we are currently waiting for
pub(crate) pending_done_count: usize,
// current transaction descriptor
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
pub(crate) transaction_descriptor: u64,
@ -44,6 +47,7 @@ impl MssqlStream {
inner,
columns: Vec::new(),
response: None,
pending_done_count: 0,
transaction_descriptor: 0,
})
}
@ -146,8 +150,8 @@ impl MssqlStream {
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
MessageType::Error => {
let err = ProtocolError::get(buf)?;
return Err(MssqlDatabaseError(err).into());
let error = ProtocolError::get(buf)?;
return self.handle_error(error);
}
MessageType::ColMetaData => {
@ -165,6 +169,35 @@ impl MssqlStream {
self.response = Some(self.recv_packet().await?);
}
}
pub(crate) fn handle_done(&mut self, _done: &Done) {
self.pending_done_count -= 1;
}
pub(crate) fn handle_error<T>(&mut self, error: ProtocolError) -> Result<T, Error> {
// error is sent _instead_ of a done
self.pending_done_count -= 1;
Err(MssqlDatabaseError(error).into())
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.wbuf.is_empty() {
self.flush().await?;
}
while self.pending_done_count > 0 {
let message = self.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
if !done.status.contains(DoneStatus::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(&done);
}
}
}
Ok(())
}
}
impl Deref for MssqlStream {

View File

@ -61,7 +61,7 @@ impl TransactionManager for MssqlTransactionManager {
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
};
conn.pending_done_count += 1;
conn.stream.pending_done_count += 1;
conn.stream.write_packet(
PacketType::SqlBatch,
SqlBatch {