diff --git a/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs b/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs index 0bce31beb8b8..38e8d2fa9368 100644 --- a/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs +++ b/compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs @@ -143,7 +143,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>( let total_len = lane_count * 2; - let indexes = idx.iter().map(|idx| idx.unwrap_leaf().to_u32()).collect::>(); + let indexes = idx.iter().map(|idx| idx.to_leaf().to_u32()).collect::>(); for &idx in &indexes { assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len); @@ -961,9 +961,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>( let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap(); let ptr_val = ptr.load_scalar(fx); - let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(); + let alignment = + generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(); let memflags = match alignment { SimdAlign::Unaligned => MemFlags::new().with_notrap(), @@ -1006,9 +1005,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>( let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap(); let ret_lane_layout = fx.layout_of(ret_lane_ty); - let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(); + let alignment = + generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(); let memflags = match alignment { SimdAlign::Unaligned => MemFlags::new().with_notrap(), @@ -1059,9 +1057,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>( let ret_lane_layout = fx.layout_of(ret_lane_ty); let ptr_val = ptr.load_scalar(fx); - let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(); + let alignment = + generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(); let memflags = match alignment { SimdAlign::Unaligned => MemFlags::new().with_notrap(), diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 3bc890310cc8..215738828c98 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -351,7 +351,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { _ => bug!(), }; let ptr = args[0].immediate(); - let locality = fn_args.const_at(1).to_value().valtree.unwrap_leaf().to_i32(); + let locality = fn_args.const_at(1).to_leaf().to_i32(); self.call_intrinsic( "llvm.prefetch", &[self.val_ty(ptr)], @@ -1533,7 +1533,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } if name == sym::simd_shuffle_const_generic { - let idx = fn_args[2].expect_const().to_value().valtree.unwrap_branch(); + let idx = fn_args[2].expect_const().to_branch(); let n = idx.len() as u64; let (out_len, out_ty) = require_simd!(ret_ty, SimdReturn); @@ -1552,7 +1552,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>( .iter() .enumerate() .map(|(arg_idx, val)| { - let idx = val.unwrap_leaf().to_i32(); + let idx = val.to_leaf().to_i32(); if idx >= i32::try_from(total_len).unwrap() { bx.sess().dcx().emit_err(InvalidMonomorphization::SimdIndexOutOfBounds { span, @@ -1964,9 +1964,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>( // those lanes whose `mask` bit is enabled. // The memory addresses corresponding to the “off” lanes are not accessed. - let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(); + let alignment = fn_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(); // The element type of the "mask" argument must be a signed integer type of any width let mask_ty = in_ty; @@ -2059,9 +2057,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>( // those lanes whose `mask` bit is enabled. // The memory addresses corresponding to the “off” lanes are not accessed. - let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(); + let alignment = fn_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(); // The element type of the "mask" argument must be a signed integer type of any width let mask_ty = in_ty; diff --git a/compiler/rustc_codegen_ssa/src/mir/constant.rs b/compiler/rustc_codegen_ssa/src/mir/constant.rs index 11b6ab3cdf1a..abdac4c7c372 100644 --- a/compiler/rustc_codegen_ssa/src/mir/constant.rs +++ b/compiler/rustc_codegen_ssa/src/mir/constant.rs @@ -77,22 +77,21 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { .flatten() .map(|val| { // A SIMD type has a single field, which is an array. - let fields = val.unwrap_branch(); + let fields = val.to_branch(); assert_eq!(fields.len(), 1); - let array = fields[0].unwrap_branch(); + let array = fields[0].to_branch(); // Iterate over the array elements to obtain the values in the vector. let values: Vec<_> = array .iter() .map(|field| { - if let Some(prim) = field.try_to_scalar() { - let layout = bx.layout_of(field_ty); - let BackendRepr::Scalar(scalar) = layout.backend_repr else { - bug!("from_const: invalid ByVal layout: {:#?}", layout); - }; - bx.scalar_to_backend(prim, scalar, bx.immediate_backend_type(layout)) - } else { + let Some(prim) = field.try_to_scalar() else { bug!("field is not a scalar {:?}", field) - } + }; + let layout = bx.layout_of(field_ty); + let BackendRepr::Scalar(scalar) = layout.backend_repr else { + bug!("from_const: invalid ByVal layout: {:#?}", layout); + }; + bx.scalar_to_backend(prim, scalar, bx.immediate_backend_type(layout)) }) .collect(); bx.const_vector(&values) diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index aeb740118234..f4fae40d8828 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -102,7 +102,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { }; let parse_atomic_ordering = |ord: ty::Value<'tcx>| { - let discr = ord.valtree.unwrap_branch()[0].unwrap_leaf(); + let discr = ord.to_branch()[0].to_leaf(); discr.to_atomic_ordering() }; diff --git a/compiler/rustc_const_eval/src/const_eval/valtrees.rs b/compiler/rustc_const_eval/src/const_eval/valtrees.rs index 7c41258ebfe5..b771addb8df5 100644 --- a/compiler/rustc_const_eval/src/const_eval/valtrees.rs +++ b/compiler/rustc_const_eval/src/const_eval/valtrees.rs @@ -36,13 +36,17 @@ fn branches<'tcx>( // For enums, we prepend their variant index before the variant's fields so we can figure out // the variant again when just seeing a valtree. if let Some(variant) = variant { - branches.push(ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into())); + branches.push(ty::Const::new_value( + *ecx.tcx, + ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()), + ecx.tcx.types.u32, + )); } for i in 0..field_count { let field = ecx.project_field(&place, FieldIdx::from_usize(i)).unwrap(); let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?; - branches.push(valtree); + branches.push(ty::Const::new_value(*ecx.tcx, valtree, field.layout.ty)); } // Have to account for ZSTs here @@ -65,7 +69,7 @@ fn slice_branches<'tcx>( for i in 0..n { let place_elem = ecx.project_index(place, i).unwrap(); let valtree = const_to_valtree_inner(ecx, &place_elem, num_nodes)?; - elems.push(valtree); + elems.push(ty::Const::new_value(*ecx.tcx, valtree, place_elem.layout.ty)); } Ok(ty::ValTree::from_branches(*ecx.tcx, elems)) @@ -200,8 +204,8 @@ fn reconstruct_place_meta<'tcx>( &ObligationCause::dummy(), |ty| ty, || { - let branches = last_valtree.unwrap_branch(); - last_valtree = *branches.last().unwrap(); + let branches = last_valtree.to_branch(); + last_valtree = branches.last().unwrap().to_value().valtree; debug!(?branches, ?last_valtree); }, ); @@ -212,7 +216,7 @@ fn reconstruct_place_meta<'tcx>( }; // Get the number of elements in the unsized field. - let num_elems = last_valtree.unwrap_branch().len(); + let num_elems = last_valtree.to_branch().len(); MemPlaceMeta::Meta(Scalar::from_target_usize(num_elems as u64, &tcx)) } @@ -274,7 +278,7 @@ pub fn valtree_to_const_value<'tcx>( mir::ConstValue::ZeroSized } ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => { - mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf())) + mir::ConstValue::Scalar(Scalar::Int(cv.to_leaf())) } ty::Pat(ty, _) => { let cv = ty::Value { valtree: cv.valtree, ty }; @@ -301,12 +305,13 @@ pub fn valtree_to_const_value<'tcx>( || matches!(cv.ty.kind(), ty::Adt(def, _) if def.is_struct())) { // A Scalar tuple/struct; we can avoid creating an allocation. - let branches = cv.valtree.unwrap_branch(); + let branches = cv.to_branch(); // Find the non-ZST field. (There can be aligned ZST!) for (i, &inner_valtree) in branches.iter().enumerate() { let field = layout.field(&LayoutCx::new(tcx, typing_env), i); if !field.is_zst() { - let cv = ty::Value { valtree: inner_valtree, ty: field.ty }; + let cv = + ty::Value { valtree: inner_valtree.to_value().valtree, ty: field.ty }; return valtree_to_const_value(tcx, typing_env, cv); } } @@ -381,7 +386,7 @@ fn valtree_into_mplace<'tcx>( // Zero-sized type, nothing to do. } ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(..) => { - let scalar_int = valtree.unwrap_leaf(); + let scalar_int = valtree.to_leaf(); debug!("writing trivial valtree {:?} to place {:?}", scalar_int, place); ecx.write_immediate(Immediate::Scalar(scalar_int.into()), place).unwrap(); } @@ -391,13 +396,13 @@ fn valtree_into_mplace<'tcx>( ecx.write_immediate(imm, place).unwrap(); } ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Str | ty::Slice(_) => { - let branches = valtree.unwrap_branch(); + let branches = valtree.to_branch(); // Need to downcast place for enums let (place_adjusted, branches, variant_idx) = match ty.kind() { ty::Adt(def, _) if def.is_enum() => { // First element of valtree corresponds to variant - let scalar_int = branches[0].unwrap_leaf(); + let scalar_int = branches[0].to_leaf(); let variant_idx = VariantIdx::from_u32(scalar_int.to_u32()); let variant = def.variant(variant_idx); debug!(?variant); @@ -425,7 +430,7 @@ fn valtree_into_mplace<'tcx>( }; debug!(?place_inner); - valtree_into_mplace(ecx, &place_inner, *inner_valtree); + valtree_into_mplace(ecx, &place_inner, inner_valtree.to_value().valtree); dump_place(ecx, &place_inner); } diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs index 20de47683122..33a115384a88 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs @@ -545,7 +545,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let (right, right_len) = self.project_to_simd(&args[1])?; let (dest, dest_len) = self.project_to_simd(&dest)?; - let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch(); + let index = generic_args[2].expect_const().to_branch(); let index_len = index.len(); assert_eq!(left_len, right_len); @@ -553,7 +553,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { for i in 0..dest_len { let src_index: u64 = - index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into(); + index[usize::try_from(i).unwrap()].to_leaf().to_u32().into(); let dest = self.project_index(&dest, i)?; let val = if src_index < left_len { @@ -657,9 +657,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.check_simd_ptr_alignment( ptr, dest_layout, - generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(), + generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(), )?; for i in 0..dest_len { @@ -689,9 +687,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.check_simd_ptr_alignment( ptr, args[2].layout, - generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0] - .unwrap_leaf() - .to_simd_alignment(), + generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(), )?; for i in 0..vals_len { diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index d1d4c32184ee..4fa39eb83e9e 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -92,7 +92,7 @@ macro_rules! arena_types { [] name_set: rustc_data_structures::unord::UnordSet, [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, - [] valtree: rustc_middle::ty::ValTreeKind<'tcx>, + [] valtree: rustc_middle::ty::ValTreeKind>, [] stable_order_of_exportable_impls: rustc_data_structures::fx::FxIndexMap, diff --git a/compiler/rustc_middle/src/mir/consts.rs b/compiler/rustc_middle/src/mir/consts.rs index fe352df3b9f0..afe39e4481ef 100644 --- a/compiler/rustc_middle/src/mir/consts.rs +++ b/compiler/rustc_middle/src/mir/consts.rs @@ -302,15 +302,7 @@ impl<'tcx> Const<'tcx> { #[inline] pub fn try_to_scalar(self) -> Option { match self { - Const::Ty(_, c) => match c.kind() { - ty::ConstKind::Value(cv) if cv.ty.is_primitive() => { - // A valtree of a type where leaves directly represent the scalar const value. - // Just checking whether it is a leaf is insufficient as e.g. references are leafs - // but the leaf value is the value they point to, not the reference itself! - Some(cv.valtree.unwrap_leaf().into()) - } - _ => None, - }, + Const::Ty(_, c) => c.try_to_scalar(), Const::Val(val, _) => val.try_to_scalar(), Const::Unevaluated(..) => None, } @@ -321,10 +313,7 @@ impl<'tcx> Const<'tcx> { // This is equivalent to `self.try_to_scalar()?.try_to_int().ok()`, but measurably faster. match self { Const::Val(ConstValue::Scalar(Scalar::Int(x)), _) => Some(x), - Const::Ty(_, c) => match c.kind() { - ty::ConstKind::Value(cv) if cv.ty.is_primitive() => Some(cv.valtree.unwrap_leaf()), - _ => None, - }, + Const::Ty(_, c) => c.try_to_leaf(), _ => None, } } @@ -377,14 +366,10 @@ impl<'tcx> Const<'tcx> { tcx: TyCtxt<'tcx>, typing_env: ty::TypingEnv<'tcx>, ) -> Option { - if let Const::Ty(_, c) = self - && let ty::ConstKind::Value(cv) = c.kind() - && cv.ty.is_primitive() - { - // Avoid the `valtree_to_const_val` query. Can only be done on primitive types that - // are valtree leaves, and *not* on references. (References should return the - // pointer here, which valtrees don't represent.) - Some(cv.valtree.unwrap_leaf().into()) + if let Const::Ty(_, c) = self { + // We don't evaluate anything for type system constants as normalizing + // the MIR will handle this for us + c.try_to_scalar() } else { self.eval(tcx, typing_env, DUMMY_SP).ok()?.try_to_scalar() } diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index 3d893bf75e0b..31745cae3c06 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -928,7 +928,7 @@ impl<'tcx> PatRange<'tcx> { let lo_is_min = match self.lo { PatRangeBoundary::NegInfinity => true, PatRangeBoundary::Finite(value) => { - let lo = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias; + let lo = value.to_leaf().to_bits(size) ^ bias; lo <= min } PatRangeBoundary::PosInfinity => false, @@ -937,7 +937,7 @@ impl<'tcx> PatRange<'tcx> { let hi_is_max = match self.hi { PatRangeBoundary::NegInfinity => false, PatRangeBoundary::Finite(value) => { - let hi = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias; + let hi = value.to_leaf().to_bits(size) ^ bias; hi > max || hi == max && self.end == RangeEnd::Included } PatRangeBoundary::PosInfinity => true, @@ -1029,7 +1029,7 @@ impl<'tcx> PatRangeBoundary<'tcx> { } pub fn to_bits(self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> u128 { match self { - Self::Finite(value) => value.try_to_scalar_int().unwrap().to_bits_unchecked(), + Self::Finite(value) => value.to_leaf().to_bits_unchecked(), Self::NegInfinity => { // Unwrap is ok because the type is known to be numeric. ty.numeric_min_and_max_as_bits(tcx).unwrap().0 @@ -1057,7 +1057,7 @@ impl<'tcx> PatRangeBoundary<'tcx> { // many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared // in this way. (Finite(a), Finite(b)) if matches!(ty.kind(), ty::Int(_) | ty::Uint(_) | ty::Char) => { - if let (Some(a), Some(b)) = (a.try_to_scalar_int(), b.try_to_scalar_int()) { + if let (Some(a), Some(b)) = (a.try_to_leaf(), b.try_to_leaf()) { let sz = ty.primitive_size(tcx); let cmp = match ty.kind() { ty::Uint(_) | ty::Char => a.to_uint(sz).cmp(&b.to_uint(sz)), diff --git a/compiler/rustc_middle/src/ty/consts.rs b/compiler/rustc_middle/src/ty/consts.rs index 787ea5f9363d..da3caf0bb210 100644 --- a/compiler/rustc_middle/src/ty/consts.rs +++ b/compiler/rustc_middle/src/ty/consts.rs @@ -6,6 +6,7 @@ use rustc_macros::{HashStable, TyDecodable, TyEncodable}; use rustc_type_ir::walk::TypeWalker; use rustc_type_ir::{self as ir, TypeFlags, WithCachedTypeInfo}; +use crate::mir::interpret::Scalar; use crate::ty::{self, Ty, TyCtxt}; mod int; @@ -260,7 +261,7 @@ impl<'tcx> Const<'tcx> { /// Attempts to convert to a value. /// - /// Note that this does not evaluate the constant. + /// Note that this does not normalize the constant. pub fn try_to_value(self) -> Option> { match self.kind() { ty::ConstKind::Value(cv) => Some(cv), @@ -268,6 +269,45 @@ impl<'tcx> Const<'tcx> { } } + /// Converts to a `ValTreeKind::Leaf` value, `panic`'ing + /// if this constant is some other kind. + /// + /// Note that this does not normalize the constant. + #[inline] + pub fn to_leaf(self) -> ScalarInt { + self.to_value().to_leaf() + } + + /// Converts to a `ValTreeKind::Branch` value, `panic`'ing + /// if this constant is some other kind. + /// + /// Note that this does not normalize the constant. + #[inline] + pub fn to_branch(self) -> &'tcx [ty::Const<'tcx>] { + self.to_value().to_branch() + } + + /// Attempts to convert to a `ValTreeKind::Leaf` value. + /// + /// Note that this does not normalize the constant. + pub fn try_to_leaf(self) -> Option { + self.try_to_value()?.try_to_leaf() + } + + /// Attempts to convert to a `ValTreeKind::Leaf` value. + /// + /// Note that this does not normalize the constant. + pub fn try_to_scalar(self) -> Option { + self.try_to_leaf().map(Scalar::Int) + } + + /// Attempts to convert to a `ValTreeKind::Branch` value. + /// + /// Note that this does not normalize the constant. + pub fn try_to_branch(self) -> Option<&'tcx [ty::Const<'tcx>]> { + self.try_to_value()?.try_to_branch() + } + /// Convenience method to extract the value of a usize constant, /// useful to get the length of an array type. /// diff --git a/compiler/rustc_middle/src/ty/consts/valtree.rs b/compiler/rustc_middle/src/ty/consts/valtree.rs index a14e47d70821..8afee2dfe3bc 100644 --- a/compiler/rustc_middle/src/ty/consts/valtree.rs +++ b/compiler/rustc_middle/src/ty/consts/valtree.rs @@ -3,89 +3,38 @@ use std::ops::Deref; use rustc_data_structures::intern::Interned; use rustc_hir::def::Namespace; -use rustc_macros::{HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable}; +use rustc_macros::{ + HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable, extension, +}; use super::ScalarInt; use crate::mir::interpret::{ErrorHandled, Scalar}; use crate::ty::print::{FmtPrinter, PrettyPrinter}; -use crate::ty::{self, Ty, TyCtxt}; +use crate::ty::{self, Ty, TyCtxt, ValTreeKind}; -/// This datastructure is used to represent the value of constants used in the type system. -/// -/// We explicitly choose a different datastructure from the way values are processed within -/// CTFE, as in the type system equal values (according to their `PartialEq`) must also have -/// equal representation (`==` on the rustc data structure, e.g. `ValTree`) and vice versa. -/// Since CTFE uses `AllocId` to represent pointers, it often happens that two different -/// `AllocId`s point to equal values. So we may end up with different representations for -/// two constants whose value is `&42`. Furthermore any kind of struct that has padding will -/// have arbitrary values within that padding, even if the values of the struct are the same. -/// -/// `ValTree` does not have this problem with representation, as it only contains integers or -/// lists of (nested) `ValTree`. -#[derive(Clone, Debug, Hash, Eq, PartialEq)] -#[derive(HashStable, TyEncodable, TyDecodable)] -pub enum ValTreeKind<'tcx> { - /// integers, `bool`, `char` are represented as scalars. - /// See the `ScalarInt` documentation for how `ScalarInt` guarantees that equal values - /// of these types have the same representation. - Leaf(ScalarInt), - - //SliceOrStr(ValSlice<'tcx>), - // don't use SliceOrStr for now - /// The fields of any kind of aggregate. Structs, tuples and arrays are represented by - /// listing their fields' values in order. - /// - /// Enums are represented by storing their variant index as a u32 field, followed by all - /// the fields of the variant. - /// - /// ZST types are represented as an empty slice. - Branch(Box<[ValTree<'tcx>]>), -} - -impl<'tcx> ValTreeKind<'tcx> { - #[inline] - pub fn unwrap_leaf(&self) -> ScalarInt { - match self { - Self::Leaf(s) => *s, - _ => bug!("expected leaf, got {:?}", self), - } - } - - #[inline] - pub fn unwrap_branch(&self) -> &[ValTree<'tcx>] { - match self { - Self::Branch(branch) => &**branch, - _ => bug!("expected branch, got {:?}", self), - } - } - - pub fn try_to_scalar(&self) -> Option { - self.try_to_scalar_int().map(Scalar::Int) - } - - pub fn try_to_scalar_int(&self) -> Option { - match self { - Self::Leaf(s) => Some(*s), - Self::Branch(_) => None, - } - } - - pub fn try_to_branch(&self) -> Option<&[ValTree<'tcx>]> { - match self { - Self::Branch(branch) => Some(&**branch), - Self::Leaf(_) => None, - } +#[extension(pub trait ValTreeKindExt<'tcx>)] +impl<'tcx> ty::ValTreeKind> { + fn try_to_scalar(&self) -> Option { + self.try_to_leaf().map(Scalar::Int) } } /// An interned valtree. Use this rather than `ValTreeKind`, whenever possible. /// -/// See the docs of [`ValTreeKind`] or the [dev guide] for an explanation of this type. +/// See the docs of [`ty::ValTreeKind`] or the [dev guide] for an explanation of this type. /// /// [dev guide]: https://rustc-dev-guide.rust-lang.org/mir/index.html#valtrees #[derive(Copy, Clone, Hash, Eq, PartialEq)] #[derive(HashStable)] -pub struct ValTree<'tcx>(pub(crate) Interned<'tcx, ValTreeKind<'tcx>>); +// FIXME(mgca): Try not interning here. We already intern `ty::Const` which `ValTreeKind` +// recurses through +pub struct ValTree<'tcx>(pub(crate) Interned<'tcx, ty::ValTreeKind>>); + +impl<'tcx> rustc_type_ir::inherent::ValTree> for ValTree<'tcx> { + fn kind(&self) -> &ty::ValTreeKind> { + &self + } +} impl<'tcx> ValTree<'tcx> { /// Returns the zero-sized valtree: `Branch([])`. @@ -94,28 +43,33 @@ impl<'tcx> ValTree<'tcx> { } pub fn is_zst(self) -> bool { - matches!(*self, ValTreeKind::Branch(box [])) + matches!(*self, ty::ValTreeKind::Branch(box [])) } pub fn from_raw_bytes(tcx: TyCtxt<'tcx>, bytes: &[u8]) -> Self { - let branches = bytes.iter().map(|&b| Self::from_scalar_int(tcx, b.into())); + let branches = bytes.iter().map(|&b| { + ty::Const::new_value(tcx, Self::from_scalar_int(tcx, b.into()), tcx.types.u8) + }); Self::from_branches(tcx, branches) } - pub fn from_branches(tcx: TyCtxt<'tcx>, branches: impl IntoIterator) -> Self { - tcx.intern_valtree(ValTreeKind::Branch(branches.into_iter().collect())) + pub fn from_branches( + tcx: TyCtxt<'tcx>, + branches: impl IntoIterator>, + ) -> Self { + tcx.intern_valtree(ty::ValTreeKind::Branch(branches.into_iter().collect())) } pub fn from_scalar_int(tcx: TyCtxt<'tcx>, i: ScalarInt) -> Self { - tcx.intern_valtree(ValTreeKind::Leaf(i)) + tcx.intern_valtree(ty::ValTreeKind::Leaf(i)) } } impl<'tcx> Deref for ValTree<'tcx> { - type Target = &'tcx ValTreeKind<'tcx>; + type Target = &'tcx ty::ValTreeKind>; #[inline] - fn deref(&self) -> &&'tcx ValTreeKind<'tcx> { + fn deref(&self) -> &&'tcx ty::ValTreeKind> { &self.0.0 } } @@ -154,7 +108,7 @@ impl<'tcx> Value<'tcx> { let (ty::Bool | ty::Char | ty::Uint(_) | ty::Int(_) | ty::Float(_)) = self.ty.kind() else { return None; }; - let scalar = self.valtree.try_to_scalar_int()?; + let scalar = self.try_to_leaf()?; let input = typing_env.with_post_analysis_normalized(tcx).as_query_input(self.ty); let size = tcx.layout_of(input).ok()?.size; Some(scalar.to_bits(size)) @@ -164,14 +118,14 @@ impl<'tcx> Value<'tcx> { if !self.ty.is_bool() { return None; } - self.valtree.try_to_scalar_int()?.try_to_bool().ok() + self.try_to_leaf()?.try_to_bool().ok() } pub fn try_to_target_usize(self, tcx: TyCtxt<'tcx>) -> Option { if !self.ty.is_usize() { return None; } - self.valtree.try_to_scalar_int().map(|s| s.to_target_usize(tcx)) + self.try_to_leaf().map(|s| s.to_target_usize(tcx)) } /// Get the values inside the ValTree as a slice of bytes. This only works for @@ -192,9 +146,48 @@ impl<'tcx> Value<'tcx> { _ => return None, } - Some(tcx.arena.alloc_from_iter( - self.valtree.unwrap_branch().into_iter().map(|v| v.unwrap_leaf().to_u8()), - )) + Some(tcx.arena.alloc_from_iter(self.to_branch().into_iter().map(|ct| ct.to_leaf().to_u8()))) + } + + /// Converts to a `ValTreeKind::Leaf` value, `panic`'ing + /// if this constant is some other kind. + #[inline] + pub fn to_leaf(self) -> ScalarInt { + match &**self.valtree { + ValTreeKind::Leaf(s) => *s, + ValTreeKind::Branch(..) => bug!("expected leaf, got {:?}", self), + } + } + + /// Converts to a `ValTreeKind::Branch` value, `panic`'ing + /// if this constant is some other kind. + #[inline] + pub fn to_branch(self) -> &'tcx [ty::Const<'tcx>] { + match &**self.valtree { + ValTreeKind::Branch(branch) => &**branch, + ValTreeKind::Leaf(..) => bug!("expected branch, got {:?}", self), + } + } + + /// Attempts to convert to a `ValTreeKind::Leaf` value. + pub fn try_to_leaf(self) -> Option { + match &**self.valtree { + ValTreeKind::Leaf(s) => Some(*s), + ValTreeKind::Branch(_) => None, + } + } + + /// Attempts to convert to a `ValTreeKind::Leaf` value. + pub fn try_to_scalar(&self) -> Option { + self.try_to_leaf().map(Scalar::Int) + } + + /// Attempts to convert to a `ValTreeKind::Branch` value. + pub fn try_to_branch(self) -> Option<&'tcx [ty::Const<'tcx>]> { + match &**self.valtree { + ValTreeKind::Branch(branch) => Some(&**branch), + ValTreeKind::Leaf(_) => None, + } } } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 9e4692b96418..000e7c16098e 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -165,6 +165,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> { type ValueConst = ty::Value<'tcx>; type ExprConst = ty::Expr<'tcx>; type ValTree = ty::ValTree<'tcx>; + type ScalarInt = ty::ScalarInt; type Region = Region<'tcx>; type EarlyParamRegion = ty::EarlyParamRegion; @@ -954,7 +955,7 @@ pub struct CtxtInterners<'tcx> { fields: InternedSet<'tcx, List>, local_def_ids: InternedSet<'tcx, List>, captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>, - valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>, + valtree: InternedSet<'tcx, ty::ValTreeKind>>, patterns: InternedSet<'tcx, List>>, outlives: InternedSet<'tcx, List>>, } @@ -2777,7 +2778,7 @@ macro_rules! direct_interners { // crate only, and have a corresponding `mk_` function. direct_interners! { region: pub(crate) intern_region(RegionKind<'tcx>): Region -> Region<'tcx>, - valtree: pub(crate) intern_valtree(ValTreeKind<'tcx>): ValTree -> ValTree<'tcx>, + valtree: pub(crate) intern_valtree(ValTreeKind>): ValTree -> ValTree<'tcx>, pat: pub mk_pat(PatternKind<'tcx>): Pattern -> Pattern<'tcx>, const_allocation: pub mk_const_alloc(Allocation): ConstAllocation -> ConstAllocation<'tcx>, layout: pub mk_layout(LayoutData): Layout -> Layout<'tcx>, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 60b6b7024dcd..5cc5ab0d5268 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -77,7 +77,7 @@ pub use self::closure::{ }; pub use self::consts::{ AnonConstKind, AtomicOrdering, Const, ConstInt, ConstKind, ConstToValTreeResult, Expr, - ExprKind, ScalarInt, SimdAlign, UnevaluatedConst, ValTree, ValTreeKind, Value, + ExprKind, ScalarInt, SimdAlign, UnevaluatedConst, ValTree, ValTreeKindExt, Value, }; pub use self::context::{ CtxtInterners, CurrentGcx, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt, TyCtxtFeed, tls, diff --git a/compiler/rustc_middle/src/ty/pattern.rs b/compiler/rustc_middle/src/ty/pattern.rs index 335e5c064743..6acf0aff800f 100644 --- a/compiler/rustc_middle/src/ty/pattern.rs +++ b/compiler/rustc_middle/src/ty/pattern.rs @@ -72,7 +72,7 @@ impl<'tcx> IrPrint> for TyCtxt<'tcx> { write!(f, "{start}")?; if let Some(c) = end.try_to_value() { - let end = c.valtree.unwrap_leaf(); + let end = c.to_leaf(); let size = end.size(); let max = match c.ty.kind() { ty::Int(_) => { diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 5126d902a6d5..1a5a3f3965fa 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -256,8 +256,8 @@ TrivialTypeTraversalImpls! { crate::ty::AssocItem, crate::ty::AssocKind, crate::ty::BoundRegion, + crate::ty::ScalarInt, crate::ty::UserTypeAnnotationIndex, - crate::ty::ValTree<'tcx>, crate::ty::abstract_const::NotConstEvaluatable, crate::ty::adjustment::AutoBorrowMutability, crate::ty::adjustment::PointerCoercion, @@ -697,6 +697,37 @@ impl<'tcx> TypeSuperVisitable> for ty::Const<'tcx> { } } +impl<'tcx> TypeVisitable> for ty::ValTree<'tcx> { + fn visit_with>>(&self, visitor: &mut V) -> V::Result { + let inner: &ty::ValTreeKind> = &*self; + inner.visit_with(visitor) + } +} + +impl<'tcx> TypeFoldable> for ty::ValTree<'tcx> { + fn try_fold_with>>( + self, + folder: &mut F, + ) -> Result { + let inner: &ty::ValTreeKind> = &*self; + let new_inner = inner.clone().try_fold_with(folder)?; + + if inner == &new_inner { + Ok(self) + } else { + let valtree = folder.cx().intern_valtree(new_inner); + Ok(valtree) + } + } + + fn fold_with>>(self, folder: &mut F) -> Self { + let inner: &ty::ValTreeKind> = &*self; + let new_inner = inner.clone().fold_with(folder); + + if inner == &new_inner { self } else { folder.cx().intern_valtree(new_inner) } + } +} + impl<'tcx> TypeVisitable> for rustc_span::ErrorGuaranteed { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_error(*self) diff --git a/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs b/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs index b221318bf0b1..ddaa61c6cc91 100644 --- a/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs +++ b/compiler/rustc_mir_build/src/builder/custom/parse/instruction.rs @@ -157,7 +157,7 @@ impl<'a, 'tcx> ParseCtxt<'a, 'tcx> { }); } }; - values.push(value.valtree.unwrap_leaf().to_bits_unchecked()); + values.push(value.to_leaf().to_bits_unchecked()); targets.push(self.parse_block(arm.body)?); } diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 8897ca7c7210..465b1db9a164 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -2935,7 +2935,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { bug!("malformed valtree for an enum") }; - let ValTreeKind::Leaf(actual_variant_idx) = ***actual_variant_idx else { + let ValTreeKind::Leaf(actual_variant_idx) = *actual_variant_idx.to_value().valtree + else { bug!("malformed valtree for an enum") }; @@ -2943,7 +2944,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } Constructor::IntRange(int_range) => { let size = pat.ty().primitive_size(self.tcx); - let actual_int = valtree.unwrap_leaf().to_bits(size); + let actual_int = valtree.to_leaf().to_bits(size); let actual_int = if pat.ty().is_signed() { MaybeInfiniteInt::new_finite_int(actual_int, size.bits()) } else { @@ -2951,33 +2952,33 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }; IntRange::from_singleton(actual_int).is_subrange(int_range) } - Constructor::Bool(pattern_value) => match valtree.unwrap_leaf().try_to_bool() { + Constructor::Bool(pattern_value) => match valtree.to_leaf().try_to_bool() { Ok(actual_value) => *pattern_value == actual_value, Err(()) => bug!("bool value with invalid bits"), }, Constructor::F16Range(l, h, end) => { - let actual = valtree.unwrap_leaf().to_f16(); + let actual = valtree.to_leaf().to_f16(); match end { RangeEnd::Included => (*l..=*h).contains(&actual), RangeEnd::Excluded => (*l..*h).contains(&actual), } } Constructor::F32Range(l, h, end) => { - let actual = valtree.unwrap_leaf().to_f32(); + let actual = valtree.to_leaf().to_f32(); match end { RangeEnd::Included => (*l..=*h).contains(&actual), RangeEnd::Excluded => (*l..*h).contains(&actual), } } Constructor::F64Range(l, h, end) => { - let actual = valtree.unwrap_leaf().to_f64(); + let actual = valtree.to_leaf().to_f64(); match end { RangeEnd::Included => (*l..=*h).contains(&actual), RangeEnd::Excluded => (*l..*h).contains(&actual), } } Constructor::F128Range(l, h, end) => { - let actual = valtree.unwrap_leaf().to_f128(); + let actual = valtree.to_leaf().to_f128(); match end { RangeEnd::Included => (*l..=*h).contains(&actual), RangeEnd::Excluded => (*l..*h).contains(&actual), diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 402587bff7e8..2401bc6648ed 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -116,7 +116,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let switch_targets = SwitchTargets::new( target_blocks.iter().filter_map(|(&branch, &block)| { if let TestBranch::Constant(value) = branch { - let bits = value.valtree.unwrap_leaf().to_bits_unchecked(); + let bits = value.to_leaf().to_bits_unchecked(); Some((bits, block)) } else { None diff --git a/compiler/rustc_mir_build/src/builder/scope.rs b/compiler/rustc_mir_build/src/builder/scope.rs index a176f3e49a50..981704052536 100644 --- a/compiler/rustc_mir_build/src/builder/scope.rs +++ b/compiler/rustc_mir_build/src/builder/scope.rs @@ -897,7 +897,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.tcx, ValTree::from_branches( self.tcx, - [ValTree::from_scalar_int(self.tcx, variant_index.as_u32().into())], + [ty::Const::new_value( + self.tcx, + ValTree::from_scalar_int( + self.tcx, + variant_index.as_u32().into(), + ), + self.tcx.types.u32, + )], ), self.thir[value].ty, ), diff --git a/compiler/rustc_mir_build/src/thir/constant.rs b/compiler/rustc_mir_build/src/thir/constant.rs index 6e071fb344c4..563212a51f31 100644 --- a/compiler/rustc_mir_build/src/thir/constant.rs +++ b/compiler/rustc_mir_build/src/thir/constant.rs @@ -63,7 +63,7 @@ pub(crate) fn lit_to_const<'tcx>( // A CStr is a newtype around a byte slice, so we create the inner slice here. // We need a branch for each "level" of the data structure. let bytes = ty::ValTree::from_raw_bytes(tcx, byte_sym.as_byte_str()); - ty::ValTree::from_branches(tcx, [bytes]) + ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, *inner_ty)]) } (ast::LitKind::Int(n, _), ty::Uint(ui)) if !neg => { let scalar_int = trunc(n.get(), *ui); diff --git a/compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs b/compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs index c3c4c455b965..ce4c89a8eb2e 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs @@ -239,14 +239,14 @@ impl<'tcx> ConstToPat<'tcx> { return self.mk_err(tcx.dcx().create_err(err), ty); } ty::Adt(adt_def, args) if adt_def.is_enum() => { - let (&variant_index, fields) = cv.unwrap_branch().split_first().unwrap(); - let variant_index = VariantIdx::from_u32(variant_index.unwrap_leaf().to_u32()); + let (&variant_index, fields) = cv.to_branch().split_first().unwrap(); + let variant_index = VariantIdx::from_u32(variant_index.to_leaf().to_u32()); PatKind::Variant { adt_def: *adt_def, args, variant_index, subpatterns: self.field_pats( - fields.iter().copied().zip( + fields.iter().map(|ct| ct.to_value().valtree).zip( adt_def.variants()[variant_index] .fields .iter() @@ -258,28 +258,32 @@ impl<'tcx> ConstToPat<'tcx> { ty::Adt(def, args) => { assert!(!def.is_union()); // Valtree construction would never succeed for unions. PatKind::Leaf { - subpatterns: self.field_pats(cv.unwrap_branch().iter().copied().zip( - def.non_enum_variant().fields.iter().map(|field| field.ty(tcx, args)), - )), + subpatterns: self.field_pats( + cv.to_branch().iter().map(|ct| ct.to_value().valtree).zip( + def.non_enum_variant().fields.iter().map(|field| field.ty(tcx, args)), + ), + ), } } ty::Tuple(fields) => PatKind::Leaf { - subpatterns: self.field_pats(cv.unwrap_branch().iter().copied().zip(fields.iter())), + subpatterns: self.field_pats( + cv.to_branch().iter().map(|ct| ct.to_value().valtree).zip(fields.iter()), + ), }, ty::Slice(elem_ty) => PatKind::Slice { prefix: cv - .unwrap_branch() + .to_branch() .iter() - .map(|val| *self.valtree_to_pat(*val, *elem_ty)) + .map(|val| *self.valtree_to_pat(val.to_value().valtree, *elem_ty)) .collect(), slice: None, suffix: Box::new([]), }, ty::Array(elem_ty, _) => PatKind::Array { prefix: cv - .unwrap_branch() + .to_branch() .iter() - .map(|val| *self.valtree_to_pat(*val, *elem_ty)) + .map(|val| *self.valtree_to_pat(val.to_value().valtree, *elem_ty)) .collect(), slice: None, suffix: Box::new([]), @@ -312,7 +316,7 @@ impl<'tcx> ConstToPat<'tcx> { } }, ty::Float(flt) => { - let v = cv.unwrap_leaf(); + let v = cv.to_leaf(); let is_nan = match flt { ty::FloatTy::F16 => v.to_f16().is_nan(), ty::FloatTy::F32 => v.to_f32().is_nan(), diff --git a/compiler/rustc_pattern_analysis/src/rustc.rs b/compiler/rustc_pattern_analysis/src/rustc.rs index df86233c2b05..d66c303b1726 100644 --- a/compiler/rustc_pattern_analysis/src/rustc.rs +++ b/compiler/rustc_pattern_analysis/src/rustc.rs @@ -440,7 +440,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { match bdy { PatRangeBoundary::NegInfinity => MaybeInfiniteInt::NegInfinity, PatRangeBoundary::Finite(value) => { - let bits = value.try_to_scalar_int().unwrap().to_bits_unchecked(); + let bits = value.to_leaf().to_bits_unchecked(); match *ty.kind() { ty::Int(ity) => { let size = Integer::from_int_ty(&self.tcx, ity).size().bits(); @@ -540,7 +540,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Char | ty::Int(_) | ty::Uint(_) => { ctor = { - let bits = value.valtree.unwrap_leaf().to_bits_unchecked(); + let bits = value.to_leaf().to_bits_unchecked(); let x = match *ty.kind() { ty::Int(ity) => { let size = Integer::from_int_ty(&cx.tcx, ity).size().bits(); @@ -555,7 +555,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Float(ty::FloatTy::F16) => { use rustc_apfloat::Float; - let bits = value.valtree.unwrap_leaf().to_u16(); + let bits = value.to_leaf().to_u16(); let value = rustc_apfloat::ieee::Half::from_bits(bits.into()); ctor = F16Range(value, value, RangeEnd::Included); fields = vec![]; @@ -563,7 +563,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Float(ty::FloatTy::F32) => { use rustc_apfloat::Float; - let bits = value.valtree.unwrap_leaf().to_u32(); + let bits = value.to_leaf().to_u32(); let value = rustc_apfloat::ieee::Single::from_bits(bits.into()); ctor = F32Range(value, value, RangeEnd::Included); fields = vec![]; @@ -571,7 +571,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Float(ty::FloatTy::F64) => { use rustc_apfloat::Float; - let bits = value.valtree.unwrap_leaf().to_u64(); + let bits = value.to_leaf().to_u64(); let value = rustc_apfloat::ieee::Double::from_bits(bits.into()); ctor = F64Range(value, value, RangeEnd::Included); fields = vec![]; @@ -579,7 +579,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Float(ty::FloatTy::F128) => { use rustc_apfloat::Float; - let bits = value.valtree.unwrap_leaf().to_u128(); + let bits = value.to_leaf().to_u128(); let value = rustc_apfloat::ieee::Quad::from_bits(bits); ctor = F128Range(value, value, RangeEnd::Included); fields = vec![]; @@ -623,12 +623,8 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> { } ty::Float(fty) => { use rustc_apfloat::Float; - let lo = lo - .as_finite() - .map(|c| c.try_to_scalar_int().unwrap().to_bits_unchecked()); - let hi = hi - .as_finite() - .map(|c| c.try_to_scalar_int().unwrap().to_bits_unchecked()); + let lo = lo.as_finite().map(|c| c.to_leaf().to_bits_unchecked()); + let hi = hi.as_finite().map(|c| c.to_leaf().to_bits_unchecked()); match fty { ty::FloatTy::F16 => { use rustc_apfloat::ieee::Half; diff --git a/compiler/rustc_symbol_mangling/src/legacy.rs b/compiler/rustc_symbol_mangling/src/legacy.rs index ee2621af8428..ea16231880e2 100644 --- a/compiler/rustc_symbol_mangling/src/legacy.rs +++ b/compiler/rustc_symbol_mangling/src/legacy.rs @@ -293,7 +293,7 @@ impl<'tcx> Printer<'tcx> for LegacySymbolMangler<'tcx> { ty::ConstKind::Value(cv) if cv.ty.is_integral() => { // The `pretty_print_const` formatting depends on -Zverbose-internals // flag, so we cannot reuse it here. - let scalar = cv.valtree.unwrap_leaf(); + let scalar = cv.to_leaf(); let signed = matches!(cv.ty.kind(), ty::Int(_)); write!( self, diff --git a/compiler/rustc_transmute/src/lib.rs b/compiler/rustc_transmute/src/lib.rs index 36281ff16bce..58cb2eb6556e 100644 --- a/compiler/rustc_transmute/src/lib.rs +++ b/compiler/rustc_transmute/src/lib.rs @@ -129,10 +129,7 @@ mod rustc { use rustc_middle::ty::ScalarInt; use rustc_span::sym; - let Some(cv) = ct.try_to_value() else { - return None; - }; - + let cv = ct.try_to_value()?; let adt_def = cv.ty.ty_adt_def()?; if !tcx.is_lang_item(adt_def.did(), LangItem::TransmuteOpts) { @@ -149,7 +146,7 @@ mod rustc { } let variant = adt_def.non_enum_variant(); - let fields = cv.valtree.unwrap_branch(); + let fields = cv.to_branch(); let get_field = |name| { let (field_idx, _) = variant @@ -158,7 +155,7 @@ mod rustc { .enumerate() .find(|(_, field_def)| name == field_def.name) .unwrap_or_else(|| panic!("There were no fields named `{name}`.")); - fields[field_idx].unwrap_leaf() == ScalarInt::TRUE + fields[field_idx].to_leaf() == ScalarInt::TRUE }; Some(Self { diff --git a/compiler/rustc_ty_utils/src/consts.rs b/compiler/rustc_ty_utils/src/consts.rs index 4b19d0f16d78..d8d5b7fc75cc 100644 --- a/compiler/rustc_ty_utils/src/consts.rs +++ b/compiler/rustc_ty_utils/src/consts.rs @@ -25,15 +25,14 @@ fn destructure_const<'tcx>( let ty::ConstKind::Value(cv) = const_.kind() else { bug!("cannot destructure constant {:?}", const_) }; - - let branches = cv.valtree.unwrap_branch(); + let branches = cv.to_branch(); let (fields, variant) = match cv.ty.kind() { ty::Array(inner_ty, _) | ty::Slice(inner_ty) => { // construct the consts for the elements of the array/slice let field_consts = branches .iter() - .map(|b| ty::Const::new_value(tcx, *b, *inner_ty)) + .map(|b| ty::Const::new_value(tcx, b.to_value().valtree, *inner_ty)) .collect::>(); debug!(?field_consts); @@ -43,7 +42,7 @@ fn destructure_const<'tcx>( ty::Adt(def, args) => { let (variant_idx, branches) = if def.is_enum() { let (head, rest) = branches.split_first().unwrap(); - (VariantIdx::from_u32(head.unwrap_leaf().to_u32()), rest) + (VariantIdx::from_u32(head.to_leaf().to_u32()), rest) } else { (FIRST_VARIANT, branches) }; @@ -52,7 +51,8 @@ fn destructure_const<'tcx>( for (field, field_valtree) in iter::zip(fields, branches) { let field_ty = field.ty(tcx, args); - let field_const = ty::Const::new_value(tcx, *field_valtree, field_ty); + let field_const = + ty::Const::new_value(tcx, field_valtree.to_value().valtree, field_ty); field_consts.push(field_const); } debug!(?field_consts); @@ -61,7 +61,9 @@ fn destructure_const<'tcx>( } ty::Tuple(elem_tys) => { let fields = iter::zip(*elem_tys, branches) - .map(|(elem_ty, elem_valtree)| ty::Const::new_value(tcx, *elem_valtree, elem_ty)) + .map(|(elem_ty, elem_valtree)| { + ty::Const::new_value(tcx, elem_valtree.to_value().valtree, elem_ty) + }) .collect::>(); (fields, None) diff --git a/compiler/rustc_type_ir/src/const_kind.rs b/compiler/rustc_type_ir/src/const_kind.rs index f315e8b3e11c..a5c40d4eb199 100644 --- a/compiler/rustc_type_ir/src/const_kind.rs +++ b/compiler/rustc_type_ir/src/const_kind.rs @@ -127,3 +127,76 @@ impl HashStable for InferConst { } } } + +/// This datastructure is used to represent the value of constants used in the type system. +/// +/// We explicitly choose a different datastructure from the way values are processed within +/// CTFE, as in the type system equal values (according to their `PartialEq`) must also have +/// equal representation (`==` on the rustc data structure, e.g. `ValTree`) and vice versa. +/// Since CTFE uses `AllocId` to represent pointers, it often happens that two different +/// `AllocId`s point to equal values. So we may end up with different representations for +/// two constants whose value is `&42`. Furthermore any kind of struct that has padding will +/// have arbitrary values within that padding, even if the values of the struct are the same. +/// +/// `ValTree` does not have this problem with representation, as it only contains integers or +/// lists of (nested) `ty::Const`s (which may indirectly contain more `ValTree`s). +#[derive_where(Clone, Debug, Hash, Eq, PartialEq; I: Interner)] +#[derive(TypeVisitable_Generic, TypeFoldable_Generic)] +#[cfg_attr( + feature = "nightly", + derive(Decodable_NoContext, Encodable_NoContext, HashStable_NoContext) +)] +pub enum ValTreeKind { + /// integers, `bool`, `char` are represented as scalars. + /// See the `ScalarInt` documentation for how `ScalarInt` guarantees that equal values + /// of these types have the same representation. + Leaf(I::ScalarInt), + + /// The fields of any kind of aggregate. Structs, tuples and arrays are represented by + /// listing their fields' values in order. + /// + /// Enums are represented by storing their variant index as a u32 field, followed by all + /// the fields of the variant. + /// + /// ZST types are represented as an empty slice. + // FIXME(mgca): Use a `List` here instead of a boxed slice + Branch(Box<[I::Const]>), +} + +impl ValTreeKind { + /// Converts to a `ValTreeKind::Leaf` value, `panic`'ing + /// if this valtree is some other kind. + #[inline] + pub fn to_leaf(&self) -> I::ScalarInt { + match self { + ValTreeKind::Leaf(s) => *s, + ValTreeKind::Branch(..) => panic!("expected leaf, got {:?}", self), + } + } + + /// Converts to a `ValTreeKind::Branch` value, `panic`'ing + /// if this valtree is some other kind. + #[inline] + pub fn to_branch(&self) -> &[I::Const] { + match self { + ValTreeKind::Branch(branch) => &**branch, + ValTreeKind::Leaf(..) => panic!("expected branch, got {:?}", self), + } + } + + /// Attempts to convert to a `ValTreeKind::Leaf` value. + pub fn try_to_leaf(&self) -> Option { + match self { + ValTreeKind::Leaf(s) => Some(*s), + ValTreeKind::Branch(_) => None, + } + } + + /// Attempts to convert to a `ValTreeKind::Branch` value. + pub fn try_to_branch(&self) -> Option<&[I::Const]> { + match self { + ValTreeKind::Branch(branch) => Some(&**branch), + ValTreeKind::Leaf(_) => None, + } + } +} diff --git a/compiler/rustc_type_ir/src/flags.rs b/compiler/rustc_type_ir/src/flags.rs index 34b030ee768b..2c1fc7decc3e 100644 --- a/compiler/rustc_type_ir/src/flags.rs +++ b/compiler/rustc_type_ir/src/flags.rs @@ -477,7 +477,17 @@ impl FlagComputation { ty::ConstKind::Placeholder(_) => { self.add_flags(TypeFlags::HAS_CT_PLACEHOLDER); } - ty::ConstKind::Value(cv) => self.add_ty(cv.ty()), + ty::ConstKind::Value(cv) => { + self.add_ty(cv.ty()); + match cv.valtree().kind() { + ty::ValTreeKind::Leaf(_) => (), + ty::ValTreeKind::Branch(cts) => { + for ct in cts { + self.add_const(*ct); + } + } + } + } ty::ConstKind::Expr(e) => self.add_args(e.args().as_slice()), ty::ConstKind::Error(_) => self.add_flags(TypeFlags::HAS_ERROR), } diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index 75ba0231d98c..16f837141e97 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -292,6 +292,12 @@ pub trait ValueConst>: Copy + Debug + Hash + Eq { fn valtree(self) -> I::ValTree; } +// FIXME(mgca): This trait can be removed once we're not using a `Box` in `Branch` +pub trait ValTree>: Copy + Debug + Hash + Eq { + // This isnt' `IntoKind` because then we can't return a reference + fn kind(&self) -> &ty::ValTreeKind; +} + pub trait ExprConst>: Copy + Debug + Hash + Eq + Relate { fn args(self) -> I::GenericArgs; } diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 3884f29a4fc8..03cf738c0598 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -153,7 +153,8 @@ pub trait Interner: type PlaceholderConst: PlaceholderConst; type ValueConst: ValueConst; type ExprConst: ExprConst; - type ValTree: Copy + Debug + Hash + Eq; + type ValTree: ValTree; + type ScalarInt: Copy + Debug + Hash + Eq; // Kinds of regions type Region: Region; diff --git a/compiler/rustc_type_ir/src/relate.rs b/compiler/rustc_type_ir/src/relate.rs index 4f843503d1af..4954ebc51cfc 100644 --- a/compiler/rustc_type_ir/src/relate.rs +++ b/compiler/rustc_type_ir/src/relate.rs @@ -582,13 +582,27 @@ pub fn structurally_relate_consts>( } (ty::ConstKind::Placeholder(p1), ty::ConstKind::Placeholder(p2)) => p1 == p2, (ty::ConstKind::Value(a_val), ty::ConstKind::Value(b_val)) => { - a_val.valtree() == b_val.valtree() + match (a_val.valtree().kind(), b_val.valtree().kind()) { + (ty::ValTreeKind::Leaf(scalar_a), ty::ValTreeKind::Leaf(scalar_b)) => { + scalar_a == scalar_b + } + (ty::ValTreeKind::Branch(branches_a), ty::ValTreeKind::Branch(branches_b)) + if branches_a.len() == branches_b.len() => + { + branches_a + .into_iter() + .zip(branches_b) + .all(|(a, b)| relation.relate(*a, *b).is_ok()) + } + _ => false, + } } // While this is slightly incorrect, it shouldn't matter for `min_const_generics` // and is the better alternative to waiting until `generic_const_exprs` can // be stabilized. (ty::ConstKind::Unevaluated(au), ty::ConstKind::Unevaluated(bu)) if au.def == bu.def => { + // FIXME(mgca): remove this if cfg!(debug_assertions) { let a_ty = cx.type_of(au.def.into()).instantiate(cx, au.args); let b_ty = cx.type_of(bu.def.into()).instantiate(cx, bu.args); diff --git a/tests/ui/transmutability/structs/repr/transmute_infinitely_recursive_type.stderr b/tests/ui/transmutability/structs/repr/transmute_infinitely_recursive_type.stderr index 1a0563b469c1..a96876a2c25a 100644 --- a/tests/ui/transmutability/structs/repr/transmute_infinitely_recursive_type.stderr +++ b/tests/ui/transmutability/structs/repr/transmute_infinitely_recursive_type.stderr @@ -12,7 +12,7 @@ LL | struct ExplicitlyPadded(Box); error[E0391]: cycle detected when computing layout of `should_pad_explicitly_packed_field::ExplicitlyPadded` | = note: ...which immediately requires computing layout of `should_pad_explicitly_packed_field::ExplicitlyPadded` again - = note: cycle used when evaluating trait selection obligation `(): core::mem::transmutability::TransmuteFrom` + = note: cycle used when evaluating trait selection obligation `(): core::mem::transmutability::TransmuteFrom` = note: see https://rustc-dev-guide.rust-lang.org/overview.html#queries and https://rustc-dev-guide.rust-lang.org/query.html for more information error: aborting due to 2 previous errors