From 5b202cb6636e29a3b2ca9c479182239d5ce24c7d Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Sun, 9 Mar 2025 18:42:46 +0900 Subject: [PATCH] fix: Prevent wrong invocations of `needs_parens_in` with non-ancestral "parent"s --- .../src/handlers/apply_demorgan.rs | 14 +++-- .../src/handlers/inline_local_variable.rs | 54 +++++++++++++++- crates/syntax/src/ast/prec.rs | 63 ++++++++++++++++--- docs/book/src/assists_generated.md | 6 +- 4 files changed, 120 insertions(+), 17 deletions(-) diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs index 77562c588e..67bf8eed23 100644 --- a/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -128,7 +128,9 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti let parent = neg_expr.syntax().parent(); editor = builder.make_editor(neg_expr.syntax()); - if parent.is_some_and(|parent| demorganed.needs_parens_in(&parent)) { + if parent.is_some_and(|parent| { + demorganed.needs_parens_in_place_of(&parent, neg_expr.syntax()) + }) { cov_mark::hit!(demorgan_keep_parens_for_op_precedence2); editor.replace(neg_expr.syntax(), make.expr_paren(demorganed).syntax()); } else { @@ -392,15 +394,19 @@ fn f() { !(S <= S || S < S) } #[test] fn demorgan_keep_pars_for_op_precedence3() { - check_assist(apply_demorgan, "fn f() { (a || !(b &&$0 c); }", "fn f() { (a || !b || !c; }"); + check_assist( + apply_demorgan, + "fn f() { (a || !(b &&$0 c); }", + "fn f() { (a || (!b || !c); }", + ); } #[test] - fn demorgan_removes_pars_in_eq_precedence() { + fn demorgan_keeps_pars_in_eq_precedence() { check_assist( apply_demorgan, "fn() { let x = a && !(!b |$0| !c); }", - "fn() { let x = a && b && c; }", + "fn() { let x = a && (b && c); }", ) } diff --git a/crates/ide-assists/src/handlers/inline_local_variable.rs b/crates/ide-assists/src/handlers/inline_local_variable.rs index 0aa9970a72..36eed290dc 100644 --- a/crates/ide-assists/src/handlers/inline_local_variable.rs +++ b/crates/ide-assists/src/handlers/inline_local_variable.rs @@ -57,12 +57,14 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) } let usage_node = name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind())); - let usage_parent_option = usage_node.and_then(|it| it.parent()); + let usage_parent_option = usage_node.as_ref().and_then(|it| it.parent()); let usage_parent = match usage_parent_option { Some(u) => u, None => return Some((name_ref, false)), }; - Some((name_ref, initializer_expr.needs_parens_in(&usage_parent))) + let should_wrap = initializer_expr + .needs_parens_in_place_of(&usage_parent, usage_node.as_ref().unwrap()); + Some((name_ref, should_wrap)) }) .collect::>>()?; @@ -941,6 +943,54 @@ fn main() { fn main() { let _ = (|| 2)(); } +"#, + ); + } + + #[test] + fn test_wrap_in_parens() { + check_assist( + inline_local_variable, + r#" +fn main() { + let $0a = 123 < 456; + let b = !a; +} +"#, + r#" +fn main() { + let b = !(123 < 456); +} +"#, + ); + check_assist( + inline_local_variable, + r#" +trait Foo { + fn foo(&self); +} + +impl Foo for bool { + fn foo(&self) {} +} + +fn main() { + let $0a = 123 < 456; + let b = a.foo(); +} +"#, + r#" +trait Foo { + fn foo(&self); +} + +impl Foo for bool { + fn foo(&self) {} +} + +fn main() { + let b = (123 < 456).foo(); +} "#, ); } diff --git a/crates/syntax/src/ast/prec.rs b/crates/syntax/src/ast/prec.rs index 0c4da76299..4f0e2cad17 100644 --- a/crates/syntax/src/ast/prec.rs +++ b/crates/syntax/src/ast/prec.rs @@ -1,5 +1,7 @@ //! Precedence representation. +use stdx::always; + use crate::{ ast::{self, BinaryOp, Expr, HasArgList, RangeItem}, match_ast, AstNode, SyntaxNode, @@ -140,6 +142,22 @@ pub fn precedence(expr: &ast::Expr) -> ExprPrecedence { } } +fn check_ancestry(ancestor: &SyntaxNode, descendent: &SyntaxNode) -> bool { + let bail = || always!(false, "{} is not an ancestor of {}", ancestor, descendent); + + if !ancestor.text_range().contains_range(descendent.text_range()) { + return bail(); + } + + for anc in descendent.ancestors() { + if anc == *ancestor { + return true; + } + } + + bail() +} + impl Expr { pub fn precedence(&self) -> ExprPrecedence { precedence(self) @@ -153,9 +171,19 @@ impl Expr { /// Returns `true` if `self` would need to be wrapped in parentheses given that its parent is `parent`. pub fn needs_parens_in(&self, parent: &SyntaxNode) -> bool { + self.needs_parens_in_place_of(parent, self.syntax()) + } + + /// Returns `true` if `self` would need to be wrapped in parentheses if it replaces `place_of` + /// given that `place_of`'s parent is `parent`. + pub fn needs_parens_in_place_of(&self, parent: &SyntaxNode, place_of: &SyntaxNode) -> bool { + if !check_ancestry(parent, place_of) { + return false; + } + match_ast! { match parent { - ast::Expr(e) => self.needs_parens_in_expr(&e), + ast::Expr(e) => self.needs_parens_in_expr(&e, place_of), ast::Stmt(e) => self.needs_parens_in_stmt(Some(&e)), ast::StmtList(_) => self.needs_parens_in_stmt(None), ast::ArgList(_) => false, @@ -165,7 +193,7 @@ impl Expr { } } - fn needs_parens_in_expr(&self, parent: &Expr) -> bool { + fn needs_parens_in_expr(&self, parent: &Expr, place_of: &SyntaxNode) -> bool { // Parentheses are necessary when calling a function-like pointer that is a member of a struct or union // (e.g. `(a.f)()`). let is_parent_call_expr = matches!(parent, ast::Expr::CallExpr(_)); @@ -199,13 +227,17 @@ impl Expr { if self.is_paren_like() || parent.is_paren_like() - || self.is_prefix() && (parent.is_prefix() || !self.is_ordered_before(parent)) - || self.is_postfix() && (parent.is_postfix() || self.is_ordered_before(parent)) + || self.is_prefix() + && (parent.is_prefix() + || !self.is_ordered_before_parent_in_place_of(parent, place_of)) + || self.is_postfix() + && (parent.is_postfix() + || self.is_ordered_before_parent_in_place_of(parent, place_of)) { return false; } - let (left, right, inv) = match self.is_ordered_before(parent) { + let (left, right, inv) = match self.is_ordered_before_parent_in_place_of(parent, place_of) { true => (self, parent, false), false => (parent, self, true), }; @@ -413,13 +445,28 @@ impl Expr { } } - fn is_ordered_before(&self, other: &Expr) -> bool { + fn is_ordered_before_parent_in_place_of(&self, parent: &Expr, place_of: &SyntaxNode) -> bool { + use rowan::TextSize; use Expr::*; - return order(self) < order(other); + let self_range = self.syntax().text_range(); + let place_of_range = place_of.text_range(); + + let self_order_adjusted = order(self) - self_range.start() + place_of_range.start(); + + let parent_order = order(parent); + let parent_order_adjusted = if parent_order <= place_of_range.start() { + parent_order + } else if parent_order >= place_of_range.end() { + parent_order - place_of_range.len() + self_range.len() + } else { + return false; + }; + + return self_order_adjusted < parent_order_adjusted; /// Returns text range that can be used to compare two expression for order (which goes first). - fn order(this: &Expr) -> rowan::TextSize { + fn order(this: &Expr) -> TextSize { // For non-paren-like operators: get the operator itself let token = match this { RangeExpr(e) => e.op_token(), diff --git a/docs/book/src/assists_generated.md b/docs/book/src/assists_generated.md index 918ae4a579..9a80185179 100644 --- a/docs/book/src/assists_generated.md +++ b/docs/book/src/assists_generated.md @@ -306,7 +306,7 @@ fn main() { ### `apply_demorgan_iterator` -**Source:** [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L154) +**Source:** [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L156) Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to `Iterator::all` and `Iterator::any`. @@ -1070,7 +1070,7 @@ pub use foo::{Bar, Baz}; ### `expand_record_rest_pattern` -**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L24) +**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L26) Fills fields by replacing rest pattern in record patterns. @@ -1094,7 +1094,7 @@ fn foo(bar: Bar) { ### `expand_tuple_struct_rest_pattern` -**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L80) +**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L82) Fills fields by replacing rest pattern in tuple struct patterns.