chore: remove BoxFuture's (non-breaking) (#3629)

* chore: reduce BoxFuture's when using recursion.

* remove BoxFuture's in WithSocket

* chore: better document previous changes
This commit is contained in:
joeydewaal 2024-12-12 21:43:22 +01:00 committed by GitHub
parent 42ce24dab8
commit 1f6ce33df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 123 additions and 133 deletions

12
Cargo.lock generated
View File

@ -1177,7 +1177,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
] ]
[[package]] [[package]]
@ -1914,7 +1914,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
] ]
[[package]] [[package]]
@ -3986,7 +3986,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
] ]
[[package]] [[package]]
@ -4815,7 +4815,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
"synstructure", "synstructure",
] ]
@ -4856,7 +4856,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
"synstructure", "synstructure",
] ]
@ -4899,5 +4899,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.52", "syn 2.0.87",
] ]

View File

@ -143,7 +143,10 @@ where
pub trait WithSocket { pub trait WithSocket {
type Output; type Output;
fn with_socket<S: Socket>(self, socket: S) -> Self::Output; fn with_socket<S: Socket>(
self,
socket: S,
) -> impl std::future::Future<Output = Self::Output> + Send;
} }
pub struct SocketIntoBox; pub struct SocketIntoBox;
@ -151,7 +154,7 @@ pub struct SocketIntoBox;
impl WithSocket for SocketIntoBox { impl WithSocket for SocketIntoBox {
type Output = Box<dyn Socket>; type Output = Box<dyn Socket>;
fn with_socket<S: Socket>(self, socket: S) -> Self::Output { async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
Box::new(socket) Box::new(socket)
} }
} }
@ -197,7 +200,7 @@ pub async fn connect_tcp<Ws: WithSocket>(
let stream = TcpStream::connect((host, port)).await?; let stream = TcpStream::connect((host, port)).await?;
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
return Ok(with_socket.with_socket(stream)); return Ok(with_socket.with_socket(stream).await);
} }
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
@ -217,7 +220,7 @@ pub async fn connect_tcp<Ws: WithSocket>(
Ok(s) Ok(s)
}); });
match stream { match stream {
Ok(stream) => return Ok(with_socket.with_socket(stream)), Ok(stream) => return Ok(with_socket.with_socket(stream).await),
Err(e) => last_err = Some(e), Err(e) => last_err = Some(e),
} }
} }
@ -255,7 +258,7 @@ pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
let stream = UnixStream::connect(path).await?; let stream = UnixStream::connect(path).await?;
return Ok(with_socket.with_socket(stream)); return Ok(with_socket.with_socket(stream).await);
} }
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
@ -265,7 +268,7 @@ pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
let stream = Async::<UnixStream>::connect(path).await?; let stream = Async::<UnixStream>::connect(path).await?;
Ok(with_socket.with_socket(stream)) Ok(with_socket.with_socket(stream).await)
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]

View File

@ -75,10 +75,14 @@ where
Ws: WithSocket, Ws: WithSocket,
{ {
#[cfg(feature = "_tls-native-tls")] #[cfg(feature = "_tls-native-tls")]
return Ok(with_socket.with_socket(tls_native_tls::handshake(socket, config).await?)); return Ok(with_socket
.with_socket(tls_native_tls::handshake(socket, config).await?)
.await);
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))] #[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
return Ok(with_socket.with_socket(tls_rustls::handshake(socket, config).await?)); return Ok(with_socket
.with_socket(tls_rustls::handshake(socket, config).await?)
.await);
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
{ {

View File

@ -1,6 +1,5 @@
use bytes::buf::Buf; use bytes::buf::Buf;
use bytes::Bytes; use bytes::Bytes;
use futures_core::future::BoxFuture;
use crate::collation::{CharSet, Collation}; use crate::collation::{CharSet, Collation};
use crate::common::StatementCache; use crate::common::StatementCache;
@ -22,7 +21,7 @@ impl MySqlConnection {
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
}; };
let stream = handshake.await?; let stream = handshake?;
Ok(Self { Ok(Self {
inner: Box::new(MySqlConnectionInner { inner: Box::new(MySqlConnectionInner {
@ -187,9 +186,9 @@ impl<'a> DoHandshake<'a> {
} }
impl<'a> WithSocket for DoHandshake<'a> { impl<'a> WithSocket for DoHandshake<'a> {
type Output = BoxFuture<'a, Result<MySqlStream, Error>>; type Output = Result<MySqlStream, Error>;
fn with_socket<S: Socket>(self, socket: S) -> Self::Output { async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
Box::pin(self.do_handshake(socket)) self.do_handshake(socket).await
} }
} }

View File

@ -94,7 +94,7 @@ pub(super) async fn maybe_upgrade<S: Socket>(
impl WithSocket for MapStream { impl WithSocket for MapStream {
type Output = MySqlStream; type Output = MySqlStream;
fn with_socket<S: Socket>(self, socket: S) -> Self::Output { async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
MySqlStream { MySqlStream {
socket: BufferedSocket::new(Box::new(socket)), socket: BufferedSocket::new(Box::new(socket)),
server_version: self.server_version, server_version: self.server_version,

View File

@ -10,7 +10,6 @@ use crate::types::Json;
use crate::types::Oid; use crate::types::Oid;
use crate::HashMap; use crate::HashMap;
use crate::{PgColumn, PgConnection, PgTypeInfo}; use crate::{PgColumn, PgConnection, PgTypeInfo};
use futures_core::future::BoxFuture;
use smallvec::SmallVec; use smallvec::SmallVec;
use sqlx_core::query_builder::QueryBuilder; use sqlx_core::query_builder::QueryBuilder;
use std::sync::Arc; use std::sync::Arc;
@ -169,7 +168,8 @@ impl PgConnection {
// fallback to asking the database directly for a type name // fallback to asking the database directly for a type name
if should_fetch { if should_fetch {
let info = self.fetch_type_by_oid(oid).await?; // we're boxing this future here so we can use async recursion
let info = Box::pin(async { self.fetch_type_by_oid(oid).await }).await?;
// cache the type name <-> oid relationship in a paired hashmap // cache the type name <-> oid relationship in a paired hashmap
// so we don't come down this road again // so we don't come down this road again
@ -190,19 +190,18 @@ impl PgConnection {
} }
} }
fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> { async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
Box::pin(async move { let (name, typ_type, category, relation_id, element, base_type): (
let (name, typ_type, category, relation_id, element, base_type): ( String,
String, i8,
i8, i8,
i8, Oid,
Oid, Oid,
Oid, Oid,
Oid, ) = query_as(
) = query_as( // Converting the OID to `regtype` and then `text` will give us the name that
// Converting the OID to `regtype` and then `text` will give us the name that // the type will need to be found at by search_path.
// the type will need to be found at by search_path. "SELECT oid::regtype::text, \
"SELECT oid::regtype::text, \
typtype, \ typtype, \
typcategory, \ typcategory, \
typrelid, \ typrelid, \
@ -210,54 +209,51 @@ impl PgConnection {
typbasetype \ typbasetype \
FROM pg_catalog.pg_type \ FROM pg_catalog.pg_type \
WHERE oid = $1", WHERE oid = $1",
) )
.bind(oid) .bind(oid)
.fetch_one(&mut *self) .fetch_one(&mut *self)
.await?; .await?;
let typ_type = TypType::try_from(typ_type); let typ_type = TypType::try_from(typ_type);
let category = TypCategory::try_from(category); let category = TypCategory::try_from(category);
match (typ_type, category) { match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
(Ok(TypType::Base), Ok(TypCategory::Array)) => { (Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array( kind: PgTypeKind::Array(
self.maybe_fetch_type_info_by_oid(element, true).await?, self.maybe_fetch_type_info_by_oid(element, true).await?,
), ),
name: name.into(),
oid,
}))))
}
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
}))))
}
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
}
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
self.fetch_enum_by_oid(oid, name).await
}
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
}
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
name: name.into(), name: name.into(),
oid, oid,
})))), }))))
} }
})
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
}))))
}
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
}
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => self.fetch_enum_by_oid(oid, name).await,
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
}
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
name: name.into(),
oid,
})))),
}
} }
async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> { async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
@ -280,15 +276,14 @@ ORDER BY enumsortorder
})))) }))))
} }
fn fetch_composite_by_oid( async fn fetch_composite_by_oid(
&mut self, &mut self,
oid: Oid, oid: Oid,
relation_id: Oid, relation_id: Oid,
name: String, name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> { ) -> Result<PgTypeInfo, Error> {
Box::pin(async move { let raw_fields: Vec<(String, Oid)> = query_as(
let raw_fields: Vec<(String, Oid)> = query_as( r#"
r#"
SELECT attname, atttypid SELECT attname, atttypid
FROM pg_catalog.pg_attribute FROM pg_catalog.pg_attribute
WHERE attrelid = $1 WHERE attrelid = $1
@ -296,69 +291,60 @@ AND NOT attisdropped
AND attnum > 0 AND attnum > 0
ORDER BY attnum ORDER BY attnum
"#, "#,
) )
.bind(relation_id) .bind(relation_id)
.fetch_all(&mut *self) .fetch_all(&mut *self)
.await?; .await?;
let mut fields = Vec::new(); let mut fields = Vec::new();
for (field_name, field_oid) in raw_fields.into_iter() { for (field_name, field_oid) in raw_fields.into_iter() {
let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?;
fields.push((field_name, field_type)); fields.push((field_name, field_type));
} }
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid, oid,
name: name.into(), name: name.into(),
kind: PgTypeKind::Composite(Arc::from(fields)), kind: PgTypeKind::Composite(Arc::from(fields)),
})))) }))))
})
} }
fn fetch_domain_by_oid( async fn fetch_domain_by_oid(
&mut self, &mut self,
oid: Oid, oid: Oid,
base_type: Oid, base_type: Oid,
name: String, name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> { ) -> Result<PgTypeInfo, Error> {
Box::pin(async move { let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid, oid,
name: name.into(), name: name.into(),
kind: PgTypeKind::Domain(base_type), kind: PgTypeKind::Domain(base_type),
})))) }))))
})
} }
fn fetch_range_by_oid( async fn fetch_range_by_oid(&mut self, oid: Oid, name: String) -> Result<PgTypeInfo, Error> {
&mut self, let element_oid: Oid = query_scalar(
oid: Oid, r#"
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let element_oid: Oid = query_scalar(
r#"
SELECT rngsubtype SELECT rngsubtype
FROM pg_catalog.pg_range FROM pg_catalog.pg_range
WHERE rngtypid = $1 WHERE rngtypid = $1
"#, "#,
) )
.bind(oid) .bind(oid)
.fetch_one(&mut *self) .fetch_one(&mut *self)
.await?; .await?;
let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(element), kind: PgTypeKind::Range(element),
name: name.into(), name: name.into(),
oid, oid,
})))) }))))
})
} }
pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> { pub(crate) async fn resolve_type_id(&mut self, ty: &PgType) -> Result<Oid, Error> {

View File

@ -42,12 +42,12 @@ pub struct PgStream {
impl PgStream { impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> { pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_future = match options.fetch_socket() { let socket_result = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
}; };
let socket = socket_future.await?; let socket = socket_result?;
Ok(Self { Ok(Self {
inner: BufferedSocket::new(socket), inner: BufferedSocket::new(socket),

View File

@ -1,5 +1,3 @@
use futures_core::future::BoxFuture;
use crate::error::Error; use crate::error::Error;
use crate::net::tls::{self, TlsConfig}; use crate::net::tls::{self, TlsConfig};
use crate::net::{Socket, SocketIntoBox, WithSocket}; use crate::net::{Socket, SocketIntoBox, WithSocket};
@ -10,10 +8,10 @@ use crate::{PgConnectOptions, PgSslMode};
pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions);
impl<'a> WithSocket for MaybeUpgradeTls<'a> { impl<'a> WithSocket for MaybeUpgradeTls<'a> {
type Output = BoxFuture<'a, crate::Result<Box<dyn Socket>>>; type Output = crate::Result<Box<dyn Socket>>;
fn with_socket<S: Socket>(self, socket: S) -> Self::Output { async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
Box::pin(maybe_upgrade(socket, self.0)) maybe_upgrade(socket, self.0).await
} }
} }