diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs index e37c0cf0fd03..00e238648712 100644 --- a/compiler/rustc_infer/src/infer/at.rs +++ b/compiler/rustc_infer/src/infer/at.rs @@ -78,6 +78,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { err_count_on_creation: self.err_count_on_creation, in_snapshot: self.in_snapshot.clone(), universe: self.universe.clone(), + normalize_fn_sig_for_diagnostic: self + .normalize_fn_sig_for_diagnostic + .as_ref() + .map(|f| f.clone()), } } } diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs index 59ea1f3f9de4..7bb151ef2b9e 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs @@ -961,12 +961,23 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { } } + fn normalize_fn_sig_for_diagnostic(&self, sig: ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx> { + if let Some(normalize) = &self.normalize_fn_sig_for_diagnostic { + normalize(self, sig) + } else { + sig + } + } + /// Given two `fn` signatures highlight only sub-parts that are different. fn cmp_fn_sig( &self, sig1: &ty::PolyFnSig<'tcx>, sig2: &ty::PolyFnSig<'tcx>, ) -> (DiagnosticStyledString, DiagnosticStyledString) { + let sig1 = &self.normalize_fn_sig_for_diagnostic(*sig1); + let sig2 = &self.normalize_fn_sig_for_diagnostic(*sig2); + let get_lifetimes = |sig| { use rustc_hir::def::Namespace; let (_, sig, reg) = ty::print::FmtPrinter::new(self.tcx, Namespace::TypeNS) diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index c95738e0018c..a56b34ef3eb5 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -337,6 +337,9 @@ pub struct InferCtxt<'a, 'tcx> { /// when we enter into a higher-ranked (`for<..>`) type or trait /// bound. universe: Cell, + + normalize_fn_sig_for_diagnostic: + Option, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>, } /// See the `error_reporting` module for more details. @@ -540,6 +543,8 @@ pub struct InferCtxtBuilder<'tcx> { defining_use_anchor: DefiningAnchor, considering_regions: bool, fresh_typeck_results: Option>>, + normalize_fn_sig_for_diagnostic: + Option, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>, } pub trait TyCtxtInferExt<'tcx> { @@ -553,6 +558,7 @@ impl<'tcx> TyCtxtInferExt<'tcx> for TyCtxt<'tcx> { defining_use_anchor: DefiningAnchor::Error, considering_regions: true, fresh_typeck_results: None, + normalize_fn_sig_for_diagnostic: None, } } } @@ -582,6 +588,14 @@ impl<'tcx> InferCtxtBuilder<'tcx> { self } + pub fn with_normalize_fn_sig_for_diagnostic( + mut self, + fun: Lrc, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>, + ) -> Self { + self.normalize_fn_sig_for_diagnostic = Some(fun); + self + } + /// Given a canonical value `C` as a starting point, create an /// inference context that contains each of the bound values /// within instantiated as a fresh variable. The `f` closure is @@ -611,6 +625,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> { defining_use_anchor, considering_regions, ref fresh_typeck_results, + ref normalize_fn_sig_for_diagnostic, } = *self; let in_progress_typeck_results = fresh_typeck_results.as_ref(); f(InferCtxt { @@ -629,6 +644,9 @@ impl<'tcx> InferCtxtBuilder<'tcx> { in_snapshot: Cell::new(false), skip_leak_check: Cell::new(false), universe: Cell::new(ty::UniverseIndex::ROOT), + normalize_fn_sig_for_diagnostic: normalize_fn_sig_for_diagnostic + .as_ref() + .map(|f| f.clone()), }) } } diff --git a/compiler/rustc_typeck/src/check/inherited.rs b/compiler/rustc_typeck/src/check/inherited.rs index f3115fc5c023..1a000f77b0fd 100644 --- a/compiler/rustc_typeck/src/check/inherited.rs +++ b/compiler/rustc_typeck/src/check/inherited.rs @@ -1,18 +1,22 @@ use super::callee::DeferredCallResolution; use rustc_data_structures::fx::FxHashSet; +use rustc_data_structures::sync::Lrc; use rustc_hir as hir; use rustc_hir::def_id::LocalDefId; use rustc_hir::HirIdMap; use rustc_infer::infer; use rustc_infer::infer::{InferCtxt, InferOk, TyCtxtInferExt}; +use rustc_infer::traits::TraitEngineExt as _; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::visit::TypeVisitable; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_span::def_id::LocalDefIdMap; use rustc_span::{self, Span}; use rustc_trait_selection::infer::InferCtxtExt as _; -use rustc_trait_selection::traits::{self, ObligationCause, TraitEngine, TraitEngineExt}; +use rustc_trait_selection::traits::{ + self, FulfillmentContext, ObligationCause, TraitEngine, TraitEngineExt as _, +}; use std::cell::RefCell; use std::ops::Deref; @@ -84,7 +88,29 @@ impl<'tcx> Inherited<'_, 'tcx> { infcx: tcx .infer_ctxt() .ignoring_regions() - .with_fresh_in_progress_typeck_results(hir_owner), + .with_fresh_in_progress_typeck_results(hir_owner) + .with_normalize_fn_sig_for_diagnostic(Lrc::new(move |infcx, fn_sig| { + if fn_sig.has_escaping_bound_vars() { + return fn_sig; + } + infcx.probe(|_| { + let traits::Normalized { value: normalized_fn_sig, obligations } = + traits::normalize( + &mut traits::SelectionContext::new(infcx), + // FIXME(compiler-errors): This is probably not the right param-env... + infcx.tcx.param_env(def_id), + ObligationCause::dummy(), + fn_sig, + ); + let mut fulfillment_ctxt = FulfillmentContext::new_in_snapshot(); + fulfillment_ctxt.register_predicate_obligations(infcx, obligations); + if fulfillment_ctxt.select_all_or_error(infcx).is_empty() { + infcx.resolve_vars_if_possible(normalized_fn_sig) + } else { + fn_sig + } + }) + })), def_id, } } diff --git a/src/test/ui/mismatched_types/normalize-fn-sig.rs b/src/test/ui/mismatched_types/normalize-fn-sig.rs new file mode 100644 index 000000000000..1a2093c44f02 --- /dev/null +++ b/src/test/ui/mismatched_types/normalize-fn-sig.rs @@ -0,0 +1,16 @@ +trait Foo { + type Bar; +} + +impl Foo for T { + type Bar = i32; +} + +fn foo(_: ::Bar, _: &'static ::Bar) {} + +fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {} + +fn main() { + needs_i32_ref_fn(foo::<()>); + //~^ ERROR mismatched types +} diff --git a/src/test/ui/mismatched_types/normalize-fn-sig.stderr b/src/test/ui/mismatched_types/normalize-fn-sig.stderr new file mode 100644 index 000000000000..6c55f29c5d15 --- /dev/null +++ b/src/test/ui/mismatched_types/normalize-fn-sig.stderr @@ -0,0 +1,19 @@ +error[E0308]: mismatched types + --> $DIR/normalize-fn-sig.rs:14:22 + | +LL | needs_i32_ref_fn(foo::<()>); + | ---------------- ^^^^^^^^^ expected `&i32`, found `i32` + | | + | arguments to this function are incorrect + | + = note: expected fn pointer `fn(&'static i32, i32)` + found fn item `fn(i32, &'static i32) {foo::<()>}` +note: function defined here + --> $DIR/normalize-fn-sig.rs:11:4 + | +LL | fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {} + | ^^^^^^^^^^^^^^^^ ------------------------ + +error: aborting due to previous error + +For more information about this error, try `rustc --explain E0308`.