eagerly compute sub_relations again

This commit is contained in:
lcnr 2025-08-27 10:03:47 +02:00
parent f099b241e2
commit 67965f817d
15 changed files with 152 additions and 111 deletions

View file

@ -21,7 +21,6 @@ use rustc_middle::ty::{self, Const, Ty, TyCtxt, TypeVisitableExt};
use rustc_session::Session;
use rustc_span::{self, DUMMY_SP, ErrorGuaranteed, Ident, Span, sym};
use rustc_trait_selection::error_reporting::TypeErrCtxt;
use rustc_trait_selection::error_reporting::infer::sub_relations::SubRelations;
use rustc_trait_selection::traits::{
self, FulfillmentError, ObligationCause, ObligationCauseCode, ObligationCtxt,
};
@ -188,14 +187,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
///
/// [`InferCtxtErrorExt::err_ctxt`]: rustc_trait_selection::error_reporting::InferCtxtErrorExt::err_ctxt
pub(crate) fn err_ctxt(&'a self) -> TypeErrCtxt<'a, 'tcx> {
let mut sub_relations = SubRelations::default();
sub_relations.add_constraints(
self,
self.fulfillment_cx.borrow_mut().pending_obligations().iter().map(|o| o.predicate),
);
TypeErrCtxt {
infcx: &self.infcx,
sub_relations: RefCell::new(sub_relations),
typeck_results: Some(self.typeck_results.borrow()),
fallback_has_occurred: self.fallback_has_occurred.get(),
normalize_fn_sig: Box::new(|fn_sig| {

View file

@ -179,6 +179,10 @@ impl<'tcx> rustc_type_ir::InferCtxtLike for InferCtxt<'tcx> {
self.inner.borrow_mut().type_variables().equate(a, b);
}
fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) {
self.sub_ty_vids_raw(a, b);
}
fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid) {
self.inner.borrow_mut().int_unification_table().union(a, b);
}

View file

@ -764,6 +764,7 @@ impl<'tcx> InferCtxt<'tcx> {
let r_b = self.shallow_resolve(predicate.skip_binder().b);
match (r_a.kind(), r_b.kind()) {
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
self.sub_ty_vids_raw(a_vid, b_vid);
return Err((a_vid, b_vid));
}
_ => {}
@ -1128,6 +1129,14 @@ impl<'tcx> InferCtxt<'tcx> {
self.inner.borrow_mut().type_variables().root_var(var)
}
pub fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) {
self.inner.borrow_mut().type_variables().sub(a, b);
}
pub fn sub_root_var(&self, var: ty::TyVid) -> ty::TyVid {
self.inner.borrow_mut().type_variables().sub_root_var(var)
}
pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid {
self.inner.borrow_mut().const_unification_table().find(var).vid
}

View file

@ -519,6 +519,10 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for Generalizer<'_, 'tcx> {
let origin = inner.type_variables().var_origin(vid);
let new_var_id =
inner.type_variables().new_var(self.for_universe, origin);
// Record that `vid` and `new_var_id` have to be subtypes
// of each other. This is currently only used for diagnostics.
// To see why, see the docs in the `type_variables` module.
inner.type_variables().sub(vid, new_var_id);
// If we're in the new solver and create a new inference
// variable inside of an alias we eagerly constrain that
// inference variable to prevent unexpected ambiguity errors.

View file

@ -20,7 +20,7 @@ pub struct Snapshot<'tcx> {
pub(crate) enum UndoLog<'tcx> {
DuplicateOpaqueType,
OpaqueTypes(OpaqueTypeKey<'tcx>, Option<OpaqueHiddenType<'tcx>>),
TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'tcx>>>),
TypeVariables(type_variable::UndoLog<'tcx>),
ConstUnificationTable(sv::UndoLog<ut::Delegate<ConstVidKey<'tcx>>>),
IntUnificationTable(sv::UndoLog<ut::Delegate<ty::IntVid>>),
FloatUnificationTable(sv::UndoLog<ut::Delegate<ty::FloatVid>>),
@ -49,6 +49,8 @@ impl_from! {
RegionConstraintCollector(region_constraints::UndoLog<'tcx>),
TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'tcx>>>),
TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidSubKey>>),
TypeVariables(type_variable::UndoLog<'tcx>),
IntUnificationTable(sv::UndoLog<ut::Delegate<ty::IntVid>>),
FloatUnificationTable(sv::UndoLog<ut::Delegate<ty::FloatVid>>),

View file

@ -13,12 +13,48 @@ use tracing::debug;
use crate::infer::InferCtxtUndoLogs;
/// Represents a single undo-able action that affects a type inference variable.
#[derive(Clone)]
pub(crate) enum UndoLog<'tcx> {
EqRelation(sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>),
SubRelation(sv::UndoLog<ut::Delegate<TyVidSubKey>>),
}
/// Convert from a specific kind of undo to the more general UndoLog
impl<'tcx> From<sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>> for UndoLog<'tcx> {
fn from(l: sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>) -> Self {
UndoLog::EqRelation(l)
}
}
/// Convert from a specific kind of undo to the more general UndoLog
impl<'tcx> From<sv::UndoLog<ut::Delegate<TyVidSubKey>>> for UndoLog<'tcx> {
fn from(l: sv::UndoLog<ut::Delegate<TyVidSubKey>>) -> Self {
UndoLog::SubRelation(l)
}
}
impl<'tcx> Rollback<sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>> for TypeVariableStorage<'tcx> {
fn reverse(&mut self, undo: sv::UndoLog<ut::Delegate<TyVidEqKey<'tcx>>>) {
self.eq_relations.reverse(undo)
}
}
impl<'tcx> Rollback<sv::UndoLog<ut::Delegate<TyVidSubKey>>> for TypeVariableStorage<'tcx> {
fn reverse(&mut self, undo: sv::UndoLog<ut::Delegate<TyVidSubKey>>) {
self.sub_relations.reverse(undo)
}
}
impl<'tcx> Rollback<UndoLog<'tcx>> for TypeVariableStorage<'tcx> {
fn reverse(&mut self, undo: UndoLog<'tcx>) {
match undo {
UndoLog::EqRelation(undo) => self.eq_relations.reverse(undo),
UndoLog::SubRelation(undo) => self.sub_relations.reverse(undo),
}
}
}
#[derive(Clone, Default)]
pub(crate) struct TypeVariableStorage<'tcx> {
/// The origins of each type variable.
@ -27,6 +63,23 @@ pub(crate) struct TypeVariableStorage<'tcx> {
/// constraint `?X == ?Y`. This table also stores, for each key,
/// the known value.
eq_relations: ut::UnificationTableStorage<TyVidEqKey<'tcx>>,
/// Only used by `-Znext-solver` and for diagnostics.
///
/// When reporting ambiguity errors, we sometimes want to
/// treat all inference vars which are subtypes of each
/// others as if they are equal. For this case we compute
/// the transitive closure of our subtype obligations here.
///
/// E.g. when encountering ambiguity errors, we want to suggest
/// specifying some method argument or to add a type annotation
/// to a local variable. Because subtyping cannot change the
/// shape of a type, it's fine if the cause of the ambiguity error
/// is only related to the suggested variable via subtyping.
///
/// Even for something like `let x = returns_arg(); x.method();` the
/// type of `x` is only a supertype of the argument of `returns_arg`. We
/// still want to suggest specifying the type of the argument.
sub_relations: ut::UnificationTableStorage<TyVidSubKey>,
}
pub(crate) struct TypeVariableTable<'a, 'tcx> {
@ -109,6 +162,16 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
debug_assert!(self.probe(a).is_unknown());
debug_assert!(self.probe(b).is_unknown());
self.eq_relations().union(a, b);
self.sub_relations().union(a, b);
}
/// Records that `a <: b`, depending on `dir`.
///
/// Precondition: neither `a` nor `b` are known.
pub(crate) fn sub(&mut self, a: ty::TyVid, b: ty::TyVid) {
debug_assert!(self.probe(a).is_unknown());
debug_assert!(self.probe(b).is_unknown());
self.sub_relations().union(a, b);
}
/// Instantiates `vid` with the type `ty`.
@ -142,6 +205,10 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
origin: TypeVariableOrigin,
) -> ty::TyVid {
let eq_key = self.eq_relations().new_key(TypeVariableValue::Unknown { universe });
let sub_key = self.sub_relations().new_key(());
debug_assert_eq!(eq_key.vid, sub_key.vid);
let index = self.storage.values.push(TypeVariableData { origin });
debug_assert_eq!(eq_key.vid, index);
@ -164,6 +231,18 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
self.eq_relations().find(vid).vid
}
/// Returns the "root" variable of `vid` in the `sub_relations`
/// equivalence table. All type variables that have been are
/// related via equality or subtyping will yield the same root
/// variable (per the union-find algorithm), so `sub_root_var(a)
/// == sub_root_var(b)` implies that:
/// ```text
/// exists X. (a <: X || X <: a) && (b <: X || X <: b)
/// ```
pub(crate) fn sub_root_var(&mut self, vid: ty::TyVid) -> ty::TyVid {
self.sub_relations().find(vid).vid
}
/// Retrieves the type to which `vid` has been instantiated, if
/// any.
pub(crate) fn probe(&mut self, vid: ty::TyVid) -> TypeVariableValue<'tcx> {
@ -181,6 +260,11 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> {
self.storage.eq_relations.with_log(self.undo_log)
}
#[inline]
fn sub_relations(&mut self) -> super::UnificationTable<'_, 'tcx, TyVidSubKey> {
self.storage.sub_relations.with_log(self.undo_log)
}
/// Returns a range of the type variables created during the snapshot.
pub(crate) fn vars_since_snapshot(
&mut self,
@ -243,6 +327,33 @@ impl<'tcx> ut::UnifyKey for TyVidEqKey<'tcx> {
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) struct TyVidSubKey {
vid: ty::TyVid,
}
impl From<ty::TyVid> for TyVidSubKey {
#[inline] // make this function eligible for inlining - it is quite hot.
fn from(vid: ty::TyVid) -> Self {
TyVidSubKey { vid }
}
}
impl ut::UnifyKey for TyVidSubKey {
type Value = ();
#[inline]
fn index(&self) -> u32 {
self.vid.as_u32()
}
#[inline]
fn from_index(i: u32) -> TyVidSubKey {
TyVidSubKey { vid: ty::TyVid::from_u32(i) }
}
fn tag() -> &'static str {
"TyVidSubKey"
}
}
impl<'tcx> ut::UnifyValue for TypeVariableValue<'tcx> {
type Error = ut::NoError;

View file

@ -900,6 +900,10 @@ where
&& goal.param_env.visit_with(&mut visitor).is_continue()
}
pub(super) fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid) {
self.delegate.sub_ty_vids_raw(a, b)
}
#[instrument(level = "trace", skip(self, param_env), ret)]
pub(super) fn eq<T: Relate<I>>(
&mut self,

View file

@ -119,11 +119,15 @@ where
#[instrument(level = "trace", skip(self))]
fn compute_subtype_goal(&mut self, goal: Goal<I, ty::SubtypePredicate<I>>) -> QueryResult<I> {
if goal.predicate.a.is_ty_var() && goal.predicate.b.is_ty_var() {
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
} else {
self.sub(goal.param_env, goal.predicate.a, goal.predicate.b)?;
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
match (goal.predicate.a.kind(), goal.predicate.b.kind()) {
(ty::Infer(ty::TyVar(a_vid)), ty::Infer(ty::TyVar(b_vid))) => {
self.sub_ty_vids_raw(a_vid, b_vid);
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
}
_ => {
self.sub(goal.param_env, goal.predicate.a, goal.predicate.b)?;
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
}
}
}

View file

@ -91,7 +91,6 @@ mod suggest;
pub mod need_type_info;
pub mod nice_region_error;
pub mod region;
pub mod sub_relations;
/// Makes a valid string literal from a string by escaping special characters (" and \),
/// unless they are already escaped.

View file

@ -894,7 +894,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
use ty::{Infer, TyVar};
match (inner_ty.kind(), target_ty.kind()) {
(&Infer(TyVar(a_vid)), &Infer(TyVar(b_vid))) => {
self.tecx.sub_relations.borrow_mut().unified(self.tecx, a_vid, b_vid)
self.tecx.sub_root_var(a_vid) == self.tecx.sub_root_var(b_vid)
}
_ => false,
}

View file

@ -1,81 +0,0 @@
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::undo_log::NoUndo;
use rustc_data_structures::unify as ut;
use rustc_middle::ty;
use crate::infer::InferCtxt;
#[derive(Debug, Copy, Clone, PartialEq)]
struct SubId(u32);
impl ut::UnifyKey for SubId {
type Value = ();
#[inline]
fn index(&self) -> u32 {
self.0
}
#[inline]
fn from_index(i: u32) -> SubId {
SubId(i)
}
fn tag() -> &'static str {
"SubId"
}
}
/// When reporting ambiguity errors, we sometimes want to
/// treat all inference vars which are subtypes of each
/// others as if they are equal. For this case we compute
/// the transitive closure of our subtype obligations here.
///
/// E.g. when encountering ambiguity errors, we want to suggest
/// specifying some method argument or to add a type annotation
/// to a local variable. Because subtyping cannot change the
/// shape of a type, it's fine if the cause of the ambiguity error
/// is only related to the suggested variable via subtyping.
///
/// Even for something like `let x = returns_arg(); x.method();` the
/// type of `x` is only a supertype of the argument of `returns_arg`. We
/// still want to suggest specifying the type of the argument.
#[derive(Default)]
pub struct SubRelations {
map: FxHashMap<ty::TyVid, SubId>,
table: ut::UnificationTableStorage<SubId>,
}
impl SubRelations {
fn get_id<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, vid: ty::TyVid) -> SubId {
let root_vid = infcx.root_var(vid);
*self.map.entry(root_vid).or_insert_with(|| self.table.with_log(&mut NoUndo).new_key(()))
}
pub fn add_constraints<'tcx>(
&mut self,
infcx: &InferCtxt<'tcx>,
obls: impl IntoIterator<Item = ty::Predicate<'tcx>>,
) {
for p in obls {
let (a, b) = match p.kind().skip_binder() {
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
(a, b)
}
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
_ => continue,
};
match (a.kind(), b.kind()) {
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
let a = self.get_id(infcx, a_vid);
let b = self.get_id(infcx, b_vid);
self.table.with_log(&mut NoUndo).unify_var_var(a, b).unwrap();
}
_ => continue,
}
}
}
pub fn unified<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, a: ty::TyVid, b: ty::TyVid) -> bool {
let a = self.get_id(infcx, a);
let b = self.get_id(infcx, b);
self.table.with_log(&mut NoUndo).unioned(a, b)
}
}

View file

@ -7,8 +7,6 @@ use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::ty::{self, Ty};
use crate::error_reporting::infer::sub_relations;
pub mod infer;
pub mod traits;
@ -21,7 +19,6 @@ pub mod traits;
/// methods which should not be used during the happy path.
pub struct TypeErrCtxt<'a, 'tcx> {
pub infcx: &'a InferCtxt<'tcx>,
pub sub_relations: std::cell::RefCell<sub_relations::SubRelations>,
pub typeck_results: Option<std::cell::Ref<'a, ty::TypeckResults<'tcx>>>,
pub fallback_has_occurred: bool,
@ -38,7 +35,6 @@ impl<'tcx> InferCtxt<'tcx> {
fn err_ctxt(&self) -> TypeErrCtxt<'_, 'tcx> {
TypeErrCtxt {
infcx: self,
sub_relations: Default::default(),
typeck_results: None,
fallback_has_occurred: false,
normalize_fn_sig: Box::new(|fn_sig| fn_sig),

View file

@ -139,10 +139,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
&self,
mut errors: Vec<FulfillmentError<'tcx>>,
) -> ErrorGuaranteed {
self.sub_relations
.borrow_mut()
.add_constraints(self, errors.iter().map(|e| e.obligation.predicate));
#[derive(Debug)]
struct ErrorDescriptor<'tcx> {
goal: Goal<'tcx, ty::Predicate<'tcx>>,

View file

@ -126,13 +126,12 @@ impl<'tcx> rustc_next_trait_solver::delegate::SolverDelegate for SolverDelegate<
}
ty::PredicateKind::Subtype(ty::SubtypePredicate { a, b, .. })
| ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => {
if self.shallow_resolve(a).is_ty_var() && self.shallow_resolve(b).is_ty_var() {
// FIXME: We also need to register a subtype relation between these vars
// when those are added, and if they aren't in the same sub root then
// we should mark this goal as `has_changed`.
Some(Certainty::AMBIGUOUS)
} else {
None
match (self.shallow_resolve(a).kind(), self.shallow_resolve(b).kind()) {
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
self.sub_ty_vids_raw(a_vid, b_vid);
Some(Certainty::AMBIGUOUS)
}
_ => None,
}
}
ty::PredicateKind::Clause(ty::ClauseKind::ConstArgHasType(ct, _)) => {

View file

@ -197,6 +197,7 @@ pub trait InferCtxtLike: Sized {
) -> U;
fn equate_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid);
fn sub_ty_vids_raw(&self, a: ty::TyVid, b: ty::TyVid);
fn equate_int_vids_raw(&self, a: ty::IntVid, b: ty::IntVid);
fn equate_float_vids_raw(&self, a: ty::FloatVid, b: ty::FloatVid);
fn equate_const_vids_raw(&self, a: ty::ConstVid, b: ty::ConstVid);