Updates from review and from testing.

This commit is contained in:
Anthony Dodd 2020-01-29 22:55:45 -06:00 committed by Ryan Leckey
parent a52f36468b
commit cb186e6a13

View File

@ -1,3 +1,5 @@
use std::ops::DerefMut;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
@ -15,9 +17,9 @@ impl PgConnection {
/// Register this connection as a listener on the specified channel.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
let _ = self.execute(cmd.as_str(), Default::default()).await?;
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = build_listen_all_query(&[channel]);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}
@ -28,25 +30,18 @@ impl PgConnection {
mut self,
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<Self>> {
for channel in channels {
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
let _ = self.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}
/// Build a LISTEN query based on the given channel input.
fn build_listen_query(channel: &impl AsRef<str>) -> String {
format!(r#"LISTEN "{}";"#, channel.as_ref())
}
}
impl PgPool {
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
pub async fn listen(&self, channel: &impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
pub async fn listen(&self, channel: impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
let mut conn = self.acquire().await?;
let cmd = PgConnection::build_listen_query(channel);
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
let cmd = build_listen_all_query(&[channel]);
let _ = conn.send(cmd.as_str()).await?;
Ok(PgListener::new(conn))
}
@ -56,31 +51,31 @@ impl PgPool {
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<PgPoolConnection>> {
let mut conn = self.acquire().await?;
for channel in channels {
let cmd = PgConnection::build_listen_query(&channel);
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = conn.send(cmd.as_str()).await?;
Ok(PgListener::new(conn))
}
}
impl PgPoolConnection {
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = PgConnection::build_listen_query(channel);
let _ = self.execute(cmd.as_str(), Default::default()).await?;
/// Register this connection as a listener on the specified channel.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
let cmd = build_listen_all_query(&[channel]);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}
/// Fetch a new connection from the pool and register it as a listener on the specified channels.
/// Register this connection as a listener on all of the specified channels.
///
/// If an error is returned here, the connection will be dropped.
pub async fn listen_all(
mut self,
channels: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<PgListener<Self>> {
for channel in channels {
let cmd = PgConnection::build_listen_query(&channel);
let _ = self.execute(cmd.as_str(), Default::default()).await?;
}
let cmd = build_listen_all_query(channels);
let _ = self.send(cmd.as_str()).await?;
Ok(PgListener::new(self))
}
}
@ -99,16 +94,21 @@ impl<C> PgListener<C> {
impl<C> PgListener<C>
where
C: AsMut<PgConnection>,
C: DerefMut<Target = PgConnection>,
{
/// Get the next async notification from the database.
pub async fn next(&mut self) -> Result<NotifyMessage> {
loop {
match self.0.as_mut().receive().await? {
match (&mut self.0).receive().await? {
Some(Message::NotificationResponse(notification)) => return Ok(notification.into()),
// TODO: verify with team if this is correct. Looks like the connection being closed will cause an error
// to propagate up from `recevie`, but it would be good to verify with team.
Some(_) | None => continue,
Some(msg) => {
return Err(protocol_err!(
"unexpected message received from database {:?}",
msg
)
.into())
}
None => continue,
}
}
}
@ -170,6 +170,7 @@ impl<C: Connection<Database = Postgres>> crate::Executor for PgListener<C> {
}
/// An asynchronous message sent from the database.
#[derive(Debug)]
#[non_exhaustive]
pub struct NotifyMessage {
/// The channel of the notification, which can be thought of as a topic.
@ -186,3 +187,30 @@ impl From<Box<NotificationResponse>> for NotifyMessage {
}
}
}
/// Build a query which issues a LISTEN command for each given channel.
fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
channels.into_iter().fold(String::new(), |mut acc, chan| {
acc.push_str(r#"LISTEN ""#);
acc.push_str(chan.as_ref());
acc.push_str(r#"";"#);
acc
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_listen_all_query_with_single_channel() {
let output = build_listen_all_query(&["test"]);
assert_eq!(output.as_str(), r#"LISTEN "test";"#);
}
#[test]
fn build_listen_all_query_with_multiple_channels() {
let output = build_listen_all_query(&["channel.0", "channel.1"]);
assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
}
}