diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 3b39707cf6..46d2e88160 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1727,10 +1727,10 @@ impl Adt { pub fn ty_with_args<'db>( self, db: &'db dyn HirDatabase, - args: impl Iterator>, + args: impl IntoIterator>, ) -> Type<'db> { let id = AdtId::from(self); - let mut it = args.map(|t| t.ty); + let mut it = args.into_iter().map(|t| t.ty); let ty = TyBuilder::def_ty(db, id.into(), None) .fill(|x| { let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner)); diff --git a/crates/ide-assists/src/handlers/wrap_return_type.rs b/crates/ide-assists/src/handlers/wrap_return_type.rs index 9ea78719b2..d7189aa5db 100644 --- a/crates/ide-assists/src/handlers/wrap_return_type.rs +++ b/crates/ide-assists/src/handlers/wrap_return_type.rs @@ -56,7 +56,8 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op }; let type_ref = &ret_type.ty()?; - let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + let ty = ctx.sema.resolve_type(type_ref)?; + let ty_adt = ty.as_adt(); let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()); for kind in WrapperKind::ALL { @@ -64,7 +65,7 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op continue; }; - if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) { + if matches!(ty_adt, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) { // The return type is already wrapped cov_mark::hit!(wrap_return_type_simple_return_type_already_wrapped); continue; @@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op |builder| { let mut editor = builder.make_editor(&parent); let make = SyntaxFactory::with_mappings(); - let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol()); - let new_return_ty = alias.unwrap_or_else(|| match kind { - WrapperKind::Option => make.ty_option(type_ref.clone()), - WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()), + let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol()); + let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| { + let (ast_ty, ty_constructor) = match kind { + WrapperKind::Option => { + (make.ty_option(type_ref.clone()), famous_defs.core_option_Option()) + } + WrapperKind::Result => ( + make.ty_result(type_ref.clone(), make.ty_infer().into()), + famous_defs.core_result_Result(), + ), + }; + let semantic_ty = ty_constructor + .map(|ty_constructor| { + hir::Adt::from(ty_constructor).ty_with_args(ctx.db(), [ty.clone()]) + }) + .unwrap_or_else(|| ty.clone()); + (ast_ty, semantic_ty) }); let mut exprs_to_wrap = Vec::new(); @@ -96,6 +110,17 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op for_each_tail_expr(&body_expr, tail_cb); for ret_expr_arg in exprs_to_wrap { + if let Some(ty) = ctx.sema.type_of_expr(&ret_expr_arg) { + if ty.adjusted().could_unify_with(ctx.db(), &semantic_new_return_ty) { + // The type is already correct, don't wrap it. + // We deliberately don't use `could_unify_with_deeply()`, because as long as the outer + // enum matches it's okay for us, as we don't trigger the assist if the return type + // is already `Option`/`Result`, so mismatched exact type is more likely a mistake + // than something intended. + continue; + } + } + let happy_wrapped = make.expr_call( make.expr_path(make.ident_path(kind.happy_ident())), make.arg_list(iter::once(ret_expr_arg.clone())), @@ -103,12 +128,12 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax()); } - editor.replace(type_ref.syntax(), new_return_ty.syntax()); + editor.replace(type_ref.syntax(), ast_new_return_ty.syntax()); if let WrapperKind::Result = kind { // Add a placeholder snippet at the first generic argument that doesn't equal the return type. // This is normally the error type, but that may not be the case when we inserted a type alias. - let args = new_return_ty + let args = ast_new_return_ty .path() .unwrap() .segment() @@ -188,27 +213,28 @@ impl WrapperKind { } // Try to find an wrapper type alias in the current scope (shadowing the default). -fn wrapper_alias( - ctx: &AssistContext<'_>, +fn wrapper_alias<'db>( + ctx: &AssistContext<'db>, make: &SyntaxFactory, - core_wrapper: &hir::Enum, - ret_type: &ast::Type, + core_wrapper: hir::Enum, + ast_ret_type: &ast::Type, + semantic_ret_type: &hir::Type<'db>, wrapper: hir::Symbol, -) -> Option { +) -> Option<(ast::PathType, hir::Type<'db>)> { let wrapper_path = hir::ModPath::from_segments( hir::PathKind::Plain, iter::once(hir::Name::new_symbol_root(wrapper)), ); - ctx.sema.resolve_mod_path(ret_type.syntax(), &wrapper_path).and_then(|def| { + ctx.sema.resolve_mod_path(ast_ret_type.syntax(), &wrapper_path).and_then(|def| { def.filter_map(|def| match def.into_module_def() { hir::ModuleDef::TypeAlias(alias) => { let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?; - (&enum_ty == core_wrapper).then_some(alias) + (enum_ty == core_wrapper).then_some((alias, enum_ty)) } _ => None, }) - .find_map(|alias| { + .find_map(|(alias, enum_ty)| { let mut inserted_ret_type = false; let generic_args = alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| { @@ -216,7 +242,7 @@ fn wrapper_alias( // Replace the very first type parameter with the function's return type. ast::GenericParam::TypeParam(_) if !inserted_ret_type => { inserted_ret_type = true; - make.type_arg(ret_type.clone()).into() + make.type_arg(ast_ret_type.clone()).into() } ast::GenericParam::LifetimeParam(_) => { make.lifetime_arg(make.lifetime("'_")).into() @@ -231,7 +257,10 @@ fn wrapper_alias( make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list), ); - Some(make.ty_path(path)) + let new_ty = + hir::Adt::from(enum_ty).ty_with_args(ctx.db(), [semantic_ret_type.clone()]); + + Some((make.ty_path(path), new_ty)) }) }) } @@ -605,29 +634,39 @@ fn foo() -> Option { check_assist_by_label( wrap_return_type, r#" -//- minicore: option +//- minicore: option, future +struct F(i32); +impl core::future::Future for F { + type Output = i32; + fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll { 0 } +} async fn foo() -> i$032 { if true { if false { - 1.await + F(1).await } else { - 2.await + F(2).await } } else { - 24i32.await + F(24i32).await } } "#, r#" +struct F(i32); +impl core::future::Future for F { + type Output = i32; + fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll { 0 } +} async fn foo() -> Option { if true { if false { - Some(1.await) + Some(F(1).await) } else { - Some(2.await) + Some(F(2).await) } } else { - Some(24i32.await) + Some(F(24i32).await) } } "#, @@ -1666,29 +1705,39 @@ fn foo() -> Result { check_assist_by_label( wrap_return_type, r#" -//- minicore: result +//- minicore: result, future +struct F(i32); +impl core::future::Future for F { + type Output = i32; + fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll { 0 } +} async fn foo() -> i$032 { if true { if false { - 1.await + F(1).await } else { - 2.await + F(2).await } } else { - 24i32.await + F(24i32).await } } "#, r#" +struct F(i32); +impl core::future::Future for F { + type Output = i32; + fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll { 0 } +} async fn foo() -> Result { if true { if false { - Ok(1.await) + Ok(F(1).await) } else { - Ok(2.await) + Ok(F(2).await) } } else { - Ok(24i32.await) + Ok(F(24i32).await) } } "#, @@ -2455,6 +2504,56 @@ type Result = core::result::Result, Bar>; fn foo() -> Result { Ok(0) +} + "#, + WrapperKind::Result.label(), + ); + } + + #[test] + fn already_wrapped() { + check_assist_by_label( + wrap_return_type, + r#" +//- minicore: option +fn foo() -> i32$0 { + if false { + 0 + } else { + Some(1) + } +} + "#, + r#" +fn foo() -> Option { + if false { + Some(0) + } else { + Some(1) + } +} + "#, + WrapperKind::Option.label(), + ); + check_assist_by_label( + wrap_return_type, + r#" +//- minicore: result +fn foo() -> i32$0 { + if false { + 0 + } else { + Ok(1) + } +} + "#, + r#" +fn foo() -> Result { + if false { + Ok(0) + } else { + Ok(1) + } } "#, WrapperKind::Result.label(),