diff --git a/Cargo.lock b/Cargo.lock index 17dea1ba4c..9d0f63d2bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -545,6 +545,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flate2" version = "1.1.2" @@ -775,6 +781,7 @@ dependencies = [ "itertools", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "oorandom", + "petgraph", "project-model", "query-group-macro", "ra-ap-rustc_abi", @@ -1594,6 +1601,17 @@ dependencies = [ "libc", ] +[[package]] +name = "petgraph" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.4", + "indexmap", +] + [[package]] name = "pin-project-lite" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index 0401367f78..d3a4e37561 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -170,6 +170,7 @@ tracing-subscriber = { version = "0.3.20", default-features = false, features = triomphe = { version = "0.1.14", default-features = false, features = ["std"] } url = "2.5.4" xshell = "0.2.7" +petgraph = { version = "0.8.2", default-features = false } # We need to freeze the version of the crate, as the raw-api feature is considered unstable dashmap = { version = "=6.1.0", features = ["raw-api", "inline"] } diff --git a/crates/hir-ty/Cargo.toml b/crates/hir-ty/Cargo.toml index 138d02e5a6..4013d19ad0 100644 --- a/crates/hir-ty/Cargo.toml +++ b/crates/hir-ty/Cargo.toml @@ -34,6 +34,7 @@ rustc_apfloat = "0.2.3" query-group.workspace = true salsa.workspace = true salsa-macros.workspace = true +petgraph.workspace = true ra-ap-rustc_abi.workspace = true ra-ap-rustc_index.workspace = true diff --git a/crates/hir-ty/src/consteval.rs b/crates/hir-ty/src/consteval.rs index e2a8d1cedc..b2daed425e 100644 --- a/crates/hir-ty/src/consteval.rs +++ b/crates/hir-ty/src/consteval.rs @@ -327,7 +327,7 @@ pub(crate) fn eval_to_const( debruijn: DebruijnIndex, ) -> Const { let db = ctx.db; - let infer = ctx.clone().resolve_all(); + let infer = ctx.fixme_resolve_all_clone(); fn has_closure(body: &Body, expr: ExprId) -> bool { if matches!(body[expr], Expr::Closure { .. }) { return true; diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 299b73a7d6..1586846bbe 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -36,12 +36,12 @@ fn check_fail( error: impl FnOnce(ConstEvalError) -> bool, ) { let (db, file_id) = TestDB::with_single_file(ra_fixture); - match eval_goal(&db, file_id) { + salsa::attach(&db, || match eval_goal(&db, file_id) { Ok(_) => panic!("Expected fail, but it succeeded"), Err(e) => { - assert!(error(simplify(e.clone())), "Actual error was: {}", pretty_print_err(e, db)) + assert!(error(simplify(e.clone())), "Actual error was: {}", pretty_print_err(e, &db)) } - } + }) } #[track_caller] @@ -79,36 +79,38 @@ fn check_answer( check: impl FnOnce(&[u8], &MemoryMap<'_>), ) { let (db, file_ids) = TestDB::with_many_files(ra_fixture); - let file_id = *file_ids.last().unwrap(); - let r = match eval_goal(&db, file_id) { - Ok(t) => t, - Err(e) => { - let err = pretty_print_err(e, db); - panic!("Error in evaluating goal: {err}"); - } - }; - match &r.data(Interner).value { - chalk_ir::ConstValue::Concrete(c) => match &c.interned { - ConstScalar::Bytes(b, mm) => { - check(b, mm); + salsa::attach(&db, || { + let file_id = *file_ids.last().unwrap(); + let r = match eval_goal(&db, file_id) { + Ok(t) => t, + Err(e) => { + let err = pretty_print_err(e, &db); + panic!("Error in evaluating goal: {err}"); } - x => panic!("Expected number but found {x:?}"), - }, - _ => panic!("result of const eval wasn't a concrete const"), - } + }; + match &r.data(Interner).value { + chalk_ir::ConstValue::Concrete(c) => match &c.interned { + ConstScalar::Bytes(b, mm) => { + check(b, mm); + } + x => panic!("Expected number but found {x:?}"), + }, + _ => panic!("result of const eval wasn't a concrete const"), + } + }); } -fn pretty_print_err(e: ConstEvalError, db: TestDB) -> String { +fn pretty_print_err(e: ConstEvalError, db: &TestDB) -> String { let mut err = String::new(); let span_formatter = |file, range| format!("{file:?} {range:?}"); let display_target = - DisplayTarget::from_crate(&db, *db.all_crates().last().expect("no crate graph present")); + DisplayTarget::from_crate(db, *db.all_crates().last().expect("no crate graph present")); match e { ConstEvalError::MirLowerError(e) => { - e.pretty_print(&mut err, &db, span_formatter, display_target) + e.pretty_print(&mut err, db, span_formatter, display_target) } ConstEvalError::MirEvalError(e) => { - e.pretty_print(&mut err, &db, span_formatter, display_target) + e.pretty_print(&mut err, db, span_formatter, display_target) } } .unwrap(); diff --git a/crates/hir-ty/src/consteval_nextsolver.rs b/crates/hir-ty/src/consteval_nextsolver.rs index 6e07d3afe5..155f1336e4 100644 --- a/crates/hir-ty/src/consteval_nextsolver.rs +++ b/crates/hir-ty/src/consteval_nextsolver.rs @@ -222,7 +222,7 @@ pub(crate) fn const_eval_discriminant_variant( // and make this function private. See the fixme comment on `InferenceContext::resolve_all`. pub(crate) fn eval_to_const<'db>(expr: ExprId, ctx: &mut InferenceContext<'db>) -> Const<'db> { let interner = DbInterner::new_with(ctx.db, None, None); - let infer = ctx.clone().resolve_all(); + let infer = ctx.fixme_resolve_all_clone(); fn has_closure(body: &Body, expr: ExprId) -> bool { if matches!(body[expr], Expr::Closure { .. }) { return true; diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index fd10f92398..287afb039b 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -19,6 +19,7 @@ pub(crate) mod closure; mod coerce; pub(crate) mod diagnostics; mod expr; +mod fallback; mod mutability; mod pat; mod path; @@ -53,16 +54,16 @@ use indexmap::IndexSet; use intern::sym; use la_arena::{ArenaMap, Entry}; use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_type_ir::inherent::Ty as _; use stdx::{always, never}; use triomphe::Arc; -use crate::db::InternedClosureId; use crate::{ AliasEq, AliasTy, Binders, ClosureId, Const, DomainGoal, GenericArg, ImplTraitId, ImplTraitIdx, IncorrectGenericsLenKind, Interner, Lifetime, OpaqueTyId, ParamLoweringMode, PathLoweringDiagnostic, ProjectionTy, Substitution, TargetFeatures, TraitEnvironment, Ty, TyBuilder, TyExt, - db::HirDatabase, + db::{HirDatabase, InternedClosureId}, fold_tys, generics::Generics, infer::{ @@ -75,6 +76,7 @@ use crate::{ mir::MirSpan, next_solver::{ self, DbInterner, + infer::{DefineOpaqueTypes, traits::ObligationCause}, mapping::{ChalkToNextSolver, NextSolverToChalk}, }, static_lifetime, to_assoc_type_id, @@ -138,6 +140,20 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc for InferenceResult { } } +#[derive(Debug, Clone)] +struct InternedStandardTypesNextSolver<'db> { + unit: crate::next_solver::Ty<'db>, + never: crate::next_solver::Ty<'db>, + i32: crate::next_solver::Ty<'db>, + f64: crate::next_solver::Ty<'db>, +} + +impl<'db> InternedStandardTypesNextSolver<'db> { + fn new(interner: DbInterner<'db>) -> Self { + Self { + unit: crate::next_solver::Ty::new_unit(interner), + never: crate::next_solver::Ty::new(interner, crate::next_solver::TyKind::Never), + i32: crate::next_solver::Ty::new_int(interner, rustc_type_ir::IntTy::I32), + f64: crate::next_solver::Ty::new_float(interner, rustc_type_ir::FloatTy::F64), + } + } +} + /// The inference context contains all information needed during type inference. #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'db> { @@ -718,6 +752,7 @@ pub(crate) struct InferenceContext<'db> { resume_yield_tys: Option<(Ty, Ty)>, diverges: Diverges, breakables: Vec>, + types: InternedStandardTypesNextSolver<'db>, /// Whether we are inside the pattern of a destructuring assignment. inside_assignment: bool, @@ -798,11 +833,13 @@ impl<'db> InferenceContext<'db> { resolver: Resolver<'db>, ) -> Self { let trait_env = db.trait_environment_for_body(owner); + let table = unify::InferenceTable::new(db, trait_env); InferenceContext { + types: InternedStandardTypesNextSolver::new(table.interner), target_features: OnceCell::new(), generics: OnceCell::new(), result: InferenceResult::default(), - table: unify::InferenceTable::new(db, trait_env), + table, tuple_field_accesses_rev: Default::default(), return_ty: TyKind::Error.intern(Interner), // set in collect_* calls resume_yield_tys: None, @@ -865,24 +902,33 @@ impl<'db> InferenceContext<'db> { self.result.has_errors = true; } - // FIXME: This function should be private in module. It is currently only used in the consteval, since we need - // `InferenceResult` in the middle of inference. See the fixme comment in `consteval::eval_to_const`. If you - // used this function for another workaround, mention it here. If you really need this function and believe that - // there is no problem in it being `pub(crate)`, remove this comment. - pub(crate) fn resolve_all(mut self) -> InferenceResult { - self.table.select_obligations_where_possible(); - self.table.fallback_if_possible(); + /// Clones `self` and calls `resolve_all()` on it. + // FIXME: Remove this. + pub(crate) fn fixme_resolve_all_clone(&self) -> InferenceResult { + let mut ctx = self.clone(); + + ctx.type_inference_fallback(); // Comment from rustc: // Even though coercion casts provide type hints, we check casts after fallback for // backwards compatibility. This makes fallback a stronger type hint than a cast coercion. - let cast_checks = std::mem::take(&mut self.deferred_cast_checks); + let cast_checks = std::mem::take(&mut ctx.deferred_cast_checks); for mut cast in cast_checks.into_iter() { - if let Err(diag) = cast.check(&mut self) { - self.diagnostics.push(diag); + if let Err(diag) = cast.check(&mut ctx) { + ctx.diagnostics.push(diag); } } + ctx.table.select_obligations_where_possible(); + + ctx.resolve_all() + } + + // FIXME: This function should be private in module. It is currently only used in the consteval, since we need + // `InferenceResult` in the middle of inference. See the fixme comment in `consteval::eval_to_const`. If you + // used this function for another workaround, mention it here. If you really need this function and believe that + // there is no problem in it being `pub(crate)`, remove this comment. + pub(crate) fn resolve_all(self) -> InferenceResult { let InferenceContext { mut table, mut result, tuple_field_accesses_rev, diagnostics, .. } = self; @@ -914,11 +960,6 @@ impl<'db> InferenceContext<'db> { diagnostics: _, } = &mut result; - // FIXME resolve obligations as well (use Guidance if necessary) - table.select_obligations_where_possible(); - - // make sure diverging type variables are marked as such - table.propagate_diverging_flag(); for ty in type_of_expr.values_mut() { *ty = table.resolve_completely(ty.clone()); *has_errors = *has_errors || ty.contains_unknown(); @@ -1673,6 +1714,22 @@ impl<'db> InferenceContext<'db> { self.resolve_associated_type_with_params(inner_ty, assoc_ty, &[]) } + fn demand_eqtype( + &mut self, + expected: crate::next_solver::Ty<'db>, + actual: crate::next_solver::Ty<'db>, + ) { + let result = self + .table + .infer_ctxt + .at(&ObligationCause::new(), self.table.trait_env.env) + .eq(DefineOpaqueTypes::Yes, expected, actual) + .map(|infer_ok| self.table.register_infer_ok(infer_ok)); + if let Err(_err) = result { + // FIXME: Emit diagnostic. + } + } + fn resolve_associated_type_with_params( &mut self, inner_ty: Ty, diff --git a/crates/hir-ty/src/infer/coerce.rs b/crates/hir-ty/src/infer/coerce.rs index 219b519e46..62ce00a2e3 100644 --- a/crates/hir-ty/src/infer/coerce.rs +++ b/crates/hir-ty/src/infer/coerce.rs @@ -210,9 +210,8 @@ impl<'a, 'b, 'db> Coerce<'a, 'b, 'db> { // Coercing from `!` to any type is allowed: if a.is_never() { // If we're coercing into an inference var, mark it as possibly diverging. - // FIXME: rustc does this differently. - if let TyKind::Infer(rustc_type_ir::TyVar(b)) = b.kind() { - self.table.set_diverging(b.as_u32().into(), chalk_ir::TyVariableKind::General); + if b.is_infer() { + self.table.set_diverging(b); } if self.coerce_never { @@ -1613,16 +1612,21 @@ fn coerce<'db>( chalk_ir::GenericArgData::Const(c) => c.inference_var(Interner), } == Some(iv)) }; - let fallback = |iv, kind, default, binder| match kind { - chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv) - .map_or(default, |i| crate::BoundVar::new(binder, i).to_ty(Interner).cast(Interner)), - chalk_ir::VariableKind::Lifetime => find_var(iv).map_or(default, |i| { - crate::BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner) - }), - chalk_ir::VariableKind::Const(ty) => find_var(iv).map_or(default, |i| { - crate::BoundVar::new(binder, i).to_const(Interner, ty).cast(Interner) - }), + let fallback = |iv, kind, binder| match kind { + chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv).map_or_else( + || chalk_ir::TyKind::Error.intern(Interner).cast(Interner), + |i| crate::BoundVar::new(binder, i).to_ty(Interner).cast(Interner), + ), + chalk_ir::VariableKind::Lifetime => find_var(iv).map_or_else( + || crate::LifetimeData::Error.intern(Interner).cast(Interner), + |i| crate::BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner), + ), + chalk_ir::VariableKind::Const(ty) => find_var(iv).map_or_else( + || crate::unknown_const(ty.clone()).cast(Interner), + |i| crate::BoundVar::new(binder, i).to_const(Interner, ty.clone()).cast(Interner), + ), }; // FIXME also map the types in the adjustments + // FIXME: We don't fallback correctly since this is done on `InferenceContext` and we only have `InferenceTable`. Ok((adjustments, table.resolve_with_fallback(ty.to_chalk(table.interner), &fallback))) } diff --git a/crates/hir-ty/src/infer/fallback.rs b/crates/hir-ty/src/infer/fallback.rs new file mode 100644 index 0000000000..2022447ad4 --- /dev/null +++ b/crates/hir-ty/src/infer/fallback.rs @@ -0,0 +1,439 @@ +//! Fallback of infer vars to `!` and `i32`/`f64`. + +use intern::sym; +use petgraph::{ + Graph, + visit::{Dfs, Walker}, +}; +use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet}; +use rustc_type_ir::{ + TyVid, + inherent::{IntoKind, Ty as _}, +}; +use tracing::debug; + +use crate::{ + infer::InferenceContext, + next_solver::{CoercePredicate, PredicateKind, SubtypePredicate, Ty, TyKind}, +}; + +#[derive(Copy, Clone)] +pub(crate) enum DivergingFallbackBehavior { + /// Always fallback to `()` (aka "always spontaneous decay") + ToUnit, + /// Sometimes fallback to `!`, but mainly fallback to `()` so that most of the crates are not broken. + ContextDependent, + /// Always fallback to `!` (which should be equivalent to never falling back + not making + /// never-to-any coercions unless necessary) + ToNever, +} + +impl<'db> InferenceContext<'db> { + pub(super) fn type_inference_fallback(&mut self) { + debug!( + "type-inference-fallback start obligations: {:#?}", + self.table.fulfillment_cx.pending_obligations() + ); + + // All type checking constraints were added, try to fallback unsolved variables. + self.table.select_obligations_where_possible(); + + debug!( + "type-inference-fallback post selection obligations: {:#?}", + self.table.fulfillment_cx.pending_obligations() + ); + + let fallback_occurred = self.fallback_types(); + + if !fallback_occurred { + return; + } + + // We now see if we can make progress. This might cause us to + // unify inference variables for opaque types, since we may + // have unified some other type variables during the first + // phase of fallback. This means that we only replace + // inference variables with their underlying opaque types as a + // last resort. + // + // In code like this: + // + // ```rust + // type MyType = impl Copy; + // fn produce() -> MyType { true } + // fn bad_produce() -> MyType { panic!() } + // ``` + // + // we want to unify the opaque inference variable in `bad_produce` + // with the diverging fallback for `panic!` (e.g. `()` or `!`). + // This will produce a nice error message about conflicting concrete + // types for `MyType`. + // + // If we had tried to fallback the opaque inference variable to `MyType`, + // we will generate a confusing type-check error that does not explicitly + // refer to opaque types. + self.table.select_obligations_where_possible(); + } + + fn diverging_fallback_behavior(&self) -> DivergingFallbackBehavior { + if self.krate().data(self.db).edition.at_least_2024() { + return DivergingFallbackBehavior::ToNever; + } + + if self.resolver.def_map().is_unstable_feature_enabled(&sym::never_type_fallback) { + return DivergingFallbackBehavior::ContextDependent; + } + + DivergingFallbackBehavior::ToUnit + } + + fn fallback_types(&mut self) -> bool { + // Check if we have any unresolved variables. If not, no need for fallback. + let unresolved_variables = self.table.infer_ctxt.unresolved_variables(); + + if unresolved_variables.is_empty() { + return false; + } + + let diverging_fallback_behavior = self.diverging_fallback_behavior(); + + let diverging_fallback = + self.calculate_diverging_fallback(&unresolved_variables, diverging_fallback_behavior); + + // We do fallback in two passes, to try to generate + // better error messages. + // The first time, we do *not* replace opaque types. + let mut fallback_occurred = false; + for ty in unresolved_variables { + debug!("unsolved_variable = {:?}", ty); + fallback_occurred |= self.fallback_if_possible(ty, &diverging_fallback); + } + + fallback_occurred + } + + // Tries to apply a fallback to `ty` if it is an unsolved variable. + // + // - Unconstrained ints are replaced with `i32`. + // + // - Unconstrained floats are replaced with `f64`. + // + // - Non-numerics may get replaced with `()` or `!`, depending on + // how they were categorized by `calculate_diverging_fallback` + // (and the setting of `#![feature(never_type_fallback)]`). + // + // Fallback becomes very dubious if we have encountered + // type-checking errors. In that case, fallback to Error. + // + // Sets `FnCtxt::fallback_has_occurred` if fallback is performed + // during this call. + fn fallback_if_possible( + &mut self, + ty: Ty<'db>, + diverging_fallback: &FxHashMap, Ty<'db>>, + ) -> bool { + // Careful: we do NOT shallow-resolve `ty`. We know that `ty` + // is an unsolved variable, and we determine its fallback + // based solely on how it was created, not what other type + // variables it may have been unified with since then. + // + // The reason this matters is that other attempts at fallback + // may (in principle) conflict with this fallback, and we wish + // to generate a type error in that case. (However, this + // actually isn't true right now, because we're only using the + // builtin fallback rules. This would be true if we were using + // user-supplied fallbacks. But it's still useful to write the + // code to detect bugs.) + // + // (Note though that if we have a general type variable `?T` + // that is then unified with an integer type variable `?I` + // that ultimately never gets resolved to a special integral + // type, `?T` is not considered unsolved, but `?I` is. The + // same is true for float variables.) + let fallback = match ty.kind() { + TyKind::Infer(rustc_type_ir::IntVar(_)) => self.types.i32, + TyKind::Infer(rustc_type_ir::FloatVar(_)) => self.types.f64, + _ => match diverging_fallback.get(&ty) { + Some(&fallback_ty) => fallback_ty, + None => return false, + }, + }; + debug!("fallback_if_possible(ty={:?}): defaulting to `{:?}`", ty, fallback); + + self.demand_eqtype(ty, fallback); + true + } + + /// The "diverging fallback" system is rather complicated. This is + /// a result of our need to balance 'do the right thing' with + /// backwards compatibility. + /// + /// "Diverging" type variables are variables created when we + /// coerce a `!` type into an unbound type variable `?X`. If they + /// never wind up being constrained, the "right and natural" thing + /// is that `?X` should "fallback" to `!`. This means that e.g. an + /// expression like `Some(return)` will ultimately wind up with a + /// type like `Option` (presuming it is not assigned or + /// constrained to have some other type). + /// + /// However, the fallback used to be `()` (before the `!` type was + /// added). Moreover, there are cases where the `!` type 'leaks + /// out' from dead code into type variables that affect live + /// code. The most common case is something like this: + /// + /// ```rust + /// # fn foo() -> i32 { 4 } + /// match foo() { + /// 22 => Default::default(), // call this type `?D` + /// _ => return, // return has type `!` + /// } // call the type of this match `?M` + /// ``` + /// + /// Here, coercing the type `!` into `?M` will create a diverging + /// type variable `?X` where `?X <: ?M`. We also have that `?D <: + /// ?M`. If `?M` winds up unconstrained, then `?X` will + /// fallback. If it falls back to `!`, then all the type variables + /// will wind up equal to `!` -- this includes the type `?D` + /// (since `!` doesn't implement `Default`, we wind up a "trait + /// not implemented" error in code like this). But since the + /// original fallback was `()`, this code used to compile with `?D + /// = ()`. This is somewhat surprising, since `Default::default()` + /// on its own would give an error because the types are + /// insufficiently constrained. + /// + /// Our solution to this dilemma is to modify diverging variables + /// so that they can *either* fallback to `!` (the default) or to + /// `()` (the backwards compatibility case). We decide which + /// fallback to use based on whether there is a coercion pattern + /// like this: + /// + /// ```ignore (not-rust) + /// ?Diverging -> ?V + /// ?NonDiverging -> ?V + /// ?V != ?NonDiverging + /// ``` + /// + /// Here `?Diverging` represents some diverging type variable and + /// `?NonDiverging` represents some non-diverging type + /// variable. `?V` can be any type variable (diverging or not), so + /// long as it is not equal to `?NonDiverging`. + /// + /// Intuitively, what we are looking for is a case where a + /// "non-diverging" type variable (like `?M` in our example above) + /// is coerced *into* some variable `?V` that would otherwise + /// fallback to `!`. In that case, we make `?V` fallback to `!`, + /// along with anything that would flow into `?V`. + /// + /// The algorithm we use: + /// * Identify all variables that are coerced *into* by a + /// diverging variable. Do this by iterating over each + /// diverging, unsolved variable and finding all variables + /// reachable from there. Call that set `D`. + /// * Walk over all unsolved, non-diverging variables, and find + /// any variable that has an edge into `D`. + fn calculate_diverging_fallback( + &self, + unresolved_variables: &[Ty<'db>], + behavior: DivergingFallbackBehavior, + ) -> FxHashMap, Ty<'db>> { + debug!("calculate_diverging_fallback({:?})", unresolved_variables); + + // Construct a coercion graph where an edge `A -> B` indicates + // a type variable is that is coerced + let coercion_graph = self.create_coercion_graph(); + + // Extract the unsolved type inference variable vids; note that some + // unsolved variables are integer/float variables and are excluded. + let unsolved_vids = unresolved_variables.iter().filter_map(|ty| ty.ty_vid()); + + // Compute the diverging root vids D -- that is, the root vid of + // those type variables that (a) are the target of a coercion from + // a `!` type and (b) have not yet been solved. + // + // These variables are the ones that are targets for fallback to + // either `!` or `()`. + let diverging_roots: FxHashSet = self + .table + .diverging_type_vars + .iter() + .map(|&ty| self.shallow_resolve(ty)) + .filter_map(|ty| ty.ty_vid()) + .map(|vid| self.table.infer_ctxt.root_var(vid)) + .collect(); + debug!( + "calculate_diverging_fallback: diverging_type_vars={:?}", + self.table.diverging_type_vars + ); + debug!("calculate_diverging_fallback: diverging_roots={:?}", diverging_roots); + + // Find all type variables that are reachable from a diverging + // type variable. These will typically default to `!`, unless + // we find later that they are *also* reachable from some + // other type variable outside this set. + let mut roots_reachable_from_diverging = Dfs::empty(&coercion_graph); + let mut diverging_vids = vec![]; + let mut non_diverging_vids = vec![]; + for unsolved_vid in unsolved_vids { + let root_vid = self.table.infer_ctxt.root_var(unsolved_vid); + debug!( + "calculate_diverging_fallback: unsolved_vid={:?} root_vid={:?} diverges={:?}", + unsolved_vid, + root_vid, + diverging_roots.contains(&root_vid), + ); + if diverging_roots.contains(&root_vid) { + diverging_vids.push(unsolved_vid); + roots_reachable_from_diverging.move_to(root_vid.as_u32().into()); + + // drain the iterator to visit all nodes reachable from this node + while roots_reachable_from_diverging.next(&coercion_graph).is_some() {} + } else { + non_diverging_vids.push(unsolved_vid); + } + } + + debug!( + "calculate_diverging_fallback: roots_reachable_from_diverging={:?}", + roots_reachable_from_diverging, + ); + + // Find all type variables N0 that are not reachable from a + // diverging variable, and then compute the set reachable from + // N0, which we call N. These are the *non-diverging* type + // variables. (Note that this set consists of "root variables".) + let mut roots_reachable_from_non_diverging = Dfs::empty(&coercion_graph); + for &non_diverging_vid in &non_diverging_vids { + let root_vid = self.table.infer_ctxt.root_var(non_diverging_vid); + if roots_reachable_from_diverging.discovered.contains(root_vid.as_usize()) { + continue; + } + roots_reachable_from_non_diverging.move_to(root_vid.as_u32().into()); + while roots_reachable_from_non_diverging.next(&coercion_graph).is_some() {} + } + debug!( + "calculate_diverging_fallback: roots_reachable_from_non_diverging={:?}", + roots_reachable_from_non_diverging, + ); + + debug!("obligations: {:#?}", self.table.fulfillment_cx.pending_obligations()); + + // For each diverging variable, figure out whether it can + // reach a member of N. If so, it falls back to `()`. Else + // `!`. + let mut diverging_fallback = + FxHashMap::with_capacity_and_hasher(diverging_vids.len(), FxBuildHasher); + + for &diverging_vid in &diverging_vids { + let diverging_ty = Ty::new_var(self.table.interner, diverging_vid); + let root_vid = self.table.infer_ctxt.root_var(diverging_vid); + let can_reach_non_diverging = Dfs::new(&coercion_graph, root_vid.as_u32().into()) + .iter(&coercion_graph) + .any(|n| roots_reachable_from_non_diverging.discovered.contains(n.index())); + + let mut fallback_to = |ty| { + diverging_fallback.insert(diverging_ty, ty); + }; + + match behavior { + DivergingFallbackBehavior::ToUnit => { + debug!("fallback to () - legacy: {:?}", diverging_vid); + fallback_to(self.types.unit); + } + DivergingFallbackBehavior::ContextDependent => { + // FIXME: rustc does the following, but given this is only relevant when the unstable + // `never_type_fallback` feature is active, I chose to not port this. + // if found_infer_var_info.self_in_trait && found_infer_var_info.output { + // // This case falls back to () to ensure that the code pattern in + // // tests/ui/never_type/fallback-closure-ret.rs continues to + // // compile when never_type_fallback is enabled. + // // + // // This rule is not readily explainable from first principles, + // // but is rather intended as a patchwork fix to ensure code + // // which compiles before the stabilization of never type + // // fallback continues to work. + // // + // // Typically this pattern is encountered in a function taking a + // // closure as a parameter, where the return type of that closure + // // (checked by `relationship.output`) is expected to implement + // // some trait (checked by `relationship.self_in_trait`). This + // // can come up in non-closure cases too, so we do not limit this + // // rule to specifically `FnOnce`. + // // + // // When the closure's body is something like `panic!()`, the + // // return type would normally be inferred to `!`. However, it + // // needs to fall back to `()` in order to still compile, as the + // // trait is specifically implemented for `()` but not `!`. + // // + // // For details on the requirements for these relationships to be + // // set, see the relationship finding module in + // // compiler/rustc_trait_selection/src/traits/relationships.rs. + // debug!("fallback to () - found trait and projection: {:?}", diverging_vid); + // fallback_to(self.types.unit); + // } + if can_reach_non_diverging { + debug!("fallback to () - reached non-diverging: {:?}", diverging_vid); + fallback_to(self.types.unit); + } else { + debug!("fallback to ! - all diverging: {:?}", diverging_vid); + fallback_to(self.types.never); + } + } + DivergingFallbackBehavior::ToNever => { + debug!( + "fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}", + diverging_vid + ); + fallback_to(self.types.never); + } + } + } + + diverging_fallback + } + + /// Returns a graph whose nodes are (unresolved) inference variables and where + /// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`. + fn create_coercion_graph(&self) -> Graph<(), ()> { + let pending_obligations = self.table.fulfillment_cx.pending_obligations(); + let pending_obligations_len = pending_obligations.len(); + debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations); + let coercion_edges = pending_obligations + .into_iter() + .filter_map(|obligation| { + // The predicates we are looking for look like `Coerce(?A -> ?B)`. + // They will have no bound variables. + obligation.predicate.kind().no_bound_vars() + }) + .filter_map(|atom| { + // We consider both subtyping and coercion to imply 'flow' from + // some position in the code `a` to a different position `b`. + // This is then used to determine which variables interact with + // live code, and as such must fall back to `()` to preserve + // soundness. + // + // In practice currently the two ways that this happens is + // coercion and subtyping. + let (a, b) = match atom { + PredicateKind::Coerce(CoercePredicate { a, b }) => (a, b), + PredicateKind::Subtype(SubtypePredicate { a_is_expected: _, a, b }) => (a, b), + _ => return None, + }; + + let a_vid = self.root_vid(a)?; + let b_vid = self.root_vid(b)?; + Some((a_vid.as_u32(), b_vid.as_u32())) + }); + let num_ty_vars = self.table.infer_ctxt.num_ty_vars(); + let mut graph = Graph::with_capacity(num_ty_vars, pending_obligations_len); + for _ in 0..num_ty_vars { + graph.add_node(()); + } + graph.extend_with_edges(coercion_edges); + graph + } + + /// If `ty` is an unresolved type variable, returns its root vid. + fn root_vid(&self, ty: Ty<'db>) -> Option { + Some(self.table.infer_ctxt.root_var(self.shallow_resolve(ty).ty_vid()?)) + } +} diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index dd7e77ba8c..108cf5b1a2 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -3,8 +3,7 @@ use std::fmt; use chalk_ir::{ - CanonicalVarKind, FloatTy, IntTy, TyVariableKind, cast::Cast, fold::TypeFoldable, - interner::HasInterner, + CanonicalVarKind, TyVariableKind, cast::Cast, fold::TypeFoldable, interner::HasInterner, }; use either::Either; use hir_def::{AdtId, lang_item::LangItem}; @@ -12,7 +11,7 @@ use hir_expand::name::Name; use intern::sym; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_type_ir::{ - FloatVid, IntVid, TyVid, TypeVisitableExt, UpcastFrom, + TyVid, TypeVisitableExt, UpcastFrom, inherent::{IntoKind, Span, Term as _, Ty as _}, relate::{Relate, solver_relating::RelateExt}, solve::{Certainty, GoalSource}, @@ -23,8 +22,8 @@ use triomphe::Arc; use super::{InferResult, InferenceContext, TypeError}; use crate::{ AliasTy, BoundVar, Canonical, Const, ConstValue, DebruijnIndex, GenericArg, GenericArgData, - InferenceVar, Interner, Lifetime, OpaqueTyId, ProjectionTy, Scalar, Substitution, - TraitEnvironment, Ty, TyExt, TyKind, VariableKind, + InferenceVar, Interner, Lifetime, OpaqueTyId, ProjectionTy, Substitution, TraitEnvironment, Ty, + TyExt, TyKind, VariableKind, consteval::unknown_const, db::HirDatabase, fold_generic_args, fold_tys_and_consts, @@ -143,7 +142,6 @@ pub fn could_unify_deeply( let ty1_with_vars = table.normalize_associated_types_in(ty1_with_vars); let ty2_with_vars = table.normalize_associated_types_in(ty2_with_vars); table.select_obligations_where_possible(); - table.propagate_diverging_flag(); let ty1_with_vars = table.resolve_completely(ty1_with_vars); let ty2_with_vars = table.resolve_completely(ty2_with_vars); table.unify_deeply(&ty1_with_vars, &ty2_with_vars) @@ -170,13 +168,19 @@ pub(crate) fn unify( GenericArgData::Const(c) => c.inference_var(Interner), } == Some(iv)) }; - let fallback = |iv, kind, default, binder| match kind { - chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv) - .map_or(default, |i| BoundVar::new(binder, i).to_ty(Interner).cast(Interner)), - chalk_ir::VariableKind::Lifetime => find_var(iv) - .map_or(default, |i| BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner)), - chalk_ir::VariableKind::Const(ty) => find_var(iv) - .map_or(default, |i| BoundVar::new(binder, i).to_const(Interner, ty).cast(Interner)), + let fallback = |iv, kind, binder| match kind { + chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv).map_or_else( + || TyKind::Error.intern(Interner).cast(Interner), + |i| BoundVar::new(binder, i).to_ty(Interner).cast(Interner), + ), + chalk_ir::VariableKind::Lifetime => find_var(iv).map_or_else( + || crate::error_lifetime().cast(Interner), + |i| BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner), + ), + chalk_ir::VariableKind::Const(ty) => find_var(iv).map_or_else( + || crate::unknown_const(ty.clone()).cast(Interner), + |i| BoundVar::new(binder, i).to_const(Interner, ty.clone()).cast(Interner), + ), }; Some(Substitution::from_iter( Interner, @@ -215,14 +219,13 @@ pub(crate) struct InferenceTable<'db> { pub(crate) trait_env: Arc>, pub(crate) tait_coercion_table: Option>, pub(crate) infer_ctxt: InferCtxt<'db>, - diverging_tys: FxHashSet, pub(super) fulfillment_cx: FulfillmentCtxt<'db>, + pub(super) diverging_type_vars: FxHashSet>, } pub(crate) struct InferenceTableSnapshot<'db> { ctxt_snapshot: CombinedSnapshot, obligations: FulfillmentCtxt<'db>, - diverging_tys: FxHashSet, } impl<'db> InferenceTable<'db> { @@ -238,7 +241,7 @@ impl<'db> InferenceTable<'db> { tait_coercion_table: None, fulfillment_cx: FulfillmentCtxt::new(&infer_ctxt), infer_ctxt, - diverging_tys: FxHashSet::default(), + diverging_type_vars: FxHashSet::default(), } } @@ -321,74 +324,8 @@ impl<'db> InferenceTable<'db> { } } - /// Chalk doesn't know about the `diverging` flag, so when it unifies two - /// type variables of which one is diverging, the chosen root might not be - /// diverging and we have no way of marking it as such at that time. This - /// function goes through all type variables and make sure their root is - /// marked as diverging if necessary, so that resolving them gives the right - /// result. - pub(super) fn propagate_diverging_flag(&mut self) { - let mut new_tys = FxHashSet::default(); - for ty in self.diverging_tys.iter() { - match ty.kind(Interner) { - TyKind::InferenceVar(var, kind) => match kind { - TyVariableKind::General => { - let root = InferenceVar::from( - self.infer_ctxt.root_var(TyVid::from_u32(var.index())).as_u32(), - ); - if root.index() != var.index() { - new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); - } - } - TyVariableKind::Integer => { - let root = InferenceVar::from( - self.infer_ctxt - .inner - .borrow_mut() - .int_unification_table() - .find(IntVid::from_usize(var.index() as usize)) - .as_u32(), - ); - if root.index() != var.index() { - new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); - } - } - TyVariableKind::Float => { - let root = InferenceVar::from( - self.infer_ctxt - .inner - .borrow_mut() - .float_unification_table() - .find(FloatVid::from_usize(var.index() as usize)) - .as_u32(), - ); - if root.index() != var.index() { - new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); - } - } - }, - _ => {} - } - } - self.diverging_tys.extend(new_tys); - } - - pub(super) fn set_diverging(&mut self, iv: InferenceVar, kind: TyVariableKind) { - self.diverging_tys.insert(TyKind::InferenceVar(iv, kind).intern(Interner)); - } - - fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty { - let is_diverging = - self.diverging_tys.contains(&TyKind::InferenceVar(iv, kind).intern(Interner)); - if is_diverging { - return TyKind::Never.intern(Interner); - } - match kind { - TyVariableKind::General => TyKind::Error, - TyVariableKind::Integer => TyKind::Scalar(Scalar::Int(IntTy::I32)), - TyVariableKind::Float => TyKind::Scalar(Scalar::Float(FloatTy::F64)), - } - .intern(Interner) + pub(super) fn set_diverging(&mut self, ty: crate::next_solver::Ty<'db>) { + self.diverging_type_vars.insert(ty); } pub(crate) fn canonicalize(&mut self, t: T) -> rustc_type_ir::Canonical, T> @@ -529,7 +466,7 @@ impl<'db> InferenceTable<'db> { let ty = var.to_ty(Interner, kind); if diverging { - self.diverging_tys.insert(ty.clone()); + self.diverging_type_vars.insert(ty.to_nextsolver(self.interner)); } ty } @@ -573,7 +510,7 @@ impl<'db> InferenceTable<'db> { pub(crate) fn resolve_with_fallback( &mut self, t: T, - fallback: &dyn Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, + fallback: &dyn Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg, ) -> T where T: HasInterner + TypeFoldable, @@ -615,7 +552,7 @@ impl<'db> InferenceTable<'db> { fn resolve_with_fallback_inner( &mut self, t: T, - fallback: &dyn Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, + fallback: &dyn Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg, ) -> T where T: HasInterner + TypeFoldable, @@ -632,53 +569,15 @@ impl<'db> InferenceTable<'db> { T: HasInterner + TypeFoldable + ChalkToNextSolver<'db, U>, U: NextSolverToChalk<'db, T> + rustc_type_ir::TypeFoldable>, { - let t = self.resolve_with_fallback(t, &|_, _, d, _| d); - let t = self.normalize_associated_types_in(t); - // let t = self.resolve_opaque_tys_in(t); - // Resolve again, because maybe normalization inserted infer vars. - self.resolve_with_fallback(t, &|_, _, d, _| d) - } + let value = t.to_nextsolver(self.interner); + let value = self.infer_ctxt.resolve_vars_if_possible(value); - /// Apply a fallback to unresolved scalar types. Integer type variables and float type - /// variables are replaced with i32 and f64, respectively. - /// - /// This method is only intended to be called just before returning inference results (i.e. in - /// `InferenceContext::resolve_all()`). - /// - /// FIXME: This method currently doesn't apply fallback to unconstrained general type variables - /// whereas rustc replaces them with `()` or `!`. - pub(super) fn fallback_if_possible(&mut self) { - let int_fallback = TyKind::Scalar(Scalar::Int(IntTy::I32)).intern(Interner); - let float_fallback = TyKind::Scalar(Scalar::Float(FloatTy::F64)).intern(Interner); + let mut goals = vec![]; + let value = value.fold_with(&mut resolve_completely::Resolver::new(self, true, &mut goals)); - let int_vars = self.infer_ctxt.inner.borrow_mut().int_unification_table().len(); - for v in 0..int_vars { - let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Integer); - let maybe_resolved = self.resolve_ty_shallow(&var); - if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) { - // I don't think we can ever unify these vars with float vars, but keep this here for now - let fallback = match kind { - TyVariableKind::Integer => &int_fallback, - TyVariableKind::Float => &float_fallback, - TyVariableKind::General => unreachable!(), - }; - self.unify(&var, fallback); - } - } - let float_vars = self.infer_ctxt.inner.borrow_mut().float_unification_table().len(); - for v in 0..float_vars { - let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Float); - let maybe_resolved = self.resolve_ty_shallow(&var); - if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) { - // I don't think we can ever unify these vars with float vars, but keep this here for now - let fallback = match kind { - TyVariableKind::Integer => &int_fallback, - TyVariableKind::Float => &float_fallback, - TyVariableKind::General => unreachable!(), - }; - self.unify(&var, fallback); - } - } + // FIXME(next-solver): Handle `goals`. + + value.to_chalk(self.interner) } /// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that. @@ -829,15 +728,13 @@ impl<'db> InferenceTable<'db> { pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot<'db> { let ctxt_snapshot = self.infer_ctxt.start_snapshot(); - let diverging_tys = self.diverging_tys.clone(); let obligations = self.fulfillment_cx.clone(); - InferenceTableSnapshot { ctxt_snapshot, diverging_tys, obligations } + InferenceTableSnapshot { ctxt_snapshot, obligations } } #[tracing::instrument(skip_all)] pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot<'db>) { self.infer_ctxt.rollback_to(snapshot.ctxt_snapshot); - self.diverging_tys = snapshot.diverging_tys; self.fulfillment_cx = snapshot.obligations; } @@ -1166,14 +1063,10 @@ impl fmt::Debug for InferenceTable<'_> { mod resolve { use super::InferenceTable; use crate::{ - ConcreteConst, Const, ConstData, ConstScalar, ConstValue, DebruijnIndex, GenericArg, - InferenceVar, Interner, Lifetime, Ty, TyVariableKind, VariableKind, - next_solver::mapping::NextSolverToChalk, - }; - use chalk_ir::{ - cast::Cast, - fold::{TypeFoldable, TypeFolder}, + Const, DebruijnIndex, GenericArg, InferenceVar, Interner, Lifetime, Ty, TyVariableKind, + VariableKind, next_solver::mapping::NextSolverToChalk, }; + use chalk_ir::fold::{TypeFoldable, TypeFolder}; use rustc_type_ir::{FloatVid, IntVid, TyVid}; #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -1187,7 +1080,7 @@ mod resolve { pub(super) struct Resolver< 'a, 'b, - F: Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, + F: Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg, > { pub(super) table: &'a mut InferenceTable<'b>, pub(super) var_stack: &'a mut Vec<(InferenceVar, VarKind)>, @@ -1195,7 +1088,7 @@ mod resolve { } impl TypeFolder for Resolver<'_, '_, F> where - F: Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, + F: Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg, { fn as_dyn(&mut self) -> &mut dyn TypeFolder { self @@ -1217,8 +1110,7 @@ mod resolve { let var = InferenceVar::from(vid.as_u32()); if self.var_stack.contains(&(var, VarKind::Ty(kind))) { // recursive type - let default = self.table.fallback_value(var, kind).cast(Interner); - return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + return (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone(); } @@ -1230,8 +1122,7 @@ mod resolve { self.var_stack.pop(); result } else { - let default = self.table.fallback_value(var, kind).cast(Interner); - (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone() } @@ -1247,8 +1138,7 @@ mod resolve { let var = InferenceVar::from(vid.as_u32()); if self.var_stack.contains(&(var, VarKind::Ty(kind))) { // recursive type - let default = self.table.fallback_value(var, kind).cast(Interner); - return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + return (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone(); } @@ -1260,8 +1150,7 @@ mod resolve { self.var_stack.pop(); result } else { - let default = self.table.fallback_value(var, kind).cast(Interner); - (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone() } @@ -1277,8 +1166,7 @@ mod resolve { let var = InferenceVar::from(vid.as_u32()); if self.var_stack.contains(&(var, VarKind::Ty(kind))) { // recursive type - let default = self.table.fallback_value(var, kind).cast(Interner); - return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + return (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone(); } @@ -1290,8 +1178,7 @@ mod resolve { self.var_stack.pop(); result } else { - let default = self.table.fallback_value(var, kind).cast(Interner); - (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + (self.fallback)(var, VariableKind::Ty(kind), outer_binder) .assert_ty_ref(Interner) .clone() } @@ -1310,15 +1197,9 @@ mod resolve { .infer_ctxt .root_const_var(rustc_type_ir::ConstVid::from_u32(var.index())); let var = InferenceVar::from(vid.as_u32()); - let default = ConstData { - ty: ty.clone(), - value: ConstValue::Concrete(ConcreteConst { interned: ConstScalar::Unknown }), - } - .intern(Interner) - .cast(Interner); if self.var_stack.contains(&(var, VarKind::Const)) { // recursive - return (self.fallback)(var, VariableKind::Const(ty), default, outer_binder) + return (self.fallback)(var, VariableKind::Const(ty), outer_binder) .assert_const_ref(Interner) .clone(); } @@ -1330,7 +1211,7 @@ mod resolve { self.var_stack.pop(); result } else { - (self.fallback)(var, VariableKind::Const(ty), default, outer_binder) + (self.fallback)(var, VariableKind::Const(ty), outer_binder) .assert_const_ref(Interner) .clone() } @@ -1349,3 +1230,124 @@ mod resolve { } } } + +mod resolve_completely { + use rustc_type_ir::{ + DebruijnIndex, Flags, TypeFolder, TypeSuperFoldable, + inherent::{Const as _, Ty as _}, + }; + + use crate::next_solver::Region; + use crate::{ + infer::unify::InferenceTable, + next_solver::{ + Const, DbInterner, ErrorGuaranteed, Goal, Predicate, Term, Ty, + infer::traits::ObligationCause, + normalize::deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals, + }, + }; + + pub(super) struct Resolver<'a, 'db> { + ctx: &'a mut InferenceTable<'db>, + /// Whether we should normalize, disabled when resolving predicates. + should_normalize: bool, + nested_goals: &'a mut Vec>>, + } + + impl<'a, 'db> Resolver<'a, 'db> { + pub(super) fn new( + ctx: &'a mut InferenceTable<'db>, + should_normalize: bool, + nested_goals: &'a mut Vec>>, + ) -> Resolver<'a, 'db> { + Resolver { ctx, nested_goals, should_normalize } + } + + fn handle_term( + &mut self, + value: T, + outer_exclusive_binder: impl FnOnce(T) -> DebruijnIndex, + ) -> T + where + T: Into> + TypeSuperFoldable> + Copy, + { + let value = if self.should_normalize { + let cause = ObligationCause::new(); + let at = self.ctx.infer_ctxt.at(&cause, self.ctx.trait_env.env); + let universes = vec![None; outer_exclusive_binder(value).as_usize()]; + match deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals( + at, value, universes, + ) { + Ok((value, goals)) => { + self.nested_goals.extend(goals); + value + } + Err(_errors) => { + // FIXME: Report the error. + value + } + } + } else { + value + }; + + value.fold_with(&mut ReplaceInferWithError { interner: self.ctx.interner }) + } + } + + impl<'cx, 'db> TypeFolder> for Resolver<'cx, 'db> { + fn cx(&self) -> DbInterner<'db> { + self.ctx.interner + } + + fn fold_region(&mut self, r: Region<'db>) -> Region<'db> { + if r.is_var() { Region::error(self.ctx.interner) } else { r } + } + + fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> { + self.handle_term(ty, |it| it.outer_exclusive_binder()) + } + + fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> { + self.handle_term(ct, |it| it.outer_exclusive_binder()) + } + + fn fold_predicate(&mut self, predicate: Predicate<'db>) -> Predicate<'db> { + assert!( + !self.should_normalize, + "normalizing predicates in writeback is not generally sound" + ); + predicate.super_fold_with(self) + } + } + + struct ReplaceInferWithError<'db> { + interner: DbInterner<'db>, + } + + impl<'db> TypeFolder> for ReplaceInferWithError<'db> { + fn cx(&self) -> DbInterner<'db> { + self.interner + } + + fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> { + if t.is_infer() { + Ty::new_error(self.interner, ErrorGuaranteed) + } else { + t.super_fold_with(self) + } + } + + fn fold_const(&mut self, c: Const<'db>) -> Const<'db> { + if c.is_ct_infer() { + Const::new_error(self.interner, ErrorGuaranteed) + } else { + c.super_fold_with(self) + } + } + + fn fold_region(&mut self, r: Region<'db>) -> Region<'db> { + if r.is_var() { Region::error(self.interner) } else { r } + } + } +} diff --git a/crates/hir-ty/src/next_solver.rs b/crates/hir-ty/src/next_solver.rs index 073a02908d..ab167e88af 100644 --- a/crates/hir-ty/src/next_solver.rs +++ b/crates/hir-ty/src/next_solver.rs @@ -13,7 +13,7 @@ pub(crate) mod inspect; pub mod interner; mod ir_print; pub mod mapping; -mod normalize; +pub mod normalize; pub mod obligation_ctxt; mod opaques; pub mod predicate; diff --git a/crates/hir-ty/src/next_solver/region.rs b/crates/hir-ty/src/next_solver/region.rs index d6214d9915..0bfd2b8003 100644 --- a/crates/hir-ty/src/next_solver/region.rs +++ b/crates/hir-ty/src/next_solver/region.rs @@ -15,7 +15,7 @@ use super::{ interner::{BoundVarKind, DbInterner, Placeholder}, }; -type RegionKind<'db> = rustc_type_ir::RegionKind>; +pub type RegionKind<'db> = rustc_type_ir::RegionKind>; #[salsa::interned(constructor = new_, debug)] pub struct Region<'db> { @@ -53,6 +53,10 @@ impl<'db> Region<'db> { Region::new(interner, RegionKind::ReVar(v)) } + pub fn new_erased(interner: DbInterner<'db>) -> Region<'db> { + Region::new(interner, RegionKind::ReErased) + } + pub fn is_placeholder(&self) -> bool { matches!(self.inner(), RegionKind::RePlaceholder(..)) } @@ -61,6 +65,10 @@ impl<'db> Region<'db> { matches!(self.inner(), RegionKind::ReStatic) } + pub fn is_var(&self) -> bool { + matches!(self.inner(), RegionKind::ReVar(_)) + } + pub fn error(interner: DbInterner<'db>) -> Self { Region::new(interner, RegionKind::ReError(ErrorGuaranteed)) } diff --git a/crates/hir-ty/src/next_solver/ty.rs b/crates/hir-ty/src/next_solver/ty.rs index c7a747ade3..16bf082b01 100644 --- a/crates/hir-ty/src/next_solver/ty.rs +++ b/crates/hir-ty/src/next_solver/ty.rs @@ -7,6 +7,7 @@ use hir_def::{GenericDefId, TypeOrConstParamId, TypeParamId}; use intern::{Interned, Symbol, sym}; use rustc_abi::{Float, Integer, Size}; use rustc_ast_ir::{Mutability, try_visit, visit::VisitorResult}; +use rustc_type_ir::TyVid; use rustc_type_ir::{ BoundVar, ClosureKind, CollectAndApply, FlagComputation, Flags, FloatTy, FloatVid, InferTy, IntTy, IntVid, Interner, TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, @@ -338,6 +339,14 @@ impl<'db> Ty<'db> { matches!(self.kind(), TyKind::Tuple(tys) if tys.inner().is_empty()) } + #[inline] + pub fn ty_vid(self) -> Option { + match self.kind() { + TyKind::Infer(rustc_type_ir::TyVar(vid)) => Some(vid), + _ => None, + } + } + /// Given a `fn` type, returns an equivalent `unsafe fn` type; /// that is, a `fn` type that is equivalent in every way for being /// unsafe. diff --git a/crates/hir-ty/src/tests/never_type.rs b/crates/hir-ty/src/tests/never_type.rs index af5290d720..4d68179a88 100644 --- a/crates/hir-ty/src/tests/never_type.rs +++ b/crates/hir-ty/src/tests/never_type.rs @@ -14,8 +14,6 @@ fn test() { ); } -// FIXME(next-solver): The never type fallback implemented in r-a no longer works properly because of -// `Coerce` predicates. We should reimplement fallback like rustc. #[test] fn infer_never2() { check_types( @@ -26,7 +24,7 @@ fn test() { let a = gen(); if false { a } else { loop {} }; a; -} //^ {unknown} +} //^ ! "#, ); } @@ -41,7 +39,7 @@ fn test() { let a = gen(); if false { loop {} } else { a }; a; - //^ {unknown} + //^ ! } "#, ); @@ -56,7 +54,7 @@ enum Option { None, Some(T) } fn test() { let a = if true { Option::None } else { Option::Some(return) }; a; -} //^ Option<{unknown}> +} //^ Option "#, ); } @@ -220,7 +218,7 @@ fn test(a: i32) { _ => loop {}, }; i; -} //^ {unknown} +} //^ ! "#, ); } diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index a6215ef8fe..00835aa031 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -1951,7 +1951,7 @@ fn main() { Alias::Braced; //^^^^^^^^^^^^^ {unknown} let Alias::Braced = loop {}; - //^^^^^^^^^^^^^ {unknown} + //^^^^^^^^^^^^^ ! let Alias::Braced(..) = loop {}; //^^^^^^^^^^^^^^^^^ Enum diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs index 1db4f8ecd6..920bdd9568 100644 --- a/crates/intern/src/symbol/symbols.rs +++ b/crates/intern/src/symbol/symbols.rs @@ -516,4 +516,5 @@ define_symbols! { flags, precision, width, + never_type_fallback, }