From 434bfaa76a446a1f99fef01cd6059ba2792146bd Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 6 Jun 2020 12:09:15 -0700 Subject: [PATCH] feat(mssql): handle stream flushing --- sqlx-core/src/mssql/connection/establish.rs | 5 +- sqlx-core/src/mssql/connection/executor.rs | 48 ++++++++++++++++--- sqlx-core/src/mssql/connection/mod.rs | 3 ++ sqlx-core/src/mssql/connection/stream.rs | 11 +++-- sqlx-core/src/mssql/protocol/done.rs | 12 +---- sqlx-core/src/mssql/protocol/message.rs | 10 ++++ sqlx-core/src/mssql/protocol/mod.rs | 1 + sqlx-core/src/mssql/protocol/return_status.rs | 17 +++++++ tests/mssql/mssql.rs | 33 +++++++++++++ 9 files changed, 119 insertions(+), 21 deletions(-) create mode 100644 sqlx-core/src/mssql/protocol/return_status.rs diff --git a/sqlx-core/src/mssql/connection/establish.rs b/sqlx-core/src/mssql/connection/establish.rs index 8612261a..6c2717b5 100644 --- a/sqlx-core/src/mssql/connection/establish.rs +++ b/sqlx-core/src/mssql/connection/establish.rs @@ -76,6 +76,9 @@ impl MsSqlConnection { } } - Ok(Self { stream }) + Ok(Self { + stream, + pending_done_count: 0, + }) } } diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index 4a353217..75b3df43 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -7,6 +7,7 @@ use futures_util::TryStreamExt; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::mssql::protocol::done::Done; use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; @@ -14,7 +15,31 @@ use crate::mssql::protocol::sql_batch::SqlBatch; use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow}; impl MsSqlConnection { + async fn wait_until_ready(&mut self) -> Result<(), Error> { + if !self.stream.wbuf.is_empty() { + self.stream.flush().await?; + } + + while self.pending_done_count > 0 { + if let Message::DoneProc(done) | Message::Done(done) = + self.stream.recv_message().await? + { + // finished RPC procedure *OR* SQL batch + self.handle_done(done); + } + } + + Ok(()) + } + + fn handle_done(&mut self, _: Done) { + self.pending_done_count -= 1; + } + async fn run(&mut self, query: &str, arguments: Option) -> Result<(), Error> { + self.wait_until_ready().await?; + self.pending_done_count += 1; + if let Some(mut arguments) = arguments { let proc = Either::Right(Procedure::ExecuteSql); let mut proc_args = MsSqlArguments::default(); @@ -22,12 +47,14 @@ impl MsSqlConnection { // SQL proc_args.add_unnamed(query); - // Declarations - // NAME TYPE, NAME TYPE, ... - proc_args.add_unnamed(&*arguments.declarations); + if !arguments.data.is_empty() { + // Declarations + // NAME TYPE, NAME TYPE, ... + proc_args.add_unnamed(&*arguments.declarations); - // Add the list of SQL parameters _after_ our RPC parameters - proc_args.append(&mut arguments); + // Add the list of SQL parameters _after_ our RPC parameters + proc_args.append(&mut arguments); + } self.stream.write_packet( PacketType::Rpc, @@ -72,10 +99,19 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection { yield v; } - Message::Done(done) => { + Message::DoneProc(done) => { + self.handle_done(done); + break; + } + + Message::DoneInProc(done) => { + // finished SQL query *within* procedure let v = Either::Left(done.affected_rows); yield v; + } + Message::Done(done) => { + self.handle_done(done); break; } diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs index d4daf777..59387623 100644 --- a/sqlx-core/src/mssql/connection/mod.rs +++ b/sqlx-core/src/mssql/connection/mod.rs @@ -16,6 +16,9 @@ mod stream; pub struct MsSqlConnection { stream: MsSqlStream, + + // number of Done* messages that we are currently expecting + pub(crate) pending_done_count: usize, } impl Debug for MsSqlConnection { diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index af9683af..71d6c0e3 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -13,6 +13,7 @@ use crate::mssql::protocol::info::Info; use crate::mssql::protocol::login_ack::LoginAck; use crate::mssql::protocol::message::{Message, MessageType}; use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status}; +use crate::mssql::protocol::return_status::ReturnStatus; use crate::mssql::protocol::row::Row; use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError}; use crate::net::MaybeTlsStream; @@ -106,13 +107,15 @@ impl MsSqlStream { }; let ty = MessageType::get(buf)?; - - return Ok(match ty { + let message = match ty { MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?), MessageType::Info => Message::Info(Info::get(buf)?), MessageType::Row => Message::Row(Row::get(buf, &self.columns)?), MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?), + MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?), MessageType::Done => Message::Done(Done::get(buf)?), + MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?), + MessageType::DoneProc => Message::DoneProc(Done::get(buf)?), MessageType::Error => { let err = ProtocolError::get(buf)?; @@ -125,7 +128,9 @@ impl MsSqlStream { ColMetaData::get(buf, &mut self.columns)?; continue; } - }); + }; + + return Ok(message); } // no packet from the server to iterate (or its empty); fill our buffer diff --git a/sqlx-core/src/mssql/protocol/done.rs b/sqlx-core/src/mssql/protocol/done.rs index f5731f0a..5543ab73 100644 --- a/sqlx-core/src/mssql/protocol/done.rs +++ b/sqlx-core/src/mssql/protocol/done.rs @@ -3,21 +3,11 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -// Token Stream Function: -// Indicates the completion status of a SQL statementwithin a stored procedure. - -// Token Stream Definition: -// DONEINPROC = -// TokenType -// Status -// CurCmd -// DoneRowCount - #[derive(Debug)] pub(crate) struct Done { status: Status, - // The token of the current SQL statement. The token value is provided andcontrolled by the + // The token of the current SQL statement. The token value is provided and controlled by the // application layer, which utilizes TDS. The TDS layer does not evaluate the value. cursor_command: u16, diff --git a/sqlx-core/src/mssql/protocol/message.rs b/sqlx-core/src/mssql/protocol/message.rs index e4a98cc1..bb9167cd 100644 --- a/sqlx-core/src/mssql/protocol/message.rs +++ b/sqlx-core/src/mssql/protocol/message.rs @@ -6,6 +6,7 @@ use crate::mssql::protocol::env_change::EnvChange; use crate::mssql::protocol::error::Error; use crate::mssql::protocol::info::Info; use crate::mssql::protocol::login_ack::LoginAck; +use crate::mssql::protocol::return_status::ReturnStatus; use crate::mssql::protocol::row::Row; #[derive(Debug)] @@ -14,7 +15,10 @@ pub(crate) enum Message { LoginAck(LoginAck), EnvChange(EnvChange), Done(Done), + DoneInProc(Done), + DoneProc(Done), Row(Row), + ReturnStatus(ReturnStatus), } #[derive(Debug)] @@ -23,9 +27,12 @@ pub(crate) enum MessageType { LoginAck, EnvChange, Done, + DoneProc, + DoneInProc, Row, Error, ColMetaData, + ReturnStatus, } impl MessageType { @@ -37,7 +44,10 @@ impl MessageType { 0xad => MessageType::LoginAck, 0xd1 => MessageType::Row, 0xe3 => MessageType::EnvChange, + 0x79 => MessageType::ReturnStatus, 0xfd => MessageType::Done, + 0xfe => MessageType::DoneProc, + 0xff => MessageType::DoneInProc, ty => { return Err(err_protocol!( diff --git a/sqlx-core/src/mssql/protocol/mod.rs b/sqlx-core/src/mssql/protocol/mod.rs index 226a0d39..31ef0561 100644 --- a/sqlx-core/src/mssql/protocol/mod.rs +++ b/sqlx-core/src/mssql/protocol/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod login_ack; pub(crate) mod message; pub(crate) mod packet; pub(crate) mod pre_login; +pub(crate) mod return_status; pub(crate) mod row; pub(crate) mod rpc; pub(crate) mod sql_batch; diff --git a/sqlx-core/src/mssql/protocol/return_status.rs b/sqlx-core/src/mssql/protocol/return_status.rs new file mode 100644 index 00000000..fd779b40 --- /dev/null +++ b/sqlx-core/src/mssql/protocol/return_status.rs @@ -0,0 +1,17 @@ +use bitflags::bitflags; +use bytes::{Buf, Bytes}; + +use crate::error::Error; + +#[derive(Debug)] +pub(crate) struct ReturnStatus { + value: i32, +} + +impl ReturnStatus { + pub(crate) fn get(buf: &mut Bytes) -> Result { + let value = buf.get_i32_le(); + + Ok(Self { value }) + } +} diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index ac2cf92f..85717c07 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -1,3 +1,4 @@ +use futures::TryStreamExt; use sqlx::mssql::MsSql; use sqlx::{Connection, Executor, Row}; use sqlx_core::mssql::MsSqlRow; @@ -40,3 +41,35 @@ async fn it_maths() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_executes() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _ = conn + .execute( + r#" +CREATE TABLE #users (id INTEGER PRIMARY KEY); + "#, + ) + .await?; + + for index in 1..=10_i32 { + let cnt = sqlx::query("INSERT INTO #users (id) VALUES (@p1)") + .bind(index * 2) + .execute(&mut conn) + .await?; + + assert_eq!(cnt, 1); + } + + let sum: i32 = sqlx::query("SELECT id FROM #users") + .try_map(|row: MsSqlRow| row.try_get::(0)) + .fetch(&mut conn) + .try_fold(0_i32, |acc, x| async move { Ok(acc + x) }) + .await?; + + assert_eq!(sum, 110); + + Ok(()) +}