From 2eb7389b6367c0354b30faa5e8ef4fe09a7e87c4 Mon Sep 17 00:00:00 2001 From: Prajwal S N Date: Tue, 8 Apr 2025 14:18:09 +0530 Subject: [PATCH] refactor: migrate `let_else_to_match` to editor Signed-off-by: Prajwal S N --- .../src/handlers/add_missing_match_arms.rs | 4 +- .../src/handlers/convert_let_else_to_match.rs | 299 ++++++++++-------- .../handlers/destructure_struct_binding.rs | 4 +- .../src/handlers/expand_rest_pattern.rs | 7 +- .../src/handlers/unmerge_match_arm.rs | 10 +- .../src/handlers/missing_fields.rs | 11 +- crates/syntax/src/ast/make.rs | 32 +- .../src/ast/syntax_factory/constructors.rs | 110 ++++++- 8 files changed, 335 insertions(+), 142 deletions(-) diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs index 05d21cb979..858d436991 100644 --- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs +++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -493,8 +493,8 @@ fn build_pat( hir::StructKind::Record => { let fields = fields .into_iter() - .map(|f| make.name_ref(f.name(db).as_str())) - .map(|name_ref| make.record_pat_field_shorthand(name_ref)); + .map(|f| make.ident_pat(false, false, make.name(f.name(db).as_str()))) + .map(|ident| make.record_pat_field_shorthand(ident.into())); let fields = make.record_pat_field_list(fields, None); make.record_pat_with_fields(path, fields).into() } diff --git a/crates/ide-assists/src/handlers/convert_let_else_to_match.rs b/crates/ide-assists/src/handlers/convert_let_else_to_match.rs index df92b07cba..ebfed9f9ca 100644 --- a/crates/ide-assists/src/handlers/convert_let_else_to_match.rs +++ b/crates/ide-assists/src/handlers/convert_let_else_to_match.rs @@ -1,8 +1,9 @@ -use hir::Semantics; -use ide_db::RootDatabase; use syntax::T; use syntax::ast::RangeItem; -use syntax::ast::{AstNode, HasName, LetStmt, Name, Pat, edit::AstNodeEdit}; +use syntax::ast::edit::IndentLevel; +use syntax::ast::edit_in_place::Indent; +use syntax::ast::syntax_factory::SyntaxFactory; +use syntax::ast::{self, AstNode, HasName, LetStmt, Pat}; use crate::{AssistContext, AssistId, Assists}; @@ -25,155 +26,205 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - // should focus on else token to trigger + // Should focus on the `else` token to trigger let let_stmt = ctx .find_token_syntax_at_offset(T![else]) .and_then(|it| it.parent()?.parent()) .or_else(|| ctx.find_token_syntax_at_offset(T![let])?.parent())?; let let_stmt = LetStmt::cast(let_stmt)?; - let let_else_block = let_stmt.let_else()?.block_expr()?; - let let_init = let_stmt.initializer()?; + let else_block = let_stmt.let_else()?.block_expr()?; + let else_expr = if else_block.statements().next().is_none() { + else_block.tail_expr()? + } else { + else_block.into() + }; + let init = let_stmt.initializer()?; + // Ignore let stmt with type annotation if let_stmt.ty().is_some() { - // don't support let with type annotation return None; } let pat = let_stmt.pat()?; - let mut binders = Vec::new(); - binders_in_pat(&mut binders, &pat, &ctx.sema)?; - let target = let_stmt.syntax().text_range(); + let make = SyntaxFactory::with_mappings(); + let mut idents = Vec::default(); + let pat_without_mut = remove_mut_and_collect_idents(&make, &pat, &mut idents)?; + let bindings = idents + .into_iter() + .filter_map(|ref pat| { + // Identifiers which resolve to constants are not bindings + if ctx.sema.resolve_bind_pat_to_const(pat).is_none() { + Some((pat.name()?, pat.ref_token().is_none() && pat.mut_token().is_some())) + } else { + None + } + }) + .collect::>(); + acc.add( AssistId::refactor_rewrite("convert_let_else_to_match"), - "Convert let-else to let and match", - target, - |edit| { - let indent_level = let_stmt.indent_level().0 as usize; - let indent = " ".repeat(indent_level); - let indent1 = " ".repeat(indent_level + 1); + if bindings.is_empty() { + "Convert let-else to match" + } else { + "Convert let-else to let and match" + }, + let_stmt.syntax().text_range(), + |builder| { + let mut editor = builder.make_editor(let_stmt.syntax()); - let binders_str = binders_to_str(&binders, false); - let binders_str_mut = binders_to_str(&binders, true); + let binding_paths = bindings + .iter() + .map(|(name, _)| make.expr_path(make.ident_path(&name.to_string()))) + .collect::>(); - let init_expr = let_init.syntax().text(); - let mut pat_no_mut = pat.syntax().text().to_string(); - // remove the mut from the pattern - for (b, ismut) in binders.iter() { - if *ismut { - pat_no_mut = pat_no_mut.replace(&format!("mut {b}"), &b.to_string()); - } + let binding_arm = make.match_arm( + pat_without_mut, + None, + // There are three possible cases: + // + // - No bindings: `None => {}` + // - Single binding: `Some(it) => it` + // - Multiple bindings: `Foo::Bar { a, b, .. } => (a, b)` + match binding_paths.len() { + 0 => make.expr_empty_block().into(), + + 1 => binding_paths[0].clone(), + _ => make.expr_tuple(binding_paths).into(), + }, + ); + let else_arm = make.match_arm(make.wildcard_pat().into(), None, else_expr); + let match_ = make.expr_match(init, make.match_arm_list([binding_arm, else_arm])); + match_.reindent_to(IndentLevel::from_node(let_stmt.syntax())); + + if bindings.is_empty() { + editor.replace(let_stmt.syntax(), match_.syntax()); + } else { + let ident_pats = bindings + .into_iter() + .map(|(name, is_mut)| make.ident_pat(false, is_mut, name).into()) + .collect::>(); + let new_let_stmt = make.let_stmt( + if ident_pats.len() == 1 { + ident_pats[0].clone() + } else { + make.tuple_pat(ident_pats).into() + }, + None, + Some(match_.into()), + ); + editor.replace(let_stmt.syntax(), new_let_stmt.syntax()); } - let only_expr = let_else_block.statements().next().is_none(); - let branch2 = match &let_else_block.tail_expr() { - Some(tail) if only_expr => format!("{tail},"), - _ => let_else_block.syntax().text().to_string(), - }; - let replace = if binders.is_empty() { - format!( - "match {init_expr} {{ -{indent1}{pat_no_mut} => {binders_str} -{indent1}_ => {branch2} -{indent}}}" - ) - } else { - format!( - "let {binders_str_mut} = match {init_expr} {{ -{indent1}{pat_no_mut} => {binders_str}, -{indent1}_ => {branch2} -{indent}}};" - ) - }; - edit.replace(target, replace); + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } -/// Gets a list of binders in a pattern, and whether they are mut. -fn binders_in_pat( - acc: &mut Vec<(Name, bool)>, - pat: &Pat, - sem: &Semantics<'_, RootDatabase>, -) -> Option<()> { - use Pat::*; - match pat { - IdentPat(p) => { - let ident = p.name()?; - let ismut = p.ref_token().is_none() && p.mut_token().is_some(); - // check for const reference - if sem.resolve_bind_pat_to_const(p).is_none() { - acc.push((ident, ismut)); - } +fn remove_mut_and_collect_idents( + make: &SyntaxFactory, + pat: &ast::Pat, + acc: &mut Vec, +) -> Option { + Some(match pat { + ast::Pat::IdentPat(p) => { + acc.push(p.clone()); + let non_mut_pat = make.ident_pat( + p.ref_token().is_some(), + p.ref_token().is_some() && p.mut_token().is_some(), + p.name()?, + ); if let Some(inner) = p.pat() { - binders_in_pat(acc, &inner, sem)?; + non_mut_pat.set_pat(remove_mut_and_collect_idents(make, &inner, acc)); } - Some(()) + non_mut_pat.into() } - BoxPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), - RestPat(_) | LiteralPat(_) | PathPat(_) | WildcardPat(_) | ConstBlockPat(_) => Some(()), - OrPat(p) => { - for p in p.pats() { - binders_in_pat(acc, &p, sem)?; - } - Some(()) + ast::Pat::BoxPat(p) => { + make.box_pat(remove_mut_and_collect_idents(make, &p.pat()?, acc)?).into() } - ParenPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), - RangePat(p) => { - if let Some(st) = p.start() { - binders_in_pat(acc, &st, sem)? - } - if let Some(ed) = p.end() { - binders_in_pat(acc, &ed, sem)? - } - Some(()) + ast::Pat::OrPat(p) => make + .or_pat( + p.pats() + .map(|pat| remove_mut_and_collect_idents(make, &pat, acc)) + .collect::>>()?, + p.leading_pipe().is_some(), + ) + .into(), + ast::Pat::ParenPat(p) => { + make.paren_pat(remove_mut_and_collect_idents(make, &p.pat()?, acc)?).into() } - RecordPat(p) => { - for f in p.record_pat_field_list()?.fields() { - let pat = f.pat()?; - binders_in_pat(acc, &pat, sem)?; + ast::Pat::RangePat(p) => make + .range_pat( + if let Some(start) = p.start() { + Some(remove_mut_and_collect_idents(make, &start, acc)?) + } else { + None + }, + if let Some(end) = p.end() { + Some(remove_mut_and_collect_idents(make, &end, acc)?) + } else { + None + }, + ) + .into(), + ast::Pat::RecordPat(p) => make + .record_pat_with_fields( + p.path()?, + make.record_pat_field_list( + p.record_pat_field_list()? + .fields() + .map(|field| { + remove_mut_and_collect_idents(make, &field.pat()?, acc).map(|pat| { + if let Some(name_ref) = field.name_ref() { + make.record_pat_field(name_ref, pat) + } else { + make.record_pat_field_shorthand(pat) + } + }) + }) + .collect::>>()?, + p.record_pat_field_list()?.rest_pat(), + ), + ) + .into(), + ast::Pat::RefPat(p) => { + let inner = p.pat()?; + if let ast::Pat::IdentPat(ident) = inner { + acc.push(ident); + p.clone_for_update().into() + } else { + make.ref_pat(remove_mut_and_collect_idents(make, &inner, acc)?).into() } - Some(()) - } - RefPat(p) => p.pat().and_then(|p| binders_in_pat(acc, &p, sem)), - SlicePat(p) => { - for p in p.pats() { - binders_in_pat(acc, &p, sem)?; - } - Some(()) - } - TuplePat(p) => { - for p in p.fields() { - binders_in_pat(acc, &p, sem)?; - } - Some(()) - } - TupleStructPat(p) => { - for p in p.fields() { - binders_in_pat(acc, &p, sem)?; - } - Some(()) } + ast::Pat::SlicePat(p) => make + .slice_pat( + p.pats() + .map(|pat| remove_mut_and_collect_idents(make, &pat, acc)) + .collect::>>()?, + ) + .into(), + ast::Pat::TuplePat(p) => make + .tuple_pat( + p.fields() + .map(|field| remove_mut_and_collect_idents(make, &field, acc)) + .collect::>>()?, + ) + .into(), + ast::Pat::TupleStructPat(p) => make + .tuple_struct_pat( + p.path()?, + p.fields() + .map(|field| remove_mut_and_collect_idents(make, &field, acc)) + .collect::>>()?, + ) + .into(), + ast::Pat::RestPat(_) + | ast::Pat::LiteralPat(_) + | ast::Pat::PathPat(_) + | ast::Pat::WildcardPat(_) + | ast::Pat::ConstBlockPat(_) => pat.clone(), // don't support macro pat yet - MacroPat(_) => None, - } -} - -fn binders_to_str(binders: &[(Name, bool)], addmut: bool) -> String { - let vars = binders - .iter() - .map( - |(ident, ismut)| { - if *ismut && addmut { format!("mut {ident}") } else { ident.to_string() } - }, - ) - .collect::>() - .join(", "); - if binders.is_empty() { - String::from("{}") - } else if binders.len() == 1 { - vars - } else { - format!("({vars})") - } + ast::Pat::MacroPat(_) => return None, + }) } #[cfg(test)] diff --git a/crates/ide-assists/src/handlers/destructure_struct_binding.rs b/crates/ide-assists/src/handlers/destructure_struct_binding.rs index 800ef89ac6..b8c647ac8b 100644 --- a/crates/ide-assists/src/handlers/destructure_struct_binding.rs +++ b/crates/ide-assists/src/handlers/destructure_struct_binding.rs @@ -196,7 +196,9 @@ fn destructure_pat( let fields = field_names.iter().map(|(old_name, new_name)| { // Use shorthand syntax if possible if old_name == new_name && !is_mut { - make.record_pat_field_shorthand(make.name_ref(old_name)) + make.record_pat_field_shorthand( + make.ident_pat(false, false, make.name(old_name)).into(), + ) } else { make.record_pat_field( make.name_ref(old_name), diff --git a/crates/ide-assists/src/handlers/expand_rest_pattern.rs b/crates/ide-assists/src/handlers/expand_rest_pattern.rs index 4e487e2162..b71de5e00c 100644 --- a/crates/ide-assists/src/handlers/expand_rest_pattern.rs +++ b/crates/ide-assists/src/handlers/expand_rest_pattern.rs @@ -56,7 +56,12 @@ fn expand_record_rest_pattern( let new_field_list = make.record_pat_field_list(old_field_list.fields(), None); for (f, _) in missing_fields.iter() { let field = make.record_pat_field_shorthand( - make.name_ref(&f.name(ctx.sema.db).display_no_db(edition).to_smolstr()), + make.ident_pat( + false, + false, + make.name(&f.name(ctx.sema.db).display_no_db(edition).to_smolstr()), + ) + .into(), ); new_field_list.add_field(field); } diff --git a/crates/ide-assists/src/handlers/unmerge_match_arm.rs b/crates/ide-assists/src/handlers/unmerge_match_arm.rs index 31ff47a054..5aedff5cc7 100644 --- a/crates/ide-assists/src/handlers/unmerge_match_arm.rs +++ b/crates/ide-assists/src/handlers/unmerge_match_arm.rs @@ -53,8 +53,14 @@ pub(crate) fn unmerge_match_arm(acc: &mut Assists, ctx: &AssistContext<'_>) -> O |edit| { let pats_after = pipe_token .siblings_with_tokens(Direction::Next) - .filter_map(|it| ast::Pat::cast(it.into_node()?)); - let new_pat = make::or_pat(pats_after, or_pat.leading_pipe().is_some()); + .filter_map(|it| ast::Pat::cast(it.into_node()?)) + .collect::>(); + // It is guaranteed that `pats_after` has at least one element + let new_pat = if pats_after.len() == 1 { + pats_after[0].clone() + } else { + make::or_pat(pats_after, or_pat.leading_pipe().is_some()).into() + }; let new_match_arm = make::match_arm(new_pat, match_arm.guard(), match_arm_body).clone_for_update(); diff --git a/crates/ide-diagnostics/src/handlers/missing_fields.rs b/crates/ide-diagnostics/src/handlers/missing_fields.rs index 9aea2b1056..a354d123f5 100644 --- a/crates/ide-diagnostics/src/handlers/missing_fields.rs +++ b/crates/ide-diagnostics/src/handlers/missing_fields.rs @@ -163,9 +163,14 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingFields) -> Option ast::WildcardPat { } pub fn rest_pat() -> ast::RestPat { - ast_from_text("fn f(..)") + ast_from_text("fn f() { let ..; }") } pub fn literal_pat(lit: &str) -> ast::LiteralPat { @@ -788,8 +788,8 @@ pub fn record_pat_field(name_ref: ast::NameRef, pat: ast::Pat) -> ast::RecordPat ast_from_text(&format!("fn f(S {{ {name_ref}: {pat} }}: ()))")) } -pub fn record_pat_field_shorthand(name_ref: ast::NameRef) -> ast::RecordPatField { - ast_from_text(&format!("fn f(S {{ {name_ref} }}: ()))")) +pub fn record_pat_field_shorthand(pat: ast::Pat) -> ast::RecordPatField { + ast_from_text(&format!("fn f(S {{ {pat} }}: ()))")) } /// Returns a `IdentPat` if the path has just one segment, a `PathPat` otherwise. @@ -801,16 +801,38 @@ pub fn path_pat(path: ast::Path) -> ast::Pat { } /// Returns a `Pat` if the path has just one segment, an `OrPat` otherwise. -pub fn or_pat(pats: impl IntoIterator, leading_pipe: bool) -> ast::Pat { +/// +/// Invariant: `pats` must be length > 1. +pub fn or_pat(pats: impl IntoIterator, leading_pipe: bool) -> ast::OrPat { let leading_pipe = if leading_pipe { "| " } else { "" }; let pats = pats.into_iter().join(" | "); return from_text(&format!("{leading_pipe}{pats}")); - fn from_text(text: &str) -> ast::Pat { + fn from_text(text: &str) -> ast::OrPat { ast_from_text(&format!("fn f({text}: ())")) } } +pub fn box_pat(pat: ast::Pat) -> ast::BoxPat { + ast_from_text(&format!("fn f(box {pat}: ())")) +} + +pub fn paren_pat(pat: ast::Pat) -> ast::ParenPat { + ast_from_text(&format!("fn f(({pat}): ())")) +} + +pub fn range_pat(start: Option, end: Option) -> ast::RangePat { + ast_from_text(&format!( + "fn f({}..{}: ())", + start.map(|e| e.to_string()).unwrap_or_default(), + end.map(|e| e.to_string()).unwrap_or_default() + )) +} + +pub fn ref_pat(pat: ast::Pat) -> ast::RefPat { + ast_from_text(&format!("fn f(&{pat}: ())")) +} + pub fn match_arm(pat: ast::Pat, guard: Option, expr: ast::Expr) -> ast::MatchArm { return match guard { Some(guard) => from_text(&format!("{pat} {guard} => {expr}")), diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 1854000d3d..3b205516c2 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -3,7 +3,7 @@ use crate::{ AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken, ast::{ self, HasArgList, HasGenericArgs, HasGenericParams, HasLoopBody, HasName, HasTypeBounds, - HasVisibility, make, + HasVisibility, RangeItem, make, }, syntax_editor::SyntaxMappingBuilder, }; @@ -254,12 +254,12 @@ impl SyntaxFactory { ast } - pub fn record_pat_field_shorthand(&self, name_ref: ast::NameRef) -> ast::RecordPatField { - let ast = make::record_pat_field_shorthand(name_ref.clone()).clone_for_update(); + pub fn record_pat_field_shorthand(&self, pat: ast::Pat) -> ast::RecordPatField { + let ast = make::record_pat_field_shorthand(pat.clone()).clone_for_update(); if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_node(name_ref.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); builder.finish(&mut mapping); } @@ -294,6 +294,76 @@ impl SyntaxFactory { make::rest_pat().clone_for_update() } + pub fn or_pat( + &self, + pats: impl IntoIterator, + leading_pipe: bool, + ) -> ast::OrPat { + let (pats, input) = iterator_input(pats); + let ast = make::or_pat(pats, leading_pipe).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(input, ast.pats().map(|it| it.syntax().clone())); + builder.finish(&mut mapping); + } + + ast + } + + pub fn box_pat(&self, pat: ast::Pat) -> ast::BoxPat { + let ast = make::box_pat(pat.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn paren_pat(&self, pat: ast::Pat) -> ast::ParenPat { + let ast = make::paren_pat(pat.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn range_pat(&self, start: Option, end: Option) -> ast::RangePat { + let ast = make::range_pat(start.clone(), end.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(start) = start { + builder.map_node(start.syntax().clone(), ast.start().unwrap().syntax().clone()); + } + if let Some(end) = end { + builder.map_node(end.syntax().clone(), ast.end().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + + pub fn ref_pat(&self, pat: ast::Pat) -> ast::RefPat { + let ast = make::ref_pat(pat.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn block_expr( &self, statements: impl IntoIterator, @@ -673,6 +743,38 @@ impl SyntaxFactory { ast } + pub fn let_else_stmt( + &self, + pattern: ast::Pat, + ty: Option, + initializer: ast::Expr, + diverging: ast::BlockExpr, + ) -> ast::LetStmt { + let ast = make::let_else_stmt( + pattern.clone(), + ty.clone(), + initializer.clone(), + diverging.clone(), + ) + .clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pattern.syntax().clone(), ast.pat().unwrap().syntax().clone()); + if let Some(input) = ty { + builder.map_node(input.syntax().clone(), ast.ty().unwrap().syntax().clone()); + } + builder.map_node( + initializer.syntax().clone(), + ast.initializer().unwrap().syntax().clone(), + ); + builder.map_node(diverging.syntax().clone(), ast.let_else().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn type_arg(&self, ty: ast::Type) -> ast::TypeArg { let ast = make::type_arg(ty.clone()).clone_for_update();