From 6377d50bd137424a38f2f71bb3eba29d74d02210 Mon Sep 17 00:00:00 2001 From: hkalbasi Date: Thu, 2 Mar 2023 11:18:50 +0330 Subject: [PATCH] Support "for loop" MIR lowering --- crates/hir-ty/src/consteval/tests.rs | 64 +++++ crates/hir-ty/src/infer.rs | 5 + crates/hir-ty/src/infer/expr.rs | 6 +- crates/hir-ty/src/mir.rs | 4 + crates/hir-ty/src/mir/eval.rs | 7 +- crates/hir-ty/src/mir/lower.rs | 255 ++++++++++++------ .../src/handlers/mutability_errors.rs | 16 ++ crates/test-utils/src/minicore.rs | 14 + 8 files changed, 292 insertions(+), 79 deletions(-) diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 0f0e68a560..e255bd798e 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -415,6 +415,43 @@ fn loops() { ); } +#[test] +fn for_loops() { + check_number( + r#" + //- minicore: iterator + + struct Range { + start: u8, + end: u8, + } + + impl Iterator for Range { + type Item = u8; + fn next(&mut self) -> Option { + if self.start >= self.end { + None + } else { + let r = self.start; + self.start = self.start + 1; + Some(r) + } + } + } + + const GOAL: u8 = { + let mut sum = 0; + let ar = Range { start: 1, end: 11 }; + for i in ar { + sum = sum + i; + } + sum + }; + "#, + 55, + ); +} + #[test] fn recursion() { check_number( @@ -518,6 +555,33 @@ fn tuples() { ); } +#[test] +fn path_pattern_matching() { + check_number( + r#" + enum Season { + Spring, + Summer, + Fall, + Winter, + } + + use Season::*; + + const fn f(x: Season) -> i32 { + match x { + Spring => 1, + Summer => 2, + Fall => 3, + Winter => 4, + } + } + const GOAL: i32 = f(Spring) + 10 * f(Summer) + 100 * f(Fall) + 1000 * f(Winter); + "#, + 4321, + ); +} + #[test] fn pattern_matching_ergonomics() { check_number( diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index bac733d988..262c562e9f 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -354,6 +354,8 @@ pub struct InferenceResult { pub type_of_pat: ArenaMap, pub type_of_binding: ArenaMap, pub type_of_rpit: ArenaMap, + /// Type of the result of `.into_iter()` on the for. `ExprId` is the one of the whole for loop. + pub type_of_for_iterator: ArenaMap, type_mismatches: FxHashMap, /// Interned common types to return references to. standard_types: InternedStandardTypes, @@ -549,6 +551,9 @@ impl<'a> InferenceContext<'a> { for ty in result.type_of_rpit.values_mut() { *ty = table.resolve_completely(ty.clone()); } + for ty in result.type_of_for_iterator.values_mut() { + *ty = table.resolve_completely(ty.clone()); + } for mismatch in result.type_mismatches.values_mut() { mismatch.expected = table.resolve_completely(mismatch.expected.clone()); mismatch.actual = table.resolve_completely(mismatch.actual.clone()); diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index cca84488c9..535189ff02 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -242,8 +242,10 @@ impl<'a> InferenceContext<'a> { let iterable_ty = self.infer_expr(iterable, &Expectation::none()); let into_iter_ty = self.resolve_associated_type(iterable_ty, self.resolve_into_iter_item()); - let pat_ty = - self.resolve_associated_type(into_iter_ty, self.resolve_iterator_item()); + let pat_ty = self + .resolve_associated_type(into_iter_ty.clone(), self.resolve_iterator_item()); + + self.result.type_of_for_iterator.insert(tgt_expr, into_iter_ty); self.infer_top_pat(pat, &pat_ty); self.with_breakable_ctx(BreakableKind::Loop, None, label, |this| { diff --git a/crates/hir-ty/src/mir.rs b/crates/hir-ty/src/mir.rs index c18a34a192..7c1cbbdf53 100644 --- a/crates/hir-ty/src/mir.rs +++ b/crates/hir-ty/src/mir.rs @@ -83,6 +83,10 @@ impl Operand { fn from_bytes(data: Vec, ty: Ty) -> Self { Operand::from_concrete_const(data, MemoryMap::default(), ty) } + + fn const_zst(ty: Ty) -> Operand { + Self::from_bytes(vec![], ty) + } } #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 245cfdb4dd..b0b09fcd53 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -1122,7 +1122,12 @@ impl Evaluator<'_> { } fn detect_lang_function(&self, def: FunctionId) -> Option { - lang_attr(self.db.upcast(), def) + let candidate = lang_attr(self.db.upcast(), def)?; + // filter normal lang functions out + if [LangItem::IntoIterIntoIter, LangItem::IteratorNext].contains(&candidate) { + return None; + } + Some(candidate) } fn create_memory_map(&self, bytes: &[u8], ty: &Ty, locals: &Locals<'_>) -> Result { diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index e89e16079d..f9a66286b2 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -9,6 +9,7 @@ use hir_def::{ Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId, RecordLitField, }, + lang_item::{LangItem, LangItemTarget}, layout::LayoutError, resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, DefWithBodyId, EnumVariantId, HasModule, @@ -17,8 +18,8 @@ use la_arena::ArenaMap; use crate::{ consteval::ConstEvalError, db::HirDatabase, display::HirDisplay, infer::TypeMismatch, - inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, utils::generics, - Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt, + inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, static_lifetime, + utils::generics, Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt, }; use super::*; @@ -59,6 +60,7 @@ pub enum MirLowerError { Loop, /// Something that should never happen and is definitely a bug, but we don't want to panic if it happened ImplementationError(&'static str), + LangItemNotFound(LangItem), } macro_rules! not_supported { @@ -484,13 +486,64 @@ impl MirLowerCtx<'_> { Ok(()) }) } - Expr::For { .. } => not_supported!("for loop"), + &Expr::For { iterable, pat, body, label } => { + let into_iter_fn = self.resolve_lang_item(LangItem::IntoIterIntoIter)? + .as_function().ok_or(MirLowerError::LangItemNotFound(LangItem::IntoIterIntoIter))?; + let iter_next_fn = self.resolve_lang_item(LangItem::IteratorNext)? + .as_function().ok_or(MirLowerError::LangItemNotFound(LangItem::IteratorNext))?; + let option_some = self.resolve_lang_item(LangItem::OptionSome)? + .as_enum_variant().ok_or(MirLowerError::LangItemNotFound(LangItem::OptionSome))?; + let option = option_some.parent; + let into_iter_fn_op = Operand::const_zst( + TyKind::FnDef( + self.db.intern_callable_def(CallableDefId::FunctionId(into_iter_fn)).into(), + Substitution::from1(Interner, self.expr_ty(iterable)) + ).intern(Interner)); + let iter_next_fn_op = Operand::const_zst( + TyKind::FnDef( + self.db.intern_callable_def(CallableDefId::FunctionId(iter_next_fn)).into(), + Substitution::from1(Interner, self.expr_ty(iterable)) + ).intern(Interner)); + let iterator_ty = &self.infer.type_of_for_iterator[expr_id]; + let ref_mut_iterator_ty = TyKind::Ref(Mutability::Mut, static_lifetime(), iterator_ty.clone()).intern(Interner); + let item_ty = &self.infer.type_of_pat[pat]; + let option_item_ty = TyKind::Adt(chalk_ir::AdtId(option.into()), Substitution::from1(Interner, item_ty.clone())).intern(Interner); + let iterator_place: Place = self.temp(iterator_ty.clone())?.into(); + let option_item_place: Place = self.temp(option_item_ty.clone())?.into(); + let ref_mut_iterator_place: Place = self.temp(ref_mut_iterator_ty)?.into(); + let Some(current) = self.lower_call_and_args(into_iter_fn_op, Some(iterable).into_iter(), iterator_place.clone(), current, false)? + else { + return Ok(None); + }; + self.push_assignment(current, ref_mut_iterator_place.clone(), Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, iterator_place), expr_id.into()); + self.lower_loop(current, label, |this, begin| { + this.push_storage_live(pat, begin)?; + let Some(current) = this.lower_call(iter_next_fn_op, vec![Operand::Copy(ref_mut_iterator_place)], option_item_place.clone(), begin, false)? + else { + return Ok(()); + }; + let end = this.current_loop_end()?; + let (current, _) = this.pattern_matching_variant( + option_item_ty.clone(), + BindingAnnotation::Unannotated, + option_item_place.into(), + option_some.into(), + current, + pat.into(), + Some(end), + &[pat], &None)?; + if let (_, Some(block)) = this.lower_expr_to_some_place(body, current)? { + this.set_goto(block, begin); + } + Ok(()) + }) + }, Expr::Call { callee, args, .. } => { let callee_ty = self.expr_ty_after_adjustments(*callee); match &callee_ty.data(Interner).kind { chalk_ir::TyKind::FnDef(..) => { let func = Operand::from_bytes(vec![], callee_ty.clone()); - self.lower_call(func, args.iter().copied(), place, current, self.is_uninhabited(expr_id)) + self.lower_call_and_args(func, args.iter().copied(), place, current, self.is_uninhabited(expr_id)) } TyKind::Scalar(_) | TyKind::Tuple(_, _) @@ -527,7 +580,7 @@ impl MirLowerCtx<'_> { ) .intern(Interner); let func = Operand::from_bytes(vec![], ty); - self.lower_call( + self.lower_call_and_args( func, iter::once(*receiver).chain(args.iter().copied()), place, @@ -962,7 +1015,7 @@ impl MirLowerCtx<'_> { Ok(prev_block) } - fn lower_call( + fn lower_call_and_args( &mut self, func: Operand, args: impl Iterator, @@ -983,6 +1036,17 @@ impl MirLowerCtx<'_> { else { return Ok(None); }; + self.lower_call(func, args, place, current, is_uninhabited) + } + + fn lower_call( + &mut self, + func: Operand, + args: Vec, + place: Place, + current: BasicBlockId, + is_uninhabited: bool, + ) -> Result> { let b = if is_uninhabited { None } else { Some(self.new_basic_block()) }; self.set_terminator( current, @@ -1112,7 +1176,22 @@ impl MirLowerCtx<'_> { Pat::Record { .. } => not_supported!("record pattern"), Pat::Range { .. } => not_supported!("range pattern"), Pat::Slice { .. } => not_supported!("slice pattern"), - Pat::Path(_) => not_supported!("path pattern"), + Pat::Path(_) => { + let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else { + not_supported!("unresolved variant"); + }; + self.pattern_matching_variant( + cond_ty, + binding_mode, + cond_place, + variant, + current, + pattern.into(), + current_else, + &[], + &None, + )? + } Pat::Lit(l) => { let then_target = self.new_basic_block(); let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); @@ -1183,75 +1262,17 @@ impl MirLowerCtx<'_> { let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else { not_supported!("unresolved variant"); }; - pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); - let subst = match cond_ty.kind(Interner) { - TyKind::Adt(_, s) => s, - _ => { - return Err(MirLowerError::TypeError( - "non adt type matched with tuple struct", - )) - } - }; - let fields_type = self.db.field_types(variant); - match variant { - VariantId::EnumVariantId(v) => { - let e = self.db.const_eval_discriminant(v)? as u128; - let next = self.new_basic_block(); - let tmp = self.discr_temp_place(); - self.push_assignment( - current, - tmp.clone(), - Rvalue::Discriminant(cond_place.clone()), - pattern.into(), - ); - let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); - self.set_terminator( - current, - Terminator::SwitchInt { - discr: Operand::Copy(tmp), - targets: SwitchTargets::static_if(e, next, else_target), - }, - ); - let enum_data = self.db.enum_data(v.parent); - let fields = - enum_data.variants[v.local_id].variant_data.fields().iter().map( - |(x, _)| { - ( - PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), - fields_type[x].clone().substitute(Interner, subst), - ) - }, - ); - self.pattern_match_tuple_like( - next, - Some(else_target), - args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), - *ellipsis, - &cond_place, - binding_mode, - )? - } - VariantId::StructId(s) => { - let struct_data = self.db.struct_data(s); - let fields = struct_data.variant_data.fields().iter().map(|(x, _)| { - ( - PlaceElem::Field(FieldId { parent: s.into(), local_id: x }), - fields_type[x].clone().substitute(Interner, subst), - ) - }); - self.pattern_match_tuple_like( - current, - current_else, - args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), - *ellipsis, - &cond_place, - binding_mode, - )? - } - VariantId::UnionId(_) => { - return Err(MirLowerError::TypeError("pattern matching on union")) - } - } + self.pattern_matching_variant( + cond_ty, + binding_mode, + cond_place, + variant, + current, + pattern.into(), + current_else, + args, + ellipsis, + )? } Pat::Ref { .. } => not_supported!("& pattern"), Pat::Box { .. } => not_supported!("box pattern"), @@ -1259,6 +1280,83 @@ impl MirLowerCtx<'_> { }) } + fn pattern_matching_variant( + &mut self, + mut cond_ty: Ty, + mut binding_mode: BindingAnnotation, + mut cond_place: Place, + variant: VariantId, + current: BasicBlockId, + span: MirSpan, + current_else: Option, + args: &[PatId], + ellipsis: &Option, + ) -> Result<(BasicBlockId, Option)> { + pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); + let subst = match cond_ty.kind(Interner) { + TyKind::Adt(_, s) => s, + _ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")), + }; + let fields_type = self.db.field_types(variant); + Ok(match variant { + VariantId::EnumVariantId(v) => { + let e = self.db.const_eval_discriminant(v)? as u128; + let next = self.new_basic_block(); + let tmp = self.discr_temp_place(); + self.push_assignment( + current, + tmp.clone(), + Rvalue::Discriminant(cond_place.clone()), + span, + ); + let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); + self.set_terminator( + current, + Terminator::SwitchInt { + discr: Operand::Copy(tmp), + targets: SwitchTargets::static_if(e, next, else_target), + }, + ); + let enum_data = self.db.enum_data(v.parent); + let fields = + enum_data.variants[v.local_id].variant_data.fields().iter().map(|(x, _)| { + ( + PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), + fields_type[x].clone().substitute(Interner, subst), + ) + }); + self.pattern_match_tuple_like( + next, + Some(else_target), + args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), + *ellipsis, + &cond_place, + binding_mode, + )? + } + VariantId::StructId(s) => { + let struct_data = self.db.struct_data(s); + let fields = struct_data.variant_data.fields().iter().map(|(x, _)| { + ( + PlaceElem::Field(FieldId { parent: s.into(), local_id: x }), + fields_type[x].clone().substitute(Interner, subst), + ) + }); + self.pattern_match_tuple_like( + current, + current_else, + args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)), + *ellipsis, + &cond_place, + binding_mode, + )? + } + VariantId::UnionId(_) => { + return Err(MirLowerError::TypeError("pattern matching on union")) + } + }) + } + fn pattern_match_tuple_like( &mut self, mut current: BasicBlockId, @@ -1384,6 +1482,11 @@ impl MirLowerCtx<'_> { }); Ok(()) } + + fn resolve_lang_item(&self, item: LangItem) -> Result { + let crate_id = self.owner.module(self.db.upcast()).krate(); + self.db.lang_item(crate_id, item).ok_or(MirLowerError::LangItemNotFound(item)) + } } fn pattern_matching_dereference( diff --git a/crates/ide-diagnostics/src/handlers/mutability_errors.rs b/crates/ide-diagnostics/src/handlers/mutability_errors.rs index 4f5d958354..a6aa069e27 100644 --- a/crates/ide-diagnostics/src/handlers/mutability_errors.rs +++ b/crates/ide-diagnostics/src/handlers/mutability_errors.rs @@ -507,6 +507,22 @@ fn f(x: i32) { x = 5; //^^^^^ 💡 error: cannot mutate immutable variable `x` } +"#, + ); + } + + #[test] + fn for_loop() { + check_diagnostics( + r#" +//- minicore: iterators +fn f(x: [(i32, u8); 10]) { + for (a, mut b) in x { + //^^^^^ 💡 weak: remove this `mut` + a = 2; + //^^^^^ 💡 error: cannot mutate immutable variable `a` + } +} "#, ); } diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 7b48e42489..44d7f69061 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -728,6 +728,20 @@ pub mod iter { self } } + pub struct IntoIter([T; N]); + impl IntoIterator for [T; N] { + type Item = T; + type IntoIter = IntoIter; + fn into_iter(self) -> I { + IntoIter(self) + } + } + impl Iterator for IntoIter { + type Item = T; + fn next(&mut self) -> Option { + loop {} + } + } } pub use self::collect::IntoIterator; }