mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-04-08 11:17:12 +00:00
feat(macros): extend sqlx_macros::test to test cancellation
Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
committed by
Austin Bonander
parent
5f793c6e95
commit
568256b654
@@ -15,11 +15,15 @@ runtime-actix = [ "actix-rt", "actix-threadpool", "tokio", "tokio-native-tls", "
|
||||
runtime-async-std = [ "async-std", "async-native-tls" ]
|
||||
runtime-tokio = [ "tokio", "tokio-native-tls", "once_cell" ]
|
||||
|
||||
capture-awaits = ["backtrace", "futures"]
|
||||
|
||||
[dependencies]
|
||||
async-native-tls = { version = "0.3.3", optional = true }
|
||||
actix-rt = { version = "1.1.1", optional = true }
|
||||
actix-threadpool = { version = "0.3.2", optional = true }
|
||||
async-std = { version = "1.6.0", features = [ "unstable" ], optional = true }
|
||||
backtrace = { version = "0.3.50", optional = true }
|
||||
futures = { version = "0.3.5", optional = true }
|
||||
tokio = { version = "0.2.21", optional = true, features = [ "blocking", "stream", "fs", "tcp", "uds", "macros", "rt-core", "rt-threaded", "time", "dns", "io-util" ] }
|
||||
tokio-native-tls = { version = "0.1.0", optional = true }
|
||||
native-tls = "0.2.4"
|
||||
|
||||
@@ -130,11 +130,40 @@ where
|
||||
f()
|
||||
}
|
||||
|
||||
/// Capture the last `.await` point in a backtrace.
|
||||
///
|
||||
/// NOTE: backtrace requires `.resolve()` still to get something that isn't all numbers.
|
||||
#[cfg(all(
|
||||
feature = "runtime-async-std",
|
||||
feature = "capture-awaits",
|
||||
not(any(feature = "runtime-actix", feature = "runtime-tokio"))
|
||||
))]
|
||||
pub async fn capture_last_await<Fut>(fut: Fut) -> (Fut::Output, Option<backtrace::Backtrace>)
|
||||
where
|
||||
Fut: futures::Future,
|
||||
{
|
||||
use backtrace::Backtrace;
|
||||
use std::cell::Cell;
|
||||
|
||||
async_std::task_local! {
|
||||
static LAST_AWAIT: Cell<Option<backtrace::Backtrace>> = Cell::new(None);
|
||||
}
|
||||
|
||||
fn capture_await() {
|
||||
LAST_AWAIT.with(|last_await| last_await.set(Some(Backtrace::new_unresolved())));
|
||||
}
|
||||
|
||||
LAST_AWAIT.with(|last| last.set(None));
|
||||
let res = capture_awaits(fut, capture_await).await;
|
||||
let last_await = LAST_AWAIT.with(|last_await| last_await.replace(None));
|
||||
(res, last_await)
|
||||
}
|
||||
|
||||
#[cfg(all(
|
||||
any(feature = "runtime-tokio", feature = "runtime-actix"),
|
||||
not(feature = "runtime-async-std")
|
||||
))]
|
||||
pub use tokio_runtime::{block_on, enter_runtime};
|
||||
pub use tokio_runtime::*;
|
||||
|
||||
#[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))]
|
||||
mod tokio_runtime {
|
||||
@@ -162,4 +191,114 @@ mod tokio_runtime {
|
||||
{
|
||||
RUNTIME.enter(f)
|
||||
}
|
||||
|
||||
/// Capture the last `.await` point in a backtrace.
|
||||
///
|
||||
/// NOTE: backtrace requires `.resolve()` still to get something that isn't all numbers.
|
||||
#[cfg(feature = "capture-awaits")]
|
||||
pub async fn capture_last_await<Fut>(fut: Fut) -> (Fut::Output, Option<backtrace::Backtrace>)
|
||||
where
|
||||
Fut: futures::Future,
|
||||
{
|
||||
use backtrace::Backtrace;
|
||||
use futures::{future::poll_fn, pin_mut, task::Context, Future};
|
||||
use std::cell::Cell;
|
||||
|
||||
tokio::task_local!(
|
||||
static LAST_AWAIT: Cell<Option<Backtrace>>;
|
||||
);
|
||||
|
||||
fn capture_await() {
|
||||
LAST_AWAIT.with(|last_await| last_await.set(Some(Backtrace::new_unresolved())))
|
||||
}
|
||||
|
||||
LAST_AWAIT
|
||||
.scope(Cell::new(None), async move {
|
||||
let res = super::capture_awaits(fut, capture_await);
|
||||
let backtrace = LAST_AWAIT.with(|last_await| last_await.replace(None));
|
||||
|
||||
(res, backtrace)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a `Waker` which captures a backtrace when it is cloned; when a `Waker` is cloned that
|
||||
/// should mean that it's being squirreled away because the future it was called with is
|
||||
/// going to return `Pending`.
|
||||
#[cfg(feature = "capture-awaits")]
|
||||
fn capture_await_waker(waker: &std::task::Waker, capture_await: fn()) -> std::task::Waker {
|
||||
use std::mem;
|
||||
use std::sync::Arc;
|
||||
use std::task::{RawWaker, RawWakerVTable, Waker};
|
||||
|
||||
struct WakerData {
|
||||
inner: std::task::Waker,
|
||||
capture_await: fn(),
|
||||
}
|
||||
|
||||
// by requiring `capture_await` to be a regular fn pointer, we don't need to leak an
|
||||
// allocation for our VTable
|
||||
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
|
||||
|
||||
unsafe fn clone(data: *const ()) -> RawWaker {
|
||||
// SAFETY: pointer must be the right type and Arc must not be dropped here without cloning.
|
||||
let data = Arc::from_raw(data as *const WakerData);
|
||||
(data.capture_await)();
|
||||
let cloned = data.clone();
|
||||
mem::forget(data);
|
||||
raw_waker(cloned)
|
||||
}
|
||||
|
||||
unsafe fn wake(data: *const ()) {
|
||||
// SAFETY: pointer must be the right type
|
||||
// LEAK SAFETY: `Arc` *must* be dropped here
|
||||
let data = Arc::from_raw(data as *const WakerData);
|
||||
data.inner.wake_by_ref();
|
||||
}
|
||||
|
||||
unsafe fn wake_by_ref(data: *const ()) {
|
||||
// SAFETY: pointer must be the right type and Arc must *NOT* be dropped here
|
||||
(data as *const WakerData)
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.inner
|
||||
.wake_by_ref();
|
||||
}
|
||||
|
||||
unsafe fn drop(data: *const ()) {
|
||||
// SAFETY: pointer must be the right type
|
||||
// LEAK SAFETY: `Arc` *must* be dropped here
|
||||
let _ = Arc::from_raw(data as *const WakerData);
|
||||
}
|
||||
|
||||
fn raw_waker(data: Arc<WakerData>) -> RawWaker {
|
||||
// SAFETY: Arc must not be dropped here; be sure to use Arc::into_raw
|
||||
RawWaker::new(Arc::into_raw(data) as *const (), &VTABLE)
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// SAFETY: verified above
|
||||
Waker::from_raw(raw_waker(Arc::new(WakerData {
|
||||
inner: waker.clone(),
|
||||
capture_await,
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "capture-awaits")]
|
||||
async fn capture_awaits<Fut>(fut: Fut, capture_await: fn()) -> Fut::Output
|
||||
where
|
||||
Fut: futures::Future,
|
||||
{
|
||||
use futures::{future::poll_fn, pin_mut, task::Context, Future};
|
||||
pin_mut!(fut);
|
||||
|
||||
poll_fn(move |cx| {
|
||||
let waker = capture_await_waker(cx.waker(), capture_await);
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
fut.as_mut().poll(&mut cx)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user