diff --git a/sqlx-core/src/mariadb/backend.rs b/sqlx-core/src/mariadb/backend.rs index 5fdd6de9..9c2a3739 100644 --- a/sqlx-core/src/mariadb/backend.rs +++ b/sqlx-core/src/mariadb/backend.rs @@ -1,7 +1,7 @@ use super::{MariaDb, MariaDbQueryParameters, MariaDbRow}; use crate::backend::Backend; use crate::describe::{Describe, ResultField}; -use crate::mariadb::protocol::ColumnDefinitionPacket; +use crate::mariadb::protocol::{StmtExecFlag, ComStmtExecute, ResultRow, Capabilities, OkPacket, EofPacket, ErrPacket, ColumnDefinitionPacket, ColumnCountPacket}; use async_trait::async_trait; use futures_core::stream::BoxStream; @@ -31,17 +31,127 @@ impl Backend for MariaDb { self.start_sequence(); let prepare_ok = self.send_prepare(query).await?; - let affected = self.execute(prepare_ok.statement_id, params).await?; + // SEND ================ + self.start_sequence(); + self.execute(prepare_ok.statement_id, params).await?; + // ===================== - Ok(affected) + // Row Counter, used later + let mut rows = 0u64; + let capabilities = self.capabilities; + let has_eof = capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF); + + let packet = self.receive().await?; + if packet[0] == 0x00 { + let _ok = OkPacket::decode(packet, capabilities)?; + } else if packet[0] == 0xFF { + return ErrPacket::decode(packet)?.expect_error(); + } else { + // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes + // how many columns we can expect when fetching rows from this statement + let column_count: u64 = ColumnCountPacket::decode(packet)?.columns; + + // Next we have a [ColumnDefinitionPacket] which verbosely explains each minute + // detail about the column in question including table, aliasing, and type + // TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation + let mut columns = vec![]; + for _ in 0..column_count { + columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?); + } + + // When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by + // an EOF packet + if !has_eof { + let _eof = EofPacket::decode(self.receive().await?)?; + } + + // For each row in the result set we will receive a ResultRow packet. + // We may receive an [OkPacket], [EofPacket], or [ErrPacket] (depending on if EOFs are enabled) to finalize the iteration. + loop { + let packet = self.receive().await?; + if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF { + // NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows) + // but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF. + if !has_eof { + let _eof = EofPacket::decode(packet)?; + } else { + let _ok = OkPacket::decode(packet, capabilities)?; + } + + break; + } else if packet[0] == 0xFF { + let err = ErrPacket::decode(packet)?; + panic!("received db err = {:?}", err); + } else { + // Ignore result rows; exec only returns number of affected rows; + let _ = ResultRow::decode(packet, &columns)?; + + // For every row we decode we increment counter + rows = rows + 1; + } + } + } + + Ok(rows) } fn fetch( &mut self, _query: &str, _params: MariaDbQueryParameters, - ) -> BoxStream<'_, crate::Result> { - unimplemented!(); + ) -> BoxStream<'_, crate::Result> { + Box::pin(async_stream::try_stream! { + // Write prepare statement to buffer + self.start_sequence(); + let prepare_ok = self.send_prepare(query).await?; + + self.start_sequence(); + self.execute(prepare_ok.statement_id, params).await?; + + let capabilities = self.capabilities; + let has_eof = capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF); + + let packet = self.receive().await?; + if packet[0] == 0x00 { + let _ok = OkPacket::decode(packet, capabilities)?; + } else if packet[0] == 0xFF { + return ErrPacket::decode(packet)?.expect_error(); + } + // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes + // how many columns we can expect when fetching rows from this statement + // let column_count: u64 = ColumnCountPacket::decode(packet)?.columns; + + // Next we have a [ColumnDefinitionPacket] which verbosely explains each minute + // detail about the column in question including table, aliasing, and type + // TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation + let mut columns = vec![]; + for _ in 0..column_count { + columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?); + } + + // When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by + // an EOF packet + // if !has_eof { + // let _eof = EofPacket::decode(self.receive().await?)?; + // } + + // loop { + // let packet = self.receive().await?; + // if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF { + // if !has_eof { + // let _eof = EofPacket::decode(packet)?; + // } else { + // let _ok = OkPacket::decode(packet, capabilities)?; + // } + // break; + // } else if packet[0] == 0xFF { + // let err = ErrPacket::decode(packet)?; + // panic!("received db err = {:?}", err); + // } else { + // yield ResultRow::decode(packet, &columns); + // } + // } + }) } async fn fetch_optional( diff --git a/sqlx-core/src/mariadb/connection.rs b/sqlx-core/src/mariadb/connection.rs index 80b23afc..76d32858 100644 --- a/sqlx-core/src/mariadb/connection.rs +++ b/sqlx-core/src/mariadb/connection.rs @@ -192,11 +192,55 @@ impl MariaDb { ComStmtPrepareOk::decode(packet).map_err(Into::into) } + pub(super) async fn step(&mut self, columns: &Vec, packet: &[u8]) -> Result> { + // For each row in the result set we will receive a ResultRow packet. + // We may receive an [OkPacket], [EofPacket], or [ErrPacket] (depending on if EOFs are enabled) to finalize the iteration. + if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF { + // NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows) + // but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF. + if !self.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + let _eof = EofPacket::decode(packet)?; + Ok(None) + } else { + let _ok = OkPacket::decode(packet, self.capabilities)?; + Ok(None) + } + } else if packet[0] == 0xFF { + let _ = ErrPacket::decode(packet)?; + // TODO: Should be error + Ok(None) + } else { + Ok(Some(ResultRow::decode(packet, columns)?)) + } + } + + pub(super) async fn column_definitions(&mut self, packet: &[u8]) -> Result> { + // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes + // how many columns we can expect when fetching rows from this statement + let column_count: u64 = ColumnCountPacket::decode(packet)?.columns; + + // Next we have a [ColumnDefinitionPacket] which verbosely explains each minute + // detail about the column in question including table, aliasing, and type + // TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation + let mut columns = vec![]; + for _ in 0..column_count { + columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?); + } + + // When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by + // an EOF packet + if !self.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + let _eof = EofPacket::decode(self.receive().await?)?; + } + + Ok(columns) + } + pub(super) async fn execute( &mut self, statement_id: u32, _params: MariaDbQueryParameters, - ) -> Result { + ) -> Result<()> { // TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO) // SEND ================ @@ -211,62 +255,6 @@ impl MariaDb { self.stream.flush().await?; // ===================== - // Row Counter, used later - let mut rows = 0u64; - let capabilities = self.capabilities; - let has_eof = capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF); - - let packet = self.receive().await?; - if packet[0] == 0x00 { - let _ok = OkPacket::decode(packet, capabilities)?; - } else if packet[0] == 0xFF { - return ErrPacket::decode(packet)?.expect_error(); - } else { - // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes - // how many columns we can expect when fetching rows from this statement - let column_count: u64 = ColumnCountPacket::decode(packet)?.columns; - - // Next we have a [ColumnDefinitionPacket] which verbosely explains each minute - // detail about the column in question including table, aliasing, and type - // TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation - let mut columns = vec![]; - for _ in 0..column_count { - columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?); - } - - // When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by - // an EOF packet - if !has_eof { - let _eof = EofPacket::decode(self.receive().await?)?; - } - - // For each row in the result set we will receive a ResultRow packet. - // We may receive an [OkPacket], [EofPacket], or [ErrPacket] (depending on if EOFs are enabled) to finalize the iteration. - loop { - let packet = self.receive().await?; - if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF { - // NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows) - // but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF. - if !has_eof { - let _eof = EofPacket::decode(packet)?; - } else { - let _ok = OkPacket::decode(packet, capabilities)?; - } - - break; - } else if packet[0] == 0xFF { - let err = ErrPacket::decode(packet)?; - panic!("received db err = {:?}", err); - } else { - // Ignore result rows; exec only returns number of affected rows; - let _ = ResultRow::decode(packet, &columns)?; - - // For every row we decode we increment counter - rows = rows + 1; - } - } - } - - Ok(rows) + Ok(()) } } diff --git a/sqlx-core/src/mariadb/protocol/response/eof.rs b/sqlx-core/src/mariadb/protocol/response/eof.rs index 21d4ed75..6a46eafb 100644 --- a/sqlx-core/src/mariadb/protocol/response/eof.rs +++ b/sqlx-core/src/mariadb/protocol/response/eof.rs @@ -18,7 +18,7 @@ impl EofPacket { pub(crate) fn decode(mut buf: &[u8]) -> crate::Result { let header = buf.get_u8()?; if header != 0xFE { - return Err(protocol_err!("expected 0xFE; received {}", header)); + return Err(protocol_err!("expected 0xFE; received {}", header))?; } let warning_count = buf.get_u16::()?; diff --git a/sqlx-core/src/mariadb/protocol/response/ok.rs b/sqlx-core/src/mariadb/protocol/response/ok.rs index 62058854..e71b49b5 100644 --- a/sqlx-core/src/mariadb/protocol/response/ok.rs +++ b/sqlx-core/src/mariadb/protocol/response/ok.rs @@ -27,7 +27,7 @@ impl OkPacket { return Err(protocol_err!( "expected 0x00 or 0xFE; received 0x{:X}", header - )); + ))?; } let affected_rows = buf.get_uint_lenenc::()?.unwrap_or(0);