mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-19 21:25:46 +00:00
fix(mssql): potential stall on re-using connection after an dropped incomplete fetch stream
This commit is contained in:
@@ -74,9 +74,6 @@ impl MssqlConnection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self { stream })
|
||||||
stream,
|
|
||||||
pending_done_count: 0,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::describe::{Column, Describe};
|
|||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::executor::{Execute, Executor};
|
use crate::executor::{Execute, Executor};
|
||||||
use crate::mssql::protocol::col_meta_data::Flags;
|
use crate::mssql::protocol::col_meta_data::Flags;
|
||||||
use crate::mssql::protocol::done::{Done, Status};
|
use crate::mssql::protocol::done::Status;
|
||||||
use crate::mssql::protocol::message::Message;
|
use crate::mssql::protocol::message::Message;
|
||||||
use crate::mssql::protocol::packet::PacketType;
|
use crate::mssql::protocol::packet::PacketType;
|
||||||
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
|
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
|
||||||
@@ -17,33 +17,9 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
|
|||||||
use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo};
|
use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo};
|
||||||
|
|
||||||
impl MssqlConnection {
|
impl MssqlConnection {
|
||||||
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
|
||||||
if !self.stream.wbuf.is_empty() {
|
|
||||||
self.pending_done_count += 1;
|
|
||||||
self.stream.flush().await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
while self.pending_done_count > 0 {
|
|
||||||
let message = self.stream.recv_message().await?;
|
|
||||||
|
|
||||||
if let Message::DoneProc(done) | Message::Done(done) = message {
|
|
||||||
if !done.status.contains(Status::DONE_MORE) {
|
|
||||||
// finished RPC procedure *OR* SQL batch
|
|
||||||
self.handle_done(done);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_done(&mut self, _: Done) {
|
|
||||||
self.pending_done_count -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
|
async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
|
||||||
self.wait_until_ready().await?;
|
self.stream.wait_until_ready().await?;
|
||||||
self.pending_done_count += 1;
|
self.stream.pending_done_count += 1;
|
||||||
|
|
||||||
if let Some(mut arguments) = arguments {
|
if let Some(mut arguments) = arguments {
|
||||||
let proc = Either::Right(Procedure::ExecuteSql);
|
let proc = Either::Right(Procedure::ExecuteSql);
|
||||||
@@ -112,12 +88,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Message::Done(done) | Message::DoneProc(done) => {
|
Message::Done(done) | Message::DoneProc(done) => {
|
||||||
|
if !done.status.contains(Status::DONE_MORE) {
|
||||||
|
self.stream.handle_done(&done);
|
||||||
|
}
|
||||||
|
|
||||||
if done.status.contains(Status::DONE_COUNT) {
|
if done.status.contains(Status::DONE_COUNT) {
|
||||||
r#yield!(Either::Left(done.affected_rows));
|
r#yield!(Either::Left(done.affected_rows));
|
||||||
}
|
}
|
||||||
|
|
||||||
if !done.status.contains(Status::DONE_MORE) {
|
if !done.status.contains(Status::DONE_MORE) {
|
||||||
self.handle_done(done);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -221,12 +200,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
|
|||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
self.stream.flush().await?;
|
self.stream.flush().await?;
|
||||||
|
self.stream.wait_until_ready().await?;
|
||||||
|
self.stream.pending_done_count += 1;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match self.stream.recv_message().await? {
|
match self.stream.recv_message().await? {
|
||||||
Message::DoneProc(done) | Message::Done(done) => {
|
Message::DoneProc(done) | Message::Done(done) => {
|
||||||
if !done.status.contains(Status::DONE_MORE) {
|
if !done.status.contains(Status::DONE_MORE) {
|
||||||
// done with prepare
|
// done with prepare
|
||||||
|
self.stream.handle_done(&done);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ mod stream;
|
|||||||
|
|
||||||
pub struct MssqlConnection {
|
pub struct MssqlConnection {
|
||||||
pub(crate) stream: MssqlStream,
|
pub(crate) stream: MssqlStream,
|
||||||
|
|
||||||
// number of Done* messages that we are currently expecting
|
|
||||||
pub(crate) pending_done_count: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for MssqlConnection {
|
impl Debug for MssqlConnection {
|
||||||
@@ -42,7 +39,7 @@ impl Connection for MssqlConnection {
|
|||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
||||||
self.wait_until_ready().boxed()
|
self.stream.wait_until_ready().boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use sqlx_rt::TcpStream;
|
|||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use crate::io::{BufStream, Encode};
|
use crate::io::{BufStream, Encode};
|
||||||
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
|
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
|
||||||
use crate::mssql::protocol::done::Done;
|
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
|
||||||
use crate::mssql::protocol::env_change::EnvChange;
|
use crate::mssql::protocol::env_change::EnvChange;
|
||||||
use crate::mssql::protocol::error::Error as ProtocolError;
|
use crate::mssql::protocol::error::Error as ProtocolError;
|
||||||
use crate::mssql::protocol::info::Info;
|
use crate::mssql::protocol::info::Info;
|
||||||
@@ -22,6 +22,9 @@ use crate::net::MaybeTlsStream;
|
|||||||
pub(crate) struct MssqlStream {
|
pub(crate) struct MssqlStream {
|
||||||
inner: BufStream<MaybeTlsStream<TcpStream>>,
|
inner: BufStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
|
||||||
|
// how many Done (or Error) we are currently waiting for
|
||||||
|
pub(crate) pending_done_count: usize,
|
||||||
|
|
||||||
// current transaction descriptor
|
// current transaction descriptor
|
||||||
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
|
// set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK
|
||||||
pub(crate) transaction_descriptor: u64,
|
pub(crate) transaction_descriptor: u64,
|
||||||
@@ -44,6 +47,7 @@ impl MssqlStream {
|
|||||||
inner,
|
inner,
|
||||||
columns: Vec::new(),
|
columns: Vec::new(),
|
||||||
response: None,
|
response: None,
|
||||||
|
pending_done_count: 0,
|
||||||
transaction_descriptor: 0,
|
transaction_descriptor: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -146,8 +150,8 @@ impl MssqlStream {
|
|||||||
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
|
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
|
||||||
|
|
||||||
MessageType::Error => {
|
MessageType::Error => {
|
||||||
let err = ProtocolError::get(buf)?;
|
let error = ProtocolError::get(buf)?;
|
||||||
return Err(MssqlDatabaseError(err).into());
|
return self.handle_error(error);
|
||||||
}
|
}
|
||||||
|
|
||||||
MessageType::ColMetaData => {
|
MessageType::ColMetaData => {
|
||||||
@@ -165,6 +169,35 @@ impl MssqlStream {
|
|||||||
self.response = Some(self.recv_packet().await?);
|
self.response = Some(self.recv_packet().await?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handle_done(&mut self, _done: &Done) {
|
||||||
|
self.pending_done_count -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handle_error<T>(&mut self, error: ProtocolError) -> Result<T, Error> {
|
||||||
|
// error is sent _instead_ of a done
|
||||||
|
self.pending_done_count -= 1;
|
||||||
|
Err(MssqlDatabaseError(error).into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
||||||
|
if !self.wbuf.is_empty() {
|
||||||
|
self.flush().await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
while self.pending_done_count > 0 {
|
||||||
|
let message = self.recv_message().await?;
|
||||||
|
|
||||||
|
if let Message::DoneProc(done) | Message::Done(done) = message {
|
||||||
|
if !done.status.contains(DoneStatus::DONE_MORE) {
|
||||||
|
// finished RPC procedure *OR* SQL batch
|
||||||
|
self.handle_done(&done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for MssqlStream {
|
impl Deref for MssqlStream {
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ impl TransactionManager for MssqlTransactionManager {
|
|||||||
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
|
Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
|
||||||
};
|
};
|
||||||
|
|
||||||
conn.pending_done_count += 1;
|
conn.stream.pending_done_count += 1;
|
||||||
conn.stream.write_packet(
|
conn.stream.write_packet(
|
||||||
PacketType::SqlBatch,
|
PacketType::SqlBatch,
|
||||||
SqlBatch {
|
SqlBatch {
|
||||||
|
|||||||
Reference in New Issue
Block a user