From 5ed11234cc3ae44e971dd4d347ff5aa4e10b4fc4 Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Thu, 15 May 2025 23:36:00 +0300 Subject: [PATCH] Improve asm support Including: - Infer `label {}` and `const` operands. - Correctly handle unsafe check inside `label {}`. - Fix an embarrassing parser typo that cause labels to never be part of the AST --- crates/hir-def/src/expr_store.rs | 18 +++--- crates/hir-ty/src/diagnostics/unsafe_check.rs | 58 ++++++++++++++----- crates/hir-ty/src/infer/expr.rs | 36 +++++++----- crates/hir-ty/src/tests/macros.rs | 4 ++ crates/hir-ty/src/tests/simple.rs | 39 +++++++++++++ .../src/handlers/missing_unsafe.rs | 21 +++++++ crates/parser/src/grammar/expressions/atom.rs | 8 ++- crates/parser/test_data/generated/runner.rs | 2 + .../test_data/parser/inline/ok/asm_label.rast | 37 ++++++++++++ .../test_data/parser/inline/ok/asm_label.rs | 3 + 10 files changed, 186 insertions(+), 40 deletions(-) create mode 100644 crates/parser/test_data/parser/inline/ok/asm_label.rast create mode 100644 crates/parser/test_data/parser/inline/ok/asm_label.rs diff --git a/crates/hir-def/src/expr_store.rs b/crates/hir-def/src/expr_store.rs index 09ee286f5c..f617c3225a 100644 --- a/crates/hir-def/src/expr_store.rs +++ b/crates/hir-def/src/expr_store.rs @@ -298,17 +298,16 @@ impl ExpressionStore { Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op { AsmOperand::In { expr, .. } | AsmOperand::Out { expr: Some(expr), .. } - | AsmOperand::InOut { expr, .. } => f(*expr), + | AsmOperand::InOut { expr, .. } + | AsmOperand::Const(expr) + | AsmOperand::Label(expr) => f(*expr), AsmOperand::SplitInOut { in_expr, out_expr, .. } => { f(*in_expr); if let Some(out_expr) = out_expr { f(*out_expr); } } - AsmOperand::Out { expr: None, .. } - | AsmOperand::Const(_) - | AsmOperand::Label(_) - | AsmOperand::Sym(_) => (), + AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (), }), Expr::If { condition, then_branch, else_branch } => { f(*condition); @@ -435,17 +434,16 @@ impl ExpressionStore { Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op { AsmOperand::In { expr, .. } | AsmOperand::Out { expr: Some(expr), .. } - | AsmOperand::InOut { expr, .. } => f(*expr), + | AsmOperand::InOut { expr, .. } + | AsmOperand::Const(expr) + | AsmOperand::Label(expr) => f(*expr), AsmOperand::SplitInOut { in_expr, out_expr, .. } => { f(*in_expr); if let Some(out_expr) = out_expr { f(*out_expr); } } - AsmOperand::Out { expr: None, .. } - | AsmOperand::Const(_) - | AsmOperand::Label(_) - | AsmOperand::Sym(_) => (), + AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (), }), Expr::If { condition, then_branch, else_branch } => { f(*condition); diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs index 6d5aaa3472..20cf3c7811 100644 --- a/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -7,7 +7,7 @@ use either::Either; use hir_def::{ AdtId, DefWithBodyId, FieldId, FunctionId, VariantId, expr_store::{Body, path::Path}, - hir::{Expr, ExprId, ExprOrPatId, Pat, PatId, Statement, UnaryOp}, + hir::{AsmOperand, Expr, ExprId, ExprOrPatId, Pat, PatId, Statement, UnaryOp}, resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs}, signatures::StaticFlags, type_ref::Rawness, @@ -199,6 +199,17 @@ impl<'db> UnsafeVisitor<'db> { } } + fn with_inside_unsafe_block( + &mut self, + inside_unsafe_block: InsideUnsafeBlock, + f: impl FnOnce(&mut Self) -> R, + ) -> R { + let old = mem::replace(&mut self.inside_unsafe_block, inside_unsafe_block); + let result = f(self); + self.inside_unsafe_block = old; + result + } + fn walk_pats_top(&mut self, pats: impl Iterator, parent_expr: ExprId) { let guard = self.resolver.update_to_inner_scope(self.db, self.def, parent_expr); pats.for_each(|pat| self.walk_pat(pat)); @@ -303,7 +314,29 @@ impl<'db> UnsafeVisitor<'db> { self.walk_pats_top(std::iter::once(target), current); self.inside_assignment = old_inside_assignment; } - Expr::InlineAsm(_) => self.on_unsafe_op(current.into(), UnsafetyReason::InlineAsm), + Expr::InlineAsm(asm) => { + self.on_unsafe_op(current.into(), UnsafetyReason::InlineAsm); + asm.operands.iter().for_each(|(_, op)| match op { + AsmOperand::In { expr, .. } + | AsmOperand::Out { expr: Some(expr), .. } + | AsmOperand::InOut { expr, .. } + | AsmOperand::Const(expr) => self.walk_expr(*expr), + AsmOperand::SplitInOut { in_expr, out_expr, .. } => { + self.walk_expr(*in_expr); + if let Some(out_expr) = out_expr { + self.walk_expr(*out_expr); + } + } + AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (), + AsmOperand::Label(expr) => { + // Inline asm labels are considered safe even when inside unsafe blocks. + self.with_inside_unsafe_block(InsideUnsafeBlock::No, |this| { + this.walk_expr(*expr) + }); + } + }); + return; + } // rustc allows union assignment to propagate through field accesses and casts. Expr::Cast { .. } => self.inside_assignment = inside_assignment, Expr::Field { .. } => { @@ -317,17 +350,16 @@ impl<'db> UnsafeVisitor<'db> { } } Expr::Unsafe { statements, .. } => { - let old_inside_unsafe_block = - mem::replace(&mut self.inside_unsafe_block, InsideUnsafeBlock::Yes); - self.walk_pats_top( - statements.iter().filter_map(|statement| match statement { - &Statement::Let { pat, .. } => Some(pat), - _ => None, - }), - current, - ); - self.body.walk_child_exprs_without_pats(current, |child| self.walk_expr(child)); - self.inside_unsafe_block = old_inside_unsafe_block; + self.with_inside_unsafe_block(InsideUnsafeBlock::Yes, |this| { + this.walk_pats_top( + statements.iter().filter_map(|statement| match statement { + &Statement::Let { pat, .. } => Some(pat), + _ => None, + }), + current, + ); + this.body.walk_child_exprs_without_pats(current, |child| this.walk_expr(child)); + }); return; } Expr::Block { statements, .. } | Expr::Async { statements, .. } => { diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 8084b394d0..87b7f3406f 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -959,8 +959,8 @@ impl InferenceContext<'_> { } Expr::OffsetOf(_) => TyKind::Scalar(Scalar::Uint(UintTy::Usize)).intern(Interner), Expr::InlineAsm(asm) => { - let mut check_expr_asm_operand = |expr, is_input: bool| { - let ty = self.infer_expr_no_expect(expr, ExprIsRead::Yes); + let check_expr_asm_operand = |this: &mut Self, expr, is_input: bool| { + let ty = this.infer_expr_no_expect(expr, ExprIsRead::Yes); // If this is an input value, we require its type to be fully resolved // at this point. This allows us to provide helpful coercions which help @@ -970,18 +970,18 @@ impl InferenceContext<'_> { // allows them to be inferred based on how they are used later in the // function. if is_input { - let ty = self.resolve_ty_shallow(&ty); + let ty = this.resolve_ty_shallow(&ty); match ty.kind(Interner) { TyKind::FnDef(def, parameters) => { let fnptr_ty = TyKind::Function( - CallableSig::from_def(self.db, *def, parameters).to_fn_ptr(), + CallableSig::from_def(this.db, *def, parameters).to_fn_ptr(), ) .intern(Interner); - _ = self.coerce(Some(expr), &ty, &fnptr_ty, CoerceNever::Yes); + _ = this.coerce(Some(expr), &ty, &fnptr_ty, CoerceNever::Yes); } TyKind::Ref(mutbl, _, base_ty) => { let ptr_ty = TyKind::Raw(*mutbl, base_ty.clone()).intern(Interner); - _ = self.coerce(Some(expr), &ty, &ptr_ty, CoerceNever::Yes); + _ = this.coerce(Some(expr), &ty, &ptr_ty, CoerceNever::Yes); } _ => {} } @@ -990,22 +990,28 @@ impl InferenceContext<'_> { let diverge = asm.options.contains(AsmOptions::NORETURN); asm.operands.iter().for_each(|(_, operand)| match *operand { - AsmOperand::In { expr, .. } => check_expr_asm_operand(expr, true), + AsmOperand::In { expr, .. } => check_expr_asm_operand(self, expr, true), AsmOperand::Out { expr: Some(expr), .. } | AsmOperand::InOut { expr, .. } => { - check_expr_asm_operand(expr, false) + check_expr_asm_operand(self, expr, false) } AsmOperand::Out { expr: None, .. } => (), AsmOperand::SplitInOut { in_expr, out_expr, .. } => { - check_expr_asm_operand(in_expr, true); + check_expr_asm_operand(self, in_expr, true); if let Some(out_expr) = out_expr { - check_expr_asm_operand(out_expr, false); + check_expr_asm_operand(self, out_expr, false); } } - // FIXME - AsmOperand::Label(_) => (), - // FIXME - AsmOperand::Const(_) => (), - // FIXME + AsmOperand::Label(expr) => { + self.infer_expr( + expr, + &Expectation::HasType(self.result.standard_types.unit.clone()), + ExprIsRead::No, + ); + } + AsmOperand::Const(expr) => { + self.infer_expr(expr, &Expectation::None, ExprIsRead::No); + } + // FIXME: `sym` should report for things that are not functions or statics. AsmOperand::Sym(_) => (), }); if diverge { diff --git a/crates/hir-ty/src/tests/macros.rs b/crates/hir-ty/src/tests/macros.rs index 446f0b21a2..ea7a113cae 100644 --- a/crates/hir-ty/src/tests/macros.rs +++ b/crates/hir-ty/src/tests/macros.rs @@ -1505,6 +1505,10 @@ fn main() { !119..120 'o': i32 293..294 'o': i32 308..317 'thread_id': usize + !314..320 'OffPtr': usize + !333..338 'OffFn': usize + !354..355 '0': i32 + !371..382 'MEM_RELEASE': usize "#]], ) } diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 08127eeb46..cf51671afb 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -3926,3 +3926,42 @@ fn foo() { "#]], ); } + +#[test] +fn asm_const_label() { + check_infer( + r#" +//- minicore: asm +const fn bar() -> i32 { 123 } +fn baz(s: &str) {} + +fn foo() { + unsafe { + core::arch::asm!( + "mov eax, {}", + "jmp {}", + const bar(), + label { + baz("hello"); + }, + ); + } +} + "#, + expect![[r#" + 22..29 '{ 123 }': i32 + 24..27 '123': i32 + 37..38 's': &'? str + 46..48 '{}': () + !0..68 'builti...");},)': () + !40..43 'bar': fn bar() -> i32 + !40..45 'bar()': i32 + !51..66 '{baz("hello");}': () + !52..55 'baz': fn baz(&'? str) + !52..64 'baz("hello")': () + !56..63 '"hello"': &'static str + 59..257 '{ ... } }': () + 65..255 'unsafe... }': () + "#]], + ); +} diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index ff041e183b..364bead34e 100644 --- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -894,4 +894,25 @@ fn main() { "#, ); } + + #[test] + fn asm_label() { + check_diagnostics( + r#" +//- minicore: asm +fn foo() { + unsafe { + core::arch::asm!( + "jmp {}", + label { + let p = 0xDEADBEAF as *mut u8; + *p = 3; + // ^^ error: dereference of raw pointer is unsafe and requires an unsafe function or block + }, + ); + } +} + "#, + ); + } } diff --git a/crates/parser/src/grammar/expressions/atom.rs b/crates/parser/src/grammar/expressions/atom.rs index 5faf6fc275..8cc332d463 100644 --- a/crates/parser/src/grammar/expressions/atom.rs +++ b/crates/parser/src/grammar/expressions/atom.rs @@ -381,10 +381,14 @@ fn parse_asm_expr(p: &mut Parser<'_>, m: Marker) -> Option { op.complete(p, ASM_REG_OPERAND); op_n.complete(p, ASM_OPERAND_NAMED); } else if p.eat_contextual_kw(T![label]) { + // test asm_label + // fn foo() { + // builtin#asm("", label {}); + // } dir_spec.abandon(p); block_expr(p); - op.complete(p, ASM_OPERAND_NAMED); - op_n.complete(p, ASM_LABEL); + op.complete(p, ASM_LABEL); + op_n.complete(p, ASM_OPERAND_NAMED); } else if p.eat(T![const]) { dir_spec.abandon(p); expr(p); diff --git a/crates/parser/test_data/generated/runner.rs b/crates/parser/test_data/generated/runner.rs index 24db9478ee..030d8e0f04 100644 --- a/crates/parser/test_data/generated/runner.rs +++ b/crates/parser/test_data/generated/runner.rs @@ -21,6 +21,8 @@ mod ok { #[test] fn asm_expr() { run_and_expect_no_errors("test_data/parser/inline/ok/asm_expr.rs"); } #[test] + fn asm_label() { run_and_expect_no_errors("test_data/parser/inline/ok/asm_label.rs"); } + #[test] fn assoc_const_eq() { run_and_expect_no_errors("test_data/parser/inline/ok/assoc_const_eq.rs"); } diff --git a/crates/parser/test_data/parser/inline/ok/asm_label.rast b/crates/parser/test_data/parser/inline/ok/asm_label.rast new file mode 100644 index 0000000000..38999c9cd3 --- /dev/null +++ b/crates/parser/test_data/parser/inline/ok/asm_label.rast @@ -0,0 +1,37 @@ +SOURCE_FILE + FN + FN_KW "fn" + WHITESPACE " " + NAME + IDENT "foo" + PARAM_LIST + L_PAREN "(" + R_PAREN ")" + WHITESPACE " " + BLOCK_EXPR + STMT_LIST + L_CURLY "{" + WHITESPACE "\n " + EXPR_STMT + ASM_EXPR + BUILTIN_KW "builtin" + POUND "#" + ASM_KW "asm" + L_PAREN "(" + LITERAL + STRING "\"\"" + COMMA "," + WHITESPACE " " + ASM_OPERAND_NAMED + ASM_LABEL + LABEL_KW "label" + WHITESPACE " " + BLOCK_EXPR + STMT_LIST + L_CURLY "{" + R_CURLY "}" + R_PAREN ")" + SEMICOLON ";" + WHITESPACE "\n" + R_CURLY "}" + WHITESPACE "\n" diff --git a/crates/parser/test_data/parser/inline/ok/asm_label.rs b/crates/parser/test_data/parser/inline/ok/asm_label.rs new file mode 100644 index 0000000000..996c1c8477 --- /dev/null +++ b/crates/parser/test_data/parser/inline/ok/asm_label.rs @@ -0,0 +1,3 @@ +fn foo() { + builtin#asm("", label {}); +}