fix cancellation issues with PgListener, PgStream::recv() (#3467)

* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe

* fix(postgres): make `PgListener` close the connection on-error

* fix: incorrect math in `BufferedSocket::read_buffered()`
This commit is contained in:
Austin Bonander
2024-08-27 10:54:31 -07:00
committed by GitHub
parent 20ba796b0d
commit e10789d9d7
4 changed files with 120 additions and 21 deletions

View File

@@ -1,11 +1,11 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut};
use std::ops::{ControlFlow, Deref, DerefMut};
use std::str::FromStr;
use futures_channel::mpsc::UnboundedSender;
use futures_util::SinkExt;
use log::Level;
use sqlx_core::bytes::{Buf, Bytes};
use sqlx_core::bytes::Buf;
use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error;
@@ -77,16 +77,45 @@ impl PgStream {
}
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let mut header: Bytes = self.inner.read(5).await?;
// NOTE: to not break everything, this should be cancel-safe;
// DO NOT modify `buf` unless a full message has been read
self.inner
.try_read(|buf| {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let Some(mut header) = buf.get(..5) else {
return Ok(ControlFlow::Continue(5));
};
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let contents = self.inner.read(size).await?;
let message_len = header.get_u32() as usize;
Ok(ReceivedMessage { format, contents })
let expected_len = message_len
.checked_add(1)
// this shouldn't really happen but is mostly a sanity check
.ok_or_else(|| {
err_protocol!("message_len + 1 overflows usize: {message_len}")
})?;
if buf.len() < expected_len {
return Ok(ControlFlow::Continue(expected_len));
}
// `buf` SHOULD NOT be modified ABOVE this line
// pop off the format code since it's not counted in `message_len`
buf.advance(1);
// consume the message, including the length prefix
let mut contents = buf.split_to(message_len).freeze();
// cut off the length prefix
contents.advance(4);
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
})
.await
}
// Get the next message from the server

View File

@@ -262,8 +262,11 @@ impl PgListener {
if (err.kind() == io::ErrorKind::ConnectionAborted
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
{
self.buffer_tx = self.connection().await?.stream.notifications.take();
self.connection = None;
if let Some(mut conn) = self.connection.take() {
self.buffer_tx = conn.stream.notifications.take();
// Close the connection in a background task, so we can continue.
conn.close_on_drop();
}
// lost connection
return Ok(None);