Clean up Pool/Connection impl and allow an explicit checkout from Pool to Connection

This commit is contained in:
Ryan Leckey 2019-09-04 17:56:58 -07:00
parent 7bf023b742
commit dce5718981
4 changed files with 314 additions and 239 deletions

View File

@ -2,6 +2,7 @@ use crate::{
backend::Backend,
error::Error,
executor::Executor,
pool::{Live, SharedPool},
query::{IntoQueryParameters, QueryParameters},
row::FromSqlRow,
};
@ -11,11 +12,11 @@ use futures_channel::oneshot::{channel, Sender};
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::{stream::StreamExt, TryFutureExt};
use std::{
ops::{Deref, DerefMut},
sync::{
atomic::{AtomicBool, Ordering},
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Instant,
};
/// A connection to the database.
@ -87,37 +88,34 @@ impl<DB> Connection<DB>
where
DB: Backend,
{
pub async fn establish(url: &str) -> Result<Self, Error> {
let raw = <DB as Backend>::RawConnection::establish(url).await?;
pub(crate) fn new(live: Live<DB>, pool: Option<Arc<SharedPool<DB>>>) -> Self {
let shared = SharedConnection {
raw: AtomicCell::new(Some(Box::new(raw))),
waiting: AtomicBool::new(false),
live: AtomicCell::new(Some(live)),
num_waiters: AtomicUsize::new(0),
waiters: SegQueue::new(),
pool,
};
Ok(Self(Arc::new(shared)))
Self(Arc::new(shared))
}
#[inline]
async fn acquire(&self) -> ConnectionFairy<'_, DB> {
ConnectionFairy::new(&self.0, self.0.acquire().await)
}
pub async fn establish(url: &str) -> crate::Result<Self> {
let raw = <DB as Backend>::RawConnection::establish(url).await?;
let live = Live {
raw,
since: Instant::now(),
};
/// Release resources for this database connection immediately.
///
/// This method is not required to be called. A database server will eventually notice
/// and clean up not fully closed connections.
///
/// It is safe to close an already closed connection.
pub async fn close(&self) -> crate::Result<()> {
let mut conn = self.acquire().await;
conn.close().await
Ok(Self::new(live, None))
}
/// Verifies a connection to the database is still alive.
pub async fn ping(&self) -> crate::Result<()> {
let mut conn = self.acquire().await;
conn.ping().await
let mut live = self.0.acquire().await;
live.raw.ping().await?;
self.0.release(live);
Ok(())
}
}
@ -136,8 +134,11 @@ where
A: IntoQueryParameters<Self::Backend> + Send,
{
Box::pin(async move {
let mut conn = self.acquire().await;
conn.execute(query, params.into()).await
let mut live = self.0.acquire().await;
let result = live.raw.execute(query, params.into()).await;
self.0.release(live);
result
})
}
@ -151,12 +152,15 @@ where
T: FromSqlRow<Self::Backend> + Send + Unpin,
{
Box::pin(async_stream::try_stream! {
let mut conn = self.acquire().await;
let mut s = conn.fetch(query, params.into());
let mut live = self.0.acquire().await;
let mut s = live.raw.fetch(query, params.into());
while let Some(row) = s.next().await.transpose()? {
yield T::from_row(row);
}
drop(s);
self.0.release(live);
})
}
@ -170,8 +174,9 @@ where
T: FromSqlRow<Self::Backend>,
{
Box::pin(async move {
let mut conn = self.acquire().await;
let row = conn.fetch_optional(query, params.into()).await?;
let mut live = self.0.acquire().await;
let row = live.raw.fetch_optional(query, params.into()).await?;
self.0.release(live);
Ok(row.map(T::from_row))
})
@ -182,100 +187,62 @@ struct SharedConnection<DB>
where
DB: Backend,
{
raw: AtomicCell<Option<Box<DB::RawConnection>>>,
waiting: AtomicBool,
waiters: SegQueue<Sender<Box<DB::RawConnection>>>,
live: AtomicCell<Option<Live<DB>>>,
pool: Option<Arc<SharedPool<DB>>>,
num_waiters: AtomicUsize,
waiters: SegQueue<Sender<Live<DB>>>,
}
impl<DB> SharedConnection<DB>
where
DB: Backend,
{
async fn acquire(&self) -> Box<DB::RawConnection> {
if let Some(raw) = self.raw.swap(None) {
async fn acquire(&self) -> Live<DB> {
if let Some(live) = self.live.swap(None) {
// Fast path, this connection is not currently in use.
// We can directly return the inner connection.
return raw;
return live;
}
let (sender, receiver) = channel();
self.waiters.push(sender);
self.waiting.store(true, Ordering::Release);
self.num_waiters.fetch_add(1, Ordering::AcqRel);
// TODO: Handle this error
receiver.await.unwrap()
// Waiters are not dropped unless the pool is dropped
// which would drop this future
receiver
.await
.expect("waiter dropped without dropping connection")
}
fn release(&self, mut raw: Box<DB::RawConnection>) {
// If we have any waiters, iterate until we find a non-dropped waiter
if self.waiting.load(Ordering::Acquire) {
fn release(&self, mut live: Live<DB>) {
if self.num_waiters.load(Ordering::Acquire) > 0 {
while let Ok(waiter) = self.waiters.pop() {
raw = match waiter.send(raw) {
Err(raw) => raw,
Ok(_) => {
self.num_waiters.fetch_sub(1, Ordering::AcqRel);
live = match waiter.send(live) {
Ok(()) => {
return;
}
Err(live) => live,
};
}
}
// Otherwise, just re-store the connection until
// we are needed again
self.raw.store(Some(raw));
self.live.store(Some(live));
}
}
struct ConnectionFairy<'a, DB>
where
DB: Backend,
{
shared: &'a Arc<SharedConnection<DB>>,
raw: Option<Box<DB::RawConnection>>,
}
impl<'a, DB> ConnectionFairy<'a, DB>
where
DB: Backend,
{
#[inline]
fn new(shared: &'a Arc<SharedConnection<DB>>, raw: Box<DB::RawConnection>) -> Self {
Self {
shared,
raw: Some(raw),
}
}
}
impl<DB> Deref for ConnectionFairy<'_, DB>
where
DB: Backend,
{
type Target = DB::RawConnection;
#[inline]
fn deref(&self) -> &Self::Target {
self.raw.as_ref().expect("connection use after drop")
}
}
impl<DB> DerefMut for ConnectionFairy<'_, DB>
where
DB: Backend,
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.raw.as_mut().expect("connection use after drop")
}
}
impl<DB> Drop for ConnectionFairy<'_, DB>
impl<DB> Drop for SharedConnection<DB>
where
DB: Backend,
{
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
self.shared.release(raw);
if let Some(pool) = &self.pool {
// This error should not be able to happen
pool.release(self.live.take().expect("drop while checked out"));
}
}
}

View File

@ -1,33 +1,102 @@
use crate::{
backend::Backend, connection::RawConnection, error::Error, executor::Executor,
query::IntoQueryParameters, row::FromSqlRow,
backend::Backend,
connection::{Connection, RawConnection},
error::Error,
executor::Executor,
query::IntoQueryParameters,
row::FromSqlRow,
};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures_channel::oneshot;
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::StreamExt;
use std::{
ops::{Deref, DerefMut},
marker::PhantomData,
sync::{
atomic::{AtomicUsize, Ordering},
atomic::{AtomicU32, AtomicUsize, Ordering},
Arc,
},
time::{Duration, Instant},
};
pub struct PoolOptions {
pub max_size: usize,
pub min_idle: Option<usize>,
pub max_lifetime: Option<Duration>,
pub idle_timeout: Option<Duration>,
pub connection_timeout: Option<Duration>,
}
/// A database connection pool.
/// A pool of database connections.
pub struct Pool<DB>(Arc<SharedPool<DB>>)
where
DB: Backend;
impl<DB> Pool<DB>
where
DB: Backend,
{
/// Creates a connection pool with the default configuration.
pub async fn new(url: &str) -> crate::Result<Self> {
Ok(Pool(Arc::new(
SharedPool::new(url, Options::default()).await?,
)))
}
/// Returns a [Builder] to configure a new connection pool.
pub fn builder() -> Builder<DB> {
Builder::new()
}
/// Retrieves a connection from the pool.
///
/// Waits for at most the configured connection timeout before returning an error.
pub async fn acquire(&self) -> crate::Result<Connection<DB>> {
let live = self.0.acquire().await?;
Ok(Connection::new(live, Some(Arc::clone(&self.0))))
}
/// Attempts to retrieve a connection from the pool if there is one available.
///
/// Returns `None` if there are no idle connections available in the pool.
/// This method will not block waiting to establish a new connection.
pub fn try_acquire(&self) -> Option<Connection<DB>> {
let live = self.0.try_acquire()?;
Some(Connection::new(live, Some(Arc::clone(&self.0))))
}
/// Ends the use of a connection pool. Prevents any new connections
/// and will close all active connections when they are returned to the pool.
///
/// Does not resolve until all connections are closed.
pub async fn close(&self) {
unimplemented!()
}
/// Returns the number of connections currently being managed by the pool.
pub fn size(&self) -> u32 {
self.0.size.load(Ordering::Acquire)
}
/// Returns the number of idle connections.
pub fn idle(&self) -> usize {
self.0.num_idle.load(Ordering::Acquire)
}
/// Returns the configured maximum pool size.
pub fn max_size(&self) -> u32 {
self.0.options.max_size
}
/// Returns the configured mimimum idle connection count.
pub fn min_idle(&self) -> Option<u32> {
self.0.options.min_idle
}
/// Returns the configured maximum connection lifetime.
pub fn max_lifetime(&self) -> Option<Duration> {
self.0.options.max_lifetime
}
/// Returns the configured idle connection timeout.
pub fn idle_timeout(&self) -> Option<Duration> {
self.0.options.idle_timeout
}
}
/// Returns a new [Pool] tied to the same shared connection pool.
impl<DB> Clone for Pool<DB>
where
DB: Backend,
@ -37,88 +106,169 @@ where
}
}
impl<DB> Pool<DB>
pub struct Builder<DB>
where
DB: Backend,
{
// TODO: PoolBuilder
pub fn new(url: &str, max_size: usize) -> Self {
Self(Arc::new(SharedPool {
url: url.to_owned(),
idle: ArrayQueue::new(max_size),
total: AtomicUsize::new(0),
waiters: SegQueue::new(),
options: PoolOptions {
idle_timeout: None,
connection_timeout: None,
max_lifetime: None,
max_size,
min_idle: None,
},
}))
phantom: PhantomData<DB>,
options: Options,
}
impl<DB> Builder<DB>
where
DB: Backend,
{
pub fn new() -> Self {
Self {
phantom: PhantomData,
options: Options::default(),
}
}
pub fn max_size(mut self, max_size: u32) -> Self {
self.options.max_size = max_size;
self
}
pub fn min_idle(mut self, min_idle: impl Into<Option<u32>>) -> Self {
self.options.min_idle = min_idle.into();
self
}
pub fn max_lifetime(mut self, max_lifetime: impl Into<Option<Duration>>) -> Self {
self.options.max_lifetime = max_lifetime.into();
self
}
pub fn idle_timeout(mut self, idle_timeout: impl Into<Option<Duration>>) -> Self {
self.options.idle_timeout = idle_timeout.into();
self
}
pub async fn build(self, url: &str) -> crate::Result<Pool<DB>> {
Ok(Pool(Arc::new(SharedPool::new(url, self.options).await?)))
}
}
struct SharedPool<DB>
struct Options {
max_size: u32,
min_idle: Option<u32>,
max_lifetime: Option<Duration>,
idle_timeout: Option<Duration>,
}
impl Default for Options {
fn default() -> Self {
Self {
max_size: 10,
min_idle: None,
max_lifetime: None,
idle_timeout: None,
}
}
}
pub(crate) struct SharedPool<DB>
where
DB: Backend,
{
url: String,
idle: ArrayQueue<Idle<DB>>,
waiters: SegQueue<oneshot::Sender<Live<DB>>>,
total: AtomicUsize,
options: PoolOptions,
size: AtomicU32,
num_waiters: AtomicUsize,
num_idle: AtomicUsize,
options: Options,
}
impl<DB> SharedPool<DB>
where
DB: Backend,
{
async fn acquire(&self) -> Result<Live<DB>, Error> {
if let Ok(idle) = self.idle.pop() {
return Ok(idle.live);
}
async fn new(url: &str, options: Options) -> crate::Result<Self> {
// TODO: Establish [min_idle] connections
let total = self.total.load(Ordering::SeqCst);
if total >= self.options.max_size {
// Too many already, add a waiter and wait for
// a free connection
let (sender, reciever) = oneshot::channel();
self.waiters.push(sender);
// TODO: Handle errors here
return Ok(reciever.await.unwrap());
}
self.total.store(total + 1, Ordering::SeqCst);
let raw = <DB as Backend>::RawConnection::establish(&self.url).await?;
let live = Live {
raw,
since: Instant::now(),
};
Ok(live)
Ok(Self {
url: url.to_owned(),
idle: ArrayQueue::new(options.max_size as usize),
waiters: SegQueue::new(),
size: AtomicU32::new(0),
num_idle: AtomicUsize::new(0),
num_waiters: AtomicUsize::new(0),
options,
})
}
fn release(&self, mut live: Live<DB>) {
while let Ok(waiter) = self.waiters.pop() {
live = match waiter.send(live) {
Ok(()) => {
return;
}
#[inline]
fn try_acquire(&self) -> Option<Live<DB>> {
if let Ok(idle) = self.idle.pop() {
self.num_idle.fetch_sub(1, Ordering::AcqRel);
Err(live) => live,
};
return Some(idle.live);
}
None
}
async fn acquire(&self) -> crate::Result<Live<DB>> {
if let Some(live) = self.try_acquire() {
return Ok(live);
}
loop {
let size = self.size.load(Ordering::Acquire);
if size >= self.options.max_size {
// Too many open connections
// Wait until one is available
let (sender, receiver) = oneshot::channel();
self.waiters.push(sender);
self.num_waiters.fetch_add(1, Ordering::AcqRel);
// Waiters are not dropped unless the pool is dropped
// which would drop this future
return Ok(receiver
.await
.expect("waiter dropped without dropping pool"));
}
if self.size.compare_and_swap(size, size + 1, Ordering::AcqRel) == size {
// Open a new connection and return directly
let raw = <DB as Backend>::RawConnection::establish(&self.url).await?;
let live = Live {
raw,
since: Instant::now(),
};
return Ok(live);
}
}
}
pub(crate) fn release(&self, mut live: Live<DB>) {
if self.num_waiters.load(Ordering::Acquire) > 0 {
while let Ok(waiter) = self.waiters.pop() {
self.num_waiters.fetch_sub(1, Ordering::AcqRel);
live = match waiter.send(live) {
Ok(()) => {
return;
}
Err(live) => live,
};
}
}
let _ = self.idle.push(Idle {
live,
since: Instant::now(),
});
self.num_idle.fetch_add(1, Ordering::AcqRel);
}
}
@ -137,10 +287,11 @@ where
A: IntoQueryParameters<Self::Backend> + Send,
{
Box::pin(async move {
let live = self.0.acquire().await?;
let mut conn = PooledConnection::new(&self.0, live);
let mut live = self.0.acquire().await?;
let result = live.raw.execute(query, params.into()).await;
self.0.release(live);
conn.execute(query, params.into()).await
result
})
}
@ -154,13 +305,15 @@ where
T: FromSqlRow<Self::Backend> + Send + Unpin,
{
Box::pin(async_stream::try_stream! {
let live = self.0.acquire().await?;
let mut conn = PooledConnection::new(&self.0, live);
let mut s = conn.fetch(query, params.into());
let mut live = self.0.acquire().await?;
let mut s = live.raw.fetch(query, params.into());
while let Some(row) = s.next().await.transpose()? {
yield T::from_row(row);
}
drop(s);
self.0.release(live);
})
}
@ -174,82 +327,30 @@ where
T: FromSqlRow<Self::Backend>,
{
Box::pin(async move {
let live = self.0.acquire().await?;
let mut conn = PooledConnection::new(&self.0, live);
let row = conn.fetch_optional(query, params.into()).await?;
let mut live = self.0.acquire().await?;
let row = live.raw.fetch_optional(query, params.into()).await?;
self.0.release(live);
Ok(row.map(T::from_row))
})
}
}
struct PooledConnection<'a, DB>
where
DB: Backend,
{
shared: &'a Arc<SharedPool<DB>>,
live: Option<Live<DB>>,
}
impl<'a, DB> PooledConnection<'a, DB>
where
DB: Backend,
{
fn new(shared: &'a Arc<SharedPool<DB>>, live: Live<DB>) -> Self {
Self {
shared,
live: Some(live),
}
}
}
impl<DB> Deref for PooledConnection<'_, DB>
where
DB: Backend,
{
type Target = DB::RawConnection;
fn deref(&self) -> &Self::Target {
&self.live.as_ref().expect("connection use after drop").raw
}
}
impl<DB> DerefMut for PooledConnection<'_, DB>
where
DB: Backend,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.live.as_mut().expect("connection use after drop").raw
}
}
impl<DB> Drop for PooledConnection<'_, DB>
where
DB: Backend,
{
fn drop(&mut self) {
if let Some(live) = self.live.take() {
self.shared.release(live);
}
}
}
struct Idle<DB>
where
DB: Backend,
{
live: Live<DB>,
// TODO: Implement idle connection timeouts
#[allow(unused)]
since: Instant,
}
struct Live<DB>
pub(crate) struct Live<DB>
where
DB: Backend,
{
raw: DB::RawConnection,
// TODO: Implement live connection timeouts
pub(crate) raw: DB::RawConnection,
#[allow(unused)]
since: Instant,
pub(crate) since: Instant,
}

View File

@ -229,7 +229,10 @@ impl PostgresRawConnection {
async fn step(&mut self) -> crate::Result<Option<Step>> {
while let Some(message) = self.receive().await? {
match message {
Message::BindComplete | Message::ParseComplete | Message::PortalSuspended | Message::CloseComplete => {}
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete => {}
Message::CommandComplete(body) => {
return Ok(Some(Step::Command(body.affected_rows())));

View File

@ -85,16 +85,18 @@ mod tests {
.await
.unwrap();
let res: Option<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'")
.fetch_optional(&conn)
.await
.unwrap();
let res: Option<(String, bool)> =
crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'")
.fetch_optional(&conn)
.await
.unwrap();
assert!(res.is_none());
let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'")
.fetch_one(&conn)
.await;
let res: crate::Result<(String, bool)> =
crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'")
.fetch_one(&conn)
.await;
matches::assert_matches!(res, Err(crate::Error::NotFound));
}
@ -105,9 +107,10 @@ mod tests {
.await
.unwrap();
let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles")
.fetch_one(&conn)
.await;
let res: crate::Result<(String, bool)> =
crate::query("SELECT rolname, rolsuper FROM pg_roles")
.fetch_one(&conn)
.await;
matches::assert_matches!(res, Err(crate::Error::FoundMoreThanOne));
}
@ -118,10 +121,11 @@ mod tests {
.await
.unwrap();
let res: (String, bool) = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'postgres'")
.fetch_one(&conn)
.await
.unwrap();
let res: (String, bool) =
crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'postgres'")
.fetch_one(&conn)
.await
.unwrap();
assert_eq!(res.0, "postgres");
assert!(res.1);