From b3413c643b2b00b91ea1200ec17e49b8a161e7de Mon Sep 17 00:00:00 2001 From: b-naber Date: Fri, 23 Jun 2023 12:28:19 +0000 Subject: [PATCH] only infer array type on irrefutable patterns --- compiler/rustc_hir_typeck/src/_match.rs | 2 +- compiler/rustc_hir_typeck/src/check.rs | 2 +- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 5 +- .../rustc_hir_typeck/src/gather_locals.rs | 31 +++- compiler/rustc_hir_typeck/src/pat.rs | 138 +++++++++++++----- 5 files changed, 138 insertions(+), 40 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/_match.rs b/compiler/rustc_hir_typeck/src/_match.rs index f2a43cc414d3..55c3cd8f4c63 100644 --- a/compiler/rustc_hir_typeck/src/_match.rs +++ b/compiler/rustc_hir_typeck/src/_match.rs @@ -41,7 +41,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // #55810: Type check patterns first so we get types for all bindings. let scrut_span = scrut.span.find_ancestor_inside(expr.span).unwrap_or(scrut.span); for arm in arms { - self.check_pat_top(&arm.pat, scrutinee_ty, Some(scrut_span), Some(scrut)); + self.check_pat_top(&arm.pat, scrutinee_ty, Some(scrut_span), Some(scrut), None); } // Now typecheck the blocks. diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index 53bae315d786..1fc1e5aca2b3 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -89,7 +89,7 @@ pub(super) fn check_fn<'a, 'tcx>( for (idx, (param_ty, param)) in inputs_fn.chain(maybe_va_list).zip(body.params).enumerate() { // Check the pattern. let ty_span = try { inputs_hir?.get(idx)?.span }; - fcx.check_pat_top(¶m.pat, param_ty, ty_span, None); + fcx.check_pat_top(¶m.pat, param_ty, ty_span, None, None); // Check that argument is Sized. if !params_can_be_unsized { diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index a9610009db1e..c70de7b8b725 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -1,7 +1,7 @@ use crate::coercion::CoerceMany; use crate::errors::SuggestPtrNullMut; use crate::fn_ctxt::arg_matrix::{ArgMatrix, Compatibility, Error, ExpectedIdx, ProvidedIdx}; -use crate::gather_locals::Declaration; +use crate::gather_locals::{DeclContext, Declaration}; use crate::method::MethodCallee; use crate::TupleArgumentsFlag::*; use crate::{errors, Expectation::*}; @@ -1474,7 +1474,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { }; // Type check the pattern. Override if necessary to avoid knock-on errors. - self.check_pat_top(&decl.pat, decl_ty, ty_span, origin_expr); + let decl_ctxt = DeclContext { has_else: decl.els.is_some(), origin: decl.origin }; + self.check_pat_top(&decl.pat, decl_ty, ty_span, origin_expr, Some(decl_ctxt)); let pat_ty = self.node_ty(decl.pat.hir_id); self.overwrite_local_ty_if_err(decl.hir_id, decl.pat, pat_ty); diff --git a/compiler/rustc_hir_typeck/src/gather_locals.rs b/compiler/rustc_hir_typeck/src/gather_locals.rs index 4f45a24b216e..15562d574667 100644 --- a/compiler/rustc_hir_typeck/src/gather_locals.rs +++ b/compiler/rustc_hir_typeck/src/gather_locals.rs @@ -9,6 +9,24 @@ use rustc_span::def_id::LocalDefId; use rustc_span::Span; use rustc_trait_selection::traits; +#[derive(Debug, Copy, Clone)] +pub(super) enum DeclOrigin { + // from an `if let` expression + LetExpr, + // from `let x = ..` + LocalDecl, +} + +/// Provides context for checking patterns in declarations. More specifically this +/// allows us to infer array types if the pattern is irrefutable and allows us to infer +/// the size of the array. See issue #76342. +#[derive(Debug, Copy, Clone)] +pub(crate) struct DeclContext { + // whether we're in a let-else context + pub(super) has_else: bool, + pub(super) origin: DeclOrigin, +} + /// A declaration is an abstraction of [hir::Local] and [hir::Let]. /// /// It must have a hir_id, as this is how we connect gather_locals to the check functions. @@ -19,19 +37,28 @@ pub(super) struct Declaration<'a> { pub span: Span, pub init: Option<&'a hir::Expr<'a>>, pub els: Option<&'a hir::Block<'a>>, + pub origin: DeclOrigin, } impl<'a> From<&'a hir::Local<'a>> for Declaration<'a> { fn from(local: &'a hir::Local<'a>) -> Self { let hir::Local { hir_id, pat, ty, span, init, els, source: _ } = *local; - Declaration { hir_id, pat, ty, span, init, els } + Declaration { hir_id, pat, ty, span, init, els, origin: DeclOrigin::LocalDecl } } } impl<'a> From<&'a hir::Let<'a>> for Declaration<'a> { fn from(let_expr: &'a hir::Let<'a>) -> Self { let hir::Let { hir_id, pat, ty, span, init } = *let_expr; - Declaration { hir_id, pat, ty, span, init: Some(init), els: None } + Declaration { + hir_id, + pat, + ty, + span, + init: Some(init), + els: None, + origin: DeclOrigin::LetExpr, + } } } diff --git a/compiler/rustc_hir_typeck/src/pat.rs b/compiler/rustc_hir_typeck/src/pat.rs index 4785564d850f..5616f67f430f 100644 --- a/compiler/rustc_hir_typeck/src/pat.rs +++ b/compiler/rustc_hir_typeck/src/pat.rs @@ -1,3 +1,4 @@ +use crate::gather_locals::{DeclContext, DeclOrigin}; use crate::{errors, FnCtxt, RawTy}; use rustc_ast as ast; use rustc_data_structures::fx::FxHashMap; @@ -135,15 +136,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { /// /// Otherwise, `Some(span)` represents the span of a type expression /// which originated the `expected` type. - pub fn check_pat_top( + pub(crate) fn check_pat_top( &self, pat: &'tcx Pat<'tcx>, expected: Ty<'tcx>, span: Option, origin_expr: Option<&'tcx hir::Expr<'tcx>>, + decl_ctxt: Option, ) { let info = TopInfo { expected, origin_expr, span }; - self.check_pat(pat, expected, INITIAL_BM, info); + self.check_pat(pat, expected, INITIAL_BM, info, decl_ctxt); } /// Type check the given `pat` against the `expected` type @@ -158,6 +160,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) { let path_res = match &pat.kind { PatKind::Path(qpath) => { @@ -173,33 +176,41 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { PatKind::Lit(lt) => self.check_pat_lit(pat.span, lt, expected, ti), PatKind::Range(lhs, rhs, _) => self.check_pat_range(pat.span, lhs, rhs, expected, ti), PatKind::Binding(ba, var_id, _, sub) => { - self.check_pat_ident(pat, ba, var_id, sub, expected, def_bm, ti) - } - PatKind::TupleStruct(ref qpath, subpats, ddpos) => { - self.check_pat_tuple_struct(pat, qpath, subpats, ddpos, expected, def_bm, ti) + self.check_pat_ident(pat, ba, var_id, sub, expected, def_bm, ti, decl_ctxt) } + PatKind::TupleStruct(ref qpath, subpats, ddpos) => self.check_pat_tuple_struct( + pat, qpath, subpats, ddpos, expected, def_bm, ti, decl_ctxt, + ), PatKind::Path(ref qpath) => { self.check_pat_path(pat, qpath, path_res.unwrap(), expected, ti) } - PatKind::Struct(ref qpath, fields, has_rest_pat) => { - self.check_pat_struct(pat, qpath, fields, has_rest_pat, expected, def_bm, ti) - } + PatKind::Struct(ref qpath, fields, has_rest_pat) => self.check_pat_struct( + pat, + qpath, + fields, + has_rest_pat, + expected, + def_bm, + ti, + decl_ctxt, + ), PatKind::Or(pats) => { for pat in pats { - self.check_pat(pat, expected, def_bm, ti); + self.check_pat(pat, expected, def_bm, ti, decl_ctxt); } expected } PatKind::Tuple(elements, ddpos) => { - self.check_pat_tuple(pat.span, elements, ddpos, expected, def_bm, ti) + self.check_pat_tuple(pat.span, elements, ddpos, expected, def_bm, ti, decl_ctxt) + } + PatKind::Box(inner) => { + self.check_pat_box(pat.span, inner, expected, def_bm, ti, decl_ctxt) } - PatKind::Box(inner) => self.check_pat_box(pat.span, inner, expected, def_bm, ti), PatKind::Ref(inner, mutbl) => { - self.check_pat_ref(pat, inner, mutbl, expected, def_bm, ti) - } - PatKind::Slice(before, slice, after) => { - self.check_pat_slice(pat.span, before, slice, after, expected, def_bm, ti) + self.check_pat_ref(pat, inner, mutbl, expected, def_bm, ti, decl_ctxt) } + PatKind::Slice(before, slice, after) => self + .check_pat_slice(pat.span, before, slice, after, expected, def_bm, ti, decl_ctxt), }; self.write_ty(pat.hir_id, ty); @@ -582,6 +593,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { // Determine the binding mode... let bm = match ba { @@ -620,7 +632,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } if let Some(p) = sub { - self.check_pat(p, expected, def_bm, ti); + self.check_pat(p, expected, def_bm, ti, decl_ctxt); } local_ty @@ -845,6 +857,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { // Resolve the path and check the definition for errors. let (variant, pat_ty) = match self.check_struct_path(qpath, pat.hir_id) { @@ -853,7 +866,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let err = Ty::new_error(self.tcx, guar); for field in fields { let ti = ti; - self.check_pat(field.pat, err, def_bm, ti); + self.check_pat(field.pat, err, def_bm, ti, decl_ctxt); } return err; } @@ -863,7 +876,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { self.demand_eqtype_pat(pat.span, expected, pat_ty, ti); // Type-check subpatterns. - if self.check_struct_pat_fields(pat_ty, &pat, variant, fields, has_rest_pat, def_bm, ti) { + if self.check_struct_pat_fields( + pat_ty, + &pat, + variant, + fields, + has_rest_pat, + def_bm, + ti, + decl_ctxt, + ) { pat_ty } else { Ty::new_misc_error(self.tcx) @@ -1030,11 +1052,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { let tcx = self.tcx; let on_error = |e| { for pat in subpats { - self.check_pat(pat, Ty::new_error(tcx, e), def_bm, ti); + self.check_pat(pat, Ty::new_error(tcx, e), def_bm, ti, decl_ctxt); } }; let report_unexpected_res = |res: Res| { @@ -1100,7 +1123,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { for (i, subpat) in subpats.iter().enumerate_and_adjust(variant.fields.len(), ddpos) { let field = &variant.fields[FieldIdx::from_usize(i)]; let field_ty = self.field_ty(subpat.span, field, args); - self.check_pat(subpat, field_ty, def_bm, ti); + self.check_pat(subpat, field_ty, def_bm, ti, decl_ctxt); self.tcx.check_stability( variant.fields[FieldIdx::from_usize(i)].did, @@ -1286,6 +1309,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { let tcx = self.tcx; let mut expected_len = elements.len(); @@ -1312,12 +1336,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // further errors being emitted when using the bindings. #50333 let element_tys_iter = (0..max_len).map(|_| Ty::new_error(tcx, reported)); for (_, elem) in elements.iter().enumerate_and_adjust(max_len, ddpos) { - self.check_pat(elem, Ty::new_error(tcx, reported), def_bm, ti); + self.check_pat(elem, Ty::new_error(tcx, reported), def_bm, ti, decl_ctxt); } Ty::new_tup_from_iter(tcx, element_tys_iter) } else { for (i, elem) in elements.iter().enumerate_and_adjust(max_len, ddpos) { - self.check_pat(elem, element_tys[i], def_bm, ti); + self.check_pat(elem, element_tys[i], def_bm, ti, decl_ctxt); } pat_ty } @@ -1332,6 +1356,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { has_rest_pat: bool, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> bool { let tcx = self.tcx; @@ -1378,7 +1403,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } }; - self.check_pat(field.pat, field_ty, def_bm, ti); + self.check_pat(field.pat, field_ty, def_bm, ti, decl_ctxt); } let mut unmentioned_fields = variant @@ -1943,6 +1968,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { let tcx = self.tcx; let (box_ty, inner_ty) = match self.check_dereferenceable(span, expected, inner) { @@ -1962,7 +1988,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { (err, err) } }; - self.check_pat(inner, inner_ty, def_bm, ti); + self.check_pat(inner, inner_ty, def_bm, ti, decl_ctxt); box_ty } @@ -1975,6 +2001,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { let tcx = self.tcx; let expected = self.shallow_resolve(expected); @@ -2013,7 +2040,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { (err, err) } }; - self.check_pat(inner, inner_ty, def_bm, ti); + self.check_pat(inner, inner_ty, def_bm, ti, decl_ctxt); ref_ty } @@ -2043,6 +2070,44 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { Some(tcx.mk_array(inner_ty, len.try_into().unwrap())) } + /// Determines whether we can infer the expected type in the slice pattern to be of type array. + /// This is only possible if we're in an irrefutable pattern. If we were to allow this in refutable + /// patterns we wouldn't e.g. report ambiguity in the following situation: + /// + /// ```ignore(rust) + /// struct Zeroes; + /// const ARR: [usize; 2] = [0; 2]; + /// const ARR2: [usize; 2] = [2; 2]; + /// + /// impl Into<&'static [usize; 2]> for Zeroes { + /// fn into(self) -> &'static [usize; 2] { + /// &ARR + /// } + /// } + /// + /// impl Into<&'static [usize]> for Zeroes { + /// fn into(self) -> &'static [usize] { + /// &ARR2 + /// } + /// } + /// + /// fn main() { + /// let &[a, b]: &[usize] = Zeroes.into() else { + /// .. + /// }; + /// } + /// ``` + /// + /// If we're in an irrefutable pattern we prefer the array impl candidate given that + /// the slice impl candidate would be be rejected anyway (if no ambiguity existed). + fn decl_allows_array_type_infer(&self, decl_ctxt: Option) -> bool { + if let Some(decl_ctxt) = decl_ctxt { + !decl_ctxt.has_else && matches!(decl_ctxt.origin, DeclOrigin::LocalDecl) + } else { + false + } + } + /// Type check a slice pattern. /// /// Syntactically, these look like `[pat_0, ..., pat_n]`. @@ -2062,12 +2127,17 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { expected: Ty<'tcx>, def_bm: BindingMode, ti: TopInfo<'tcx>, + decl_ctxt: Option, ) -> Ty<'tcx> { - // If `expected` is an infer ty, we try to equate it to an array if the given pattern - // allows it. See issue #76342 - if let Some(resolved_arr_ty) = self.try_resolve_slice_ty_to_array_ty(before, slice, span) && expected.is_ty_var() { - debug!(?resolved_arr_ty); - self.demand_eqtype(span, expected, resolved_arr_ty); + // If the pattern is irrefutable and `expected` is an infer ty, we try to equate it + // to an array if the given pattern allows it. See issue #76342 + if self.decl_allows_array_type_infer(decl_ctxt) && expected.is_ty_var() { + if let Some(resolved_arr_ty) = + self.try_resolve_slice_ty_to_array_ty(before, slice, span) + { + debug!(?resolved_arr_ty); + self.demand_eqtype(span, expected, resolved_arr_ty); + } } let expected = self.structurally_resolve_type(span, expected); @@ -2096,15 +2166,15 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // Type check all the patterns before `slice`. for elt in before { - self.check_pat(elt, element_ty, def_bm, ti); + self.check_pat(elt, element_ty, def_bm, ti, decl_ctxt); } // Type check the `slice`, if present, against its expected type. if let Some(slice) = slice { - self.check_pat(slice, opt_slice_ty.unwrap(), def_bm, ti); + self.check_pat(slice, opt_slice_ty.unwrap(), def_bm, ti, decl_ctxt); } // Type check the elements after `slice`, if present. for elt in after { - self.check_pat(elt, element_ty, def_bm, ti); + self.check_pat(elt, element_ty, def_bm, ti, decl_ctxt); } inferred }