From a0da99e128db6985c1791294956477af06ef4e29 Mon Sep 17 00:00:00 2001 From: Anthony Dodd Date: Fri, 31 Jan 2020 16:01:30 -0600 Subject: [PATCH] A good bit of refactoring. Broke up PgListener into two types. PgListener for basic one-off connections, and PgPoolListener for the listener created from the PgPool. The API is a bit more clear now with this change in terms of reconnect behavior and the like. Update `fn stream` to be `fn into_stream`, as that nomenclature is a bit more normative in the Rust ecosystem. --- sqlx-core/src/postgres/listen.rs | 288 ++++++++++++++----------------- 1 file changed, 133 insertions(+), 155 deletions(-) diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs index 31945017..d53911db 100644 --- a/sqlx-core/src/postgres/listen.rs +++ b/sqlx-core/src/postgres/listen.rs @@ -1,5 +1,6 @@ use std::ops::DerefMut; +use async_stream::stream; use futures_core::future::BoxFuture; use futures_core::stream::Stream; @@ -13,218 +14,93 @@ use crate::Result; type PgPoolConnection = PoolConnection; /// Extension methods for Postgres connections. -pub trait PgConnectionExt { +pub trait PgConnectionExt { fn listen(self, channels: &[&str]) -> PgListener; } impl PgConnectionExt for PgConnection { /// Register this connection as a listener on the specified channels. fn listen(self, channels: &[&str]) -> PgListener { - PgListener::new(Some(self), channels, None) + PgListener::new(self, channels) } } impl PgConnectionExt for PgPoolConnection { /// Register this connection as a listener on the specified channels. fn listen(self, channels: &[&str]) -> PgListener { - PgListener::new(Some(self), channels, None) - } -} - -/// Extension methods for Postgres connection pools. -pub trait PgPoolExt { - fn listen(&self, channels: &[&str]) -> PgListener; -} - -impl PgPoolExt for PgPool { - /// Fetch a new connection from the pool and register it as a listener on the specified channel. - /// - /// If the underlying connection ever dies, a new connection will be acquired from the pool, - /// and listening will resume as normal. - fn listen(&self, channels: &[&str]) -> PgListener { - PgListener::new(None, channels, Some(self.clone())) + PgListener::new(self, channels) } } /// A stream of async database notifications. /// /// Notifications will always correspond to the channel(s) specified this object is created. +/// +/// This listener is bound to the lifetime of its underlying connection. If the connection ever +/// dies, this listener will terminate and will no longer yield any notifications. pub struct PgListener { needs_to_send_listen_cmd: bool, - connection: Option, + connection: C, channels: Vec, - pool: Option, } impl PgListener { /// Construct a new instance. - pub(self) fn new(connection: Option, channels: &[&str], pool: Option) -> Self { + pub(self) fn new(connection: C, channels: &[&str]) -> Self { let channels = channels.iter().map(|chan| String::from(*chan)).collect(); Self { needs_to_send_listen_cmd: true, connection, channels, - pool, } } } -impl PgListener { +impl PgListener +where + C: Connection, + C: DerefMut, +{ /// Receives the next notification available from any of the subscribed channels. - /// - /// When a `PgListener` is created from `PgPool.listen(..)`, the `PgListener` will perform - /// automatic reconnects to the database using the original `PgPool` and will submit a - /// `LISTEN` command to the database using the same originally specified channels. As such, - /// this routine will never return `None` when called on a `PgListener` created from a `PgPool`. - /// - /// However, if a `PgListener` instance is created outside of the context of a `PgPool`, then - /// this routine will return `None` when the underlying connection dies. At that point, any - /// further calls to this routine will also return `None`. - pub async fn recv(&mut self) -> Option> { + pub async fn recv(&mut self) -> Result> { loop { - // Ensure we have an active connection to work with. - let conn = match &mut self.connection { - Some(conn) => conn, - None => match self.get_new_connection().await { - // A new connection has been established, bind it and loop. - Ok(Some(conn)) => { - self.connection = Some(conn); - continue; - } - // No pool is present on this listener, return None. - Ok(None) => return None, - // We have a pool to work with, but some error has come up. Return the error. - // The next call to `recv` will build a new connection if available. - Err(err) => return Some(Err(err)), - }, - }; // Ensure the current connection has properly registered all listener channels. if self.needs_to_send_listen_cmd { - if let Err(err) = send_listen_query(conn, &self.channels).await { - // If we've encountered an error here, test the connection, drop it if needed, - // and return the error. The next call to recv will build a new connection if possible. - if let Err(_) = conn.ping().await { - self.close_conn().await; - } - return Some(Err(err)); - } - self.needs_to_send_listen_cmd = false; - } - // Await a notification from the DB. - match conn.receive().await { - // We've received an async notification, return it. - Ok(Some(Message::NotificationResponse(notification))) => { - return Some(Ok(notification.into())) - } - // Protocol error, return the error. - Ok(Some(msg)) => { - return Some(Err(protocol_err!( - "unexpected message received from database {:?}", - msg - ) - .into())) - } - // The connection is dead, ensure that it is dropped, update self state, and loop to try again. - Ok(None) => { - self.close_conn().await; - self.needs_to_send_listen_cmd = true; - continue; - } - // An error has come up, return it. - Err(err) => return Some(Err(err)), - } - } - } - - /// Consume this listener, returning a `Stream` of notifications. - pub fn stream(mut self) -> impl Stream> { - use async_stream::stream; - stream! { - loop { - match self.recv().await { - Some(res) => yield res, - None => break, - } - } - } - } - - /// Fetch a new connection from the connection pool, if a connection pool is available. - /// - /// Errors here are transient. `Ok(None)` indicates that no pool is available. - async fn get_new_connection(&mut self) -> Result> { - let pool = match &self.pool { - Some(pool) => pool, - None => return Ok(None), - }; - Ok(Some(pool.acquire().await?)) - } - - /// Close and drop the current connection. - async fn close_conn(&mut self) { - if let Some(conn) = self.connection.take() { - let _ = conn.close().await; - } - } -} - -impl PgListener { - /// Receives the next notification available from any of the subscribed channels. - /// - /// If the underlying connection ever dies, this routine will return `None`. Any further calls - /// to this routine will also return `None`. If automatic reconnect behavior is needed, use - /// `PgPool.listen(..)`, which will automatically establish a new connection from the pool and - /// resusbcribe to all channels. - pub async fn recv(&mut self) -> Option> { - loop { - // Ensure we have an active connection to work with. - let mut conn = match &mut self.connection { - Some(conn) => conn, - None => return None, // This will never practically be hit, but let's make Rust happy. - }; - // Ensure the current connection has properly registered all listener channels. - if self.needs_to_send_listen_cmd { - if let Err(err) = send_listen_query(&mut conn, &self.channels).await { + if let Err(err) = send_listen_query(&mut self.connection, &self.channels).await { // If we've encountered an error here, test the connection. If the connection // is good, we return the error. Else, we return `None` as the connection is dead. - if let Err(_) = conn.ping().await { - return None; + if let Err(_) = self.connection.ping().await { + return Ok(None); } - return Some(Err(err)); + return Err(err); } self.needs_to_send_listen_cmd = false; } // Await a notification from the DB. - match conn.receive().await { + match self.connection.receive().await? { // We've received an async notification, return it. - Ok(Some(Message::NotificationResponse(notification))) => { - return Some(Ok(notification.into())) + Some(Message::NotificationResponse(notification)) => { + return Ok(Some(notification.into())) } // Protocol error, return the error. - Ok(Some(msg)) => { - return Some(Err(protocol_err!( + Some(msg) => { + return Err(protocol_err!( "unexpected message received from database {:?}", msg ) - .into())) + .into()) } // The connection is dead, return None. - Ok(None) => return None, - // An error has come up, return it. - Err(err) => return Some(Err(err)), + None => return Ok(None), } } } /// Consume this listener, returning a `Stream` of notifications. - pub fn stream(mut self) -> impl Stream> { - use async_stream::stream; + pub fn into_stream(mut self) -> impl Stream>> { stream! { loop { - match self.recv().await { - Some(res) => yield res, - None => break, - } + yield self.recv().await } } } @@ -236,9 +112,111 @@ where { /// Close this listener stream and its underlying connection. pub async fn close(self) -> BoxFuture<'static, Result<()>> { - match self.connection { - Some(conn) => conn.close(), - None => Box::pin(futures_util::future::ok(())), + self.connection.close() + } +} + +/// Extension methods for Postgres connection pools. +pub trait PgPoolExt { + fn listen(&self, channels: &[&str]) -> PgPoolListener; +} + +impl PgPoolExt for PgPool { + /// Create a listener which supports automatic reconnects using the connection pool. + fn listen(&self, channels: &[&str]) -> PgPoolListener { + PgPoolListener::new(channels, self.clone()) + } +} + +/// A stream of async database notifications. +/// +/// Notifications will always correspond to the channel(s) specified this object is created. +/// +/// This listener, as it is built from a `PgPool`, supports auto-reconnect. If the active +/// connection being used ever dies, this listener will detect that event, acquire a new connection +/// from the pool, will re-subscribe to all of the originally specified channels, and will resume +/// operations as normal. +pub struct PgPoolListener { + needs_to_send_listen_cmd: bool, + connection: Option, + channels: Vec, + pool: PgPool, +} + +impl PgPoolListener { + /// Construct a new instance. + pub(self) fn new(channels: &[&str], pool: PgPool) -> Self { + let channels = channels.iter().map(|chan| String::from(*chan)).collect(); + Self { + needs_to_send_listen_cmd: true, + connection: None, + channels, + pool, + } + } +} + +impl PgPoolListener { + /// Receives the next notification available from any of the subscribed channels. + pub async fn recv(&mut self) -> Result> { + loop { + // Ensure we have an active connection to work with. + let conn = match &mut self.connection { + Some(conn) => conn, + None => { + let conn = self.pool.acquire().await?; + self.connection = Some(conn); + continue; + } + }; + // Ensure the current connection has properly registered all listener channels. + if self.needs_to_send_listen_cmd { + if let Err(err) = send_listen_query(conn, &self.channels).await { + // If we've encountered an error here, test the connection, drop it if needed, + // and return the error. The next call to recv will build a new connection if possible. + if let Err(_) = conn.ping().await { + self.close_conn().await; + } + return Err(err); + } + self.needs_to_send_listen_cmd = false; + } + // Await a notification from the DB. + match conn.receive().await? { + // We've received an async notification, return it. + Some(Message::NotificationResponse(notification)) => { + return Ok(Some(notification.into())); + } + // Protocol error, return the error. + Some(msg) => { + return Err(protocol_err!( + "unexpected message received from database {:?}", + msg + ) + .into()) + } + // The connection is dead, ensure that it is dropped, update self state, and loop to try again. + None => { + self.close_conn().await; + self.needs_to_send_listen_cmd = true; + continue; + } + } + } + } + + /// Consume this listener, returning a `Stream` of notifications. + pub fn into_stream(mut self) -> impl Stream>> { + stream! { + loop { + yield self.recv().await + } + } + } + /// Close and drop the current connection. + async fn close_conn(&mut self) { + if let Some(conn) = self.connection.take() { + let _ = conn.close().await; } } }