diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index fa677f43..2c6ce758 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -5,7 +5,9 @@ use std::str::from_utf8; use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; -use futures_util::{FutureExt, StreamExt, TryStreamExt}; +use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +use sqlx_core::acquire::Acquire; +use sqlx_core::transaction::Transaction; use sqlx_core::Either; use crate::describe::Describe; @@ -328,6 +330,19 @@ impl Drop for PgListener { } } +impl<'c> Acquire<'c> for &'c mut PgListener { + type Database = Postgres; + type Connection = &'c mut PgConnection; + + fn acquire(self) -> BoxFuture<'c, Result> { + self.connection().boxed() + } + + fn begin(self) -> BoxFuture<'c, Result, Error>> { + self.connection().and_then(|c| c.begin()).boxed() + } +} + impl<'c> Executor<'c> for &'c mut PgListener { type Database = Postgres; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7edb5a7a..e6c397d9 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1074,6 +1074,45 @@ async fn test_pg_listener_allows_pool_to_close() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_pg_listener_implements_acquire() -> anyhow::Result<()> { + use sqlx::Acquire; + + let pool = pool::().await?; + + let mut listener = PgListener::connect_with(&pool).await?; + listener + .listen("test_pg_listener_implements_acquire") + .await?; + + // Start a transaction on the underlying connection + let mut txn = listener.begin().await?; + + // This will reuse the same connection, so this connection should be listening to the channel + let channels: Vec = sqlx::query_scalar("SELECT pg_listening_channels()") + .fetch_all(&mut *txn) + .await?; + + assert_eq!(channels, vec!["test_pg_listener_implements_acquire"]); + + // Send a notification + sqlx::query("NOTIFY test_pg_listener_implements_acquire, 'hello'") + .execute(&mut *txn) + .await?; + + txn.commit().await?; + + // And now we can receive the notification we sent in the transaction + let notification = listener.recv().await?; + assert_eq!( + notification.channel(), + "test_pg_listener_implements_acquire" + ); + assert_eq!(notification.payload(), "hello"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> { // Only supported in Postgres 11+