diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 8bd43940..083b1510 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -59,4 +59,4 @@ default-features = false features = [ "pkg-config", "vcpkg", "bundled" ] [dev-dependencies] -matches = "0.1.8" +matches = "0.1.8" \ No newline at end of file diff --git a/sqlx-core/src/postgres/listen.rs b/sqlx-core/src/postgres/listen.rs new file mode 100644 index 00000000..aad714e0 --- /dev/null +++ b/sqlx-core/src/postgres/listen.rs @@ -0,0 +1,188 @@ +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; + +use crate::connection::Connection; +use crate::describe::Describe; +use crate::executor::Executor; +use crate::pool::PoolConnection; +use crate::postgres::protocol::{Message, NotificationResponse}; +use crate::postgres::{PgArguments, PgConnection, PgPool, PgRow, Postgres}; +use crate::Result; + +type PgPoolConnection = PoolConnection; + +impl PgConnection { + /// Register this connection as a listener on the specified channel. + /// + /// If an error is returned here, the connection will be dropped. + pub async fn listen(mut self, channel: &impl AsRef) -> Result> { + let cmd = format!(r#"LISTEN "{}""#, channel.as_ref()); + let _ = self.execute(cmd.as_str(), Default::default()).await?; + Ok(PgListener::new(self)) + } + + /// Register this connection as a listener on all of the specified channels. + /// + /// If an error is returned here, the connection will be dropped. + pub async fn listen_all( + mut self, + channels: impl IntoIterator>, + ) -> Result> { + for channel in channels { + let cmd = format!(r#"LISTEN "{}""#, channel.as_ref()); + let _ = self.execute(cmd.as_str(), Default::default()).await?; + } + Ok(PgListener::new(self)) + } + + /// Build a LISTEN query based on the given channel input. + fn build_listen_query(channel: &impl AsRef) -> String { + format!(r#"LISTEN "{}";"#, channel.as_ref()) + } +} + +impl PgPool { + /// Fetch a new connection from the pool and register it as a listener on the specified channel. + pub async fn listen(&self, channel: &impl AsRef) -> Result> { + let mut conn = self.acquire().await?; + let cmd = PgConnection::build_listen_query(channel); + let _ = conn.execute(cmd.as_str(), Default::default()).await?; + Ok(PgListener::new(conn)) + } + + /// Fetch a new connection from the pool and register it as a listener on the specified channels. + pub async fn listen_all( + &self, + channels: impl IntoIterator>, + ) -> Result> { + let mut conn = self.acquire().await?; + for channel in channels { + let cmd = PgConnection::build_listen_query(&channel); + let _ = conn.execute(cmd.as_str(), Default::default()).await?; + } + Ok(PgListener::new(conn)) + } +} + +impl PgPoolConnection { + /// Fetch a new connection from the pool and register it as a listener on the specified channel. + pub async fn listen(mut self, channel: &impl AsRef) -> Result> { + let cmd = PgConnection::build_listen_query(channel); + let _ = self.execute(cmd.as_str(), Default::default()).await?; + Ok(PgListener::new(self)) + } + + /// Fetch a new connection from the pool and register it as a listener on the specified channels. + pub async fn listen_all( + mut self, + channels: impl IntoIterator>, + ) -> Result> { + for channel in channels { + let cmd = PgConnection::build_listen_query(&channel); + let _ = self.execute(cmd.as_str(), Default::default()).await?; + } + Ok(PgListener::new(self)) + } +} + +/// A stream of async database notifications. +/// +/// Notifications will always correspond to the channel(s) specified this object is created. +pub struct PgListener(C); + +impl PgListener { + /// Construct a new instance. + pub(self) fn new(conn: C) -> Self { + Self(conn) + } +} + +impl PgListener +where + C: AsMut, +{ + /// Get the next async notification from the database. + pub async fn next(&mut self) -> Result { + loop { + match self.0.as_mut().receive().await? { + Some(Message::NotificationResponse(notification)) => return Ok(notification.into()), + // TODO: verify with team if this is correct. Looks like the connection being closed will cause an error + // to propagate up from `recevie`, but it would be good to verify with team. + Some(_) | None => continue, + } + } + } +} + +impl PgListener +where + C: Connection, +{ + /// Close this listener stream and its underlying connection. + pub async fn close(self) -> BoxFuture<'static, Result<()>> { + self.0.close() + } +} + +impl std::ops::Deref for PgListener { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for PgListener { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl> crate::Executor for PgListener { + type Database = super::Postgres; + + fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, Result<()>> { + Box::pin(self.0.send(query)) + } + + fn execute<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: PgArguments, + ) -> BoxFuture<'e, Result> { + Box::pin(self.0.execute(query, args)) + } + + fn fetch<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: PgArguments, + ) -> BoxStream<'e, Result> { + self.0.fetch(query, args) + } + + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, Result>> { + Box::pin(self.0.describe(query)) + } +} + +/// An asynchronous message sent from the database. +#[non_exhaustive] +pub struct NotifyMessage { + /// The channel of the notification, which can be thought of as a topic. + pub channel: String, + /// The payload of the notification. + pub payload: String, +} + +impl From> for NotifyMessage { + fn from(src: Box) -> Self { + Self { + channel: src.channel_name, + payload: src.message, + } + } +} diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 7ae97c9e..2e433c38 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -5,6 +5,7 @@ pub use connection::PgConnection; pub use cursor::PgCursor; pub use database::Postgres; pub use error::PgError; +pub use listen::{NotifyMessage, PgListener}; pub use row::{PgRow, PgValue}; pub use types::PgTypeInfo; @@ -14,6 +15,7 @@ mod cursor; mod database; mod error; mod executor; +mod listen; mod protocol; mod row; mod sasl; @@ -27,4 +29,4 @@ pub type PgPool = crate::pool::Pool; make_query_as!(PgQueryAs, Postgres, PgRow); impl_map_row_for_row!(Postgres, PgRow); impl_column_index_for_row!(Postgres); -impl_from_row_for_tuples!(Postgres, PgRow); +impl_from_row_for_tuples!(Postgres, PgRow); \ No newline at end of file