From d7a8e800c9270023d5a19f793f28054b936f7c5d Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 19 Aug 2023 16:49:26 -0600 Subject: [PATCH 01/10] feat: initial version of bool_to_enum assist --- .../ide-assists/src/handlers/bool_to_enum.rs | 724 ++++++++++++++++++ crates/ide-assists/src/lib.rs | 2 + crates/syntax/src/ast/make.rs | 28 + 3 files changed, 754 insertions(+) create mode 100644 crates/ide-assists/src/handlers/bool_to_enum.rs diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs new file mode 100644 index 0000000000..dd11824b99 --- /dev/null +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -0,0 +1,724 @@ +use ide_db::{ + assists::{AssistId, AssistKind}, + defs::Definition, + search::{FileReference, SearchScope, UsageSearchResult}, + source_change::SourceChangeBuilder, +}; +use syntax::{ + ast::{ + self, + edit::IndentLevel, + edit_in_place::{AttrsOwnerEdit, Indent}, + make, HasName, + }, + ted, AstNode, NodeOrToken, SyntaxNode, T, +}; + +use crate::assist_context::{AssistContext, Assists}; + +// Assist: bool_to_enum +// +// This converts boolean local variables, fields, constants, and statics into a new +// enum with two variants `Bool::True` and `Bool::False`, as well as replacing +// all assignments with the variants and replacing all usages with `== Bool::True` or +// `== Bool::False`. +// +// ``` +// fn main() { +// let $0bool = true; +// +// if bool { +// println!("foo"); +// } +// } +// ``` +// -> +// ``` +// fn main() { +// #[derive(PartialEq, Eq)] +// enum Bool { True, False } +// +// let bool = Bool::True; +// +// if bool == Bool::True { +// println!("foo"); +// } +// } +// ``` +pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = + find_bool_node(ctx)?; + + let target = name.syntax().text_range(); + acc.add( + AssistId("bool_to_enum", AssistKind::RefactorRewrite), + "Convert boolean to enum", + target, + |edit| { + if let Some(ty) = &ty_annotation { + cov_mark::hit!(replaces_ty_annotation); + edit.replace(ty.syntax().text_range(), "Bool"); + } + + if let Some(initializer) = initializer { + replace_bool_expr(edit, initializer); + } + + let usages = definition + .usages(&ctx.sema) + .in_scope(&SearchScope::single_file(ctx.file_id())) + .all(); + replace_usages(edit, &usages); + + add_enum_def(edit, ctx, &usages, target_node); + }, + ) +} + +struct BoolNodeData { + target_node: SyntaxNode, + name: ast::Name, + ty_annotation: Option, + initializer: Option, + definition: Definition, +} + +/// Attempts to find an appropriate node to apply the action to. +fn find_bool_node(ctx: &AssistContext<'_>) -> Option { + if let Some(let_stmt) = ctx.find_node_at_offset::() { + let bind_pat = match let_stmt.pat()? { + ast::Pat::IdentPat(pat) => pat, + _ => { + cov_mark::hit!(not_applicable_in_non_ident_pat); + return None; + } + }; + let def = ctx.sema.to_def(&bind_pat)?; + if !def.ty(ctx.db()).is_bool() { + cov_mark::hit!(not_applicable_non_bool_local); + return None; + } + + Some(BoolNodeData { + target_node: let_stmt.syntax().clone(), + name: bind_pat.name()?, + ty_annotation: let_stmt.ty(), + initializer: let_stmt.initializer(), + definition: Definition::Local(def), + }) + } else if let Some(const_) = ctx.find_node_at_offset::() { + let def = ctx.sema.to_def(&const_)?; + if !def.ty(ctx.db()).is_bool() { + cov_mark::hit!(not_applicable_non_bool_const); + return None; + } + + Some(BoolNodeData { + target_node: const_.syntax().clone(), + name: const_.name()?, + ty_annotation: const_.ty(), + initializer: const_.body(), + definition: Definition::Const(def), + }) + } else if let Some(static_) = ctx.find_node_at_offset::() { + let def = ctx.sema.to_def(&static_)?; + if !def.ty(ctx.db()).is_bool() { + cov_mark::hit!(not_applicable_non_bool_static); + return None; + } + + Some(BoolNodeData { + target_node: static_.syntax().clone(), + name: static_.name()?, + ty_annotation: static_.ty(), + initializer: static_.body(), + definition: Definition::Static(def), + }) + } else if let Some(field_name) = ctx.find_node_at_offset::() { + let field = field_name.syntax().ancestors().find_map(ast::RecordField::cast)?; + if field.name()? != field_name { + return None; + } + + let strukt = field.syntax().ancestors().find_map(ast::Struct::cast)?; + let def = ctx.sema.to_def(&field)?; + if !def.ty(ctx.db()).is_bool() { + cov_mark::hit!(not_applicable_non_bool_field); + return None; + } + Some(BoolNodeData { + target_node: strukt.syntax().clone(), + name: field_name, + ty_annotation: field.ty(), + initializer: None, + definition: Definition::Field(def), + }) + } else { + None + } +} + +fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) { + let expr_range = expr.syntax().text_range(); + let enum_expr = bool_expr_to_enum_expr(expr); + edit.replace(expr_range, enum_expr.syntax().text()) +} + +/// Converts an expression of type `bool` to one of the new enum type. +fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr { + let true_expr = make::expr_path(make::path_from_text("Bool::True")).clone_for_update(); + let false_expr = make::expr_path(make::path_from_text("Bool::False")).clone_for_update(); + + if let ast::Expr::Literal(literal) = &expr { + match literal.kind() { + ast::LiteralKind::Bool(true) => true_expr, + ast::LiteralKind::Bool(false) => false_expr, + _ => expr, + } + } else { + make::expr_if( + expr, + make::tail_only_block_expr(true_expr), + Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))), + ) + .clone_for_update() + } +} + +/// Replaces all usages of the target identifier, both when read and written to. +fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { + for (_, references) in usages.iter() { + references + .into_iter() + .filter_map(|FileReference { range, name, .. }| match name { + ast::NameLike::NameRef(name) => Some((*range, name)), + _ => None, + }) + .for_each(|(range, name_ref)| { + if let Some(initializer) = find_assignment_usage(name_ref) { + cov_mark::hit!(replaces_assignment); + + replace_bool_expr(edit, initializer); + } else if let Some((prefix_expr, expr)) = find_negated_usage(name_ref) { + cov_mark::hit!(replaces_negation); + + edit.replace( + prefix_expr.syntax().text_range(), + format!("{} == Bool::False", expr), + ); + } else if let Some((record_field, initializer)) = find_record_expr_usage(name_ref) { + cov_mark::hit!(replaces_record_expr); + + let record_field = edit.make_mut(record_field); + let enum_expr = bool_expr_to_enum_expr(initializer); + record_field.replace_expr(enum_expr); + } else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() { + // for any other usage in an expression, replace it with a check that it is the true variant + edit.replace(range, format!("{} == Bool::True", name_ref.text())); + } + }) + } +} + +fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { + let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; + + if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() { + bin_expr.rhs() + } else { + None + } +} + +fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { + let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; + + if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() { + let initializer = prefix_expr.expr()?; + Some((prefix_expr, initializer)) + } else { + None + } +} + +fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprField, ast::Expr)> { + let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; + let initializer = record_field.expr()?; + + Some((record_field, initializer)) +} + +/// Adds the definition of the new enum before the target node. +fn add_enum_def( + edit: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + usages: &UsageSearchResult, + target_node: SyntaxNode, +) { + let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id()); + let enum_def = make_bool_enum(make_enum_pub); + + let indent = IndentLevel::from_node(&target_node); + enum_def.reindent_to(indent); + + ted::insert_all( + ted::Position::before(&edit.make_syntax_mut(target_node)), + vec![ + enum_def.syntax().clone().into(), + make::tokens::whitespace(&format!("\n\n{indent}")).into(), + ], + ); +} + +fn make_bool_enum(make_pub: bool) -> ast::Enum { + let enum_def = make::enum_( + if make_pub { Some(make::visibility_pub()) } else { None }, + make::name("Bool"), + make::variant_list(vec![ + make::variant(make::name("True"), None), + make::variant(make::name("False"), None), + ]), + ) + .clone_for_update(); + + let derive_eq = make::attr_outer(make::meta_token_tree( + make::ext::ident_path("derive"), + make::token_tree( + T!['('], + vec![ + NodeOrToken::Token(make::tokens::ident("PartialEq")), + NodeOrToken::Token(make::token(T![,])), + NodeOrToken::Token(make::tokens::single_space()), + NodeOrToken::Token(make::tokens::ident("Eq")), + ], + ), + )) + .clone_for_update(); + enum_def.add_attr(derive_eq); + + enum_def +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::{check_assist, check_assist_not_applicable}; + + #[test] + fn local_variable_with_usage() { + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = true; + + if foo { + println!("foo"); + } +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + + if foo == Bool::True { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn local_variable_with_usage_negated() { + cov_mark::check!(replaces_negation); + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = true; + + if !foo { + println!("foo"); + } +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + + if foo == Bool::False { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn local_variable_with_type_annotation() { + cov_mark::check!(replaces_ty_annotation); + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo: bool = false; +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo: Bool = Bool::False; +} +"#, + ) + } + + #[test] + fn local_variable_with_non_literal_initializer() { + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = 1 == 2; +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = if 1 == 2 { Bool::True } else { Bool::False }; +} +"#, + ) + } + + #[test] + fn local_variable_binexpr_usage() { + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = false; + let bar = true; + + if !foo && bar { + println!("foobar"); + } +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::False; + let bar = true; + + if foo == Bool::False && bar { + println!("foobar"); + } +} +"#, + ) + } + + #[test] + fn local_variable_unop_usage() { + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = true; + + if *&foo { + println!("foobar"); + } +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + + if *&foo == Bool::True { + println!("foobar"); + } +} +"#, + ) + } + + #[test] + fn local_variable_assigned_later() { + cov_mark::check!(replaces_assignment); + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo: bool; + foo = true; +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo: Bool; + foo = Bool::True; +} +"#, + ) + } + + #[test] + fn local_variable_does_not_apply_recursively() { + check_assist( + bool_to_enum, + r#" +fn main() { + let $0foo = true; + let bar = !foo; + + if bar { + println!("bar"); + } +} +"#, + r#" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + let bar = foo == Bool::False; + + if bar { + println!("bar"); + } +} +"#, + ) + } + + #[test] + fn local_variable_non_bool() { + cov_mark::check!(not_applicable_non_bool_local); + check_assist_not_applicable( + bool_to_enum, + r#" +fn main() { + let $0foo = 1; +} +"#, + ) + } + + #[test] + fn local_variable_non_ident_pat() { + cov_mark::check!(not_applicable_in_non_ident_pat); + check_assist_not_applicable( + bool_to_enum, + r#" +fn main() { + let ($0foo, bar) = (true, false); +} +"#, + ) + } + + #[test] + fn field_basic() { + cov_mark::check!(replaces_record_expr); + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bar: bool, + baz: bool, +} + +fn main() { + let foo = Foo { bar: true, baz: false }; + + if foo.bar { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + bar: Bool, + baz: bool, +} + +fn main() { + let foo = Foo { bar: Bool::True, baz: false }; + + if foo.bar == Bool::True { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn field_in_mod_properly_indented() { + check_assist( + bool_to_enum, + r#" +mod foo { + struct Bar { + $0baz: bool, + } + + impl Bar { + fn new(baz: bool) -> Self { + Self { baz } + } + } +} +"#, + r#" +mod foo { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + struct Bar { + baz: Bool, + } + + impl Bar { + fn new(baz: bool) -> Self { + Self { baz: if baz { Bool::True } else { Bool::False } } + } + } +} +"#, + ) + } + + #[test] + fn field_non_bool() { + cov_mark::check!(not_applicable_non_bool_field); + check_assist_not_applicable( + bool_to_enum, + r#" +struct Foo { + $0bar: usize, +} + +fn main() { + let foo = Foo { bar: 1 }; +} +"#, + ) + } + + #[test] + fn const_basic() { + check_assist( + bool_to_enum, + r#" +const $0FOO: bool = false; + +fn main() { + if FOO { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +const FOO: Bool = Bool::False; + +fn main() { + if FOO == Bool::True { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn const_non_bool() { + cov_mark::check!(not_applicable_non_bool_const); + check_assist_not_applicable( + bool_to_enum, + r#" +const $0FOO: &str = "foo"; + +fn main() { + println!("{FOO}"); +} +"#, + ) + } + + #[test] + fn static_basic() { + check_assist( + bool_to_enum, + r#" +static mut $0BOOL: bool = true; + +fn main() { + unsafe { BOOL = false }; + if unsafe { BOOL } { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +static mut BOOL: Bool = Bool::True; + +fn main() { + unsafe { BOOL = Bool::False }; + if unsafe { BOOL == Bool::True } { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn static_non_bool() { + cov_mark::check!(not_applicable_non_bool_static); + check_assist_not_applicable( + bool_to_enum, + r#" +static mut $0FOO: usize = 0; + +fn main() { + if unsafe { FOO } == 0 { + println!("foo"); + } +} +"#, + ) + } +} diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs index 6f973ab53e..a17ce93e92 100644 --- a/crates/ide-assists/src/lib.rs +++ b/crates/ide-assists/src/lib.rs @@ -115,6 +115,7 @@ mod handlers { mod apply_demorgan; mod auto_import; mod bind_unused_param; + mod bool_to_enum; mod change_visibility; mod convert_bool_then; mod convert_comment_block; @@ -227,6 +228,7 @@ mod handlers { apply_demorgan::apply_demorgan, auto_import::auto_import, bind_unused_param::bind_unused_param, + bool_to_enum::bool_to_enum, change_visibility::change_visibility, convert_bool_then::convert_bool_then_to_if, convert_bool_then::convert_if_to_bool_then, diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 17e311c0c5..e0055be6e6 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -973,6 +973,11 @@ pub fn tuple_field(visibility: Option, ty: ast::Type) -> ast::T ast_from_text(&format!("struct f({visibility}{ty});")) } +pub fn variant_list(variants: impl IntoIterator) -> ast::VariantList { + let variants = variants.into_iter().join(", "); + ast_from_text(&format!("enum f {{ {variants} }}")) +} + pub fn variant(name: ast::Name, field_list: Option) -> ast::Variant { let field_list = match field_list { None => String::new(), @@ -1037,6 +1042,19 @@ pub fn struct_( ast_from_text(&format!("{visibility}struct {strukt_name}{type_params}{field_list}{semicolon}",)) } +pub fn enum_( + visibility: Option, + enum_name: ast::Name, + variant_list: ast::VariantList, +) -> ast::Enum { + let visibility = match visibility { + None => String::new(), + Some(it) => format!("{it} "), + }; + + ast_from_text(&format!("{visibility}enum {enum_name} {variant_list}")) +} + pub fn attr_outer(meta: ast::Meta) -> ast::Attr { ast_from_text(&format!("#[{meta}]")) } @@ -1149,6 +1167,16 @@ pub mod tokens { lit.syntax().first_child_or_token().unwrap().into_token().unwrap() } + pub fn ident(text: &str) -> SyntaxToken { + assert_eq!(text.trim(), text); + let path: ast::Path = super::ext::ident_path(text); + path.syntax() + .descendants_with_tokens() + .filter_map(|it| it.into_token()) + .find(|it| it.kind() == IDENT) + .unwrap() + } + pub fn single_newline() -> SyntaxToken { let res = SOURCE_FILE .tree() From 59738d5fd5f868cec69e0ff30e27a6b80fc81ee4 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 19 Aug 2023 17:41:44 -0600 Subject: [PATCH 02/10] fix: add generated doctest --- crates/ide-assists/src/tests/generated.rs | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index dfaa53449f..63a08a0e56 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -280,6 +280,34 @@ fn some_function(x: i32) { ) } +#[test] +fn doctest_bool_to_enum() { + check_doc_test( + "bool_to_enum", + r#####" +fn main() { + let $0bool = true; + + if bool { + println!("foo"); + } +} +"#####, + r#####" +fn main() { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let bool = Bool::True; + + if bool == Bool::True { + println!("foo"); + } +} +"#####, + ) +} + #[test] fn doctest_change_visibility() { check_doc_test( From 83196fd4d9ed8544410fc82fee5d54830163f248 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 19 Aug 2023 17:45:16 -0600 Subject: [PATCH 03/10] fix: remove trailing whitespace --- crates/ide-assists/src/handlers/bool_to_enum.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index dd11824b99..279e558362 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -369,7 +369,7 @@ fn main() { bool_to_enum, r#" fn main() { - let $0foo: bool = false; + let $0foo: bool = false; } "#, r#" @@ -377,7 +377,7 @@ fn main() { #[derive(PartialEq, Eq)] enum Bool { True, False } - let foo: Bool = Bool::False; + let foo: Bool = Bool::False; } "#, ) @@ -389,7 +389,7 @@ fn main() { bool_to_enum, r#" fn main() { - let $0foo = 1 == 2; + let $0foo = 1 == 2; } "#, r#" @@ -397,7 +397,7 @@ fn main() { #[derive(PartialEq, Eq)] enum Bool { True, False } - let foo = if 1 == 2 { Bool::True } else { Bool::False }; + let foo = if 1 == 2 { Bool::True } else { Bool::False }; } "#, ) @@ -468,7 +468,7 @@ fn main() { bool_to_enum, r#" fn main() { - let $0foo: bool; + let $0foo: bool; foo = true; } "#, @@ -477,7 +477,7 @@ fn main() { #[derive(PartialEq, Eq)] enum Bool { True, False } - let foo: Bool; + let foo: Bool; foo = Bool::True; } "#, From 91ac1d619475e1b61bf4ae8d318c4740a0adce66 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Fri, 8 Sep 2023 07:45:23 -0700 Subject: [PATCH 04/10] fix: initializing struct multiple times --- .../ide-assists/src/handlers/bool_to_enum.rs | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 279e558362..4158b75dc0 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -194,6 +194,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { ast::NameLike::NameRef(name) => Some((*range, name)), _ => None, }) + .rev() .for_each(|(range, name_ref)| { if let Some(initializer) = find_assignment_usage(name_ref) { cov_mark::hit!(replaces_assignment); @@ -615,6 +616,46 @@ mod foo { ) } + #[test] + fn field_multiple_initializations() { + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bar: bool, + baz: bool, +} + +fn main() { + let foo1 = Foo { bar: true, baz: false }; + let foo2 = Foo { bar: false, baz: false }; + + if foo1.bar && foo2.bar { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum $0Bool { True, False } + +struct Foo { + bar: Bool, + baz: bool, +} + +fn main() { + let foo1 = Foo { bar: Bool::True, baz: false }; + let foo2 = Foo { bar: Bool::False, baz: false }; + + if foo1.bar == Bool::True && foo2.bar == Bool::True { + println!("foo"); + } +} +"#, + ) + } + #[test] fn field_non_bool() { cov_mark::check!(not_applicable_non_bool_field); From 455dacfd3b5387bcf2854f2a88edb9b69361e69f Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Fri, 8 Sep 2023 10:06:17 -0700 Subject: [PATCH 05/10] fix: only trigger assist on Name --- .../ide-assists/src/handlers/bool_to_enum.rs | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 4158b75dc0..56749edf46 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -85,7 +85,9 @@ struct BoolNodeData { /// Attempts to find an appropriate node to apply the action to. fn find_bool_node(ctx: &AssistContext<'_>) -> Option { - if let Some(let_stmt) = ctx.find_node_at_offset::() { + let name: ast::Name = ctx.find_node_at_offset()?; + + if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) { let bind_pat = match let_stmt.pat()? { ast::Pat::IdentPat(pat) => pat, _ => { @@ -101,12 +103,12 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { Some(BoolNodeData { target_node: let_stmt.syntax().clone(), - name: bind_pat.name()?, + name, ty_annotation: let_stmt.ty(), initializer: let_stmt.initializer(), definition: Definition::Local(def), }) - } else if let Some(const_) = ctx.find_node_at_offset::() { + } else if let Some(const_) = name.syntax().ancestors().find_map(ast::Const::cast) { let def = ctx.sema.to_def(&const_)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_const); @@ -115,12 +117,12 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { Some(BoolNodeData { target_node: const_.syntax().clone(), - name: const_.name()?, + name, ty_annotation: const_.ty(), initializer: const_.body(), definition: Definition::Const(def), }) - } else if let Some(static_) = ctx.find_node_at_offset::() { + } else if let Some(static_) = name.syntax().ancestors().find_map(ast::Static::cast) { let def = ctx.sema.to_def(&static_)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_static); @@ -129,14 +131,14 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { Some(BoolNodeData { target_node: static_.syntax().clone(), - name: static_.name()?, + name, ty_annotation: static_.ty(), initializer: static_.body(), definition: Definition::Static(def), }) - } else if let Some(field_name) = ctx.find_node_at_offset::() { - let field = field_name.syntax().ancestors().find_map(ast::RecordField::cast)?; - if field.name()? != field_name { + } else { + let field = name.syntax().ancestors().find_map(ast::RecordField::cast)?; + if field.name()? != name { return None; } @@ -148,13 +150,11 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { } Some(BoolNodeData { target_node: strukt.syntax().clone(), - name: field_name, + name, ty_annotation: field.ty(), initializer: None, definition: Definition::Field(def), }) - } else { - None } } @@ -528,6 +528,18 @@ fn main() { ) } + #[test] + fn local_variable_cursor_not_on_ident() { + check_assist_not_applicable( + bool_to_enum, + r#" +fn main() { + let foo = $0true; +} +"#, + ) + } + #[test] fn local_variable_non_ident_pat() { cov_mark::check!(not_applicable_in_non_ident_pat); @@ -762,4 +774,9 @@ fn main() { "#, ) } + + #[test] + fn not_applicable_to_other_names() { + check_assist_not_applicable(bool_to_enum, "fn $0main() {}") + } } From 136a9dbe36606cb00b546c3562088c462d8a0926 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Fri, 8 Sep 2023 10:54:30 -0700 Subject: [PATCH 06/10] style: rename some locals --- crates/ide-assists/src/handlers/bool_to_enum.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 56749edf46..9752264844 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -200,12 +200,12 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { cov_mark::hit!(replaces_assignment); replace_bool_expr(edit, initializer); - } else if let Some((prefix_expr, expr)) = find_negated_usage(name_ref) { + } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(name_ref) { cov_mark::hit!(replaces_negation); edit.replace( prefix_expr.syntax().text_range(), - format!("{} == Bool::False", expr), + format!("{} == Bool::False", inner_expr), ); } else if let Some((record_field, initializer)) = find_record_expr_usage(name_ref) { cov_mark::hit!(replaces_record_expr); @@ -235,8 +235,8 @@ fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast:: let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() { - let initializer = prefix_expr.expr()?; - Some((prefix_expr, initializer)) + let inner_expr = prefix_expr.expr()?; + Some((prefix_expr, inner_expr)) } else { None } From 2e13aed3bc235d47d92f9ce3b8fd4fa3c5f87939 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 9 Sep 2023 11:40:29 -0700 Subject: [PATCH 07/10] feat: support cross module imports --- .../ide-assists/src/handlers/bool_to_enum.rs | 226 +++++++++++++++++- 1 file changed, 214 insertions(+), 12 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 9752264844..f59b052813 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -1,9 +1,13 @@ +use hir::ModuleDef; use ide_db::{ assists::{AssistId, AssistKind}, defs::Definition, - search::{FileReference, SearchScope, UsageSearchResult}, + helpers::mod_path_to_ast, + imports::insert_use::{insert_use, ImportScope}, + search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, }; +use itertools::Itertools; use syntax::{ ast::{ self, @@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists}; pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = find_bool_node(ctx)?; + let target_module = ctx.sema.scope(&target_node)?.module(); let target = name.syntax().text_range(); acc.add( @@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option replace_bool_expr(edit, initializer); } - let usages = definition - .usages(&ctx.sema) - .in_scope(&SearchScope::single_file(ctx.file_id())) - .all(); - replace_usages(edit, &usages); + let usages = definition.usages(&ctx.sema).all(); - add_enum_def(edit, ctx, &usages, target_node); + add_enum_def(edit, ctx, &usages, target_node, &target_module); + replace_usages(edit, ctx, &usages, &target_module); }, ) } @@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr { } /// Replaces all usages of the target identifier, both when read and written to. -fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { - for (_, references) in usages.iter() { +fn replace_usages( + edit: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + usages: &UsageSearchResult, + target_module: &hir::Module, +) { + for (file_id, references) in usages.iter() { + edit.edit_file(*file_id); + + // add imports across modules where needed + references + .iter() + .filter_map(|FileReference { name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) + }) + .unique_by(|name_and_module| name_and_module.1) + .filter(|(_, module)| module != target_module) + .filter_map(|(name, module)| { + let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); + let mod_path = module.find_use_path_prefixed( + ctx.sema.db, + ModuleDef::Module(*target_module), + ctx.config.insert_use.prefix_kind, + ctx.config.prefer_no_std, + ); + import_scope.zip(mod_path) + }) + .for_each(|(import_scope, mod_path)| { + let import_scope = match import_scope { + ImportScope::File(it) => ImportScope::File(edit.make_mut(it)), + ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)), + ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)), + }; + let path = + make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")); + insert_use(&import_scope, path, &ctx.config.insert_use); + }); + + // replace the usages in expressions references .into_iter() .filter_map(|FileReference { range, name, .. }| match name { @@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { let record_field = edit.make_mut(record_field); let enum_expr = bool_expr_to_enum_expr(initializer); record_field.replace_expr(enum_expr); - } else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() { + } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant edit.replace(range, format!("{} == Bool::True", name_ref.text())); } @@ -255,8 +294,15 @@ fn add_enum_def( ctx: &AssistContext<'_>, usages: &UsageSearchResult, target_node: SyntaxNode, + target_module: &hir::Module, ) { - let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id()); + let make_enum_pub = usages + .iter() + .flat_map(|(_, refs)| refs) + .filter_map(|FileReference { name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| scope.module()) + }) + .any(|module| &module != target_module); let enum_def = make_bool_enum(make_enum_pub); let indent = IndentLevel::from_node(&target_node); @@ -649,7 +695,7 @@ fn main() { "#, r#" #[derive(PartialEq, Eq)] -enum $0Bool { True, False } +enum Bool { True, False } struct Foo { bar: Bool, @@ -713,6 +759,162 @@ fn main() { ) } + #[test] + fn const_in_module() { + check_assist( + bool_to_enum, + r#" +fn main() { + if foo::FOO { + println!("foo"); + } +} + +mod foo { + pub const $0FOO: bool = true; +} +"#, + r#" +use foo::Bool; + +fn main() { + if foo::FOO == Bool::True { + println!("foo"); + } +} + +mod foo { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const FOO: Bool = Bool::True; +} +"#, + ) + } + + #[test] + fn const_in_module_with_import() { + check_assist( + bool_to_enum, + r#" +fn main() { + use foo::FOO; + + if FOO { + println!("foo"); + } +} + +mod foo { + pub const $0FOO: bool = true; +} +"#, + r#" +use crate::foo::Bool; + +fn main() { + use foo::FOO; + + if FOO == Bool::True { + println!("foo"); + } +} + +mod foo { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const FOO: Bool = Bool::True; +} +"#, + ) + } + + #[test] + fn const_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +fn main() { + if foo::FOO { + println!("foo"); + } +} + +//- /foo.rs +pub const $0FOO: bool = true; +"#, + r#" +//- /main.rs +use foo::Bool; + +mod foo; + +fn main() { + if foo::FOO == Bool::True { + println!("foo"); + } +} + +//- /foo.rs +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +pub const FOO: Bool = Bool::True; +"#, + ) + } + + #[test] + fn const_cross_file_and_module() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +fn main() { + use foo::bar; + + if bar::BAR { + println!("foo"); + } +} + +//- /foo.rs +pub mod bar { + pub const $0BAR: bool = false; +} +"#, + r#" +//- /main.rs +use crate::foo::bar::Bool; + +mod foo; + +fn main() { + use foo::bar; + + if bar::BAR == Bool::True { + println!("foo"); + } +} + +//- /foo.rs +pub mod bar { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const BAR: Bool = Bool::False; +} +"#, + ) + } + #[test] fn const_non_bool() { cov_mark::check!(not_applicable_non_bool_const); From 7ba2e130b975f62906fff6ea82aacff883a9e528 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 9 Sep 2023 23:54:25 -0700 Subject: [PATCH 08/10] fix: add checks for overwriting incorrect ancestor --- .../ide-assists/src/handlers/bool_to_enum.rs | 166 +++++++++++++++++- 1 file changed, 165 insertions(+), 1 deletion(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index f59b052813..784a0d3559 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -263,6 +263,11 @@ fn replace_usages( fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; + if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) { + cov_mark::hit!(dont_assign_incorrect_ref); + return None; + } + if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() { bin_expr.rhs() } else { @@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; + if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) { + cov_mark::hit!(dont_overwrite_expression_inside_negation); + return None; + } + if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() { let inner_expr = prefix_expr.expr()?; Some((prefix_expr, inner_expr)) @@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; let initializer = record_field.expr()?; - Some((record_field, initializer)) + if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) { + Some((record_field, initializer)) + } else { + cov_mark::hit!(dont_overwrite_wrong_record_field); + None + } } /// Adds the definition of the new enum before the target node. @@ -561,6 +576,37 @@ fn main() { ) } + #[test] + fn local_variable_nested_in_negation() { + cov_mark::check!(dont_overwrite_expression_inside_negation); + check_assist( + bool_to_enum, + r#" +fn main() { + if !"foo".chars().any(|c| { + let $0foo = true; + foo + }) { + println!("foo"); + } +} +"#, + r#" +fn main() { + if !"foo".chars().any(|c| { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + foo == Bool::True + }) { + println!("foo"); + } +} +"#, + ) + } + #[test] fn local_variable_non_bool() { cov_mark::check!(not_applicable_non_bool_local); @@ -638,6 +684,42 @@ fn main() { ) } + #[test] + fn field_negated() { + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bar: bool, +} + +fn main() { + let foo = Foo { bar: false }; + + if !foo.bar { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + bar: Bool, +} + +fn main() { + let foo = Foo { bar: Bool::False }; + + if foo.bar == Bool::False { + println!("foo"); + } +} +"#, + ) + } + #[test] fn field_in_mod_properly_indented() { check_assist( @@ -714,6 +796,88 @@ fn main() { ) } + #[test] + fn field_assigned_to_another() { + cov_mark::check!(dont_assign_incorrect_ref); + check_assist( + bool_to_enum, + r#" +struct Foo { + $0foo: bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: true }; + let mut bar = Bar { bar: true }; + + bar.bar = foo.foo; +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + foo: Bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + let mut bar = Bar { bar: true }; + + bar.bar = foo.foo == Bool::True; +} +"#, + ) + } + + #[test] + fn field_initialized_with_other() { + cov_mark::check!(dont_overwrite_wrong_record_field); + check_assist( + bool_to_enum, + r#" +struct Foo { + $0foo: bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: true }; + let bar = Bar { bar: foo.foo }; +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + foo: Bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + let bar = Bar { bar: foo.foo == Bool::True }; +} +"#, + ) + } + #[test] fn field_non_bool() { cov_mark::check!(not_applicable_non_bool_field); From 25b1b3e753c4cb6fa75b2004c8d753daf240b1c6 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sun, 10 Sep 2023 22:21:12 -0700 Subject: [PATCH 09/10] feat: add support for other ADT types and destructuring patterns --- .../ide-assists/src/handlers/bool_to_enum.rs | 492 +++++++++++++++--- 1 file changed, 429 insertions(+), 63 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 784a0d3559..b9dbd6e98f 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -6,6 +6,7 @@ use ide_db::{ imports::insert_use::{insert_use, ImportScope}, search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, + FxHashSet, }; use itertools::Itertools; use syntax::{ @@ -17,6 +18,7 @@ use syntax::{ }, ted, AstNode, NodeOrToken, SyntaxNode, T, }; +use text_edit::TextRange; use crate::assist_context::{AssistContext, Assists}; @@ -52,7 +54,7 @@ use crate::assist_context::{AssistContext, Assists}; pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = find_bool_node(ctx)?; - let target_module = ctx.sema.scope(&target_node)?.module(); + let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db()); let target = name.syntax().text_range(); acc.add( @@ -70,9 +72,8 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option } let usages = definition.usages(&ctx.sema).all(); - add_enum_def(edit, ctx, &usages, target_node, &target_module); - replace_usages(edit, ctx, &usages, &target_module); + replace_usages(edit, ctx, &usages, definition, &target_module); }, ) } @@ -144,14 +145,14 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { return None; } - let strukt = field.syntax().ancestors().find_map(ast::Struct::cast)?; + let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?; let def = ctx.sema.to_def(&field)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_field); return None; } Some(BoolNodeData { - target_node: strukt.syntax().clone(), + target_node: adt.syntax().clone(), name, ty_annotation: field.ty(), initializer: None, @@ -192,78 +193,171 @@ fn replace_usages( edit: &mut SourceChangeBuilder, ctx: &AssistContext<'_>, usages: &UsageSearchResult, + target_definition: Definition, target_module: &hir::Module, ) { for (file_id, references) in usages.iter() { edit.edit_file(*file_id); - // add imports across modules where needed - references - .iter() - .filter_map(|FileReference { name, .. }| { - ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) - }) - .unique_by(|name_and_module| name_and_module.1) - .filter(|(_, module)| module != target_module) - .filter_map(|(name, module)| { - let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); - let mod_path = module.find_use_path_prefixed( - ctx.sema.db, - ModuleDef::Module(*target_module), - ctx.config.insert_use.prefix_kind, - ctx.config.prefer_no_std, - ); - import_scope.zip(mod_path) - }) - .for_each(|(import_scope, mod_path)| { - let import_scope = match import_scope { - ImportScope::File(it) => ImportScope::File(edit.make_mut(it)), - ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)), - ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)), - }; - let path = - make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")); - insert_use(&import_scope, path, &ctx.config.insert_use); - }); + let refs_with_imports = + augment_references_with_imports(edit, ctx, references, target_module); - // replace the usages in expressions - references - .into_iter() - .filter_map(|FileReference { range, name, .. }| match name { - ast::NameLike::NameRef(name) => Some((*range, name)), - _ => None, - }) - .rev() - .for_each(|(range, name_ref)| { - if let Some(initializer) = find_assignment_usage(name_ref) { + refs_with_imports.into_iter().rev().for_each( + |FileReferenceWithImport { range, old_name, new_name, import_data }| { + // replace the usages in patterns and expressions + if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast) + { + cov_mark::hit!(replaces_record_pat_shorthand); + + let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local); + if let Some(def) = definition { + replace_usages( + edit, + ctx, + &def.usages(&ctx.sema).all(), + target_definition, + target_module, + ) + } + } else if let Some(initializer) = find_assignment_usage(&new_name) { cov_mark::hit!(replaces_assignment); replace_bool_expr(edit, initializer); - } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(name_ref) { + } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) { cov_mark::hit!(replaces_negation); edit.replace( prefix_expr.syntax().text_range(), format!("{} == Bool::False", inner_expr), ); - } else if let Some((record_field, initializer)) = find_record_expr_usage(name_ref) { + } else if let Some((record_field, initializer)) = old_name + .as_name_ref() + .and_then(ast::RecordExprField::for_field_name) + .and_then(|record_field| ctx.sema.resolve_record_field(&record_field)) + .and_then(|(got_field, _, _)| { + find_record_expr_usage(&new_name, got_field, target_definition) + }) + { cov_mark::hit!(replaces_record_expr); let record_field = edit.make_mut(record_field); let enum_expr = bool_expr_to_enum_expr(initializer); record_field.replace_expr(enum_expr); - } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { + } else if let Some(pat) = find_record_pat_field_usage(&old_name) { + match pat { + ast::Pat::IdentPat(ident_pat) => { + cov_mark::hit!(replaces_record_pat); + + let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local); + if let Some(def) = definition { + replace_usages( + edit, + ctx, + &def.usages(&ctx.sema).all(), + target_definition, + target_module, + ) + } + } + ast::Pat::LiteralPat(literal_pat) => { + cov_mark::hit!(replaces_literal_pat); + + if let Some(expr) = literal_pat.literal().and_then(|literal| { + literal.syntax().ancestors().find_map(ast::Expr::cast) + }) { + replace_bool_expr(edit, expr); + } + } + _ => (), + } + } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant - edit.replace(range, format!("{} == Bool::True", name_ref.text())); + if let Some((record_field, expr)) = new_name + .as_name_ref() + .and_then(ast::RecordExprField::for_field_name) + .and_then(|record_field| { + record_field.expr().map(|expr| (record_field, expr)) + }) + { + record_field.replace_expr( + make::expr_bin_op( + expr, + ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }), + make::expr_path(make::path_from_text("Bool::True")), + ) + .clone_for_update(), + ); + } else { + edit.replace(range, format!("{} == Bool::True", new_name.text())); + } } - }) + + // add imports across modules where needed + if let Some((import_scope, path)) = import_data { + insert_use(&import_scope, path, &ctx.config.insert_use); + } + }, + ) } } -fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { - let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; +struct FileReferenceWithImport { + range: TextRange, + old_name: ast::NameLike, + new_name: ast::NameLike, + import_data: Option<(ImportScope, ast::Path)>, +} - if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) { +fn augment_references_with_imports( + edit: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + references: &[FileReference], + target_module: &hir::Module, +) -> Vec { + let mut visited_modules = FxHashSet::default(); + + references + .iter() + .filter_map(|FileReference { range, name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module())) + }) + .map(|(range, name, ref_module)| { + let old_name = name.clone(); + let new_name = edit.make_mut(name.clone()); + + // if the referenced module is not the same as the target one and has not been seen before, add an import + let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module + && !visited_modules.contains(&ref_module) + { + visited_modules.insert(ref_module); + + let import_scope = + ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema); + let path = ref_module + .find_use_path_prefixed( + ctx.sema.db, + ModuleDef::Module(*target_module), + ctx.config.insert_use.prefix_kind, + ctx.config.prefer_no_std, + ) + .map(|mod_path| { + make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")) + }); + + import_scope.zip(path) + } else { + None + }; + + FileReferenceWithImport { range, old_name, new_name, import_data } + }) + .collect() +} + +fn find_assignment_usage(name: &ast::NameLike) -> Option { + let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?; + + if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) { cov_mark::hit!(dont_assign_incorrect_ref); return None; } @@ -275,8 +369,8 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { } } -fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { - let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; +fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> { + let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) { cov_mark::hit!(dont_overwrite_expression_inside_negation); @@ -291,15 +385,31 @@ fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast:: } } -fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprField, ast::Expr)> { - let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; +fn find_record_expr_usage( + name: &ast::NameLike, + got_field: hir::Field, + target_definition: Definition, +) -> Option<(ast::RecordExprField, ast::Expr)> { + let name_ref = name.as_name_ref()?; + let record_field = ast::RecordExprField::for_field_name(name_ref)?; let initializer = record_field.expr()?; - if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) { - Some((record_field, initializer)) - } else { - cov_mark::hit!(dont_overwrite_wrong_record_field); - None + if let Definition::Field(expected_field) = target_definition { + if got_field != expected_field { + return None; + } + } + + Some((record_field, initializer)) +} + +fn find_record_pat_field_usage(name: &ast::NameLike) -> Option { + let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?; + let pat = record_pat_field.pat()?; + + match pat { + ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat), + _ => None, } } @@ -317,7 +427,7 @@ fn add_enum_def( .filter_map(|FileReference { name, .. }| { ctx.sema.scope(name.syntax()).map(|scope| scope.module()) }) - .any(|module| &module != target_module); + .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); let indent = IndentLevel::from_node(&target_node); @@ -646,7 +756,7 @@ fn main() { } #[test] - fn field_basic() { + fn field_struct_basic() { cov_mark::check!(replaces_record_expr); check_assist( bool_to_enum, @@ -684,6 +794,263 @@ fn main() { ) } + #[test] + fn field_enum_basic() { + cov_mark::check!(replaces_record_pat); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: baz } = foo { + if baz { + println!("foo"); + } + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: baz } = foo { + if baz == Bool::True { + println!("foo"); + } + } +} +"#, + ) + } + + #[test] + fn field_enum_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /foo.rs +pub enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn foo() { + let foo = Foo::Bar { bar: true }; +} + +//- /main.rs +use foo::Foo; + +mod foo; + +fn main() { + let foo = Foo::Bar { bar: false }; +} +"#, + r#" +//- /foo.rs +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +pub enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn foo() { + let foo = Foo::Bar { bar: Bool::True }; +} + +//- /main.rs +use foo::{Foo, Bool}; + +mod foo; + +fn main() { + let foo = Foo::Bar { bar: Bool::False }; +} +"#, + ) + } + + #[test] + fn field_enum_shorthand() { + cov_mark::check!(replaces_record_pat_shorthand); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + match foo { + Foo::Bar { bar } => { + if bar { + println!("foo"); + } + } + _ => (), + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + match foo { + Foo::Bar { bar } => { + if bar == Bool::True { + println!("foo"); + } + } + _ => (), + } +} +"#, + ) + } + + #[test] + fn field_enum_replaces_literal_patterns() { + cov_mark::check!(replaces_literal_pat); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: true } = foo { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: Bool::True } = foo { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn field_enum_keeps_wildcard_patterns() { + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: _ } = foo { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: _ } = foo { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn field_union_basic() { + check_assist( + bool_to_enum, + r#" +union Foo { + $0foo: bool, + bar: usize, +} + +fn main() { + let foo = Foo { foo: true }; + + if unsafe { foo.foo } { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +union Foo { + foo: Bool, + bar: usize, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + + if unsafe { foo.foo == Bool::True } { + println!("foo"); + } +} +"#, + ) + } + #[test] fn field_negated() { check_assist( @@ -841,7 +1208,6 @@ fn main() { #[test] fn field_initialized_with_other() { - cov_mark::check!(dont_overwrite_wrong_record_field); check_assist( bool_to_enum, r#" From 93562dd5bdec82c70b8eff05c59badb4314c90c8 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 22 Sep 2023 08:53:24 +0200 Subject: [PATCH 10/10] Use parent + and_then instead of ancestors --- crates/ide-assists/src/handlers/bool_to_enum.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index b9dbd6e98f..85b0b87d0c 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -111,7 +111,7 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { initializer: let_stmt.initializer(), definition: Definition::Local(def), }) - } else if let Some(const_) = name.syntax().ancestors().find_map(ast::Const::cast) { + } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) { let def = ctx.sema.to_def(&const_)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_const); @@ -125,7 +125,7 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { initializer: const_.body(), definition: Definition::Const(def), }) - } else if let Some(static_) = name.syntax().ancestors().find_map(ast::Static::cast) { + } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) { let def = ctx.sema.to_def(&static_)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_static); @@ -140,7 +140,7 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option { definition: Definition::Static(def), }) } else { - let field = name.syntax().ancestors().find_map(ast::RecordField::cast)?; + let field = name.syntax().parent().and_then(ast::RecordField::cast)?; if field.name()? != name { return None; }