Unstall obligations by looking for coroutines in old solver

This commit is contained in:
Michael Goulet 2025-05-29 12:34:24 +00:00
parent 96171dc78f
commit 72bc11d146
3 changed files with 52 additions and 54 deletions

View file

@ -625,50 +625,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// trigger query cycle ICEs, as doing so requires MIR.
self.select_obligations_where_possible(|_| {});
let coroutines = std::mem::take(&mut *self.deferred_coroutine_interiors.borrow_mut());
debug!(?coroutines);
let ty::TypingMode::Analysis { defining_opaque_types_and_generators } = self.typing_mode()
else {
bug!();
};
let mut obligations = vec![];
if !self.next_trait_solver() {
for &(coroutine_def_id, interior) in coroutines.iter() {
debug!(?coroutine_def_id);
// Create the `CoroutineWitness` type that we will unify with `interior`.
let args = ty::GenericArgs::identity_for_item(
self.tcx,
self.tcx.typeck_root_def_id(coroutine_def_id.to_def_id()),
);
let witness =
Ty::new_coroutine_witness(self.tcx, coroutine_def_id.to_def_id(), args);
// Unify `interior` with `witness` and collect all the resulting obligations.
let span = self.tcx.hir_body_owned_by(coroutine_def_id).value.span;
let ty::Infer(ty::InferTy::TyVar(_)) = interior.kind() else {
span_bug!(span, "coroutine interior witness not infer: {:?}", interior.kind())
};
let ok = self
.at(&self.misc(span), self.param_env)
// Will never define opaque types, as all we do is instantiate a type variable.
.eq(DefineOpaqueTypes::Yes, interior, witness)
.expect("Failed to unify coroutine interior type");
obligations.extend(ok.obligations);
}
}
if !coroutines.is_empty() {
obligations.extend(
if defining_opaque_types_and_generators
.iter()
.any(|def_id| self.tcx.is_coroutine(def_id.to_def_id()))
{
self.typeck_results.borrow_mut().coroutine_stalled_predicates.extend(
self.fulfillment_cx
.borrow_mut()
.drain_stalled_obligations_for_coroutines(&self.infcx),
.drain_stalled_obligations_for_coroutines(&self.infcx)
.into_iter()
.map(|o| (o.predicate, o.cause)),
);
}
self.typeck_results
.borrow_mut()
.coroutine_stalled_predicates
.extend(obligations.into_iter().map(|o| (o.predicate, o.cause)));
}
#[instrument(skip(self), level = "debug")]

View file

@ -255,7 +255,7 @@ where
&mut self,
infcx: &InferCtxt<'tcx>,
) -> PredicateObligations<'tcx> {
let stalled_generators = match infcx.typing_mode() {
let stalled_coroutines = match infcx.typing_mode() {
TypingMode::Analysis { defining_opaque_types_and_generators } => {
defining_opaque_types_and_generators
}
@ -265,7 +265,7 @@ where
| TypingMode::PostAnalysis => return Default::default(),
};
if stalled_generators.is_empty() {
if stalled_coroutines.is_empty() {
return Default::default();
}
@ -276,7 +276,7 @@ where
.visit_proof_tree(
obl.as_goal(),
&mut StalledOnCoroutines {
stalled_generators,
stalled_coroutines,
span: obl.cause.span,
cache: Default::default(),
},
@ -298,10 +298,10 @@ where
///
/// This function can be also return false positives, which will lead to poor diagnostics
/// so we want to keep this visitor *precise* too.
struct StalledOnCoroutines<'tcx> {
stalled_generators: &'tcx ty::List<LocalDefId>,
span: Span,
cache: DelayedSet<Ty<'tcx>>,
pub struct StalledOnCoroutines<'tcx> {
pub stalled_coroutines: &'tcx ty::List<LocalDefId>,
pub span: Span,
pub cache: DelayedSet<Ty<'tcx>>,
}
impl<'tcx> inspect::ProofTreeVisitor<'tcx> for StalledOnCoroutines<'tcx> {
@ -331,7 +331,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for StalledOnCoroutines<'tcx> {
}
if let ty::CoroutineWitness(def_id, _) = *ty.kind()
&& def_id.as_local().is_some_and(|def_id| self.stalled_generators.contains(&def_id))
&& def_id.as_local().is_some_and(|def_id| self.stalled_coroutines.contains(&def_id))
{
ControlFlow::Break(())
} else if ty.has_coroutines() {

View file

@ -3,6 +3,7 @@ use std::marker::PhantomData;
use rustc_data_structures::obligation_forest::{
Error, ForestObligation, ObligationForest, ObligationProcessor, Outcome, ProcessResult,
};
use rustc_hir::def_id::LocalDefId;
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_infer::traits::{
FromSolverError, PolyTraitObligation, PredicateObligations, ProjectionCacheKey, SelectionError,
@ -12,8 +13,9 @@ use rustc_middle::bug;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::{
self, Binder, Const, GenericArgsRef, TypeVisitableExt, TypingMode, may_use_unstable_feature,
self, Binder, Const, GenericArgsRef, TypeVisitable, TypeVisitableExt, TypingMode,
};
use rustc_span::DUMMY_SP;
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, debug_span, instrument};
@ -26,6 +28,7 @@ use super::{
};
use crate::error_reporting::InferCtxtErrorExt;
use crate::infer::{InferCtxt, TyOrConstInferVar};
use crate::solve::StalledOnCoroutines;
use crate::traits::normalize::normalize_with_depth_to;
use crate::traits::project::{PolyProjectionObligation, ProjectionCacheKeyExt as _};
use crate::traits::query::evaluate_obligation::InferCtxtExt;
@ -168,8 +171,25 @@ where
&mut self,
infcx: &InferCtxt<'tcx>,
) -> PredicateObligations<'tcx> {
let mut processor =
DrainProcessor { removed_predicates: PredicateObligations::new(), infcx };
let stalled_coroutines = match infcx.typing_mode() {
TypingMode::Analysis { defining_opaque_types_and_generators } => {
defining_opaque_types_and_generators
}
TypingMode::Coherence
| TypingMode::Borrowck { defining_opaque_types: _ }
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
| TypingMode::PostAnalysis => return Default::default(),
};
if stalled_coroutines.is_empty() {
return Default::default();
}
let mut processor = DrainProcessor {
infcx,
removed_predicates: PredicateObligations::new(),
stalled_coroutines,
};
let outcome: Outcome<_, _> = self.predicates.process_obligations(&mut processor);
assert!(outcome.errors.is_empty());
return processor.removed_predicates;
@ -177,6 +197,7 @@ where
struct DrainProcessor<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
removed_predicates: PredicateObligations<'tcx>,
stalled_coroutines: &'tcx ty::List<LocalDefId>,
}
impl<'tcx> ObligationProcessor for DrainProcessor<'_, 'tcx> {
@ -185,10 +206,14 @@ where
type OUT = Outcome<Self::Obligation, Self::Error>;
fn needs_process_obligation(&self, pending_obligation: &Self::Obligation) -> bool {
pending_obligation
.stalled_on
.iter()
.any(|&var| self.infcx.ty_or_const_infer_var_changed(var))
self.infcx
.resolve_vars_if_possible(pending_obligation.obligation.predicate)
.visit_with(&mut StalledOnCoroutines {
stalled_coroutines: self.stalled_coroutines,
span: DUMMY_SP,
cache: Default::default(),
})
.is_break()
}
fn process_obligation(