mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-09-26 20:40:54 +00:00
* restore fallback to `async-io` for `connect_tcp()` when `runtime-tokio` feature is enabled * `smol` and `async-global-executor` both use `async-task`, so `JoinHandle` impls can be consolidated * no need for duplicate `yield_now()` impls * delete `impl Socket for ()`
This commit is contained in:
parent
500cd18f19
commit
66526d9c56
1
.github/workflows/sqlx.yml
vendored
1
.github/workflows/sqlx.yml
vendored
@ -22,6 +22,7 @@ jobs:
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
# Note: because `async-std` is deprecated, we only check it in a single job to save CI time.
|
||||
runtime: [ async-std, async-global-executor, smol, tokio ]
|
||||
tls: [ native-tls, rustls, none ]
|
||||
timeout-minutes: 30
|
||||
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3543,6 +3543,7 @@ dependencies = [
|
||||
"async-global-executor 3.1.0",
|
||||
"async-io",
|
||||
"async-std",
|
||||
"async-task",
|
||||
"base64 0.22.1",
|
||||
"bigdecimal",
|
||||
"bit-vec",
|
||||
|
@ -55,9 +55,11 @@ features = [
|
||||
[features]
|
||||
default = ["postgres", "sqlite", "mysql", "native-tls", "completions", "sqlx-toml"]
|
||||
|
||||
# TLS options
|
||||
rustls = ["sqlx/tls-rustls"]
|
||||
native-tls = ["sqlx/tls-native-tls"]
|
||||
|
||||
# databases
|
||||
mysql = ["sqlx/mysql"]
|
||||
postgres = ["sqlx/postgres"]
|
||||
sqlite = ["sqlx/sqlite", "_sqlite"]
|
||||
|
@ -20,11 +20,13 @@ any = []
|
||||
json = ["serde", "serde_json"]
|
||||
|
||||
# for conditional compilation
|
||||
_rt-async-global-executor = ["async-global-executor", "_rt-async-io"]
|
||||
_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"]
|
||||
_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration
|
||||
_rt-async-std = ["async-std", "_rt-async-io"]
|
||||
_rt-smol = ["smol", "_rt-async-io"]
|
||||
_rt-async-task = ["async-task"]
|
||||
_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"]
|
||||
_rt-tokio = ["tokio", "tokio-stream"]
|
||||
|
||||
_tls-native-tls = ["native-tls"]
|
||||
_tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"]
|
||||
_tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"]
|
||||
@ -68,6 +70,8 @@ mac_address = { workspace = true, optional = true }
|
||||
uuid = { workspace = true, optional = true }
|
||||
|
||||
async-io = { version = "2.4.1", optional = true }
|
||||
async-task = { version = "4.7.1", optional = true }
|
||||
|
||||
# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281
|
||||
async-fs = { version = "2.1", optional = true }
|
||||
base64 = { version = "0.22.0", default-features = false, features = ["std"] }
|
||||
|
@ -1,18 +1,14 @@
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::{
|
||||
future::Future,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
|
||||
};
|
||||
|
||||
pub use buffered::{BufferedSocket, WriteBuffer};
|
||||
use bytes::BufMut;
|
||||
use cfg_if::cfg_if;
|
||||
|
||||
pub use buffered::{BufferedSocket, WriteBuffer};
|
||||
|
||||
use crate::{io::ReadBuf, rt::spawn_blocking};
|
||||
use crate::io::ReadBuf;
|
||||
|
||||
mod buffered;
|
||||
|
||||
@ -146,10 +142,7 @@ where
|
||||
pub trait WithSocket {
|
||||
type Output;
|
||||
|
||||
fn with_socket<S: Socket>(
|
||||
self,
|
||||
socket: S,
|
||||
) -> impl std::future::Future<Output = Self::Output> + Send;
|
||||
fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
|
||||
}
|
||||
|
||||
pub struct SocketIntoBox;
|
||||
@ -193,98 +186,67 @@ pub async fn connect_tcp<Ws: WithSocket>(
|
||||
port: u16,
|
||||
with_socket: Ws,
|
||||
) -> crate::Result<Ws::Output> {
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
if crate::rt::rt_tokio::available() {
|
||||
return Ok(with_socket
|
||||
.with_socket(tokio::net::TcpStream::connect((host, port)).await?)
|
||||
.await);
|
||||
}
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature = "_rt-async-io")] {
|
||||
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)
|
||||
} else {
|
||||
crate::rt::missing_rt((host, port, with_socket))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Open a TCP socket to `host` and `port`.
|
||||
///
|
||||
/// If `host` is a hostname, attempt to connect to each address it resolves to.
|
||||
///
|
||||
/// This implements the same behavior as [`tokio::net::TcpStream::connect()`].
|
||||
#[cfg(feature = "_rt-async-io")]
|
||||
async fn connect_tcp_async_io(host: &str, port: u16) -> crate::Result<impl Socket> {
|
||||
use async_io::Async;
|
||||
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
|
||||
|
||||
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
|
||||
let host = host.trim_matches(&['[', ']'][..]);
|
||||
|
||||
let addresses = if let Ok(addr) = host.parse::<Ipv4Addr>() {
|
||||
let addr = SocketAddrV4::new(addr, port);
|
||||
vec![SocketAddr::V4(addr)].into_iter()
|
||||
} else if let Ok(addr) = host.parse::<Ipv6Addr>() {
|
||||
let addr = SocketAddrV6::new(addr, port, 0, 0);
|
||||
vec![SocketAddr::V6(addr)].into_iter()
|
||||
} else {
|
||||
let host = host.to_string();
|
||||
spawn_blocking(move || {
|
||||
let addr = (host.as_str(), port);
|
||||
ToSocketAddrs::to_socket_addrs(&addr)
|
||||
})
|
||||
.await?
|
||||
};
|
||||
if let Ok(addr) = host.parse::<IpAddr>() {
|
||||
return Ok(Async::<TcpStream>::connect((addr, port)).await?);
|
||||
}
|
||||
|
||||
let host = host.to_string();
|
||||
|
||||
let addresses = crate::rt::spawn_blocking(move || {
|
||||
let addr = (host.as_str(), port);
|
||||
ToSocketAddrs::to_socket_addrs(&addr)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
// Loop through all the Socket Addresses that the hostname resolves to
|
||||
for socket_addr in addresses {
|
||||
match connect_tcp_address(socket_addr).await {
|
||||
Ok(stream) => return Ok(with_socket.with_socket(stream).await),
|
||||
match Async::<TcpStream>::connect(socket_addr).await {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(e) => last_err = Some(e),
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach this point, it means we failed to connect to any of the addresses.
|
||||
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
|
||||
Err(match last_err {
|
||||
Some(err) => err,
|
||||
None => io::Error::new(
|
||||
io::ErrorKind::AddrNotAvailable,
|
||||
"Hostname did not resolve to any addresses",
|
||||
)
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn connect_tcp_address(socket_addr: SocketAddr) -> crate::Result<impl Socket> {
|
||||
cfg_if! {
|
||||
if #[cfg(feature = "_rt-tokio")] {
|
||||
if crate::rt::rt_tokio::available() {
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
let stream = TcpStream::connect(socket_addr).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
Ok(stream)
|
||||
} else {
|
||||
crate::rt::missing_rt(socket_addr)
|
||||
}
|
||||
} else if #[cfg(feature = "_rt-async-io")] {
|
||||
use async_io::Async;
|
||||
use std::net::TcpStream;
|
||||
|
||||
let stream = Async::<TcpStream>::connect(socket_addr).await?;
|
||||
stream.get_ref().set_nodelay(true)?;
|
||||
|
||||
Ok(stream)
|
||||
} else {
|
||||
crate::rt::missing_rt(socket_addr);
|
||||
#[allow(unreachable_code)]
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Work around `impl Socket`` and 'unability to specify test build cargo feature'.
|
||||
// `connect_tcp_address` compilation would fail without this impl with
|
||||
// 'cannot infer return type' error.
|
||||
impl Socket for () {
|
||||
fn try_read(&mut self, _: &mut dyn ReadBuf) -> io::Result<usize> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn try_write(&mut self, _: &[u8]) -> io::Result<usize> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_read_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_write_ready(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!()
|
||||
}
|
||||
Err(last_err
|
||||
.unwrap_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::AddrNotAvailable,
|
||||
"Hostname did not resolve to any addresses",
|
||||
)
|
||||
})
|
||||
.into())
|
||||
}
|
||||
|
||||
/// Connect a Unix Domain Socket at the given path.
|
||||
|
@ -9,12 +9,6 @@ use cfg_if::cfg_if;
|
||||
#[cfg(feature = "_rt-async-io")]
|
||||
pub mod rt_async_io;
|
||||
|
||||
#[cfg(feature = "_rt-async-global-executor")]
|
||||
pub mod rt_async_global_executor;
|
||||
|
||||
#[cfg(feature = "_rt-smol")]
|
||||
pub mod rt_smol;
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
pub mod rt_tokio;
|
||||
|
||||
@ -23,14 +17,16 @@ pub mod rt_tokio;
|
||||
pub struct TimeoutError;
|
||||
|
||||
pub enum JoinHandle<T> {
|
||||
#[cfg(feature = "_rt-async-global-executor")]
|
||||
AsyncGlobalExecutor(rt_async_global_executor::JoinHandle<T>),
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
AsyncStd(async_std::task::JoinHandle<T>),
|
||||
#[cfg(feature = "_rt-smol")]
|
||||
Smol(rt_smol::JoinHandle<T>),
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
Tokio(tokio::task::JoinHandle<T>),
|
||||
|
||||
// Implementation shared by `smol` and `async-global-executor`
|
||||
#[cfg(feature = "_rt-async-task")]
|
||||
AsyncTask(Option<async_task::Task<T>>),
|
||||
|
||||
// `PhantomData<T>` requires `T: Unpin`
|
||||
_Phantom(PhantomData<fn() -> T>),
|
||||
}
|
||||
@ -41,7 +37,6 @@ pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, T
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
if rt_tokio::available() {
|
||||
#[allow(clippy::needless_return)]
|
||||
return tokio::time::timeout(duration, f)
|
||||
.await
|
||||
.map_err(|_| TimeoutError);
|
||||
@ -84,15 +79,11 @@ where
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature = "_rt-async-global-executor")] {
|
||||
JoinHandle::AsyncGlobalExecutor(rt_async_global_executor::JoinHandle {
|
||||
task: Some(async_global_executor::spawn(fut)),
|
||||
})
|
||||
JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))
|
||||
} else if #[cfg(feature = "_rt-smol")] {
|
||||
JoinHandle::AsyncTask(Some(smol::spawn(fut)))
|
||||
} else if #[cfg(feature = "_rt-async-std")] {
|
||||
JoinHandle::AsyncStd(async_std::task::spawn(fut))
|
||||
} else if #[cfg(feature = "_rt-smol")] {
|
||||
JoinHandle::Smol(rt_smol::JoinHandle {
|
||||
task: Some(smol::spawn(fut)),
|
||||
})
|
||||
} else {
|
||||
missing_rt(fut)
|
||||
}
|
||||
@ -112,15 +103,11 @@ where
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature = "_rt-async-global-executor")] {
|
||||
JoinHandle::AsyncGlobalExecutor(rt_async_global_executor::JoinHandle {
|
||||
task: Some(async_global_executor::spawn_blocking(f)),
|
||||
})
|
||||
JoinHandle::AsyncTask(Some(async_global_executor::spawn_blocking(f)))
|
||||
} else if #[cfg(feature = "_rt-smol")] {
|
||||
JoinHandle::AsyncTask(Some(smol::unblock(f)))
|
||||
} else if #[cfg(feature = "_rt-async-std")] {
|
||||
JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
|
||||
} else if #[cfg(feature = "_rt-smol")] {
|
||||
JoinHandle::Smol(rt_smol::JoinHandle {
|
||||
task: Some(smol::unblock(f)),
|
||||
})
|
||||
} else {
|
||||
missing_rt(f)
|
||||
}
|
||||
@ -133,17 +120,27 @@ pub async fn yield_now() {
|
||||
return tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
cfg_if! {
|
||||
if #[cfg(feature = "_rt-async-global-executor")] {
|
||||
rt_async_global_executor::yield_now().await
|
||||
} else if #[cfg(feature = "_rt-async-std")] {
|
||||
async_std::task::yield_now().await
|
||||
} else if #[cfg(feature = "_rt-smol")] {
|
||||
smol::future::yield_now().await
|
||||
// `smol`, `async-global-executor`, and `async-std` all have the same implementation for this.
|
||||
//
|
||||
// By immediately signaling the waker and then returning `Pending`,
|
||||
// this essentially just moves the task to the back of the runnable queue.
|
||||
//
|
||||
// There isn't any special integration with the runtime, so we can save code by rolling our own.
|
||||
//
|
||||
// (Tokio's implementation is nearly identical too,
|
||||
// but has additional integration with `tracing` which may be useful for debugging.)
|
||||
let mut yielded = false;
|
||||
|
||||
std::future::poll_fn(|cx| {
|
||||
if !yielded {
|
||||
yielded = true;
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
} else {
|
||||
missing_rt(())
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
@ -169,7 +166,7 @@ pub const fn missing_rt<T>(_unused: T) -> ! {
|
||||
panic!("this functionality requires a Tokio context")
|
||||
}
|
||||
|
||||
panic!("one of the `runtime-async-global-executor`, `runtime-async-std`, `runtime-smol`, or `runtime-tokio` feature must be enabled")
|
||||
panic!("one of the `runtime` features of SQLx must be enabled")
|
||||
}
|
||||
|
||||
impl<T: Send + 'static> Future for JoinHandle<T> {
|
||||
@ -178,16 +175,20 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
|
||||
#[track_caller]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match &mut *self {
|
||||
#[cfg(feature = "_rt-async-global-executor")]
|
||||
Self::AsyncGlobalExecutor(handle) => Pin::new(handle).poll(cx),
|
||||
#[cfg(feature = "_rt-async-std")]
|
||||
Self::AsyncStd(handle) => Pin::new(handle).poll(cx),
|
||||
#[cfg(feature = "_rt-smol")]
|
||||
Self::Smol(handle) => Pin::new(handle).poll(cx),
|
||||
|
||||
#[cfg(feature = "_rt-async-task")]
|
||||
Self::AsyncTask(task) => Pin::new(task)
|
||||
.as_pin_mut()
|
||||
.expect("BUG: task taken")
|
||||
.poll(cx),
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
Self::Tokio(handle) => Pin::new(handle)
|
||||
.poll(cx)
|
||||
.map(|res| res.expect("spawned task panicked")),
|
||||
|
||||
Self::_Phantom(_) => {
|
||||
let _ = cx;
|
||||
unreachable!("runtime should have been checked on spawn")
|
||||
@ -195,3 +196,19 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for JoinHandle<T> {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
// `async_task` cancels on-drop by default.
|
||||
// We need to explicitly detach to match Tokio and `async-std`.
|
||||
#[cfg(feature = "_rt-async-task")]
|
||||
Self::AsyncTask(task) => {
|
||||
if let Some(task) = task.take() {
|
||||
task.detach();
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,30 +0,0 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use async_global_executor::Task;
|
||||
|
||||
pub struct JoinHandle<T> {
|
||||
pub task: Option<Task<T>>,
|
||||
}
|
||||
|
||||
impl<T> Drop for JoinHandle<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(task) = self.task.take() {
|
||||
task.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for JoinHandle<T> {
|
||||
type Output = T;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.task.as_mut() {
|
||||
Some(task) => Future::poll(Pin::new(task), cx),
|
||||
None => unreachable!("JoinHandle polled after dropping"),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
mod join_handle;
|
||||
pub use join_handle::*;
|
||||
|
||||
pub mod yield_now;
|
||||
pub use yield_now::*;
|
@ -1,28 +0,0 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub fn yield_now() -> impl Future<Output = ()> {
|
||||
YieldNow(false)
|
||||
}
|
||||
|
||||
struct YieldNow(bool);
|
||||
|
||||
impl Future for YieldNow {
|
||||
type Output = ();
|
||||
|
||||
// The futures executor is implemented as a FIFO queue, so all this future
|
||||
// does is re-schedule the future back to the end of the queue, giving room
|
||||
// for other futures to progress.
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if !self.0 {
|
||||
self.0 = true;
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use smol::Task;
|
||||
|
||||
pub struct JoinHandle<T> {
|
||||
pub task: Option<Task<T>>,
|
||||
}
|
||||
|
||||
impl<T> Drop for JoinHandle<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(task) = self.task.take() {
|
||||
task.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for JoinHandle<T> {
|
||||
type Output = T;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.task.as_mut() {
|
||||
Some(task) => Future::poll(Pin::new(task), cx),
|
||||
None => unreachable!("JoinHandle polled after dropping"),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
mod join_handle;
|
||||
pub use join_handle::*;
|
Loading…
x
Reference in New Issue
Block a user