diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index b83bad7c..b8cc9d7f 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -41,6 +41,9 @@ pub enum Error { /// because another task encountered too many errors while trying to open a new connection. TimedOut, + /// `Pool::close()` was called while we were waiting in `Pool::acquire()`. + PoolClosed, + // TODO: Remove and replace with `#[non_exhaustive]` when possible #[doc(hidden)] __Nonexhaustive, diff --git a/sqlx-core/src/pool.rs b/sqlx-core/src/pool.rs index 76a15347..6ef9302b 100644 --- a/sqlx-core/src/pool.rs +++ b/sqlx-core/src/pool.rs @@ -15,7 +15,7 @@ use std::{ marker::PhantomData, ops::{Deref, DerefMut}, sync::{ - atomic::{AtomicU32, AtomicUsize, Ordering}, + atomic::{AtomicU32, AtomicUsize, AtomicBool, Ordering}, Arc, }, time::{Duration, Instant}, @@ -68,7 +68,7 @@ where /// /// Does not resolve until all connections are closed. pub async fn close(&self) { - unimplemented!() + let _ = self.0.close().await; } /// Returns the number of connections currently being managed by the pool. @@ -195,6 +195,7 @@ where pool_rx: Receiver>, pool_tx: Sender>, size: AtomicU32, + closed: AtomicBool, options: Options, } @@ -212,12 +213,35 @@ where pool_rx, pool_tx, size: AtomicU32::new(0), + closed: AtomicBool::new(false), options, }) } + async fn close(&self) { + self.closed.store(true, Ordering::Release); + + while self.size.load(Ordering::Acquire) > 0 { + // don't block on the receiver because we own one Sender so it should never return + // `None`; a `select!()` would also work but that produces more complicated code + // and a timeout isn't necessarily appropriate + match self.pool_rx.recv().now_or_never() { + Some(Some(idle)) => { + let _ = idle.raw.inner.close().await; + self.size.fetch_sub(1, Ordering::AcqRel); + }, + Some(None) => panic!("we own a Sender how did this happen"), + None => task::yield_now().await, + } + } + } + #[inline] fn try_acquire(&self) -> Option> { + if self.closed.load(Ordering::Acquire) { + return None; + } + Some(self.pool_rx.recv().now_or_never()??.live(&self.pool_tx)) } @@ -229,7 +253,7 @@ where return Ok(live); } - loop { + while !self.closed.load(Ordering::Acquire) { let size = self.size.load(Ordering::Acquire); if size >= self.options.max_size { @@ -248,6 +272,12 @@ where Err(_) => continue, }; + if self.closed.load(Ordering::Acquire) { + let _ = idle.raw.inner.close().await; + self.size.fetch_sub(1, Ordering::AcqRel); + return Err(Error::PoolClosed); + } + // check if idle connection was within max lifetime (or not set) if self.options.max_lifetime.map_or(true, |max| idle.raw.created.elapsed() < max) // and if connection wasn't idle too long (or not set) @@ -277,10 +307,17 @@ where return self.new_conn(deadline).await } } + + Err(Error::PoolClosed) } async fn new_conn(&self, deadline: Instant) -> crate::Result> { while Instant::now() < deadline { + if self.closed.load(Ordering::Acquire) { + self.size.fetch_sub(1, Ordering::AcqRel); + return Err(Error::PoolClosed); + } + // result here is `Result, TimeoutError>` match timeout(deadline - Instant::now(), DB::open(&self.url)).await { Ok(Ok(raw)) => return Ok(Live::pooled(raw, &self.pool_tx)),