refactor: tweaks after #3791 (#4022)

* restore fallback to `async-io` for `connect_tcp()` when `runtime-tokio` feature is enabled
* `smol` and `async-global-executor` both use `async-task`, so `JoinHandle` impls can be consolidated
* no need for duplicate `yield_now()` impls
* delete `impl Socket for ()`
This commit is contained in:
Austin Bonander 2025-09-08 14:28:58 -07:00 committed by GitHub
parent 500cd18f19
commit 66526d9c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 117 additions and 225 deletions

View File

@ -22,6 +22,7 @@ jobs:
runs-on: ubuntu-24.04
strategy:
matrix:
# Note: because `async-std` is deprecated, we only check it in a single job to save CI time.
runtime: [ async-std, async-global-executor, smol, tokio ]
tls: [ native-tls, rustls, none ]
timeout-minutes: 30

1
Cargo.lock generated
View File

@ -3543,6 +3543,7 @@ dependencies = [
"async-global-executor 3.1.0",
"async-io",
"async-std",
"async-task",
"base64 0.22.1",
"bigdecimal",
"bit-vec",

View File

@ -55,9 +55,11 @@ features = [
[features]
default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"]
# TLS options
rustls = ["sqlx/tls-rustls"]
native-tls = ["sqlx/tls-native-tls"]
# databases
mysql = ["sqlx/mysql"]
postgres = ["sqlx/postgres"]
sqlite = ["sqlx/sqlite", "_sqlite"]

View File

@ -20,11 +20,13 @@ any = []
json = ["serde", "serde_json"]
# for conditional compilation
_rt-async-global-executor = ["async-global-executor", "_rt-async-io"]
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
_rt-async-std = ["async-std", "_rt-async-io"]
_rt-smol = ["smol", "_rt-async-io"]
_rt-async-task = ["async-task"]
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
_rt-tokio = ["tokio", "tokio-stream"]
_tls-native-tls = ["native-tls"]
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
@ -68,6 +70,8 @@ mac_address = { workspace = true, optional = true }
uuid = { workspace = true, optional = true }
async-io = { version = "2.4.1", optional = true }
async-task = { version = "4.7.1", optional = true }
# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281
async-fs = { version = "2.1", optional = true }
base64 = { version = "0.22.0", default-features = false, features = ["std"] }

View File

@ -1,18 +1,14 @@
use std::future::Future;
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{
future::Future,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
};
pub use buffered::{BufferedSocket, WriteBuffer};
use bytes::BufMut;
use cfg_if::cfg_if;
pub use buffered::{BufferedSocket, WriteBuffer};
use crate::{io::ReadBuf, rt::spawn_blocking};
use crate::io::ReadBuf;
mod buffered;
@ -146,10 +142,7 @@ where
pub trait WithSocket {
type Output;
fn with_socket<S: Socket>(
self,
socket: S,
) -> impl std::future::Future<Output = Self::Output> + Send;
fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
}
pub struct SocketIntoBox;
@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(
port: u16,
with_socket: Ws,
) -> crate::Result<Ws::Output> {
#[cfg(feature = "_rt-tokio")]
if crate::rt::rt_tokio::available() {
return Ok(with_socket
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
.await);
}
cfg_if! {
if #[cfg(feature = "_rt-async-io")] {
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
} else {
crate::rt::missing_rt((host, port, with_socket))
}
}
}
/// Open a TCP socket to `host` and `port`.
///
/// If `host` is a hostname, attempt to connect to each address it resolves to.
///
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
#[cfg(feature = "_rt-async-io")]
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
use async_io::Async;
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);
let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() {
let addr = SocketAddrV4::new(addr, port);
vec![SocketAddr::V4(addr)].into_iter()
} else if let Ok(addr) = host.parse::<Ipv6Addr>() {
let addr = SocketAddrV6::new(addr, port, 0, 0);
vec![SocketAddr::V6(addr)].into_iter()
} else {
let host = host.to_string();
spawn_blocking(move || {
let addr = (host.as_str(), port);
ToSocketAddrs::to_socket_addrs(&addr)
})
.await?
};
if let Ok(addr) = host.parse::<IpAddr>() {
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
}
let host = host.to_string();
let addresses = crate::rt::spawn_blocking(move || {
let addr = (host.as_str(), port);
ToSocketAddrs::to_socket_addrs(&addr)
})
.await?;
let mut last_err = None;
// Loop through all the Socket Addresses that the hostname resolves to
for socket_addr in addresses {
match connect_tcp_address(socket_addr).await {
Ok(stream) => return Ok(with_socket.with_socket(stream).await),
match Async::<TcpStream>::connect(socket_addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}
// If we reach this point, it means we failed to connect to any of the addresses.
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
Err(match last_err {
Some(err) => err,
None => io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Hostname did not resolve to any addresses",
)
.into(),
})
}
async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> {
cfg_if! {
if #[cfg(feature = "_rt-tokio")] {
if crate::rt::rt_tokio::available() {
use tokio::net::TcpStream;
let stream = TcpStream::connect(socket_addr).await?;
stream.set_nodelay(true)?;
Ok(stream)
} else {
crate::rt::missing_rt(socket_addr)
}
} else if #[cfg(feature = "_rt-async-io")] {
use async_io::Async;
use std::net::TcpStream;
let stream = Async::<TcpStream>::connect(socket_addr).await?;
stream.get_ref().set_nodelay(true)?;
Ok(stream)
} else {
crate::rt::missing_rt(socket_addr);
#[allow(unreachable_code)]
Ok(())
}
}
}
// Work around `impl Socket`` and 'unability to specify test build cargo feature'.
// `connect_tcp_address` compilation would fail without this impl with
// 'cannot infer return type' error.
impl Socket for () {
fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> {
unreachable!()
}
fn try_write(&mut self, _: &[u8]) -> io::Result<usize> {
unreachable!()
}
fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
unreachable!()
}
Err(last_err
.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Hostname did not resolve to any addresses",
)
})
.into())
}
/// Connect a Unix Domain Socket at the given path.

View File

@ -9,12 +9,6 @@ use cfg_if::cfg_if;
#[cfg(feature = "_rt-async-io")]
pub mod rt_async_io;
#[cfg(feature = "_rt-async-global-executor")]
pub mod rt_async_global_executor;
#[cfg(feature = "_rt-smol")]
pub mod rt_smol;
#[cfg(feature = "_rt-tokio")]
pub mod rt_tokio;
@ -23,14 +17,16 @@ pub mod rt_tokio;
pub struct TimeoutError;
pub enum JoinHandle<T> {
#[cfg(feature = "_rt-async-global-executor")]
AsyncGlobalExecutor(rt_async_global_executor::JoinHandle<T>),
#[cfg(feature = "_rt-async-std")]
AsyncStd(async_std::task::JoinHandle<T>),
#[cfg(feature = "_rt-smol")]
Smol(rt_smol::JoinHandle<T>),
#[cfg(feature = "_rt-tokio")]
Tokio(tokio::task::JoinHandle<T>),
// Implementation shared by `smol` and `async-global-executor`
#[cfg(feature = "_rt-async-task")]
AsyncTask(Option<async_task::Task<T>>),
// `PhantomData<T>` requires `T: Unpin`
_Phantom(PhantomData<fn() -> T>),
}
@ -41,7 +37,6 @@ pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, T
#[cfg(feature = "_rt-tokio")]
if rt_tokio::available() {
#[allow(clippy::needless_return)]
return tokio::time::timeout(duration, f)
.await
.map_err(|_| TimeoutError);
@ -84,15 +79,11 @@ where
cfg_if! {
if #[cfg(feature = "_rt-async-global-executor")] {
JoinHandle::AsyncGlobalExecutor(rt_async_global_executor::JoinHandle {
task: Some(async_global_executor::spawn(fut)),
})
JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))
} else if #[cfg(feature = "_rt-smol")] {
JoinHandle::AsyncTask(Some(smol::spawn(fut)))
} else if #[cfg(feature = "_rt-async-std")] {
JoinHandle::AsyncStd(async_std::task::spawn(fut))
} else if #[cfg(feature = "_rt-smol")] {
JoinHandle::Smol(rt_smol::JoinHandle {
task: Some(smol::spawn(fut)),
})
} else {
missing_rt(fut)
}
@ -112,15 +103,11 @@ where
cfg_if! {
if #[cfg(feature = "_rt-async-global-executor")] {
JoinHandle::AsyncGlobalExecutor(rt_async_global_executor::JoinHandle {
task: Some(async_global_executor::spawn_blocking(f)),
})
JoinHandle::AsyncTask(Some(async_global_executor::spawn_blocking(f)))
} else if #[cfg(feature = "_rt-smol")] {
JoinHandle::AsyncTask(Some(smol::unblock(f)))
} else if #[cfg(feature = "_rt-async-std")] {
JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
} else if #[cfg(feature = "_rt-smol")] {
JoinHandle::Smol(rt_smol::JoinHandle {
task: Some(smol::unblock(f)),
})
} else {
missing_rt(f)
}
@ -133,17 +120,27 @@ pub async fn yield_now() {
return tokio::task::yield_now().await;
}
cfg_if! {
if #[cfg(feature = "_rt-async-global-executor")] {
rt_async_global_executor::yield_now().await
} else if #[cfg(feature = "_rt-async-std")] {
async_std::task::yield_now().await
} else if #[cfg(feature = "_rt-smol")] {
smol::future::yield_now().await
// `smol`, `async-global-executor`, and `async-std` all have the same implementation for this.
//
// By immediately signaling the waker and then returning `Pending`,
// this essentially just moves the task to the back of the runnable queue.
//
// There isn't any special integration with the runtime, so we can save code by rolling our own.
//
// (Tokio's implementation is nearly identical too,
// but has additional integration with `tracing` which may be useful for debugging.)
let mut yielded = false;
std::future::poll_fn(|cx| {
if !yielded {
yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
} else {
missing_rt(())
Poll::Ready(())
}
}
})
.await
}
#[track_caller]
@ -169,7 +166,7 @@ pub const fn missing_rt<T>(_unused: T) -> ! {
panic!("this functionality requires a Tokio context")
}
panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, `runtime-smol`, or `runtime-tokio` feature must be enabled")
panic!("one of the `runtime` features of SQLx must be enabled")
}
impl<T: Send + 'static> Future for JoinHandle<T> {
@ -178,16 +175,20 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
#[track_caller]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match &mut *self {
#[cfg(feature = "_rt-async-global-executor")]
Self::AsyncGlobalExecutor(handle) => Pin::new(handle).poll(cx),
#[cfg(feature = "_rt-async-std")]
Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
#[cfg(feature = "_rt-smol")]
Self::Smol(handle) => Pin::new(handle).poll(cx),
#[cfg(feature = "_rt-async-task")]
Self::AsyncTask(task) => Pin::new(task)
.as_pin_mut()
.expect("BUG: task taken")
.poll(cx),
#[cfg(feature = "_rt-tokio")]
Self::Tokio(handle) => Pin::new(handle)
.poll(cx)
.map(|res| res.expect("spawned task panicked")),
Self::_Phantom(_) => {
let _ = cx;
unreachable!("runtime should have been checked on spawn")
@ -195,3 +196,19 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
}
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
match self {
// `async_task` cancels on-drop by default.
// We need to explicitly detach to match Tokio and `async-std`.
#[cfg(feature = "_rt-async-task")]
Self::AsyncTask(task) => {
if let Some(task) = task.take() {
task.detach();
}
}
_ => (),
}
}
}

View File

@ -1,30 +0,0 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use async_global_executor::Task;
pub struct JoinHandle<T> {
pub task: Option<Task<T>>,
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
task.detach();
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.task.as_mut() {
Some(task) => Future::poll(Pin::new(task), cx),
None => unreachable!("JoinHandle polled after dropping"),
}
}
}

View File

@ -1,5 +0,0 @@
mod join_handle;
pub use join_handle::*;
pub mod yield_now;
pub use yield_now::*;

View File

@ -1,28 +0,0 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub fn yield_now() -> impl Future<Output = ()> {
YieldNow(false)
}
struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
// The futures executor is implemented as a FIFO queue, so all this future
// does is re-schedule the future back to the end of the queue, giving room
// for other futures to progress.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if !self.0 {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(())
}
}
}

View File

@ -1,30 +0,0 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use smol::Task;
pub struct JoinHandle<T> {
pub task: Option<Task<T>>,
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if let Some(task) = self.task.take() {
task.detach();
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.task.as_mut() {
Some(task) => Future::poll(Pin::new(task), cx),
None => unreachable!("JoinHandle polled after dropping"),
}
}
}

View File

@ -1,2 +0,0 @@
mod join_handle;
pub use join_handle::*;