diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index 5697e73f73a7..e1d0da803fce 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -691,11 +691,12 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { &self, a: ty::Unevaluated<'tcx, ()>, b: ty::Unevaluated<'tcx, ()>, + param_env: ty::ParamEnv<'tcx>, ) -> bool { let canonical = self.canonicalize_query((a, b), &mut OriginalQueryValues::default()); debug!("canonical consts: {:?}", &canonical.value); - self.tcx.try_unify_abstract_consts(canonical.value) + self.tcx.try_unify_abstract_consts(param_env.and(canonical.value)) } pub fn is_in_snapshot(&self) -> bool { diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 44b622c1e3d8..d39ae43fe8c1 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -329,12 +329,12 @@ rustc_queries! { } } - query try_unify_abstract_consts(key: ( - ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()> - )) -> bool { + query try_unify_abstract_consts(key: + ty::ParamEnvAnd<'tcx, (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()> + )>) -> bool { desc { |tcx| "trying to unify the generic constants {} and {}", - tcx.def_path_str(key.0.def.did), tcx.def_path_str(key.1.def.did) + tcx.def_path_str(key.value.0.def.did), tcx.def_path_str(key.value.1.def.did) } } diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index 81ee7942c4d3..5d6cbcf69070 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -585,7 +585,7 @@ pub fn super_relate_consts<'tcx, R: TypeRelation<'tcx>>( (ty::ConstKind::Unevaluated(au), ty::ConstKind::Unevaluated(bu)) if tcx.features().generic_const_exprs => { - tcx.try_unify_abstract_consts((au.shrink(), bu.shrink())) + tcx.try_unify_abstract_consts(relation.param_env().and((au.shrink(), bu.shrink()))) } // While this is slightly incorrect, it shouldn't matter for `min_const_generics` diff --git a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs index f880b28b3c8d..0c33ea858fd0 100644 --- a/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs +++ b/compiler/rustc_trait_selection/src/traits/const_evaluatable.rs @@ -28,13 +28,13 @@ use std::iter; use std::ops::ControlFlow; /// Check if a given constant can be evaluated. +#[instrument(skip(infcx), level = "debug")] pub fn is_const_evaluatable<'cx, 'tcx>( infcx: &InferCtxt<'cx, 'tcx>, uv: ty::Unevaluated<'tcx, ()>, param_env: ty::ParamEnv<'tcx>, span: Span, ) -> Result<(), NotConstEvaluatable> { - debug!("is_const_evaluatable({:?})", uv); let tcx = infcx.tcx; if tcx.features().generic_const_exprs { @@ -185,6 +185,7 @@ pub fn is_const_evaluatable<'cx, 'tcx>( } } +#[instrument(skip(tcx), level = "debug")] fn satisfied_from_param_env<'tcx>( tcx: TyCtxt<'tcx>, ct: AbstractConst<'tcx>, @@ -197,11 +198,12 @@ fn satisfied_from_param_env<'tcx>( // Try to unify with each subtree in the AbstractConst to allow for // `N + 1` being const evaluatable even if theres only a `ConstEvaluatable` // predicate for `(N + 1) * 2` - let result = - walk_abstract_const(tcx, b_ct, |b_ct| match try_unify(tcx, ct, b_ct) { + let result = walk_abstract_const(tcx, b_ct, |b_ct| { + match try_unify(tcx, ct, b_ct, param_env) { true => ControlFlow::BREAK, false => ControlFlow::CONTINUE, - }); + } + }); if let ControlFlow::Break(()) = result { debug!("is_const_evaluatable: abstract_const ~~> ok"); @@ -570,11 +572,12 @@ pub(super) fn thir_abstract_const<'tcx>( pub(super) fn try_unify_abstract_consts<'tcx>( tcx: TyCtxt<'tcx>, (a, b): (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>), + param_env: ty::ParamEnv<'tcx>, ) -> bool { (|| { if let Some(a) = AbstractConst::new(tcx, a)? { if let Some(b) = AbstractConst::new(tcx, b)? { - return Ok(try_unify(tcx, a, b)); + return Ok(try_unify(tcx, a, b, param_env)); } } @@ -619,32 +622,59 @@ where recurse(tcx, ct, &mut f) } -/// Tries to unify two abstract constants using structural equality. -pub(super) fn try_unify<'tcx>( +// Substitutes generics repeatedly to allow AbstractConsts to unify where a +// ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g. +// Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])] +#[inline] +#[instrument(skip(tcx), level = "debug")] +fn try_replace_substs_in_root<'tcx>( tcx: TyCtxt<'tcx>, - mut a: AbstractConst<'tcx>, - mut b: AbstractConst<'tcx>, -) -> bool { - // We substitute generics repeatedly to allow AbstractConsts to unify where a - // ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g. - // Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])] - while let Node::Leaf(a_ct) = a.root(tcx) { - match AbstractConst::from_const(tcx, a_ct) { - Ok(Some(a_act)) => a = a_act, + mut abstr_const: AbstractConst<'tcx>, +) -> Option> { + while let Node::Leaf(ct) = abstr_const.root(tcx) { + match AbstractConst::from_const(tcx, ct) { + Ok(Some(act)) => abstr_const = act, Ok(None) => break, - Err(_) => return true, - } - } - while let Node::Leaf(b_ct) = b.root(tcx) { - match AbstractConst::from_const(tcx, b_ct) { - Ok(Some(b_act)) => b = b_act, - Ok(None) => break, - Err(_) => return true, + Err(_) => return None, } } - match (a.root(tcx), b.root(tcx)) { + Some(abstr_const) +} + +/// Tries to unify two abstract constants using structural equality. +#[instrument(skip(tcx), level = "debug")] +pub(super) fn try_unify<'tcx>( + tcx: TyCtxt<'tcx>, + a: AbstractConst<'tcx>, + b: AbstractConst<'tcx>, + param_env: ty::ParamEnv<'tcx>, +) -> bool { + let a = match try_replace_substs_in_root(tcx, a) { + Some(a) => a, + None => { + return true; + } + }; + + let b = match try_replace_substs_in_root(tcx, b) { + Some(b) => b, + None => { + return true; + } + }; + + let a_root = a.root(tcx); + let b_root = b.root(tcx); + debug!(?a_root, ?b_root); + + match (a_root, b_root) { (Node::Leaf(a_ct), Node::Leaf(b_ct)) => { + let a_ct = a_ct.eval(tcx, param_env); + debug!("a_ct evaluated: {:?}", a_ct); + let b_ct = b_ct.eval(tcx, param_env); + debug!("b_ct evaluated: {:?}", b_ct); + if a_ct.ty() != b_ct.ty() { return false; } @@ -678,23 +708,23 @@ pub(super) fn try_unify<'tcx>( } } (Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => { - try_unify(tcx, a.subtree(al), b.subtree(bl)) - && try_unify(tcx, a.subtree(ar), b.subtree(br)) + try_unify(tcx, a.subtree(al), b.subtree(bl), param_env) + && try_unify(tcx, a.subtree(ar), b.subtree(br), param_env) } (Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => { - try_unify(tcx, a.subtree(av), b.subtree(bv)) + try_unify(tcx, a.subtree(av), b.subtree(bv), param_env) } (Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args)) if a_args.len() == b_args.len() => { - try_unify(tcx, a.subtree(a_f), b.subtree(b_f)) + try_unify(tcx, a.subtree(a_f), b.subtree(b_f), param_env) && iter::zip(a_args, b_args) - .all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn))) + .all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn), param_env)) } (Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty)) if (a_ty == b_ty) && (a_kind == b_kind) => { - try_unify(tcx, a.subtree(a_operand), b.subtree(b_operand)) + try_unify(tcx, a.subtree(a_operand), b.subtree(b_operand), param_env) } // use this over `_ => false` to make adding variants to `Node` less error prone (Node::Cast(..), _)