From e99e80cf94b46ed0e53fa314e5f6243b82a64fc3 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 16 Mar 2020 18:35:37 -0700 Subject: [PATCH] listen: merge PgListener and PgPoolListener; allow PgListener to be used as an Executor; allow channels to be adjusted at run-time --- examples/listen-postgres/src/main.rs | 64 ++-- sqlx-core/src/postgres/listen.rs | 481 ++++++++++++++------------- sqlx-core/src/postgres/mod.rs | 2 +- 3 files changed, 280 insertions(+), 267 deletions(-) diff --git a/examples/listen-postgres/src/main.rs b/examples/listen-postgres/src/main.rs index e7f3bec5..1f902e15 100644 --- a/examples/listen-postgres/src/main.rs +++ b/examples/listen-postgres/src/main.rs @@ -1,9 +1,10 @@ -use std::time::Duration; - use async_std::stream; -use futures::stream::StreamExt; -use sqlx::postgres::PgPoolExt; -use sqlx::Executor; +use futures::StreamExt; +use futures::TryStreamExt; +use sqlx::postgres::PgListener; +use sqlx::{Executor, PgPool}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; #[async_std::main] async fn main() -> Result<(), Box> { @@ -12,47 +13,56 @@ async fn main() -> Result<(), Box> { std::env::var("DATABASE_URL").expect("Env var DATABASE_URL is required for this example."); let pool = sqlx::PgPool::new(&conn_str).await?; - let notify_pool = pool.clone(); + let mut listener = PgListener::new(&conn_str).await?; + + // let notify_pool = pool.clone(); let _t = async_std::task::spawn(async move { - stream::interval(Duration::from_secs(5)) - .for_each(move |_| notify(notify_pool.clone())) + stream::interval(Duration::from_secs(2)) + .for_each(|_| notify(&pool)) .await }); println!("Starting LISTEN loop."); - let mut listener = pool.listen(&["chan0", "chan1", "chan2"]); + + listener.listen_all(&["chan0", "chan1", "chan2"]).await?; + let mut counter = 0usize; loop { - let res = listener.recv().await; - println!("[from recv]: {:?}", res); + let notification = listener.recv().await?; + println!("[from recv]: {:?}", notification); + counter += 1; if counter >= 3 { break; } } - let stream = listener.into_stream(); - futures::pin_mut!(stream); - while let Some(res) = stream.next().await { - println!("[from stream]: {:?}", res); + // Prove that we are buffering messages by waiting for 6 seconds + listener.execute("SELECT pg_sleep(6)").await?; + + let mut stream = listener.into_stream(); + while let Some(notification) = stream.try_next().await? { + println!("[from stream]: {:?}", notification); } Ok(()) } -async fn notify(pool: sqlx::PgPool) { - let mut conn = match pool.acquire().await { - Ok(conn) => conn, - Err(err) => return println!("[from notify]: {:?}", err), - }; - let res = conn - .execute( +async fn notify(mut pool: &PgPool) { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + + let res = pool + .execute(&*format!( r#" - NOTIFY "chan0", '{"payload": 0}'; - NOTIFY "chan1", '{"payload": 1}'; - NOTIFY "chan2", '{"payload": 2}'; - "#, - ) +NOTIFY "chan0", '{{"payload": {}}}'; +NOTIFY "chan1", '{{"payload": {}}}'; +NOTIFY "chan2", '{{"payload": {}}}'; + "#, + COUNTER.fetch_add(1, Ordering::SeqCst), + COUNTER.fetch_add(1, Ordering::SeqCst), + COUNTER.fetch_add(1, Ordering::SeqCst) + )) .await; + println!("[from notify]: {:?}", res); } diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs index d0080b16..95569c69 100644 --- a/sqlx-core/src/postgres/listen.rs +++ b/sqlx-core/src/postgres/listen.rs @@ -1,283 +1,286 @@ -use std::ops::DerefMut; +use std::collections::HashSet; +use std::fmt::{self, Debug}; +use std::io; -use async_stream::stream; +use async_stream::try_stream; +use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::Stream; -use crate::connection::Connection; -use crate::executor::Executor; -use crate::pool::PoolConnection; +use crate::describe::Describe; +use crate::executor::{Execute, Executor, RefExecutor}; +use crate::pool::{Pool, PoolConnection}; use crate::postgres::protocol::{Message, NotificationResponse}; -use crate::postgres::{PgConnection, PgPool}; -use crate::Result; +use crate::postgres::{PgConnection, PgCursor, Postgres}; -type PgPoolConnection = PoolConnection; - -/// Extension methods for Postgres connections. -pub trait PgConnectionExt { - fn listen(self, channels: &[&str]) -> PgListener; -} - -impl PgConnectionExt for PgConnection { - /// Register this connection as a listener on the specified channels. - fn listen(self, channels: &[&str]) -> PgListener { - PgListener::new(self, channels) - } -} - -impl PgConnectionExt for PgPoolConnection { - /// Register this connection as a listener on the specified channels. - fn listen(self, channels: &[&str]) -> PgListener { - PgListener::new(self, channels) - } -} - -/// A stream of async database notifications. +/// A stream of asynchronous notifications from Postgres. /// -/// Notifications will always correspond to the channel(s) specified when this object was created. -/// -/// This listener is bound to the lifetime of its underlying connection. If the connection ever -/// dies, this listener will terminate and will no longer yield any notifications. -pub struct PgListener { - needs_to_send_listen_cmd: bool, - connection: C, - channels: Vec, -} - -impl PgListener { - /// Construct a new instance. - pub(self) fn new(connection: C, channels: &[&str]) -> Self { - let channels = channels.iter().map(|chan| String::from(*chan)).collect(); - Self { - needs_to_send_listen_cmd: true, - connection, - channels, - } - } -} - -impl PgListener -where - C: Connection, - C: DerefMut, -{ - /// Receives the next notification available from any of the subscribed channels. - pub async fn recv(&mut self) -> Result> { - loop { - // Ensure the current connection has properly registered all listener channels. - if self.needs_to_send_listen_cmd { - if let Err(err) = send_listen_query(&mut self.connection, &self.channels).await { - // If we've encountered an error here, test the connection. If the connection - // is good, we return the error. Else, we return `None` as the connection is dead. - if let Err(_) = self.connection.ping().await { - return Ok(None); - } - return Err(err); - } - - self.needs_to_send_listen_cmd = false; - } - - // Await a notification from the DB. - return match self.connection.stream.read().await? { - // We've received an async notification, return it. - Message::NotificationResponse => { - let notification = NotificationResponse::read(self.connection.stream.buffer())?; - - Ok(Some(notification.into())) - } - - // Protocol error, return the error. - message => Err(protocol_err!( - "unexpected message received from database {:?}", - message - ) - .into()), - }; - } - } - - /// Consume this listener, returning a `Stream` of notifications. - pub fn into_stream(mut self) -> impl Stream> { - stream! { - loop { - match self.recv().await { - Ok(Some(msg)) => yield Ok(msg), - Ok(None) => break, - Err(err) => yield Err(err), - } - } - } - } -} - -impl PgListener -where - C: Connection, -{ - /// Close this listener stream and its underlying connection. - pub async fn close(self) -> BoxFuture<'static, Result<()>> { - self.connection.close() - } -} - -/// Extension methods for Postgres connection pools. -pub trait PgPoolExt { - fn listen(&self, channels: &[&str]) -> PgPoolListener; -} - -impl PgPoolExt for PgPool { - /// Create a listener which supports automatic reconnects using the connection pool. - fn listen(&self, channels: &[&str]) -> PgPoolListener { - PgPoolListener::new(channels, self.clone()) - } -} - -/// A stream of async database notifications. -/// -/// Notifications will always correspond to the channel(s) specified when this object was created. -/// -/// This listener, as it is built from a `PgPool`, supports auto-reconnect. If the active -/// connection being used ever dies, this listener will detect that event, acquire a new connection -/// from the pool, will re-subscribe to all of the originally specified channels, and will resume +/// This listener will auto-reconnect. If the active +/// connection being used ever dies, this listener will detect that event, create a +/// new connection, will re-subscribe to all of the originally specified channels, and will resume /// operations as normal. -pub struct PgPoolListener { - needs_to_send_listen_cmd: bool, - connection: Option, - channels: Vec, - pool: PgPool, +pub struct PgListener { + pool: Pool, + connection: Option>, + buffer_rx: mpsc::UnboundedReceiver>, + buffer_tx: Option>>, + channels: HashSet, } -impl PgPoolListener { - /// Construct a new instance. - pub(self) fn new(channels: &[&str], pool: PgPool) -> Self { - let channels = channels.iter().map(|chan| String::from(*chan)).collect(); - Self { - needs_to_send_listen_cmd: true, - connection: None, - channels, - pool, - } +/// An asynchronous notification from Postgres. +pub struct PgNotification<'c>(NotificationResponse<'c>); + +impl PgListener { + pub async fn new(url: &str) -> crate::Result { + // Create a pool of 1 without timeouts (as they don't apply here) + // We only use the pool to handle re-connections + let pool = Pool::::builder() + .max_size(1) + .max_lifetime(None) + .idle_timeout(None) + .build(url) + .await?; + + Self::from_pool(&pool).await + } + + pub async fn from_pool(pool: &Pool) -> crate::Result { + // Pull out an initial connection + let mut connection = pool.acquire().await?; + + // Setup a notification buffer + let (sender, receiver) = mpsc::unbounded(); + connection.stream.notifications = Some(sender); + + Ok(Self { + pool: pool.clone(), + connection: Some(connection), + buffer_rx: receiver, + buffer_tx: None, + channels: HashSet::new(), + }) + } + + /// Starts listening for notifications on a channel. + pub async fn listen(&mut self, channel: &str) -> crate::Result<()> { + self.connection() + .execute(&*format!("LISTEN {}", ident(channel))) + .await?; + + self.channels.insert(channel.to_owned()); + + Ok(()) + } + + /// Starts listening for notifications on all channels. + pub async fn listen_all(&mut self, channels: &[impl AsRef]) -> crate::Result<()> { + self.connection() + .execute(&*build_listen_all_query(channels)) + .await?; + + self.channels + .extend(channels.iter().map(|s| s.as_ref().to_string())); + + Ok(()) + } + + /// Stops listening for notifications on a channel. + pub async fn unlisten(&mut self, channel: &str) -> crate::Result<()> { + if self.channels.contains(channel) { + self.connection() + .execute(&*format!("UNLISTEN {}", ident(channel))) + .await?; + + self.channels.remove(channel); + } + + Ok(()) + } + + /// Stops listening for notifications on all channels. + pub async fn unlisten_all(&mut self) -> crate::Result<()> { + self.connection().execute("UNLISTEN *").await?; + + self.channels.clear(); + + Ok(()) + } + + #[inline] + async fn connect_if_needed(&mut self) -> crate::Result<()> { + if let None = self.connection { + let mut connection = self.pool.acquire().await?; + connection.stream.notifications = self.buffer_tx.take(); + + self.connection = Some(connection); + + let channels: Vec = self.channels.iter().cloned().collect(); + self.listen_all(&*channels).await?; + } + + Ok(()) + } + + #[inline] + fn connection(&mut self) -> &mut PgConnection { + self.connection.as_mut().unwrap() } -} -impl PgPoolListener { /// Receives the next notification available from any of the subscribed channels. - pub async fn recv(&mut self) -> Result { + pub async fn recv(&mut self) -> crate::Result> { + // Flush the buffer first, if anything + // This would only fill up if this listener is used as a connection + if let Ok(Some(notification)) = self.buffer_rx.try_next() { + return Ok(PgNotification(notification)); + } + loop { // Ensure we have an active connection to work with. - let conn = match &mut self.connection { - Some(conn) => conn, - None => { - let conn = self.pool.acquire().await?; - self.connection = Some(conn); - continue; - } - }; - // Ensure the current connection has properly registered all listener channels. - if self.needs_to_send_listen_cmd { - if let Err(err) = send_listen_query(conn, &self.channels).await { - // If we've encountered an error here, test the connection, drop it if needed, - // and return the error. The next call to recv will build a new connection if possible. - if let Err(_) = conn.ping().await { - self.close_conn().await; - } - return Err(err); - } - self.needs_to_send_listen_cmd = false; - } - // Await a notification from the DB. - // TODO: Handle connection dead here - match conn.stream.read().await? { + self.connect_if_needed().await?; + + match self.connection().stream.read().await { // We've received an async notification, return it. - Message::NotificationResponse => { - let notification = NotificationResponse::read(conn.stream.buffer())?; + Ok(Message::NotificationResponse) => { + let notification = + NotificationResponse::read(self.connection().stream.buffer())?; - return Ok(notification.into()); + return Ok(PgNotification(notification)); } - // Protocol error, return the error. - msg => { - return Err(protocol_err!( - "unexpected message received from database {:?}", - msg - ) - .into()) - } // The connection is dead, ensure that it is dropped, update self state, and loop to try again. - // None => { - // self.close_conn().await; - // self.needs_to_send_listen_cmd = true; - // continue; - // } + // Mark the connection as ready for another query + Ok(Message::ReadyForQuery) => { + self.connection().is_ready = true; + } + + // Ignore unexpected messages + Ok(_) => {} + + // The connection is dead, ensure that it is dropped, + // update self state, and loop to try again. + Err(crate::Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + self.buffer_tx = self.connection().stream.notifications.take(); + self.connection = None; + } + + // Forward other errors + Err(error) => { + return Err(error); + } } } } /// Consume this listener, returning a `Stream` of notifications. - pub fn into_stream(mut self) -> impl Stream> { - stream! { + pub fn into_stream( + mut self, + ) -> impl Stream>> + Unpin { + Box::pin(try_stream! { loop { - yield self.recv().await + let notification = self.recv().await?; + yield notification.into_owned(); } - } - } - - /// Close and drop the current connection. - async fn close_conn(&mut self) { - if let Some(conn) = self.connection.take() { - let _ = conn.close().await; - } - } - - /// Close this pool listener's current connection & drop the connection. - pub async fn close(mut self) { - self.close_conn().await + }) } } -/// An asynchronous message sent from the database. -#[derive(Debug)] -#[non_exhaustive] -pub struct PgNotification { - /// The PID of the database process which sent this notification. - pub pid: u32, - /// The channel of the notification, which can be thought of as a topic. - pub channel: String, - /// The payload of the notification. - pub payload: String, -} +impl Executor for PgListener { + type Database = Postgres; -impl From for PgNotification { - fn from(src: NotificationResponse) -> Self { - Self { - pid: src.pid, - channel: src.channel_name, - payload: src.message, - } + fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( + &'c mut self, + query: E, + ) -> BoxFuture<'e, crate::Result> + where + E: Execute<'q, Self::Database>, + { + self.connection().execute(query) + } + + fn fetch<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q> + where + E: Execute<'q, Self::Database>, + { + self.connection().fetch(query) + } + + fn describe<'e, 'q, E: 'e>( + &'e mut self, + query: E, + ) -> BoxFuture<'e, crate::Result>> + where + E: Execute<'q, Self::Database>, + { + self.connection().describe(query) } } -/// Build a query which issues a LISTEN command for each given channel. +impl<'c> RefExecutor<'c> for &'c mut PgListener { + type Database = Postgres; + + fn fetch_by_ref<'q, E>(self, query: E) -> PgCursor<'c, 'q> + where + E: Execute<'q, Self::Database>, + { + self.connection().fetch_by_ref(query) + } +} + +impl PgNotification<'_> { + /// The process ID of the notifying backend process. + #[inline] + pub fn process_id(&self) -> u32 { + self.0.process_id + } + + /// The channel that the notify has been raised on. This can be thought + /// of as the message topic. + #[inline] + pub fn channel(&self) -> &str { + self.0.channel.as_ref() + } + + /// The payload of the notification. An empty payload is received as an + /// empty string. + #[inline] + pub fn payload(&self) -> &str { + self.0.payload.as_ref() + } + + fn into_owned(self) -> PgNotification<'static> { + PgNotification(self.0.into_owned()) + } +} + +impl Debug for PgNotification<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgNotification") + .field("process_id", &self.process_id()) + .field("channel", &self.channel()) + .field("payload", &self.payload()) + .finish() + } +} + +fn ident(mut name: &str) -> String { + // If the input string contains a NUL byte, we should truncate the + // identifier. + if let Some(index) = name.find('\0') { + name = &name[..index]; + } + + // Any double quotes must be escaped + name.replace('"', "\"\"") +} + fn build_listen_all_query(channels: impl IntoIterator>) -> String { channels.into_iter().fold(String::new(), |mut acc, chan| { acc.push_str(r#"LISTEN ""#); - acc.push_str(chan.as_ref()); + acc.push_str(&ident(chan.as_ref())); acc.push_str(r#"";"#); acc }) } -/// Send the structure listen query to the database. -async fn send_listen_query>( - conn: &mut C, - channels: impl IntoIterator>, -) -> Result<()> { - let cmd = build_listen_all_query(channels); - let _ = conn.execute(cmd.as_str()).await?; - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index b2392e5f..f1dd5997 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -5,7 +5,7 @@ pub use connection::PgConnection; pub use cursor::PgCursor; pub use database::Postgres; pub use error::PgError; -pub use listen::{PgConnectionExt, PgListener, PgNotification, PgPoolExt, PgPoolListener}; +pub use listen::{PgListener, PgNotification}; pub use row::{PgRow, PgValue}; pub use types::PgTypeInfo;