mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-01-20 15:46:30 +00:00
227 lines
6.9 KiB
Rust
227 lines
6.9 KiB
Rust
use std::net::Shutdown;
|
|
|
|
use byteorder::{ByteOrder, LittleEndian};
|
|
|
|
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
|
|
use crate::mysql::protocol::{Capabilities, Encode, EofPacket, ErrPacket, OkPacket};
|
|
|
|
use crate::mysql::MySqlError;
|
|
use crate::url::Url;
|
|
|
|
// Size before a packet is split
|
|
const MAX_PACKET_SIZE: u32 = 1024;
|
|
|
|
pub(crate) struct MySqlStream {
|
|
pub(super) stream: BufStream<MaybeTlsStream>,
|
|
|
|
// Is the stream ready to send commands
|
|
// Put another way, are we still expecting an EOF or OK packet to terminate
|
|
pub(super) is_ready: bool,
|
|
|
|
// Active capabilities
|
|
pub(super) capabilities: Capabilities,
|
|
|
|
// Packets in a command sequence have an incrementing sequence number
|
|
// This number must be 0 at the start of each command
|
|
pub(super) seq_no: u8,
|
|
|
|
// Packets are buffered into a second buffer from the stream
|
|
// as we may have compressed or split packets to figure out before
|
|
// decoding
|
|
packet_buf: Vec<u8>,
|
|
packet_len: usize,
|
|
}
|
|
|
|
impl MySqlStream {
|
|
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
|
|
let stream = MaybeTlsStream::connect(&url, 3306).await?;
|
|
|
|
let mut capabilities = Capabilities::PROTOCOL_41
|
|
| Capabilities::IGNORE_SPACE
|
|
| Capabilities::DEPRECATE_EOF
|
|
| Capabilities::FOUND_ROWS
|
|
| Capabilities::TRANSACTIONS
|
|
| Capabilities::SECURE_CONNECTION
|
|
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
|
| Capabilities::MULTI_STATEMENTS
|
|
| Capabilities::MULTI_RESULTS
|
|
| Capabilities::PLUGIN_AUTH;
|
|
|
|
if url.database().is_some() {
|
|
capabilities |= Capabilities::CONNECT_WITH_DB;
|
|
}
|
|
|
|
if cfg!(feature = "tls") {
|
|
capabilities |= Capabilities::SSL;
|
|
}
|
|
|
|
Ok(Self {
|
|
capabilities,
|
|
stream: BufStream::new(stream),
|
|
packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize),
|
|
packet_len: 0,
|
|
seq_no: 0,
|
|
is_ready: true,
|
|
})
|
|
}
|
|
|
|
pub(super) fn is_tls(&self) -> bool {
|
|
self.stream.is_tls()
|
|
}
|
|
|
|
pub(super) fn shutdown(&self) -> crate::Result<()> {
|
|
Ok(self.stream.shutdown(Shutdown::Both)?)
|
|
}
|
|
|
|
#[inline]
|
|
pub(super) async fn send<T>(&mut self, packet: T, initial: bool) -> crate::Result<()>
|
|
where
|
|
T: Encode + std::fmt::Debug,
|
|
{
|
|
if initial {
|
|
self.seq_no = 0;
|
|
}
|
|
|
|
self.write(packet);
|
|
self.flush().await
|
|
}
|
|
|
|
#[inline]
|
|
pub(super) async fn flush(&mut self) -> crate::Result<()> {
|
|
Ok(self.stream.flush().await?)
|
|
}
|
|
|
|
/// Write the packet to the buffered stream ( do not send to the server )
|
|
pub(super) fn write<T>(&mut self, packet: T)
|
|
where
|
|
T: Encode,
|
|
{
|
|
let buf = self.stream.buffer_mut();
|
|
|
|
// Allocate room for the header that we write after the packet;
|
|
// so, we can get an accurate and cheap measure of packet length
|
|
|
|
let header_offset = buf.len();
|
|
buf.advance(4);
|
|
|
|
packet.encode(buf, self.capabilities);
|
|
|
|
// Determine length of encoded packet
|
|
// and write to allocated header
|
|
|
|
let len = buf.len() - header_offset - 4;
|
|
let mut header = &mut buf[header_offset..];
|
|
|
|
LittleEndian::write_u32(&mut header, len as u32);
|
|
|
|
// Take the last sequence number received, if any, and increment by 1
|
|
// If there was no sequence number, we only increment if we split packets
|
|
header[3] = self.seq_no;
|
|
self.seq_no = self.seq_no.wrapping_add(1);
|
|
}
|
|
|
|
#[inline]
|
|
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
|
|
self.read().await?;
|
|
|
|
Ok(self.packet())
|
|
}
|
|
|
|
pub(super) async fn read(&mut self) -> crate::Result<()> {
|
|
self.packet_buf.clear();
|
|
self.packet_len = 0;
|
|
|
|
// Read the packet header which contains the length and the sequence number
|
|
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
|
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
|
let mut header = self.stream.peek(4_usize).await?;
|
|
|
|
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
|
|
self.seq_no = header.get_u8()?.wrapping_add(1);
|
|
|
|
self.stream.consume(4);
|
|
|
|
// Read the packet body and copy it into our internal buf
|
|
// We must have a separate buffer around the stream as we can't operate directly
|
|
// on bytes returned from the stream. We have various kinds of payload manipulation
|
|
// that must be handled before decoding.
|
|
let payload = self.stream.peek(self.packet_len).await?;
|
|
|
|
self.packet_buf.reserve(payload.len());
|
|
self.packet_buf.extend_from_slice(payload);
|
|
|
|
self.stream.consume(self.packet_len);
|
|
|
|
// TODO: Implement packet compression
|
|
// TODO: Implement packet joining
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Returns a reference to the most recently received packet data.
|
|
/// A call to `read` invalidates this buffer.
|
|
#[inline]
|
|
pub(super) fn packet(&self) -> &[u8] {
|
|
&self.packet_buf[..self.packet_len]
|
|
}
|
|
}
|
|
|
|
impl MySqlStream {
|
|
pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> {
|
|
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
|
let _eof = EofPacket::read(self.receive().await?)?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result<Option<EofPacket>> {
|
|
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) && self.packet()[0] == 0xFE {
|
|
Ok(Some(EofPacket::read(self.packet())?))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
pub(crate) fn handle_unexpected<T>(&mut self) -> crate::Result<T> {
|
|
Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into())
|
|
}
|
|
|
|
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
|
|
self.is_ready = true;
|
|
Err(MySqlError(ErrPacket::read(self.packet(), self.capabilities)?).into())
|
|
}
|
|
|
|
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
|
|
self.is_ready = true;
|
|
OkPacket::read(self.packet())
|
|
}
|
|
|
|
pub(crate) async fn wait_until_ready(&mut self) -> crate::Result<()> {
|
|
if !self.is_ready {
|
|
loop {
|
|
let packet_id = self.receive().await?[0];
|
|
match packet_id {
|
|
0xFE if self.packet().len() < 0xFF_FF_FF => {
|
|
// OK or EOF packet
|
|
self.is_ready = true;
|
|
break;
|
|
}
|
|
|
|
0xFF => {
|
|
// ERR packet
|
|
self.is_ready = true;
|
|
return self.handle_err();
|
|
}
|
|
|
|
_ => {
|
|
// Something else; skip
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|