Prefer to return Box<thir::Pat> instead of thir::PatKind

This will allow extra data to be attached to the `Pat` before it is returned.
This commit is contained in:
Zalathar 2026-01-06 14:34:55 +11:00
parent 0495a73474
commit f85b898c45
2 changed files with 69 additions and 41 deletions

View file

@ -174,10 +174,10 @@ impl<'tcx> ConstToPat<'tcx> {
}
};
// Convert the valtree to a const.
let inlined_const_as_pat = self.valtree_to_pat(valtree, ty);
// Lower the valtree to a THIR pattern.
let mut thir_pat = self.valtree_to_pat(valtree, ty);
if !inlined_const_as_pat.references_error() {
if !thir_pat.references_error() {
// Always check for `PartialEq` if we had no other errors yet.
if !type_has_partial_eq_impl(self.tcx, typing_env, ty).has_impl {
let mut err = self.tcx.dcx().create_err(TypeNotPartialEq { span: self.span, ty });
@ -188,8 +188,12 @@ impl<'tcx> ConstToPat<'tcx> {
// Wrap the pattern in a marker node to indicate that it is the result of lowering a
// constant. This is used for diagnostics.
let kind = PatKind::ExpandedConstant { subpattern: inlined_const_as_pat, def_id: uv.def };
Box::new(Pat { kind, ty, span: self.span })
thir_pat = Box::new(Pat {
ty: thir_pat.ty,
span: thir_pat.span,
kind: PatKind::ExpandedConstant { def_id: uv.def, subpattern: thir_pat },
});
thir_pat
}
fn field_pats(

View file

@ -167,9 +167,10 @@ impl<'tcx> PatCtxt<'tcx> {
// Return None in that case; the caller will use NegInfinity or PosInfinity instead.
let Some(expr) = expr else { return Ok(None) };
// Lower the endpoint into a temporary `PatKind` that will then be
// Lower the endpoint into a temporary `thir::Pat` that will then be
// deconstructed to obtain the constant value and other data.
let mut kind: PatKind<'tcx> = self.lower_pat_expr(pat, expr);
let endpoint_pat: Box<Pat<'tcx>> = self.lower_pat_expr(pat, expr);
let box Pat { mut kind, .. } = endpoint_pat;
// Unpeel any ascription or inline-const wrapper nodes.
loop {
@ -250,7 +251,7 @@ impl<'tcx> PatCtxt<'tcx> {
lo_expr: Option<&'tcx hir::PatExpr<'tcx>>,
hi_expr: Option<&'tcx hir::PatExpr<'tcx>>,
end: RangeEnd,
) -> Result<PatKind<'tcx>, ErrorGuaranteed> {
) -> Result<Box<Pat<'tcx>>, ErrorGuaranteed> {
let ty = self.typeck_results.node_type(pat.hir_id);
let span = pat.span;
@ -306,27 +307,34 @@ impl<'tcx> PatCtxt<'tcx> {
return Err(e);
}
}
let mut thir_pat = Box::new(Pat { ty, span, kind });
// If we are handling a range with associated constants (e.g.
// `Foo::<'a>::A..=Foo::B`), we need to put the ascriptions for the associated
// constants somewhere. Have them on the range pattern.
for ascription in ascriptions {
let subpattern = Box::new(Pat { span, ty, kind });
kind = PatKind::AscribeUserType { ascription, subpattern };
thir_pat = Box::new(Pat {
ty,
span,
kind: PatKind::AscribeUserType { ascription, subpattern: thir_pat },
});
}
// `PatKind::ExpandedConstant` wrappers from range endpoints used to
// also be preserved here, but that was only needed for unsafeck of
// inline `const { .. }` patterns, which were removed by
// <https://github.com/rust-lang/rust/pull/138492>.
Ok(kind)
Ok(thir_pat)
}
#[instrument(skip(self), level = "debug")]
fn lower_pattern_unadjusted(&mut self, pat: &'tcx hir::Pat<'tcx>) -> Box<Pat<'tcx>> {
let mut ty = self.typeck_results.node_type(pat.hir_id);
let mut span = pat.span;
let ty = self.typeck_results.node_type(pat.hir_id);
let span = pat.span;
// Some of these match arms return a `Box<Pat>` early, while others
// evaluate to a `PatKind` that will become a `Box<Pat>` at the end of
// this function.
let kind = match pat.kind {
hir::PatKind::Missing => PatKind::Missing,
@ -334,10 +342,13 @@ impl<'tcx> PatCtxt<'tcx> {
hir::PatKind::Never => PatKind::Never,
hir::PatKind::Expr(value) => self.lower_pat_expr(pat, value),
hir::PatKind::Expr(value) => return self.lower_pat_expr(pat, value),
hir::PatKind::Range(lo_expr, hi_expr, end) => {
self.lower_pattern_range(pat, lo_expr, hi_expr, end).unwrap_or_else(PatKind::Error)
match self.lower_pattern_range(pat, lo_expr, hi_expr, end) {
Ok(thir_pat) => return thir_pat,
Err(e) => PatKind::Error(e),
}
}
hir::PatKind::Deref(subpattern) => {
@ -360,7 +371,7 @@ impl<'tcx> PatCtxt<'tcx> {
},
hir::PatKind::Slice(prefix, slice, suffix) => {
self.slice_or_array_pattern(pat, prefix, slice, suffix)
return self.slice_or_array_pattern(pat, prefix, slice, suffix);
}
hir::PatKind::Tuple(pats, ddpos) => {
@ -372,8 +383,9 @@ impl<'tcx> PatCtxt<'tcx> {
}
hir::PatKind::Binding(explicit_ba, id, ident, sub) => {
let mut thir_pat_span = span;
if let Some(ident_span) = ident.span.find_ancestor_inside(span) {
span = span.with_hi(ident_span.hi());
thir_pat_span = span.with_hi(ident_span.hi());
}
let mode = *self
@ -389,22 +401,23 @@ impl<'tcx> PatCtxt<'tcx> {
// A ref x pattern is the same node used for x, and as such it has
// x's type, which is &T, where we want T (the type being matched).
let var_ty = ty;
let mut thir_pat_ty = ty;
if let hir::ByRef::Yes(pinnedness, _) = mode.0 {
match pinnedness {
hir::Pinnedness::Pinned
if let Some(pty) = ty.pinned_ty()
&& let &ty::Ref(_, rty, _) = pty.kind() =>
{
ty = rty;
thir_pat_ty = rty;
}
hir::Pinnedness::Not if let &ty::Ref(_, rty, _) = ty.kind() => {
ty = rty;
thir_pat_ty = rty;
}
_ => bug!("`ref {}` has wrong type {}", ident, ty),
}
};
PatKind::Binding {
let kind = PatKind::Binding {
mode,
name: ident.name,
var: LocalVarId(id),
@ -412,7 +425,10 @@ impl<'tcx> PatCtxt<'tcx> {
subpattern: self.lower_opt_pattern(sub),
is_primary: id == pat.hir_id,
is_shorthand: false,
}
};
// We might have modified the type or span, so use the modified
// values in the THIR pattern node.
return Box::new(Pat { ty: thir_pat_ty, span: thir_pat_span, kind });
}
hir::PatKind::TupleStruct(ref qpath, pats, ddpos) => {
@ -422,7 +438,7 @@ impl<'tcx> PatCtxt<'tcx> {
};
let variant_def = adt_def.variant_of_res(res);
let subpatterns = self.lower_tuple_subpats(pats, variant_def.fields.len(), ddpos);
self.lower_variant_or_leaf(pat, None, res, subpatterns)
return self.lower_variant_or_leaf(pat, None, res, subpatterns);
}
hir::PatKind::Struct(ref qpath, fields, _) => {
@ -439,7 +455,7 @@ impl<'tcx> PatCtxt<'tcx> {
})
.collect();
self.lower_variant_or_leaf(pat, None, res, subpatterns)
return self.lower_variant_or_leaf(pat, None, res, subpatterns);
}
hir::PatKind::Or(pats) => PatKind::Or { pats: self.lower_patterns(pats) },
@ -450,6 +466,8 @@ impl<'tcx> PatCtxt<'tcx> {
hir::PatKind::Err(guar) => PatKind::Error(guar),
};
// For pattern kinds that haven't already returned, create a `thir::Pat`
// with the HIR pattern node's type and span.
Box::new(Pat { span, ty, kind })
}
@ -482,13 +500,14 @@ impl<'tcx> PatCtxt<'tcx> {
prefix: &'tcx [hir::Pat<'tcx>],
slice: Option<&'tcx hir::Pat<'tcx>>,
suffix: &'tcx [hir::Pat<'tcx>],
) -> PatKind<'tcx> {
) -> Box<Pat<'tcx>> {
let ty = self.typeck_results.node_type(pat.hir_id);
let span = pat.span;
let prefix = self.lower_patterns(prefix);
let slice = self.lower_opt_pattern(slice);
let suffix = self.lower_patterns(suffix);
match ty.kind() {
let kind = match ty.kind() {
// Matching a slice, `[T]`.
ty::Slice(..) => PatKind::Slice { prefix, slice, suffix },
// Fixed-length array, `[T; len]`.
@ -499,8 +518,9 @@ impl<'tcx> PatCtxt<'tcx> {
assert!(len >= prefix.len() as u64 + suffix.len() as u64);
PatKind::Array { prefix, slice, suffix }
}
_ => span_bug!(pat.span, "bad slice pattern type {ty:?}"),
}
_ => span_bug!(span, "bad slice pattern type {ty:?}"),
};
Box::new(Pat { ty, span, kind })
}
fn lower_variant_or_leaf(
@ -509,7 +529,7 @@ impl<'tcx> PatCtxt<'tcx> {
expr: Option<&'tcx hir::PatExpr<'tcx>>,
res: Res,
subpatterns: Vec<FieldPat<'tcx>>,
) -> PatKind<'tcx> {
) -> Box<Pat<'tcx>> {
// Check whether the caller should have provided an `expr` for this pattern kind.
assert_matches!(
(pat.kind, expr),
@ -533,7 +553,7 @@ impl<'tcx> PatCtxt<'tcx> {
res => res,
};
let mut kind = match res {
let kind = match res {
Res::Def(DefKind::Variant, variant_id) => {
let enum_id = self.tcx.parent(variant_id);
let adt_def = self.tcx.adt_def(enum_id);
@ -542,7 +562,7 @@ impl<'tcx> PatCtxt<'tcx> {
ty::Adt(_, args) | ty::FnDef(_, args) => args,
ty::Error(e) => {
// Avoid ICE (#50585)
return PatKind::Error(*e);
return Box::new(Pat { ty, span, kind: PatKind::Error(*e) });
}
_ => bug!("inappropriate type for def: {:?}", ty),
};
@ -583,21 +603,26 @@ impl<'tcx> PatCtxt<'tcx> {
PatKind::Error(e)
}
};
let mut thir_pat = Box::new(Pat { ty, span, kind });
if let Some(user_ty) = self.user_args_applied_to_ty_of_hir_id(hir_id) {
debug!("lower_variant_or_leaf: kind={:?} user_ty={:?} span={:?}", kind, user_ty, span);
debug!(?thir_pat, ?user_ty, ?span, "lower_variant_or_leaf: applying ascription");
let annotation = CanonicalUserTypeAnnotation {
user_ty: Box::new(user_ty),
span,
inferred_ty: self.typeck_results.node_type(hir_id),
};
kind = PatKind::AscribeUserType {
subpattern: Box::new(Pat { span, ty, kind }),
ascription: Ascription { annotation, variance: ty::Covariant },
};
thir_pat = Box::new(Pat {
ty,
span,
kind: PatKind::AscribeUserType {
subpattern: thir_pat,
ascription: Ascription { annotation, variance: ty::Covariant },
},
});
}
kind
thir_pat
}
fn user_args_applied_to_ty_of_hir_id(
@ -632,8 +657,7 @@ impl<'tcx> PatCtxt<'tcx> {
_ => {
// The path isn't the name of a constant, so it must actually
// be a unit struct or unit variant (e.g. `Option::None`).
let kind = self.lower_variant_or_leaf(pat, Some(expr), res, vec![]);
return Box::new(Pat { span, ty, kind });
return self.lower_variant_or_leaf(pat, Some(expr), res, vec![]);
}
};
@ -674,10 +698,10 @@ impl<'tcx> PatCtxt<'tcx> {
&mut self,
pat: &'tcx hir::Pat<'tcx>, // Pattern that directly contains `expr`
expr: &'tcx hir::PatExpr<'tcx>,
) -> PatKind<'tcx> {
) -> Box<Pat<'tcx>> {
assert_matches!(pat.kind, hir::PatKind::Expr(..) | hir::PatKind::Range(..));
match &expr.kind {
hir::PatExprKind::Path(qpath) => self.lower_path(pat, expr, qpath).kind,
hir::PatExprKind::Path(qpath) => self.lower_path(pat, expr, qpath),
hir::PatExprKind::Lit { lit, negated } => {
// We handle byte string literal patterns by using the pattern's type instead of the
// literal's type in `const_to_pat`: if the literal `b"..."` matches on a slice reference,
@ -691,7 +715,7 @@ impl<'tcx> PatCtxt<'tcx> {
let pat_ty = self.typeck_results.node_type(pat.hir_id);
let lit_input = LitToConstInput { lit: lit.node, ty: pat_ty, neg: *negated };
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span).kind
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span)
}
}
}