mirror of
https://github.com/embassy-rs/embassy.git
synced 2025-10-02 14:44:32 +00:00
Merge pull request #4443 from Brezak/task-unsafe
executor: mark unsafe tasks as unsafe
This commit is contained in:
commit
f3cb0f3c30
@ -5,7 +5,7 @@ use darling::FromMeta;
|
|||||||
use proc_macro2::{Span, TokenStream};
|
use proc_macro2::{Span, TokenStream};
|
||||||
use quote::{format_ident, quote};
|
use quote::{format_ident, quote};
|
||||||
use syn::visit::{self, Visit};
|
use syn::visit::{self, Visit};
|
||||||
use syn::{Expr, ExprLit, Lit, LitInt, ReturnType, Type};
|
use syn::{Expr, ExprLit, Lit, LitInt, ReturnType, Type, Visibility};
|
||||||
|
|
||||||
use crate::util::*;
|
use crate::util::*;
|
||||||
|
|
||||||
@ -112,13 +112,11 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let task_ident = f.sig.ident.clone();
|
// Copy the generics + where clause to avoid more spurious errors.
|
||||||
let task_inner_ident = format_ident!("__{}_task", task_ident);
|
let generics = &f.sig.generics;
|
||||||
|
let where_clause = &f.sig.generics.where_clause;
|
||||||
let mut task_inner = f.clone();
|
let unsafety = &f.sig.unsafety;
|
||||||
let visibility = task_inner.vis.clone();
|
let visibility = &f.vis;
|
||||||
task_inner.vis = syn::Visibility::Inherited;
|
|
||||||
task_inner.sig.ident = task_inner_ident.clone();
|
|
||||||
|
|
||||||
// assemble the original input arguments,
|
// assemble the original input arguments,
|
||||||
// including any attributes that may have
|
// including any attributes that may have
|
||||||
@ -131,6 +129,64 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let task_ident = f.sig.ident.clone();
|
||||||
|
let task_inner_ident = format_ident!("__{}_task", task_ident);
|
||||||
|
|
||||||
|
let task_inner_future_output = match &f.sig.output {
|
||||||
|
ReturnType::Default => quote! {-> impl ::core::future::Future<Output = ()>},
|
||||||
|
// Special case the never type since we can't stuff it into a `impl Future<Output = !>`
|
||||||
|
ReturnType::Type(arrow, maybe_never)
|
||||||
|
if f.sig.asyncness.is_some() && matches!(**maybe_never, Type::Never(_)) =>
|
||||||
|
{
|
||||||
|
quote! {
|
||||||
|
#arrow impl ::core::future::Future<Output=#embassy_executor::_export::Never>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ReturnType::Type(arrow, maybe_never) if matches!(**maybe_never, Type::Never(_)) => quote! {
|
||||||
|
#arrow #maybe_never
|
||||||
|
},
|
||||||
|
// Grab the arrow span, why not
|
||||||
|
ReturnType::Type(arrow, typ) if f.sig.asyncness.is_some() => quote! {
|
||||||
|
#arrow impl ::core::future::Future<Output = #typ>
|
||||||
|
},
|
||||||
|
// We assume that if `f` isn't async, it must return `-> impl Future<...>`
|
||||||
|
// This is checked using traits later
|
||||||
|
ReturnType::Type(arrow, typ) => quote! {
|
||||||
|
#arrow #typ
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// We have to rename the function since it might be recursive;
|
||||||
|
let mut task_inner_function = f.clone();
|
||||||
|
let task_inner_function_ident = format_ident!("__{}_task_inner_function", task_ident);
|
||||||
|
task_inner_function.sig.ident = task_inner_function_ident.clone();
|
||||||
|
task_inner_function.vis = Visibility::Inherited;
|
||||||
|
|
||||||
|
let task_inner_body = if errors.is_empty() {
|
||||||
|
quote! {
|
||||||
|
#task_inner_function
|
||||||
|
|
||||||
|
// SAFETY: All the preconditions to `#task_ident` apply to
|
||||||
|
// all contexts `#task_inner_ident` is called in
|
||||||
|
#unsafety {
|
||||||
|
#task_inner_function_ident(#(#full_args,)*)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
async {::core::todo!()}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let task_inner = quote! {
|
||||||
|
#visibility fn #task_inner_ident #generics (#fargs)
|
||||||
|
#task_inner_future_output
|
||||||
|
#where_clause
|
||||||
|
{
|
||||||
|
#task_inner_body
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let spawn = if returns_impl_trait {
|
let spawn = if returns_impl_trait {
|
||||||
quote!(spawn)
|
quote!(spawn)
|
||||||
} else {
|
} else {
|
||||||
@ -173,7 +229,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) }
|
unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) }
|
||||||
};
|
};
|
||||||
|
|
||||||
let task_outer_attrs = task_inner.attrs.clone();
|
let task_outer_attrs = &f.attrs;
|
||||||
|
|
||||||
if !errors.is_empty() {
|
if !errors.is_empty() {
|
||||||
task_outer_body = quote! {
|
task_outer_body = quote! {
|
||||||
@ -183,10 +239,6 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the generics + where clause to avoid more spurious errors.
|
|
||||||
let generics = &f.sig.generics;
|
|
||||||
let where_clause = &f.sig.generics.where_clause;
|
|
||||||
|
|
||||||
let result = quote! {
|
let result = quote! {
|
||||||
// This is the user's task function, renamed.
|
// This is the user's task function, renamed.
|
||||||
// We put it outside the #task_ident fn below, because otherwise
|
// We put it outside the #task_ident fn below, because otherwise
|
||||||
@ -196,7 +248,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
|
|||||||
#task_inner
|
#task_inner
|
||||||
|
|
||||||
#(#task_outer_attrs)*
|
#(#task_outer_attrs)*
|
||||||
#visibility fn #task_ident #generics (#fargs) -> #embassy_executor::SpawnToken<impl Sized> #where_clause{
|
#visibility #unsafety fn #task_ident #generics (#fargs) -> #embassy_executor::SpawnToken<impl Sized> #where_clause{
|
||||||
#task_outer_body
|
#task_outer_body
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +265,7 @@ fn check_arg_ty(errors: &mut TokenStream, ty: &Type) {
|
|||||||
|
|
||||||
impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
|
impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
|
||||||
fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) {
|
fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) {
|
||||||
// only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`.
|
// Only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`.
|
||||||
if i.lifetime.is_none() {
|
if i.lifetime.is_none() {
|
||||||
error(
|
error(
|
||||||
self.errors,
|
self.errors,
|
||||||
|
@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- Added support for `-> impl Future<Output = ()>` in `#[task]`
|
- Added support for `-> impl Future<Output = ()>` in `#[task]`
|
||||||
- Fixed `Send` unsoundness with `-> impl Future` tasks
|
- Fixed `Send` unsoundness with `-> impl Future` tasks
|
||||||
- Marked `Spawner::for_current_executor` as `unsafe`
|
- Marked `Spawner::for_current_executor` as `unsafe`
|
||||||
|
- `#[task]` now properly marks the generated function as unsafe if the task is marked unsafe
|
||||||
|
|
||||||
## 0.7.0 - 2025-01-02
|
## 0.7.0 - 2025-01-02
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ pub mod _export {
|
|||||||
);
|
);
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
trait HasOutput {
|
pub trait HasOutput {
|
||||||
type Output;
|
type Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ pub mod _export {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
type Never = <fn() -> ! as HasOutput>::Output;
|
pub type Never = <fn() -> ! as HasOutput>::Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation details for embassy macros.
|
/// Implementation details for embassy macros.
|
||||||
@ -242,7 +242,7 @@ pub mod _export {
|
|||||||
impl TaskReturnValue for Never {}
|
impl TaskReturnValue for Never {}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
trait HasOutput {
|
pub trait HasOutput {
|
||||||
type Output;
|
type Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,5 +251,5 @@ pub mod _export {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
type Never = <fn() -> ! as HasOutput>::Output;
|
pub type Never = <fn() -> ! as HasOutput>::Output;
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
use std::task::Poll;
|
use std::task::Poll;
|
||||||
|
|
||||||
use embassy_executor::raw::Executor;
|
use embassy_executor::raw::Executor;
|
||||||
use embassy_executor::task;
|
use embassy_executor::{task, Spawner};
|
||||||
|
|
||||||
#[export_name = "__pender"]
|
#[export_name = "__pender"]
|
||||||
fn __pender(context: *mut ()) {
|
fn __pender(context: *mut ()) {
|
||||||
@ -317,3 +317,12 @@ fn executor_task_cfg_args() {
|
|||||||
let (_, _, _) = (a, b, c);
|
let (_, _, _) = (a, b, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn recursive_task() {
|
||||||
|
#[embassy_executor::task(pool_size = 2)]
|
||||||
|
async fn task1() {
|
||||||
|
let spawner = unsafe { Spawner::for_current_executor().await };
|
||||||
|
spawner.spawn(task1());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -32,4 +32,7 @@ fn ui() {
|
|||||||
t.compile_fail("tests/ui/self.rs");
|
t.compile_fail("tests/ui/self.rs");
|
||||||
t.compile_fail("tests/ui/type_error.rs");
|
t.compile_fail("tests/ui/type_error.rs");
|
||||||
t.compile_fail("tests/ui/where_clause.rs");
|
t.compile_fail("tests/ui/where_clause.rs");
|
||||||
|
t.compile_fail("tests/ui/unsafe_op_in_unsafe_task.rs");
|
||||||
|
|
||||||
|
t.pass("tests/ui/task_safety_attribute.rs");
|
||||||
}
|
}
|
||||||
|
@ -7,4 +7,4 @@ error[E0277]: task futures must resolve to `()` or `!`
|
|||||||
= note: use `async fn` or change the return type to `impl Future<Output = ()>`
|
= note: use `async fn` or change the return type to `impl Future<Output = ()>`
|
||||||
= help: the following other types implement trait `TaskReturnValue`:
|
= help: the following other types implement trait `TaskReturnValue`:
|
||||||
()
|
()
|
||||||
<fn() -> ! as _export::HasOutput>::Output
|
<fn() -> ! as HasOutput>::Output
|
||||||
|
@ -8,3 +8,17 @@ help: indicate the anonymous lifetime
|
|||||||
|
|
|
|
||||||
6 | async fn task(_x: Foo<'_>) {}
|
6 | async fn task(_x: Foo<'_>) {}
|
||||||
| ++++
|
| ++++
|
||||||
|
|
||||||
|
error[E0700]: hidden type for `impl Sized` captures lifetime that does not appear in bounds
|
||||||
|
--> tests/ui/nonstatic_struct_elided.rs:5:1
|
||||||
|
|
|
||||||
|
5 | #[embassy_executor::task]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^ opaque type defined here
|
||||||
|
6 | async fn task(_x: Foo) {}
|
||||||
|
| --- hidden type `impl Sized` captures the anonymous lifetime defined here
|
||||||
|
|
|
||||||
|
= note: this error originates in the attribute macro `embassy_executor::task` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
|
help: add a `use<...>` bound to explicitly capture `'_`
|
||||||
|
|
|
||||||
|
5 | #[embassy_executor::task] + use<'_>
|
||||||
|
| +++++++++
|
||||||
|
25
embassy-executor/tests/ui/task_safety_attribute.rs
Normal file
25
embassy-executor/tests/ui/task_safety_attribute.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))]
|
||||||
|
#![deny(unused_unsafe)]
|
||||||
|
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
|
#[embassy_executor::task]
|
||||||
|
async fn safe() {}
|
||||||
|
|
||||||
|
#[embassy_executor::task]
|
||||||
|
async unsafe fn not_safe() {}
|
||||||
|
|
||||||
|
#[export_name = "__pender"]
|
||||||
|
fn pender(_: *mut ()) {
|
||||||
|
// The test doesn't link if we don't include this.
|
||||||
|
// We never call this anyway.
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _forget_me = safe();
|
||||||
|
// SAFETY: not_safe has not safety preconditions
|
||||||
|
let _forget_me2 = unsafe { not_safe() };
|
||||||
|
|
||||||
|
mem::forget(_forget_me);
|
||||||
|
mem::forget(_forget_me2);
|
||||||
|
}
|
10
embassy-executor/tests/ui/unsafe_op_in_unsafe_task.rs
Normal file
10
embassy-executor/tests/ui/unsafe_op_in_unsafe_task.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))]
|
||||||
|
#![deny(unsafe_op_in_unsafe_fn)]
|
||||||
|
|
||||||
|
#[embassy_executor::task]
|
||||||
|
async unsafe fn task() {
|
||||||
|
let x = 5;
|
||||||
|
(&x as *const i32).read();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
18
embassy-executor/tests/ui/unsafe_op_in_unsafe_task.stderr
Normal file
18
embassy-executor/tests/ui/unsafe_op_in_unsafe_task.stderr
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
error[E0133]: call to unsafe function `std::ptr::const_ptr::<impl *const T>::read` is unsafe and requires unsafe block
|
||||||
|
--> tests/ui/unsafe_op_in_unsafe_task.rs:7:5
|
||||||
|
|
|
||||||
|
7 | (&x as *const i32).read();
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^^^^ call to unsafe function
|
||||||
|
|
|
||||||
|
= note: for more information, see <https://doc.rust-lang.org/nightly/edition-guide/rust-2024/unsafe-op-in-unsafe-fn.html>
|
||||||
|
= note: consult the function's documentation for information on how to avoid undefined behavior
|
||||||
|
note: an unsafe function restricts its caller, but its body is safe by default
|
||||||
|
--> tests/ui/unsafe_op_in_unsafe_task.rs:5:1
|
||||||
|
|
|
||||||
|
5 | async unsafe fn task() {
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
note: the lint level is defined here
|
||||||
|
--> tests/ui/unsafe_op_in_unsafe_task.rs:2:9
|
||||||
|
|
|
||||||
|
2 | #![deny(unsafe_op_in_unsafe_fn)]
|
||||||
|
| ^^^^^^^^^^^^^^^^^^^^^^
|
Loading…
x
Reference in New Issue
Block a user