mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-01-19 15:23:07 +00:00
fix(mssql): potential stall on re-using connection after an dropped incomplete fetch stream
This commit is contained in:
parent
ce4286dff5
commit
72ca9036c5
@ -74,9 +74,6 @@ impl MssqlConnection {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream,
|
||||
pending_done_count: 0,
|
||||
})
|
||||
Ok(Self { stream })
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ use crate::describe::{Column, Describe};
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::mssql::protocol::col_meta_data::Flags;
|
||||
use crate::mssql::protocol::done::{Done, Status};
|
||||
use crate::mssql::protocol::done::Status;
|
||||
use crate::mssql::protocol::message::Message;
|
||||
use crate::mssql::protocol::packet::PacketType;
|
||||
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
|
||||
@ -17,33 +17,9 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
|
||||
use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo};
|
||||
|
||||
impl MssqlConnection {
|
||||
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
||||
if !self.stream.wbuf.is_empty() {
|
||||
self.pending_done_count += 1;
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
while self.pending_done_count > 0 {
|
||||
let message = self.stream.recv_message().await?;
|
||||
|
||||
if let Message::DoneProc(done) | Message::Done(done) = message {
|
||||
if !done.status.contains(Status::DONE_MORE) {
|
||||
// finished RPC procedure *OR* SQL batch
|
||||
self.handle_done(done);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_done(&mut self, _: Done) {
|
||||
self.pending_done_count -= 1;
|
||||
}
|
||||
|
||||
async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
|
||||
self.wait_until_ready().await?;
|
||||
self.pending_done_count += 1;
|
||||
self.stream.wait_until_ready().await?;
|
||||
self.stream.pending_done_count += 1;
|
||||
|
||||
if let Some(mut arguments) = arguments {
|
||||
let proc = Either::Right(Procedure::ExecuteSql);
|
||||
@ -112,12 +88,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
|
||||
}
|
||||
|
||||
Message::Done(done) | Message::DoneProc(done) => {
|
||||
if !done.status.contains(Status::DONE_MORE) {
|
||||
self.stream.handle_done(&done);
|
||||
}
|
||||
|
||||
if done.status.contains(Status::DONE_COUNT) {
|
||||
r#yield!(Either::Left(done.affected_rows));
|
||||
}
|
||||
|
||||
if !done.status.contains(Status::DONE_MORE) {
|
||||
self.handle_done(done);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -221,12 +200,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
|
||||
|
||||
Box::pin(async move {
|
||||
self.stream.flush().await?;
|
||||
self.stream.wait_until_ready().await?;
|
||||
self.stream.pending_done_count += 1;
|
||||
|
||||
loop {
|
||||
match self.stream.recv_message().await? {
|
||||
Message::DoneProc(done) | Message::Done(done) => {
|
||||
if !done.status.contains(Status::DONE_MORE) {
|
||||
// done with prepare
|
||||
self.stream.handle_done(&done);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,9 +16,6 @@ mod stream;
|
||||
|
||||
pub struct MssqlConnection {
|
||||
pub(crate) stream: MssqlStream,
|
||||
|
||||
// number of Done* messages that we are currently expecting
|
||||
pub(crate) pending_done_count: usize,
|
||||
}
|
||||
|
||||
impl Debug for MssqlConnection {
|
||||
@ -42,7 +39,7 @@ impl Connection for MssqlConnection {
|
||||
|
||||
#[doc(hidden)]
|
||||
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
||||
self.wait_until_ready().boxed()
|
||||
self.stream.wait_until_ready().boxed()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
||||
@ -6,7 +6,7 @@ use sqlx_rt::TcpStream;
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufStream, Encode};
|
||||
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
|
||||
use crate::mssql::protocol::done::Done;
|
||||
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
|
||||
use crate::mssql::protocol::env_change::EnvChange;
|
||||
use crate::mssql::protocol::error::Error as ProtocolError;
|
||||
use crate::mssql::protocol::info::Info;
|
||||
@ -22,6 +22,9 @@ use crate::net::MaybeTlsStream;
|
||||
pub(crate) struct MssqlStream {
|
||||
inner: BufStream<MaybeTlsStream<TcpStream>>,
|
||||
|
||||
// how many Done (or Error) we are currently waiting for
|
||||
pub(crate) pending_done_count: usize,
|
||||
|
||||
// current transaction descriptor
|
||||
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
|
||||
pub(crate) transaction_descriptor: u64,
|
||||
@ -44,6 +47,7 @@ impl MssqlStream {
|
||||
inner,
|
||||
columns: Vec::new(),
|
||||
response: None,
|
||||
pending_done_count: 0,
|
||||
transaction_descriptor: 0,
|
||||
})
|
||||
}
|
||||
@ -146,8 +150,8 @@ impl MssqlStream {
|
||||
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
|
||||
|
||||
MessageType::Error => {
|
||||
let err = ProtocolError::get(buf)?;
|
||||
return Err(MssqlDatabaseError(err).into());
|
||||
let error = ProtocolError::get(buf)?;
|
||||
return self.handle_error(error);
|
||||
}
|
||||
|
||||
MessageType::ColMetaData => {
|
||||
@ -165,6 +169,35 @@ impl MssqlStream {
|
||||
self.response = Some(self.recv_packet().await?);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_done(&mut self, _done: &Done) {
|
||||
self.pending_done_count -= 1;
|
||||
}
|
||||
|
||||
pub(crate) fn handle_error<T>(&mut self, error: ProtocolError) -> Result<T, Error> {
|
||||
// error is sent _instead_ of a done
|
||||
self.pending_done_count -= 1;
|
||||
Err(MssqlDatabaseError(error).into())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
||||
if !self.wbuf.is_empty() {
|
||||
self.flush().await?;
|
||||
}
|
||||
|
||||
while self.pending_done_count > 0 {
|
||||
let message = self.recv_message().await?;
|
||||
|
||||
if let Message::DoneProc(done) | Message::Done(done) = message {
|
||||
if !done.status.contains(DoneStatus::DONE_MORE) {
|
||||
// finished RPC procedure *OR* SQL batch
|
||||
self.handle_done(&done);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MssqlStream {
|
||||
|
||||
@ -61,7 +61,7 @@ impl TransactionManager for MssqlTransactionManager {
|
||||
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
|
||||
};
|
||||
|
||||
conn.pending_done_count += 1;
|
||||
conn.stream.pending_done_count += 1;
|
||||
conn.stream.write_packet(
|
||||
PacketType::SqlBatch,
|
||||
SqlBatch {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user