listen: merge PgListener and PgPoolListener; allow PgListener to be used as an Executor; allow channels to be adjusted at run-time

This commit is contained in:
Ryan Leckey 2020-03-16 18:35:37 -07:00
parent ed9d6c3b62
commit e99e80cf94
3 changed files with 280 additions and 267 deletions

View File

@ -1,9 +1,10 @@
use std::time::Duration;
use async_std::stream;
use futures::stream::StreamExt;
use sqlx::postgres::PgPoolExt;
use sqlx::Executor;
use futures::StreamExt;
use futures::TryStreamExt;
use sqlx::postgres::PgListener;
use sqlx::{Executor, PgPool};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[async_std::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -12,47 +13,56 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
std::env::var("DATABASE_URL").expect("Env var DATABASE_URL is required for this example.");
let pool = sqlx::PgPool::new(&conn_str).await?;
let notify_pool = pool.clone();
let mut listener = PgListener::new(&conn_str).await?;
// let notify_pool = pool.clone();
let _t = async_std::task::spawn(async move {
stream::interval(Duration::from_secs(5))
.for_each(move |_| notify(notify_pool.clone()))
stream::interval(Duration::from_secs(2))
.for_each(|_| notify(&pool))
.await
});
println!("Starting LISTEN loop.");
let mut listener = pool.listen(&["chan0", "chan1", "chan2"]);
listener.listen_all(&["chan0", "chan1", "chan2"]).await?;
let mut counter = 0usize;
loop {
let res = listener.recv().await;
println!("[from recv]: {:?}", res);
let notification = listener.recv().await?;
println!("[from recv]: {:?}", notification);
counter += 1;
if counter >= 3 {
break;
}
}
let stream = listener.into_stream();
futures::pin_mut!(stream);
while let Some(res) = stream.next().await {
println!("[from stream]: {:?}", res);
// Prove that we are buffering messages by waiting for 6 seconds
listener.execute("SELECT pg_sleep(6)").await?;
let mut stream = listener.into_stream();
while let Some(notification) = stream.try_next().await? {
println!("[from stream]: {:?}", notification);
}
Ok(())
}
async fn notify(pool: sqlx::PgPool) {
let mut conn = match pool.acquire().await {
Ok(conn) => conn,
Err(err) => return println!("[from notify]: {:?}", err),
};
let res = conn
.execute(
async fn notify(mut pool: &PgPool) {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let res = pool
.execute(&*format!(
r#"
NOTIFY "chan0", '{"payload": 0}';
NOTIFY "chan1", '{"payload": 1}';
NOTIFY "chan2", '{"payload": 2}';
"#,
)
NOTIFY "chan0", '{{"payload": {}}}';
NOTIFY "chan1", '{{"payload": {}}}';
NOTIFY "chan2", '{{"payload": {}}}';
"#,
COUNTER.fetch_add(1, Ordering::SeqCst),
COUNTER.fetch_add(1, Ordering::SeqCst),
COUNTER.fetch_add(1, Ordering::SeqCst)
))
.await;
println!("[from notify]: {:?}", res);
}

View File

@ -1,283 +1,286 @@
use std::ops::DerefMut;
use std::collections::HashSet;
use std::fmt::{self, Debug};
use std::io;
use async_stream::stream;
use async_stream::try_stream;
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::Stream;
use crate::connection::Connection;
use crate::executor::Executor;
use crate::pool::PoolConnection;
use crate::describe::Describe;
use crate::executor::{Execute, Executor, RefExecutor};
use crate::pool::{Pool, PoolConnection};
use crate::postgres::protocol::{Message, NotificationResponse};
use crate::postgres::{PgConnection, PgPool};
use crate::Result;
use crate::postgres::{PgConnection, PgCursor, Postgres};
type PgPoolConnection = PoolConnection<PgConnection>;
/// Extension methods for Postgres connections.
pub trait PgConnectionExt<C: Connection + Unpin> {
fn listen(self, channels: &[&str]) -> PgListener<C>;
}
impl PgConnectionExt<PgConnection> for PgConnection {
/// Register this connection as a listener on the specified channels.
fn listen(self, channels: &[&str]) -> PgListener<Self> {
PgListener::new(self, channels)
}
}
impl PgConnectionExt<PgPoolConnection> for PgPoolConnection {
/// Register this connection as a listener on the specified channels.
fn listen(self, channels: &[&str]) -> PgListener<Self> {
PgListener::new(self, channels)
}
}
/// A stream of async database notifications.
/// A stream of asynchronous notifications from Postgres.
///
/// Notifications will always correspond to the channel(s) specified when this object was created.
///
/// This listener is bound to the lifetime of its underlying connection. If the connection ever
/// dies, this listener will terminate and will no longer yield any notifications.
pub struct PgListener<C> {
needs_to_send_listen_cmd: bool,
connection: C,
channels: Vec<String>,
}
impl<C> PgListener<C> {
/// Construct a new instance.
pub(self) fn new(connection: C, channels: &[&str]) -> Self {
let channels = channels.iter().map(|chan| String::from(*chan)).collect();
Self {
needs_to_send_listen_cmd: true,
connection,
channels,
}
}
}
impl<C> PgListener<C>
where
C: Connection,
C: DerefMut<Target = PgConnection>,
{
/// Receives the next notification available from any of the subscribed channels.
pub async fn recv(&mut self) -> Result<Option<PgNotification>> {
loop {
// Ensure the current connection has properly registered all listener channels.
if self.needs_to_send_listen_cmd {
if let Err(err) = send_listen_query(&mut self.connection, &self.channels).await {
// If we've encountered an error here, test the connection. If the connection
// is good, we return the error. Else, we return `None` as the connection is dead.
if let Err(_) = self.connection.ping().await {
return Ok(None);
}
return Err(err);
}
self.needs_to_send_listen_cmd = false;
}
// Await a notification from the DB.
return match self.connection.stream.read().await? {
// We've received an async notification, return it.
Message::NotificationResponse => {
let notification = NotificationResponse::read(self.connection.stream.buffer())?;
Ok(Some(notification.into()))
}
// Protocol error, return the error.
message => Err(protocol_err!(
"unexpected message received from database {:?}",
message
)
.into()),
};
}
}
/// Consume this listener, returning a `Stream` of notifications.
pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification>> {
stream! {
loop {
match self.recv().await {
Ok(Some(msg)) => yield Ok(msg),
Ok(None) => break,
Err(err) => yield Err(err),
}
}
}
}
}
impl<C> PgListener<C>
where
C: Connection,
{
/// Close this listener stream and its underlying connection.
pub async fn close(self) -> BoxFuture<'static, Result<()>> {
self.connection.close()
}
}
/// Extension methods for Postgres connection pools.
pub trait PgPoolExt {
fn listen(&self, channels: &[&str]) -> PgPoolListener;
}
impl PgPoolExt for PgPool {
/// Create a listener which supports automatic reconnects using the connection pool.
fn listen(&self, channels: &[&str]) -> PgPoolListener {
PgPoolListener::new(channels, self.clone())
}
}
/// A stream of async database notifications.
///
/// Notifications will always correspond to the channel(s) specified when this object was created.
///
/// This listener, as it is built from a `PgPool`, supports auto-reconnect. If the active
/// connection being used ever dies, this listener will detect that event, acquire a new connection
/// from the pool, will re-subscribe to all of the originally specified channels, and will resume
/// This listener will auto-reconnect. If the active
/// connection being used ever dies, this listener will detect that event, create a
/// new connection, will re-subscribe to all of the originally specified channels, and will resume
/// operations as normal.
pub struct PgPoolListener {
needs_to_send_listen_cmd: bool,
connection: Option<PgPoolConnection>,
channels: Vec<String>,
pool: PgPool,
pub struct PgListener {
pool: Pool<PgConnection>,
connection: Option<PoolConnection<PgConnection>>,
buffer_rx: mpsc::UnboundedReceiver<NotificationResponse<'static>>,
buffer_tx: Option<mpsc::UnboundedSender<NotificationResponse<'static>>>,
channels: HashSet<String>,
}
impl PgPoolListener {
/// Construct a new instance.
pub(self) fn new(channels: &[&str], pool: PgPool) -> Self {
let channels = channels.iter().map(|chan| String::from(*chan)).collect();
Self {
needs_to_send_listen_cmd: true,
connection: None,
channels,
pool,
}
/// An asynchronous notification from Postgres.
pub struct PgNotification<'c>(NotificationResponse<'c>);
impl PgListener {
pub async fn new(url: &str) -> crate::Result<Self> {
// Create a pool of 1 without timeouts (as they don't apply here)
// We only use the pool to handle re-connections
let pool = Pool::<PgConnection>::builder()
.max_size(1)
.max_lifetime(None)
.idle_timeout(None)
.build(url)
.await?;
Self::from_pool(&pool).await
}
pub async fn from_pool(pool: &Pool<PgConnection>) -> crate::Result<Self> {
// Pull out an initial connection
let mut connection = pool.acquire().await?;
// Setup a notification buffer
let (sender, receiver) = mpsc::unbounded();
connection.stream.notifications = Some(sender);
Ok(Self {
pool: pool.clone(),
connection: Some(connection),
buffer_rx: receiver,
buffer_tx: None,
channels: HashSet::new(),
})
}
/// Starts listening for notifications on a channel.
pub async fn listen(&mut self, channel: &str) -> crate::Result<()> {
self.connection()
.execute(&*format!("LISTEN {}", ident(channel)))
.await?;
self.channels.insert(channel.to_owned());
Ok(())
}
/// Starts listening for notifications on all channels.
pub async fn listen_all(&mut self, channels: &[impl AsRef<str>]) -> crate::Result<()> {
self.connection()
.execute(&*build_listen_all_query(channels))
.await?;
self.channels
.extend(channels.iter().map(|s| s.as_ref().to_string()));
Ok(())
}
/// Stops listening for notifications on a channel.
pub async fn unlisten(&mut self, channel: &str) -> crate::Result<()> {
if self.channels.contains(channel) {
self.connection()
.execute(&*format!("UNLISTEN {}", ident(channel)))
.await?;
self.channels.remove(channel);
}
Ok(())
}
/// Stops listening for notifications on all channels.
pub async fn unlisten_all(&mut self) -> crate::Result<()> {
self.connection().execute("UNLISTEN *").await?;
self.channels.clear();
Ok(())
}
#[inline]
async fn connect_if_needed(&mut self) -> crate::Result<()> {
if let None = self.connection {
let mut connection = self.pool.acquire().await?;
connection.stream.notifications = self.buffer_tx.take();
self.connection = Some(connection);
let channels: Vec<String> = self.channels.iter().cloned().collect();
self.listen_all(&*channels).await?;
}
Ok(())
}
#[inline]
fn connection(&mut self) -> &mut PgConnection {
self.connection.as_mut().unwrap()
}
}
impl PgPoolListener {
/// Receives the next notification available from any of the subscribed channels.
pub async fn recv(&mut self) -> Result<PgNotification> {
pub async fn recv(&mut self) -> crate::Result<PgNotification<'_>> {
// Flush the buffer first, if anything
// This would only fill up if this listener is used as a connection
if let Ok(Some(notification)) = self.buffer_rx.try_next() {
return Ok(PgNotification(notification));
}
loop {
// Ensure we have an active connection to work with.
let conn = match &mut self.connection {
Some(conn) => conn,
None => {
let conn = self.pool.acquire().await?;
self.connection = Some(conn);
continue;
}
};
// Ensure the current connection has properly registered all listener channels.
if self.needs_to_send_listen_cmd {
if let Err(err) = send_listen_query(conn, &self.channels).await {
// If we've encountered an error here, test the connection, drop it if needed,
// and return the error. The next call to recv will build a new connection if possible.
if let Err(_) = conn.ping().await {
self.close_conn().await;
}
return Err(err);
}
self.needs_to_send_listen_cmd = false;
}
// Await a notification from the DB.
// TODO: Handle connection dead here
match conn.stream.read().await? {
self.connect_if_needed().await?;
match self.connection().stream.read().await {
// We've received an async notification, return it.
Message::NotificationResponse => {
let notification = NotificationResponse::read(conn.stream.buffer())?;
Ok(Message::NotificationResponse) => {
let notification =
NotificationResponse::read(self.connection().stream.buffer())?;
return Ok(notification.into());
return Ok(PgNotification(notification));
}
// Protocol error, return the error.
msg => {
return Err(protocol_err!(
"unexpected message received from database {:?}",
msg
)
.into())
} // The connection is dead, ensure that it is dropped, update self state, and loop to try again.
// None => {
// self.close_conn().await;
// self.needs_to_send_listen_cmd = true;
// continue;
// }
// Mark the connection as ready for another query
Ok(Message::ReadyForQuery) => {
self.connection().is_ready = true;
}
// Ignore unexpected messages
Ok(_) => {}
// The connection is dead, ensure that it is dropped,
// update self state, and loop to try again.
Err(crate::Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
self.buffer_tx = self.connection().stream.notifications.take();
self.connection = None;
}
// Forward other errors
Err(error) => {
return Err(error);
}
}
}
}
/// Consume this listener, returning a `Stream` of notifications.
pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification>> {
stream! {
pub fn into_stream(
mut self,
) -> impl Stream<Item = crate::Result<PgNotification<'static>>> + Unpin {
Box::pin(try_stream! {
loop {
yield self.recv().await
let notification = self.recv().await?;
yield notification.into_owned();
}
}
}
/// Close and drop the current connection.
async fn close_conn(&mut self) {
if let Some(conn) = self.connection.take() {
let _ = conn.close().await;
}
}
/// Close this pool listener's current connection & drop the connection.
pub async fn close(mut self) {
self.close_conn().await
})
}
}
/// An asynchronous message sent from the database.
#[derive(Debug)]
#[non_exhaustive]
pub struct PgNotification {
/// The PID of the database process which sent this notification.
pub pid: u32,
/// The channel of the notification, which can be thought of as a topic.
pub channel: String,
/// The payload of the notification.
pub payload: String,
}
impl Executor for PgListener {
type Database = Postgres;
impl From<NotificationResponse> for PgNotification {
fn from(src: NotificationResponse) -> Self {
Self {
pid: src.pid,
channel: src.channel_name,
payload: src.message,
}
fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>(
&'c mut self,
query: E,
) -> BoxFuture<'e, crate::Result<u64>>
where
E: Execute<'q, Self::Database>,
{
self.connection().execute(query)
}
fn fetch<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q>
where
E: Execute<'q, Self::Database>,
{
self.connection().fetch(query)
}
fn describe<'e, 'q, E: 'e>(
&'e mut self,
query: E,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
where
E: Execute<'q, Self::Database>,
{
self.connection().describe(query)
}
}
/// Build a query which issues a LISTEN command for each given channel.
impl<'c> RefExecutor<'c> for &'c mut PgListener {
type Database = Postgres;
fn fetch_by_ref<'q, E>(self, query: E) -> PgCursor<'c, 'q>
where
E: Execute<'q, Self::Database>,
{
self.connection().fetch_by_ref(query)
}
}
impl PgNotification<'_> {
/// The process ID of the notifying backend process.
#[inline]
pub fn process_id(&self) -> u32 {
self.0.process_id
}
/// The channel that the notify has been raised on. This can be thought
/// of as the message topic.
#[inline]
pub fn channel(&self) -> &str {
self.0.channel.as_ref()
}
/// The payload of the notification. An empty payload is received as an
/// empty string.
#[inline]
pub fn payload(&self) -> &str {
self.0.payload.as_ref()
}
fn into_owned(self) -> PgNotification<'static> {
PgNotification(self.0.into_owned())
}
}
impl Debug for PgNotification<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PgNotification")
.field("process_id", &self.process_id())
.field("channel", &self.channel())
.field("payload", &self.payload())
.finish()
}
}
fn ident(mut name: &str) -> String {
// If the input string contains a NUL byte, we should truncate the
// identifier.
if let Some(index) = name.find('\0') {
name = &name[..index];
}
// Any double quotes must be escaped
name.replace('"', "\"\"")
}
fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
channels.into_iter().fold(String::new(), |mut acc, chan| {
acc.push_str(r#"LISTEN ""#);
acc.push_str(chan.as_ref());
acc.push_str(&ident(chan.as_ref()));
acc.push_str(r#"";"#);
acc
})
}
/// Send the structure listen query to the database.
async fn send_listen_query<C: DerefMut<Target = PgConnection>>(
conn: &mut C,
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<()> {
let cmd = build_listen_all_query(channels);
let _ = conn.execute(cmd.as_str()).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -5,7 +5,7 @@ pub use connection::PgConnection;
pub use cursor::PgCursor;
pub use database::Postgres;
pub use error::PgError;
pub use listen::{PgConnectionExt, PgListener, PgNotification, PgPoolExt, PgPoolListener};
pub use listen::{PgListener, PgNotification};
pub use row::{PgRow, PgValue};
pub use types::PgTypeInfo;