diff --git a/examples/postgres.rs b/examples/postgres.rs index 7eccecd4..51b81901 100644 --- a/examples/postgres.rs +++ b/examples/postgres.rs @@ -1,88 +1,85 @@ -// #![feature(async_await)] +#![feature(async_await)] -// use sqlx::{postgres::Connection, ConnectOptions}; -// use std::io; +use sqlx::{postgres::Connection, ConnectOptions}; +use std::io; -// // TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL) -// // TODO: Connection strings ala postgres@localhost/sqlx_dev +// TODO: ToSql and FromSql (to [de]serialize values from/to Rust and SQL) +// TODO: Connection strings ala postgres@localhost/sqlx_dev -// #[runtime::main(runtime_tokio::Tokio)] -// async fn main() -> io::Result<()> { -// env_logger::init(); +#[runtime::main(runtime_tokio::Tokio)] +async fn main() -> io::Result<()> { + env_logger::init(); -// // Connect as postgres / postgres and DROP the sqlx__dev database -// // if exists and then re-create it -// let mut conn = Connection::establish( -// ConnectOptions::new() -// .host("127.0.0.1") -// .port(5432) -// .user("postgres") -// .database("postgres"), -// ) -// .await?; + // Connect as postgres / postgres and DROP the sqlx__dev database + // if exists and then re-create it + let mut conn = Connection::establish( + ConnectOptions::new() + .host("127.0.0.1") + .port(5432) + .user("postgres") + .database("postgres"), + ) + .await?; -// println!(" :: drop database (if exists) sqlx__dev"); + // println!(" :: drop database (if exists) sqlx__dev"); -// conn.prepare("DROP DATABASE IF EXISTS sqlx__dev") -// .execute() -// .await?; + // conn.prepare("DROP DATABASE IF EXISTS sqlx__dev") + // .execute() + // .await?; -// println!(" :: create database sqlx__dev"); + // println!(" :: create database sqlx__dev"); -// conn.prepare("CREATE DATABASE sqlx__dev").execute().await?; + // conn.prepare("CREATE DATABASE sqlx__dev").execute().await?; -// conn.close().await?; + // conn.close().await?; -// let mut conn = Connection::establish( -// ConnectOptions::new() -// .host("127.0.0.1") -// .port(5432) -// .user("postgres") -// .database("sqlx__dev"), -// ) -// .await?; + // let mut conn = Connection::establish( + // ConnectOptions::new() + // .host("127.0.0.1") + // .port(5432) + // .user("postgres") + // .database("sqlx__dev"), + // ) + // .await?; -// println!(" :: create schema"); + // println!(" :: create schema"); -// conn.prepare( -// r#" -// CREATE TABLE IF NOT EXISTS users ( -// id BIGSERIAL PRIMARY KEY, -// name TEXT NOT NULL -// ); -// "#, -// ) -// .execute() -// .await?; + // conn.prepare( + // r#" + // CREATE TABLE IF NOT EXISTS users ( + // id BIGSERIAL PRIMARY KEY, + // name TEXT NOT NULL + // ); + // "#, + // ) + // .execute() + // .await?; -// println!(" :: insert"); + // println!(" :: insert"); -// let new_row = conn -// .prepare("INSERT INTO users (name) VALUES ($1) RETURNING id") -// .bind(b"Joe") -// .get() -// .await?; + let row = conn + .prepare("SELECT pg_typeof($1), pg_typeof($2)") + .bind(20) + .bind_as::(10) + .get() + .await?; -// let new_id = new_row.as_ref().unwrap().get(0); + println!("{:?}", row); -// println!("insert {:?}", new_id); + // println!(" :: select"); -// // println!(" :: select"); + // conn.prepare("SELECT id FROM users") + // .select() + // .try_for_each(|row| { + // let id = row.get(0); -// // conn.prepare("SELECT id FROM users") -// // .select() -// // .try_for_each(|row| { -// // let id = row.get(0); + // println!("select {:?}", id); -// // println!("select {:?}", id); + // future::ok(()) + // }) + // .await?; -// // future::ok(()) -// // }) -// // .await?; + conn.close().await?; -// conn.close().await?; - -// Ok(()) -// } - -fn main() {} + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index b4196b64..893b6b43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,3 +21,5 @@ pub mod postgres; // Helper macro for writing long complex tests #[macro_use] pub mod macros; + +pub mod types; diff --git a/src/postgres/connection/execute.rs b/src/postgres/connection/execute.rs index 7ec7faa0..42157d61 100644 --- a/src/postgres/connection/execute.rs +++ b/src/postgres/connection/execute.rs @@ -2,23 +2,15 @@ use super::prepare::Prepare; use crate::postgres::protocol::{self, Message}; use std::io; -impl<'a> Prepare<'a> { +impl<'a, 'b> Prepare<'a, 'b> { pub async fn execute(self) -> io::Result { - // protocol::bind::trailer( - // &mut self.connection.wbuf, - // self.bind_state, - // self.bind_values, - // &[], - // ); + let conn = self.finish(); - // protocol::execute(&mut self.connection.wbuf, "", 0); - // protocol::sync(&mut self.connection.wbuf); - - self.connection.flush().await?; + conn.flush().await?; let mut rows = 0; - while let Some(message) = self.connection.receive().await? { + while let Some(message) = conn.receive().await? { match message { Message::BindComplete | Message::ParseComplete => { // Indicates successful completion of a phase diff --git a/src/postgres/connection/get.rs b/src/postgres/connection/get.rs index 2f4eec6b..5f64b4d6 100644 --- a/src/postgres/connection/get.rs +++ b/src/postgres/connection/get.rs @@ -2,24 +2,15 @@ use super::prepare::Prepare; use crate::postgres::protocol::{self, DataRow, Message}; use std::io; -impl<'a> Prepare<'a> { +impl<'a, 'b> Prepare<'a, 'b> { pub async fn get(self) -> io::Result> { - // protocol::bind::trailer( - // &mut self.connection.wbuf, - // self.bind_state, - // self.bind_values, - // &[], - // ); + let conn = self.finish(); - // protocol::execute(&mut self.connection.wbuf, "", 1); - // protocol::close::portal(&mut self.connection.wbuf, ""); - // protocol::sync(&mut self.connection.wbuf); - - self.connection.flush().await?; + conn.flush().await?; let mut row: Option = None; - while let Some(message) = self.connection.receive().await? { + while let Some(message) = conn.receive().await? { match message { Message::BindComplete | Message::ParseComplete diff --git a/src/postgres/connection/mod.rs b/src/postgres/connection/mod.rs index c3f87a29..86b0b757 100644 --- a/src/postgres/connection/mod.rs +++ b/src/postgres/connection/mod.rs @@ -57,7 +57,7 @@ impl Connection { Ok(conn) } - pub fn prepare(&mut self, query: &str) -> prepare::Prepare { + pub fn prepare<'a, 'b>(&'a mut self, query: &'b str) -> prepare::Prepare<'a, 'b> { prepare::prepare(self, query) } diff --git a/src/postgres/connection/prepare.rs b/src/postgres/connection/prepare.rs index 1ca223fe..b409c1a1 100644 --- a/src/postgres/connection/prepare.rs +++ b/src/postgres/connection/prepare.rs @@ -1,29 +1,65 @@ use super::Connection; use crate::{ - postgres::protocol::{self, Parse}, - types::ToSql, + postgres::protocol::{self, BindValues}, + types::{SqlType, ToSql, ToSqlAs}, }; -pub struct Prepare<'a> { +pub struct Prepare<'a, 'b> { + query: &'b str, pub(super) connection: &'a mut Connection, + pub(super) bind: BindValues, } #[inline] -pub fn prepare<'a, 'b>(connection: &'a mut Connection, query: &'b str) -> Prepare<'a> { +pub fn prepare<'a, 'b>(connection: &'a mut Connection, query: &'b str) -> Prepare<'a, 'b> { // TODO: Use a hash map to cache the parse // TODO: Use named statements - connection.write(Parse { - portal: "", + Prepare { + connection, query, - param_types: &[], - }); - - Prepare { connection } + bind: BindValues::new(), + } } -// impl<'a> Prepare<'a> { -// #[inline] -// pub fn bind(mut self, value: impl ToSql) -> Self { -// unimplemented!() -// } -// } +impl<'a, 'b> Prepare<'a, 'b> { + #[inline] + pub fn bind(mut self, value: T) -> Self + where + T: ToSqlAs<::Type>, + { + self.bind.add(value); + self + } + + #[inline] + pub fn bind_as>(mut self, value: T) -> Self { + self.bind.add_as::(value); + self + } + + pub(super) fn finish(self) -> &'a mut Connection { + self.connection.write(protocol::Parse { + portal: "", + query: self.query, + param_types: self.bind.types(), + }); + + self.connection.write(protocol::Bind { + portal: "", + statement: "", + formats: self.bind.formats(), + values_len: self.bind.values_len(), + values: self.bind.values(), + result_formats: &[], + }); + + self.connection.write(protocol::Execute { + portal: "", + limit: 0, + }); + + self.connection.write(protocol::Sync); + + self.connection + } +} diff --git a/src/postgres/connection/select.rs b/src/postgres/connection/select.rs index 94f4814f..e1ec1e68 100644 --- a/src/postgres/connection/select.rs +++ b/src/postgres/connection/select.rs @@ -3,20 +3,10 @@ use crate::postgres::protocol::{self, DataRow, Message}; use futures::{stream, Stream}; use std::io; -impl<'a> Prepare<'a> { +impl<'a, 'b> Prepare<'a, 'b> { pub fn select(self) -> impl Stream> + 'a + Unpin { - // protocol::bind::trailer( - // &mut self.connection.wbuf, - // self.bind_state, - // self.bind_values, - // &[], - // ); - - // protocol::execute(&mut self.connection.wbuf, "", 0); - // protocol::sync(&mut self.connection.wbuf); - // FIXME: Manually implement Stream on a new type to avoid the unfold adapter - stream::unfold(self.connection, |conn| { + stream::unfold(self.finish(), |conn| { Box::pin(async { if !conn.wbuf.is_empty() { if let Err(e) = conn.flush().await { diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 756d3345..8d959ded 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -2,3 +2,4 @@ mod connection; pub use connection::Connection; mod protocol; +pub mod types; diff --git a/src/postgres/protocol/bind.rs b/src/postgres/protocol/bind.rs index 6094d269..211b19e9 100644 --- a/src/postgres/protocol/bind.rs +++ b/src/postgres/protocol/bind.rs @@ -1,21 +1,87 @@ use super::{BufMut, Encode}; +use crate::types::{SqlType, ToSql, ToSqlAs}; use byteorder::{BigEndian, ByteOrder}; +const TEXT: i16 = 0; +const BINARY: i16 = 1; + +// FIXME: Think of a better name here +pub struct BindValues { + types: Vec, + formats: Vec, + values_len: i16, + values: Vec, +} + +impl BindValues { + pub fn new() -> Self { + BindValues { + types: Vec::new(), + formats: Vec::new(), + values: Vec::new(), + values_len: 0, + } + } + + #[inline] + pub fn add(&mut self, value: T) + where + T: ToSqlAs<::Type>, + { + self.add_as::(value); + } + + pub fn add_as>(&mut self, value: T) { + // TODO: When/if we receive types that do _not_ support BINARY, we need to check here + // TODO: There is no need to be explicit unless we are expecting mixed BINARY / TEXT + + self.types.push(ST::OID as i32); + + let pos = self.values.len(); + self.values.put_int_32(0); // skip over len + + value.to_sql(&mut self.values); + self.values_len += 1; + + // Write-back the len to the beginning of this frame (not including the len of len) + let len = self.values.len() - pos - 4; + BigEndian::write_i32(&mut self.values[pos..], len as i32); + } + + pub fn types(&self) -> &[i32] { + &self.types + } + + pub fn formats(&self) -> &[i16] { + // &self.formats + &[BINARY] + } + + pub fn values(&self) -> &[u8] { + &self.values + } + + pub fn values_len(&self) -> i16 { + self.values_len + } +} + pub struct Bind<'a> { /// The name of the destination portal (an empty string selects the unnamed portal). - portal: &'a str, + pub portal: &'a str, /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). - statement: &'a str, + pub statement: &'a str, /// The parameter format codes. Each must presently be zero (text) or one (binary). /// /// There can be zero to indicate that there are no parameters or that the parameters all use the /// default format (text); or one, in which case the specified format code is applied to all /// parameters; or it can equal the actual number of parameters. - formats: &'a [i16], + pub formats: &'a [i16], - values: &'a [u8], + pub values_len: i16, + pub values: &'a [u8], /// The result-column format codes. Each must presently be zero (text) or one (binary). /// @@ -23,7 +89,7 @@ pub struct Bind<'a> { /// result columns should all use the default format (text); or one, in which /// case the specified format code is applied to all result columns (if any); /// or it can equal the actual number of result columns of the query. - result_formats: &'a [i16], + pub result_formats: &'a [i16], } impl Encode for Bind<'_> { @@ -35,8 +101,13 @@ impl Encode for Bind<'_> { buf.put_str(self.portal); buf.put_str(self.statement); + buf.put_array_int_16(&self.formats); + + buf.put_int_16(self.values_len); + buf.put(self.values); + buf.put_array_int_16(&self.result_formats); // Write-back the len to the beginning of this frame @@ -44,3 +115,32 @@ impl Encode for Bind<'_> { BigEndian::write_i32(&mut buf[pos..], len as i32); } } + +#[cfg(test)] +mod tests { + use super::{Bind, BindValues, BufMut, Encode}; + + const BIND: &[u8] = b"B\0\0\0\x16\0\0\0\0\0\x02\0\0\0\x011\0\0\0\x012\0\0"; + + #[test] + fn it_encodes_bind_for_two() { + let mut buf = Vec::new(); + + let mut builder = BindValues::new(); + builder.add("1"); + builder.add("2"); + + let bind = Bind { + portal: "", + statement: "", + formats: builder.formats(), + values_len: builder.values_len(), + values: builder.values(), + result_formats: &[], + }; + + bind.encode(&mut buf); + + assert_eq!(buf, BIND); + } +} diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index f5f4dfa0..1910286d 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -22,7 +22,7 @@ mod terminate; // TODO: mod ssl_request; pub use self::{ - bind::Bind, + bind::{Bind, BindValues}, cancel_request::CancelRequest, close::Close, copy_data::CopyData, diff --git a/src/postgres/types/mod.rs b/src/postgres/types/mod.rs new file mode 100644 index 00000000..cd0a140d --- /dev/null +++ b/src/postgres/types/mod.rs @@ -0,0 +1,119 @@ +use crate::types::{SqlType, ToSql, ToSqlAs}; + +// TODO: Generalize by Backend and move common types to crate [sqlx::types] + +// Character +// https://www.postgresql.org/docs/devel/datatype-character.html + +pub struct Text; + +impl SqlType for Text { + const OID: u32 = 25; +} + +impl ToSql for &'_ str { + type Type = Text; +} + +impl ToSqlAs for &'_ str { + #[inline] + fn to_sql(self, buf: &mut Vec) { + buf.extend_from_slice(self.as_bytes()); + } +} + +// Numeric +// https://www.postgresql.org/docs/devel/datatype-numeric.html + +// i16 +pub struct SmallInt; + +impl SqlType for SmallInt { + const OID: u32 = 21; +} + +impl ToSql for i16 { + type Type = SmallInt; +} + +impl ToSqlAs for i16 { + #[inline] + fn to_sql(self, buf: &mut Vec) { + buf.extend_from_slice(&self.to_be_bytes()); + } +} + +// i32 +pub struct Int; + +impl SqlType for Int { + const OID: u32 = 23; +} + +impl ToSql for i32 { + type Type = Int; +} + +impl ToSqlAs for i32 { + #[inline] + fn to_sql(self, buf: &mut Vec) { + buf.extend_from_slice(&self.to_be_bytes()); + } +} + +// i64 +pub struct BigInt; + +impl SqlType for BigInt { + const OID: u32 = 20; +} + +impl ToSql for i64 { + type Type = BigInt; +} + +impl ToSqlAs for i64 { + #[inline] + fn to_sql(self, buf: &mut Vec) { + buf.extend_from_slice(&self.to_be_bytes()); + } +} + +// decimal? +// TODO pub struct Decimal; + +// f32 +pub struct Real; + +impl SqlType for Real { + const OID: u32 = 700; +} + +impl ToSql for f32 { + type Type = Real; +} + +impl ToSqlAs for f32 { + #[inline] + fn to_sql(self, buf: &mut Vec) { + (self.to_bits() as i32).to_sql(buf); + } +} + +// f64 +pub struct Double; + +impl SqlType for Double { + const OID: u32 = 701; +} + +impl ToSql for f64 { + type Type = Double; +} + +impl ToSqlAs for f64 { + #[inline] + fn to_sql(self, buf: &mut Vec) { + (self.to_bits() as i64).to_sql(buf); + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 00000000..778781ac --- /dev/null +++ b/src/types.rs @@ -0,0 +1,16 @@ +// TODO: Better name for ToSql/ToSqlAs. ToSqlAs is the _conversion_ trait. +// ToSql is type fallback for Rust/SQL (e.g., what is the probable SQL type for this Rust type) + +pub trait SqlType { + // FIXME: This is a postgres thing + const OID: u32; +} + +pub trait ToSql { + /// SQL type that should be inferred from the implementing Rust type. + type Type: SqlType; +} + +pub trait ToSqlAs: ToSql { + fn to_sql(self, buf: &mut Vec); +}