Empower replace_if_let_with_match

This commit is contained in:
Lukas Wirth 2021-07-02 00:20:27 +02:00
parent 1b9b2d1f40
commit 20be999304
2 changed files with 119 additions and 72 deletions

View File

@ -1,4 +1,4 @@
use std::iter; use std::iter::{self, successors};
use ide_db::{ty_filter::TryEnum, RootDatabase}; use ide_db::{ty_filter::TryEnum, RootDatabase};
use syntax::{ use syntax::{
@ -17,7 +17,7 @@ use crate::{
// Assist: replace_if_let_with_match // Assist: replace_if_let_with_match
// //
// Replaces `if let` with an else branch with a `match` expression. // Replaces a `if let` expression with a `match` expression.
// //
// ``` // ```
// enum Action { Move { distance: u32 }, Stop } // enum Action { Move { distance: u32 }, Stop }
@ -43,14 +43,28 @@ use crate::{
// ``` // ```
pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
let cond = if_expr.condition()?; let mut else_block = None;
let pat = cond.pat()?; let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
let expr = cond.expr()?; ast::ElseBranch::IfExpr(expr) => Some(expr),
let then_block = if_expr.then_branch()?; ast::ElseBranch::Block(block) => {
let else_block = match if_expr.else_branch()? { else_block = Some(block);
ast::ElseBranch::Block(it) => it, None
ast::ElseBranch::IfExpr(_) => return None, }
}; });
let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
let mut pat_bodies = Vec::new();
for if_expr in if_exprs {
let cond = if_expr.condition()?;
let expr = cond.expr()?;
if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
// Only if all condition expressions are equal we can merge them into a match
return None;
}
let pat = cond.pat()?;
let body = if_expr.then_branch()?;
pat_bodies.push((pat, body));
}
let target = if_expr.syntax().text_range(); let target = if_expr.syntax().text_range();
acc.add( acc.add(
@ -59,33 +73,50 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
target, target,
move |edit| { move |edit| {
let match_expr = { let match_expr = {
let then_arm = {
let then_block = then_block.reset_indent().indent(IndentLevel(1));
let then_expr = unwrap_trivial_block(then_block);
make::match_arm(vec![pat.clone()], then_expr)
};
let else_arm = { let else_arm = {
let pattern = ctx match else_block {
.sema Some(else_block) => {
.type_of_pat(&pat) let pattern = match &*pat_bodies {
.and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty)) [(pat, _)] => ctx
.map(|it| { .sema
if does_pat_match_variant(&pat, &it.sad_pattern()) { .type_of_pat(&pat)
it.happy_pattern() .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty))
} else { .map(|it| {
it.sad_pattern() if does_pat_match_variant(&pat, &it.sad_pattern()) {
it.happy_pattern()
} else {
it.sad_pattern()
}
}),
_ => None,
} }
}) .unwrap_or_else(|| make::wildcard_pat().into());
.unwrap_or_else(|| make::wildcard_pat().into()); make::match_arm(iter::once(pattern), unwrap_trivial_block(else_block))
let else_expr = unwrap_trivial_block(else_block); }
make::match_arm(vec![pattern], else_expr) None => make::match_arm(
iter::once(make::wildcard_pat().into()),
make::expr_unit().into(),
),
}
}; };
let match_expr = let arms = pat_bodies
make::expr_match(expr, make::match_arm_list(vec![then_arm, else_arm])); .into_iter()
.map(|(pat, body)| {
let body = body.reset_indent().indent(IndentLevel(1));
make::match_arm(vec![pat], unwrap_trivial_block(body))
})
.chain(iter::once(else_arm));
let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
match_expr.indent(IndentLevel::from_node(if_expr.syntax())) match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
}; };
edit.replace_ast::<ast::Expr>(if_expr.into(), match_expr); let expr =
if if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind())) {
make::block_expr(None, Some(match_expr)).into()
} else {
match_expr
};
edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
}, },
) )
} }
@ -182,7 +213,33 @@ mod tests {
use crate::tests::{check_assist, check_assist_target}; use crate::tests::{check_assist, check_assist_target};
#[test] #[test]
fn test_replace_if_let_with_match_unwraps_simple_expressions() { fn test_if_let_with_match_no_else() {
check_assist(
replace_if_let_with_match,
r#"
impl VariantData {
pub fn foo(&self) {
if $0let VariantData::Struct(..) = *self {
self.foo();
}
}
} "#,
r#"
impl VariantData {
pub fn foo(&self) {
match *self {
VariantData::Struct(..) => {
self.foo();
}
_ => (),
}
}
} "#,
)
}
#[test]
fn test_if_let_with_match_basic() {
check_assist( check_assist(
replace_if_let_with_match, replace_if_let_with_match,
r#" r#"
@ -190,8 +247,12 @@ impl VariantData {
pub fn is_struct(&self) -> bool { pub fn is_struct(&self) -> bool {
if $0let VariantData::Struct(..) = *self { if $0let VariantData::Struct(..) = *self {
true true
} else { } else if let VariantData::Tuple(..) = *self {
false false
} else {
bar(
123
)
} }
} }
} "#, } "#,
@ -200,7 +261,12 @@ impl VariantData {
pub fn is_struct(&self) -> bool { pub fn is_struct(&self) -> bool {
match *self { match *self {
VariantData::Struct(..) => true, VariantData::Struct(..) => true,
_ => false, VariantData::Tuple(..) => false,
_ => {
bar(
123
)
}
} }
} }
} "#, } "#,
@ -208,53 +274,35 @@ impl VariantData {
} }
#[test] #[test]
fn test_replace_if_let_with_match_doesnt_unwrap_multiline_expressions() { fn test_if_let_with_match_on_tail_if_let() {
check_assist( check_assist(
replace_if_let_with_match, replace_if_let_with_match,
r#" r#"
fn foo() {
if $0let VariantData::Struct(..) = a {
bar(
123
)
} else {
false
}
} "#,
r#"
fn foo() {
match a {
VariantData::Struct(..) => {
bar(
123
)
}
_ => false,
}
} "#,
)
}
#[test]
fn replace_if_let_with_match_target() {
check_assist_target(
replace_if_let_with_match,
r#"
impl VariantData { impl VariantData {
pub fn is_struct(&self) -> bool { pub fn is_struct(&self) -> bool {
if $0let VariantData::Struct(..) = *self { if let VariantData::Struct(..) = *self {
true true
} else if let$0 VariantData::Tuple(..) = *self {
false
} else { } else {
false false
} }
} }
} "#, } "#,
"if let VariantData::Struct(..) = *self { r#"
impl VariantData {
pub fn is_struct(&self) -> bool {
if let VariantData::Struct(..) = *self {
true true
} else { } else {
false match *self {
}", VariantData::Tuple(..) => false,
); _ => false,
}
}
}
} "#,
)
} }
#[test] #[test]

View File

@ -48,15 +48,14 @@ pub fn extract_trivial_expression(block: &ast::BlockExpr) -> Option<ast::Expr> {
return Some(expr); return Some(expr);
} }
// Unwrap `{ continue; }` // Unwrap `{ continue; }`
let (stmt,) = block.statements().next_tuple()?; let stmt = block.statements().next()?;
if let ast::Stmt::ExprStmt(expr_stmt) = stmt { if let ast::Stmt::ExprStmt(expr_stmt) = stmt {
if has_anything_else(expr_stmt.syntax()) { if has_anything_else(expr_stmt.syntax()) {
return None; return None;
} }
let expr = expr_stmt.expr()?; let expr = expr_stmt.expr()?;
match expr.syntax().kind() { if matches!(expr.syntax().kind(), CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR) {
CONTINUE_EXPR | BREAK_EXPR | RETURN_EXPR => return Some(expr), return Some(expr);
_ => (),
} }
} }
None None