From d5bcfb334b616e01c3589ba7b1697c7ebc47a024 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Fri, 11 Jul 2025 05:07:37 -0700 Subject: [PATCH] Simplify codegen for niche-encoded variant tests --- compiler/rustc_abi/src/lib.rs | 35 ++++++++- compiler/rustc_codegen_ssa/src/mir/operand.rs | 77 ++++++++++++------- tests/codegen/enum/enum-discriminant-eq.rs | 42 ++++------ tests/codegen/enum/enum-match.rs | 19 +++-- 4 files changed, 110 insertions(+), 63 deletions(-) diff --git a/compiler/rustc_abi/src/lib.rs b/compiler/rustc_abi/src/lib.rs index de4b5a46c81a..5bd73502d980 100644 --- a/compiler/rustc_abi/src/lib.rs +++ b/compiler/rustc_abi/src/lib.rs @@ -43,7 +43,7 @@ use std::fmt; #[cfg(feature = "nightly")] use std::iter::Step; use std::num::{NonZeroUsize, ParseIntError}; -use std::ops::{Add, AddAssign, Deref, Mul, RangeInclusive, Sub}; +use std::ops::{Add, AddAssign, Deref, Mul, RangeFull, RangeInclusive, Sub}; use std::str::FromStr; use bitflags::bitflags; @@ -1391,12 +1391,45 @@ impl WrappingRange { } /// Returns `true` if `size` completely fills the range. + /// + /// Note that this is *not* the same as `self == WrappingRange::full(size)`. + /// Niche calculations can produce full ranges which are not the canonical one; + /// for example `Option>` gets `valid_range: (..=0) | (1..)`. #[inline] fn is_full_for(&self, size: Size) -> bool { let max_value = size.unsigned_int_max(); debug_assert!(self.start <= max_value && self.end <= max_value); self.start == (self.end.wrapping_add(1) & max_value) } + + /// Checks whether this range is considered non-wrapping when the values are + /// interpreted as *unsigned* numbers of width `size`. + /// + /// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is, + /// and `Err(..)` if the range is full so it depends how you think about it. + #[inline] + pub fn no_unsigned_wraparound(&self, size: Size) -> Result { + if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) } + } + + /// Checks whether this range is considered non-wrapping when the values are + /// interpreted as *signed* numbers of width `size`. + /// + /// This is heavily dependent on the `size`, as `100..=200` does wrap when + /// interpreted as `i8`, but doesn't when interpreted as `i16`. + /// + /// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is, + /// and `Err(..)` if the range is full so it depends how you think about it. + #[inline] + pub fn no_signed_wraparound(&self, size: Size) -> Result { + if self.is_full_for(size) { + Err(..) + } else { + let start: i128 = size.sign_extend(self.start); + let end: i128 = size.sign_extend(self.end); + Ok(start <= end) + } + } } impl fmt::Debug for WrappingRange { diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index b0d191528a89..9910cb5469fc 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { // value and the variant index match, since that's all `Niche` can encode. let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); + let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start); // We have a subrange `niche_start..=niche_end` inside `range`. // If the value of the tag is inside this subrange, it's a @@ -511,35 +512,44 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { // } else { // untagged_variant // } - let niche_start = bx.cx().const_uint_big(tag_llty, niche_start); - let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start); + let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const); let tagged_discr = bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64); (is_niche, tagged_discr, 0) } else { // The special cases don't apply, so we'll have to go with // the general algorithm. - let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start)); - let cast_tag = bx.intcast(relative_discr, cast_to, false); - let is_niche = bx.icmp( - IntPredicate::IntULE, - relative_discr, - bx.cx().const_uint(tag_llty, relative_max as u64), - ); - // Thanks to parameter attributes and load metadata, LLVM already knows - // the general valid range of the tag. It's possible, though, for there - // to be an impossible value *in the middle*, which those ranges don't - // communicate, so it's worth an `assume` to let the optimizer know. - if niche_variants.contains(&untagged_variant) - && bx.cx().sess().opts.optimize != OptLevel::No - { - let impossible = - u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32()); - let impossible = bx.cx().const_uint(tag_llty, impossible); - let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible); - bx.assume(ne); - } + let tag_range = tag_scalar.valid_range(&dl); + let tag_size = tag_scalar.size(&dl); + let niche_end = u128::from(relative_max).wrapping_add(niche_start); + let niche_end = tag_size.truncate(niche_end); + + let relative_discr = bx.sub(tag, niche_start_const); + let cast_tag = bx.intcast(relative_discr, cast_to, false); + let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) { + if niche_start == tag_range.start { + let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end); + bx.icmp(IntPredicate::IntULE, tag, niche_end_const) + } else { + assert_eq!(niche_end, tag_range.end); + bx.icmp(IntPredicate::IntUGE, tag, niche_start_const) + } + } else if tag_range.no_signed_wraparound(tag_size) == Ok(true) { + if niche_start == tag_range.start { + let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end); + bx.icmp(IntPredicate::IntSLE, tag, niche_end_const) + } else { + assert_eq!(niche_end, tag_range.end); + bx.icmp(IntPredicate::IntSGE, tag, niche_start_const) + } + } else { + bx.icmp( + IntPredicate::IntULE, + relative_discr, + bx.cx().const_uint(tag_llty, relative_max as u64), + ) + }; (is_niche, cast_tag, niche_variants.start().as_u32() as u128) }; @@ -550,11 +560,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta)) }; - let discr = bx.select( - is_niche, - tagged_discr, - bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64), - ); + let untagged_variant_const = + bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32())); + + // Thanks to parameter attributes and load metadata, LLVM already knows + // the general valid range of the tag. It's possible, though, for there + // to be an impossible value *in the middle*, which those ranges don't + // communicate, so it's worth an `assume` to let the optimizer know. + // Most importantly, this means when optimizing a variant test like + // `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that + // to `!is_niche` because the `complex` part can't possibly match. + if niche_variants.contains(&untagged_variant) + && bx.cx().sess().opts.optimize != OptLevel::No + { + let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const); + bx.assume(ne); + } + + let discr = bx.select(is_niche, tagged_discr, untagged_variant_const); // In principle we could insert assumes on the possible range of `discr`, but // currently in LLVM this isn't worth it because the original `tag` will diff --git a/tests/codegen/enum/enum-discriminant-eq.rs b/tests/codegen/enum/enum-discriminant-eq.rs index fd6f3499cb1f..0494c5f551b3 100644 --- a/tests/codegen/enum/enum-discriminant-eq.rs +++ b/tests/codegen/enum/enum-discriminant-eq.rs @@ -89,13 +89,13 @@ pub fn mid_bool_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_bool_eq_discr( // CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2 - // CHECK: %[[A_IS_NICHE:.+]] = icmp ult i8 %[[A_REL_DISCR]], 3 + // CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1 // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 // CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2 - // CHECK: %[[B_IS_NICHE:.+]] = icmp ult i8 %[[B_REL_DISCR]], 3 + // CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1 // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 @@ -109,13 +109,13 @@ pub fn mid_ord_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_ord_eq_discr( // CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2 - // CHECK: %[[A_IS_NICHE:.+]] = icmp ult i8 %[[A_REL_DISCR]], 3 + // CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1 // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 // CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2 - // CHECK: %[[B_IS_NICHE:.+]] = icmp ult i8 %[[B_REL_DISCR]], 3 + // CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1 // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 @@ -138,13 +138,13 @@ pub fn mid_ac_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_ac_eq_discr( // CHECK: %[[A_REL_DISCR:.+]] = xor i8 %a, -128 - // CHECK: %[[A_IS_NICHE:.+]] = icmp ult i8 %[[A_REL_DISCR]], 3 + // CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0 // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127 // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 // CHECK: %[[B_REL_DISCR:.+]] = xor i8 %b, -128 - // CHECK: %[[B_IS_NICHE:.+]] = icmp ult i8 %[[B_REL_DISCR]], 3 + // CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0 // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127 // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 @@ -160,17 +160,17 @@ pub fn mid_ac_eq_discr(a: Mid, b: Mid) -> bool { pub fn mid_giant_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_giant_eq_discr( - // CHECK: %[[A_REL_DISCR_WIDE:.+]] = add nsw i128 %a, -5 - // CHECK: %[[A_REL_DISCR:.+]] = trunc nsw i128 %[[A_REL_DISCR_WIDE]] to i64 - // CHECK: %[[A_IS_NICHE:.+]] = icmp ult i128 %[[A_REL_DISCR_WIDE]], 3 - // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i128 %[[A_REL_DISCR_WIDE]], 1 + // CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64 + // CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5 + // CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4 + // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1 - // CHECK: %[[B_REL_DISCR_WIDE:.+]] = add nsw i128 %b, -5 - // CHECK: %[[B_REL_DISCR:.+]] = trunc nsw i128 %[[B_REL_DISCR_WIDE]] to i64 - // CHECK: %[[B_IS_NICHE:.+]] = icmp ult i128 %[[B_REL_DISCR_WIDE]], 3 - // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i128 %[[B_REL_DISCR_WIDE]], 1 + // CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64 + // CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5 + // CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4 + // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1 // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1 @@ -181,14 +181,11 @@ pub fn mid_giant_eq_discr(a: Mid, b: Mid) -> bool { // In niche-encoded enums, testing for the untagged variant should optimize to a // straight-forward comparison looking for the natural range of the payload value. -// FIXME: A bunch don't, though. #[unsafe(no_mangle)] pub fn mid_bool_is_thing(a: Mid) -> bool { // CHECK-LABEL: @mid_bool_is_thing( - - // CHECK: %[[REL_DISCR:.+]] = add nsw i8 %a, -2 - // CHECK: %[[R:.+]] = icmp ugt i8 %[[REL_DISCR]], 2 + // CHECK: %[[R:.+]] = icmp samesign ult i8 %a, 2 // CHECK: ret i1 %[[R]] discriminant_value(&a) == 1 } @@ -196,8 +193,7 @@ pub fn mid_bool_is_thing(a: Mid) -> bool { #[unsafe(no_mangle)] pub fn mid_ord_is_thing(a: Mid) -> bool { // CHECK-LABEL: @mid_ord_is_thing( - // CHECK: %[[REL_DISCR:.+]] = add nsw i8 %a, -2 - // CHECK: %[[R:.+]] = icmp ugt i8 %[[REL_DISCR]], 2 + // CHECK: %[[R:.+]] = icmp slt i8 %a, 2 // CHECK: ret i1 %[[R]] discriminant_value(&a) == 1 } @@ -221,11 +217,7 @@ pub fn mid_ac_is_thing(a: Mid) -> bool { #[unsafe(no_mangle)] pub fn mid_giant_is_thing(a: Mid) -> bool { // CHECK-LABEL: @mid_giant_is_thing( - // CHECK: %[[REL_DISCR_WIDE:.+]] = add nsw i128 %a, -5 - // CHECK: %[[REL_DISCR:.+]] = trunc nsw i128 %[[REL_DISCR_WIDE]] to i64 - // CHECK: %[[NOT_NICHE:.+]] = icmp ugt i128 %[[REL_DISCR_WIDE]], 2 - // CHECK: %[[IS_MID_VARIANT:.+]] = icmp eq i64 %[[REL_DISCR]], 1 - // CHECK: %[[R:.+]] = select i1 %[[NOT_NICHE]], i1 true, i1 %[[IS_MID_VARIANT]] + // CHECK: %[[R:.+]] = icmp samesign ult i128 %a, 5 // CHECK: ret i1 %[[R]] discriminant_value(&a) == 1 } diff --git a/tests/codegen/enum/enum-match.rs b/tests/codegen/enum/enum-match.rs index 9725e64def6a..57db44ec74e8 100644 --- a/tests/codegen/enum/enum-match.rs +++ b/tests/codegen/enum/enum-match.rs @@ -41,7 +41,7 @@ pub enum Enum1 { // CHECK-NEXT: start: // CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 // CHECK-NEXT: %[[REL_VAR_WIDE:.+]] = zext i8 %[[REL_VAR]] to i64 -// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp ult i8 %[[REL_VAR]], 2 +// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp{{( samesign)?}} ugt i8 %0, 1 // CHECK-NEXT: %[[NICHE_DISCR:.+]] = add nuw nsw i64 %[[REL_VAR_WIDE]], 1 // CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i64 %[[NICHE_DISCR]], i64 0 // CHECK-NEXT: switch i64 %[[DISCR]] @@ -148,10 +148,10 @@ pub enum MiddleNiche { // CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 -?[0-9]+, -?[0-9]+\))?}} i8 @match4(i8{{.+}}%0) // CHECK-NEXT: start: // CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 -// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp ult i8 %[[REL_VAR]], 5 // CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 2 // CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) -// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i8 %[[REL_VAR]], i8 2 +// CHECK-NEXT: %[[NOT_NICHE:.+]] = icmp{{( samesign)?}} ult i8 %0, 2 +// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[NOT_NICHE]], i8 2, i8 %[[REL_VAR]] // CHECK-NEXT: switch i8 %[[DISCR]] #[no_mangle] pub fn match4(e: MiddleNiche) -> u8 { @@ -167,11 +167,10 @@ pub fn match4(e: MiddleNiche) -> u8 { // CHECK-LABEL: define{{.+}}i1 @match4_is_c(i8{{.+}}%e) // CHECK-NEXT: start -// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %e, -2 -// CHECK-NEXT: %[[NOT_NICHE:.+]] = icmp ugt i8 %[[REL_VAR]], 4 -// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 2 +// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %e, 4 // CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) -// CHECK-NEXT: ret i1 %[[NOT_NICHE]] +// CHECK-NEXT: %[[IS_C:.+]] = icmp{{( samesign)?}} ult i8 %e, 2 +// CHECK-NEXT: ret i1 %[[IS_C]] #[no_mangle] pub fn match4_is_c(e: MiddleNiche) -> bool { // Before #139098, this couldn't optimize out the `select` because it looked @@ -453,10 +452,10 @@ pub enum HugeVariantIndex { // CHECK-NEXT: start: // CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 // CHECK-NEXT: %[[REL_VAR_WIDE:.+]] = zext i8 %[[REL_VAR]] to i64 -// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp ult i8 %[[REL_VAR]], 3 -// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 1 -// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) +// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp{{( samesign)?}} ugt i8 %0, 1 // CHECK-NEXT: %[[NICHE_DISCR:.+]] = add nuw nsw i64 %[[REL_VAR_WIDE]], 257 +// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i64 %[[NICHE_DISCR]], 258 +// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) // CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i64 %[[NICHE_DISCR]], i64 258 // CHECK-NEXT: switch i64 %[[DISCR]], // CHECK-NEXT: i64 257,