From 559169cc79464647e75a70036dc36f76d3b30f23 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 6 Jun 2020 14:08:46 -0700 Subject: [PATCH] refactor(mssql): clean up unused imports and other warnings --- sqlx-core/src/mssql/connection/establish.rs | 1 - sqlx-core/src/mssql/connection/executor.rs | 19 +- sqlx-core/src/mssql/connection/mod.rs | 6 +- sqlx-core/src/mssql/connection/stream.rs | 34 ++- sqlx-core/src/mssql/protocol/env_change.rs | 18 +- sqlx-core/src/mssql/protocol/login.rs | 24 -- sqlx-core/src/mssql/protocol/message.rs | 6 - sqlx-core/src/mssql/protocol/return_status.rs | 1 - sqlx-core/src/mssql/protocol/row.rs | 4 +- sqlx-core/src/mssql/protocol/rpc.rs | 4 +- sqlx-core/src/mssql/protocol/sql_batch.rs | 3 +- sqlx-core/src/mssql/protocol/type_info.rs | 2 - sqlx-core/src/mssql/transaction.rs | 57 +++- sqlx-core/src/mssql/types/float.rs | 1 - sqlx-core/src/mssql/types/int.rs | 1 - sqlx-core/src/mssql/types/str.rs | 6 +- sqlx-core/src/mssql/value.rs | 2 - sqlx-core/src/transaction.rs | 36 ++- tests/mssql/mssql.rs | 115 ++++++++ tests/postgres/postgres.rs | 274 ++++++++---------- 20 files changed, 370 insertions(+), 244 deletions(-) diff --git a/sqlx-core/src/mssql/connection/establish.rs b/sqlx-core/src/mssql/connection/establish.rs index 6c2717b5..1936da79 100644 --- a/sqlx-core/src/mssql/connection/establish.rs +++ b/sqlx-core/src/mssql/connection/establish.rs @@ -2,7 +2,6 @@ use crate::error::Error; use crate::io::Decode; use crate::mssql::connection::stream::MsSqlStream; use crate::mssql::protocol::login::Login7; -use crate::mssql::protocol::login_ack::LoginAck; use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::pre_login::{Encrypt, PreLogin, Version}; diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index 75b3df43..17542aa2 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -15,15 +15,16 @@ use crate::mssql::protocol::sql_batch::SqlBatch; use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow}; impl MsSqlConnection { - async fn wait_until_ready(&mut self) -> Result<(), Error> { + pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.stream.wbuf.is_empty() { + self.pending_done_count += 1; self.stream.flush().await?; } while self.pending_done_count > 0 { - if let Message::DoneProc(done) | Message::Done(done) = - self.stream.recv_message().await? - { + let message = self.stream.recv_message().await?; + + if let Message::DoneProc(done) | Message::Done(done) = message { // finished RPC procedure *OR* SQL batch self.handle_done(done); } @@ -59,14 +60,20 @@ impl MsSqlConnection { self.stream.write_packet( PacketType::Rpc, RpcRequest { + transaction_descriptor: self.stream.transaction_descriptor, arguments: &proc_args, procedure: proc, options: OptionFlags::empty(), }, ); } else { - self.stream - .write_packet(PacketType::SqlBatch, SqlBatch { sql: query }); + self.stream.write_packet( + PacketType::SqlBatch, + SqlBatch { + transaction_descriptor: self.stream.transaction_descriptor, + sql: query, + }, + ); } self.stream.flush().await?; diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs index 59387623..f833cf12 100644 --- a/sqlx-core/src/mssql/connection/mod.rs +++ b/sqlx-core/src/mssql/connection/mod.rs @@ -5,7 +5,7 @@ use futures_core::future::BoxFuture; use futures_util::{future::ready, FutureExt, TryFutureExt}; use crate::connection::{Connect, Connection}; -use crate::error::{BoxDynError, Error}; +use crate::error::Error; use crate::executor::Executor; use crate::mssql::connection::stream::MsSqlStream; use crate::mssql::{MsSql, MsSqlConnectOptions}; @@ -15,7 +15,7 @@ mod executor; mod stream; pub struct MsSqlConnection { - stream: MsSqlStream, + pub(crate) stream: MsSqlStream, // number of Done* messages that we are currently expecting pub(crate) pending_done_count: usize, @@ -42,7 +42,7 @@ impl Connection for MsSqlConnection { #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { - unimplemented!() + self.wait_until_ready().boxed() } #[doc(hidden)] diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index 71d6c0e3..710e2d86 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -1,7 +1,7 @@ use std::ops::{Deref, DerefMut}; use bytes::Bytes; -use sqlx_rt::{TcpStream, TlsStream}; +use sqlx_rt::TcpStream; use crate::error::Error; use crate::io::{BufStream, Encode}; @@ -21,6 +21,10 @@ use crate::net::MaybeTlsStream; pub(crate) struct MsSqlStream { inner: BufStream>, + // current transaction descriptor + // set from ENVCHANGE on `BEGIN` and reset to `0` on a ROLLBACK + pub(crate) transaction_descriptor: u64, + // current TabularResult from the server that we are iterating over response: Option<(PacketHeader, Bytes)>, @@ -39,12 +43,13 @@ impl MsSqlStream { inner, columns: Vec::new(), response: None, + transaction_descriptor: 0, }) } // writes the packet out to the write buffer // will (eventually) handle packet chunking - pub(super) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) { + pub(crate) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) { // TODO: Support packet chunking for large packet sizes // We likely need to double-buffer the writes so we know to chunk @@ -98,7 +103,7 @@ impl MsSqlStream { pub(super) async fn recv_message(&mut self) -> Result { loop { while self.response.as_ref().map_or(false, |r| !r.1.is_empty()) { - let mut buf = if let Some((_, buf)) = self.response.as_mut() { + let buf = if let Some((_, buf)) = self.response.as_mut() { buf } else { // this shouldn't be reachable but just nope out @@ -108,8 +113,27 @@ impl MsSqlStream { let ty = MessageType::get(buf)?; let message = match ty { - MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?), - MessageType::Info => Message::Info(Info::get(buf)?), + MessageType::EnvChange => { + match EnvChange::get(buf)? { + EnvChange::BeginTransaction(desc) => { + self.transaction_descriptor = desc; + } + + EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => { + self.transaction_descriptor = 0; + } + + _ => {} + } + + continue; + } + + MessageType::Info => { + let _ = Info::get(buf)?; + continue; + } + MessageType::Row => Message::Row(Row::get(buf, &self.columns)?), MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?), MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?), diff --git a/sqlx-core/src/mssql/protocol/env_change.rs b/sqlx-core/src/mssql/protocol/env_change.rs index db2a5701..b0b27611 100644 --- a/sqlx-core/src/mssql/protocol/env_change.rs +++ b/sqlx-core/src/mssql/protocol/env_change.rs @@ -1,7 +1,6 @@ use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; use crate::mssql::io::MsSqlBufExt; #[derive(Debug)] @@ -16,9 +15,9 @@ pub(crate) enum EnvChange { SqlCollation(Bytes), // TDS 7.2+ - BeginTransaction, - CommitTransaction, - RollbackTransaction, + BeginTransaction(u64), + CommitTransaction(u64), + RollbackTransaction(u64), EnlistDtcTransaction, DefectTransaction, RealTimeLogShipping, @@ -46,6 +45,17 @@ impl EnvChange { 5 => EnvChange::UnicodeDataSortingLocalId(data.get_b_varchar()?), 6 => EnvChange::UnicodeDataSortingComparisonFlags(data.get_b_varchar()?), 7 => EnvChange::SqlCollation(data.get_b_varbyte()), + 8 => EnvChange::BeginTransaction(data.get_b_varbyte().get_u64_le()), + + 9 => { + let _ = data.get_u8(); + EnvChange::CommitTransaction(data.get_u64_le()) + } + + 10 => { + let _ = data.get_u8(); + EnvChange::RollbackTransaction(data.get_u64_le()) + } _ => { return Err(err_protocol!("unexpected value {} for ENVCHANGE Type", ty)); diff --git a/sqlx-core/src/mssql/protocol/login.rs b/sqlx-core/src/mssql/protocol/login.rs index 8aa30d01..1c484831 100644 --- a/sqlx-core/src/mssql/protocol/login.rs +++ b/sqlx-core/src/mssql/protocol/login.rs @@ -1,26 +1,6 @@ -use hex::encode; -use std::mem::size_of; - use crate::io::Encode; use crate::mssql::io::MsSqlBufMutExt; -// Stream definition -// LOGIN7 = Length -// TDSVersion -// PacketSize -// ClientProgVer -// ClientPID -// ConnectionID -// OptionFlags1 -// OptionFlags2 -// TypeFlags -// OptionFlags3 -// ClientTimeZone -// ClientLCID -// OffsetLength -// Data -// FeatureExt - #[derive(Debug)] pub struct Login7<'a> { pub version: u32, @@ -156,10 +136,6 @@ impl Encode<'_> for Login7<'_> { // [ChangePassword] New password for the specified login write_offset(buf, &mut offsets, beg); - offsets += 2; - - // [SSPILong] Used for large SSPI data - offsets += 4; // Establish the length of the entire structure let len = buf.len(); diff --git a/sqlx-core/src/mssql/protocol/message.rs b/sqlx-core/src/mssql/protocol/message.rs index bb9167cd..d04c99a5 100644 --- a/sqlx-core/src/mssql/protocol/message.rs +++ b/sqlx-core/src/mssql/protocol/message.rs @@ -1,19 +1,13 @@ use bytes::{Buf, Bytes}; -use crate::mssql::protocol::col_meta_data::ColMetaData; use crate::mssql::protocol::done::Done; -use crate::mssql::protocol::env_change::EnvChange; -use crate::mssql::protocol::error::Error; -use crate::mssql::protocol::info::Info; use crate::mssql::protocol::login_ack::LoginAck; use crate::mssql::protocol::return_status::ReturnStatus; use crate::mssql::protocol::row::Row; #[derive(Debug)] pub(crate) enum Message { - Info(Info), LoginAck(LoginAck), - EnvChange(EnvChange), Done(Done), DoneInProc(Done), DoneProc(Done), diff --git a/sqlx-core/src/mssql/protocol/return_status.rs b/sqlx-core/src/mssql/protocol/return_status.rs index fd779b40..8c51d8b9 100644 --- a/sqlx-core/src/mssql/protocol/return_status.rs +++ b/sqlx-core/src/mssql/protocol/return_status.rs @@ -1,4 +1,3 @@ -use bitflags::bitflags; use bytes::{Buf, Bytes}; use crate::error::Error; diff --git a/sqlx-core/src/mssql/protocol/row.rs b/sqlx-core/src/mssql/protocol/row.rs index 1665ceb4..ec690652 100644 --- a/sqlx-core/src/mssql/protocol/row.rs +++ b/sqlx-core/src/mssql/protocol/row.rs @@ -1,10 +1,8 @@ -use std::ops::Range; - use bytes::Bytes; use crate::error::Error; use crate::mssql::protocol::col_meta_data::ColumnData; -use crate::mssql::{MsSql, MsSqlTypeInfo}; +use crate::mssql::MsSqlTypeInfo; #[derive(Debug)] pub(crate) struct Row { diff --git a/sqlx-core/src/mssql/protocol/rpc.rs b/sqlx-core/src/mssql/protocol/rpc.rs index ad87aecb..bfe97c67 100644 --- a/sqlx-core/src/mssql/protocol/rpc.rs +++ b/sqlx-core/src/mssql/protocol/rpc.rs @@ -7,6 +7,8 @@ use crate::mssql::protocol::header::{AllHeaders, Header}; use crate::mssql::MsSqlArguments; pub(crate) struct RpcRequest<'a> { + pub(crate) transaction_descriptor: u64, + // the procedure can be encoded as a u16 of a built-in or the name for a custom one pub(crate) procedure: Either<&'a str, Procedure>, pub(crate) options: OptionFlags, @@ -67,7 +69,7 @@ impl Encode<'_> for RpcRequest<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { AllHeaders(&[Header::TransactionDescriptor { outstanding_request_count: 1, - transaction_descriptor: 0, + transaction_descriptor: self.transaction_descriptor, }]) .encode(buf); diff --git a/sqlx-core/src/mssql/protocol/sql_batch.rs b/sqlx-core/src/mssql/protocol/sql_batch.rs index 92439f78..45aaed58 100644 --- a/sqlx-core/src/mssql/protocol/sql_batch.rs +++ b/sqlx-core/src/mssql/protocol/sql_batch.rs @@ -4,6 +4,7 @@ use crate::mssql::protocol::header::{AllHeaders, Header}; #[derive(Debug)] pub(crate) struct SqlBatch<'a> { + pub(crate) transaction_descriptor: u64, pub(crate) sql: &'a str, } @@ -11,7 +12,7 @@ impl Encode<'_> for SqlBatch<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { AllHeaders(&[Header::TransactionDescriptor { outstanding_request_count: 1, - transaction_descriptor: 0, + transaction_descriptor: self.transaction_descriptor, }]) .encode(buf); diff --git a/sqlx-core/src/mssql/protocol/type_info.rs b/sqlx-core/src/mssql/protocol/type_info.rs index 23995280..6109b97a 100644 --- a/sqlx-core/src/mssql/protocol/type_info.rs +++ b/sqlx-core/src/mssql/protocol/type_info.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use bitflags::bitflags; use bytes::{Buf, Bytes}; use encoding_rs::Encoding; diff --git a/sqlx-core/src/mssql/transaction.rs b/sqlx-core/src/mssql/transaction.rs index df3bff57..f56e63eb 100644 --- a/sqlx-core/src/mssql/transaction.rs +++ b/sqlx-core/src/mssql/transaction.rs @@ -1,12 +1,13 @@ +use std::borrow::Cow; + use futures_core::future::BoxFuture; use crate::error::Error; use crate::executor::Executor; +use crate::mssql::protocol::packet::PacketType; +use crate::mssql::protocol::sql_batch::SqlBatch; use crate::mssql::{MsSql, MsSqlConnection}; -use crate::transaction::{ - begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql, - TransactionManager, -}; +use crate::transaction::TransactionManager; /// Implementation of [`TransactionManager`] for MSSQL. pub struct MsSqlTransactionManager; @@ -15,18 +16,58 @@ impl TransactionManager for MsSqlTransactionManager { type Database = MsSql; fn begin(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { - unimplemented!() + Box::pin(async move { + let query = if depth == 0 { + Cow::Borrowed("BEGIN TRAN ") + } else { + Cow::Owned(format!("SAVE TRAN _sqlx_savepoint_{}", depth)) + }; + + conn.execute(&*query).await?; + + Ok(()) + }) } fn commit(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { - unimplemented!() + Box::pin(async move { + if depth == 1 { + // savepoints are not released in MSSQL + conn.execute("COMMIT TRAN").await?; + } + + Ok(()) + }) } fn rollback(conn: &mut MsSqlConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { - unimplemented!() + Box::pin(async move { + let query = if depth == 1 { + Cow::Borrowed("ROLLBACK TRAN") + } else { + Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1)) + }; + + conn.execute(&*query).await?; + + Ok(()) + }) } fn start_rollback(conn: &mut MsSqlConnection, depth: usize) { - unimplemented!() + let query = if depth == 1 { + Cow::Borrowed("ROLLBACK TRAN") + } else { + Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1)) + }; + + conn.pending_done_count += 1; + conn.stream.write_packet( + PacketType::SqlBatch, + SqlBatch { + transaction_descriptor: conn.stream.transaction_descriptor, + sql: &*query, + }, + ); } } diff --git a/sqlx-core/src/mssql/types/float.rs b/sqlx-core/src/mssql/types/float.rs index b53beef0..3cf059f2 100644 --- a/sqlx-core/src/mssql/types/float.rs +++ b/sqlx-core/src/mssql/types/float.rs @@ -1,6 +1,5 @@ use byteorder::{ByteOrder, LittleEndian}; -use crate::database::{Database, HasArguments, HasValueRef}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; diff --git a/sqlx-core/src/mssql/types/int.rs b/sqlx-core/src/mssql/types/int.rs index a9f45531..1e7b9a02 100644 --- a/sqlx-core/src/mssql/types/int.rs +++ b/sqlx-core/src/mssql/types/int.rs @@ -1,6 +1,5 @@ use byteorder::{ByteOrder, LittleEndian}; -use crate::database::{Database, HasArguments, HasValueRef}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; diff --git a/sqlx-core/src/mssql/types/str.rs b/sqlx-core/src/mssql/types/str.rs index 6417efc8..7c362f0f 100644 --- a/sqlx-core/src/mssql/types/str.rs +++ b/sqlx-core/src/mssql/types/str.rs @@ -1,12 +1,8 @@ -use byteorder::{ByteOrder, LittleEndian}; -use bytes::Buf; - -use crate::database::{Database, HasArguments, HasValueRef}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::mssql::io::MsSqlBufMutExt; -use crate::mssql::protocol::type_info::{Collation, DataType, TypeInfo}; +use crate::mssql::protocol::type_info::{DataType, TypeInfo}; use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef}; use crate::types::Type; diff --git a/sqlx-core/src/mssql/value.rs b/sqlx-core/src/mssql/value.rs index 890aa77e..fd25bfd7 100644 --- a/sqlx-core/src/mssql/value.rs +++ b/sqlx-core/src/mssql/value.rs @@ -1,9 +1,7 @@ use std::borrow::Cow; -use std::marker::PhantomData; use bytes::Bytes; -use crate::database::HasValueRef; use crate::error::{BoxDynError, UnexpectedNullError}; use crate::mssql::{MsSql, MsSqlTypeInfo}; use crate::value::{Value, ValueRef}; diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index ed5fe062..7942d238 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; +use std::mem; use std::ops::{Deref, DerefMut}; use futures_core::future::BoxFuture; @@ -89,12 +90,22 @@ where /// Commits this transaction or savepoint. pub async fn commit(mut self) -> Result<(), Error> { - DB::TransactionManager::commit(self.connection.get_mut(), self.depth).await + DB::TransactionManager::commit(self.connection.get_mut(), self.depth).await?; + + // opt-out of the automatic rollback + mem::forget(self); + + Ok(()) } /// Aborts this transaction or savepoint. pub async fn rollback(mut self) -> Result<(), Error> { - DB::TransactionManager::rollback(self.connection.get_mut(), self.depth).await + DB::TransactionManager::rollback(self.connection.get_mut(), self.depth).await?; + + // opt-out of the automatic rollback + mem::forget(self); + + Ok(()) } } @@ -243,28 +254,31 @@ where } #[allow(dead_code)] -pub(crate) fn begin_ansi_transaction_sql(index: usize) -> Cow<'static, str> { - if index == 0 { +pub(crate) fn begin_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { + if depth == 0 { Cow::Borrowed("BEGIN") } else { - Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{}", index)) + Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{}", depth)) } } #[allow(dead_code)] -pub(crate) fn commit_ansi_transaction_sql(index: usize) -> Cow<'static, str> { - if index == 1 { +pub(crate) fn commit_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { + if depth == 1 { Cow::Borrowed("COMMIT") } else { - Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", index)) + Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)) } } #[allow(dead_code)] -pub(crate) fn rollback_ansi_transaction_sql(index: usize) -> Cow<'static, str> { - if index == 1 { +pub(crate) fn rollback_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { + if depth == 1 { Cow::Borrowed("ROLLBACK") } else { - Cow::Owned(format!("ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", index)) + Cow::Owned(format!( + "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", + depth - 1 + )) } } diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 85717c07..d0051201 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -73,3 +73,118 @@ CREATE TABLE #users (id INTEGER PRIMARY KEY); Ok(()) } + +#[sqlx_macros::test] +async fn it_can_work_with_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("IF OBJECT_ID('_sqlx_users_1922', 'U') IS NULL CREATE TABLE _sqlx_users_1922 (id INTEGER PRIMARY KEY)") + .await?; + + conn.execute("DELETE FROM _sqlx_users_1922").await?; + + // begin .. rollback + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + tx.rollback().await?; + + let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 0); + + // begin .. commit + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + tx.commit().await?; + + let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + // begin .. (drop) + + { + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") + .bind(20_i32) + .execute(&mut tx) + .await?; + } + + conn = new::().await?; + + let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("IF OBJECT_ID('_sqlx_users_2523', 'U') IS NULL CREATE TABLE _sqlx_users_2523 (id INTEGER PRIMARY KEY)") + .await?; + + conn.execute("DELETE FROM _sqlx_users_2523").await?; + + // begin + let mut tx = conn.begin().await?; + + // insert a user + sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") + .bind(50_i32) + .execute(&mut tx) + .await?; + + // begin once more + let mut tx2 = tx.begin().await?; + + // insert another user + sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") + .bind(10_i32) + .execute(&mut tx2) + .await?; + + // never mind, rollback + tx2.rollback().await?; + + // did we really? + let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") + .fetch_one(&mut tx) + .await?; + + assert_eq!(count, 1); + + // actually, commit + tx.commit().await?; + + // did we really? + let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 4fe44686..e40dc2d3 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -179,165 +179,121 @@ async fn it_can_query_all_scalar() -> anyhow::Result<()> { Ok(()) } -// #[cfg_attr(feature = "runtime-async-std", async_std::test)] -// #[cfg_attr(feature = "runtime-tokio", tokio::test)] -// async fn it_can_work_with_transactions() -> anyhow::Result<()> { -// let mut conn = new::().await?; -// -// conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") -// .await?; -// -// conn.execute("TRUNCATE _sqlx_users_1922").await?; -// -// // begin .. rollback -// -// let mut tx = conn.begin().await?; -// -// sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") -// .bind(10_i32) -// .execute(&mut tx) -// .await?; -// -// conn = tx.rollback().await?; -// -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") -// .fetch_one(&mut conn) -// .await?; -// -// assert_eq!(count, 0); -// -// // begin .. commit -// -// let mut tx = conn.begin().await?; -// -// sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") -// .bind(10_i32) -// .execute(&mut tx) -// .await?; -// -// conn = tx.commit().await?; -// -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") -// .fetch_one(&mut conn) -// .await?; -// -// assert_eq!(count, 1); -// -// // begin .. (drop) -// -// { -// let mut tx = conn.begin().await?; -// -// sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") -// .bind(20_i32) -// .execute(&mut tx) -// .await?; -// } -// -// conn = new::().await?; -// -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") -// .fetch_one(&mut conn) -// .await?; -// -// assert_eq!(count, 1); -// -// Ok(()) -// } -// -// #[cfg_attr(feature = "runtime-async-std", async_std::test)] -// #[cfg_attr(feature = "runtime-tokio", tokio::test)] -// async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { -// let mut conn = new::().await?; -// -// conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") -// .await?; -// -// conn.execute("TRUNCATE _sqlx_users_2523").await?; -// -// // begin -// let mut tx = conn.begin().await?; -// -// // insert a user -// sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") -// .bind(50_i32) -// .execute(&mut tx) -// .await?; -// -// // begin once more -// let mut tx = tx.begin().await?; -// -// // insert another user -// sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") -// .bind(10_i32) -// .execute(&mut tx) -// .await?; -// -// // never mind, rollback -// let mut tx = tx.rollback().await?; -// -// // did we really? -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") -// .fetch_one(&mut tx) -// .await?; -// -// assert_eq!(count, 1); -// -// // actually, commit -// let mut conn = tx.commit().await?; -// -// // did we really? -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") -// .fetch_one(&mut conn) -// .await?; -// -// assert_eq!(count, 1); -// -// Ok(()) -// } -// -// #[cfg_attr(feature = "runtime-async-std", async_std::test)] -// #[cfg_attr(feature = "runtime-tokio", tokio::test)] -// async fn it_can_rollback_nested_transactions() -> anyhow::Result<()> { -// let mut conn = new::().await?; -// -// conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_512412 (id INTEGER PRIMARY KEY)") -// .await?; -// -// conn.execute("TRUNCATE _sqlx_users_512412").await?; -// -// // begin -// let mut tx = conn.begin().await?; -// -// // insert a user -// sqlx::query("INSERT INTO _sqlx_users_512412 (id) VALUES ($1)") -// .bind(50_i32) -// .execute(&mut tx) -// .await?; -// -// // begin once more -// let mut tx = tx.begin().await?; -// -// // insert another user -// sqlx::query("INSERT INTO _sqlx_users_512412 (id) VALUES ($1)") -// .bind(10_i32) -// .execute(&mut tx) -// .await?; -// -// // stop the phone, drop the entire transaction -// tx.close().await?; -// -// // did we really? -// let mut conn = new::().await?; -// let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_512412") -// .fetch_one(&mut conn) -// .await?; -// -// assert_eq!(count, 0); -// -// Ok(()) -// } -// +#[sqlx_macros::test] +async fn it_can_work_with_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)") + .await?; + + conn.execute("TRUNCATE _sqlx_users_1922").await?; + + // begin .. rollback + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + tx.rollback().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 0); + + // begin .. commit + + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx) + .await?; + + tx.commit().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + // begin .. (drop) + + { + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + .bind(20_i32) + .execute(&mut tx) + .await?; + } + + conn = new::().await?; + + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_2523 (id INTEGER PRIMARY KEY)") + .await?; + + conn.execute("TRUNCATE _sqlx_users_2523").await?; + + // begin + let mut tx = conn.begin().await?; + + // insert a user + sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") + .bind(50_i32) + .execute(&mut tx) + .await?; + + // begin once more + let mut tx2 = tx.begin().await?; + + // insert another user + sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES ($1)") + .bind(10_i32) + .execute(&mut tx2) + .await?; + + // never mind, rollback + tx2.rollback().await?; + + // did we really? + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") + .fetch_one(&mut tx) + .await?; + + assert_eq!(count, 1); + + // actually, commit + tx.commit().await?; + + // did we really? + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") + .fetch_one(&mut conn) + .await?; + + assert_eq!(count, 1); + + Ok(()) +} + // // run with `cargo test --features postgres -- --ignored --nocapture pool_smoke_test` // #[ignore] // #[cfg_attr(feature = "runtime-async-std", async_std::test)]