feat(macros): extend sqlx_macros::test to test cancellation

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander
2020-07-29 23:24:53 -07:00
committed by Austin Bonander
parent 5f793c6e95
commit 568256b654
6 changed files with 346 additions and 31 deletions

View File

@@ -15,11 +15,15 @@ runtime-actix = [ "actix-rt", "actix-threadpool", "tokio", "tokio-native-tls", "
runtime-async-std = [ "async-std", "async-native-tls" ]
runtime-tokio = [ "tokio", "tokio-native-tls", "once_cell" ]
capture-awaits = ["backtrace", "futures"]
[dependencies]
async-native-tls = { version = "0.3.3", optional = true }
actix-rt = { version = "1.1.1", optional = true }
actix-threadpool = { version = "0.3.2", optional = true }
async-std = { version = "1.6.0", features = [ "unstable" ], optional = true }
backtrace = { version = "0.3.50", optional = true }
futures = { version = "0.3.5", optional = true }
tokio = { version = "0.2.21", optional = true, features = [ "blocking", "stream", "fs", "tcp", "uds", "macros", "rt-core", "rt-threaded", "time", "dns", "io-util" ] }
tokio-native-tls = { version = "0.1.0", optional = true }
native-tls = "0.2.4"

View File

@@ -130,11 +130,40 @@ where
f()
}
/// Capture the last `.await` point in a backtrace.
///
/// NOTE: backtrace requires `.resolve()` still to get something that isn't all numbers.
#[cfg(all(
feature = "runtime-async-std",
feature = "capture-awaits",
not(any(feature = "runtime-actix", feature = "runtime-tokio"))
))]
pub async fn capture_last_await<Fut>(fut: Fut) -> (Fut::Output, Option<backtrace::Backtrace>)
where
Fut: futures::Future,
{
use backtrace::Backtrace;
use std::cell::Cell;
async_std::task_local! {
static LAST_AWAIT: Cell<Option<backtrace::Backtrace>> = Cell::new(None);
}
fn capture_await() {
LAST_AWAIT.with(|last_await| last_await.set(Some(Backtrace::new_unresolved())));
}
LAST_AWAIT.with(|last| last.set(None));
let res = capture_awaits(fut, capture_await).await;
let last_await = LAST_AWAIT.with(|last_await| last_await.replace(None));
(res, last_await)
}
#[cfg(all(
any(feature = "runtime-tokio", feature = "runtime-actix"),
not(feature = "runtime-async-std")
))]
pub use tokio_runtime::{block_on, enter_runtime};
pub use tokio_runtime::*;
#[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))]
mod tokio_runtime {
@@ -162,4 +191,114 @@ mod tokio_runtime {
{
RUNTIME.enter(f)
}
/// Capture the last `.await` point in a backtrace.
///
/// NOTE: backtrace requires `.resolve()` still to get something that isn't all numbers.
#[cfg(feature = "capture-awaits")]
pub async fn capture_last_await<Fut>(fut: Fut) -> (Fut::Output, Option<backtrace::Backtrace>)
where
Fut: futures::Future,
{
use backtrace::Backtrace;
use futures::{future::poll_fn, pin_mut, task::Context, Future};
use std::cell::Cell;
tokio::task_local!(
static LAST_AWAIT: Cell<Option<Backtrace>>;
);
fn capture_await() {
LAST_AWAIT.with(|last_await| last_await.set(Some(Backtrace::new_unresolved())))
}
LAST_AWAIT
.scope(Cell::new(None), async move {
let res = super::capture_awaits(fut, capture_await);
let backtrace = LAST_AWAIT.with(|last_await| last_await.replace(None));
(res, backtrace)
})
.await
}
}
/// Create a `Waker` which captures a backtrace when it is cloned; when a `Waker` is cloned that
/// should mean that it's being squirreled away because the future it was called with is
/// going to return `Pending`.
#[cfg(feature = "capture-awaits")]
fn capture_await_waker(waker: &std::task::Waker, capture_await: fn()) -> std::task::Waker {
use std::mem;
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Waker};
struct WakerData {
inner: std::task::Waker,
capture_await: fn(),
}
// by requiring `capture_await` to be a regular fn pointer, we don't need to leak an
// allocation for our VTable
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone(data: *const ()) -> RawWaker {
// SAFETY: pointer must be the right type and Arc must not be dropped here without cloning.
let data = Arc::from_raw(data as *const WakerData);
(data.capture_await)();
let cloned = data.clone();
mem::forget(data);
raw_waker(cloned)
}
unsafe fn wake(data: *const ()) {
// SAFETY: pointer must be the right type
// LEAK SAFETY: `Arc` *must* be dropped here
let data = Arc::from_raw(data as *const WakerData);
data.inner.wake_by_ref();
}
unsafe fn wake_by_ref(data: *const ()) {
// SAFETY: pointer must be the right type and Arc must *NOT* be dropped here
(data as *const WakerData)
.as_ref()
.unwrap()
.inner
.wake_by_ref();
}
unsafe fn drop(data: *const ()) {
// SAFETY: pointer must be the right type
// LEAK SAFETY: `Arc` *must* be dropped here
let _ = Arc::from_raw(data as *const WakerData);
}
fn raw_waker(data: Arc<WakerData>) -> RawWaker {
// SAFETY: Arc must not be dropped here; be sure to use Arc::into_raw
RawWaker::new(Arc::into_raw(data) as *const (), &VTABLE)
}
unsafe {
// SAFETY: verified above
Waker::from_raw(raw_waker(Arc::new(WakerData {
inner: waker.clone(),
capture_await,
})))
}
}
#[cfg(feature = "capture-awaits")]
async fn capture_awaits<Fut>(fut: Fut, capture_await: fn()) -> Fut::Output
where
Fut: futures::Future,
{
use futures::{future::poll_fn, pin_mut, task::Context, Future};
pin_mut!(fut);
poll_fn(move |cx| {
let waker = capture_await_waker(cx.waker(), capture_await);
let mut cx = Context::from_waker(&waker);
fut.as_mut().poll(&mut cx)
})
.await
}