Only use SSA locals in SimplifyComparisonIntegral

This commit is contained in:
dianqk 2026-01-10 20:10:18 +08:00
parent 528fd2a330
commit ac80ccec5f
No known key found for this signature in database
6 changed files with 52 additions and 18 deletions

View file

@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> {
self.projection.iter().any(|elem| elem.is_indirect())
}
/// Returns `true` if the `Place` always refers to the same memory region
/// whatever the state of the program.
pub fn is_stable_offset(&self) -> bool {
self.projection.iter().all(|elem| elem.is_stable_offset())
}
/// Returns `true` if this `Place`'s first projection is `Deref`.
///
/// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later,

View file

@ -9,6 +9,8 @@ use rustc_middle::mir::{
use rustc_middle::ty::{Ty, TyCtxt};
use tracing::trace;
use crate::ssa::SsaLocals;
/// Pass to convert `if` conditions on integrals into switches on the integral.
/// For an example, it turns something like
///
@ -33,11 +35,12 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
let typing_env = body.typing_env(tcx);
let ssa = SsaLocals::new(tcx, body, typing_env);
let helper = OptimizationFinder { body };
let opts = helper.find_optimizations();
let opts = helper.find_optimizations(&ssa);
let mut storage_deads_to_insert = vec![];
let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
let typing_env = body.typing_env(tcx);
for opt in opts {
trace!("SUCCESS: Applying {:?}", opt);
// replace terminator with a switchInt that switches on the integer directly
@ -154,19 +157,18 @@ struct OptimizationFinder<'a, 'tcx> {
}
impl<'tcx> OptimizationFinder<'_, 'tcx> {
fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
fn find_optimizations(&self, ssa: &SsaLocals) -> Vec<OptimizationInfo<'tcx>> {
self.body
.basic_blocks
.iter_enumerated()
.filter_map(|(bb_idx, bb)| {
// find switch
let (place_switched_on, targets, place_switched_on_moved) =
match &bb.terminator().kind {
rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => {
Some((discr.place()?, targets, discr.is_move()))
}
_ => None,
}?;
let (discr, targets) = bb.terminator().kind.as_switch()?;
let place_switched_on = discr.place()?;
// Make sure that the place is not modified.
if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() {
return None;
}
// find the statement that assigns the place being switched on
bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
@ -180,12 +182,12 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> {
box (left, right),
) => {
let (branch_value_scalar, branch_value_ty, to_switch_on) =
find_branch_value_info(left, right)?;
find_branch_value_info(left, right, ssa)?;
Some(OptimizationInfo {
bin_op_stmt_idx: stmt_idx,
bb_idx,
can_remove_bin_op_stmt: place_switched_on_moved,
can_remove_bin_op_stmt: discr.is_move(),
to_switch_on,
branch_value_scalar,
branch_value_ty,
@ -207,6 +209,7 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> {
fn find_branch_value_info<'tcx>(
left: &Operand<'tcx>,
right: &Operand<'tcx>,
ssa: &SsaLocals,
) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
// check that either left or right is a constant.
// if any are, we can use the other to switch on, and the constant as a value in a switch
@ -214,6 +217,10 @@ fn find_branch_value_info<'tcx>(
match (left, right) {
(Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
| (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
// Make sure that the place is not modified.
if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() {
return None;
}
let branch_value_ty = branch_value.const_.ty();
// we only want to apply this optimization if we are matching on integrals (and chars),
// as it is not possible to switch on floats

View file

@ -8,8 +8,7 @@
bb0: {
_2 = Eq(copy _1, const 42_u64);
_1 = const 43_u64;
- switchInt(copy _2) -> [1: bb1, otherwise: bb2];
+ switchInt(move _1) -> [42: bb1, otherwise: bb2];
switchInt(copy _2) -> [1: bb1, otherwise: bb2];
}
bb1: {

View file

@ -8,8 +8,7 @@
bb0: {
_3 = Eq(copy _1[_2], const 42_u64);
_2 = const 10_usize;
- switchInt(copy _3) -> [1: bb1, otherwise: bb2];
+ switchInt(move _1[_2]) -> [42: bb1, otherwise: bb2];
switchInt(copy _3) -> [1: bb1, otherwise: bb2];
}
bb1: {

View file

@ -8,8 +8,7 @@
bb0: {
_2 = Eq(copy _1, const 42_u64);
_2 = const false;
- switchInt(copy _2) -> [1: bb1, otherwise: bb2];
+ switchInt(move _1) -> [42: bb1, otherwise: bb2];
switchInt(copy _2) -> [1: bb1, otherwise: bb2];
}
bb1: {

View file

@ -120,6 +120,14 @@ fn dont_opt_floats(a: f32) -> i32 {
// EMIT_MIR if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_switch(mut v: u64) -> i32 {
// CHECK-LABEL: fn on_non_ssa_switch(
// CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64);
// CHECK: [[cmp]] = const false;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{
@ -145,6 +153,14 @@ pub fn on_non_ssa_switch(mut v: u64) -> i32 {
// EMIT_MIR if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_cmp(mut v: u64) -> i32 {
// CHECK-LABEL: fn on_non_ssa_cmp(
// CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64);
// CHECK: _1 = const 43_u64;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{
@ -170,6 +186,14 @@ pub fn on_non_ssa_cmp(mut v: u64) -> i32 {
// EMIT_MIR if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff
#[custom_mir(dialect = "runtime")]
pub fn on_non_ssa_place(mut v: [u64; 10], mut i: usize) -> i32 {
// CHECK-LABEL: fn on_non_ssa_place(
// CHECK: [[cmp:_.*]] = Eq(copy _1[_2], const 42_u64);
// CHECK: _2 = const 10_usize;
// CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]];
// CHECK: [[BB1]]:
// CHECK: _0 = const 0_i32;
// CHECK: [[BB2]]:
// CHECK: _0 = const 1_i32;
mir! {
let a: bool;
{