mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-19 08:39:44 +00:00
breaking: add SqlStr (#3723)
* refactor: introduce `SqlSafeStr` API * rebase main * Add SqlStr + remove Statement lifetime * Update the definition of Executor and AnyConnectionBackend + update Postgres driver * Update MySql driver * Update Sqlite driver * remove debug clone count * Reduce the amount of SqlStr clones * improve QueryBuilder error message * cargo fmt * fix clippy warnings * fix doc test * Avoid panic in `QueryBuilder::reset` * Use `QueryBuilder` when removing all test db's * Add comment to `SqlStr` Co-authored-by: Austin Bonander <austin.bonander@gmail.com> * Update sqlx-core/src/query_builder.rs Co-authored-by: Austin Bonander <austin.bonander@gmail.com> * Add `Clone` as supertrait to `Statement` * Move `Connection`, `AnyConnectionBackend` and `TransactionManager` to `SqlStr` * Replace `sql_cloned` with `sql` in `Statement` * Update `Executor` trait * Update unit tests + QueryBuilder changes * Remove code in comments * Update comment in `QueryBuilder` * Fix clippy warnings * Update `Migrate` comment * Small changes * Move `Migration` to `SqlStr` --------- Co-authored-by: Austin Bonander <austin.bonander@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ use crate::{
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use std::borrow::Cow;
|
||||
use sqlx_core::sql_str::SqlStr;
|
||||
use std::{future, pin::pin};
|
||||
|
||||
use sqlx_core::any::{
|
||||
@@ -40,10 +40,7 @@ impl AnyConnectionBackend for PgConnection {
|
||||
Connection::ping(self).boxed()
|
||||
}
|
||||
|
||||
fn begin(
|
||||
&mut self,
|
||||
statement: Option<Cow<'static, str>>,
|
||||
) -> BoxFuture<'_, sqlx_core::Result<()>> {
|
||||
fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
|
||||
PgTransactionManager::begin(self, statement).boxed()
|
||||
}
|
||||
|
||||
@@ -84,7 +81,7 @@ impl AnyConnectionBackend for PgConnection {
|
||||
|
||||
fn fetch_many<'q>(
|
||||
&'q mut self,
|
||||
query: &'q str,
|
||||
query: SqlStr,
|
||||
persistent: bool,
|
||||
arguments: Option<AnyArguments<'q>>,
|
||||
) -> BoxStream<'q, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
|
||||
@@ -110,7 +107,7 @@ impl AnyConnectionBackend for PgConnection {
|
||||
|
||||
fn fetch_optional<'q>(
|
||||
&'q mut self,
|
||||
query: &'q str,
|
||||
query: SqlStr,
|
||||
persistent: bool,
|
||||
arguments: Option<AnyArguments<'q>>,
|
||||
) -> BoxFuture<'q, sqlx_core::Result<Option<AnyRow>>> {
|
||||
@@ -135,20 +132,17 @@ impl AnyConnectionBackend for PgConnection {
|
||||
|
||||
fn prepare_with<'c, 'q: 'c>(
|
||||
&'c mut self,
|
||||
sql: &'q str,
|
||||
sql: SqlStr,
|
||||
_parameters: &[AnyTypeInfo],
|
||||
) -> BoxFuture<'c, sqlx_core::Result<AnyStatement<'q>>> {
|
||||
) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
|
||||
Box::pin(async move {
|
||||
let statement = Executor::prepare_with(self, sql, &[]).await?;
|
||||
AnyStatement::try_from_statement(
|
||||
sql,
|
||||
&statement,
|
||||
statement.metadata.column_names.clone(),
|
||||
)
|
||||
let colunn_names = statement.metadata.column_names.clone();
|
||||
AnyStatement::try_from_statement(statement, colunn_names)
|
||||
})
|
||||
}
|
||||
|
||||
fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result<Describe<Any>>> {
|
||||
fn describe<'c>(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result<Describe<Any>>> {
|
||||
Box::pin(async move {
|
||||
let describe = Executor::describe(self, sql).await?;
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::{PgColumn, PgConnection, PgTypeInfo};
|
||||
use smallvec::SmallVec;
|
||||
use sqlx_core::column::{ColumnOrigin, TableColumn};
|
||||
use sqlx_core::query_builder::QueryBuilder;
|
||||
use sqlx_core::sql_str::AssertSqlSafe;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Describes the type of the `pg_type.typtype` column
|
||||
@@ -619,7 +620,7 @@ WHERE rngtypid = $1
|
||||
}
|
||||
|
||||
let (Json(explains),): (Json<SmallVec<[Explain; 1]>>,) =
|
||||
query_as(&explain).fetch_one(self).await?;
|
||||
query_as(AssertSqlSafe(explain)).fetch_one(self).await?;
|
||||
|
||||
let mut nullables = Vec::new();
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::TryStreamExt;
|
||||
use sqlx_core::arguments::Arguments;
|
||||
use sqlx_core::sql_str::SqlStr;
|
||||
use sqlx_core::Either;
|
||||
use std::{borrow::Cow, pin::pin, sync::Arc};
|
||||
use std::{pin::pin, sync::Arc};
|
||||
|
||||
async fn prepare(
|
||||
conn: &mut PgConnection,
|
||||
@@ -209,13 +210,11 @@ impl PgConnection {
|
||||
|
||||
pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
|
||||
&'c mut self,
|
||||
query: &'q str,
|
||||
query: SqlStr,
|
||||
arguments: Option<PgArguments>,
|
||||
persistent: bool,
|
||||
metadata_opt: Option<Arc<PgStatementMetadata>>,
|
||||
) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
|
||||
let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
|
||||
|
||||
// before we continue, wait until we are "ready" to accept more queries
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
@@ -238,7 +237,13 @@ impl PgConnection {
|
||||
// prepare the statement if this our first time executing it
|
||||
// always return the statement ID here
|
||||
let (statement, metadata_) = self
|
||||
.get_or_prepare(query, &arguments.types, persistent, metadata_opt, false)
|
||||
.get_or_prepare(
|
||||
query.as_str(),
|
||||
&arguments.types,
|
||||
persistent,
|
||||
metadata_opt,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
|
||||
metadata = metadata_;
|
||||
@@ -291,7 +296,7 @@ impl PgConnection {
|
||||
PgValueFormat::Binary
|
||||
} else {
|
||||
// Query will trigger a ReadyForQuery
|
||||
self.inner.stream.write_msg(Query(query))?;
|
||||
self.inner.stream.write_msg(Query(query.as_str()))?;
|
||||
self.inner.pending_ready_for_query_count += 1;
|
||||
|
||||
// metadata starts out as "nothing"
|
||||
@@ -302,6 +307,7 @@ impl PgConnection {
|
||||
};
|
||||
|
||||
self.inner.stream.flush().await?;
|
||||
let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
|
||||
|
||||
Ok(try_stream! {
|
||||
loop {
|
||||
@@ -402,12 +408,12 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
'q: 'e,
|
||||
E: 'q,
|
||||
{
|
||||
let sql = query.sql();
|
||||
// False positive: https://github.com/rust-lang/rust-clippy/issues/12560
|
||||
#[allow(clippy::map_clone)]
|
||||
let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
|
||||
let arguments = query.take_arguments().map_err(Error::Encode);
|
||||
let persistent = query.persistent();
|
||||
let sql = query.sql();
|
||||
|
||||
Box::pin(try_stream! {
|
||||
let arguments = arguments?;
|
||||
@@ -428,7 +434,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
'q: 'e,
|
||||
E: 'q,
|
||||
{
|
||||
let sql = query.sql();
|
||||
// False positive: https://github.com/rust-lang/rust-clippy/issues/12560
|
||||
#[allow(clippy::map_clone)]
|
||||
let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
|
||||
@@ -436,6 +441,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
let persistent = query.persistent();
|
||||
|
||||
Box::pin(async move {
|
||||
let sql = query.sql();
|
||||
let arguments = arguments?;
|
||||
let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
|
||||
|
||||
@@ -454,11 +460,11 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_with<'e, 'q: 'e>(
|
||||
fn prepare_with<'e>(
|
||||
self,
|
||||
sql: &'q str,
|
||||
sql: SqlStr,
|
||||
parameters: &'e [PgTypeInfo],
|
||||
) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
|
||||
) -> BoxFuture<'e, Result<PgStatement, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
{
|
||||
@@ -466,27 +472,23 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
let (_, metadata) = self
|
||||
.get_or_prepare(sql, parameters, true, None, true)
|
||||
.get_or_prepare(sql.as_str(), parameters, true, None, true)
|
||||
.await?;
|
||||
|
||||
Ok(PgStatement {
|
||||
sql: Cow::Borrowed(sql),
|
||||
metadata,
|
||||
})
|
||||
Ok(PgStatement { sql, metadata })
|
||||
})
|
||||
}
|
||||
|
||||
fn describe<'e, 'q: 'e>(
|
||||
self,
|
||||
sql: &'q str,
|
||||
) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
|
||||
fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
{
|
||||
Box::pin(async move {
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None, true).await?;
|
||||
let (stmt_id, metadata) = self
|
||||
.get_or_prepare(sql.as_str(), &[], true, None, true)
|
||||
.await?;
|
||||
|
||||
let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::future::Future;
|
||||
@@ -20,6 +19,7 @@ use crate::types::Oid;
|
||||
use crate::{PgConnectOptions, PgTypeInfo, Postgres};
|
||||
|
||||
pub(crate) use sqlx_core::connection::*;
|
||||
use sqlx_core::sql_str::SqlSafeStr;
|
||||
|
||||
pub use self::stream::PgStream;
|
||||
|
||||
@@ -193,12 +193,12 @@ impl Connection for PgConnection {
|
||||
|
||||
fn begin_with(
|
||||
&mut self,
|
||||
statement: impl Into<Cow<'static, str>>,
|
||||
statement: impl SqlSafeStr,
|
||||
) -> impl Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Transaction::begin(self, Some(statement.into()))
|
||||
Transaction::begin(self, Some(statement.into_sql_str()))
|
||||
}
|
||||
|
||||
fn cached_statements_size(&self) -> usize {
|
||||
|
||||
@@ -30,7 +30,7 @@ impl Database for Postgres {
|
||||
type Arguments<'q> = PgArguments;
|
||||
type ArgumentBuffer<'q> = PgArgumentBuffer;
|
||||
|
||||
type Statement<'q> = PgStatement<'q>;
|
||||
type Statement = PgStatement;
|
||||
|
||||
const NAME: &'static str = "PostgreSQL";
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::{BoxStream, Stream};
|
||||
use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use sqlx_core::acquire::Acquire;
|
||||
use sqlx_core::sql_str::{AssertSqlSafe, SqlStr};
|
||||
use sqlx_core::transaction::Transaction;
|
||||
use sqlx_core::Either;
|
||||
use tracing::Instrument;
|
||||
@@ -116,7 +117,7 @@ impl PgListener {
|
||||
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
|
||||
self.connection()
|
||||
.await?
|
||||
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
|
||||
.execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel))))
|
||||
.await?;
|
||||
|
||||
self.channels.push(channel.to_owned());
|
||||
@@ -133,7 +134,10 @@ impl PgListener {
|
||||
self.channels.extend(channels.into_iter().map(|s| s.into()));
|
||||
|
||||
let query = build_listen_all_query(&self.channels[beg..]);
|
||||
self.connection().await?.execute(&*query).await?;
|
||||
self.connection()
|
||||
.await?
|
||||
.execute(AssertSqlSafe(query))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -145,7 +149,7 @@ impl PgListener {
|
||||
// UNLISTEN (we've disconnected anyways)
|
||||
if let Some(connection) = self.connection.as_mut() {
|
||||
connection
|
||||
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
|
||||
.execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel))))
|
||||
.await?;
|
||||
}
|
||||
|
||||
@@ -176,7 +180,7 @@ impl PgListener {
|
||||
connection.inner.stream.notifications = self.buffer_tx.take();
|
||||
|
||||
connection
|
||||
.execute(&*build_listen_all_query(&self.channels))
|
||||
.execute(AssertSqlSafe(build_listen_all_query(&self.channels)))
|
||||
.await?;
|
||||
|
||||
self.connection = Some(connection);
|
||||
@@ -417,11 +421,11 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
||||
async move { self.connection().await?.fetch_optional(query).await }.boxed()
|
||||
}
|
||||
|
||||
fn prepare_with<'e, 'q: 'e>(
|
||||
fn prepare_with<'e>(
|
||||
self,
|
||||
query: &'q str,
|
||||
query: SqlStr,
|
||||
parameters: &'e [PgTypeInfo],
|
||||
) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
|
||||
) -> BoxFuture<'e, Result<PgStatement, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
{
|
||||
@@ -435,10 +439,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn describe<'e, 'q: 'e>(
|
||||
self,
|
||||
query: &'q str,
|
||||
) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
|
||||
fn describe<'e>(self, query: SqlStr) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
{
|
||||
|
||||
@@ -7,6 +7,7 @@ use futures_core::future::BoxFuture;
|
||||
pub(crate) use sqlx_core::migrate::MigrateError;
|
||||
pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration};
|
||||
pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase};
|
||||
use sqlx_core::sql_str::AssertSqlSafe;
|
||||
|
||||
use crate::connection::{ConnectOptions, Connection};
|
||||
use crate::error::Error;
|
||||
@@ -44,10 +45,10 @@ impl MigrateDatabase for Postgres {
|
||||
let mut conn = options.connect().await?;
|
||||
|
||||
let _ = conn
|
||||
.execute(&*format!(
|
||||
.execute(AssertSqlSafe(format!(
|
||||
"CREATE DATABASE \"{}\"",
|
||||
database.replace('"', "\"\"")
|
||||
))
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@@ -71,10 +72,10 @@ impl MigrateDatabase for Postgres {
|
||||
let mut conn = options.connect().await?;
|
||||
|
||||
let _ = conn
|
||||
.execute(&*format!(
|
||||
.execute(AssertSqlSafe(format!(
|
||||
"DROP DATABASE IF EXISTS \"{}\"",
|
||||
database.replace('"', "\"\"")
|
||||
))
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@@ -92,10 +93,10 @@ impl MigrateDatabase for Postgres {
|
||||
|
||||
let pid_type = if version >= 90200 { "pid" } else { "procpid" };
|
||||
|
||||
conn.execute(&*format!(
|
||||
conn.execute(AssertSqlSafe(format!(
|
||||
"SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \
|
||||
WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()"
|
||||
))
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Self::drop_database(url).await
|
||||
@@ -109,8 +110,10 @@ impl Migrate for PgConnection {
|
||||
) -> BoxFuture<'e, Result<(), MigrateError>> {
|
||||
Box::pin(async move {
|
||||
// language=SQL
|
||||
self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#))
|
||||
.await?;
|
||||
self.execute(AssertSqlSafe(format!(
|
||||
r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
@@ -122,7 +125,7 @@ impl Migrate for PgConnection {
|
||||
) -> BoxFuture<'e, Result<(), MigrateError>> {
|
||||
Box::pin(async move {
|
||||
// language=SQL
|
||||
self.execute(&*format!(
|
||||
self.execute(AssertSqlSafe(format!(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
version BIGINT PRIMARY KEY,
|
||||
@@ -133,7 +136,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
execution_time BIGINT NOT NULL
|
||||
);
|
||||
"#
|
||||
))
|
||||
)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
@@ -146,9 +149,9 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
) -> BoxFuture<'e, Result<Option<i64>, MigrateError>> {
|
||||
Box::pin(async move {
|
||||
// language=SQL
|
||||
let row: Option<(i64,)> = query_as(&format!(
|
||||
let row: Option<(i64,)> = query_as(AssertSqlSafe(format!(
|
||||
"SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1"
|
||||
))
|
||||
)))
|
||||
.fetch_optional(self)
|
||||
.await?;
|
||||
|
||||
@@ -162,9 +165,9 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
) -> BoxFuture<'e, Result<Vec<AppliedMigration>, MigrateError>> {
|
||||
Box::pin(async move {
|
||||
// language=SQL
|
||||
let rows: Vec<(i64, Vec<u8>)> = query_as(&format!(
|
||||
let rows: Vec<(i64, Vec<u8>)> = query_as(AssertSqlSafe(format!(
|
||||
"SELECT version, checksum FROM {table_name} ORDER BY version"
|
||||
))
|
||||
)))
|
||||
.fetch_all(self)
|
||||
.await?;
|
||||
|
||||
@@ -245,13 +248,13 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
|
||||
// language=SQL
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let _ = query(&format!(
|
||||
let _ = query(AssertSqlSafe(format!(
|
||||
r#"
|
||||
UPDATE {table_name}
|
||||
SET execution_time = $1
|
||||
WHERE version = $2
|
||||
"#
|
||||
))
|
||||
)))
|
||||
.bind(elapsed.as_nanos() as i64)
|
||||
.bind(migration.version)
|
||||
.execute(self)
|
||||
@@ -293,17 +296,17 @@ async fn execute_migration(
|
||||
migration: &Migration,
|
||||
) -> Result<(), MigrateError> {
|
||||
let _ = conn
|
||||
.execute(&*migration.sql)
|
||||
.execute(migration.sql.clone())
|
||||
.await
|
||||
.map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?;
|
||||
|
||||
// language=SQL
|
||||
let _ = query(&format!(
|
||||
let _ = query(AssertSqlSafe(format!(
|
||||
r#"
|
||||
INSERT INTO {table_name} ( version, description, success, checksum, execution_time )
|
||||
VALUES ( $1, $2, TRUE, $3, -1 )
|
||||
"#
|
||||
))
|
||||
)))
|
||||
.bind(migration.version)
|
||||
.bind(&*migration.description)
|
||||
.bind(&*migration.checksum)
|
||||
@@ -319,15 +322,17 @@ async fn revert_migration(
|
||||
migration: &Migration,
|
||||
) -> Result<(), MigrateError> {
|
||||
let _ = conn
|
||||
.execute(&*migration.sql)
|
||||
.execute(migration.sql.clone())
|
||||
.await
|
||||
.map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?;
|
||||
|
||||
// language=SQL
|
||||
let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = $1"#))
|
||||
.bind(migration.version)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
let _ = query(AssertSqlSafe(format!(
|
||||
r#"DELETE FROM {table_name} WHERE version = $1"#
|
||||
)))
|
||||
.bind(migration.version)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,15 +3,15 @@ use crate::column::ColumnIndex;
|
||||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::{PgArguments, Postgres};
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
use sqlx_core::sql_str::SqlStr;
|
||||
pub(crate) use sqlx_core::statement::Statement;
|
||||
use sqlx_core::{Either, HashMap};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgStatement<'q> {
|
||||
pub(crate) sql: Cow<'q, str>,
|
||||
pub struct PgStatement {
|
||||
pub(crate) sql: SqlStr,
|
||||
pub(crate) metadata: Arc<PgStatementMetadata>,
|
||||
}
|
||||
|
||||
@@ -24,17 +24,14 @@ pub(crate) struct PgStatementMetadata {
|
||||
pub(crate) parameters: Vec<PgTypeInfo>,
|
||||
}
|
||||
|
||||
impl<'q> Statement<'q> for PgStatement<'q> {
|
||||
impl Statement for PgStatement {
|
||||
type Database = Postgres;
|
||||
|
||||
fn to_owned(&self) -> PgStatement<'static> {
|
||||
PgStatement::<'static> {
|
||||
sql: Cow::Owned(self.sql.clone().into_owned()),
|
||||
metadata: self.metadata.clone(),
|
||||
}
|
||||
fn into_sql(self) -> SqlStr {
|
||||
self.sql
|
||||
}
|
||||
|
||||
fn sql(&self) -> &str {
|
||||
fn sql(&self) -> &SqlStr {
|
||||
&self.sql
|
||||
}
|
||||
|
||||
@@ -49,8 +46,8 @@ impl<'q> Statement<'q> for PgStatement<'q> {
|
||||
impl_statement_query!(PgArguments);
|
||||
}
|
||||
|
||||
impl ColumnIndex<PgStatement<'_>> for &'_ str {
|
||||
fn index(&self, statement: &PgStatement<'_>) -> Result<usize, Error> {
|
||||
impl ColumnIndex<PgStatement> for &'_ str {
|
||||
fn index(&self, statement: &PgStatement) -> Result<usize, Error> {
|
||||
statement
|
||||
.metadata
|
||||
.column_names
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::fmt::Write;
|
||||
use std::future::Future;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
@@ -6,7 +5,9 @@ use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use sqlx_core::connection::Connection;
|
||||
use sqlx_core::query_builder::QueryBuilder;
|
||||
use sqlx_core::query_scalar::query_scalar;
|
||||
use sqlx_core::sql_str::AssertSqlSafe;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::executor::Executor;
|
||||
@@ -52,12 +53,12 @@ impl TestSupport for Postgres {
|
||||
|
||||
let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
|
||||
|
||||
let mut command = String::new();
|
||||
let mut builder = QueryBuilder::new("drop database if exists ");
|
||||
|
||||
for db_name in &delete_db_names {
|
||||
command.clear();
|
||||
writeln!(command, "drop database if exists {db_name:?};").ok();
|
||||
match conn.execute(&*command).await {
|
||||
builder.push(db_name);
|
||||
|
||||
match builder.build().execute(&mut conn).await {
|
||||
Ok(_deleted) => {
|
||||
deleted_db_names.push(db_name);
|
||||
}
|
||||
@@ -68,6 +69,8 @@ impl TestSupport for Postgres {
|
||||
// Bubble up other errors
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
builder.reset();
|
||||
}
|
||||
|
||||
query("delete from _sqlx_test.databases where db_name = any($1::text[])")
|
||||
@@ -163,7 +166,7 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
|
||||
|
||||
let create_command = format!("create database {db_name:?}");
|
||||
debug_assert!(create_command.starts_with("create database \""));
|
||||
conn.execute(&(create_command)[..]).await?;
|
||||
conn.execute(AssertSqlSafe(create_command)).await?;
|
||||
|
||||
Ok(TestContext {
|
||||
pool_opts: PoolOptions::new()
|
||||
@@ -185,7 +188,7 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
|
||||
|
||||
async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> {
|
||||
let delete_db_command = format!("drop database if exists {db_name:?};");
|
||||
conn.execute(&*delete_db_command).await?;
|
||||
conn.execute(AssertSqlSafe(delete_db_command)).await?;
|
||||
query("delete from _sqlx_test.databases where db_name = $1::text")
|
||||
.bind(db_name)
|
||||
.execute(&mut *conn)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use sqlx_core::database::Database;
|
||||
use std::borrow::Cow;
|
||||
use sqlx_core::sql_str::SqlStr;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::executor::Executor;
|
||||
@@ -14,11 +14,9 @@ pub struct PgTransactionManager;
|
||||
impl TransactionManager for PgTransactionManager {
|
||||
type Database = Postgres;
|
||||
|
||||
async fn begin(
|
||||
conn: &mut PgConnection,
|
||||
statement: Option<Cow<'static, str>>,
|
||||
) -> Result<(), Error> {
|
||||
async fn begin(conn: &mut PgConnection, statement: Option<SqlStr>) -> Result<(), Error> {
|
||||
let depth = conn.inner.transaction_depth;
|
||||
|
||||
let statement = match statement {
|
||||
// custom `BEGIN` statements are not allowed if we're already in
|
||||
// a transaction (we need to issue a `SAVEPOINT` instead)
|
||||
@@ -28,7 +26,7 @@ impl TransactionManager for PgTransactionManager {
|
||||
};
|
||||
|
||||
let rollback = Rollback::new(conn);
|
||||
rollback.conn.queue_simple_query(&statement)?;
|
||||
rollback.conn.queue_simple_query(statement.as_str())?;
|
||||
rollback.conn.wait_until_ready().await?;
|
||||
if !rollback.conn.in_transaction() {
|
||||
return Err(Error::BeginFailed);
|
||||
@@ -41,7 +39,7 @@ impl TransactionManager for PgTransactionManager {
|
||||
|
||||
async fn commit(conn: &mut PgConnection) -> Result<(), Error> {
|
||||
if conn.inner.transaction_depth > 0 {
|
||||
conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth))
|
||||
conn.execute(commit_ansi_transaction_sql(conn.inner.transaction_depth))
|
||||
.await?;
|
||||
|
||||
conn.inner.transaction_depth -= 1;
|
||||
@@ -52,10 +50,8 @@ impl TransactionManager for PgTransactionManager {
|
||||
|
||||
async fn rollback(conn: &mut PgConnection) -> Result<(), Error> {
|
||||
if conn.inner.transaction_depth > 0 {
|
||||
conn.execute(&*rollback_ansi_transaction_sql(
|
||||
conn.inner.transaction_depth,
|
||||
))
|
||||
.await?;
|
||||
conn.execute(rollback_ansi_transaction_sql(conn.inner.transaction_depth))
|
||||
.await?;
|
||||
|
||||
conn.inner.transaction_depth -= 1;
|
||||
}
|
||||
@@ -65,8 +61,10 @@ impl TransactionManager for PgTransactionManager {
|
||||
|
||||
fn start_rollback(conn: &mut PgConnection) {
|
||||
if conn.inner.transaction_depth > 0 {
|
||||
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth))
|
||||
.expect("BUG: Rollback query somehow too large for protocol");
|
||||
conn.queue_simple_query(
|
||||
rollback_ansi_transaction_sql(conn.inner.transaction_depth).as_str(),
|
||||
)
|
||||
.expect("BUG: Rollback query somehow too large for protocol");
|
||||
|
||||
conn.inner.transaction_depth -= 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user