diff --git a/src/connection.rs b/src/connection.rs index d91a0c6d..35825c7e 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -2,6 +2,7 @@ use crate::{ backend::Backend, error::Error, executor::Executor, + pool::{Live, SharedPool}, query::{IntoQueryParameters, QueryParameters}, row::FromSqlRow, }; @@ -11,11 +12,11 @@ use futures_channel::oneshot::{channel, Sender}; use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_util::{stream::StreamExt, TryFutureExt}; use std::{ - ops::{Deref, DerefMut}, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, + time::Instant, }; /// A connection to the database. @@ -87,37 +88,34 @@ impl Connection where DB: Backend, { - pub async fn establish(url: &str) -> Result { - let raw = ::RawConnection::establish(url).await?; + pub(crate) fn new(live: Live, pool: Option>>) -> Self { let shared = SharedConnection { - raw: AtomicCell::new(Some(Box::new(raw))), - waiting: AtomicBool::new(false), + live: AtomicCell::new(Some(live)), + num_waiters: AtomicUsize::new(0), waiters: SegQueue::new(), + pool, }; - Ok(Self(Arc::new(shared))) + Self(Arc::new(shared)) } - #[inline] - async fn acquire(&self) -> ConnectionFairy<'_, DB> { - ConnectionFairy::new(&self.0, self.0.acquire().await) - } + pub async fn establish(url: &str) -> crate::Result { + let raw = ::RawConnection::establish(url).await?; + let live = Live { + raw, + since: Instant::now(), + }; - /// Release resources for this database connection immediately. - /// - /// This method is not required to be called. A database server will eventually notice - /// and clean up not fully closed connections. - /// - /// It is safe to close an already closed connection. - pub async fn close(&self) -> crate::Result<()> { - let mut conn = self.acquire().await; - conn.close().await + Ok(Self::new(live, None)) } /// Verifies a connection to the database is still alive. pub async fn ping(&self) -> crate::Result<()> { - let mut conn = self.acquire().await; - conn.ping().await + let mut live = self.0.acquire().await; + live.raw.ping().await?; + self.0.release(live); + + Ok(()) } } @@ -136,8 +134,11 @@ where A: IntoQueryParameters + Send, { Box::pin(async move { - let mut conn = self.acquire().await; - conn.execute(query, params.into()).await + let mut live = self.0.acquire().await; + let result = live.raw.execute(query, params.into()).await; + self.0.release(live); + + result }) } @@ -151,12 +152,15 @@ where T: FromSqlRow + Send + Unpin, { Box::pin(async_stream::try_stream! { - let mut conn = self.acquire().await; - let mut s = conn.fetch(query, params.into()); + let mut live = self.0.acquire().await; + let mut s = live.raw.fetch(query, params.into()); while let Some(row) = s.next().await.transpose()? { yield T::from_row(row); } + + drop(s); + self.0.release(live); }) } @@ -170,8 +174,9 @@ where T: FromSqlRow, { Box::pin(async move { - let mut conn = self.acquire().await; - let row = conn.fetch_optional(query, params.into()).await?; + let mut live = self.0.acquire().await; + let row = live.raw.fetch_optional(query, params.into()).await?; + self.0.release(live); Ok(row.map(T::from_row)) }) @@ -182,100 +187,62 @@ struct SharedConnection where DB: Backend, { - raw: AtomicCell>>, - waiting: AtomicBool, - waiters: SegQueue>>, + live: AtomicCell>>, + pool: Option>>, + num_waiters: AtomicUsize, + waiters: SegQueue>>, } impl SharedConnection where DB: Backend, { - async fn acquire(&self) -> Box { - if let Some(raw) = self.raw.swap(None) { + async fn acquire(&self) -> Live { + if let Some(live) = self.live.swap(None) { // Fast path, this connection is not currently in use. // We can directly return the inner connection. - return raw; + return live; } let (sender, receiver) = channel(); self.waiters.push(sender); - self.waiting.store(true, Ordering::Release); + self.num_waiters.fetch_add(1, Ordering::AcqRel); - // TODO: Handle this error - receiver.await.unwrap() + // Waiters are not dropped unless the pool is dropped + // which would drop this future + receiver + .await + .expect("waiter dropped without dropping connection") } - fn release(&self, mut raw: Box) { - // If we have any waiters, iterate until we find a non-dropped waiter - if self.waiting.load(Ordering::Acquire) { + fn release(&self, mut live: Live) { + if self.num_waiters.load(Ordering::Acquire) > 0 { while let Ok(waiter) = self.waiters.pop() { - raw = match waiter.send(raw) { - Err(raw) => raw, - Ok(_) => { + self.num_waiters.fetch_sub(1, Ordering::AcqRel); + + live = match waiter.send(live) { + Ok(()) => { return; } + + Err(live) => live, }; } } - // Otherwise, just re-store the connection until - // we are needed again - self.raw.store(Some(raw)); + self.live.store(Some(live)); } } -struct ConnectionFairy<'a, DB> -where - DB: Backend, -{ - shared: &'a Arc>, - raw: Option>, -} - -impl<'a, DB> ConnectionFairy<'a, DB> -where - DB: Backend, -{ - #[inline] - fn new(shared: &'a Arc>, raw: Box) -> Self { - Self { - shared, - raw: Some(raw), - } - } -} - -impl Deref for ConnectionFairy<'_, DB> -where - DB: Backend, -{ - type Target = DB::RawConnection; - - #[inline] - fn deref(&self) -> &Self::Target { - self.raw.as_ref().expect("connection use after drop") - } -} - -impl DerefMut for ConnectionFairy<'_, DB> -where - DB: Backend, -{ - #[inline] - fn deref_mut(&mut self) -> &mut Self::Target { - self.raw.as_mut().expect("connection use after drop") - } -} - -impl Drop for ConnectionFairy<'_, DB> +impl Drop for SharedConnection where DB: Backend, { fn drop(&mut self) { - if let Some(raw) = self.raw.take() { - self.shared.release(raw); + if let Some(pool) = &self.pool { + // This error should not be able to happen + pool.release(self.live.take().expect("drop while checked out")); } } } diff --git a/src/pool.rs b/src/pool.rs index 18c954ff..b2f36194 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,33 +1,102 @@ use crate::{ - backend::Backend, connection::RawConnection, error::Error, executor::Executor, - query::IntoQueryParameters, row::FromSqlRow, + backend::Backend, + connection::{Connection, RawConnection}, + error::Error, + executor::Executor, + query::IntoQueryParameters, + row::FromSqlRow, }; use crossbeam_queue::{ArrayQueue, SegQueue}; use futures_channel::oneshot; use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_util::stream::StreamExt; use std::{ - ops::{Deref, DerefMut}, + marker::PhantomData, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicU32, AtomicUsize, Ordering}, Arc, }, time::{Duration, Instant}, }; -pub struct PoolOptions { - pub max_size: usize, - pub min_idle: Option, - pub max_lifetime: Option, - pub idle_timeout: Option, - pub connection_timeout: Option, -} - -/// A database connection pool. +/// A pool of database connections. pub struct Pool(Arc>) where DB: Backend; +impl Pool +where + DB: Backend, +{ + /// Creates a connection pool with the default configuration. + pub async fn new(url: &str) -> crate::Result { + Ok(Pool(Arc::new( + SharedPool::new(url, Options::default()).await?, + ))) + } + + /// Returns a [Builder] to configure a new connection pool. + pub fn builder() -> Builder { + Builder::new() + } + + /// Retrieves a connection from the pool. + /// + /// Waits for at most the configured connection timeout before returning an error. + pub async fn acquire(&self) -> crate::Result> { + let live = self.0.acquire().await?; + Ok(Connection::new(live, Some(Arc::clone(&self.0)))) + } + + /// Attempts to retrieve a connection from the pool if there is one available. + /// + /// Returns `None` if there are no idle connections available in the pool. + /// This method will not block waiting to establish a new connection. + pub fn try_acquire(&self) -> Option> { + let live = self.0.try_acquire()?; + Some(Connection::new(live, Some(Arc::clone(&self.0)))) + } + + /// Ends the use of a connection pool. Prevents any new connections + /// and will close all active connections when they are returned to the pool. + /// + /// Does not resolve until all connections are closed. + pub async fn close(&self) { + unimplemented!() + } + + /// Returns the number of connections currently being managed by the pool. + pub fn size(&self) -> u32 { + self.0.size.load(Ordering::Acquire) + } + + /// Returns the number of idle connections. + pub fn idle(&self) -> usize { + self.0.num_idle.load(Ordering::Acquire) + } + + /// Returns the configured maximum pool size. + pub fn max_size(&self) -> u32 { + self.0.options.max_size + } + + /// Returns the configured mimimum idle connection count. + pub fn min_idle(&self) -> Option { + self.0.options.min_idle + } + + /// Returns the configured maximum connection lifetime. + pub fn max_lifetime(&self) -> Option { + self.0.options.max_lifetime + } + + /// Returns the configured idle connection timeout. + pub fn idle_timeout(&self) -> Option { + self.0.options.idle_timeout + } +} + +/// Returns a new [Pool] tied to the same shared connection pool. impl Clone for Pool where DB: Backend, @@ -37,88 +106,169 @@ where } } -impl Pool +pub struct Builder where DB: Backend, { - // TODO: PoolBuilder - pub fn new(url: &str, max_size: usize) -> Self { - Self(Arc::new(SharedPool { - url: url.to_owned(), - idle: ArrayQueue::new(max_size), - total: AtomicUsize::new(0), - waiters: SegQueue::new(), - options: PoolOptions { - idle_timeout: None, - connection_timeout: None, - max_lifetime: None, - max_size, - min_idle: None, - }, - })) + phantom: PhantomData, + options: Options, +} + +impl Builder +where + DB: Backend, +{ + pub fn new() -> Self { + Self { + phantom: PhantomData, + options: Options::default(), + } + } + + pub fn max_size(mut self, max_size: u32) -> Self { + self.options.max_size = max_size; + self + } + + pub fn min_idle(mut self, min_idle: impl Into>) -> Self { + self.options.min_idle = min_idle.into(); + self + } + + pub fn max_lifetime(mut self, max_lifetime: impl Into>) -> Self { + self.options.max_lifetime = max_lifetime.into(); + self + } + + pub fn idle_timeout(mut self, idle_timeout: impl Into>) -> Self { + self.options.idle_timeout = idle_timeout.into(); + self + } + + pub async fn build(self, url: &str) -> crate::Result> { + Ok(Pool(Arc::new(SharedPool::new(url, self.options).await?))) } } -struct SharedPool +struct Options { + max_size: u32, + min_idle: Option, + max_lifetime: Option, + idle_timeout: Option, +} + +impl Default for Options { + fn default() -> Self { + Self { + max_size: 10, + min_idle: None, + max_lifetime: None, + idle_timeout: None, + } + } +} + +pub(crate) struct SharedPool where DB: Backend, { url: String, idle: ArrayQueue>, waiters: SegQueue>>, - total: AtomicUsize, - options: PoolOptions, + size: AtomicU32, + num_waiters: AtomicUsize, + num_idle: AtomicUsize, + options: Options, } impl SharedPool where DB: Backend, { - async fn acquire(&self) -> Result, Error> { - if let Ok(idle) = self.idle.pop() { - return Ok(idle.live); - } + async fn new(url: &str, options: Options) -> crate::Result { + // TODO: Establish [min_idle] connections - let total = self.total.load(Ordering::SeqCst); - - if total >= self.options.max_size { - // Too many already, add a waiter and wait for - // a free connection - let (sender, reciever) = oneshot::channel(); - - self.waiters.push(sender); - - // TODO: Handle errors here - return Ok(reciever.await.unwrap()); - } - - self.total.store(total + 1, Ordering::SeqCst); - - let raw = ::RawConnection::establish(&self.url).await?; - - let live = Live { - raw, - since: Instant::now(), - }; - - Ok(live) + Ok(Self { + url: url.to_owned(), + idle: ArrayQueue::new(options.max_size as usize), + waiters: SegQueue::new(), + size: AtomicU32::new(0), + num_idle: AtomicUsize::new(0), + num_waiters: AtomicUsize::new(0), + options, + }) } - fn release(&self, mut live: Live) { - while let Ok(waiter) = self.waiters.pop() { - live = match waiter.send(live) { - Ok(()) => { - return; - } + #[inline] + fn try_acquire(&self) -> Option> { + if let Ok(idle) = self.idle.pop() { + self.num_idle.fetch_sub(1, Ordering::AcqRel); - Err(live) => live, - }; + return Some(idle.live); + } + + None + } + + async fn acquire(&self) -> crate::Result> { + if let Some(live) = self.try_acquire() { + return Ok(live); + } + + loop { + let size = self.size.load(Ordering::Acquire); + + if size >= self.options.max_size { + // Too many open connections + // Wait until one is available + + let (sender, receiver) = oneshot::channel(); + + self.waiters.push(sender); + self.num_waiters.fetch_add(1, Ordering::AcqRel); + + // Waiters are not dropped unless the pool is dropped + // which would drop this future + return Ok(receiver + .await + .expect("waiter dropped without dropping pool")); + } + + if self.size.compare_and_swap(size, size + 1, Ordering::AcqRel) == size { + // Open a new connection and return directly + + let raw = ::RawConnection::establish(&self.url).await?; + let live = Live { + raw, + since: Instant::now(), + }; + + return Ok(live); + } + } + } + + pub(crate) fn release(&self, mut live: Live) { + if self.num_waiters.load(Ordering::Acquire) > 0 { + while let Ok(waiter) = self.waiters.pop() { + self.num_waiters.fetch_sub(1, Ordering::AcqRel); + + live = match waiter.send(live) { + Ok(()) => { + return; + } + + Err(live) => live, + }; + } } let _ = self.idle.push(Idle { live, since: Instant::now(), }); + + self.num_idle.fetch_add(1, Ordering::AcqRel); } } @@ -137,10 +287,11 @@ where A: IntoQueryParameters + Send, { Box::pin(async move { - let live = self.0.acquire().await?; - let mut conn = PooledConnection::new(&self.0, live); + let mut live = self.0.acquire().await?; + let result = live.raw.execute(query, params.into()).await; + self.0.release(live); - conn.execute(query, params.into()).await + result }) } @@ -154,13 +305,15 @@ where T: FromSqlRow + Send + Unpin, { Box::pin(async_stream::try_stream! { - let live = self.0.acquire().await?; - let mut conn = PooledConnection::new(&self.0, live); - let mut s = conn.fetch(query, params.into()); + let mut live = self.0.acquire().await?; + let mut s = live.raw.fetch(query, params.into()); while let Some(row) = s.next().await.transpose()? { yield T::from_row(row); } + + drop(s); + self.0.release(live); }) } @@ -174,82 +327,30 @@ where T: FromSqlRow, { Box::pin(async move { - let live = self.0.acquire().await?; - let mut conn = PooledConnection::new(&self.0, live); - let row = conn.fetch_optional(query, params.into()).await?; + let mut live = self.0.acquire().await?; + let row = live.raw.fetch_optional(query, params.into()).await?; + + self.0.release(live); Ok(row.map(T::from_row)) }) } } -struct PooledConnection<'a, DB> -where - DB: Backend, -{ - shared: &'a Arc>, - live: Option>, -} - -impl<'a, DB> PooledConnection<'a, DB> -where - DB: Backend, -{ - fn new(shared: &'a Arc>, live: Live) -> Self { - Self { - shared, - live: Some(live), - } - } -} - -impl Deref for PooledConnection<'_, DB> -where - DB: Backend, -{ - type Target = DB::RawConnection; - - fn deref(&self) -> &Self::Target { - &self.live.as_ref().expect("connection use after drop").raw - } -} - -impl DerefMut for PooledConnection<'_, DB> -where - DB: Backend, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live.as_mut().expect("connection use after drop").raw - } -} - -impl Drop for PooledConnection<'_, DB> -where - DB: Backend, -{ - fn drop(&mut self) { - if let Some(live) = self.live.take() { - self.shared.release(live); - } - } -} - struct Idle where DB: Backend, { live: Live, - // TODO: Implement idle connection timeouts #[allow(unused)] since: Instant, } -struct Live +pub(crate) struct Live where DB: Backend, { - raw: DB::RawConnection, - // TODO: Implement live connection timeouts + pub(crate) raw: DB::RawConnection, #[allow(unused)] - since: Instant, + pub(crate) since: Instant, } diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index 2738eba6..e4f1070a 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -229,7 +229,10 @@ impl PostgresRawConnection { async fn step(&mut self) -> crate::Result> { while let Some(message) = self.receive().await? { match message { - Message::BindComplete | Message::ParseComplete | Message::PortalSuspended | Message::CloseComplete => {} + Message::BindComplete + | Message::ParseComplete + | Message::PortalSuspended + | Message::CloseComplete => {} Message::CommandComplete(body) => { return Ok(Some(Step::Command(body.affected_rows()))); diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 4b1d8f20..76f01c1b 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -85,16 +85,18 @@ mod tests { .await .unwrap(); - let res: Option<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") - .fetch_optional(&conn) - .await - .unwrap(); + let res: Option<(String, bool)> = + crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") + .fetch_optional(&conn) + .await + .unwrap(); assert!(res.is_none()); - let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") - .fetch_one(&conn) - .await; + let res: crate::Result<(String, bool)> = + crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") + .fetch_one(&conn) + .await; matches::assert_matches!(res, Err(crate::Error::NotFound)); } @@ -105,9 +107,10 @@ mod tests { .await .unwrap(); - let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles") - .fetch_one(&conn) - .await; + let res: crate::Result<(String, bool)> = + crate::query("SELECT rolname, rolsuper FROM pg_roles") + .fetch_one(&conn) + .await; matches::assert_matches!(res, Err(crate::Error::FoundMoreThanOne)); } @@ -118,10 +121,11 @@ mod tests { .await .unwrap(); - let res: (String, bool) = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'postgres'") - .fetch_one(&conn) - .await - .unwrap(); + let res: (String, bool) = + crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'postgres'") + .fetch_one(&conn) + .await + .unwrap(); assert_eq!(res.0, "postgres"); assert!(res.1);