diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index 9dc4ba6e..b34d06a0 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use crate::arguments::Arguments; use crate::column::Column; use crate::connection::Connection; +use crate::done::Done; use crate::row::Row; use crate::transaction::TransactionManager; use crate::type_info::TypeInfo; @@ -31,6 +32,9 @@ pub trait Database: /// The concrete `Row` implementation for this database. type Row: Row; + /// The concrete `Done` implementation for this database. + type Done: Done; + /// The concrete `Column` implementation for this database. type Column: Column; diff --git a/sqlx-core/src/done.rs b/sqlx-core/src/done.rs new file mode 100644 index 00000000..c3d0ea50 --- /dev/null +++ b/sqlx-core/src/done.rs @@ -0,0 +1,9 @@ +use crate::database::Database; +use std::iter::Extend; + +pub trait Done: 'static + Sized + Send + Sync + Default + Extend { + type Database: Database; + + /// Returns the number of rows affected by an `UPDATE`, `INSERT`, or `DELETE`. + fn rows_affected(&self) -> u64; +} diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 9abd859c..956ca377 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -1,13 +1,11 @@ -use std::fmt::Debug; - +use crate::database::{Database, HasArguments}; +use crate::error::Error; +use crate::statement::StatementInfo; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; - -use crate::database::{Database, HasArguments}; -use crate::error::Error; -use crate::statement::StatementInfo; +use std::fmt::Debug; /// A type that contains or can provide a database /// connection to use for executing queries against the database. @@ -28,18 +26,22 @@ pub trait Executor<'c>: Send + Debug + Sized { type Database: Database; /// Execute the query and return the total number of rows affected. - fn execute<'e, 'q: 'e, E: 'q>(self, query: E) -> BoxFuture<'e, Result> + fn execute<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result<::Done, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { - self.execute_many(query) - .try_fold(0, |acc, x| async move { Ok(acc + x) }) - .boxed() + self.execute_many(query).try_collect().boxed() } /// Execute multiple queries and return the rows affected from each query, in a stream. - fn execute_many<'e, 'q: 'e, E: 'q>(self, query: E) -> BoxStream<'e, Result> + fn execute_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxStream<'e, Result<::Done, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -78,7 +80,13 @@ pub trait Executor<'c>: Send + Debug + Sized { fn fetch_many<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxStream<'e, Result::Row>, Error>> + ) -> BoxStream< + 'e, + Result< + Either<::Done, ::Row>, + Error, + >, + > where 'c: 'e, E: Execute<'q, Self::Database>; diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index d5d384e8..349b2e08 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -52,6 +52,7 @@ pub mod acquire; pub mod column; mod common; pub mod database; +pub mod done; pub mod executor; pub mod from_row; mod io; diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index c6363ce1..419ade01 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -6,7 +6,7 @@ use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; use crate::mssql::protocol::sql_batch::SqlBatch; -use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow}; +use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlDone, MssqlRow}; use crate::statement::StatementInfo; use either::Either; use futures_core::future::BoxFuture; @@ -66,7 +66,7 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -94,7 +94,9 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { } if done.status.contains(Status::DONE_COUNT) { - r#yield!(Either::Left(done.affected_rows)); + r#yield!(Either::Left(MssqlDone { + rows_affected: done.affected_rows, + })); } if !done.status.contains(Status::DONE_MORE) { @@ -104,7 +106,9 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { Message::DoneInProc(done) => { if done.status.contains(Status::DONE_COUNT) { - r#yield!(Either::Left(done.affected_rows)); + r#yield!(Either::Left(MssqlDone { + rows_affected: done.affected_rows, + })); } } diff --git a/sqlx-core/src/mssql/database.rs b/sqlx-core/src/mssql/database.rs index 604a6467..87358fab 100644 --- a/sqlx-core/src/mssql/database.rs +++ b/sqlx-core/src/mssql/database.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasArguments, HasValueRef}; use crate::mssql::{ - MssqlArguments, MssqlColumn, MssqlConnection, MssqlRow, MssqlTransactionManager, MssqlTypeInfo, - MssqlValue, MssqlValueRef, + MssqlArguments, MssqlColumn, MssqlConnection, MssqlDone, MssqlRow, MssqlTransactionManager, + MssqlTypeInfo, MssqlValue, MssqlValueRef, }; /// MSSQL database driver. @@ -15,6 +15,8 @@ impl Database for Mssql { type Row = MssqlRow; + type Done = MssqlDone; + type Column = MssqlColumn; type TypeInfo = MssqlTypeInfo; diff --git a/sqlx-core/src/mssql/done.rs b/sqlx-core/src/mssql/done.rs new file mode 100644 index 00000000..0bb556bf --- /dev/null +++ b/sqlx-core/src/mssql/done.rs @@ -0,0 +1,24 @@ +use crate::done::Done; +use crate::mssql::Mssql; +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct MssqlDone { + pub(super) rows_affected: u64, +} + +impl Done for MssqlDone { + type Database = Mssql; + + fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for MssqlDone { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} diff --git a/sqlx-core/src/mssql/mod.rs b/sqlx-core/src/mssql/mod.rs index 8ae035e8..e7cbf4a8 100644 --- a/sqlx-core/src/mssql/mod.rs +++ b/sqlx-core/src/mssql/mod.rs @@ -4,6 +4,7 @@ mod arguments; mod column; mod connection; mod database; +mod done; mod error; mod io; mod options; @@ -18,6 +19,7 @@ pub use arguments::MssqlArguments; pub use column::MssqlColumn; pub use connection::MssqlConnection; pub use database::Mssql; +pub use done::MssqlDone; pub use error::MssqlDatabaseError; pub use options::MssqlConnectOptions; pub use row::MssqlRow; diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index ab9ba852..9eeed9a9 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -19,7 +19,8 @@ use crate::mysql::protocol::statement::{ use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; use crate::mysql::protocol::Packet; use crate::mysql::{ - MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat, + MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlDone, MySqlRow, MySqlTypeInfo, + MySqlValueFormat, }; use crate::statement::StatementInfo; @@ -111,7 +112,7 @@ impl MySqlConnection { &'c mut self, query: &str, arguments: Option, - ) -> Result, Error>> + 'c, Error> { + ) -> Result, Error>> + 'c, Error> { self.stream.wait_until_ready().await?; self.stream.busy = Busy::Result; @@ -145,7 +146,12 @@ impl MySqlConnection { // this indicates either a successful query with no rows at all or a failed query let ok = packet.ok()?; - r#yield!(Either::Left(ok.affected_rows)); + let done = MySqlDone { + rows_affected: ok.affected_rows, + last_insert_id: ok.last_insert_id, + }; + + r#yield!(Either::Left(done)); if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one @@ -166,7 +172,11 @@ impl MySqlConnection { if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.stream.capabilities)?; - r#yield!(Either::Left(0)); + + r#yield!(Either::Left(MySqlDone { + rows_affected: 0, + last_insert_id: 0, + })); if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one @@ -203,7 +213,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index 178414db..2ba239b6 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -1,7 +1,8 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::mysql::value::{MySqlValue, MySqlValueRef}; use crate::mysql::{ - MySqlArguments, MySqlColumn, MySqlConnection, MySqlRow, MySqlTransactionManager, MySqlTypeInfo, + MySqlArguments, MySqlColumn, MySqlConnection, MySqlDone, MySqlRow, MySqlTransactionManager, + MySqlTypeInfo, }; /// MySQL database driver. @@ -15,6 +16,8 @@ impl Database for MySql { type Row = MySqlRow; + type Done = MySqlDone; + type Column = MySqlColumn; type TypeInfo = MySqlTypeInfo; diff --git a/sqlx-core/src/mysql/done.rs b/sqlx-core/src/mysql/done.rs new file mode 100644 index 00000000..e58a8912 --- /dev/null +++ b/sqlx-core/src/mysql/done.rs @@ -0,0 +1,32 @@ +use crate::done::Done; +use crate::mysql::MySql; +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct MySqlDone { + pub(super) rows_affected: u64, + pub(super) last_insert_id: u64, +} + +impl MySqlDone { + pub fn last_insert_id(&self) -> u64 { + self.last_insert_id + } +} + +impl Done for MySqlDone { + type Database = MySql; + + fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for MySqlDone { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + self.last_insert_id = elem.last_insert_id; + } + } +} diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index fd692e8d..fb2f0dde 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -4,6 +4,7 @@ mod arguments; mod column; mod connection; mod database; +mod done; mod error; mod io; mod options; @@ -21,6 +22,7 @@ pub use arguments::MySqlArguments; pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; +pub use done::MySqlDone; pub use error::MySqlDatabaseError; pub use options::{MySqlConnectOptions, MySqlSslMode}; pub use row::MySqlRow; diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index 7c8a6717..a6a69b8f 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -18,7 +18,7 @@ where fn fetch_many<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where E: Execute<'q, Self::Database>, { @@ -75,7 +75,10 @@ macro_rules! impl_executor_for_pool_connection { query: E, ) -> futures_core::stream::BoxStream< 'e, - Result, crate::error::Error>, + Result< + either::Either<<$DB as crate::database::Database>::Done, $R>, + crate::error::Error, + >, > where 'c: 'e, diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index 4af0a9d1..411dfcf4 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -12,7 +12,7 @@ use crate::postgres::message::{ Query, RowDescription, }; use crate::postgres::type_info::PgType; -use crate::postgres::{PgArguments, PgConnection, PgRow, PgValueFormat, Postgres}; +use crate::postgres::{PgArguments, PgConnection, PgDone, PgRow, PgValueFormat, Postgres}; use crate::statement::StatementInfo; async fn prepare( @@ -142,7 +142,7 @@ impl PgConnection { query: &str, arguments: Option, limit: u8, - ) -> Result, Error>> + '_, Error> { + ) -> Result, Error>> + '_, Error> { // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; @@ -219,7 +219,9 @@ impl PgConnection { // a SQL command completed normally let cc: CommandComplete = message.decode()?; - r#yield!(Either::Left(cc.rows_affected())); + r#yield!(Either::Left(PgDone { + rows_affected: cc.rows_affected(), + })); } MessageFormat::EmptyQueryResponse => { @@ -272,7 +274,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 562dcda0..03b218d7 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -2,7 +2,7 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::postgres::arguments::PgArgumentBuffer; use crate::postgres::value::{PgValue, PgValueRef}; use crate::postgres::{ - PgArguments, PgColumn, PgConnection, PgRow, PgTransactionManager, PgTypeInfo, + PgArguments, PgColumn, PgConnection, PgDone, PgRow, PgTransactionManager, PgTypeInfo, }; /// PostgreSQL database driver. @@ -16,6 +16,8 @@ impl Database for Postgres { type Row = PgRow; + type Done = PgDone; + type Column = PgColumn; type TypeInfo = PgTypeInfo; diff --git a/sqlx-core/src/postgres/done.rs b/sqlx-core/src/postgres/done.rs new file mode 100644 index 00000000..465c7a19 --- /dev/null +++ b/sqlx-core/src/postgres/done.rs @@ -0,0 +1,24 @@ +use crate::done::Done; +use crate::postgres::Postgres; +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct PgDone { + pub(super) rows_affected: u64, +} + +impl Done for PgDone { + type Database = Postgres; + + fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for PgDone { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 35203bab..bffda88a 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -3,7 +3,7 @@ use crate::executor::{Execute, Executor}; use crate::pool::PoolOptions; use crate::pool::{Pool, PoolConnection}; use crate::postgres::message::{MessageFormat, Notification}; -use crate::postgres::{PgConnection, PgRow, Postgres}; +use crate::postgres::{PgConnection, PgDone, PgRow, Postgres}; use crate::statement::StatementInfo; use either::Either; use futures_channel::mpsc; @@ -197,7 +197,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { fn fetch_many<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 22e0444b..9f384b05 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -4,6 +4,7 @@ mod arguments; mod column; mod connection; mod database; +mod done; mod error; mod io; mod listener; @@ -22,6 +23,7 @@ pub use arguments::{PgArgumentBuffer, PgArguments}; pub use column::PgColumn; pub use connection::PgConnection; pub use database::Postgres; +pub use done::PgDone; pub use error::{PgDatabaseError, PgErrorPosition}; pub use listener::{PgListener, PgNotification}; pub use message::PgSeverity; diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index ae151f46..d03fdd34 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -104,7 +104,7 @@ where /// Execute the query and return the total number of rows affected. #[inline] - pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result + pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result where 'q: 'e, A: 'e, @@ -115,7 +115,10 @@ where /// Execute multiple queries and return the rows affected from each query, in a stream. #[inline] - pub async fn execute_many<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> + pub async fn execute_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result> where 'q: 'e, A: 'e, @@ -141,7 +144,7 @@ where pub fn fetch_many<'e, 'c: 'e, E>( self, executor: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'q: 'e, A: 'e, @@ -231,7 +234,7 @@ where pub fn fetch_many<'e, 'c: 'e, E>( self, executor: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'q: 'e, E: 'e + Executor<'c, Database = DB>, diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index e70178be..877bf99c 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -74,7 +74,7 @@ where pub fn fetch_many<'e, 'c: 'e, E>( self, executor: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'q: 'e, E: 'e + Executor<'c, Database = DB>, diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index b3ee3d68..370bf441 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -71,7 +71,7 @@ where pub fn fetch_many<'e, 'c: 'e, E>( self, executor: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'q: 'e, E: 'e + Executor<'c, Database = DB>, diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index c2e297e2..0ab0c52f 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -5,6 +5,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{FutureExt, TryStreamExt}; use hashbrown::HashMap; +use libsqlite3_sys::sqlite3_last_insert_rowid; use crate::common::StatementCache; use crate::error::Error; @@ -13,7 +14,9 @@ use crate::ext::ustr::UStr; use crate::sqlite::connection::describe::describe; use crate::sqlite::connection::ConnectionHandle; use crate::sqlite::statement::{SqliteStatement, StatementHandle}; -use crate::sqlite::{Sqlite, SqliteArguments, SqliteColumn, SqliteConnection, SqliteRow}; +use crate::sqlite::{ + Sqlite, SqliteArguments, SqliteColumn, SqliteConnection, SqliteDone, SqliteRow, +}; use crate::statement::StatementInfo; fn prepare<'a>( @@ -92,7 +95,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, - ) -> BoxStream<'e, Result, Error>> + ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -145,7 +148,16 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { match s { Either::Left(changes) => { - r#yield!(Either::Left(changes)); + let last_insert_rowid = unsafe { + sqlite3_last_insert_rowid(conn.as_ptr()) + }; + + let done = SqliteDone { + changes: changes, + last_insert_rowid: last_insert_rowid, + }; + + r#yield!(Either::Left(done)); break; } diff --git a/sqlx-core/src/sqlite/database.rs b/sqlx-core/src/sqlite/database.rs index de627f89..0630ef2a 100644 --- a/sqlx-core/src/sqlite/database.rs +++ b/sqlx-core/src/sqlite/database.rs @@ -1,6 +1,6 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::sqlite::{ - SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnection, SqliteRow, + SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnection, SqliteDone, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, SqliteValue, SqliteValueRef, }; @@ -15,6 +15,8 @@ impl Database for Sqlite { type Row = SqliteRow; + type Done = SqliteDone; + type Column = SqliteColumn; type TypeInfo = SqliteTypeInfo; diff --git a/sqlx-core/src/sqlite/done.rs b/sqlx-core/src/sqlite/done.rs new file mode 100644 index 00000000..3b5639f8 --- /dev/null +++ b/sqlx-core/src/sqlite/done.rs @@ -0,0 +1,32 @@ +use crate::done::Done; +use crate::sqlite::Sqlite; +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct SqliteDone { + pub(super) changes: u64, + pub(super) last_insert_rowid: i64, +} + +impl SqliteDone { + pub fn last_insert_rowid(&self) -> i64 { + self.last_insert_rowid + } +} + +impl Done for SqliteDone { + type Database = Sqlite; + + fn rows_affected(&self) -> u64 { + self.changes + } +} + +impl Extend for SqliteDone { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.changes += elem.changes; + self.last_insert_rowid = elem.last_insert_rowid; + } + } +} diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index 35bbcb6b..d3e648c6 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -9,6 +9,7 @@ mod arguments; mod column; mod connection; mod database; +mod done; mod error; mod options; mod row; @@ -25,6 +26,7 @@ pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; pub use connection::SqliteConnection; pub use database::Sqlite; +pub use done::SqliteDone; pub use error::SqliteError; pub use options::{SqliteConnectOptions, SqliteJournalMode}; pub use row::SqliteRow; diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 99dd9b01..28be36fc 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -108,7 +108,10 @@ macro_rules! impl_executor_for_transaction { query: E, ) -> futures_core::stream::BoxStream< 'e, - Result, crate::error::Error>, + Result< + either::Either<<$DB as crate::database::Database>::Done, $Row>, + crate::error::Error, + >, > where 't: 'e, diff --git a/src/lib.rs b/src/lib.rs index 7885e073..85611592 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; +pub use sqlx_core::done::Done; pub use sqlx_core::executor::{Execute, Executor}; pub use sqlx_core::from_row::FromRow; pub use sqlx_core::pool::{self, Pool}; @@ -113,7 +114,7 @@ pub mod decode { pub use sqlx_macros::Decode; } -/// Return types for the `query` family of functions and macros. +/// Types and traits for the `query` family of functions and macros. pub mod query { pub use sqlx_core::query::{Map, Query}; pub use sqlx_core::query::{MapRow, TryMapRow}; @@ -126,6 +127,7 @@ pub mod prelude { pub use super::Acquire; pub use super::ConnectOptions; pub use super::Connection; + pub use super::Done; pub use super::Executor; pub use super::FromRow; pub use super::IntoArguments; diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 2c286a8d..1a89cb49 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; use sqlx::mssql::Mssql; -use sqlx::{Connection, Executor, MssqlConnection, Row}; +use sqlx::{Connection, Done, Executor, MssqlConnection, Row}; use sqlx_core::mssql::MssqlRow; use sqlx_test::new; @@ -54,7 +54,7 @@ async fn it_can_fail_to_connect() -> anyhow::Result<()> { async fn it_can_inspect_errors() -> anyhow::Result<()> { let mut conn = new::().await?; - let res: Result = sqlx::query("select f").execute(&mut conn).await; + let res: Result<_, sqlx::Error> = sqlx::query("select f").execute(&mut conn).await; let err = res.unwrap_err(); // can also do [as_database_error] or use `match ..` @@ -93,12 +93,12 @@ CREATE TABLE #users (id INTEGER PRIMARY KEY); .await?; for index in 1..=10_i32 { - let cnt = sqlx::query("INSERT INTO #users (id) VALUES (@p1)") + let done = sqlx::query("INSERT INTO #users (id) VALUES (@p1)") .bind(index * 2) .execute(&mut conn) .await?; - assert_eq!(cnt, 1); + assert_eq!(done.rows_affected(), 1); } let sum: i32 = sqlx::query("SELECT id FROM #users") diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 8cf73989..8aea5e72 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; use sqlx::mysql::{MySql, MySqlPool, MySqlPoolOptions, MySqlRow}; -use sqlx::{Connection, Executor, Row}; +use sqlx::{Connection, Done, Executor, Row}; use sqlx_test::new; #[sqlx_macros::test] @@ -55,12 +55,12 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); .await?; for index in 1..=10_i32 { - let cnt = sqlx::query("INSERT INTO users (id) VALUES (?)") + let done = sqlx::query("INSERT INTO users (id) VALUES (?)") .bind(index) .execute(&mut conn) .await?; - assert_eq!(cnt, 1); + assert_eq!(done.rows_affected(), 1); } let sum: i32 = sqlx::query("SELECT id FROM users") @@ -102,12 +102,12 @@ async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> { let mut conn = new::().await?; // ~1800 rows should be iterated and dropped - let affected = conn + let done = conn .execute("select * from mysql.time_zone limit 1575") .await?; // In MySQL, rows being returned isn't enough to flag it as an _affected_ row - assert_eq!(0, affected); + assert_eq!(0, done.rows_affected()); Ok(()) } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 6311ec86..b48f5f20 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -3,7 +3,7 @@ use sqlx::postgres::{ PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity, }; use sqlx::postgres::{PgPoolOptions, PgRow}; -use sqlx::{postgres::Postgres, Connection, Executor, Row}; +use sqlx::{postgres::Postgres, Connection, Done, Executor, Row}; use sqlx_test::new; use std::env; use std::thread; @@ -51,7 +51,7 @@ async fn it_maths() -> anyhow::Result<()> { async fn it_can_inspect_errors() -> anyhow::Result<()> { let mut conn = new::().await?; - let res: Result = sqlx::query("select f").execute(&mut conn).await; + let res: Result<_, sqlx::Error> = sqlx::query("select f").execute(&mut conn).await; let err = res.unwrap_err(); // can also do [as_database_error] or use `match ..` @@ -85,12 +85,12 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY); .await?; for index in 1..=10_i32 { - let cnt = sqlx::query("INSERT INTO users (id) VALUES ($1)") + let done = sqlx::query("INSERT INTO users (id) VALUES ($1)") .bind(index) .execute(&mut conn) .await?; - assert_eq!(cnt, 1); + assert_eq!(done.rows_affected(), 1); } let sum: i32 = sqlx::query("SELECT id FROM users") @@ -441,9 +441,9 @@ async fn test_invalid_query() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_empty_query() -> anyhow::Result<()> { let mut conn = new::().await?; - let affected = conn.execute("").await?; + let done = conn.execute("").await?; - assert_eq!(affected, 0); + assert_eq!(done.rows_affected(), 0); Ok(()) } diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 605c1850..206b09e4 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -1,8 +1,8 @@ use sqlx::error::DatabaseError; use sqlx::sqlite::{SqliteConnectOptions, SqliteError}; use sqlx::ConnectOptions; +use sqlx::TypeInfo; use sqlx::{sqlite::Sqlite, Column, Executor}; -use sqlx::{SqliteConnection, TypeInfo}; use sqlx_test::new; use std::env; diff --git a/tests/sqlite/sqlite.db b/tests/sqlite/sqlite.db index ac649928..32442999 100644 Binary files a/tests/sqlite/sqlite.db and b/tests/sqlite/sqlite.db differ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index f38431f8..edb0ba9c 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1,7 +1,7 @@ use futures::TryStreamExt; use sqlx::sqlite::SqlitePoolOptions; use sqlx::{ - query, sqlite::Sqlite, sqlite::SqliteRow, Connection, Executor, Row, SqliteConnection, + query, sqlite::Sqlite, sqlite::SqliteRow, Connection, Done, Executor, Row, SqliteConnection, SqlitePool, }; use sqlx_test::new; @@ -179,9 +179,9 @@ async fn it_fails_to_parse() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_handles_empty_queries() -> anyhow::Result<()> { let mut conn = new::().await?; - let affected = conn.execute("").await?; + let done = conn.execute("").await?; - assert_eq!(affected, 0); + assert_eq!(done.rows_affected(), 0); Ok(()) } @@ -221,12 +221,12 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) .await?; for index in 1..=10_i32 { - let cnt = sqlx::query("INSERT INTO users (id) VALUES (?)") + let done = sqlx::query("INSERT INTO users (id) VALUES (?)") .bind(index * 2) .execute(&mut conn) .await?; - assert_eq!(cnt, 1); + assert_eq!(done.rows_affected(), 1); } let sum: i32 = sqlx::query_as("SELECT id FROM users") @@ -243,7 +243,7 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) async fn it_can_execute_multiple_statements() -> anyhow::Result<()> { let mut conn = new::().await?; - let affected = conn + let done = conn .execute( r#" CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, other INTEGER); @@ -252,7 +252,7 @@ INSERT INTO users DEFAULT VALUES; ) .await?; - assert_eq!(affected, 1); + assert_eq!(done.rows_affected(), 1); for index in 2..5_i32 { let (id, other): (i32, i32) = sqlx::query_as(