From ac80ccec5f5999e7251474380661be5a84e56683 Mon Sep 17 00:00:00 2001 From: dianqk Date: Sat, 10 Jan 2026 20:10:18 +0800 Subject: [PATCH] Only use SSA locals in SimplifyComparisonIntegral --- compiler/rustc_middle/src/mir/statement.rs | 6 ++++ .../src/simplify_comparison_integral.rs | 31 ++++++++++++------- ...on_ssa_cmp.SimplifyComparisonIntegral.diff | 3 +- ..._ssa_place.SimplifyComparisonIntegral.diff | 3 +- ...ssa_switch.SimplifyComparisonIntegral.diff | 3 +- tests/mir-opt/if_condition_int.rs | 24 ++++++++++++++ 6 files changed, 52 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_middle/src/mir/statement.rs b/compiler/rustc_middle/src/mir/statement.rs index c1f8c46baddb..adaac3d7ffc2 100644 --- a/compiler/rustc_middle/src/mir/statement.rs +++ b/compiler/rustc_middle/src/mir/statement.rs @@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> { self.projection.iter().any(|elem| elem.is_indirect()) } + /// Returns `true` if the `Place` always refers to the same memory region + /// whatever the state of the program. + pub fn is_stable_offset(&self) -> bool { + self.projection.iter().all(|elem| elem.is_stable_offset()) + } + /// Returns `true` if this `Place`'s first projection is `Deref`. /// /// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later, diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs index 2643d78990e5..53a796b1179a 100644 --- a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -9,6 +9,8 @@ use rustc_middle::mir::{ use rustc_middle::ty::{Ty, TyCtxt}; use tracing::trace; +use crate::ssa::SsaLocals; + /// Pass to convert `if` conditions on integrals into switches on the integral. /// For an example, it turns something like /// @@ -33,11 +35,12 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("Running SimplifyComparisonIntegral on {:?}", body.source); + let typing_env = body.typing_env(tcx); + let ssa = SsaLocals::new(tcx, body, typing_env); let helper = OptimizationFinder { body }; - let opts = helper.find_optimizations(); + let opts = helper.find_optimizations(&ssa); let mut storage_deads_to_insert = vec![]; let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![]; - let typing_env = body.typing_env(tcx); for opt in opts { trace!("SUCCESS: Applying {:?}", opt); // replace terminator with a switchInt that switches on the integer directly @@ -154,19 +157,18 @@ struct OptimizationFinder<'a, 'tcx> { } impl<'tcx> OptimizationFinder<'_, 'tcx> { - fn find_optimizations(&self) -> Vec> { + fn find_optimizations(&self, ssa: &SsaLocals) -> Vec> { self.body .basic_blocks .iter_enumerated() .filter_map(|(bb_idx, bb)| { // find switch - let (place_switched_on, targets, place_switched_on_moved) = - match &bb.terminator().kind { - rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => { - Some((discr.place()?, targets, discr.is_move())) - } - _ => None, - }?; + let (discr, targets) = bb.terminator().kind.as_switch()?; + let place_switched_on = discr.place()?; + // Make sure that the place is not modified. + if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() { + return None; + } // find the statement that assigns the place being switched on bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| { @@ -180,12 +182,12 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> { box (left, right), ) => { let (branch_value_scalar, branch_value_ty, to_switch_on) = - find_branch_value_info(left, right)?; + find_branch_value_info(left, right, ssa)?; Some(OptimizationInfo { bin_op_stmt_idx: stmt_idx, bb_idx, - can_remove_bin_op_stmt: place_switched_on_moved, + can_remove_bin_op_stmt: discr.is_move(), to_switch_on, branch_value_scalar, branch_value_ty, @@ -207,6 +209,7 @@ impl<'tcx> OptimizationFinder<'_, 'tcx> { fn find_branch_value_info<'tcx>( left: &Operand<'tcx>, right: &Operand<'tcx>, + ssa: &SsaLocals, ) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> { // check that either left or right is a constant. // if any are, we can use the other to switch on, and the constant as a value in a switch @@ -214,6 +217,10 @@ fn find_branch_value_info<'tcx>( match (left, right) { (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on)) | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => { + // Make sure that the place is not modified. + if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() { + return None; + } let branch_value_ty = branch_value.const_.ty(); // we only want to apply this optimization if we are matching on integrals (and chars), // as it is not possible to switch on floats diff --git a/tests/mir-opt/if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff b/tests/mir-opt/if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff index d0983c660623..ce5a2bf172a9 100644 --- a/tests/mir-opt/if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff +++ b/tests/mir-opt/if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff @@ -8,8 +8,7 @@ bb0: { _2 = Eq(copy _1, const 42_u64); _1 = const 43_u64; -- switchInt(copy _2) -> [1: bb1, otherwise: bb2]; -+ switchInt(move _1) -> [42: bb1, otherwise: bb2]; + switchInt(copy _2) -> [1: bb1, otherwise: bb2]; } bb1: { diff --git a/tests/mir-opt/if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff b/tests/mir-opt/if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff index 0c6c8dca4753..7ad0a87f1cdd 100644 --- a/tests/mir-opt/if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff +++ b/tests/mir-opt/if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff @@ -8,8 +8,7 @@ bb0: { _3 = Eq(copy _1[_2], const 42_u64); _2 = const 10_usize; -- switchInt(copy _3) -> [1: bb1, otherwise: bb2]; -+ switchInt(move _1[_2]) -> [42: bb1, otherwise: bb2]; + switchInt(copy _3) -> [1: bb1, otherwise: bb2]; } bb1: { diff --git a/tests/mir-opt/if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff b/tests/mir-opt/if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff index b1b1ab2c2205..e2dc97f76b5c 100644 --- a/tests/mir-opt/if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff +++ b/tests/mir-opt/if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff @@ -8,8 +8,7 @@ bb0: { _2 = Eq(copy _1, const 42_u64); _2 = const false; -- switchInt(copy _2) -> [1: bb1, otherwise: bb2]; -+ switchInt(move _1) -> [42: bb1, otherwise: bb2]; + switchInt(copy _2) -> [1: bb1, otherwise: bb2]; } bb1: { diff --git a/tests/mir-opt/if_condition_int.rs b/tests/mir-opt/if_condition_int.rs index ba901f6b9b15..b49f8768253a 100644 --- a/tests/mir-opt/if_condition_int.rs +++ b/tests/mir-opt/if_condition_int.rs @@ -120,6 +120,14 @@ fn dont_opt_floats(a: f32) -> i32 { // EMIT_MIR if_condition_int.on_non_ssa_switch.SimplifyComparisonIntegral.diff #[custom_mir(dialect = "runtime")] pub fn on_non_ssa_switch(mut v: u64) -> i32 { + // CHECK-LABEL: fn on_non_ssa_switch( + // CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64); + // CHECK: [[cmp]] = const false; + // CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]]; + // CHECK: [[BB1]]: + // CHECK: _0 = const 0_i32; + // CHECK: [[BB2]]: + // CHECK: _0 = const 1_i32; mir! { let a: bool; { @@ -145,6 +153,14 @@ pub fn on_non_ssa_switch(mut v: u64) -> i32 { // EMIT_MIR if_condition_int.on_non_ssa_cmp.SimplifyComparisonIntegral.diff #[custom_mir(dialect = "runtime")] pub fn on_non_ssa_cmp(mut v: u64) -> i32 { + // CHECK-LABEL: fn on_non_ssa_cmp( + // CHECK: [[cmp:_.*]] = Eq(copy _1, const 42_u64); + // CHECK: _1 = const 43_u64; + // CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]]; + // CHECK: [[BB1]]: + // CHECK: _0 = const 0_i32; + // CHECK: [[BB2]]: + // CHECK: _0 = const 1_i32; mir! { let a: bool; { @@ -170,6 +186,14 @@ pub fn on_non_ssa_cmp(mut v: u64) -> i32 { // EMIT_MIR if_condition_int.on_non_ssa_place.SimplifyComparisonIntegral.diff #[custom_mir(dialect = "runtime")] pub fn on_non_ssa_place(mut v: [u64; 10], mut i: usize) -> i32 { + // CHECK-LABEL: fn on_non_ssa_place( + // CHECK: [[cmp:_.*]] = Eq(copy _1[_2], const 42_u64); + // CHECK: _2 = const 10_usize; + // CHECK: switchInt(copy [[cmp]]) -> [1: [[BB1:bb.*]], otherwise: [[BB2:bb.*]]]; + // CHECK: [[BB1]]: + // CHECK: _0 = const 0_i32; + // CHECK: [[BB2]]: + // CHECK: _0 = const 1_i32; mir! { let a: bool; {