New MIR Pass: SsaRangePropagation

This commit is contained in:
dianqk 2025-12-23 22:25:11 +08:00
parent 0ac9e59d8f
commit 0051e31f6f
No known key found for this signature in database
9 changed files with 494 additions and 0 deletions

View file

@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> {
self.projection.iter().any(|elem| elem.is_indirect())
}
/// Returns `true` if the `Place` always refers to the same memory region
/// whatever the state of the program.
pub fn is_stable_offset(&self) -> bool {
self.projection.iter().all(|elem| elem.is_stable_offset())
}
/// Returns `true` if this `Place`'s first projection is `Deref`.
///
/// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later,

View file

@ -198,6 +198,7 @@ declare_passes! {
mod single_use_consts : SingleUseConsts;
mod sroa : ScalarReplacementOfAggregates;
mod strip_debuginfo : StripDebugInfo;
mod ssa_range_prop: SsaRangePropagation;
mod unreachable_enum_branching : UnreachableEnumBranching;
mod unreachable_prop : UnreachablePropagation;
mod validate : Validator;
@ -741,6 +742,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
&dead_store_elimination::DeadStoreElimination::Initial,
&gvn::GVN,
&simplify::SimplifyLocals::AfterGVN,
// This pass does attempt to track assignments.
// Keep it close to GVN which merges identical values into the same local.
&ssa_range_prop::SsaRangePropagation,
&match_branches::MatchBranchSimplification,
&dataflow_const_prop::DataflowConstProp,
&single_use_consts::SingleUseConsts,

View file

@ -0,0 +1,203 @@
//! A pass that propagates the known ranges of SSA locals.
//! We can know the ranges of SSA locals in certain locations for the following code:
//! ```
//! fn foo(a: u32) {
//! let b = a < 9; // the integer representation of b is within the full range [0, 2).
//! if b {
//! let c = b; // c is true since b is within the range [1, 2).
//! let d = a < 8; // d is true since a is within the range [0, 9).
//! }
//! }
//! ```
use rustc_abi::WrappingRange;
use rustc_const_eval::interpret::Scalar;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_index::bit_set::DenseBitSet;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
use rustc_middle::ty::{TyCtxt, TypingEnv};
use rustc_span::DUMMY_SP;
use crate::ssa::SsaLocals;
pub(super) struct SsaRangePropagation;
impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() > 1
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let typing_env = body.typing_env(tcx);
let ssa = SsaLocals::new(tcx, body, typing_env);
// Clone dominators because we need them while mutating the body.
let dominators = body.basic_blocks.dominators().clone();
let mut range_set =
RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
for bb in reverse_postorder {
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
range_set.visit_basic_block_data(bb, data);
}
}
fn is_required(&self) -> bool {
false
}
}
struct RangeSet<'tcx, 'body, 'a> {
tcx: TyCtxt<'tcx>,
typing_env: TypingEnv<'tcx>,
ssa: &'a SsaLocals,
local_decls: &'body LocalDecls<'tcx>,
dominators: Dominators<BasicBlock>,
/// Known ranges at each locations.
ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
/// Determines if the basic block has a single unique predecessor.
unique_predecessors: DenseBitSet<BasicBlock>,
}
impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
fn new(
tcx: TyCtxt<'tcx>,
typing_env: TypingEnv<'tcx>,
body: &Body<'tcx>,
ssa: &'a SsaLocals,
local_decls: &'body LocalDecls<'tcx>,
dominators: Dominators<BasicBlock>,
) -> Self {
let predecessors = body.basic_blocks.predecessors();
let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
for bb in body.basic_blocks.indices() {
if predecessors[bb].len() == 1 {
unique_predecessors.insert(bb);
}
}
RangeSet {
tcx,
typing_env,
ssa,
local_decls,
dominators,
ranges: FxHashMap::default(),
unique_predecessors,
}
}
/// Create a new known range at the location.
fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
assert!(self.is_ssa(place));
self.ranges.entry(place).or_default().push((location, range));
}
/// Get the known range at the location.
fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
let Some(ranges) = self.ranges.get(place) else {
return None;
};
// FIXME: This should use the intersection of all valid ranges.
let (_, range) =
ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
Some(*range)
}
fn try_as_constant(
&mut self,
place: Place<'tcx>,
location: Location,
) -> Option<ConstOperand<'tcx>> {
if let Some(range) = self.get_range(&place, location)
&& range.start == range.end
{
let ty = place.ty(self.local_decls, self.tcx).ty;
let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
let const_ = Const::Val(value, ty);
return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
}
None
}
fn is_ssa(&self, place: Place<'tcx>) -> bool {
self.ssa.is_ssa(place.local) && place.is_stable_offset()
}
}
impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
// Attempts to simplify an operand to a constant value.
if let Some(place) = operand.place()
&& let Some(const_) = self.try_as_constant(place, location)
{
*operand = Operand::Constant(Box::new(const_));
};
}
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
self.super_terminator(terminator, location);
match &terminator.kind {
TerminatorKind::Assert { cond, expected, target, .. } => {
if let Some(place) = cond.place()
&& self.is_ssa(place)
{
let successor = Location { block: *target, statement_index: 0 };
if location.dominates(successor, &self.dominators) {
assert_ne!(location.block, successor.block);
let val = *expected as u128;
let range = WrappingRange { start: val, end: val };
self.insert_range(place, successor, range);
}
}
}
TerminatorKind::SwitchInt { discr, targets } => {
if let Some(place) = discr.place()
&& self.is_ssa(place)
// Reduce the potential compile-time overhead.
&& targets.all_targets().len() < 16
{
let mut distinct_targets: FxHashMap<BasicBlock, u64> = FxHashMap::default();
for (_, target) in targets.iter() {
let targets = distinct_targets.entry(target).or_default();
*targets += 1;
}
for (val, target) in targets.iter() {
if distinct_targets[&target] != 1 {
// FIXME: For multiple targets, the range can be the union of their values.
continue;
}
let successor = Location { block: target, statement_index: 0 };
if self.unique_predecessors.contains(successor.block) {
assert_ne!(location.block, successor.block);
let range = WrappingRange { start: val, end: val };
self.insert_range(place, successor, range);
}
}
// FIXME: The range for the otherwise target be extend to more types.
// For instance, `val` is within the range [4, 1) at the otherwise target of `matches!(val, 1 | 2 | 3)`.
let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
if place.ty(self.local_decls, self.tcx).ty.is_bool()
&& let [val] = targets.all_values()
&& self.unique_predecessors.contains(otherwise.block)
{
assert_ne!(location.block, otherwise.block);
let range = if val.get() == 0 {
WrappingRange { start: 1, end: 1 }
} else {
WrappingRange { start: 0, end: 0 }
};
self.insert_range(place, otherwise, range);
}
}
}
_ => {}
}
}
}

View file

@ -0,0 +1,69 @@
- // MIR for `on_assert` before SsaRangePropagation
+ // MIR for `on_assert` after SsaRangePropagation
fn on_assert(_1: usize, _2: &[u8]) -> u8 {
debug i => _1;
debug v => _2;
let mut _0: u8;
let _3: ();
let mut _4: bool;
let mut _5: usize;
let mut _6: usize;
let mut _7: &[u8];
let mut _8: !;
let _9: usize;
let mut _10: usize;
let mut _11: bool;
scope 1 (inlined core::slice::<impl [u8]>::len) {
scope 2 (inlined std::ptr::metadata::<[u8]>) {
}
}
bb0: {
StorageLive(_3);
nop;
StorageLive(_5);
_5 = copy _1;
nop;
StorageLive(_7);
_7 = &(*_2);
_6 = PtrMetadata(copy _2);
StorageDead(_7);
_4 = Lt(copy _1, copy _6);
switchInt(copy _4) -> [0: bb2, otherwise: bb1];
}
bb1: {
nop;
StorageDead(_5);
_3 = const ();
nop;
StorageDead(_3);
StorageLive(_9);
_9 = copy _1;
_10 = copy _6;
- _11 = copy _4;
- assert(copy _4, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
+ _11 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
}
bb2: {
nop;
StorageDead(_5);
StorageLive(_8);
_8 = panic(const "assertion failed: i < v.len()") -> unwind unreachable;
}
bb3: {
_0 = copy (*_2)[_1];
StorageDead(_9);
return;
}
}
ALLOC0 (size: 29, align: 1) {
0x00 │ 61 73 73 65 72 74 69 6f 6e 20 66 61 69 6c 65 64 │ assertion failed
0x10 │ 3a 20 69 20 3c 20 76 2e 6c 65 6e 28 29 │ : i < v.len()
}

View file

@ -0,0 +1,63 @@
- // MIR for `on_if` before SsaRangePropagation
+ // MIR for `on_if` after SsaRangePropagation
fn on_if(_1: usize, _2: &[u8]) -> u8 {
debug i => _1;
debug v => _2;
let mut _0: u8;
let mut _3: bool;
let mut _4: usize;
let mut _5: usize;
let mut _6: &[u8];
let _7: usize;
let mut _8: usize;
let mut _9: bool;
scope 1 (inlined core::slice::<impl [u8]>::len) {
scope 2 (inlined std::ptr::metadata::<[u8]>) {
}
}
bb0: {
nop;
StorageLive(_4);
_4 = copy _1;
nop;
StorageLive(_6);
_6 = &(*_2);
_5 = PtrMetadata(copy _2);
StorageDead(_6);
_3 = Lt(copy _1, copy _5);
switchInt(copy _3) -> [0: bb3, otherwise: bb1];
}
bb1: {
nop;
StorageDead(_4);
StorageLive(_7);
_7 = copy _1;
_8 = copy _5;
- _9 = copy _3;
- assert(copy _3, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable];
+ _9 = const true;
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable];
}
bb2: {
_0 = copy (*_2)[_1];
StorageDead(_7);
goto -> bb4;
}
bb3: {
nop;
StorageDead(_4);
_0 = const 0_u8;
goto -> bb4;
}
bb4: {
nop;
return;
}
}

View file

@ -0,0 +1,20 @@
- // MIR for `on_if_2` before SsaRangePropagation
+ // MIR for `on_if_2` after SsaRangePropagation
fn on_if_2(_1: bool) -> bool {
let mut _0: bool;
bb0: {
switchInt(copy _1) -> [1: bb2, otherwise: bb1];
}
bb1: {
goto -> bb2;
}
bb2: {
_0 = copy _1;
return;
}
}

View file

@ -0,0 +1,33 @@
- // MIR for `on_match` before SsaRangePropagation
+ // MIR for `on_match` after SsaRangePropagation
fn on_match(_1: u8) -> u8 {
debug i => _1;
let mut _0: u8;
bb0: {
switchInt(copy _1) -> [1: bb3, 2: bb2, otherwise: bb1];
}
bb1: {
_0 = const 0_u8;
goto -> bb4;
}
bb2: {
- _0 = copy _1;
+ _0 = const 2_u8;
goto -> bb4;
}
bb3: {
- _0 = copy _1;
+ _0 = const 1_u8;
goto -> bb4;
}
bb4: {
return;
}
}

View file

@ -0,0 +1,26 @@
- // MIR for `on_match_2` before SsaRangePropagation
+ // MIR for `on_match_2` after SsaRangePropagation
fn on_match_2(_1: u8) -> u8 {
debug i => _1;
let mut _0: u8;
bb0: {
switchInt(copy _1) -> [1: bb2, 2: bb2, otherwise: bb1];
}
bb1: {
_0 = const 0_u8;
goto -> bb3;
}
bb2: {
_0 = copy _1;
goto -> bb3;
}
bb3: {
return;
}
}

View file

@ -0,0 +1,70 @@
//@ test-mir-pass: SsaRangePropagation
//@ compile-flags: -Zmir-enable-passes=+GVN,+Inline --crate-type=lib -Cpanic=abort
#![feature(custom_mir, core_intrinsics)]
use std::intrinsics::mir::*;
// EMIT_MIR ssa_range.on_if.SsaRangePropagation.diff
pub fn on_if(i: usize, v: &[u8]) -> u8 {
// CHECK-LABEL: fn on_if(
// CHECK: assert(const true
if i < v.len() { v[i] } else { 0 }
}
// EMIT_MIR ssa_range.on_assert.SsaRangePropagation.diff
pub fn on_assert(i: usize, v: &[u8]) -> u8 {
// CHECK-LABEL: fn on_assert(
// CHECK: assert(const true
assert!(i < v.len());
v[i]
}
// EMIT_MIR ssa_range.on_match.SsaRangePropagation.diff
pub fn on_match(i: u8) -> u8 {
// CHECK-LABEL: fn on_match(
// CHECK: switchInt(copy _1) -> [1: [[BB_V1:bb.*]], 2: [[BB_V2:bb.*]],
// CHECK: [[BB_V2]]: {
// CHECK-NEXT: _0 = const 2_u8;
// CHECK: [[BB_V1]]: {
// CHECK-NEXT: _0 = const 1_u8;
match i {
1 => i,
2 => i,
_ => 0,
}
}
// EMIT_MIR ssa_range.on_match_2.SsaRangePropagation.diff
pub fn on_match_2(i: u8) -> u8 {
// CHECK-LABEL: fn on_match_2(
// CHECK: switchInt(copy _1) -> [1: [[BB:bb.*]], 2: [[BB]],
// CHECK: [[BB]]: {
// CHECK-NEXT: _0 = copy _1;
match i {
1 | 2 => i,
_ => 0,
}
}
// EMIT_MIR ssa_range.on_if_2.SsaRangePropagation.diff
#[custom_mir(dialect = "runtime", phase = "post-cleanup")]
pub fn on_if_2(a: bool) -> bool {
// CHECK-LABEL: fn on_if_2(
// CHECK: _0 = copy _1;
mir! {
{
match a {
true => bb2,
_ => bb1
}
}
bb1 = {
Goto(bb2)
}
bb2 = {
RET = a;
Return()
}
}
}