mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-19 16:44:07 +00:00
fix cancellation issues with PgListener, PgStream::recv() (#3467)
* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe * fix(postgres): make `PgListener` close the connection on-error * fix: incorrect math in `BufferedSocket::read_buffered()`
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::ops::{ControlFlow, Deref, DerefMut};
|
||||
use std::str::FromStr;
|
||||
|
||||
use futures_channel::mpsc::UnboundedSender;
|
||||
use futures_util::SinkExt;
|
||||
use log::Level;
|
||||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
use sqlx_core::bytes::Buf;
|
||||
|
||||
use crate::connection::tls::MaybeUpgradeTls;
|
||||
use crate::error::Error;
|
||||
@@ -77,16 +77,45 @@ impl PgStream {
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
|
||||
// all packets in postgres start with a 5-byte header
|
||||
// this header contains the message type and the total length of the message
|
||||
let mut header: Bytes = self.inner.read(5).await?;
|
||||
// NOTE: to not break everything, this should be cancel-safe;
|
||||
// DO NOT modify `buf` unless a full message has been read
|
||||
self.inner
|
||||
.try_read(|buf| {
|
||||
// all packets in postgres start with a 5-byte header
|
||||
// this header contains the message type and the total length of the message
|
||||
let Some(mut header) = buf.get(..5) else {
|
||||
return Ok(ControlFlow::Continue(5));
|
||||
};
|
||||
|
||||
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
||||
let size = (header.get_u32() - 4) as usize;
|
||||
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
||||
|
||||
let contents = self.inner.read(size).await?;
|
||||
let message_len = header.get_u32() as usize;
|
||||
|
||||
Ok(ReceivedMessage { format, contents })
|
||||
let expected_len = message_len
|
||||
.checked_add(1)
|
||||
// this shouldn't really happen but is mostly a sanity check
|
||||
.ok_or_else(|| {
|
||||
err_protocol!("message_len + 1 overflows usize: {message_len}")
|
||||
})?;
|
||||
|
||||
if buf.len() < expected_len {
|
||||
return Ok(ControlFlow::Continue(expected_len));
|
||||
}
|
||||
|
||||
// `buf` SHOULD NOT be modified ABOVE this line
|
||||
|
||||
// pop off the format code since it's not counted in `message_len`
|
||||
buf.advance(1);
|
||||
|
||||
// consume the message, including the length prefix
|
||||
let mut contents = buf.split_to(message_len).freeze();
|
||||
|
||||
// cut off the length prefix
|
||||
contents.advance(4);
|
||||
|
||||
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// Get the next message from the server
|
||||
|
||||
@@ -262,8 +262,11 @@ impl PgListener {
|
||||
if (err.kind() == io::ErrorKind::ConnectionAborted
|
||||
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
|
||||
{
|
||||
self.buffer_tx = self.connection().await?.stream.notifications.take();
|
||||
self.connection = None;
|
||||
if let Some(mut conn) = self.connection.take() {
|
||||
self.buffer_tx = conn.stream.notifications.take();
|
||||
// Close the connection in a background task, so we can continue.
|
||||
conn.close_on_drop();
|
||||
}
|
||||
|
||||
// lost connection
|
||||
return Ok(None);
|
||||
|
||||
Reference in New Issue
Block a user