mirror of
				https://github.com/rust-lang/rust-analyzer.git
				synced 2025-11-03 13:13:18 +00:00 
			
		
		
		
	fix multiple definition binding in match to let-else
This commit is contained in:
		
							parent
							
								
									38e9a110d4
								
							
						
					
					
						commit
						811190b913
					
				@ -1,6 +1,6 @@
 | 
				
			|||||||
use ide_db::defs::{Definition, NameRefClass};
 | 
					use ide_db::defs::{Definition, NameRefClass};
 | 
				
			||||||
use syntax::{
 | 
					use syntax::{
 | 
				
			||||||
    ast::{self, HasName},
 | 
					    ast::{self, HasName, Name},
 | 
				
			||||||
    ted, AstNode, SyntaxNode,
 | 
					    ted, AstNode, SyntaxNode,
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -48,7 +48,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
 | 
				
			|||||||
        other => format!("{{ {other} }}"),
 | 
					        other => format!("{{ {other} }}"),
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    let extracting_arm_pat = extracting_arm.pat()?;
 | 
					    let extracting_arm_pat = extracting_arm.pat()?;
 | 
				
			||||||
    let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
 | 
					    let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    acc.add(
 | 
					    acc.add(
 | 
				
			||||||
        AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
 | 
					        AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
 | 
				
			||||||
@ -56,7 +56,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
 | 
				
			|||||||
        let_stmt.syntax().text_range(),
 | 
					        let_stmt.syntax().text_range(),
 | 
				
			||||||
        |builder| {
 | 
					        |builder| {
 | 
				
			||||||
            let extracting_arm_pat =
 | 
					            let extracting_arm_pat =
 | 
				
			||||||
                rename_variable(&extracting_arm_pat, extracted_variable, binding);
 | 
					                rename_variable(&extracting_arm_pat, &extracted_variable_positions, binding);
 | 
				
			||||||
            builder.replace(
 | 
					            builder.replace(
 | 
				
			||||||
                let_stmt.syntax().text_range(),
 | 
					                let_stmt.syntax().text_range(),
 | 
				
			||||||
                format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
 | 
					                format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
 | 
				
			||||||
@ -95,14 +95,15 @@ fn find_arms(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Given an extracting arm, find the extracted variable.
 | 
					// Given an extracting arm, find the extracted variable.
 | 
				
			||||||
fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
 | 
					fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
 | 
				
			||||||
    match arm.expr()? {
 | 
					    match arm.expr()? {
 | 
				
			||||||
        ast::Expr::PathExpr(path) => {
 | 
					        ast::Expr::PathExpr(path) => {
 | 
				
			||||||
            let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
 | 
					            let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
 | 
				
			||||||
            match NameRefClass::classify(&ctx.sema, &name_ref)? {
 | 
					            match NameRefClass::classify(&ctx.sema, &name_ref)? {
 | 
				
			||||||
                NameRefClass::Definition(Definition::Local(local)) => {
 | 
					                NameRefClass::Definition(Definition::Local(local)) => {
 | 
				
			||||||
                    let source = local.primary_source(ctx.db()).into_ident_pat()?;
 | 
					                    let source =
 | 
				
			||||||
                    Some(source.name()?)
 | 
					                        local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
 | 
				
			||||||
 | 
					                    source.collect()
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                _ => None,
 | 
					                _ => None,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Rename `extracted` with `binding` in `pat`.
 | 
					// Rename `extracted` with `binding` in `pat`.
 | 
				
			||||||
fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::Pat) -> SyntaxNode {
 | 
					fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
 | 
				
			||||||
    let syntax = pat.syntax().clone_for_update();
 | 
					    let syntax = pat.syntax().clone_for_update();
 | 
				
			||||||
    let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
 | 
					    let extracted = extracted
 | 
				
			||||||
 | 
					        .iter()
 | 
				
			||||||
 | 
					        .map(|e| syntax.covering_element(e.syntax().text_range()))
 | 
				
			||||||
 | 
					        .collect::<Vec<_>>();
 | 
				
			||||||
 | 
					    for extracted_syntax in extracted {
 | 
				
			||||||
        // If `extracted` variable is a record field, we should rename it to `binding`,
 | 
					        // If `extracted` variable is a record field, we should rename it to `binding`,
 | 
				
			||||||
        // otherwise we just need to replace `extracted` with `binding`.
 | 
					        // otherwise we just need to replace `extracted` with `binding`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
 | 
					        if let Some(record_pat_field) =
 | 
				
			||||||
 | 
					            extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            if let Some(name_ref) = record_pat_field.field_name() {
 | 
					            if let Some(name_ref) = record_pat_field.field_name() {
 | 
				
			||||||
                ted::replace(
 | 
					                ted::replace(
 | 
				
			||||||
                    record_pat_field.syntax(),
 | 
					                    record_pat_field.syntax(),
 | 
				
			||||||
                ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding)
 | 
					                    ast::make::record_pat_field(
 | 
				
			||||||
 | 
					                        ast::make::name_ref(&name_ref.text()),
 | 
				
			||||||
 | 
					                        binding.clone(),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                    .syntax()
 | 
					                    .syntax()
 | 
				
			||||||
                    .clone_for_update(),
 | 
					                    .clone_for_update(),
 | 
				
			||||||
                );
 | 
					                );
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
        ted::replace(extracted_syntax, binding.syntax().clone_for_update());
 | 
					            ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    syntax
 | 
					    syntax
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -162,6 +170,39 @@ fn foo(opt: Option<()>) {
 | 
				
			|||||||
        );
 | 
					        );
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    fn or_pattern_multiple_binding() {
 | 
				
			||||||
 | 
					        check_assist(
 | 
				
			||||||
 | 
					            convert_match_to_let_else,
 | 
				
			||||||
 | 
					            r#"
 | 
				
			||||||
 | 
					//- minicore: option
 | 
				
			||||||
 | 
					enum Foo {
 | 
				
			||||||
 | 
					    A(u32),
 | 
				
			||||||
 | 
					    B(u32),
 | 
				
			||||||
 | 
					    C(String),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn foo(opt: Option<Foo>) -> Result<u32, ()> {
 | 
				
			||||||
 | 
					    let va$0lue = match opt {
 | 
				
			||||||
 | 
					        Some(Foo::A(it) | Foo::B(it)) => it,
 | 
				
			||||||
 | 
					        _ => return Err(()),
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					    "#,
 | 
				
			||||||
 | 
					            r#"
 | 
				
			||||||
 | 
					enum Foo {
 | 
				
			||||||
 | 
					    A(u32),
 | 
				
			||||||
 | 
					    B(u32),
 | 
				
			||||||
 | 
					    C(String),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					fn foo(opt: Option<Foo>) -> Result<u32, ()> {
 | 
				
			||||||
 | 
					    let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					    "#,
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[test]
 | 
					    #[test]
 | 
				
			||||||
    fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
 | 
					    fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
 | 
				
			||||||
        cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
 | 
					        cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user