attributes: update #[instrument] to support async-trait 0.1.43+ #1228)

It works with both the old and new version of async-trait (except for
one doc test that failed previously, and that works with the new
version). One nice thing is that the code is simpler (e.g.g no self
renaming to _self, which will enable some simplifications in the
future).

A minor nitpick is that I disliked the deeply nested pattern matching in
get_async_trait_kind (previously: get_async_trait_function), so I
"flattened" that a bit.

Fixes  #1219.
This commit is contained in:
Simon THOBY
2021-03-10 21:11:15 +01:00
committed by GitHub
parent b9f722ff71
commit 2e4a4f367d
3 changed files with 257 additions and 173 deletions

View File

@@ -34,15 +34,14 @@ proc-macro = true
[dependencies]
proc-macro2 = "1"
syn = { version = "1", default-features = false, features = ["full", "parsing", "printing", "visit-mut", "clone-impls", "extra-traits", "proc-macro"] }
syn = { version = "1", default-features = false, features = ["full", "parsing", "printing", "visit", "visit-mut", "clone-impls", "extra-traits", "proc-macro"] }
quote = "1"
[dev-dependencies]
tracing = { path = "../tracing", version = "0.2" }
tokio-test = { version = "0.2.0" }
tracing-core = { path = "../tracing-core", version = "0.2"}
async-trait = "0.1"
async-trait = "0.1.44"
[badges]
maintenance = { status = "experimental" }

View File

@@ -74,7 +74,6 @@
patterns_in_fns_without_body,
private_in_public,
unconditional_recursion,
unused,
unused_allocation,
unused_comparisons,
unused_parens,
@@ -89,9 +88,9 @@ use quote::{quote, quote_spanned, ToTokens};
use syn::ext::IdentExt as _;
use syn::parse::{Parse, ParseStream};
use syn::{
punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprCall, FieldPat, FnArg, Ident, Item,
ItemFn, LitInt, LitStr, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct,
PatType, Path, Signature, Stmt, Token,
punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg,
Ident, Item, ItemFn, LitInt, LitStr, Pat, PatIdent, PatReference, PatStruct, PatTuple,
PatTupleStruct, PatType, Path, Signature, Stmt, Token, TypePath,
};
/// Instruments a function to create and enter a `tracing` [span] every time
/// the function is called.
@@ -221,11 +220,12 @@ use syn::{
/// }
/// ```
///
/// An interesting note on this subject is that references to the `Self`
/// type inside the `fields` argument are only allowed when the instrumented
/// function is a method aka. the function receives `self` as an argument.
/// For example, this *will not work* because it doesn't receive `self`:
/// ```compile_fail
/// Note than on `async-trait` <= 0.1.43, references to the `Self`
/// type inside the `fields` argument were only allowed when the instrumented
/// function is a method (i.e., the function receives `self` as an argument).
/// For example, this *used to not work* because the instrument function
/// didn't receive `self`:
/// ```
/// # use tracing::instrument;
/// use async_trait::async_trait;
///
@@ -244,7 +244,8 @@ use syn::{
/// }
/// ```
/// Instead, you should manually rewrite any `Self` types as the type for
/// which you implement the trait: `#[instrument(fields(tmp = std::any::type_name::<Bar>()))]`.
/// which you implement the trait: `#[instrument(fields(tmp = std::any::type_name::<Bar>()))]`
/// (or maybe you can just bump `async-trait`).
///
/// [span]: https://docs.rs/tracing/latest/tracing/span/index.html
/// [`tracing`]: https://github.com/tokio-rs/tracing
@@ -254,30 +255,47 @@ pub fn instrument(
args: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input: ItemFn = syn::parse_macro_input!(item as ItemFn);
let input = syn::parse_macro_input!(item as ItemFn);
let args = syn::parse_macro_input!(args as InstrumentArgs);
let instrumented_function_name = input.sig.ident.to_string();
// check for async_trait-like patterns in the block and wrap the
// internal function with Instrument instead of wrapping the
// async_trait generated wrapper
// check for async_trait-like patterns in the block, and instrument
// the future instead of the wrapper
if let Some(internal_fun) = get_async_trait_info(&input.block, input.sig.asyncness.is_some()) {
// let's rewrite some statements!
let mut stmts: Vec<Stmt> = input.block.stmts.to_vec();
for stmt in &mut stmts {
if let Stmt::Item(Item::Fn(fun)) = stmt {
// instrument the function if we considered it as the one we truly want to trace
if fun.sig.ident == internal_fun.name {
*stmt = syn::parse2(gen_body(
fun,
args,
instrumented_function_name,
Some(internal_fun),
))
.unwrap();
break;
let mut out_stmts = Vec::with_capacity(input.block.stmts.len());
for stmt in &input.block.stmts {
if stmt == internal_fun.source_stmt {
match internal_fun.kind {
// async-trait <= 0.1.43
AsyncTraitKind::Function(fun) => {
out_stmts.push(gen_function(
fun,
args,
instrumented_function_name,
internal_fun.self_type,
));
}
// async-trait >= 0.1.44
AsyncTraitKind::Async(async_expr) => {
// fallback if we couldn't find the '__async_trait' binding, might be
// useful for crates exhibiting the same behaviors as async-trait
let instrumented_block = gen_block(
&async_expr.block,
&input.sig.inputs,
true,
args,
instrumented_function_name,
None,
);
let async_attrs = &async_expr.attrs;
out_stmts.push(quote! {
Box::pin(#(#async_attrs) * async move { #instrumented_block })
});
}
}
break;
}
}
@@ -287,20 +305,21 @@ pub fn instrument(
quote!(
#(#attrs) *
#vis #sig {
#(#stmts) *
#(#out_stmts) *
}
)
.into()
} else {
gen_body(&input, args, instrumented_function_name, None).into()
gen_function(&input, args, instrumented_function_name, None).into()
}
}
fn gen_body(
/// Given an existing function, generate an instrumented version of that function
fn gen_function(
input: &ItemFn,
mut args: InstrumentArgs,
args: InstrumentArgs,
instrumented_function_name: String,
async_trait_fun: Option<AsyncTraitInfo>,
self_type: Option<syn::TypePath>,
) -> proc_macro2::TokenStream {
// these are needed ahead of time, as ItemFn contains the function body _and_
// isn't representable inside a quote!/quote_spanned! macro
@@ -330,9 +349,39 @@ fn gen_body(
..
} = sig;
let err = args.err;
let warnings = args.warnings();
let body = gen_block(
block,
params,
asyncness.is_some(),
args,
instrumented_function_name,
self_type,
);
quote!(
#(#attrs) *
#vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
#where_clause
{
#warnings
#body
}
)
}
/// Instrument a block
fn gen_block(
block: &Block,
params: &Punctuated<FnArg, Token![,]>,
async_context: bool,
mut args: InstrumentArgs,
instrumented_function_name: String,
self_type: Option<syn::TypePath>,
) -> proc_macro2::TokenStream {
let err = args.err;
// generate the span's name
let span_name = args
// did the user override the span's name?
@@ -353,8 +402,8 @@ fn gen_body(
FnArg::Receiver(_) => Box::new(iter::once(Ident::new("self", param.span()))),
})
// Little dance with new (user-exposed) names and old (internal)
// names of identifiers. That way, you can do the following
// even though async_trait rewrite "self" as "_self":
// names of identifiers. That way, we could do the following
// even though async_trait (<=0.1.43) rewrites "self" as "_self":
// ```
// #[async_trait]
// impl Foo for FooImpl {
@@ -363,10 +412,9 @@ fn gen_body(
// }
// ```
.map(|x| {
// if we are inside a function generated by async-trait, we
// should take care to rewrite "_self" as "self" for
// 'user convenience'
if async_trait_fun.is_some() && x == "_self" {
// if we are inside a function generated by async-trait <=0.1.43, we need to
// take care to rewrite "_self" as "self" for 'user convenience'
if self_type.is_some() && x == "_self" {
(Ident::new("self", x.span()), x)
} else {
(x.clone(), x)
@@ -387,7 +435,7 @@ fn gen_body(
// filter out skipped fields
let quoted_fields: Vec<_> = param_names
.into_iter()
.iter()
.filter(|(param, _)| {
if args.skips.contains(param) {
return false;
@@ -407,13 +455,19 @@ fn gen_body(
.map(|(user_name, real_name)| quote!(#user_name = tracing::field::debug(&#real_name)))
.collect();
// when async-trait is in use, replace instances of "self" with "_self" inside the fields values
if let (Some(ref async_trait_fun), Some(Fields(ref mut fields))) =
(async_trait_fun, &mut args.fields)
{
let mut replacer = SelfReplacer {
ty: async_trait_fun.self_type.clone(),
// replace every use of a variable with its original name
if let Some(Fields(ref mut fields)) = args.fields {
let mut replacer = IdentAndTypesRenamer {
idents: param_names,
types: Vec::new(),
};
// when async-trait <=0.1.43 is in use, replace instances
// of the "Self" type inside the fields values
if let Some(self_type) = self_type {
replacer.types.push(("Self", self_type));
}
for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
syn::visit_mut::visit_expr_mut(&mut replacer, e);
}
@@ -436,9 +490,9 @@ fn gen_body(
// which is `instrument`ed using `tracing-futures`. Otherwise, this will
// enter the span and then perform the rest of the body.
// If `err` is in args, instrument any resulting `Err`s.
let body = if asyncness.is_some() {
if async_context {
if err {
quote_spanned! {block.span()=>
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
tracing::Instrument::instrument(async move {
match async move { #block }.await {
@@ -450,7 +504,7 @@ fn gen_body(
}
}
}, __tracing_attr_span).await
}
)
} else {
quote_spanned!(block.span()=>
let __tracing_attr_span = #span;
@@ -481,17 +535,7 @@ fn gen_body(
let __tracing_attr_guard = __tracing_attr_span.enter();
#block
)
};
quote!(
#(#attrs) *
#vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
#where_clause
{
#warnings
#body
}
)
}
}
#[derive(Default, Debug)]
@@ -835,6 +879,20 @@ mod kw {
syn::custom_keyword!(err);
}
enum AsyncTraitKind<'a> {
// old construction. Contains the function
Function(&'a ItemFn),
// new construction. Contains a reference to the async block
Async(&'a ExprAsync),
}
struct AsyncTraitInfo<'a> {
// statement that must be patched
source_stmt: &'a Stmt,
kind: AsyncTraitKind<'a>,
self_type: Option<syn::TypePath>,
}
// Get the AST of the inner function we need to hook, if it was generated
// by async-trait.
// When we are given a function annotated by async-trait, that function
@@ -842,118 +900,122 @@ mod kw {
// user logic, and it is that pinned future that needs to be instrumented.
// Were we to instrument its parent, we would only collect information
// regarding the allocation of that future, and not its own span of execution.
// So we inspect the block of the function to find if it matches the pattern
// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` and we return
// the name `foo` if that is the case. 'gen_body' will then be able
// to use that information to instrument the proper function.
// Depending on the version of async-trait, we inspect the block of the function
// to find if it matches the pattern
// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` (<=0.1.43), or if
// it matches `Box::pin(async move { ... }) (>=0.1.44). We the return the
// statement that must be instrumented, along with some other informations.
// 'gen_body' will then be able to use that information to instrument the
// proper function/future.
// (this follows the approach suggested in
// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
fn get_async_trait_function(block: &Block, block_is_async: bool) -> Option<&ItemFn> {
fn get_async_trait_info(block: &Block, block_is_async: bool) -> Option<AsyncTraitInfo<'_>> {
// are we in an async context? If yes, this isn't a async_trait-like pattern
if block_is_async {
return None;
}
// list of async functions declared inside the block
let mut inside_funs = Vec::new();
// last expression declared in the block (it determines the return
// value of the block, so that if we are working on a function
// whose `trait` or `impl` declaration is annotated by async_trait,
// this is quite likely the point where the future is pinned)
let mut last_expr = None;
// obtain the list of direct internal functions and the last
// expression of the block
for stmt in &block.stmts {
let inside_funs = block.stmts.iter().filter_map(|stmt| {
if let Stmt::Item(Item::Fn(fun)) = &stmt {
// is the function declared as async? If so, this is a good
// candidate, let's keep it in hand
// If the function is async, this is a candidate
if fun.sig.asyncness.is_some() {
inside_funs.push(fun);
}
} else if let Stmt::Expr(e) = &stmt {
last_expr = Some(e);
}
}
// let's play with (too much) pattern matching
// is the last expression a function call?
if let Some(Expr::Call(ExprCall {
func: outside_func,
args: outside_args,
..
})) = last_expr
{
if let Expr::Path(path) = outside_func.as_ref() {
// is it a call to `Box::pin()`?
if "Box::pin" == path_to_string(&path.path) {
// does it takes at least an argument? (if it doesn't,
// it's not gonna compile anyway, but that's no reason
// to (try to) perform an out of bounds access)
if outside_args.is_empty() {
return None;
}
// is the argument to Box::pin a function call itself?
if let Expr::Call(ExprCall { func, .. }) = &outside_args[0] {
if let Expr::Path(inside_path) = func.as_ref() {
// "stringify" the path of the function called
let func_name = path_to_string(&inside_path.path);
// is this function directly defined insided the current block?
for fun in inside_funs {
if fun.sig.ident == func_name {
// we must hook this function now
return Some(fun);
}
}
}
}
return Some((stmt, fun));
}
}
}
None
}
struct AsyncTraitInfo {
name: String,
self_type: Option<syn::TypePath>,
}
// Return the informations necessary to process a function annotated with async-trait.
fn get_async_trait_info(block: &Block, block_is_async: bool) -> Option<AsyncTraitInfo> {
let fun = get_async_trait_function(block, block_is_async)?;
// if "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
// parameter type) with the type of "_self"
let self_type = fun
.sig
.inputs
.iter()
.map(|arg| {
if let FnArg::Typed(ty) = arg {
if let Pat::Ident(PatIdent { ident, .. }) = &*ty.pat {
if ident == "_self" {
let mut ty = &*ty.ty;
// extract the inner type if the argument is "&self" or "&mut self"
if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty {
ty = &*elem;
}
if let syn::Type::Path(tp) = ty {
return Some(tp.clone());
}
}
}
}
None
});
// last expression of the block (it determines the return value
// of the block, so that if we are working on a function whose
// `trait` or `impl` declaration is annotated by async_trait,
// this is quite likely the point where the future is pinned)
let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
if let Stmt::Expr(expr) = stmt {
Some((stmt, expr))
} else {
None
})
.next();
let self_type = match self_type {
Some(x) => x,
None => None,
}
})?;
// is the last expression a function call?
let (outside_func, outside_args) = match last_expr {
Expr::Call(ExprCall { func, args, .. }) => (func, args),
_ => return None,
};
// is it a call to `Box::pin()`?
let path = match outside_func.as_ref() {
Expr::Path(path) => &path.path,
_ => return None,
};
if !path_to_string(path).ends_with("Box::pin") {
return None;
}
// Does the call take an argument? If it doesn't,
// it's not gonna compile anyway, but that's no reason
// to (try to) perform an out of bounds access
if outside_args.is_empty() {
return None;
}
// Is the argument to Box::pin an async block that
// captures its arguments?
if let Expr::Async(async_expr) = &outside_args[0] {
// check that the move 'keyword' is present
async_expr.capture?;
return Some(AsyncTraitInfo {
source_stmt: last_expr_stmt,
kind: AsyncTraitKind::Async(async_expr),
self_type: None,
});
}
// Is the argument to Box::pin a function call itself?
let func = match &outside_args[0] {
Expr::Call(ExprCall { func, .. }) => func,
_ => return None,
};
// "stringify" the path of the function called
let func_name = match **func {
Expr::Path(ref func_path) => path_to_string(&func_path.path),
_ => return None,
};
// Was that function defined inside of the current block?
// If so, retrieve the statement where it was declared and the function itself
let (stmt_func_declaration, func) = inside_funs
.into_iter()
.find(|(_, fun)| fun.sig.ident == func_name)?;
// If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
// parameter type) with the type of "_self"
let mut self_type = None;
for arg in &func.sig.inputs {
if let FnArg::Typed(ty) = arg {
if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
if ident == "_self" {
let mut ty = *ty.ty.clone();
// extract the inner type if the argument is "&self" or "&mut self"
if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty {
ty = *elem;
}
if let syn::Type::Path(tp) = ty {
self_type = Some(tp);
break;
}
}
}
}
}
Some(AsyncTraitInfo {
name: fun.sig.ident.to_string(),
source_stmt: stmt_func_declaration,
kind: AsyncTraitKind::Function(func),
self_type,
})
}
@@ -973,26 +1035,48 @@ fn path_to_string(path: &Path) -> String {
res
}
// A visitor struct replacing the "self" and "Self" tokens in user-supplied fields expressions when
// the function is generated by async-trait.
struct SelfReplacer {
ty: Option<syn::TypePath>,
/// A visitor struct to replace idents and types in some piece
/// of code (e.g. the "self" and "Self" tokens in user-supplied
/// fields expressions when the function is generated by an old
/// version of async-trait).
struct IdentAndTypesRenamer<'a> {
types: Vec<(&'a str, TypePath)>,
idents: Vec<(Ident, Ident)>,
}
impl syn::visit_mut::VisitMut for SelfReplacer {
impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> {
// we deliberately compare strings because we want to ignore the spans
// If we apply clippy's lint, the behavior changes
#[allow(clippy::cmp_owned)]
fn visit_ident_mut(&mut self, id: &mut Ident) {
if id == "self" {
*id = Ident::new("_self", id.span())
for (old_ident, new_ident) in &self.idents {
if id.to_string() == old_ident.to_string() {
*id = new_ident.clone();
}
}
}
fn visit_type_mut(&mut self, ty: &mut syn::Type) {
if let syn::Type::Path(syn::TypePath { ref mut path, .. }) = ty {
if path_to_string(path) == "Self" {
if let Some(ref true_type) = self.ty {
*path = true_type.path.clone();
for (type_name, new_type) in &self.types {
if let syn::Type::Path(TypePath { path, .. }) = ty {
if path_to_string(path) == *type_name {
*ty = syn::Type::Path(new_type.clone());
}
}
}
}
}
// A visitor struct that replace an async block by its patched version
struct AsyncTraitBlockReplacer<'a> {
block: &'a Block,
patched_block: Block,
}
impl<'a> syn::visit_mut::VisitMut for AsyncTraitBlockReplacer<'a> {
fn visit_block_mut(&mut self, i: &mut Block) {
if i == self.block {
*i = self.patched_block.clone();
}
}
}

View File

@@ -172,18 +172,19 @@ fn async_fn_with_async_trait_and_fields_expressions() {
#[async_trait]
impl Test for TestImpl {
// check that self is correctly handled, even when using async_trait
#[instrument(fields(val=self.foo(), test=%v+5))]
async fn call(&mut self, v: usize) {}
#[instrument(fields(val=self.foo(), val2=Self::clone(self).foo(), test=%_v+5))]
async fn call(&mut self, _v: usize) {}
}
let span = span::mock().named("call");
let (collector, handle) = collector::mock()
.new_span(
span.clone().with_field(
field::mock("v")
field::mock("_v")
.with_value(&tracing::field::debug(5))
.and(field::mock("test").with_value(&tracing::field::debug(10)))
.and(field::mock("val").with_value(&42u64)),
.and(field::mock("val").with_value(&42u64))
.and(field::mock("val2").with_value(&42u64)),
),
)
.enter(span.clone())