mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-19 08:39:44 +00:00
Integrate new protocol crate to connection
This commit is contained in:
@@ -8,6 +8,7 @@ edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
sqlx-core = { path = "../sqlx-core" }
|
||||
sqlx-postgres-protocol = { path = "../sqlx-postgres-protocol" }
|
||||
runtime = "=0.3.0-alpha.4"
|
||||
futures-preview = "=0.3.0-alpha.16"
|
||||
failure = "0.1"
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
use super::Connection;
|
||||
use crate::protocol::{
|
||||
client::{PasswordMessage, StartupMessage},
|
||||
server::Message as ServerMessage,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use sqlx_core::ConnectOptions;
|
||||
use md5::{Digest, Md5};
|
||||
use sqlx_postgres_protocol::{Authentication, Message, PasswordMessage, StartupMessage};
|
||||
use std::io;
|
||||
|
||||
pub async fn establish<'a, 'b: 'a>(
|
||||
@@ -14,61 +10,63 @@ pub async fn establish<'a, 'b: 'a>(
|
||||
) -> io::Result<()> {
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = [
|
||||
("user", options.user),
|
||||
("database", options.database),
|
||||
// TODO: Expose this property perhaps?
|
||||
(
|
||||
"application_name",
|
||||
Some(concat!(env!("CARGO_PKG_NAME"), "/v", env!("CARGO_PKG_VERSION"))),
|
||||
),
|
||||
let mut message = StartupMessage::builder();
|
||||
|
||||
if let Some(user) = options.user {
|
||||
// FIXME: User is technically required. We should default this like psql does.
|
||||
message = message.param("user", user);
|
||||
}
|
||||
|
||||
if let Some(database) = options.database {
|
||||
message = message.param("database", database);
|
||||
}
|
||||
|
||||
let message = message
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", Some("ISO, MDY")),
|
||||
.param("DateStyle", "ISO, MDY")
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", Some("iso_8601")),
|
||||
.param("IntervalStyle", "iso_8601")
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", Some("UTC")),
|
||||
.param("TimeZone", "UTC")
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", Some("3")),
|
||||
.param("extra_float_digits", "3")
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", Some("UTF-8")),
|
||||
];
|
||||
.param("client_encoding", "UTF-8")
|
||||
.build();
|
||||
|
||||
conn.send(StartupMessage { params: ¶ms }).await?;
|
||||
conn.send(message).await?;
|
||||
|
||||
// FIXME: This feels like it could be reduced (see other connection flows)
|
||||
while let Some(message) = conn.incoming.next().await {
|
||||
match message {
|
||||
ServerMessage::AuthenticationOk => {
|
||||
Message::Authentication(Authentication::Ok) => {
|
||||
// Do nothing; server is just telling us that
|
||||
// there is no password needed
|
||||
}
|
||||
|
||||
ServerMessage::AuthenticationClearTextPassword => {
|
||||
conn.send(PasswordMessage { password: options.password.unwrap_or_default() })
|
||||
.await?;
|
||||
Message::Authentication(Authentication::CleartextPassword) => {
|
||||
// FIXME: Should error early (before send) if the user did not supply a password
|
||||
conn.send(PasswordMessage::cleartext(options.password.unwrap_or_default())).await?;
|
||||
}
|
||||
|
||||
ServerMessage::AuthenticationMd5Password(body) => {
|
||||
// Hash password|username
|
||||
// FIXME: ConnectOptions should prepare a default user
|
||||
let pass_user =
|
||||
md5(options.password.unwrap_or_default(), options.user.unwrap_or_default());
|
||||
|
||||
let with_salt = md5(pass_user, body.salt);
|
||||
let password = format!("md5{}", with_salt);
|
||||
|
||||
conn.send(PasswordMessage { password: &password }).await?;
|
||||
Message::Authentication(Authentication::Md5Password { salt }) => {
|
||||
// FIXME: Should error early (before send) if the user did not supply a password
|
||||
conn.send(PasswordMessage::md5(
|
||||
options.password.unwrap_or_default(),
|
||||
options.user.unwrap_or_default(),
|
||||
&salt,
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
ServerMessage::BackendKeyData(body) => {
|
||||
conn.process_id = body.process_id;
|
||||
conn.secret_key = body.secret_key;
|
||||
Message::BackendKeyData(body) => {
|
||||
conn.process_id = body.process_id();
|
||||
conn.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
ServerMessage::ReadyForQuery(_) => {
|
||||
Message::ReadyForQuery(_) => {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -80,8 +78,3 @@ pub async fn establish<'a, 'b: 'a>(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn md5(a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> String {
|
||||
hex::encode(Md5::new().chain(a).chain(b).result())
|
||||
}
|
||||
|
||||
@@ -1,35 +1,33 @@
|
||||
use crate::protocol::{
|
||||
client::{Serialize, Terminate},
|
||||
server::Message as ServerMessage,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
|
||||
SinkExt, StreamExt,
|
||||
SinkExt,
|
||||
};
|
||||
use sqlx_core::ConnectOptions;
|
||||
use runtime::{net::TcpStream, task::JoinHandle};
|
||||
use sqlx_core::ConnectOptions;
|
||||
use sqlx_postgres_protocol::{Encode, Message, Terminate};
|
||||
use std::io;
|
||||
|
||||
mod establish;
|
||||
mod query;
|
||||
// mod query;
|
||||
|
||||
pub struct Connection {
|
||||
writer: WriteHalf<TcpStream>,
|
||||
incoming: mpsc::UnboundedReceiver<ServerMessage>,
|
||||
incoming: mpsc::UnboundedReceiver<Message>,
|
||||
|
||||
// Buffer used when serializing outgoing messages
|
||||
// FIXME: Use BytesMut
|
||||
wbuf: Vec<u8>,
|
||||
|
||||
// Handle to coroutine reading messages from the stream
|
||||
receiver: JoinHandle<io::Result<()>>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: i32,
|
||||
process_id: u32,
|
||||
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: i32,
|
||||
secret_key: u32,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
@@ -43,8 +41,8 @@ impl Connection {
|
||||
writer,
|
||||
receiver,
|
||||
incoming: rx,
|
||||
process_id: -1,
|
||||
secret_key: -1,
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
};
|
||||
|
||||
establish::establish(&mut conn, options).await?;
|
||||
@@ -52,9 +50,9 @@ impl Connection {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub async fn execute<'a, 'b: 'a>(&'a mut self, query: &'b str) -> io::Result<()> {
|
||||
query::query(self, query).await
|
||||
}
|
||||
// pub async fn execute<'a, 'b: 'a>(&'a mut self, query: &'b str) -> io::Result<()> {
|
||||
// query::query(self, query).await
|
||||
// }
|
||||
|
||||
pub async fn close(mut self) -> io::Result<()> {
|
||||
self.send(Terminate).await?;
|
||||
@@ -64,14 +62,16 @@ impl Connection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Send client-serializable message to the server
|
||||
async fn send<S>(&mut self, message: S) -> io::Result<()>
|
||||
// Send client message to the server
|
||||
async fn send<T>(&mut self, message: T) -> io::Result<()>
|
||||
where
|
||||
S: Serialize,
|
||||
T: Encode,
|
||||
{
|
||||
self.wbuf.clear();
|
||||
|
||||
message.serialize(&mut self.wbuf);
|
||||
message.encode(&mut self.wbuf)?;
|
||||
|
||||
log::trace!("sending: {:?}", bytes::Bytes::from(self.wbuf.clone()));
|
||||
|
||||
self.writer.write_all(&self.wbuf).await?;
|
||||
self.writer.flush().await?;
|
||||
@@ -82,7 +82,7 @@ impl Connection {
|
||||
|
||||
async fn receiver(
|
||||
mut reader: ReadHalf<TcpStream>,
|
||||
mut sender: mpsc::UnboundedSender<ServerMessage>,
|
||||
mut sender: mpsc::UnboundedSender<Message>,
|
||||
) -> io::Result<()> {
|
||||
let mut rbuf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
@@ -107,6 +107,7 @@ async fn receiver(
|
||||
}
|
||||
|
||||
// TODO: Need a select! on a channel that I can trigger to cancel this
|
||||
log::trace!("waiting to read ...");
|
||||
let cnt = reader.read(&mut rbuf[len..]).await?;
|
||||
|
||||
if cnt > 0 {
|
||||
@@ -117,28 +118,29 @@ async fn receiver(
|
||||
}
|
||||
|
||||
while len > 0 {
|
||||
log::trace!("{} in rbuf", len);
|
||||
log::trace!("rbuf: {:?}", rbuf);
|
||||
|
||||
let size = rbuf.len();
|
||||
let message = ServerMessage::deserialize(&mut rbuf)?;
|
||||
let message = Message::decode(&mut rbuf)?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
// TODO: Some messages should be kept behind here
|
||||
match message {
|
||||
Some(ServerMessage::ParameterStatus(body)) => {
|
||||
log::debug!("parameter {} = {}", body.name()?, body.value()?);
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
log::debug!("parameter: {} = {}", body.name(), body.value());
|
||||
}
|
||||
|
||||
Some(ServerMessage::NoticeResponse(body)) => {
|
||||
log::warn!("notice: {:?}", body);
|
||||
Some(Message::Response(body)) => {
|
||||
log::warn!("response: {:?}", body);
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
// TODO: Handle this error?
|
||||
sender.send(message).await.unwrap();
|
||||
}
|
||||
|
||||
None => {
|
||||
// Did not receive enough bytes to
|
||||
// deserialize a complete message
|
||||
// decode a complete message
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#![feature(non_exhaustive, async_await)]
|
||||
#![allow(clippy::needless_lifetimes)]
|
||||
|
||||
// mod connection;
|
||||
// pub use connection::Connection;
|
||||
mod connection;
|
||||
pub use connection::Connection;
|
||||
|
||||
Reference in New Issue
Block a user