Auto merge of #150925 - dianqk:if-cmp, r=saethlin

Only use SSA locals in SimplifyComparisonIntegral

Fixes https://github.com/rust-lang/rust/issues/150904.

The place may be modified from the comparison statement to the switchInt terminator.

Best reviewed commit by commit.
This commit is contained in:
bors 2026-01-15 23:54:21 +00:00
commit 18ae990755
23 changed files with 301 additions and 56 deletions

View file

@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> {
self.projection.iter().any(|elem| elem.is_indirect())
}
/// Returns `true` if the `Place` always refers to the same memory region
/// whatever the state of the program.
pub fn is_stable_offset(&self) -> bool {
self.projection.iter().all(|elem| elem.is_stable_offset())
}
/// Returns `true` if this `Place`'s first projection is `Deref`.
///
/// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later,

View file

@ -9,6 +9,8 @@ use rustc_middle::mir::{
use rustc_middle::ty::{Ty, TyCtxt};
use tracing::trace;
use crate::ssa::SsaLocals;
/// Pass to convert `if` conditions on integrals into switches on the integral.
/// For an example, it turns something like
///
@ -27,17 +29,18 @@ pub(super) struct SimplifyComparisonIntegral;
impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() > 0
sess.mir_opt_level() > 1
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
let typing_env = body.typing_env(tcx);
let ssa = SsaLocals::new(tcx, body, typing_env);
let helper = OptimizationFinder { body };
let opts = helper.find_optimizations();
let opts = helper.find_optimizations(&ssa);
let mut storage_deads_to_insert = vec![];
let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
let typing_env = body.typing_env(tcx);
for opt in opts {
trace!("SUCCESS: Applying {:?}", opt);
// replace terminator with a switchInt that switches on the integer directly
@ -132,7 +135,7 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
let terminator = bb.terminator_mut();
terminator.kind =
TerminatorKind::SwitchInt { discr: Operand::Move(opt.to_switch_on), targets };
TerminatorKind::SwitchInt { discr: Operand::Copy(opt.to_switch_on), targets };
}
for (idx, bb_idx) in storage_deads_to_remove {
@ -154,19 +157,18 @@ struct OptimizationFinder<'a, 'tcx> {
}
impl<'tcx> OptimizationFinder<'_, 'tcx> {
fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
fn find_optimizations(&self, ssa: &SsaLocals) -> Vec<OptimizationInfo<'tcx>> {
self.body
.basic_blocks
.iter_enumerated()
.filter_map(|(bb_idx, bb)| {
// find switch
let (place_switched_on, targets, place_switched_on_moved) =
match &bb.terminator().kind {
rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => {
Some((discr.place()?, targets, discr.is_move()))
}
_ => None,
}?;
let (discr, targets) = bb.terminator().kind.as_switch()?;
let place_switched_on = discr.place()?;
// Make sure that the place is not modified.
if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() {
return None;
}
// find the statement that assigns the place being switched on
bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
@ -180,12 +182,12 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> {
box (left, right),
) => {
let (branch_value_scalar, branch_value_ty, to_switch_on) =
find_branch_value_info(left, right)?;
find_branch_value_info(left, right, ssa)?;
Some(OptimizationInfo {
bin_op_stmt_idx: stmt_idx,
bb_idx,
can_remove_bin_op_stmt: place_switched_on_moved,
can_remove_bin_op_stmt: discr.is_move(),
to_switch_on,
branch_value_scalar,
branch_value_ty,
@ -207,6 +209,7 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> {
fn find_branch_value_info<'tcx>(
left: &Operand<'tcx>,
right: &Operand<'tcx>,
ssa: &SsaLocals,
) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
// check that either left or right is a constant.
// if any are, we can use the other to switch on, and the constant as a value in a switch
@ -214,6 +217,10 @@ fn find_branch_value_info<'tcx>(
match (left, right) {
(Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
| (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
// Make sure that the place is not modified.
if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() {
return None;
}
let branch_value_ty = branch_value.const_.ty();
// we only want to apply this optimization if we are matching on integrals (and chars),
// as it is not possible to switch on floats

View file

@ -1,6 +1,8 @@
//@ min-lldb-version: 310
//@ compile-flags:-g
// FIXME: Investigate why test fails without SimplifyComparisonIntegral pass.
//@ compile-flags: -Zmir-enable-passes=+SimplifyComparisonIntegral
//@ ignore-backends: gcc
// === GDB TESTS ===================================================================================

View file

@ -74,7 +74,7 @@
_23 = copy (((*_1).0: std::fmt::FormattingOptions).0: u32);
_22 = BitAnd(move _23, const core::fmt::flags::PRECISION_FLAG);
StorageDead(_23);
switchInt(move _22) -> [0: bb10, otherwise: bb11];
switchInt(copy _22) -> [0: bb10, otherwise: bb11];
}
bb4: {

View file

@ -74,7 +74,7 @@
_23 = copy (((*_1).0: std::fmt::FormattingOptions).0: u32);
_22 = BitAnd(move _23, const core::fmt::flags::PRECISION_FLAG);
StorageDead(_23);
switchInt(move _22) -> [0: bb10, otherwise: bb11];
switchInt(copy _22) -> [0: bb10, otherwise: bb11];
}
bb4: {

View file

@ -74,7 +74,7 @@
_23 = copy (((*_1).0: std::fmt::FormattingOptions).0: u32);
_22 = BitAnd(move _23, const core::fmt::flags::PRECISION_FLAG);
StorageDead(_23);
switchInt(move _22) -> [0: bb10, otherwise: bb11];
switchInt(copy _22) -> [0: bb10, otherwise: bb11];
}
bb4: {

View file

@ -74,7 +74,7 @@
_23 = copy (((*_1).0: std::fmt::FormattingOptions).0: u32);
_22 = BitAnd(move _23, const core::fmt::flags::PRECISION_FLAG);
StorageDead(_23);
switchInt(move _22) -> [0: bb10, otherwise: bb11];
switchInt(copy _22) -> [0: bb10, otherwise: bb11];
}
bb4: {

View file

@ -9,7 +9,7 @@
bb0: {
StorageLive(_2);
_2 = copy _1;
switchInt(move _2) -> [0: bb2, otherwise: bb1];
switchInt(copy _1) -> [0: bb2, otherwise: bb1];
}
bb1: {

View file

@ -11,7 +11,7 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
_2 = Eq(move _3, const -42f32);
_2 = Eq(copy _1, const -42f32);
switchInt(move _2) -> [0: bb2, otherwise: bb1];
}

View file

@ -15,23 +15,20 @@
}
bb0: {
StorageLive(_2);
nop;
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const 17_i8);
- StorageDead(_3);
_2 = Eq(copy _1, const 17_i8);
StorageDead(_3);
- switchInt(copy _2) -> [0: bb2, otherwise: bb1];
+ _2 = Eq(copy _3, const 17_i8);
+ nop;
+ switchInt(move _3) -> [17: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [17: bb1, otherwise: bb2];
}
bb1: {
+ StorageDead(_3);
StorageLive(_6);
StorageLive(_7);
_7 = copy _2;
_6 = move _7 as i32 (IntToInt);
_6 = copy _2 as i32 (IntToInt);
StorageDead(_7);
_0 = Add(const 100_i32, move _6);
StorageDead(_6);
@ -39,11 +36,10 @@
}
bb2: {
+ StorageDead(_3);
StorageLive(_4);
StorageLive(_5);
_5 = copy _2;
_4 = move _5 as i32 (IntToInt);
_4 = copy _2 as i32 (IntToInt);
StorageDead(_5);
_0 = Add(const 10_i32, move _4);
StorageDead(_4);
@ -51,7 +47,7 @@
}
bb3: {
StorageDead(_2);
nop;
return;
}
}

View file

@ -0,0 +1,24 @@
- // MIR for `on_non_ssa_cmp` before SimplifyComparisonIntegral
+ // MIR for `on_non_ssa_cmp` after SimplifyComparisonIntegral
fn on_non_ssa_cmp(_1: u64) -> i32 {
let mut _0: i32;
let mut _2: bool;
bb0: {
_2 = Eq(copy _1, const 42_u64);
_1 = const 43_u64;
switchInt(copy _2) -> [1: bb1, otherwise: bb2];
}
bb1: {
_0 = const 0_i32;
return;
}
bb2: {
_0 = const 1_i32;
return;
}
}

View file

@ -0,0 +1,24 @@
- // MIR for `on_non_ssa_place` before SimplifyComparisonIntegral
+ // MIR for `on_non_ssa_place` after SimplifyComparisonIntegral
fn on_non_ssa_place(_1: [u64; 10], _2: usize) -> i32 {
let mut _0: i32;
let mut _3: bool;
bb0: {
_3 = Eq(copy _1[_2], const 42_u64);
_2 = const 10_usize;
switchInt(copy _3) -> [1: bb1, otherwise: bb2];
}
bb1: {
_0 = const 0_i32;
return;
}
bb2: {
_0 = const 1_i32;
return;
}
}

View file

@ -0,0 +1,24 @@
- // MIR for `on_non_ssa_switch` before SimplifyComparisonIntegral
+ // MIR for `on_non_ssa_switch` after SimplifyComparisonIntegral
fn on_non_ssa_switch(_1: u64) -> i32 {
let mut _0: i32;
let mut _2: bool;
bb0: {
_2 = Eq(copy _1, const 42_u64);
_2 = const false;
switchInt(copy _2) -> [1: bb1, otherwise: bb2];
}
bb1: {
_0 = const 0_i32;
return;
}
bb2: {
_0 = const 1_i32;
return;
}
}

View file

@ -11,10 +11,10 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const 'x');
- _2 = Eq(copy _1, const 'x');
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ nop;
+ switchInt(move _3) -> [120: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [120: bb1, otherwise: bb2];
}
bb1: {

View file

@ -11,10 +11,10 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const 42_i8);
- _2 = Eq(copy _1, const 42_i8);
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ nop;
+ switchInt(move _3) -> [42: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [42: bb1, otherwise: bb2];
}
bb1: {

View file

@ -13,10 +13,10 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const 42_u32);
- _2 = Eq(copy _1, const 42_u32);
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ nop;
+ switchInt(move _3) -> [42: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [42: bb1, otherwise: bb2];
}
bb1: {
@ -30,10 +30,10 @@
StorageLive(_4);
StorageLive(_5);
_5 = copy _1;
- _4 = Ne(move _5, const 21_u32);
- _4 = Ne(copy _1, const 21_u32);
- switchInt(move _4) -> [0: bb4, otherwise: bb3];
+ nop;
+ switchInt(move _5) -> [21: bb4, otherwise: bb3];
+ switchInt(copy _1) -> [21: bb4, otherwise: bb3];
}
bb3: {

View file

@ -11,10 +11,10 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const -42_i32);
- _2 = Eq(copy _1, const -42_i32);
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ nop;
+ switchInt(move _3) -> [4294967254: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [4294967254: bb1, otherwise: bb2];
}
bb1: {

View file

@ -11,10 +11,10 @@
StorageLive(_2);
StorageLive(_3);
_3 = copy _1;
- _2 = Eq(move _3, const 42_u32);
- _2 = Eq(copy _1, const 42_u32);
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ nop;
+ switchInt(move _3) -> [42: bb1, otherwise: bb2];
+ switchInt(copy _1) -> [42: bb1, otherwise: bb2];
}
bb1: {

View file

@ -1,36 +1,80 @@
// skip-filecheck
//@ test-mir-pass: SimplifyComparisonIntegral
// EMIT_MIR if_condition_int.opt_u32.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_negative.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_char.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_i8.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_opt_bool.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_multiple_ifs.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_remove_comparison.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_opt_floats.SimplifyComparisonIntegral.diff
// GVN simplifies FileCheck.
//@ compile-flags: -Zmir-enable-passes=+GVN
#![feature(custom_mir, core_intrinsics)]
extern crate core;
use core::intrinsics::mir::*;
// EMIT_MIR if_condition_int.opt_u32.SimplifyComparisonIntegral.diff
fn opt_u32(x: u32) -> u32 {
// CHECK-LABEL: fn opt_u32(
// CHECK: switchInt(copy _1) -> [42: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_u32;
if x == 42 { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.dont_opt_bool.SimplifyComparisonIntegral.diff
// don't opt: it is already optimal to switch on the bool
fn dont_opt_bool(x: bool) -> u32 {
// CHECK-LABEL: fn dont_opt_bool(
// CHECK: switchInt(copy _1) -> [0: [[BB2:bb.*]], otherwise: [[BB1:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_u32;
if x { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.opt_char.SimplifyComparisonIntegral.diff
fn opt_char(x: char) -> u32 {
// CHECK-LABEL: fn opt_char(
// CHECK: switchInt(copy _1) -> [120: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_u32;
if x == 'x' { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.opt_i8.SimplifyComparisonIntegral.diff
fn opt_i8(x: i8) -> u32 {
// CHECK-LABEL: fn opt_i8(
// CHECK: switchInt(copy _1) -> [42: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_u32;
if x == 42 { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.opt_negative.SimplifyComparisonIntegral.diff
fn opt_negative(x: i32) -> u32 {
// CHECK-LABEL: fn opt_negative(
// CHECK: switchInt(copy _1) -> [4294967254: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_u32;
if x == -42 { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.opt_multiple_ifs.SimplifyComparisonIntegral.diff
fn opt_multiple_ifs(x: u32) -> u32 {
// CHECK-LABEL: fn opt_multiple_ifs(
// CHECK: switchInt(copy _1) -> [42: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_u32;
// CHECK: [[BB2]]:
// CHECK: switchInt(copy _1) -> [21: [[BB4:bb.*]], otherwise: [[BB3:bb.*]]];
// CHECK: [[BB3]]:
// CHECK: _0 = const 1_u32;
// CHECK: [[BB4]]:
// CHECK: _0 = const 2_u32;
if x == 42 {
0
} else if x != 21 {
@ -40,8 +84,18 @@ fn opt_multiple_ifs(x: u32) -> u32 {
}
}
// EMIT_MIR if_condition_int.dont_remove_comparison.SimplifyComparisonIntegral.diff
// test that we optimize, but do not remove the b statement, as that is used later on
fn dont_remove_comparison(a: i8) -> i32 {
// CHECK-LABEL: fn dont_remove_comparison(
// CHECK: [[b:_.*]] = Eq(copy _1, const 17_i8);
// CHECK: switchInt(copy _1) -> [17: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: [[cast_1:_.*]] = copy [[b]] as i32 (IntToInt);
// CHECK: _0 = Add(const 100_i32, move [[cast_1]]);
// CHECK: [[BB2]]:
// CHECK: [[cast_2:_.*]] = copy [[b]] as i32 (IntToInt);
// CHECK: _0 = Add(const 10_i32, move [[cast_2]]);
let b = a == 17;
match b {
false => 10 + b as i32,
@ -49,11 +103,118 @@ fn dont_remove_comparison(a: i8) -> i32 {
}
}
// EMIT_MIR if_condition_int.dont_opt_floats.SimplifyComparisonIntegral.diff
// test that we do not optimize on floats
fn dont_opt_floats(a: f32) -> i32 {
// CHECK-LABEL: fn dont_opt_floats(
// CHECK: [[cmp:_.*]] = Eq(copy _1, const -42f32);
// CHECK: switchInt(move [[cmp]]) -> [0: [[BB2:bb.*]], otherwise: [[BB1:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
if a == -42.0 { 0 } else { 1 }
}
// EMIT_MIR if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_switch(mut v: u64) -> i32 {
// CHECK-LABEL: fn on_non_ssa_switch(
// CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64);
// CHECK: [[cmp]] = const false;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{
a = v == 42;
a = false;
match a {
true => bb1,
_ => bb2,
}
}
bb1 = {
RET = 0;
Return()
}
bb2 = {
RET = 1;
Return()
}
}
}
// EMIT_MIR if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_cmp(mut v: u64) -> i32 {
// CHECK-LABEL: fn on_non_ssa_cmp(
// CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64);
// CHECK: _1 = const 43_u64;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{
a = v == 42;
v = 43;
match a {
true => bb1,
_ => bb2,
}
}
bb1 = {
RET = 0;
Return()
}
bb2 = {
RET = 1;
Return()
}
}
}
// EMIT_MIR if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_place(mut v: [u64; 10], mut i: usize) -> i32 {
// CHECK-LABEL: fn on_non_ssa_place(
// CHECK: [[cmp:_.*]] = Eq(copy _1[_2], const 42_u64);
// CHECK: _2 = const 10_usize;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{
a = v[i] == 42;
i = 10;
match a {
true => bb1,
_ => bb2,
}
}
bb1 = {
RET = 0;
Return()
}
bb2 = {
RET = 1;
Return()
}
}
}
fn main() {
opt_u32(0);
opt_char('0');
@ -63,4 +224,5 @@ fn main() {
opt_multiple_ifs(0);
dont_remove_comparison(11);
dont_opt_floats(1.0);
on_non_ssa_switch(42);
}

View file

@ -28,7 +28,7 @@ fn num_to_digit(_1: char) -> u32 {
StorageLive(_3);
_3 = discriminant(_2);
StorageDead(_2);
switchInt(move _3) -> [1: bb2, otherwise: bb7];
switchInt(copy _3) -> [1: bb2, otherwise: bb7];
}
bb2: {

View file

@ -28,7 +28,7 @@ fn num_to_digit(_1: char) -> u32 {
StorageLive(_3);
_3 = discriminant(_2);
StorageDead(_2);
switchInt(move _3) -> [1: bb2, otherwise: bb7];
switchInt(copy _3) -> [1: bb2, otherwise: bb7];
}
bb2: {

View file

@ -28,7 +28,7 @@ fn num_to_digit(_1: char) -> u32 {
StorageLive(_3);
_3 = discriminant(_2);
StorageDead(_2);
switchInt(move _3) -> [1: bb2, otherwise: bb7];
switchInt(copy _3) -> [1: bb2, otherwise: bb7];
}
bb2: {

View file

@ -28,7 +28,7 @@ fn num_to_digit(_1: char) -> u32 {
StorageLive(_3);
_3 = discriminant(_2);
StorageDead(_2);
switchInt(move _3) -> [1: bb2, otherwise: bb7];
switchInt(copy _3) -> [1: bb2, otherwise: bb7];
}
bb2: {