Add connection string parsing for postgres, run rustfmt over mariadb, use constants for ErrorCode

This commit is contained in:
Ryan Leckey 2019-08-09 21:46:49 -07:00
parent feaa209c33
commit 78b3ae4a19
22 changed files with 1150 additions and 1136 deletions

View File

@ -1,7 +1,7 @@
[workspace]
members = [
".",
"examples/todo",
"examples/todos",
"examples/contacts"
]
@ -31,6 +31,7 @@ hex = "0.3.2"
itoa = "0.4.4"
log = "0.4.8"
md-5 = "0.8.0"
url = "2.1.0"
memchr = "2.2.1"
runtime = { version = "=0.3.0-alpha.6", default-features = false }

View File

@ -10,7 +10,7 @@ use fake::{
Dummy, Fake, Faker,
};
use futures::future;
use sqlx::{pool::Pool, postgres::Postgres};
use sqlx::{Pool, Postgres};
#[derive(Debug, Dummy)]
struct Contact {
@ -34,13 +34,7 @@ struct Contact {
async fn main() -> Fallible<()> {
env_logger::try_init()?;
let options = sqlx::ConnectOptions::new()
.host("127.0.0.1")
.port(5432)
.user("postgres")
.database("sqlx__dev__contacts");
let pool = Pool::<Postgres>::new(options);
let pool = Pool::<Postgres>::new("postgres://postgres@localhost/sqlx__dev");
{
let mut conn = pool.acquire().await?;

View File

@ -1,5 +1,5 @@
[package]
name = "todo"
name = "todos"
version = "0.1.0"
edition = "2018"

View File

@ -2,7 +2,7 @@
use failure::Fallible;
use futures::{future, TryStreamExt};
use sqlx::postgres::Connection;
use sqlx::{Connection, Postgres};
use structopt::StructOpt;
#[derive(StructOpt, Debug)]
@ -26,14 +26,8 @@ async fn main() -> Fallible<()> {
let opt = Options::from_args();
let mut conn = Connection::establish(
sqlx::ConnectOptions::new()
.host("127.0.0.1")
.port(5432)
.user("postgres")
.database("sqlx__dev__tasks"),
)
.await?;
let mut conn =
Connection::<Postgres>::establish("postgres://postgres@localhost/sqlx__dev").await?;
ensure_schema(&mut conn).await?;
@ -54,7 +48,7 @@ async fn main() -> Fallible<()> {
Ok(())
}
async fn ensure_schema(conn: &mut Connection) -> Fallible<()> {
async fn ensure_schema(conn: &mut Connection<Postgres>) -> Fallible<()> {
conn.prepare("BEGIN").execute().await?;
// language=sql
@ -76,7 +70,7 @@ CREATE TABLE IF NOT EXISTS tasks (
Ok(())
}
async fn print_all_tasks(conn: &mut Connection) -> Fallible<()> {
async fn print_all_tasks(conn: &mut Connection<Postgres>) -> Fallible<()> {
// language=sql
conn.prepare(
r#"
@ -97,7 +91,7 @@ WHERE done_at IS NULL
Ok(())
}
async fn add_task(conn: &mut Connection, text: &str) -> Fallible<()> {
async fn add_task(conn: &mut Connection<Postgres>, text: &str) -> Fallible<()> {
// language=sql
conn.prepare(
r#"
@ -112,7 +106,7 @@ VALUES ( $1 )
Ok(())
}
async fn mark_task_as_done(conn: &mut Connection, id: i64) -> Fallible<()> {
async fn mark_task_as_done(conn: &mut Connection<Postgres>, id: i64) -> Fallible<()> {
// language=sql
conn.prepare(
r#"

View File

@ -1,9 +1,57 @@
use crate::{backend::Backend, ConnectOptions};
use crate::backend::Backend;
use futures::future::BoxFuture;
use std::io;
use std::{
io,
ops::{Deref, DerefMut},
};
use url::Url;
// TODO: Re-implement and forward to Raw instead of using Deref
pub trait RawConnection {
fn establish(options: ConnectOptions<'_>) -> BoxFuture<io::Result<Self>>
fn establish(url: &Url) -> BoxFuture<io::Result<Self>>
where
Self: Sized;
}
pub struct Connection<B>
where
B: Backend,
{
pub(crate) inner: B::RawConnection,
}
impl<B> Connection<B>
where
B: Backend,
{
#[inline]
pub async fn establish(url: &str) -> io::Result<Self> {
// TODO: Handle url parse errors
let url = Url::parse(url).unwrap();
Ok(Self {
inner: B::RawConnection::establish(&url).await?,
})
}
}
impl<B> Deref for Connection<B>
where
B: Backend,
{
type Target = B::RawConnection;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<B> DerefMut for Connection<B>
where
B: Backend,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

View File

@ -1,4 +1,4 @@
#![feature(non_exhaustive, async_await, async_closure)]
#![feature(async_await)]
#![cfg_attr(test, feature(test))]
#![allow(clippy::needless_lifetimes)]
// FIXME: Remove this once API has matured
@ -12,12 +12,8 @@ extern crate bitflags;
#[macro_use]
extern crate enum_tryfrom_derive;
mod options;
pub use self::options::ConnectOptions;
// Helper macro for writing long complex tests
#[macro_use]
pub mod macros;
mod macros;
pub mod backend;
pub mod deserialize;
@ -34,6 +30,12 @@ pub mod mariadb;
#[cfg(feature = "postgres")]
mod postgres;
// TODO: This module is not intended to be directly public
#[cfg(feature = "postgres")]
pub use self::postgres::Postgres;
pub mod connection;
pub mod pool;
pub use self::{connection::Connection, pool::Pool};
mod options;

View File

@ -2,9 +2,9 @@ use super::Connection;
use crate::{
mariadb::{
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket,
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag, ProtocolType
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, ProtocolType, StmtExecFlag,
},
ConnectOptions,
options::ConnectOptions,
};
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
@ -169,13 +169,13 @@ mod test {
match ctx.decoder.peek_tag() {
0xFF => {
ErrPacket::decode(&mut ctx)?;
},
}
0x00 => {
OkPacket::decode(&mut ctx)?;
},
}
_ => {
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
},
}
}
Ok(())

View File

@ -2,9 +2,9 @@ use crate::{
mariadb::{
protocol::encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare,
ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader,
ResultSet, ServerStatusFlag, ProtocolType
ProtocolType, ResultSet, ServerStatusFlag,
},
ConnectOptions,
options::ConnectOptions,
};
use byteorder::{ByteOrder, LittleEndian};
use bytes::{Bytes, BytesMut};

View File

@ -9,6 +9,7 @@ pub use protocol::{
ComSetOption, ComShutdown, ComSleep, ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch,
ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, DeContext, Decode, Decoder, Encode,
EofPacket, ErrPacket, ErrorCode, FieldDetailFlag, FieldType, HandshakeResponsePacket,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultRow, ResultSet, SSLRequestPacket,
ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag, ProtocolType
InitialHandshakePacket, OkPacket, PacketHeader, ProtocolType, ResultRow, ResultRowBinary,
ResultRowText, ResultSet, SSLRequestPacket, ServerStatusFlag, SessionChangeType,
SetOptionOptions, ShutdownOptions, StmtExecFlag,
};

File diff suppressed because it is too large Load Diff

View File

@ -19,8 +19,8 @@ pub use packets::{
ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep,
ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk,
ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultSet, SSLRequestPacket,
SetOptionOptions, ShutdownOptions, ResultRow
InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultRowBinary, ResultRowText,
ResultSet, SSLRequestPacket, SetOptionOptions, ShutdownOptions,
};
pub use decode::{DeContext, Decode, Decoder};
@ -30,5 +30,6 @@ pub use encode::{BufMut, Encode};
pub use error_codes::ErrorCode;
pub use types::{
ProtocolType, Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag,
Capabilities, FieldDetailFlag, FieldType, ProtocolType, ServerStatusFlag, SessionChangeType,
StmtExecFlag,
};

View File

@ -1,5 +1,6 @@
use crate::mariadb::{
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag, FieldType
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, FieldType,
StmtExecFlag,
};
use bytes::Bytes;
use failure::Error;

View File

@ -110,7 +110,7 @@ impl crate::mariadb::Decode for ResultRow {
Ok(ResultRow {
length,
seq_no,
columns
columns,
})
}
}

View File

@ -30,7 +30,7 @@ impl Decode for ErrPacket {
panic!("Packet header is not 0xFF for ErrPacket");
}
let error_code = ErrorCode::try_from(decoder.decode_int_i16())?;
let error_code = ErrorCode(decoder.decode_int_u16());
let mut stage = None;
let mut max_stage = None;
@ -42,7 +42,7 @@ impl Decode for ErrPacket {
let mut error_message = None;
// Progress Reporting
if error_code as u16 == 0xFFFF {
if error_code.0 == 0xFFFF {
stage = Some(decoder.decode_int_u8());
max_stage = Some(decoder.decode_int_u8());
progress = Some(decoder.decode_int_i24());

View File

@ -28,11 +28,11 @@ pub use ssl_request::SSLRequestPacket;
pub use text::{
ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection,
ComSetOption, ComShutdown, ComSleep, ComStatistics, SetOptionOptions, ShutdownOptions,
ResultRow as ResultRowText
ComSetOption, ComShutdown, ComSleep, ComStatistics, ResultRow as ResultRowText,
SetOptionOptions, ShutdownOptions,
};
pub use binary::{
ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp,
ComStmtReset, ResultRow as ResultRowBinary
ComStmtReset, ResultRow as ResultRowBinary,
};

View File

@ -1,10 +1,10 @@
use crate::mariadb::{ResultRowText, ResultRowBinary};
use crate::mariadb::{ResultRowBinary, ResultRowText};
#[derive(Debug)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub columns: Vec<Option<bytes::Bytes>>
pub columns: Vec<Option<bytes::Bytes>>,
}
impl From<ResultRowText> for ResultRow {
@ -17,7 +17,6 @@ impl From<ResultRowText> for ResultRow {
}
}
impl From<ResultRowBinary> for ResultRow {
fn from(row: ResultRowBinary) -> Self {
ResultRow {

View File

@ -3,7 +3,8 @@ use failure::Error;
use crate::mariadb::{
Capabilities, ColumnDefPacket, ColumnPacket, ConnContext, DeContext, Decode, Decoder,
EofPacket, ErrPacket, Framed, OkPacket, ResultRowText, ResultRowBinary, ProtocolType, ResultRow
EofPacket, ErrPacket, Framed, OkPacket, ProtocolType, ResultRow, ResultRowBinary,
ResultRowText,
};
#[derive(Debug, Default)]
@ -14,7 +15,10 @@ pub struct ResultSet {
}
impl ResultSet {
pub async fn deserialize(mut ctx: DeContext<'_>, protocol: ProtocolType) -> Result<Self, Error> {
pub async fn deserialize(
mut ctx: DeContext<'_>,
protocol: ProtocolType,
) -> Result<Self, Error> {
let column_packet = ColumnPacket::decode(&mut ctx)?;
let columns = if let Some(columns) = column_packet.columns {
@ -58,29 +62,25 @@ impl ResultSet {
break;
} else {
let index = ctx.decoder.index;
match protocol {
ProtocolType::Text => {
match ResultRowText::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
match protocol {
ProtocolType::Text => match ResultRowText::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
},
ProtocolType::Binary => {
match ResultRowBinary::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
ProtocolType::Binary => match ResultRowBinary::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
},
}

View File

@ -2,7 +2,7 @@ use std::convert::TryFrom;
pub enum ProtocolType {
Text,
Binary
Binary,
}
bitflags! {

View File

@ -2,18 +2,18 @@ use std::borrow::Cow;
#[derive(Debug, Clone, PartialEq)]
pub struct ConnectOptions<'a> {
pub host: Cow<'a, str>,
pub host: &'a str,
pub port: u16,
pub user: Option<Cow<'a, str>>,
pub database: Option<Cow<'a, str>>,
pub password: Option<Cow<'a, str>>,
pub user: Option<&'a str>,
pub database: Option<&'a str>,
pub password: Option<&'a str>,
}
impl<'a> Default for ConnectOptions<'a> {
#[inline]
fn default() -> Self {
Self {
host: Cow::Borrowed("localhost"),
host: "localhost",
port: 5432,
user: None,
database: None,
@ -28,20 +28,9 @@ impl<'a> ConnectOptions<'a> {
Self::default()
}
#[inline]
pub fn into_owned(self) -> ConnectOptions<'static> {
ConnectOptions {
host: self.host.into_owned().into(),
port: self.port,
user: self.user.map(|s| s.into_owned().into()),
database: self.database.map(|s| s.into_owned().into()),
password: self.password.map(|s| s.into_owned().into()),
}
}
#[inline]
pub fn host(mut self, host: &'a str) -> Self {
self.host = Cow::Borrowed(host);
self.host = host;
self
}
@ -53,19 +42,19 @@ impl<'a> ConnectOptions<'a> {
#[inline]
pub fn user(mut self, user: &'a str) -> Self {
self.user = Some(Cow::Borrowed(user));
self.user = Some(user);
self
}
#[inline]
pub fn database(mut self, database: &'a str) -> Self {
self.database = Some(Cow::Borrowed(database));
self.database = Some(database);
self
}
#[inline]
pub fn password(mut self, password: &'a str) -> Self {
self.password = Some(Cow::Borrowed(password));
self.password = Some(password);
self
}
}

View File

@ -1,5 +1,5 @@
use super::connection::RawConnection;
use crate::{backend::Backend, ConnectOptions};
use crate::{backend::Backend, Connection};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures::{channel::oneshot, TryFutureExt};
use std::{
@ -11,11 +11,10 @@ use std::{
},
time::Instant,
};
use url::Url;
// TODO: Add a sqlx::Connection type so we don't leak the RawConnection
// TODO: Reap old connections
// TODO: Clean up (a lot) and document what's going on
// TODO: sqlx::ConnectOptions needs to be removed and replaced with URIs everywhere
pub struct Pool<B>
where
@ -24,6 +23,24 @@ where
inner: Arc<InnerPool<B>>,
}
struct InnerPool<B>
where
B: Backend,
{
url: Url,
idle: ArrayQueue<Idle<B>>,
waiters: SegQueue<oneshot::Sender<Live<B>>>,
total: AtomicUsize,
}
pub struct PooledConnection<B>
where
B: Backend,
{
connection: Option<Live<B>>,
pool: Arc<InnerPool<B>>,
}
impl<B> Clone for Pool<B>
where
B: Backend,
@ -35,24 +52,15 @@ where
}
}
struct InnerPool<B>
where
B: Backend,
{
options: ConnectOptions<'static>,
idle: ArrayQueue<Idle<B>>,
waiters: SegQueue<oneshot::Sender<Live<B>>>,
total: AtomicUsize,
}
impl<B> Pool<B>
where
B: Backend,
{
pub fn new<'a>(options: ConnectOptions<'a>) -> Self {
pub fn new<'a>(url: &str) -> Self {
Self {
inner: Arc::new(InnerPool {
options: options.into_owned(),
// TODO: Handle errors nicely
url: Url::parse(url).unwrap(),
idle: ArrayQueue::new(10),
total: AtomicUsize::new(0),
waiters: SegQueue::new(),
@ -60,10 +68,10 @@ where
}
}
pub async fn acquire(&self) -> io::Result<Connection<B>> {
pub async fn acquire(&self) -> io::Result<PooledConnection<B>> {
self.inner
.acquire()
.map_ok(|live| Connection::new(live, &self.inner))
.map_ok(|live| PooledConnection::new(live, &self.inner))
.await
}
}
@ -101,7 +109,9 @@ where
self.total.store(total + 1, Ordering::SeqCst);
log::debug!("acquire: no idle connections; establish new connection");
let connection = B::RawConnection::establish(self.options.clone()).await?;
let connection = B::RawConnection::establish(&self.url).await?;
let connection = Connection { inner: connection };
let live = Live {
connection,
since: Instant::now(),
@ -127,18 +137,7 @@ where
});
}
}
// TODO: Need a better name here than [pool::Connection] ?
pub struct Connection<B>
where
B: Backend,
{
connection: Option<Live<B>>,
pool: Arc<InnerPool<B>>,
}
impl<B> Connection<B>
impl<B> PooledConnection<B>
where
B: Backend,
{
@ -150,11 +149,11 @@ where
}
}
impl<B> Deref for Connection<B>
impl<B> Deref for PooledConnection<B>
where
B: Backend,
{
type Target = B::RawConnection;
type Target = Connection<B>;
#[inline]
fn deref(&self) -> &Self::Target {
@ -163,7 +162,7 @@ where
}
}
impl<B> DerefMut for Connection<B>
impl<B> DerefMut for PooledConnection<B>
where
B: Backend,
{
@ -174,7 +173,7 @@ where
}
}
impl<B> Drop for Connection<B>
impl<B> Drop for PooledConnection<B>
where
B: Backend,
{
@ -198,6 +197,6 @@ struct Live<B>
where
B: Backend,
{
connection: B::RawConnection,
connection: Connection<B>,
since: Instant,
}

View File

@ -1,16 +1,12 @@
use super::RawConnection;
use crate::{
postgres::protocol::{Authentication, Message, PasswordMessage, StartupMessage},
ConnectOptions,
};
use crate::postgres::protocol::{Authentication, Message, PasswordMessage, StartupMessage};
use std::{borrow::Cow, io};
use url::Url;
pub async fn establish<'a, 'b: 'a>(
conn: &'a mut RawConnection,
options: ConnectOptions<'b>,
) -> io::Result<()> {
let user = &*options.user.expect("user is required");
let password = &*options.password.unwrap_or(Cow::Borrowed(""));
pub async fn establish<'a, 'b: 'a>(conn: &'a mut RawConnection, url: &'b Url) -> io::Result<()> {
let user = url.username();
let password = url.password().unwrap_or("");
let database = url.path().trim_start_matches('/');
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
@ -18,10 +14,7 @@ pub async fn establish<'a, 'b: 'a>(
// FIXME: ConnectOptions user and database need to be required parameters and error
// before they get here
("user", user),
(
"database",
&*options.database.expect("database is required"),
),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),

View File

@ -1,5 +1,4 @@
use super::protocol::{Encode, Message, Terminate};
use crate::ConnectOptions;
use bytes::{BufMut, BytesMut};
use futures::{
future::BoxFuture,
@ -10,6 +9,7 @@ use futures::{
};
use runtime::net::TcpStream;
use std::{fmt::Debug, io, pin::Pin};
use url::Url;
mod establish;
mod execute;
@ -41,8 +41,11 @@ pub struct RawConnection {
}
impl RawConnection {
pub async fn establish(options: ConnectOptions<'_>) -> io::Result<Self> {
let stream = TcpStream::connect((&*options.host, options.port)).await?;
pub async fn establish(url: &Url) -> io::Result<Self> {
let host = url.host_str().unwrap_or("localhost");
let port = url.port().unwrap_or(5432);
let stream = TcpStream::connect((host, port)).await?;
let mut conn = Self {
wbuf: Vec::with_capacity(1024),
rbuf: BytesMut::with_capacity(1024 * 8),
@ -53,7 +56,7 @@ impl RawConnection {
secret_key: 0,
};
establish::establish(&mut conn, options).await?;
establish::establish(&mut conn, &url).await?;
Ok(conn)
}
@ -139,7 +142,7 @@ impl RawConnection {
impl crate::connection::RawConnection for RawConnection {
#[inline]
fn establish(options: ConnectOptions<'_>) -> BoxFuture<io::Result<Self>> {
Box::pin(RawConnection::establish(options))
fn establish(url: &Url) -> BoxFuture<io::Result<Self>> {
Box::pin(RawConnection::establish(url))
}
}