mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-30 13:20:59 +00:00
Fix prepared statement execution
This commit is contained in:
parent
433ec628da
commit
c019f91fc6
@ -2,7 +2,7 @@ use super::Connection;
|
||||
use crate::{
|
||||
mariadb::{
|
||||
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket,
|
||||
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag,
|
||||
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag, ProtocolType
|
||||
},
|
||||
ConnectOptions,
|
||||
};
|
||||
@ -109,10 +109,8 @@ mod test {
|
||||
})
|
||||
.await?;
|
||||
|
||||
println!("selecting db");
|
||||
conn.select_db("test").await?;
|
||||
|
||||
println!("querying");
|
||||
conn.query("SELECT * FROM users").await?;
|
||||
|
||||
Ok(())
|
||||
@ -154,17 +152,10 @@ mod test {
|
||||
.prepare("SELECT id FROM users WHERE username=?")
|
||||
.await?;
|
||||
|
||||
println!("{:?}", prepared);
|
||||
|
||||
// if let Some(param_defs) = &mut prepared.param_defs {
|
||||
// param_defs[0].field_type = FieldType::MysqlTypeBlob;
|
||||
// }
|
||||
|
||||
let exec = ComStmtExec {
|
||||
stmt_id: prepared.ok.stmt_id,
|
||||
flags: StmtExecFlag::NoCursor,
|
||||
// params: None,
|
||||
params: Some(vec![Some(Bytes::from_static(b"daniel"))]),
|
||||
params: Some(vec![Some(Bytes::from_static(b"josh"))]),
|
||||
param_defs: prepared.param_defs,
|
||||
};
|
||||
|
||||
@ -172,26 +163,21 @@ mod test {
|
||||
|
||||
let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream);
|
||||
ctx.next_packet().await?;
|
||||
ctx.columns = Some(prepared.ok.columns as u64);
|
||||
ctx.column_defs = prepared.res_columns;
|
||||
|
||||
match ctx.decoder.peek_tag() {
|
||||
0xFF => println!("{:?}", ErrPacket::decode(&mut ctx)?),
|
||||
0x00 => println!("{:?}", OkPacket::decode(&mut ctx)?),
|
||||
_ => println!("{:?}", ResultSet::deserialize(ctx).await?),
|
||||
0xFF => {
|
||||
ErrPacket::decode(&mut ctx)?;
|
||||
},
|
||||
0x00 => {
|
||||
OkPacket::decode(&mut ctx)?;
|
||||
},
|
||||
_ => {
|
||||
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
|
||||
},
|
||||
}
|
||||
|
||||
// let fetch = ComStmtFetch {
|
||||
// stmt_id: -1,
|
||||
// rows: 10,
|
||||
// };
|
||||
//
|
||||
// conn.send(fetch).await?;
|
||||
//
|
||||
// let buf = conn.stream.next_packet().await?;
|
||||
//
|
||||
// println!("{:?}", buf);
|
||||
|
||||
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ use crate::{
|
||||
mariadb::{
|
||||
protocol::encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare,
|
||||
ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader,
|
||||
ResultSet, ServerStatusFlag,
|
||||
ResultSet, ServerStatusFlag, ProtocolType
|
||||
},
|
||||
ConnectOptions,
|
||||
};
|
||||
@ -130,7 +130,7 @@ impl Connection {
|
||||
Ok(None)
|
||||
}
|
||||
0xFB => unimplemented!(),
|
||||
_ => Ok(Some(ResultSet::deserialize(ctx).await?)),
|
||||
_ => Ok(Some(ResultSet::deserialize(ctx, ProtocolType::Text).await?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,7 +203,6 @@ impl Framed {
|
||||
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
|
||||
// so the the use of packet_headers.last().unwrap() is allowed.
|
||||
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
|
||||
println!("{:?}", self.buf);
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.length as usize == encode::U24_MAX {
|
||||
packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?);
|
||||
|
||||
@ -9,6 +9,6 @@ pub use protocol::{
|
||||
ComSetOption, ComShutdown, ComSleep, ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch,
|
||||
ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, DeContext, Decode, Decoder, Encode,
|
||||
EofPacket, ErrPacket, ErrorCode, FieldDetailFlag, FieldType, HandshakeResponsePacket,
|
||||
InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket,
|
||||
ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag,
|
||||
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultRow, ResultSet, SSLRequestPacket,
|
||||
ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag, ProtocolType
|
||||
};
|
||||
|
||||
@ -316,9 +316,7 @@ impl BufMut for Vec<u8> {
|
||||
// Same as the string counterpart copied to maintain consistency with the spec.
|
||||
#[inline]
|
||||
fn put_byte_fix(&mut self, bytes: &Bytes, size: usize) {
|
||||
if size != bytes.len() {
|
||||
panic!("Sizes do not match");
|
||||
}
|
||||
assert_eq!(size, bytes.len());
|
||||
|
||||
self.extend_from_slice(bytes);
|
||||
}
|
||||
|
||||
@ -19,8 +19,8 @@ pub use packets::{
|
||||
ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep,
|
||||
ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk,
|
||||
ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket,
|
||||
InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket,
|
||||
SetOptionOptions, ShutdownOptions,
|
||||
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultSet, SSLRequestPacket,
|
||||
SetOptionOptions, ShutdownOptions, ResultRow
|
||||
};
|
||||
|
||||
pub use decode::{DeContext, Decode, Decoder};
|
||||
@ -30,5 +30,5 @@ pub use encode::{BufMut, Encode};
|
||||
pub use error_codes::ErrorCode;
|
||||
|
||||
pub use types::{
|
||||
Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag,
|
||||
ProtocolType, Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag,
|
||||
};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use crate::mariadb::{
|
||||
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag,
|
||||
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag, FieldType
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
@ -26,12 +26,12 @@ impl Encode for ComStmtExec {
|
||||
(Some(params), Some(param_defs)) if params.len() > 0 => {
|
||||
let null_bitmap_size = (params.len() + 7) / 8;
|
||||
let mut shift_amount = 0u8;
|
||||
let mut bitmap = vec![1u8];
|
||||
let send_type = 0u8;
|
||||
let mut bitmap = vec![0u8];
|
||||
let send_type = 1u8;
|
||||
|
||||
// Generate NULL-bitmap from params
|
||||
for param in params {
|
||||
if param.is_some() {
|
||||
if param.is_none() {
|
||||
let last_byte = bitmap.pop().unwrap();
|
||||
bitmap.push(last_byte & (1 << shift_amount));
|
||||
}
|
||||
@ -48,7 +48,7 @@ impl Encode for ComStmtExec {
|
||||
|
||||
if send_type > 0 {
|
||||
for param in param_defs {
|
||||
// buf.put_int_u8(param.field_type as u8);
|
||||
buf.put_int_u8(param.field_type as u8);
|
||||
buf.put_int_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ pub use com_stmt_prepare::ComStmtPrepare;
|
||||
pub use com_stmt_prepare_ok::ComStmtPrepareOk;
|
||||
pub use com_stmt_prepare_resp::ComStmtPrepareResp;
|
||||
pub use com_stmt_reset::ComStmtReset;
|
||||
pub use result_row::ResultRow;
|
||||
|
||||
pub enum BinaryProtocol {
|
||||
ComStmtPrepare = 0x16,
|
||||
|
||||
@ -3,6 +3,8 @@ use bytes::Bytes;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ResultRow {
|
||||
pub length: u32,
|
||||
pub seq_no: u8,
|
||||
pub columns: Vec<Option<Bytes>>,
|
||||
}
|
||||
|
||||
@ -24,11 +26,11 @@ impl crate::mariadb::Decode for ResultRow {
|
||||
))
|
||||
}?;
|
||||
|
||||
let row = match (&ctx.columns, &ctx.column_defs) {
|
||||
let columns = match (&ctx.columns, &ctx.column_defs) {
|
||||
(Some(columns), Some(column_defs)) => {
|
||||
(0..*columns as usize)
|
||||
.map(|index| {
|
||||
if (1 << (index % 8)) & bitmap[index / 8] as usize == 0 {
|
||||
if (1 << (index % 8)) & bitmap[index / 8] as usize == 1 {
|
||||
None
|
||||
} else {
|
||||
match column_defs[index].field_type {
|
||||
@ -105,6 +107,10 @@ impl crate::mariadb::Decode for ResultRow {
|
||||
_ => Vec::new(),
|
||||
};
|
||||
|
||||
Ok(ResultRow::default())
|
||||
Ok(ResultRow {
|
||||
length,
|
||||
seq_no,
|
||||
columns
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,9 +29,10 @@ pub use ssl_request::SSLRequestPacket;
|
||||
pub use text::{
|
||||
ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection,
|
||||
ComSetOption, ComShutdown, ComSleep, ComStatistics, SetOptionOptions, ShutdownOptions,
|
||||
ResultRow as ResultRowText
|
||||
};
|
||||
|
||||
pub use binary::{
|
||||
ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp,
|
||||
ComStmtReset,
|
||||
ComStmtReset, ResultRow as ResultRowBinary
|
||||
};
|
||||
|
||||
@ -1,67 +1,29 @@
|
||||
use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use std::convert::TryFrom;
|
||||
use crate::mariadb::{ResultRowText, ResultRowBinary};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct ResultRow {
|
||||
pub length: u32,
|
||||
pub seq_no: u8,
|
||||
pub row: Vec<Bytes>,
|
||||
pub columns: Vec<Option<bytes::Bytes>>
|
||||
}
|
||||
|
||||
impl Decode for ResultRow {
|
||||
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_u8();
|
||||
|
||||
let row = if let Some(columns) = ctx.columns {
|
||||
(0..columns)
|
||||
.map(|_| decoder.decode_string_lenenc())
|
||||
.collect::<Vec<Bytes>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ResultRow {
|
||||
length,
|
||||
seq_no,
|
||||
row,
|
||||
})
|
||||
impl From<ResultRowText> for ResultRow {
|
||||
fn from(row: ResultRowText) -> Self {
|
||||
ResultRow {
|
||||
length: row.length,
|
||||
seq_no: row.seq_no,
|
||||
columns: row.columns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{
|
||||
__bytes_builder,
|
||||
mariadb::{ConnContext, Decoder},
|
||||
ConnectOptions,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_result_row_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
1u8,
|
||||
// string<lenenc> column data
|
||||
1u8, b"s"
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
ctx.columns = Some(1);
|
||||
|
||||
let _message = ResultRow::decode(&mut ctx)?;
|
||||
|
||||
Ok(())
|
||||
impl From<ResultRowBinary> for ResultRow {
|
||||
fn from(row: ResultRowBinary) -> Self {
|
||||
ResultRow {
|
||||
length: row.length,
|
||||
seq_no: row.seq_no,
|
||||
columns: row.columns,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ use failure::Error;
|
||||
|
||||
use crate::mariadb::{
|
||||
Capabilities, ColumnDefPacket, ColumnPacket, ConnContext, DeContext, Decode, Decoder,
|
||||
EofPacket, ErrPacket, Framed, OkPacket, ResultRow,
|
||||
EofPacket, ErrPacket, Framed, OkPacket, ResultRowText, ResultRowBinary, ProtocolType, ResultRow
|
||||
};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@ -14,11 +14,9 @@ pub struct ResultSet {
|
||||
}
|
||||
|
||||
impl ResultSet {
|
||||
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, Error> {
|
||||
pub async fn deserialize(mut ctx: DeContext<'_>, protocol: ProtocolType) -> Result<Self, Error> {
|
||||
let column_packet = ColumnPacket::decode(&mut ctx)?;
|
||||
|
||||
println!("{:?}", column_packet);
|
||||
|
||||
let columns = if let Some(columns) = column_packet.columns {
|
||||
let mut column_defs = Vec::new();
|
||||
for _ in 0..columns {
|
||||
@ -30,8 +28,6 @@ impl ResultSet {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
println!("{:?}", columns);
|
||||
|
||||
ctx.next_packet().await?;
|
||||
|
||||
let eof_packet = if !ctx
|
||||
@ -41,7 +37,6 @@ impl ResultSet {
|
||||
{
|
||||
// If we get an eof packet we must update ctx to hold a new buffer of the next packet.
|
||||
let eof_packet = Some(EofPacket::decode(&mut ctx)?);
|
||||
println!("{:?}", eof_packet);
|
||||
ctx.next_packet().await?;
|
||||
eof_packet
|
||||
} else {
|
||||
@ -59,19 +54,35 @@ impl ResultSet {
|
||||
};
|
||||
|
||||
let tag = ctx.decoder.peek_tag();
|
||||
if tag == &0xFE && packet_header.length <= 0xFFFFFF || packet_header.length == 0 {
|
||||
if tag == &0xFE && packet_header.length <= 0xFFFFFF {
|
||||
break;
|
||||
} else {
|
||||
let index = ctx.decoder.index;
|
||||
match ResultRow::decode(&mut ctx) {
|
||||
Ok(v) => {
|
||||
rows.push(v);
|
||||
ctx.next_packet().await?;
|
||||
}
|
||||
Err(_) => {
|
||||
ctx.decoder.index = index;
|
||||
break;
|
||||
}
|
||||
match protocol {
|
||||
ProtocolType::Text => {
|
||||
match ResultRowText::decode(&mut ctx) {
|
||||
Ok(row) => {
|
||||
rows.push(ResultRow::from(row));
|
||||
ctx.next_packet().await?;
|
||||
}
|
||||
Err(_) => {
|
||||
ctx.decoder.index = index;
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
ProtocolType::Binary => {
|
||||
match ResultRowBinary::decode(&mut ctx) {
|
||||
Ok(row) => {
|
||||
rows.push(ResultRow::from(row));
|
||||
ctx.next_packet().await?;
|
||||
}
|
||||
Err(_) => {
|
||||
ctx.decoder.index = index;
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -324,7 +335,7 @@ mod test {
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
ResultSet::deserialize(ctx).await?;
|
||||
ResultSet::deserialize(ctx, ProtocolType::Text).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ pub mod com_set_option;
|
||||
pub mod com_shutdown;
|
||||
pub mod com_sleep;
|
||||
pub mod com_statistics;
|
||||
pub mod result_row;
|
||||
|
||||
pub use com_debug::ComDebug;
|
||||
pub use com_init_db::ComInitDb;
|
||||
@ -21,6 +22,7 @@ pub use com_set_option::{ComSetOption, SetOptionOptions};
|
||||
pub use com_shutdown::{ComShutdown, ShutdownOptions};
|
||||
pub use com_sleep::ComSleep;
|
||||
pub use com_statistics::ComStatistics;
|
||||
pub use result_row::ResultRow;
|
||||
|
||||
// This is an enum of text protocol packet tags.
|
||||
// Tags are the 5th byte of the packet (1st byte of packet body)
|
||||
|
||||
67
src/mariadb/protocol/packets/text/result_row.rs
Normal file
67
src/mariadb/protocol/packets/text/result_row.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ResultRow {
|
||||
pub length: u32,
|
||||
pub seq_no: u8,
|
||||
pub columns: Vec<Option<Bytes>>,
|
||||
}
|
||||
|
||||
impl Decode for ResultRow {
|
||||
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_u8();
|
||||
|
||||
let columns = if let Some(columns) = ctx.columns {
|
||||
(0..columns)
|
||||
.map(|_| Some(decoder.decode_string_lenenc()))
|
||||
.collect::<Vec<Option<Bytes>>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ResultRow {
|
||||
length,
|
||||
seq_no,
|
||||
columns,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{
|
||||
__bytes_builder,
|
||||
mariadb::{ConnContext, Decoder},
|
||||
ConnectOptions,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_result_row_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
1u8,
|
||||
// string<lenenc> column data
|
||||
1u8, b"s"
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
ctx.columns = Some(1);
|
||||
|
||||
let _message = ResultRow::decode(&mut ctx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,10 @@
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub enum ProtocolType {
|
||||
Text,
|
||||
Binary
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub struct Capabilities: u128 {
|
||||
const CLIENT_MYSQL = 1;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user