mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 04:04:56 +00:00
feat(postgres): stub start-up flow
This commit is contained in:
parent
424d4b7aa1
commit
1eb1cd3ea9
@ -38,6 +38,7 @@ futures-io = { version = "0.3", optional = true }
|
||||
bytes = "1.0"
|
||||
memchr = "2.3"
|
||||
bitflags = "1.2"
|
||||
base64 = "0.13.0"
|
||||
|
||||
[dev-dependencies]
|
||||
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }
|
||||
|
||||
@ -2,10 +2,13 @@ use std::fmt::{self, Debug, Formatter};
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
use futures_util::future::{BoxFuture, FutureExt, TryFutureExt};
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Close, Connect, Connection, Runtime};
|
||||
|
||||
use crate::stream::PgStream;
|
||||
use crate::Postgres;
|
||||
use crate::{PgConnectOptions, Postgres};
|
||||
|
||||
mod connect;
|
||||
|
||||
/// A single connection (also known as a session) to a
|
||||
/// PostgreSQL database server.
|
||||
@ -64,10 +67,10 @@ impl<Rt: Runtime> Connection<Rt> for PgConnection<Rt> {
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Connect<Rt> for PgConnection<Rt> {
|
||||
type Options = PostgresConnectOptions;
|
||||
type Options = PgConnectOptions;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
fn connect_with(options: &PostgresConnectOptions) -> BoxFuture<'_, sqlx_core::Result<Self>>
|
||||
fn connect_with(options: &PgConnectOptions) -> BoxFuture<'_, sqlx_core::Result<Self>>
|
||||
where
|
||||
Self: Sized,
|
||||
Rt: sqlx_core::Async,
|
||||
@ -82,6 +85,52 @@ impl<Rt: Runtime> Close<Rt> for PgConnection<Rt> {
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
todo!()
|
||||
Box::pin(async move {
|
||||
self.stream.close_async().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
mod blocking {
|
||||
use sqlx_core::blocking::{Close, Connect, Connection, Runtime};
|
||||
|
||||
use super::{PgConnectOptions, PgConnection, Postgres};
|
||||
|
||||
impl<Rt: Runtime> Connection<Rt> for PgConnection<Rt> {
|
||||
#[inline]
|
||||
fn ping(&mut self) -> sqlx_core::Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn describe<'x, 'e, 'q>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
) -> sqlx_core::Result<sqlx_core::Describe<Postgres>>
|
||||
where
|
||||
'e: 'x,
|
||||
'q: 'x,
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Connect<Rt> for PgConnection<Rt> {
|
||||
#[inline]
|
||||
fn connect_with(options: &PgConnectOptions) -> sqlx_core::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Self::connect_blocking(options)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Close<Rt> for PgConnection<Rt> {
|
||||
#[inline]
|
||||
fn close(mut self) -> sqlx_core::Result<()> {
|
||||
self.stream.close_blocking()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
129
sqlx-postgres/src/connection/connect.rs
Normal file
129
sqlx-postgres/src/connection/connect.rs
Normal file
@ -0,0 +1,129 @@
|
||||
//! Implements start-up flow.
|
||||
//!
|
||||
//! To begin a session, a frontend opens a connection to the server
|
||||
//! and sends a startup message.
|
||||
//!
|
||||
//! The server then sends an appropriate authentication request message, to
|
||||
//! which the frontend must reply with an appropriate authentication
|
||||
//! response message.
|
||||
//!
|
||||
//! The authentication cycle ends with the server either rejecting
|
||||
//! the connection attempt (ErrorResponse), or sending AuthenticationOk.
|
||||
//!
|
||||
//! <https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3>
|
||||
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Error, Result, Runtime};
|
||||
|
||||
use crate::protocol::backend::{Authentication, BackendMessage, BackendMessageType};
|
||||
use crate::protocol::frontend::Startup;
|
||||
use crate::{PgClientError, PgConnectOptions, PgConnection};
|
||||
|
||||
impl<Rt: Runtime> PgConnection<Rt> {
|
||||
fn write_startup_message(&mut self, options: &PgConnectOptions) -> Result<()> {
|
||||
let params = vec![
|
||||
("user", options.get_username()),
|
||||
("database", options.get_database()),
|
||||
("application_name", options.get_application_name()),
|
||||
// sets the text display format for date and time values
|
||||
// as well as the rules for interpreting ambiguous date input values
|
||||
("DateStyle", Some("ISO, MDY")),
|
||||
// sets the client-side encoding (charset)
|
||||
// NOTE: this must not be changed, too much in the driver depends on this being set to UTF-8
|
||||
("client_encoding", Some("UTF8")),
|
||||
// sets the timezone for displaying and interpreting time stamps
|
||||
// NOTE: this is only used to assume timestamptz values are in UTC
|
||||
("TimeZone", Some("UTC")),
|
||||
];
|
||||
|
||||
self.stream.write_message(&Startup(¶ms))
|
||||
}
|
||||
|
||||
fn handle_startup_response(
|
||||
&mut self,
|
||||
options: &PgConnectOptions,
|
||||
message: BackendMessage,
|
||||
) -> Result<bool> {
|
||||
match message.ty {
|
||||
BackendMessageType::Authentication => match message.deserialize()? {
|
||||
Authentication::Ok => {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
Authentication::Md5Password(_) => {
|
||||
todo!("md5")
|
||||
}
|
||||
|
||||
Authentication::CleartextPassword => {
|
||||
todo!("cleartext")
|
||||
}
|
||||
|
||||
Authentication::Sasl(_) => todo!("sasl"),
|
||||
Authentication::SaslContinue(_) => todo!("sasl continue"),
|
||||
Authentication::SaslFinal(_) => todo!("sasl final"),
|
||||
},
|
||||
|
||||
ty => {
|
||||
return Err(Error::client(PgClientError::UnexpectedMessageType {
|
||||
ty: ty as u8,
|
||||
context: "starting up",
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_connect {
|
||||
(@blocking @new $options:ident) => {
|
||||
NetStream::connect($options.address.as_ref())?
|
||||
};
|
||||
|
||||
(@new $options:ident) => {
|
||||
NetStream::connect_async($options.address.as_ref()).await?
|
||||
};
|
||||
|
||||
($(@$blocking:ident)? $options:ident) => {{
|
||||
// open a network stream to the database server
|
||||
let stream = impl_connect!($(@$blocking)? @new $options);
|
||||
|
||||
// construct a <PgConnection> around the network stream
|
||||
// wraps the stream in a <BufStream> to buffer read and write
|
||||
let mut self_ = Self::new(stream);
|
||||
|
||||
// to begin a session, a frontend should send a startup message
|
||||
// this is built up of various startup parameters that control the connection
|
||||
self_.write_startup_message($options)?;
|
||||
|
||||
// the server then uses this information and the contents of
|
||||
// its configuration files (such as pg_hba.conf) to determine whether the connection is
|
||||
// provisionally acceptable, and what additional
|
||||
// authentication is required (if any).
|
||||
loop {
|
||||
let message = read_message!($(@$blocking)? self_.stream);
|
||||
if self_.handle_startup_response($options, message)? {
|
||||
// complete, successful authentication
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self_)
|
||||
}};
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> PgConnection<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn connect_async(options: &PgConnectOptions) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
impl_connect!(options)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn connect_blocking(options: &PgConnectOptions) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
impl_connect!(@blocking options)
|
||||
}
|
||||
}
|
||||
4
sqlx-postgres/src/error.rs
Normal file
4
sqlx-postgres/src/error.rs
Normal file
@ -0,0 +1,4 @@
|
||||
mod client;
|
||||
mod database;
|
||||
|
||||
pub use client::PgClientError;
|
||||
40
sqlx-postgres/src/error/client.rs
Normal file
40
sqlx-postgres/src/error/client.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use crate::protocol::backend::BackendMessageType;
|
||||
use sqlx_core::ClientError;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use std::str::Utf8Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum PgClientError {
|
||||
// attempting to interpret data from postgres as UTF-8, when it should
|
||||
// be UTF-8, but for some reason (data corruption?) it is not
|
||||
NotUtf8(Utf8Error),
|
||||
UnknownAuthenticationMethod(u32),
|
||||
UnknownMessageType(u8),
|
||||
UnexpectedMessageType { ty: u8, context: &'static str },
|
||||
}
|
||||
|
||||
impl Display for PgClientError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::NotUtf8(source) => write!(f, "{}", source),
|
||||
|
||||
Self::UnknownAuthenticationMethod(method) => {
|
||||
write!(f, "unknown authentication method: {}", method)
|
||||
}
|
||||
|
||||
Self::UnknownMessageType(ty) => {
|
||||
write!(f, "unknown protocol message type: '{}' ({})", *ty as char, *ty)
|
||||
}
|
||||
|
||||
Self::UnexpectedMessageType { ty, context } => {
|
||||
write!(f, "unexpected message {:?} '{}' while {}", ty, (*ty as u8 as char), context)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for PgClientError {}
|
||||
|
||||
impl ClientError for PgClientError {}
|
||||
0
sqlx-postgres/src/error/database.rs
Normal file
0
sqlx-postgres/src/error/database.rs
Normal file
3
sqlx-postgres/src/io.rs
Normal file
3
sqlx-postgres/src/io.rs
Normal file
@ -0,0 +1,3 @@
|
||||
mod write;
|
||||
|
||||
pub(crate) use write::PgWriteExt;
|
||||
31
sqlx-postgres/src/io/write.rs
Normal file
31
sqlx-postgres/src/io/write.rs
Normal file
@ -0,0 +1,31 @@
|
||||
use sqlx_core::io::WriteExt;
|
||||
use sqlx_core::Result;
|
||||
|
||||
pub trait PgWriteExt: WriteExt {
|
||||
fn write_len_prefixed<F>(&mut self, f: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<()>;
|
||||
}
|
||||
|
||||
impl PgWriteExt for Vec<u8> {
|
||||
/// Writes a length-prefixed message, this is used when encoding nearly
|
||||
/// all messages as postgres wants us to send the length of the
|
||||
/// often-variable-sized messages up front.
|
||||
fn write_len_prefixed<F>(&mut self, f: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<()>,
|
||||
{
|
||||
// reserve space to write the prefixed length
|
||||
let offset = self.len();
|
||||
self.extend_from_slice(&[0; 4]);
|
||||
|
||||
// write the main body of the message
|
||||
f(self)?;
|
||||
|
||||
// now calculate the size of what we wrote and set the length value
|
||||
let size = (self.len() - offset) as i32;
|
||||
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -28,8 +28,8 @@ mod stream;
|
||||
mod column;
|
||||
mod connection;
|
||||
mod database;
|
||||
// mod error;
|
||||
// mod io;
|
||||
mod error;
|
||||
mod io;
|
||||
mod options;
|
||||
mod output;
|
||||
mod protocol;
|
||||
@ -46,9 +46,10 @@ pub mod types;
|
||||
// mod mock;
|
||||
|
||||
pub use column::PgColumn;
|
||||
// pub use connection::PgConnection;
|
||||
pub use connection::PgConnection;
|
||||
pub use database::Postgres;
|
||||
// pub use error::PgDatabaseError;
|
||||
pub use error::PgClientError;
|
||||
pub use options::PgConnectOptions;
|
||||
pub use output::PgOutput;
|
||||
pub use query_result::PgQueryResult;
|
||||
|
||||
@ -1,3 +1,2 @@
|
||||
mod message;
|
||||
|
||||
pub(crate) use message::{BackendMessage, BackendMessageType};
|
||||
pub(crate) mod backend;
|
||||
pub(crate) mod frontend;
|
||||
|
||||
7
sqlx-postgres/src/protocol/backend.rs
Normal file
7
sqlx-postgres/src/protocol/backend.rs
Normal file
@ -0,0 +1,7 @@
|
||||
mod auth;
|
||||
mod message;
|
||||
mod sasl;
|
||||
|
||||
pub(crate) use auth::{Authentication, AuthenticationMd5Password};
|
||||
pub(crate) use message::{BackendMessage, BackendMessageType};
|
||||
pub(crate) use sasl::{AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal};
|
||||
61
sqlx-postgres/src/protocol/backend/auth.rs
Normal file
61
sqlx-postgres/src/protocol/backend/auth.rs
Normal file
@ -0,0 +1,61 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::{Error, Result};
|
||||
|
||||
use crate::protocol::backend::{
|
||||
AuthenticationSasl, AuthenticationSaslContinue, AuthenticationSaslFinal,
|
||||
};
|
||||
use crate::PgClientError;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Authentication {
|
||||
/// The authentication exchange is successfully completed.
|
||||
Ok,
|
||||
|
||||
/// The frontend must now send a PasswordMessage containing the
|
||||
/// password in clear-text form.
|
||||
CleartextPassword,
|
||||
|
||||
/// The frontend must now send a PasswordMessage containing the
|
||||
/// password (with user name) encrypted via MD5.
|
||||
Md5Password(AuthenticationMd5Password),
|
||||
|
||||
/// The frontend must now initiate a SASL negotiation,
|
||||
/// using one of the SASL mechanisms listed in the message.
|
||||
Sasl(AuthenticationSasl),
|
||||
|
||||
/// This message contains challenge data from the previous step of
|
||||
/// SASL negotiation.
|
||||
SaslContinue(AuthenticationSaslContinue),
|
||||
|
||||
/// SASL authentication has completed with additional mechanism-specific
|
||||
/// data for the client.
|
||||
SaslFinal(AuthenticationSaslFinal),
|
||||
}
|
||||
|
||||
impl Deserialize<'_> for Authentication {
|
||||
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
match buf.get_u32() {
|
||||
0 => Ok(Self::Ok),
|
||||
3 => Ok(Self::CleartextPassword),
|
||||
|
||||
5 => {
|
||||
let mut salt = [0_u8; 4];
|
||||
buf.copy_to_slice(&mut salt);
|
||||
|
||||
Ok(Self::Md5Password(AuthenticationMd5Password { salt }))
|
||||
}
|
||||
|
||||
10 => AuthenticationSasl::deserialize(buf).map(Self::Sasl),
|
||||
11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue),
|
||||
12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal),
|
||||
|
||||
ty => Err(Error::client(PgClientError::UnknownAuthenticationMethod(ty))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AuthenticationMd5Password {
|
||||
pub(crate) salt: [u8; 4],
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::PgClientError;
|
||||
use bytes::Bytes;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::{Error, Result};
|
||||
@ -66,7 +67,7 @@ impl TryFrom<u8> for BackendMessageType {
|
||||
b'c' => Self::CopyDone,
|
||||
|
||||
_ => {
|
||||
todo!("protocol unexpected data error")
|
||||
return Err(Error::client(PgClientError::UnknownMessageType(ty)));
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -74,7 +75,7 @@ impl TryFrom<u8> for BackendMessageType {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BackendMessage {
|
||||
pub(crate) r#type: BackendMessageType,
|
||||
pub(crate) ty: BackendMessageType,
|
||||
pub(crate) contents: Bytes,
|
||||
}
|
||||
|
||||
88
sqlx-postgres/src/protocol/backend/sasl.rs
Normal file
88
sqlx-postgres/src/protocol/backend/sasl.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytestring::ByteString;
|
||||
use sqlx_core::io::Deserialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AuthenticationSasl(Bytes);
|
||||
|
||||
impl Deserialize<'_> for AuthenticationSasl {
|
||||
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
|
||||
Ok(Self(buf))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AuthenticationSaslContinue {
|
||||
pub(crate) salt: Box<[u8]>,
|
||||
pub(crate) iterations: u32,
|
||||
pub(crate) nonce: ByteString,
|
||||
pub(crate) message: ByteString,
|
||||
}
|
||||
|
||||
impl Deserialize<'_> for AuthenticationSaslContinue {
|
||||
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
|
||||
let mut iterations: u32 = 4096;
|
||||
let mut salt = Vec::new();
|
||||
let mut nonce = Bytes::new();
|
||||
|
||||
// [Example]
|
||||
// r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096
|
||||
|
||||
for item in buf.split(|b| *b == b',') {
|
||||
let key = item[0];
|
||||
let value = &item[2..];
|
||||
|
||||
match key {
|
||||
b'r' => {
|
||||
nonce = buf.slice_ref(value);
|
||||
}
|
||||
|
||||
b'i' => {
|
||||
iterations = atoi::atoi(value).unwrap_or(4096);
|
||||
}
|
||||
|
||||
b's' => {
|
||||
// FIXME: raise proper protocol errors
|
||||
salt = base64::decode(value).unwrap();
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
iterations,
|
||||
salt: salt.into_boxed_slice(),
|
||||
|
||||
// FIXME: raise proper protocol errors
|
||||
nonce: ByteString::try_from(nonce).unwrap(),
|
||||
message: ByteString::try_from(buf).unwrap(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AuthenticationSaslFinal {
|
||||
pub(crate) verifier: Box<[u8]>,
|
||||
}
|
||||
|
||||
impl Deserialize<'_> for AuthenticationSaslFinal {
|
||||
fn deserialize_with(buf: Bytes, _: ()) -> Result<Self> {
|
||||
let mut verifier = Vec::new();
|
||||
|
||||
for item in buf.split(|b| *b == b',') {
|
||||
let key = item[0];
|
||||
let value = &item[2..];
|
||||
|
||||
if let b'v' = key {
|
||||
// FIXME: raise proper protocol errors
|
||||
verifier = base64::decode(value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { verifier: verifier.into_boxed_slice() })
|
||||
}
|
||||
}
|
||||
5
sqlx-postgres/src/protocol/frontend.rs
Normal file
5
sqlx-postgres/src/protocol/frontend.rs
Normal file
@ -0,0 +1,5 @@
|
||||
mod startup;
|
||||
mod terminate;
|
||||
|
||||
pub(crate) use startup::Startup;
|
||||
pub(crate) use terminate::Terminate;
|
||||
55
sqlx-postgres/src/protocol/frontend/startup.rs
Normal file
55
sqlx-postgres/src/protocol/frontend/startup.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
use crate::io::PgWriteExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Startup<'a>(pub(crate) &'a [(&'a str, Option<&'a str>)]);
|
||||
|
||||
impl Serialize<'_> for Startup<'_> {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.write_len_prefixed(|buf| {
|
||||
// The protocol version number. The most significant 16 bits are the
|
||||
// major version number (3 for the protocol described here). The least
|
||||
// significant 16 bits are the minor version number (0
|
||||
// for the protocol described here)
|
||||
buf.extend(&196_608_i32.to_be_bytes());
|
||||
|
||||
// For each startup parameter, write the name and value
|
||||
// as NUL-terminated strings
|
||||
for (name, value) in self.0 {
|
||||
if let Some(value) = value {
|
||||
write_startup_param(buf, name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Followed by a trailing NUL
|
||||
buf.push(0);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn write_startup_param(buf: &mut Vec<u8>, name: &str, value: &str) {
|
||||
buf.reserve(name.len() + value.len() + 2);
|
||||
buf.extend(name.as_bytes());
|
||||
buf.push(0);
|
||||
buf.extend(value.as_bytes());
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Serialize, Startup};
|
||||
|
||||
#[test]
|
||||
fn should_encode_startup() {
|
||||
let mut buf = Vec::new();
|
||||
let m = Startup(&[("user", Some("postgres")), ("database", Some("postgres"))]);
|
||||
|
||||
m.serialize(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0");
|
||||
}
|
||||
}
|
||||
32
sqlx-postgres/src/protocol/frontend/terminate.rs
Normal file
32
sqlx-postgres/src/protocol/frontend/terminate.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use sqlx_core::io::Serialize;
|
||||
use sqlx_core::Result;
|
||||
|
||||
/// On receipt of this message, the backend closes the connection
|
||||
/// and terminates.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Terminate;
|
||||
|
||||
impl Serialize<'_> for Terminate {
|
||||
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
|
||||
buf.push(b'X');
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sqlx_core::io::Serialize;
|
||||
|
||||
use super::Terminate;
|
||||
|
||||
#[test]
|
||||
fn should_serialize() -> anyhow::Result<()> {
|
||||
let mut buf = Vec::new();
|
||||
Terminate.serialize(&mut buf)?;
|
||||
|
||||
assert_eq!(&buf, &[b'X']);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -3,11 +3,12 @@ use std::fmt::Debug;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::Buf;
|
||||
use sqlx_core::io::{BufStream, Serialize};
|
||||
use sqlx_core::io::{BufStream, Serialize, Stream};
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Result, Runtime};
|
||||
|
||||
use crate::protocol::{BackendMessage, BackendMessageType};
|
||||
use crate::protocol::backend::{BackendMessage, BackendMessageType};
|
||||
use crate::protocol::frontend::Terminate;
|
||||
|
||||
/// Reads and writes messages to and from the PostgreSQL database server.
|
||||
///
|
||||
@ -38,6 +39,10 @@ impl<Rt: Runtime> PgStream<Rt> {
|
||||
where
|
||||
T: Serialize<'ser> + Debug,
|
||||
{
|
||||
log::trace!("write > {:?}", message);
|
||||
|
||||
message.serialize(self.stream.buffer())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -86,7 +91,7 @@ impl<Rt: Runtime> PgStream<Rt> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
_ => Ok(Some(BackendMessage { contents, r#type: ty })),
|
||||
_ => Ok(Some(BackendMessage { contents, ty })),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -159,3 +164,29 @@ macro_rules! read_message {
|
||||
$stream.read_message_async().await?
|
||||
};
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> PgStream<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn close_async(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
self.write_message(&Terminate)?;
|
||||
self.flush_async().await?;
|
||||
self.shutdown_async().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn close_blocking(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
self.write_message(&Terminate)?;
|
||||
self.flush()?;
|
||||
self.shutdown()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user