mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 23:35:20 +00:00
141 lines
3.9 KiB
Rust
141 lines
3.9 KiB
Rust
use crate::protocol::{
|
|
client::{Serialize, Terminate},
|
|
server::Message as ServerMessage,
|
|
};
|
|
use bytes::BytesMut;
|
|
use futures::{
|
|
channel::mpsc,
|
|
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
|
|
SinkExt, StreamExt,
|
|
};
|
|
use mason_core::ConnectOptions;
|
|
use runtime::{net::TcpStream, task::JoinHandle};
|
|
use std::io;
|
|
|
|
mod establish;
|
|
|
|
pub struct Connection {
|
|
writer: WriteHalf<TcpStream>,
|
|
incoming: mpsc::UnboundedReceiver<ServerMessage>,
|
|
|
|
// Buffer used when serializing outgoing messages
|
|
wbuf: Vec<u8>,
|
|
|
|
// Handle to coroutine reading messages from the stream
|
|
receiver: JoinHandle<io::Result<()>>,
|
|
|
|
// Process ID of the Backend
|
|
process_id: i32,
|
|
|
|
// Backend-unique key to use to send a cancel query message to the server
|
|
secret_key: i32,
|
|
}
|
|
|
|
impl Connection {
|
|
pub async fn establish(options: ConnectOptions<'_>) -> io::Result<Self> {
|
|
let stream = TcpStream::connect((options.host, options.port)).await?;
|
|
let (reader, writer) = stream.split();
|
|
let (tx, rx) = mpsc::unbounded();
|
|
let receiver = runtime::spawn(receiver(reader, tx));
|
|
let mut conn = Self {
|
|
wbuf: Vec::with_capacity(1024),
|
|
writer,
|
|
receiver,
|
|
incoming: rx,
|
|
process_id: -1,
|
|
secret_key: -1,
|
|
};
|
|
|
|
establish::establish(&mut conn, options).await?;
|
|
|
|
Ok(conn)
|
|
}
|
|
|
|
pub async fn close(mut self) -> io::Result<()> {
|
|
self.send(Terminate).await?;
|
|
self.writer.close().await?;
|
|
self.receiver.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// Send client-serializable message to the server
|
|
async fn send<S>(&mut self, message: S) -> io::Result<()>
|
|
where
|
|
S: Serialize,
|
|
{
|
|
self.wbuf.clear();
|
|
|
|
message.serialize(&mut self.wbuf);
|
|
|
|
self.writer.write_all(&self.wbuf).await?;
|
|
self.writer.flush().await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
async fn receiver(
|
|
mut reader: ReadHalf<TcpStream>,
|
|
mut sender: mpsc::UnboundedSender<ServerMessage>,
|
|
) -> io::Result<()> {
|
|
let mut rbuf = BytesMut::with_capacity(0);
|
|
let mut len = 0;
|
|
|
|
loop {
|
|
// This uses an adaptive system to extend the vector when it fills. We want to
|
|
// avoid paying to allocate and zero a huge chunk of memory if the reader only
|
|
// has 4 bytes while still making large reads if the reader does have a ton
|
|
// of data to return.
|
|
|
|
// See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
|
|
|
|
if len == rbuf.len() {
|
|
rbuf.reserve(32);
|
|
|
|
unsafe {
|
|
// Set length to the capacity and efficiently
|
|
// zero-out the memory
|
|
rbuf.set_len(rbuf.capacity());
|
|
reader.initializer().initialize(&mut rbuf[len..]);
|
|
}
|
|
}
|
|
|
|
// TODO: Need a select! on a channel that I can trigger to cancel this
|
|
let cnt = reader.read(&mut rbuf[len..]).await?;
|
|
|
|
if cnt > 0 {
|
|
len += cnt;
|
|
} else {
|
|
// Read 0 bytes from the server; end-of-stream
|
|
break;
|
|
}
|
|
|
|
while len > 0 {
|
|
let size = rbuf.len();
|
|
let message = ServerMessage::deserialize(&mut rbuf)?;
|
|
len -= size - rbuf.len();
|
|
|
|
// TODO: Some messages should be kept behind here
|
|
match message {
|
|
Some(ServerMessage::ParameterStatus(body)) => {
|
|
log::debug!("parameter {} = {}", body.name()?, body.value()?);
|
|
}
|
|
|
|
Some(message) => {
|
|
// TODO: Handle this error?
|
|
sender.send(message).await.unwrap();
|
|
}
|
|
|
|
None => {
|
|
// Did not receive enough bytes to
|
|
// deserialize a complete message
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|