From c280ad587fa28bef852fe7f53f267b07fd99ca40 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 15 Aug 2019 23:03:25 -0700 Subject: [PATCH] Move to sqlx::query( ... ).execute style API --- Cargo.toml | 11 +- examples/contacts/Cargo.toml | 3 +- examples/contacts/src/main.rs | 97 +++++++---- examples/todos/Cargo.toml | 3 +- examples/todos/src/main.rs | 37 +++-- src/backend.rs | 4 +- src/client.rs | 58 ------- src/connection.rs | 238 +++++++++++++++++++++++++-- src/executor.rs | 55 +++++++ src/lib.rs | 9 +- src/pg/backend.rs | 2 +- src/pg/connection/establish.rs | 6 +- src/pg/connection/execute.rs | 50 +++--- src/pg/connection/fetch.rs | 37 +++++ src/pg/connection/fetch_optional.rs | 34 ++++ src/pg/connection/get.rs | 61 ------- src/pg/connection/mod.rs | 85 +++++++--- src/pg/connection/prepare.rs | 69 -------- src/pg/connection/select.rs | 67 -------- src/pg/mod.rs | 2 +- src/pg/protocol/bind.rs | 29 ---- src/pg/protocol/data_row.rs | 1 - src/pg/query.rs | 195 ++++------------------ src/pool.rs | 242 ++++++++++++++++++---------- src/query.rs | 50 +++++- src/row.rs | 2 +- 26 files changed, 776 insertions(+), 671 deletions(-) delete mode 100644 src/client.rs create mode 100644 src/executor.rs create mode 100644 src/pg/connection/fetch.rs create mode 100644 src/pg/connection/fetch_optional.rs delete mode 100644 src/pg/connection/get.rs delete mode 100644 src/pg/connection/prepare.rs delete mode 100644 src/pg/connection/select.rs diff --git a/Cargo.toml b/Cargo.toml index 5839df8c..c28904d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ description = "Asynchronous and expressive database client in pure Rust." edition = "2018" [features] -default = [] +default = ["postgres"] postgres = [] mariadb = [] @@ -23,6 +23,7 @@ bitflags = "1.1.0" byteorder = "1.3.2" bytes = "0.4.12" crossbeam-queue = "0.1.2" +crossbeam-utils = "0.6.6" enum-tryfrom = "0.2.1" enum-tryfrom-derive = "0.2.1" failure = "0.1.5" @@ -33,7 +34,9 @@ log = "0.4.8" md-5 = "0.8.0" url = "2.1.0" memchr = "2.2.1" -runtime = { version = "=0.3.0-alpha.6", default-features = false } +async-stream = "0.1.0" +tokio = { version = "=0.2.0-alpha.1" } -[dev-dependencies] -runtime = { version = "=0.3.0-alpha.6", default-features = true } +[profile.release] +lto = true +codegen-units = 1 diff --git a/examples/contacts/Cargo.toml b/examples/contacts/Cargo.toml index 4e8753bf..844f903b 100644 --- a/examples/contacts/Cargo.toml +++ b/examples/contacts/Cargo.toml @@ -7,8 +7,7 @@ edition = "2018" sqlx = { path = "../..", features = [ "postgres" ] } failure = "0.1.5" env_logger = "0.6.2" -runtime = { version = "=0.3.0-alpha.6", default-features = false } -runtime-tokio = { version = "=0.3.0-alpha.5" } +tokio = { version = "=0.2.0-alpha.1" } futures-preview = "=0.3.0-alpha.17" fake = { version = "2.0", features=[ "derive" ] } rand = "0.7.0" diff --git a/examples/contacts/src/main.rs b/examples/contacts/src/main.rs index 84ec352b..5c10f0a3 100644 --- a/examples/contacts/src/main.rs +++ b/examples/contacts/src/main.rs @@ -1,4 +1,4 @@ -#![feature(async_await)] +#![feature(async_await, try_blocks)] use failure::Fallible; use fake::{ @@ -9,10 +9,17 @@ use fake::{ }, Dummy, Fake, Faker, }; -use futures::future; -use sqlx::{pg::Pg, Client, Connection, Query}; +use std::time::Duration; +use futures::stream::TryStreamExt; +use std::io; +use sqlx::{ + pg::{Pg, PgQuery}, + Pool, Query, Connection, +}; use std::time::Instant; +type PgPool = Pool; + #[derive(Debug, Dummy)] struct Contact { #[dummy(faker = "Name()")] @@ -31,16 +38,22 @@ struct Contact { phone: String, } -#[runtime::main(runtime_tokio::Tokio)] +#[tokio::main] async fn main() -> Fallible<()> { env_logger::try_init()?; - let client = Client::::new("postgres://postgres@localhost/sqlx__dev"); + let pool = PgPool::new("postgres://postgres@127.0.0.1/sqlx__dev", 1); - { - let mut conn = client.get().await?; - conn.prepare( - r#" + ensure_schema(&pool).await?; + insert(&pool, 500).await?; + select(&pool, 50_000).await?; + + Ok(()) +} + +async fn ensure_schema(pool: &PgPool) -> io::Result<()> { + sqlx::query::( + r#" CREATE TABLE IF NOT EXISTS contacts ( id BIGSERIAL PRIMARY KEY, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), @@ -50,27 +63,29 @@ CREATE TABLE IF NOT EXISTS contacts ( email TEXT NOT NULL, phone TEXT NOT NULL ) - "#, - ) - .execute() + "#, + ) + .execute(&pool) + .await?; + + sqlx::query::("TRUNCATE contacts") + .execute(&pool) .await?; - conn.prepare("TRUNCATE contacts").execute().await?; - } + Ok(()) +} - let mut handles = vec![]; +async fn insert(pool: &PgPool, count: usize) -> io::Result<()> { let start_at = Instant::now(); - let rows = 10_000; - for _ in 0..rows { - let client = client.clone(); + for _ in 0..count { + let pool = pool.clone(); let contact: Contact = Faker.fake(); - let handle: runtime::task::JoinHandle> = runtime::task::spawn(async move { - let mut conn = client.get().await?; - conn.prepare( + + sqlx::query::( r#" - INSERT INTO contacts (name, username, password, email, phone) - VALUES ($1, $2, $3, $4, $5) +INSERT INTO contacts (name, username, password, email, phone) +VALUES ($1, $2, $3, $4, $5) "#, ) .bind(contact.name) @@ -78,18 +93,38 @@ CREATE TABLE IF NOT EXISTS contacts ( .bind(contact.password) .bind(contact.email) .bind(contact.phone) - .execute() + .execute(&pool) .await?; - - Ok(()) - }); - - handles.push(handle); } - future::join_all(handles).await; + let elapsed = start_at.elapsed(); + let per = Duration::from_nanos((elapsed.as_nanos() / (count as u128)) as u64); - println!("insert {} rows in {:?}", rows, start_at.elapsed()); + println!("insert {} rows in {:?} [ 1 in ~{:?} ]", count, elapsed, per); + + Ok(()) +} + +async fn select(pool: &PgPool, iterations: usize) -> io::Result<()> { + let start_at = Instant::now(); + let mut rows: usize = 0; + + for _ in 0..iterations { + // TODO: Once we have FromRow derives we can replace this with Vec + let contacts: Vec<(String, String, String, String, String)> = sqlx::query::( + r#" +SELECT name, username, password, email, phone +FROM contacts + "#, + ).fetch(&pool).try_collect().await?; + + rows = contacts.len(); + } + + let elapsed = start_at.elapsed(); + let per = Duration::from_nanos((elapsed.as_nanos() / (iterations as u128)) as u64); + + println!("select {} rows in ~{:?} [ x{} in {:?} ]", rows, per, iterations, elapsed); Ok(()) } diff --git a/examples/todos/Cargo.toml b/examples/todos/Cargo.toml index 6d387563..d0cdd4e4 100644 --- a/examples/todos/Cargo.toml +++ b/examples/todos/Cargo.toml @@ -7,7 +7,6 @@ edition = "2018" sqlx = { path = "../..", features = [ "postgres" ] } failure = "0.1.5" env_logger = "0.6.2" -runtime = { version = "=0.3.0-alpha.6", default-features = false } -runtime-tokio = { version = "=0.3.0-alpha.5" } +tokio = { version = "=0.2.0-alpha.1" } futures-preview = "=0.3.0-alpha.17" structopt = "0.2.18" diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 1cf685a2..5faa4413 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -2,7 +2,10 @@ use failure::Fallible; use futures::{future, TryStreamExt}; -use sqlx::{pg::PgConnection, Connection, Query}; +use sqlx::{ + pg::{Pg, PgQuery}, + Connection, Query, +}; use structopt::StructOpt; #[derive(StructOpt, Debug)] @@ -20,13 +23,13 @@ enum Command { MarkAsDone { id: i64 }, } -#[runtime::main(runtime_tokio::Tokio)] +#[tokio::main] async fn main() -> Fallible<()> { env_logger::try_init()?; let opt = Options::from_args(); - let mut conn = PgConnection::establish("postgres://postgres@localhost/sqlx__dev").await?; + let mut conn = Connection::::establish("postgres://postgres@127.0.0.1/sqlx__dev").await?; ensure_schema(&mut conn).await?; @@ -47,11 +50,11 @@ async fn main() -> Fallible<()> { Ok(()) } -async fn ensure_schema(conn: &mut PgConnection) -> Fallible<()> { - conn.prepare("BEGIN").execute().await?; +async fn ensure_schema(conn: &mut Connection) -> Fallible<()> { + sqlx::query::("BEGIN").execute(conn).await?; // language=sql - conn.prepare( + sqlx::query::( r#" CREATE TABLE IF NOT EXISTS tasks ( id BIGSERIAL PRIMARY KEY, @@ -61,24 +64,24 @@ CREATE TABLE IF NOT EXISTS tasks ( ) "#, ) - .execute() + .execute(conn) .await?; - conn.prepare("COMMIT").execute().await?; + sqlx::query::("COMMIT").execute(conn).await?; Ok(()) } -async fn print_all_tasks(conn: &mut PgConnection) -> Fallible<()> { +async fn print_all_tasks(conn: &mut Connection) -> Fallible<()> { // language=sql - conn.prepare( + sqlx::query::( r#" SELECT id, text FROM tasks WHERE done_at IS NULL "#, ) - .fetch() + .fetch(conn) .try_for_each(|(id, text): (i64, String)| { // language=text println!("{:>5} | {}", id, text); @@ -90,24 +93,24 @@ WHERE done_at IS NULL Ok(()) } -async fn add_task(conn: &mut PgConnection, text: &str) -> Fallible<()> { +async fn add_task(conn: &mut Connection, text: &str) -> Fallible<()> { // language=sql - conn.prepare( + sqlx::query::( r#" INSERT INTO tasks ( text ) VALUES ( $1 ) "#, ) .bind(text) - .execute() + .execute(conn) .await?; Ok(()) } -async fn mark_task_as_done(conn: &mut PgConnection, id: i64) -> Fallible<()> { +async fn mark_task_as_done(conn: &mut Connection, id: i64) -> Fallible<()> { // language=sql - conn.prepare( + sqlx::query::( r#" UPDATE tasks SET done_at = now() @@ -115,7 +118,7 @@ WHERE id = $1 "#, ) .bind(id) - .execute() + .execute(conn) .await?; Ok(()) diff --git a/src/backend.rs b/src/backend.rs index f56f8590..ad11b6d2 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,4 +1,4 @@ -use crate::{connection::Connection, row::Row}; +use crate::{connection::RawConnection, row::Row}; /// A database backend. /// @@ -7,6 +7,6 @@ use crate::{connection::Connection, row::Row}; /// to query capabilities within a database backend (e.g., with a specific /// `Connection` can we `bind` a `i64`?). pub trait Backend: Sized { - type Connection: Connection; + type RawConnection: RawConnection; type Row: Row; } diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index 6ab5347a..00000000 --- a/src/client.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::{ - backend::Backend, - connection::{Connection, ConnectionAssocQuery}, - pool::{Pool, PoolOptions}, -}; -use std::{io, ops::DerefMut}; - -pub struct Client { - pool: Pool, -} - -impl Clone for Client { - fn clone(&self) -> Self { - Self { - pool: self.pool.clone(), - } - } -} - -impl Client { - pub fn new(url: &str) -> Self { - Self { - pool: Pool::new( - url, - PoolOptions { - idle_timeout: None, - connection_timeout: None, - max_lifetime: None, - max_size: 70, - min_idle: None, - }, - ), - } - } - - pub async fn get(&self) -> io::Result> { - Ok(self.pool.acquire().await?) - } -} - -// impl<'c, 'q, DB: Backend> ConnectionAssocQuery<'c, 'q> for Client { -// type Query = <::Connection as ConnectionAssocQuery<'c, 'q>>::Query; -// } - -// impl Connection for Client { -// type Backend = DB; - -// #[inline] -// fn establish(url: &str) -> BoxFuture> { -// Box::pin(future::ok(Client::new(url))) -// } - -// #[inline] -// fn prepare<'c, 'q>(&'c mut self, query: &'q str) -> <::Connection as ConnectionAssocQuery<'c, 'q>>::Query { -// // TODO: Think on how to handle error here -// self.pool.acquire().unwrap().prepare(query) -// } -// } diff --git a/src/connection.rs b/src/connection.rs index 04c402c4..34775e9d 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,23 +1,233 @@ -pub(crate) use self::internal::ConnectionAssocQuery; -use crate::{backend::Backend, Query}; -use futures::future::BoxFuture; -use std::io; +use crate::{backend::Backend, executor::Executor, query::Query, row::FromRow}; +use crossbeam_queue::SegQueue; +use crossbeam_utils::atomic::AtomicCell; +use futures::{ + channel::oneshot::{channel, Sender}, + future::BoxFuture, + stream::{BoxStream, StreamExt}, +}; +use std::{ + io, + ops::{Deref, DerefMut}, + sync::Arc, + sync::atomic::{AtomicBool, Ordering}, +}; -mod internal { - pub trait ConnectionAssocQuery<'c, 'q> { - type Query: super::Query<'c, 'q>; - } -} - -pub trait Connection: for<'c, 'q> ConnectionAssocQuery<'c, 'q> { +pub trait RawConnection: Send { type Backend: Backend; + /// Establish a new connection to the database server. fn establish(url: &str) -> BoxFuture> where Self: Sized; - fn prepare<'c, 'q>( + /// Release resources for this database connection immediately. + /// This method is not required to be called. A database server will eventually notice + /// and clean up not fully closed connections. + fn finalize<'c>(&'c mut self) -> BoxFuture<'c, io::Result<()>>; + + fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>; + + fn fetch<'c, 'q, Q: 'q>( &'c mut self, - query: &'q str, - ) -> >::Query; + query: Q, + ) -> BoxStream<'c, io::Result<::Row>> + where + Q: Query<'q, Backend = Self::Backend>; + + fn fetch_optional<'c, 'q, Q: 'q>( + &'c mut self, + query: Q, + ) -> BoxFuture<'c, io::Result::Row>>> + where + Q: Query<'q, Backend = Self::Backend>; +} + +pub struct Connection(Arc>) +where + DB: Backend; + +impl Clone for Connection +where + DB: Backend, +{ + #[inline] + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +impl Connection +where + DB: Backend, +{ + pub async fn establish(url: &str) -> io::Result { + let raw = ::RawConnection::establish(url).await?; + let shared = SharedConnection { + raw: AtomicCell::new(Some(Box::new(raw))), + waiting: AtomicBool::new(false), + waiters: SegQueue::new(), + }; + + Ok(Self(Arc::new(shared))) + } + + async fn get(&self) -> ConnectionFairy { + let raw = self.0.acquire().await; + let conn = ConnectionFairy::new(&self.0, raw); + + conn + } +} + +impl Executor for Connection +where + DB: Backend, +{ + type Backend = DB; + + fn execute<'c, 'q, Q: 'q + 'c>(&'c self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + { + Box::pin(async move { + let mut conn = self.get().await; + conn.execute(query).await + }) + } + + fn fetch<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>(&'c self, query: Q) -> BoxStream<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow + Send + Unpin, + { + Box::pin(async_stream::try_stream! { + let mut conn = self.get().await; + let mut s = conn.fetch(query); + + while let Some(row) = s.next().await.transpose()? { + yield T::from_row(row); + } + }) + } + + fn fetch_optional<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>( + &'c self, + query: Q, + ) -> BoxFuture<'c, io::Result>> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow, + { + Box::pin(async move { + let mut conn = self.get().await; + let row = conn.fetch_optional(query).await?; + + Ok(row.map(T::from_row)) + }) + } +} + +struct SharedConnection +where + DB: Backend, +{ + raw: AtomicCell>>, + waiting: AtomicBool, + waiters: SegQueue>>, +} + +impl SharedConnection +where + DB: Backend, +{ + async fn acquire(&self) -> Box { + if let Some(raw) = self.raw.swap(None) { + // Fast path, this connection is not currently in use. + // We can directly return the inner connection. + return raw; + } + + let (sender, receiver) = channel(); + + self.waiters.push(sender); + self.waiting.store(true, Ordering::Release); + + // TODO: Handle this error + receiver.await.unwrap() + } + + fn release(&self, mut raw: Box) { + // If we have any waiters, iterate until we find a non-dropped waiter + if self.waiting.load(Ordering::Acquire) { + while let Ok(waiter) = self.waiters.pop() { + raw = match waiter.send(raw) { + Err(raw) => raw, + Ok(_) => { + return; + } + }; + } + } + + // Otherwise, just re-store the connection until + // we are needed again + self.raw.store(Some(raw)); + } +} + +struct ConnectionFairy<'a, DB> +where + DB: Backend, +{ + shared: &'a Arc>, + raw: Option>, +} + +impl<'a, DB> ConnectionFairy<'a, DB> +where + DB: Backend, +{ + #[inline] + fn new(shared: &'a Arc>, raw: Box) -> Self { + Self { + shared, + raw: Some(raw), + } + } +} + +impl Deref for ConnectionFairy<'_, DB> +where + DB: Backend, +{ + type Target = DB::RawConnection; + + #[inline] + fn deref(&self) -> &Self::Target { + self.raw.as_ref().expect("connection use after drop") + } +} + +impl DerefMut for ConnectionFairy<'_, DB> +where + DB: Backend, +{ + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + self.raw.as_mut().expect("connection use after drop") + } +} + +impl Drop for ConnectionFairy<'_, DB> +where + DB: Backend, +{ + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + self.shared.release(raw); + } + } } diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 00000000..5bfaddcb --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,55 @@ +use crate::{backend::Backend, row::FromRow, Query}; +use futures::{future::BoxFuture, stream::BoxStream}; +use std::io; + +pub trait Executor: Send { + type Backend: Backend; + + fn execute<'c, 'q, Q: 'q + 'c>(&'c self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>; + + fn fetch<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>(&'c self, query: Q) -> BoxStream<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow + Send + Unpin; + + fn fetch_optional<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>( + &'c self, + query: Q, + ) -> BoxFuture<'c, io::Result>> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow; +} + +impl<'e, E> Executor for &'e E where E: Executor + Send + Sync { + type Backend = E::Backend; + + #[inline] + fn execute<'c, 'q, Q: 'q + 'c>(&'c self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend> + { + (*self).execute(query) + } + + fn fetch<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>(&'c self, query: Q) -> BoxStream<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow + Send + Unpin + { + (*self).fetch(query) + } + + fn fetch_optional<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>( + &'c self, + query: Q, + ) -> BoxFuture<'c, io::Result>> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow + { + (*self).fetch_optional(query) + } +} diff --git a/src/lib.rs b/src/lib.rs index 15edcc8f..c8a1b012 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,11 +30,16 @@ pub mod mariadb; #[cfg(feature = "postgres")] pub mod pg; -mod client; mod connection; +mod executor; mod pool; mod query; -pub use self::{client::Client, connection::Connection, query::Query}; +pub use self::{ + connection::Connection, + pool::Pool, + query::{query, Query}, +}; +// TODO: Remove after Mariadb transitions to URIs mod options; diff --git a/src/pg/backend.rs b/src/pg/backend.rs index 59a2967d..f2ba8df7 100644 --- a/src/pg/backend.rs +++ b/src/pg/backend.rs @@ -3,7 +3,7 @@ use crate::backend::Backend; pub struct Pg; impl Backend for Pg { - type Connection = super::PgConnection; + type RawConnection = super::PgRawConnection; type Row = super::PgRow; } diff --git a/src/pg/connection/establish.rs b/src/pg/connection/establish.rs index f699c4b9..b26e29f4 100644 --- a/src/pg/connection/establish.rs +++ b/src/pg/connection/establish.rs @@ -1,9 +1,9 @@ -use super::PgConnection; +use super::PgRawConnection; use crate::pg::protocol::{Authentication, Message, PasswordMessage, StartupMessage}; -use std::{borrow::Cow, io}; +use std::io; use url::Url; -pub async fn establish<'a, 'b: 'a>(conn: &'a mut PgConnection, url: &'b Url) -> io::Result<()> { +pub async fn establish<'a, 'b: 'a>(conn: &'a mut PgRawConnection, url: &'b Url) -> io::Result<()> { let user = url.username(); let password = url.password().unwrap_or(""); let database = url.path().trim_start_matches('/'); diff --git a/src/pg/connection/execute.rs b/src/pg/connection/execute.rs index 42157d61..fff88791 100644 --- a/src/pg/connection/execute.rs +++ b/src/pg/connection/execute.rs @@ -1,41 +1,31 @@ -use super::prepare::Prepare; -use crate::postgres::protocol::{self, Message}; +use super::PgRawConnection; +use crate::pg::protocol::Message; use std::io; -impl<'a, 'b> Prepare<'a, 'b> { - pub async fn execute(self) -> io::Result { - let conn = self.finish(); +pub async fn execute(conn: &mut PgRawConnection) -> io::Result { + conn.flush().await?; - conn.flush().await?; + let mut rows = 0; - let mut rows = 0; + while let Some(message) = conn.receive().await? { + match message { + Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete | Message::ParseComplete => { - // Indicates successful completion of a phase - } + Message::CommandComplete(body) => { + rows = body.rows(); + } - Message::DataRow(_) => { - // This is EXECUTE so we are ignoring any potential results - } + Message::ReadyForQuery(_) => { + // Successful completion of the whole cycle + return Ok(rows); + } - Message::CommandComplete(body) => { - rows = body.rows(); - } - - Message::ReadyForQuery(_) => { - // Successful completion of the whole cycle - return Ok(rows); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } + message => { + unimplemented!("received {:?} unimplemented message", message); } } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() } + + // FIXME: This is an end-of-file error. How we should bubble this up here? + unreachable!() } diff --git a/src/pg/connection/fetch.rs b/src/pg/connection/fetch.rs new file mode 100644 index 00000000..93537136 --- /dev/null +++ b/src/pg/connection/fetch.rs @@ -0,0 +1,37 @@ +use super::{PgRawConnection, PgRow}; +use crate::pg::protocol::Message; +use futures::stream::Stream; +use std::io; + +pub fn fetch<'a>( + conn: &'a mut PgRawConnection, +) -> impl Stream> + 'a { + async_stream::try_stream! { + conn.flush().await?; + + while let Some(message) = conn.receive().await? { + match message { + Message::BindComplete + | Message::ParseComplete + | Message::PortalSuspended + | Message::CloseComplete + | Message::CommandComplete(_) => {} + + Message::DataRow(body) => { + yield PgRow(body); + } + + Message::ReadyForQuery(_) => { + return; + } + + message => { + unimplemented!("received {:?} unimplemented message", message); + } + } + } + + // FIXME: This is an end-of-file error. How we should bubble this up here? + unreachable!() + } +} diff --git a/src/pg/connection/fetch_optional.rs b/src/pg/connection/fetch_optional.rs new file mode 100644 index 00000000..149549bb --- /dev/null +++ b/src/pg/connection/fetch_optional.rs @@ -0,0 +1,34 @@ +use super::{PgRawConnection, PgRow}; +use crate::pg::protocol::Message; +use std::io; + +pub async fn fetch_optional<'a>(conn: &'a mut PgRawConnection) -> io::Result> { + conn.flush().await?; + + let mut row: Option = None; + + while let Some(message) = conn.receive().await? { + match message { + Message::BindComplete + | Message::ParseComplete + | Message::PortalSuspended + | Message::CloseComplete + | Message::CommandComplete(_) => {} + + Message::DataRow(body) => { + row = Some(PgRow(body)); + } + + Message::ReadyForQuery(_) => { + return Ok(row); + } + + message => { + unimplemented!("received {:?} unimplemented message", message); + } + } + } + + // FIXME: This is an end-of-file error. How we should bubble this up here? + unreachable!() +} diff --git a/src/pg/connection/get.rs b/src/pg/connection/get.rs deleted file mode 100644 index ed28df43..00000000 --- a/src/pg/connection/get.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::prepare::Prepare; -use crate::{ - postgres::{ - protocol::{self, DataRow, Message}, - Postgres, - }, - row::{FromRow, Row}, - types::SqlType, -}; -use std::io; - -// TODO: Think through how best to handle null _rows_ and null _values_ - -impl<'a, 'b> Prepare<'a, 'b> { - #[inline] - pub async fn get(self) -> io::Result - where - T: FromRow, - { - Ok(T::from_row(self.get_raw().await?.unwrap())) - } - - // TODO: Better name? - // TODO: Should this be public? - async fn get_raw(self) -> io::Result>> { - let conn = self.finish(); - - conn.flush().await?; - - let mut row: Option> = None; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete - | Message::ParseComplete - | Message::PortalSuspended - | Message::CloseComplete => { - // Indicates successful completion of a phase - } - - Message::DataRow(body) => { - // note: because we used `EXECUTE 1` this will only execute once - row = Some(Row::(body)); - } - - Message::CommandComplete(_) => {} - - Message::ReadyForQuery(_) => { - return Ok(row); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() - } -} diff --git a/src/pg/connection/mod.rs b/src/pg/connection/mod.rs index ae255433..2ed75c11 100644 --- a/src/pg/connection/mod.rs +++ b/src/pg/connection/mod.rs @@ -1,23 +1,34 @@ use super::{ - protocol::{Encode, Message, Terminate}, - Pg, PgQuery, + protocol::{Authentication, Encode, Message, PasswordMessage, StartupMessage, Terminate}, + Pg, PgQuery, PgRow, }; -use crate::connection::{Connection, ConnectionAssocQuery}; +use crate::{connection::RawConnection, query::Query, row::FromRow}; use bytes::{BufMut, BytesMut}; use futures::{ future::BoxFuture, - io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, ready, + stream::{self, BoxStream, Stream}, task::{Context, Poll}, Future, }; -use runtime::net::TcpStream; -use std::{fmt::Debug, io, pin::Pin}; +use std::{ + fmt::Debug, + io, + net::{IpAddr, Shutdown, SocketAddr}, + pin::Pin, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; use url::Url; mod establish; +mod execute; +mod fetch; +mod fetch_optional; -pub struct PgConnection { +pub struct PgRawConnection { stream: TcpStream, // Do we think that there is data in the read buffer to be decoded @@ -40,15 +51,19 @@ pub struct PgConnection { secret_key: u32, } -impl PgConnection { - pub async fn establish(url: &str) -> io::Result { +impl PgRawConnection { + async fn establish(url: &str) -> io::Result { // TODO: Handle errors let url = Url::parse(url).unwrap(); let host = url.host_str().unwrap_or("localhost"); let port = url.port().unwrap_or(5432); - let stream = TcpStream::connect((host, port)).await?; + // FIXME: handle errors + let host: IpAddr = host.parse().unwrap(); + let addr: SocketAddr = (host, port).into(); + + let stream = TcpStream::connect(&addr).await?; let mut conn = Self { wbuf: Vec::with_capacity(1024), rbuf: BytesMut::with_capacity(1024 * 8), @@ -64,16 +79,16 @@ impl PgConnection { Ok(conn) } - pub async fn close(mut self) -> io::Result<()> { + async fn finalize(&mut self) -> io::Result<()> { self.write(Terminate); self.flush().await?; - self.stream.close().await?; + self.stream.shutdown(Shutdown::Both)?; Ok(()) } // Wait and return the next message to be received from Postgres. - pub(super) async fn receive(&mut self) -> io::Result> { + async fn receive(&mut self) -> io::Result> { loop { if self.stream_eof { // Reached end-of-file on a previous read call. @@ -131,7 +146,7 @@ impl PgConnection { message.encode(&mut self.wbuf); } - pub(super) async fn flush(&mut self) -> io::Result<()> { + async fn flush(&mut self) -> io::Result<()> { self.stream.write_all(&self.wbuf).await?; self.wbuf.clear(); @@ -139,20 +154,46 @@ impl PgConnection { } } -impl<'c, 'q> ConnectionAssocQuery<'c, 'q> for PgConnection { - type Query = PgQuery<'c, 'q>; -} - -impl Connection for PgConnection { +impl RawConnection for PgRawConnection { type Backend = Pg; #[inline] fn establish(url: &str) -> BoxFuture> { - Box::pin(PgConnection::establish(url)) + Box::pin(PgRawConnection::establish(url)) } #[inline] - fn prepare<'c, 'q>(&'c mut self, query: &'q str) -> PgQuery<'c, 'q> { - PgQuery::new(self, query) + fn finalize<'c>(&'c mut self) -> BoxFuture<'c, io::Result<()>> { + Box::pin(self.finalize()) + } + + fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + { + query.finish(self); + + Box::pin(execute::execute(self)) + } + + fn fetch<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxStream<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + { + query.finish(self); + + Box::pin(fetch::fetch(self)) + } + + fn fetch_optional<'c, 'q, Q: 'q>( + &'c mut self, + query: Q, + ) -> BoxFuture<'c, io::Result>> + where + Q: Query<'q, Backend = Self::Backend>, + { + query.finish(self); + + Box::pin(fetch_optional::fetch_optional(self)) } } diff --git a/src/pg/connection/prepare.rs b/src/pg/connection/prepare.rs deleted file mode 100644 index 02e3e690..00000000 --- a/src/pg/connection/prepare.rs +++ /dev/null @@ -1,69 +0,0 @@ -use super::RawConnection; -use crate::{ - postgres::{ - protocol::{self, BindValues}, - Postgres, - }, - serialize::ToSql, - types::{AsSql, SqlType}, -}; - -pub struct Prepare<'a, 'b> { - query: &'b str, - pub(super) connection: &'a mut RawConnection, - pub(super) bind: BindValues, -} - -#[inline] -pub fn prepare<'a, 'b>(connection: &'a mut RawConnection, query: &'b str) -> Prepare<'a, 'b> { - // TODO: Use a hash map to cache the parse - // TODO: Use named statements - Prepare { - connection, - query, - bind: BindValues::new(), - } -} - -impl<'a, 'b> Prepare<'a, 'b> { - #[inline] - pub fn bind>(mut self, value: T) -> Self - where - T: ToSql>::Type>, - { - self.bind.add(value); - self - } - - #[inline] - pub fn bind_as, T: ToSql>(mut self, value: T) -> Self { - self.bind.add_as::(value); - self - } - - pub(super) fn finish(self) -> &'a mut RawConnection { - self.connection.write(protocol::Parse { - portal: "", - query: self.query, - param_types: self.bind.types(), - }); - - self.connection.write(protocol::Bind { - portal: "", - statement: "", - formats: self.bind.formats(), - values_len: self.bind.values_len(), - values: self.bind.values(), - result_formats: &[1], - }); - - self.connection.write(protocol::Execute { - portal: "", - limit: 0, - }); - - self.connection.write(protocol::Sync); - - self.connection - } -} diff --git a/src/pg/connection/select.rs b/src/pg/connection/select.rs deleted file mode 100644 index 12b71b3a..00000000 --- a/src/pg/connection/select.rs +++ /dev/null @@ -1,67 +0,0 @@ -use super::prepare::Prepare; -use crate::{ - postgres::{ - protocol::{self, DataRow, Message}, - Postgres, - }, - row::{FromRow, Row}, -}; -use futures::{stream, Stream, TryStreamExt}; -use std::io; - -impl<'a, 'b> Prepare<'a, 'b> { - #[inline] - pub fn select( - self, - ) -> impl Stream> + 'a + Unpin - where - T: FromRow, - { - self.select_raw().map_ok(T::from_row) - } - - // TODO: Better name? - // TODO: Should this be public? - fn select_raw(self) -> impl Stream, io::Error>> + 'a + Unpin { - // FIXME: Manually implement Stream on a new type to avoid the unfold adapter - stream::unfold(self.finish(), |conn| { - Box::pin(async { - if !conn.wbuf.is_empty() { - if let Err(e) = conn.flush().await { - return Some((Err(e), conn)); - } - } - - loop { - let message = match conn.receive().await { - Ok(Some(message)) => message, - // FIXME: This is an end-of-file error. How we should bubble this up here? - Ok(None) => unreachable!(), - Err(e) => return Some((Err(e), conn)), - }; - - match message { - Message::BindComplete | Message::ParseComplete => { - // Indicates successful completion of a phase - } - - Message::DataRow(row) => { - break Some((Ok(Row::(row)), conn)); - } - - Message::CommandComplete(_) => {} - - Message::ReadyForQuery(_) => { - // Successful completion of the whole cycle - break None; - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - }) - }) - } -} diff --git a/src/pg/mod.rs b/src/pg/mod.rs index dc663cb2..6a03e04a 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -5,4 +5,4 @@ mod query; mod row; pub mod types; -pub use self::{backend::Pg, connection::PgConnection, query::PgQuery, row::PgRow}; +pub use self::{backend::Pg, connection::PgRawConnection, query::PgQuery, row::PgRow}; diff --git a/src/pg/protocol/bind.rs b/src/pg/protocol/bind.rs index feaf1f85..80295a03 100644 --- a/src/pg/protocol/bind.rs +++ b/src/pg/protocol/bind.rs @@ -50,32 +50,3 @@ impl Encode for Bind<'_> { BigEndian::write_i32(&mut buf[pos..], len as i32); } } - -#[cfg(test)] -mod tests { - use super::{Bind, BindCollector, BufMut, Encode}; - - const BIND: &[u8] = b"B\0\0\0\x18\0\0\0\x01\0\x01\0\x02\0\0\0\x011\0\0\0\x012\0\0"; - - #[test] - fn it_encodes_bind_for_two() { - let mut buf = Vec::new(); - - let mut builder = BindCollector::new(); - builder.add("1"); - builder.add("2"); - - let bind = Bind { - portal: "", - statement: "", - formats: builder.formats(), - values_len: builder.values_len(), - values: builder.values(), - result_formats: &[], - }; - - bind.encode(&mut buf); - - assert_eq!(buf, BIND); - } -} diff --git a/src/pg/protocol/data_row.rs b/src/pg/protocol/data_row.rs index 396c49e1..52169a1d 100644 --- a/src/pg/protocol/data_row.rs +++ b/src/pg/protocol/data_row.rs @@ -81,7 +81,6 @@ impl Debug for DataRow { #[cfg(test)] mod tests { use super::{DataRow, Decode}; - use crate::row::RawRow; use bytes::Bytes; use std::io; diff --git a/src/pg/query.rs b/src/pg/query.rs index 1d35d05b..d1100edd 100644 --- a/src/pg/query.rs +++ b/src/pg/query.rs @@ -1,6 +1,6 @@ use super::{ protocol::{self, BufMut, Message}, - Pg, PgConnection, PgRow, + Pg, PgRawConnection, PgRow, }; use crate::{ query::Query, @@ -16,8 +16,8 @@ use futures::{ }; use std::io; -pub struct PgQuery<'c, 'q> { - conn: &'c mut PgConnection, +pub struct PgQuery<'q> { + limit: i32, query: &'q str, // OIDs of the bind parameters types: Vec, @@ -25,44 +25,20 @@ pub struct PgQuery<'c, 'q> { buf: Vec, } -impl<'c, 'q> PgQuery<'c, 'q> { - pub(super) fn new(conn: &'c mut PgConnection, query: &'q str) -> Self { +impl<'q> Query<'q> for PgQuery<'q> { + type Backend = Pg; + + fn new(query: &'q str) -> Self { Self { + limit: 0, query, - conn, - types: Vec::new(), - buf: Vec::new(), + // Estimates for average number of bind parameters were + // chosen from sampling from internal projects + types: Vec::with_capacity(4), + buf: Vec::with_capacity(32), } } - fn finish(self, limit: i32) -> &'c mut PgConnection { - self.conn.write(protocol::Parse { - portal: "", - query: self.query, - param_types: &*self.types, - }); - - self.conn.write(protocol::Bind { - portal: "", - statement: "", - formats: &[1], // [BINARY] - // TODO: Early error if there is more than i16 - values_len: self.types.len() as i16, - values: &*self.buf, - result_formats: &[1], // [BINARY] - }); - - self.conn.write(protocol::Execute { portal: "", limit }); - - self.conn.write(protocol::Sync); - - self.conn - } -} - -impl<'c, 'q> Query<'c, 'q> for PgQuery<'c, 'q> { - type Backend = Pg; - fn bind_as(mut self, value: T) -> Self where Self: Sized, @@ -91,132 +67,29 @@ impl<'c, 'q> Query<'c, 'q> for PgQuery<'c, 'q> { self } - #[inline] - fn execute(self) -> BoxFuture<'c, io::Result> { - Box::pin(execute(self.finish(0))) - } + fn finish(self, conn: &mut PgRawConnection) { + conn.write(protocol::Parse { + portal: "", + query: self.query, + param_types: &*self.types, + }); - #[inline] - fn fetch(self) -> BoxStream<'c, io::Result> - where - T: FromRow, - { - Box::pin(fetch(self.finish(0))) - } + conn.write(protocol::Bind { + portal: "", + statement: "", + formats: &[1], // [BINARY] + // TODO: Early error if there is more than i16 + values_len: self.types.len() as i16, + values: &*self.buf, + result_formats: &[1], // [BINARY] + }); - #[inline] - fn fetch_optional(self) -> BoxFuture<'c, io::Result>> - where - T: FromRow, - { - Box::pin(fetch_optional(self.finish(1))) + // TODO: Make limit be 1 for fetch_optional + conn.write(protocol::Execute { + portal: "", + limit: self.limit, + }); + + conn.write(protocol::Sync); } } - -async fn execute(conn: &mut PgConnection) -> io::Result { - conn.flush().await?; - - let mut rows = 0; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} - - Message::CommandComplete(body) => { - rows = body.rows(); - } - - Message::ReadyForQuery(_) => { - // Successful completion of the whole cycle - return Ok(rows); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() -} - -async fn fetch_optional<'a, A: 'a, T: 'a>(conn: &'a mut PgConnection) -> io::Result> -where - T: FromRow, -{ - conn.flush().await?; - - let mut row: Option = None; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete - | Message::ParseComplete - | Message::PortalSuspended - | Message::CloseComplete - | Message::CommandComplete(_) => {} - - Message::DataRow(body) => { - row = Some(PgRow(body)); - } - - Message::ReadyForQuery(_) => { - return Ok(row.map(T::from_row)); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() -} - -fn fetch<'a, A: 'a, T: 'a>( - conn: &'a mut PgConnection, -) -> impl Stream> + 'a + Unpin -where - T: FromRow, -{ - // FIXME: Manually implement Stream on a new type to avoid the unfold adapter - stream::unfold(conn, |conn| { - Box::pin(async { - if !conn.wbuf.is_empty() { - if let Err(e) = conn.flush().await { - return Some((Err(e), conn)); - } - } - - loop { - let message = match conn.receive().await { - Ok(Some(message)) => message, - // FIXME: This is an end-of-file error. How we should bubble this up here? - Ok(None) => unreachable!(), - Err(e) => return Some((Err(e), conn)), - }; - - match message { - Message::BindComplete - | Message::ParseComplete - | Message::CommandComplete(_) => {} - - Message::DataRow(row) => { - break Some((Ok(T::from_row(PgRow(row))), conn)); - } - - Message::ReadyForQuery(_) => { - // Successful completion of the whole cycle - break None; - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - }) - }) -} diff --git a/src/pool.rs b/src/pool.rs index 6a7e3fb9..8be8de6f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,16 @@ -use crate::{backend::Backend, Connection}; +use crate::{ + backend::Backend, connection::RawConnection, executor::Executor, query::Query, row::FromRow, + Connection, +}; use crossbeam_queue::{ArrayQueue, SegQueue}; -use futures::{channel::oneshot, TryFutureExt}; +use futures::{ + channel::oneshot, + future::BoxFuture, + stream::{self, BoxStream, Stream, StreamExt}, + TryFutureExt, +}; use std::{ + future::Future, io, ops::{Deref, DerefMut}, sync::{ @@ -12,9 +21,6 @@ use std::{ }; use url::Url; -// TODO: Reap old connections -// TODO: Clean up (a lot) and document what's going on - pub struct PoolOptions { pub max_size: usize, pub min_idle: Option, @@ -24,59 +30,59 @@ pub struct PoolOptions { } /// A database connection pool. -pub struct Pool { - inner: Arc>, +pub struct Pool(Arc>) +where + DB: Backend; + +impl Clone for Pool +where + DB: Backend, +{ + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } } -struct InnerPool { +impl Pool +where + DB: Backend, +{ + // TODO: PoolBuilder + pub fn new<'a>(url: &str, max_size: usize) -> Self { + Self(Arc::new(SharedPool { + url: url.to_owned(), + idle: ArrayQueue::new(max_size), + total: AtomicUsize::new(0), + waiters: SegQueue::new(), + options: PoolOptions { + idle_timeout: None, + connection_timeout: None, + max_lifetime: None, + max_size, + min_idle: None, + }, + })) + } +} + +struct SharedPool +where + DB: Backend, +{ url: String, - idle: ArrayQueue>, - waiters: SegQueue>>, - // Total count of connections managed by this connection pool + idle: ArrayQueue>, + waiters: SegQueue>>, total: AtomicUsize, options: PoolOptions, } -pub struct PoolConnection { - connection: Option>, - pool: Arc>, -} - -impl Clone for Pool { - fn clone(&self) -> Self { - Self { - inner: Arc::clone(&self.inner), - } - } -} - -impl Pool { - pub fn new<'a>(url: &str, options: PoolOptions) -> Self { - Self { - inner: Arc::new(InnerPool { - url: url.to_owned(), - idle: ArrayQueue::new(options.max_size), - total: AtomicUsize::new(0), - waiters: SegQueue::new(), - options, - }), - } - } - - pub async fn acquire(&self) -> io::Result> { - self.inner - .acquire() - .map_ok(|live| PoolConnection::new(live, &self.inner)) - .await - } -} - -impl InnerPool { - async fn acquire(&self) -> io::Result> { +impl SharedPool +where + DB: Backend, +{ + async fn acquire(&self) -> io::Result> { if let Ok(idle) = self.idle.pop() { - log::debug!("acquire: found idle connection"); - - return Ok(idle.connection); + return Ok(idle.live); } let total = self.total.load(Ordering::SeqCst); @@ -84,93 +90,159 @@ impl InnerPool { if total >= self.options.max_size { // Too many already, add a waiter and wait for // a free connection - log::debug!("acquire: too many open connections; waiting for a free connection"); - let (sender, reciever) = oneshot::channel(); self.waiters.push(sender); // TODO: Handle errors here - let live = reciever.await.unwrap(); - - log::debug!("acquire: free connection now available"); - - return Ok(live); + return Ok(reciever.await.unwrap()); } self.total.store(total + 1, Ordering::SeqCst); - log::debug!("acquire: no idle connections; establish new connection"); - let connection = Conn::establish(&self.url).await?; + let raw = ::RawConnection::establish(&self.url).await?; let live = Live { - connection, + raw, since: Instant::now(), }; Ok(live) } - fn release(&self, mut connection: Live) { + fn release(&self, mut live: Live) { while let Ok(waiter) = self.waiters.pop() { - connection = match waiter.send(connection) { + live = match waiter.send(live) { Ok(()) => { return; } - Err(connection) => connection, + Err(live) => live, }; } let _ = self.idle.push(Idle { - connection, + live, since: Instant::now(), }); } } -impl PoolConnection { - fn new(connection: Live, pool: &Arc>) -> Self { +impl Executor for Pool +where + DB: Backend, +{ + type Backend = DB; + + fn execute<'c, 'q, Q: 'q + 'c>(&'c self, query: Q) -> BoxFuture<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + { + Box::pin(async move { + let live = self.0.acquire().await?; + let mut conn = PooledConnection::new(&self.0, live); + + conn.execute(query).await + }) + } + + fn fetch<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>(&'c self, query: Q) -> BoxStream<'c, io::Result> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow + Send + Unpin, + { + Box::pin(async_stream::try_stream! { + let live = self.0.acquire().await?; + let mut conn = PooledConnection::new(&self.0, live); + let mut s = conn.fetch(query); + + while let Some(row) = s.next().await.transpose()? { + yield T::from_row(row); + } + }) + } + + fn fetch_optional<'c, 'q, A: 'c, T: 'c, Q: 'q + 'c>( + &'c self, + query: Q, + ) -> BoxFuture<'c, io::Result>> + where + Q: Query<'q, Backend = Self::Backend>, + T: FromRow, + { + Box::pin(async move { + let live = self.0.acquire().await?; + let mut conn = PooledConnection::new(&self.0, live); + let row = conn.fetch_optional(query).await?; + + Ok(row.map(T::from_row)) + }) + } +} + +struct PooledConnection<'a, DB> +where + DB: Backend, +{ + shared: &'a Arc>, + live: Option>, +} + +impl<'a, DB> PooledConnection<'a, DB> +where + DB: Backend, +{ + fn new(shared: &'a Arc>, live: Live) -> Self { Self { - connection: Some(connection), - pool: Arc::clone(pool), + shared, + live: Some(live), } } } -impl Deref for PoolConnection { - type Target = Conn; +impl Deref for PooledConnection<'_, DB> +where + DB: Backend, +{ + type Target = DB::RawConnection; - #[inline] fn deref(&self) -> &Self::Target { - // PANIC: Will not panic unless accessed after drop - &self.connection.as_ref().unwrap().connection + &self.live.as_ref().expect("connection use after drop").raw } } -impl DerefMut for PoolConnection { - #[inline] +impl DerefMut for PooledConnection<'_, DB> +where + DB: Backend, +{ fn deref_mut(&mut self) -> &mut Self::Target { - // PANIC: Will not panic unless accessed after drop - &mut self.connection.as_mut().unwrap().connection + &mut self.live.as_mut().expect("connection use after drop").raw } } -impl Drop for PoolConnection { +impl Drop for PooledConnection<'_, DB> +where + DB: Backend, +{ fn drop(&mut self) { - log::debug!("release: dropping connection; store back in queue"); - if let Some(connection) = self.connection.take() { - self.pool.release(connection); + if let Some(live) = self.live.take() { + self.shared.release(live); } } } -struct Idle { - connection: Live, +struct Idle +where + DB: Backend, +{ + live: Live, since: Instant, } -struct Live { - connection: Conn, +struct Live +where + DB: Backend, +{ + raw: DB::RawConnection, since: Instant, } diff --git a/src/query.rs b/src/query.rs index ce0d642e..2f26eff5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,7 @@ use crate::{ backend::Backend, + executor::Executor, + pool::Pool, row::FromRow, serialize::ToSql, types::{AsSqlType, HasSqlType}, @@ -7,13 +9,14 @@ use crate::{ use futures::{future::BoxFuture, stream::BoxStream}; use std::io; -pub trait Query<'c, 'q> { +pub trait Query<'q>: Sized + Send + Sync { type Backend: Backend; + fn new(query: &'q str) -> Self; + #[inline] fn bind(self, value: T) -> Self where - Self: Sized, Self::Backend: HasSqlType<>::SqlType>, T: AsSqlType + ToSql<>::SqlType, Self::Backend>, @@ -23,17 +26,48 @@ pub trait Query<'c, 'q> { fn bind_as(self, value: T) -> Self where - Self: Sized, Self::Backend: HasSqlType, T: ToSql; - fn execute(self) -> BoxFuture<'c, io::Result>; + fn finish(self, conn: &mut ::RawConnection); - fn fetch(self) -> BoxStream<'c, io::Result> + #[inline] + fn execute<'c, C>(self, executor: &'c C) -> BoxFuture<'c, io::Result> where - T: FromRow; + Self: 'c + 'q, + C: Executor, + { + executor.execute(self) + } - fn fetch_optional(self) -> BoxFuture<'c, io::Result>> + #[inline] + fn fetch<'c, A: 'c, T: 'c, C>(self, executor: &'c C) -> BoxStream<'c, io::Result> where - T: FromRow; + Self: 'c + 'q, + C: Executor, + T: FromRow + Send + Unpin, + { + executor.fetch(self) + } + + #[inline] + fn fetch_optional<'c, A: 'c, T: 'c, C>( + self, + executor: &'c C, + ) -> BoxFuture<'c, io::Result>> + where + Self: 'c + 'q, + C: Executor, + T: FromRow, + { + executor.fetch_optional(self) + } +} + +#[inline] +pub fn query<'q, Q>(query: &'q str) -> Q +where + Q: Query<'q>, +{ + Q::new(query) } diff --git a/src/row.rs b/src/row.rs index 1c1358eb..8d6a4988 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,6 +1,6 @@ use crate::{backend::Backend, deserialize::FromSql, types::HasSqlType}; -pub trait Row { +pub trait Row: Send { type Backend: Backend; fn is_empty(&self) -> bool;