mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
Adjust mariadb for async-trait usage
This commit is contained in:
parent
47b06edad1
commit
fb877fee28
@ -1,147 +0,0 @@
|
||||
use super::MariaDbRawConnection;
|
||||
use crate::mariadb::protocol::{
|
||||
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket, HandshakeResponsePacket,
|
||||
InitialHandshakePacket, OkPacket, ProtocolType, StmtExecFlag,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use failure::{err_msg, Error};
|
||||
use std::ops::BitAnd;
|
||||
use url::Url;
|
||||
|
||||
pub async fn establish(conn: &mut MariaDbRawConnection, url: Url) -> Result<(), Error> {
|
||||
let buf = conn.stream.next_packet().await?;
|
||||
let mut de_ctx = DeContext::new(&mut conn.context, buf);
|
||||
let initial = InitialHandshakePacket::decode(&mut de_ctx)?;
|
||||
|
||||
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
|
||||
|
||||
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
|
||||
// Minimum client capabilities required to establish connection.bak
|
||||
capabilities: de_ctx.ctx.capabilities,
|
||||
max_packet_size: 1024,
|
||||
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
|
||||
username: url.username(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
conn.send(handshake).await?;
|
||||
|
||||
let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?);
|
||||
|
||||
match ctx.decoder.peek_tag() {
|
||||
0xFF => {
|
||||
return Err(ErrPacket::decode(&mut ctx)?.into());
|
||||
}
|
||||
0x00 => {
|
||||
OkPacket::decode(&mut ctx)?;
|
||||
}
|
||||
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::mariadb::{ComStmtFetch, ComStmtPrepareResp, FieldType, ResultSet};
|
||||
use failure::Error;
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_connect() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_ping() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
conn.ping().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_select_db() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_query() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.query("SELECT * FROM users").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_prepare() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.prepare("SELECT * FROM users WHERE username = ?")
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_execute_prepared() -> Result<(), Error> {
|
||||
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
let mut prepared = conn
|
||||
.prepare("SELECT id FROM users WHERE username=?")
|
||||
.await?;
|
||||
|
||||
let exec = ComStmtExec {
|
||||
stmt_id: prepared.ok.stmt_id,
|
||||
flags: StmtExecFlag::NO_CURSOR,
|
||||
params: Some(vec![Some(Bytes::from_static(b"josh"))]),
|
||||
param_defs: prepared.param_defs,
|
||||
};
|
||||
|
||||
conn.send(exec).await?;
|
||||
|
||||
let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream);
|
||||
ctx.next_packet().await?;
|
||||
ctx.columns = Some(prepared.ok.columns as u64);
|
||||
ctx.column_defs = prepared.res_columns;
|
||||
|
||||
println!("{:?}", ctx.columns);
|
||||
println!("{:?}", ctx.column_defs);
|
||||
|
||||
match ctx.decoder.peek_tag() {
|
||||
0xFF => {
|
||||
ErrPacket::decode(&mut ctx)?;
|
||||
}
|
||||
0x00 => {
|
||||
OkPacket::decode(&mut ctx)?;
|
||||
}
|
||||
_ => {
|
||||
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_does_not_connect() -> Result<(), Error> {
|
||||
match MariaDbRawConnection::establish(&"mariadb//roote@127.0.0.1:3306").await {
|
||||
Ok(_) => Err(err_msg("Bad username still worked?")),
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,9 +0,0 @@
|
||||
use crate::mariadb::MariaDbRawConnection;
|
||||
use std::io;
|
||||
|
||||
pub async fn execute(conn: &mut MariaDbRawConnection) -> io::Result<u64> {
|
||||
conn.flush().await?;
|
||||
|
||||
let mut rows: u64 = 0;
|
||||
while let Some(message) = conn.receive().await? {}
|
||||
}
|
||||
@ -1,267 +0,0 @@
|
||||
use crate::{
|
||||
error::ErrorKind,
|
||||
mariadb::protocol::{
|
||||
encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare,
|
||||
ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader,
|
||||
ProtocolType, ResultSet, ServerStatusFlag,
|
||||
},
|
||||
};
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use core::convert::TryFrom;
|
||||
use failure::Error;
|
||||
use futures_core::future::BoxFuture;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
mod establish;
|
||||
mod execute;
|
||||
|
||||
pub struct MariaDbRawConnection {
|
||||
pub stream: TcpStream,
|
||||
|
||||
// Buffer used when serializing outgoing messages
|
||||
pub wbuf: Vec<u8>,
|
||||
|
||||
pub rbuf: BytesMut,
|
||||
pub read_index: usize,
|
||||
|
||||
// Context for the connection.bak
|
||||
// Explicitly declared to easily send to deserializers
|
||||
pub context: ConnContext,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnContext {
|
||||
// MariaDB Connection ID
|
||||
pub connection_id: i32,
|
||||
|
||||
// Sequence Number
|
||||
pub seq_no: u8,
|
||||
|
||||
// Last sequence number return by MariaDB
|
||||
pub last_seq_no: u8,
|
||||
|
||||
// Server Capabilities
|
||||
pub capabilities: Capabilities,
|
||||
|
||||
// Server status
|
||||
pub status: ServerStatusFlag,
|
||||
}
|
||||
|
||||
impl ConnContext {
|
||||
#[cfg(test)]
|
||||
pub fn new() -> Self {
|
||||
ConnContext {
|
||||
connection_id: 0,
|
||||
seq_no: 2,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_eof() -> Self {
|
||||
ConnContext {
|
||||
connection_id: 0,
|
||||
seq_no: 2,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF,
|
||||
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MariaDbRawConnection {
|
||||
pub async fn establish(url: &str) -> Result<Self, Error> {
|
||||
// TODO: Handle errors
|
||||
let url = Url::parse(url).map_err(ErrorKind::UrlParse)?;
|
||||
println!("{:?}", url);
|
||||
|
||||
let host = url.host_str().unwrap_or("localhost");
|
||||
let port = url.port().unwrap_or(3306);
|
||||
|
||||
// FIXME: handle errors
|
||||
let host: IpAddr = host.parse().unwrap();
|
||||
let addr: SocketAddr = (host, port).into();
|
||||
let stream = TcpStream::connect(&addr).await?;
|
||||
let mut conn: MariaDbRawConnection = Self {
|
||||
stream,
|
||||
wbuf: Vec::with_capacity(1024),
|
||||
rbuf: BytesMut::with_capacity(8 * 1024),
|
||||
read_index: 0,
|
||||
context: ConnContext {
|
||||
connection_id: -1,
|
||||
seq_no: 1,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
status: ServerStatusFlag::default(),
|
||||
},
|
||||
};
|
||||
|
||||
establish::establish(&mut conn, url).await?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
// pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
|
||||
// where
|
||||
// S: Encode,
|
||||
// {
|
||||
// self.wbuf.clear();
|
||||
// message.encode(&mut self.wbuf, &mut self.context)?;
|
||||
// self.stream.inner.write_all(&self.wbuf).await?;
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
pub fn write(&mut self, message: impl Encode) {
|
||||
message.encode(&mut self.wbuf);
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<(), Error> {
|
||||
self.stream.flush().await?;
|
||||
self.stream.clear().clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn quit(&mut self) -> Result<(), Error> {
|
||||
self.write(ComQuit()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ping(&mut self) -> Result<(), Error> {
|
||||
self.write(ComPing()).await?;
|
||||
|
||||
// Ping response must be an OkPacket
|
||||
OkPacket::decode(&mut DeContext::new(
|
||||
&mut self.context,
|
||||
self.stream.next_packet().await?,
|
||||
))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ping(&mut self) -> Result<(), Error> {
|
||||
// Send the ping command and wait for (and drop) an OK packet
|
||||
// SEND ================
|
||||
self.last_seq_no = None;
|
||||
self.write(ComPing);
|
||||
self.stream.flush().await?;
|
||||
// =====================
|
||||
|
||||
let _ = decode_ok_or_err(self.receive().await?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn prepare(&mut self, query: &str) -> Result<ComStmtPrepareResp, Error> {
|
||||
self.write(ComStmtPrepare {
|
||||
statement: Bytes::from(query),
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
|
||||
ctx.next_packet().await?;
|
||||
Ok(ComStmtPrepareResp::deserialize(ctx).await?)
|
||||
}
|
||||
|
||||
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
|
||||
let mut packet_headers: Vec<PacketHeader> = Vec::new();
|
||||
|
||||
loop {
|
||||
println!("BUF: {:?}: ", self.rbuf);
|
||||
// If we don't have a packet header or the last packet header had a length of
|
||||
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
|
||||
// because the entire message hasn't been received.
|
||||
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
|
||||
// so the the use of packet_headers.last().unwrap() is allowed.
|
||||
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.length as usize == encode::U24_MAX {
|
||||
packet_headers.push(PacketHeader::try_from(&self.rbuf[self.read_index..])?);
|
||||
}
|
||||
} else if self.rbuf.len() > 4 {
|
||||
match PacketHeader::try_from(&self.rbuf[0..]) {
|
||||
Ok(v) => packet_headers.push(v),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.combined_length() > self.rbuf.len() {
|
||||
unsafe {
|
||||
self.rbuf
|
||||
.reserve(packet_header.combined_length() - self.rbuf.len());
|
||||
}
|
||||
}
|
||||
} else if self.rbuf.len() == self.read_index {
|
||||
unsafe {
|
||||
self.rbuf.reserve(32);
|
||||
}
|
||||
}
|
||||
unsafe {
|
||||
self.rbuf.set_len(self.rbuf.capacity());
|
||||
}
|
||||
|
||||
// If we have a packet_header and the amount of currently read bytes (len) is less than
|
||||
// the specified length inside packet_header, then we can continue reading to self.rbuf.
|
||||
// Else if the total number of bytes read is equal to packet_header then we will
|
||||
// return self.rbuf from 0 to self.read_index as it should contain the entire packet.
|
||||
let bytes_read;
|
||||
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.combined_length() > self.read_index {
|
||||
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
|
||||
} else {
|
||||
// Get the packet from the rbuffer, reset read_index, and return packet
|
||||
let packet = self.rbuf.split_to(packet_header.combined_length()).freeze();
|
||||
self.read_index -= packet.len();
|
||||
return Ok(packet);
|
||||
}
|
||||
} else {
|
||||
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
|
||||
}
|
||||
|
||||
if bytes_read > 0 {
|
||||
self.read_index += bytes_read;
|
||||
// If we have read less than 4 bytes, and we don't already have a packet_header
|
||||
// we must try to read again. The packet_header is always present and is 4 bytes long.
|
||||
if bytes_read < 4 && packet_headers.len() == 0 {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
panic!("Cannot read 0 bytes from stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// impl RawConnection for MariaDbRawConnection {
|
||||
// type Backend = MariaDb;
|
||||
|
||||
// #[inline]
|
||||
// fn establish(url: &str) -> BoxFuture<std::io::Result<Self>> {
|
||||
// Box::pin(MariaDbRawConnection::establish(url))
|
||||
// }
|
||||
|
||||
// #[inline]
|
||||
// fn finalize<'c>(&'c mut self) -> BoxFuture<'c, std::io::Result<()>> {
|
||||
// Box::pin(self.finalize())
|
||||
// }
|
||||
|
||||
// fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, std::io::Result<()>>
|
||||
// where
|
||||
// Q: RawQuery<'q, Backend = Self::Backend>,
|
||||
// {
|
||||
// query.finish(self);
|
||||
// Box::pin(execute::execute(self))
|
||||
// }
|
||||
// }
|
||||
@ -13,6 +13,7 @@ use crate::{
|
||||
},
|
||||
Backend, Error, Result,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use futures_core::{future::BoxFuture, stream::BoxStream};
|
||||
use std::{
|
||||
@ -56,7 +57,7 @@ impl MariaDbRawConnection {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub async fn close(&mut self) -> Result<()> {
|
||||
pub async fn close(mut self) -> Result<()> {
|
||||
// Send the quit command
|
||||
|
||||
self.start_sequence();
|
||||
@ -297,55 +298,50 @@ enum ExecResult {
|
||||
Rows(Vec<ColumnDefinitionPacket>),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RawConnection for MariaDbRawConnection {
|
||||
type Backend = MariaDb;
|
||||
|
||||
fn establish(url: &str) -> BoxFuture<Result<Self>>
|
||||
async fn establish(url: &str) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Box::pin(MariaDbRawConnection::establish(url))
|
||||
MariaDbRawConnection::establish(url).await
|
||||
}
|
||||
|
||||
fn close(&mut self) -> BoxFuture<'_, Result<()>> {
|
||||
Box::pin(self.close())
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
self.close().await
|
||||
}
|
||||
|
||||
fn ping(&mut self) -> BoxFuture<'_, Result<()>> {
|
||||
Box::pin(self.ping())
|
||||
async fn ping(&mut self) -> crate::Result<()> {
|
||||
self.ping().await
|
||||
}
|
||||
|
||||
fn execute<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
params: MariaDbQueryParameters,
|
||||
) -> BoxFuture<'c, Result<u64>> {
|
||||
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
|
||||
// Write prepare statement to buffer
|
||||
self.start_sequence();
|
||||
self.write(ComStmtPrepare { statement: query });
|
||||
|
||||
Box::pin(async move {
|
||||
let statement_id = self.exec_prepare().await?;
|
||||
let statement_id = self.exec_prepare().await?;
|
||||
|
||||
let affected = self.execute(statement_id, params).await?;
|
||||
let affected = self.execute(statement_id, params).await?;
|
||||
|
||||
Ok(affected)
|
||||
})
|
||||
Ok(affected)
|
||||
}
|
||||
|
||||
fn fetch<'c>(
|
||||
&'c mut self,
|
||||
fn fetch(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: MariaDbQueryParameters,
|
||||
) -> BoxStream<'c, Result<MariaDbRow>> {
|
||||
) -> BoxStream<'_, Result<MariaDbRow>> {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn fetch_optional<'c>(
|
||||
&'c mut self,
|
||||
async fn fetch_optional(
|
||||
&mut self,
|
||||
query: &str,
|
||||
params: MariaDbQueryParameters,
|
||||
) -> BoxFuture<'c, Result<Option<<Self::Backend as Backend>::Row>>> {
|
||||
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
|
||||
unimplemented!();
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user