diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs index aad714e0..822d11bb 100644 --- a/sqlx-core/src/postgres/listen.rs +++ b/sqlx-core/src/postgres/listen.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -15,9 +17,9 @@ impl PgConnection { /// Register this connection as a listener on the specified channel. /// /// If an error is returned here, the connection will be dropped. - pub async fn listen(mut self, channel: &impl AsRef) -> Result> { - let cmd = format!(r#"LISTEN "{}""#, channel.as_ref()); - let _ = self.execute(cmd.as_str(), Default::default()).await?; + pub async fn listen(mut self, channel: impl AsRef) -> Result> { + let cmd = build_listen_all_query(&[channel]); + let _ = self.send(cmd.as_str()).await?; Ok(PgListener::new(self)) } @@ -28,25 +30,18 @@ impl PgConnection { mut self, channels: impl IntoIterator>, ) -> Result> { - for channel in channels { - let cmd = format!(r#"LISTEN "{}""#, channel.as_ref()); - let _ = self.execute(cmd.as_str(), Default::default()).await?; - } + let cmd = build_listen_all_query(channels); + let _ = self.send(cmd.as_str()).await?; Ok(PgListener::new(self)) } - - /// Build a LISTEN query based on the given channel input. - fn build_listen_query(channel: &impl AsRef) -> String { - format!(r#"LISTEN "{}";"#, channel.as_ref()) - } } impl PgPool { /// Fetch a new connection from the pool and register it as a listener on the specified channel. - pub async fn listen(&self, channel: &impl AsRef) -> Result> { + pub async fn listen(&self, channel: impl AsRef) -> Result> { let mut conn = self.acquire().await?; - let cmd = PgConnection::build_listen_query(channel); - let _ = conn.execute(cmd.as_str(), Default::default()).await?; + let cmd = build_listen_all_query(&[channel]); + let _ = conn.send(cmd.as_str()).await?; Ok(PgListener::new(conn)) } @@ -56,31 +51,31 @@ impl PgPool { channels: impl IntoIterator>, ) -> Result> { let mut conn = self.acquire().await?; - for channel in channels { - let cmd = PgConnection::build_listen_query(&channel); - let _ = conn.execute(cmd.as_str(), Default::default()).await?; - } + let cmd = build_listen_all_query(channels); + let _ = conn.send(cmd.as_str()).await?; Ok(PgListener::new(conn)) } } impl PgPoolConnection { - /// Fetch a new connection from the pool and register it as a listener on the specified channel. - pub async fn listen(mut self, channel: &impl AsRef) -> Result> { - let cmd = PgConnection::build_listen_query(channel); - let _ = self.execute(cmd.as_str(), Default::default()).await?; + /// Register this connection as a listener on the specified channel. + /// + /// If an error is returned here, the connection will be dropped. + pub async fn listen(mut self, channel: impl AsRef) -> Result> { + let cmd = build_listen_all_query(&[channel]); + let _ = self.send(cmd.as_str()).await?; Ok(PgListener::new(self)) } - /// Fetch a new connection from the pool and register it as a listener on the specified channels. + /// Register this connection as a listener on all of the specified channels. + /// + /// If an error is returned here, the connection will be dropped. pub async fn listen_all( mut self, channels: impl IntoIterator>, ) -> Result> { - for channel in channels { - let cmd = PgConnection::build_listen_query(&channel); - let _ = self.execute(cmd.as_str(), Default::default()).await?; - } + let cmd = build_listen_all_query(channels); + let _ = self.send(cmd.as_str()).await?; Ok(PgListener::new(self)) } } @@ -99,16 +94,21 @@ impl PgListener { impl PgListener where - C: AsMut, + C: DerefMut, { /// Get the next async notification from the database. pub async fn next(&mut self) -> Result { loop { - match self.0.as_mut().receive().await? { + match (&mut self.0).receive().await? { Some(Message::NotificationResponse(notification)) => return Ok(notification.into()), - // TODO: verify with team if this is correct. Looks like the connection being closed will cause an error - // to propagate up from `recevie`, but it would be good to verify with team. - Some(_) | None => continue, + Some(msg) => { + return Err(protocol_err!( + "unexpected message received from database {:?}", + msg + ) + .into()) + } + None => continue, } } } @@ -170,6 +170,7 @@ impl> crate::Executor for PgListener { } /// An asynchronous message sent from the database. +#[derive(Debug)] #[non_exhaustive] pub struct NotifyMessage { /// The channel of the notification, which can be thought of as a topic. @@ -186,3 +187,30 @@ impl From> for NotifyMessage { } } } + +/// Build a query which issues a LISTEN command for each given channel. +fn build_listen_all_query(channels: impl IntoIterator>) -> String { + channels.into_iter().fold(String::new(), |mut acc, chan| { + acc.push_str(r#"LISTEN ""#); + acc.push_str(chan.as_ref()); + acc.push_str(r#"";"#); + acc + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_listen_all_query_with_single_channel() { + let output = build_listen_all_query(&["test"]); + assert_eq!(output.as_str(), r#"LISTEN "test";"#); + } + + #[test] + fn build_listen_all_query_with_multiple_channels() { + let output = build_listen_all_query(&["channel.0", "channel.1"]); + assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#); + } +}