mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-02-15 04:09:37 +00:00
Updates from review and from testing.
This commit is contained in:
parent
a52f36468b
commit
cb186e6a13
@ -1,3 +1,5 @@
|
||||
use std::ops::DerefMut;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
|
||||
@ -15,9 +17,9 @@ impl PgConnection {
|
||||
/// Register this connection as a listener on the specified channel.
|
||||
///
|
||||
/// If an error is returned here, the connection will be dropped.
|
||||
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
|
||||
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
|
||||
let _ = self.execute(cmd.as_str(), Default::default()).await?;
|
||||
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
|
||||
let cmd = build_listen_all_query(&[channel]);
|
||||
let _ = self.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(self))
|
||||
}
|
||||
|
||||
@ -28,25 +30,18 @@ impl PgConnection {
|
||||
mut self,
|
||||
channels: impl IntoIterator<Item = impl AsRef<str>>,
|
||||
) -> Result<PgListener<Self>> {
|
||||
for channel in channels {
|
||||
let cmd = format!(r#"LISTEN "{}""#, channel.as_ref());
|
||||
let _ = self.execute(cmd.as_str(), Default::default()).await?;
|
||||
}
|
||||
let cmd = build_listen_all_query(channels);
|
||||
let _ = self.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(self))
|
||||
}
|
||||
|
||||
/// Build a LISTEN query based on the given channel input.
|
||||
fn build_listen_query(channel: &impl AsRef<str>) -> String {
|
||||
format!(r#"LISTEN "{}";"#, channel.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl PgPool {
|
||||
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
|
||||
pub async fn listen(&self, channel: &impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
|
||||
pub async fn listen(&self, channel: impl AsRef<str>) -> Result<PgListener<PgPoolConnection>> {
|
||||
let mut conn = self.acquire().await?;
|
||||
let cmd = PgConnection::build_listen_query(channel);
|
||||
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
|
||||
let cmd = build_listen_all_query(&[channel]);
|
||||
let _ = conn.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(conn))
|
||||
}
|
||||
|
||||
@ -56,31 +51,31 @@ impl PgPool {
|
||||
channels: impl IntoIterator<Item = impl AsRef<str>>,
|
||||
) -> Result<PgListener<PgPoolConnection>> {
|
||||
let mut conn = self.acquire().await?;
|
||||
for channel in channels {
|
||||
let cmd = PgConnection::build_listen_query(&channel);
|
||||
let _ = conn.execute(cmd.as_str(), Default::default()).await?;
|
||||
}
|
||||
let cmd = build_listen_all_query(channels);
|
||||
let _ = conn.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(conn))
|
||||
}
|
||||
}
|
||||
|
||||
impl PgPoolConnection {
|
||||
/// Fetch a new connection from the pool and register it as a listener on the specified channel.
|
||||
pub async fn listen(mut self, channel: &impl AsRef<str>) -> Result<PgListener<Self>> {
|
||||
let cmd = PgConnection::build_listen_query(channel);
|
||||
let _ = self.execute(cmd.as_str(), Default::default()).await?;
|
||||
/// Register this connection as a listener on the specified channel.
|
||||
///
|
||||
/// If an error is returned here, the connection will be dropped.
|
||||
pub async fn listen(mut self, channel: impl AsRef<str>) -> Result<PgListener<Self>> {
|
||||
let cmd = build_listen_all_query(&[channel]);
|
||||
let _ = self.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(self))
|
||||
}
|
||||
|
||||
/// Fetch a new connection from the pool and register it as a listener on the specified channels.
|
||||
/// Register this connection as a listener on all of the specified channels.
|
||||
///
|
||||
/// If an error is returned here, the connection will be dropped.
|
||||
pub async fn listen_all(
|
||||
mut self,
|
||||
channels: impl IntoIterator<Item = impl AsRef<str>>,
|
||||
) -> Result<PgListener<Self>> {
|
||||
for channel in channels {
|
||||
let cmd = PgConnection::build_listen_query(&channel);
|
||||
let _ = self.execute(cmd.as_str(), Default::default()).await?;
|
||||
}
|
||||
let cmd = build_listen_all_query(channels);
|
||||
let _ = self.send(cmd.as_str()).await?;
|
||||
Ok(PgListener::new(self))
|
||||
}
|
||||
}
|
||||
@ -99,16 +94,21 @@ impl<C> PgListener<C> {
|
||||
|
||||
impl<C> PgListener<C>
|
||||
where
|
||||
C: AsMut<PgConnection>,
|
||||
C: DerefMut<Target = PgConnection>,
|
||||
{
|
||||
/// Get the next async notification from the database.
|
||||
pub async fn next(&mut self) -> Result<NotifyMessage> {
|
||||
loop {
|
||||
match self.0.as_mut().receive().await? {
|
||||
match (&mut self.0).receive().await? {
|
||||
Some(Message::NotificationResponse(notification)) => return Ok(notification.into()),
|
||||
// TODO: verify with team if this is correct. Looks like the connection being closed will cause an error
|
||||
// to propagate up from `recevie`, but it would be good to verify with team.
|
||||
Some(_) | None => continue,
|
||||
Some(msg) => {
|
||||
return Err(protocol_err!(
|
||||
"unexpected message received from database {:?}",
|
||||
msg
|
||||
)
|
||||
.into())
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -170,6 +170,7 @@ impl<C: Connection<Database = Postgres>> crate::Executor for PgListener<C> {
|
||||
}
|
||||
|
||||
/// An asynchronous message sent from the database.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub struct NotifyMessage {
|
||||
/// The channel of the notification, which can be thought of as a topic.
|
||||
@ -186,3 +187,30 @@ impl From<Box<NotificationResponse>> for NotifyMessage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a query which issues a LISTEN command for each given channel.
|
||||
fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
|
||||
channels.into_iter().fold(String::new(), |mut acc, chan| {
|
||||
acc.push_str(r#"LISTEN ""#);
|
||||
acc.push_str(chan.as_ref());
|
||||
acc.push_str(r#"";"#);
|
||||
acc
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn build_listen_all_query_with_single_channel() {
|
||||
let output = build_listen_all_query(&["test"]);
|
||||
assert_eq!(output.as_str(), r#"LISTEN "test";"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_listen_all_query_with_multiple_channels() {
|
||||
let output = build_listen_all_query(&["channel.0", "channel.1"]);
|
||||
assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user