WIP feat(postgres): consumer API for NoticeResponse

This commit is contained in:
Austin Bonander 2025-02-02 02:24:02 -08:00
parent 65229f7ff9
commit ba88b8fbb1
11 changed files with 262 additions and 66 deletions

1
Cargo.lock generated
View File

@ -3829,6 +3829,7 @@ dependencies = [
"crc",
"dotenvy",
"etcetera",
"futures",
"futures-channel",
"futures-core",
"futures-util",

View File

@ -53,6 +53,9 @@ pub enum Error {
#[error("encountered unexpected or invalid data: {0}")]
Protocol(String),
#[error("internal error")]
Internal(#[source] BoxDynError),
/// No rows returned by a query that expected to return at least one row.
#[error("no rows returned by a query that expected to return at least one row")]
RowNotFound,

View File

@ -42,16 +42,8 @@ macro_rules! private_tracing_dynamic_event {
pub fn private_level_filter_to_levels(
filter: log::LevelFilter,
) -> Option<(tracing::Level, log::Level)> {
let tracing_level = match filter {
log::LevelFilter::Error => Some(tracing::Level::ERROR),
log::LevelFilter::Warn => Some(tracing::Level::WARN),
log::LevelFilter::Info => Some(tracing::Level::INFO),
log::LevelFilter::Debug => Some(tracing::Level::DEBUG),
log::LevelFilter::Trace => Some(tracing::Level::TRACE),
log::LevelFilter::Off => None,
};
tracing_level.zip(filter.to_level())
filter.to_level()
.map(|level| (log_level_to_tracing_level(level), level))
}
pub(crate) fn private_level_filter_to_trace_level(
@ -60,6 +52,18 @@ pub(crate) fn private_level_filter_to_trace_level(
private_level_filter_to_levels(filter).map(|(level, _)| level)
}
pub fn log_level_to_tracing_level(
level: log::Level,
) -> tracing::Level {
match level {
log::Level::Error => tracing::Level::ERROR,
log::Level::Warn => tracing::Level::WARN,
log::Level::Info => tracing::Level::INFO,
log::Level::Debug => tracing::Level::DEBUG,
log::Level::Trace => tracing::Level::TRACE,
}
}
pub struct QueryLogger<'q> {
sql: &'q str,
rows_returned: u64,

View File

@ -76,6 +76,9 @@ workspace = true
# We use JSON in the driver implementation itself so there's no reason not to enable it here.
features = ["json"]
[dev-dependencies]
futures = "0.3"
[dev-dependencies.sqlx]
workspace = true
features = ["postgres", "derive"]

View File

@ -19,7 +19,7 @@ use crate::types::Oid;
use crate::{PgConnectOptions, PgTypeInfo, Postgres};
pub(crate) use sqlx_core::connection::*;
use crate::notice::PgNoticeSink;
pub use self::stream::PgStream;
pub(crate) mod describe;
@ -78,6 +78,16 @@ impl PgConnection {
self.inner.stream.server_version_num
}
/// Set a consumer for `NoticeResponse`s.
///
/// By default, notices are logged at an appropriate level
/// under the target `sqlx::postgres::notice`.
///
/// See [`PgNoticeSink::log()`] for details.
pub fn set_notice_sink(&mut self, sink: PgNoticeSink) {
self.inner.stream.notice_sink = sink;
}
// will return when the connection is ready for another query
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.inner.stream.write_buffer_mut().is_empty() {

View File

@ -10,11 +10,11 @@ use sqlx_core::bytes::Buf;
use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error;
use crate::message::{
BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, PgNotice, Notification,
ParameterStatus, ReceivedMessage,
};
use crate::net::{self, BufferedSocket, Socket};
use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
use crate::{PgConnectOptions, PgDatabaseError, PgNoticeSink, PgSeverity};
// the stream is a separate type from the connection to uphold the invariant where an instantiated
// [PgConnection] is a **valid** connection to postgres
@ -38,6 +38,8 @@ pub struct PgStream {
pub(crate) parameter_statuses: BTreeMap<String, String>,
pub(crate) server_version_num: Option<u32>,
pub(crate) notice_sink: PgNoticeSink,
}
impl PgStream {
@ -54,6 +56,7 @@ impl PgStream {
notifications: None,
parameter_statuses: BTreeMap::default(),
server_version_num: None,
notice_sink: PgNoticeSink::log(),
})
}
@ -159,36 +162,8 @@ impl PgStream {
}
BackendMessageFormat::NoticeResponse => {
// do we need this to be more configurable?
// if you are reading this comment and think so, open an issue
let notice: Notice = message.decode()?;
let (log_level, tracing_level) = match notice.severity() {
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
(Level::Error, tracing::Level::ERROR)
}
PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
};
let log_is_enabled = log::log_enabled!(
target: "sqlx::postgres::notice",
log_level
) || sqlx_core::private_tracing_dynamic_enabled!(
target: "sqlx::postgres::notice",
tracing_level
);
if log_is_enabled {
sqlx_core::private_tracing_dynamic_event!(
target: "sqlx::postgres::notice",
tracing_level,
message = notice.message()
);
}
let notice: PgNotice = message.decode()?;
self.notice_sink.consume(notice).await?;
continue;
}

View File

@ -6,10 +6,10 @@ use smallvec::alloc::borrow::Cow;
use sqlx_core::bytes::Bytes;
pub(crate) use sqlx_core::error::*;
use crate::message::{BackendMessage, BackendMessageFormat, Notice, PgSeverity};
use crate::message::{BackendMessage, BackendMessageFormat, PgNotice, PgSeverity};
/// An error returned from the PostgreSQL database.
pub struct PgDatabaseError(pub(crate) Notice);
pub struct PgDatabaseError(pub(crate) PgNotice);
// Error message fields are documented:
// https://www.postgresql.org/docs/current/protocol-error-fields.html
@ -225,7 +225,7 @@ impl BackendMessage for PgDatabaseError {
#[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(Notice::decode_body(buf)?))
Ok(Self(PgNotice::decode_body(buf)?))
}
}

View File

@ -15,6 +15,9 @@ mod error;
mod io;
mod listener;
mod message;
mod notice;
mod options;
mod query_result;
mod row;
@ -53,7 +56,8 @@ pub use copy::{PgCopyIn, PgPoolCopyExt};
pub use database::Postgres;
pub use error::{PgDatabaseError, PgErrorPosition};
pub use listener::{PgListener, PgNotification};
pub use message::PgSeverity;
pub use message::{PgNotice, PgSeverity};
pub use notice::PgNoticeSink;
pub use options::{PgConnectOptions, PgSslMode};
pub use query_result::PgQueryResult;
pub use row::PgRow;

View File

@ -49,7 +49,7 @@ pub use parse_complete::ParseComplete;
pub use password::Password;
pub use query::Query;
pub use ready_for_query::{ReadyForQuery, TransactionStatus};
pub use response::{Notice, PgSeverity};
pub use response::{PgNotice, PgSeverity};
pub use row_description::RowDescription;
pub use sasl::{SaslInitialResponse, SaslResponse};
use sqlx_core::io::ProtocolEncode;

View File

@ -1,6 +1,6 @@
use std::fmt::{Debug, Display, Formatter};
use std::ops::Range;
use std::str::from_utf8;
use memchr::memchr;
use sqlx_core::bytes::Bytes;
@ -9,6 +9,11 @@ use crate::error::Error;
use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
/// Severity level for [`PgDatabaseError`] (`ErrorResponse`) and [`PgNotice`] (`NoticeResponse`).
///
///
/// [`PgDatabaseError`]: sqlx::postgres::PgDatabaseError
/// [`PgNotice`]:
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum PgSeverity {
@ -27,6 +32,44 @@ impl PgSeverity {
pub fn is_error(self) -> bool {
matches!(self, Self::Panic | Self::Fatal | Self::Error)
}
#[inline]
pub fn as_str(&self) -> &str {
match self {
PgSeverity::Panic => "PANIC",
PgSeverity::Fatal => "FATAL",
PgSeverity::Error => "ERROR",
PgSeverity::Warning => "WARNING",
PgSeverity::Notice => "NOTICE",
PgSeverity::Debug => "DEBUG",
PgSeverity::Info => "INFO",
PgSeverity::Log => "LOG",
}
}
pub(crate) fn to_tracing_level(&self) -> tracing::Level {
match self {
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
tracing::Level::ERROR
}
PgSeverity::Warning => tracing::Level::WARN,
PgSeverity::Notice => tracing::Level::INFO,
PgSeverity::Debug => tracing::Level::DEBUG,
PgSeverity::Info | PgSeverity::Log => tracing::Level::TRACE,
}
}
pub(crate) fn to_log_level(&self) -> log::Level {
match self {
PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
log::Level::Error
}
PgSeverity::Warning => log::Level::Warn,
PgSeverity::Notice => log::Level::Info,
PgSeverity::Debug => log::Level::Debug,
PgSeverity::Info | PgSeverity::Log => log::Level::Trace,
}
}
}
impl TryFrom<&str> for PgSeverity {
@ -52,15 +95,25 @@ impl TryFrom<&str> for PgSeverity {
}
}
#[derive(Debug)]
pub struct Notice {
impl Display for PgSeverity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.pad(self.as_str())
}
}
/// A decoded `NoticeResponse`.
///
/// May be obtained by creating a [`PgNoticeSink`][crate::PgNoticeSink] and calling
/// [`PgConnection::set_notice_sink()`][crate::PgConnection::set_notice_sink()].
pub struct PgNotice {
storage: Bytes,
severity: PgSeverity,
message: Range<usize>,
code: Range<usize>,
}
impl Notice {
impl PgNotice {
#[inline]
pub fn severity(&self) -> PgSeverity {
self.severity
@ -76,23 +129,28 @@ impl Notice {
self.get_cached_str(self.message.clone())
}
// Field descriptions available here:
// https://www.postgresql.org/docs/current/protocol-error-fields.html
/// Get a field from this notice by tag as a string.
///
/// Returns `None` if the field does not exist, or is not valid UTF-8.
///
/// Notice fields reference: <https://www.postgresql.org/docs/current/protocol-error-fields.html>
#[inline]
pub fn get(&self, ty: u8) -> Option<&str> {
self.get_raw(ty).and_then(|v| from_utf8(v).ok())
pub fn get(&self, tag: u8) -> Option<&str> {
self.get_raw(tag).and_then(|v| from_utf8(v).ok())
}
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
/// Get a field from this notice by tag as raw bytes.
///
/// Returns `None` if the field does not exist.
///
/// Notice fields reference: <https://www.postgresql.org/docs/current/protocol-error-fields.html>
pub fn get_raw(&self, tag: u8) -> Option<&[u8]> {
self.fields()
.filter(|(field, _)| *field == ty)
.filter(|(field, _)| *field == tag)
.map(|(_, range)| &self.storage[range])
.next()
}
}
impl Notice {
#[inline]
fn fields(&self) -> Fields<'_> {
Fields {
@ -108,7 +166,18 @@ impl Notice {
}
}
impl ProtocolDecode<'_> for Notice {
impl Debug for PgNotice {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PgNotice")
.field("severity", &self.severity)
.field("code", &self.code())
.field("message", &self.message())
.field("fields", &self.fields())
.finish()
}
}
impl ProtocolDecode<'_> for PgNotice {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
// Newer versions additionally come with the V field that is guaranteed to be in English.
@ -180,7 +249,7 @@ impl ProtocolDecode<'_> for Notice {
}
}
impl BackendMessage for Notice {
impl BackendMessage for PgNotice {
const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
@ -190,6 +259,7 @@ impl BackendMessage for Notice {
}
/// An iterator over each field in the Error (or Notice) response.
#[derive(Clone)]
struct Fields<'a> {
storage: &'a [u8],
offset: usize,
@ -223,6 +293,21 @@ impl<'a> Iterator for Fields<'a> {
}
}
impl Debug for Fields<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut debug_map = f.debug_map();
for (tag, span) in self.clone() {
debug_map.entry(
&format_args!("'{}'", tag.escape_ascii()),
&format_args!("\"{}\"", &self.storage[span].escape_ascii())
);
}
debug_map.finish()
}
}
fn notice_protocol_err() -> Error {
// https://github.com/launchbadge/sqlx/issues/1144
Error::Protocol(
@ -238,7 +323,7 @@ fn notice_protocol_err() -> Error {
fn test_decode_error_response() {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let m = Notice::decode(Bytes::from_static(DATA)).unwrap();
let m = PgNotice::decode(Bytes::from_static(DATA)).unwrap();
assert_eq!(
m.message(),
@ -254,7 +339,7 @@ fn test_decode_error_response() {
fn bench_error_response_get_message(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let res = Notice::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
let res = PgNotice::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
b.iter(|| {
let _ = test::black_box(&res).message();
@ -267,6 +352,6 @@ fn bench_decode_error_response(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
b.iter(|| {
let _ = Notice::decode(test::black_box(Bytes::from_static(DATA)));
let _ = PgNotice::decode(test::black_box(Bytes::from_static(DATA)));
});
}

111
sqlx-postgres/src/notice.rs Normal file
View File

@ -0,0 +1,111 @@
use std::pin::Pin;
use futures_util::{Sink, SinkExt};
use sqlx_core::error::{BoxDynError, Error};
use sqlx_core::logger::log_level_to_tracing_level;
use crate::message::PgNotice;
/// Sink for Postgres `NoticeResponse`s.
pub struct PgNoticeSink {
inner: SinkInner
}
enum SinkInner {
Discard,
Log,
Closure(Box<dyn FnMut(PgNotice) -> Result<(), BoxDynError> + Send + Sync + 'static>),
Wrapped(Pin<Box<dyn Sink<PgNotice, Error = BoxDynError> + Send + Sync + 'static>>),
}
impl PgNoticeSink {
/// Discard all `NoticeResponse`s.
pub fn discard() -> Self {
PgNoticeSink {
inner: SinkInner::Discard,
}
}
/// Log `NoticeResponse`s according to severity level under the target `sqlx::postgres::notice`.
///
/// | Postgres Severity Level | `log`/`tracing` Level |
/// | ------------------------- | --------------------- |
/// | `PANIC`, `FATAL`, `ERROR` | `ERROR` |
/// | `WARNING` | `WARN` |
/// | `NOTICE` | `INFO` |
/// | `DEBUG` | `DEBUG` |
/// | `INFO`, `LOG` | `TRACE` |
///
/// This is the default behavior of new `PgConnection`s.
///
/// To instead consume `NoticeResponse`s directly as [`PgNotice`]s, see:
///
/// * [`PgNoticeSink::closure()`]
/// * [`PgNoticeSink::wrap()`]
/// * [`PgConnection::set_notice_sink()`][crate::PgConnection::set_notice_sink()]
pub fn log() -> Self {
PgNoticeSink {
inner: SinkInner::Log
}
}
/// Supply a closure to handle [`PgNotice`]s.
///
/// Errors will be bubbled up by the connection as [`Error::Internal`].
///
/// # Warning: Do Not Block!
///
/// The closure is invoked directly by the connection, so it should not block if it is unable
/// to immediately consume the message.
///
/// Instead, use [`Self::wrap()`] to provide a [`futures::Sink`] implementation.
pub fn closure(f: impl FnMut(PgNotice) -> Result<(), BoxDynError> + Send + Sync + 'static) -> Self {
PgNoticeSink {
inner: SinkInner::Closure(Box::new(f)),
}
}
/// Supply a [`futures::Sink`] to handle [`PgNotice`]s.
///
/// Errors will be bubbled up by the connection as [`Error::Internal`].
pub fn wrap(sink: impl Sink<PgNotice, Error = BoxDynError> + Send + Sync + 'static) -> Self {
PgNoticeSink {
inner: SinkInner::Wrapped(Box::pin(sink)),
}
}
pub(crate) async fn consume(&mut self, notice: PgNotice) -> Result<(), Error> {
match &mut self.inner {
SinkInner::Discard => Ok(()),
SinkInner::Log => {
log_notice(notice);
Ok(())
}
SinkInner::Closure(f) => f(notice).map_err(Error::Internal),
SinkInner::Wrapped(sink) => {
sink.as_mut().send(notice).await.map_err(Error::Internal)
}
}
}
}
fn log_notice(notice: PgNotice) {
let tracing_level = notice.severity().to_tracing_level();
let log_is_enabled = log::log_enabled!(
target: "sqlx::postgres::notice",
notice.severity().to_log_level()
) || sqlx_core::private_tracing_dynamic_enabled!(
target: "sqlx::postgres::notice",
tracing_level
);
if log_is_enabled {
sqlx_core::private_tracing_dynamic_event!(
target: "sqlx::postgres::notice",
tracing_level,
severity=%notice.severity(),
code=%notice.code(),
"{}",
notice.message()
);
}
}