From 8e05ab04e54f2531198bb61c32ad2232d682a63c Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Sun, 5 Feb 2023 12:08:42 +0000 Subject: [PATCH] Run SROA to fixpoint. --- .../rustc_mir_dataflow/src/value_analysis.rs | 2 +- compiler/rustc_mir_transform/src/sroa.rs | 74 +++++++++---------- ...ble_variable_aggregate.main.ConstProp.diff | 27 +++---- ....copies.ScalarReplacementOfAggregates.diff | 42 ++++++++--- 4 files changed, 78 insertions(+), 67 deletions(-) diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index 90d07c81256f..8bf6493be4b0 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -824,7 +824,7 @@ pub fn iter_fields<'tcx>( } /// Returns all locals with projections that have their reference or address taken. -fn excluded_locals(body: &Body<'_>) -> IndexVec { +pub fn excluded_locals(body: &Body<'_>) -> IndexVec { struct Collector { result: IndexVec, } diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 3cfa0b16499a..28963c77aa55 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -6,7 +6,7 @@ use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; -use rustc_mir_dataflow::value_analysis::iter_fields; +use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; pub struct ScalarReplacementOfAggregates; @@ -18,26 +18,38 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { #[instrument(level = "debug", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { debug!(def_id = ?body.source.def_id()); - let escaping = escaping_locals(&*body); - debug!(?escaping); - let replacements = compute_flattening(tcx, body, escaping); - debug!(?replacements); - replace_flattened_locals(tcx, body, replacements); + let mut excluded = excluded_locals(body); + loop { + debug!(?excluded); + let escaping = escaping_locals(&excluded, body); + debug!(?escaping); + let replacements = compute_flattening(tcx, body, escaping); + debug!(?replacements); + let all_dead_locals = replace_flattened_locals(tcx, body, replacements); + if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 { + for local in excluded.indices() { + excluded[local] |= all_dead_locals.contains(local) ; + } + excluded.raw.resize(body.local_decls.len(), false); + } else { + break + } + } } } /// Identify all locals that are not eligible for SROA. /// /// There are 3 cases: -/// - the aggegated local is used or passed to other code (function parameters and arguments); +/// - the aggregated local is used or passed to other code (function parameters and arguments); /// - the locals is a union or an enum; /// - the local's address is taken, and thus the relative addresses of the fields are observable to /// client code. -fn escaping_locals(body: &Body<'_>) -> BitSet { +fn escaping_locals(excluded: &IndexVec, body: &Body<'_>) -> BitSet { let mut set = BitSet::new_empty(body.local_decls.len()); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); for (local, decl) in body.local_decls().iter_enumerated() { - if decl.ty.is_union() || decl.ty.is_enum() { + if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] { set.insert(local); } } @@ -62,17 +74,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { self.super_place(place, context, location); } - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; - } - } - self.super_rvalue(rvalue, location) - } - fn visit_assign( &mut self, lvalue: &Place<'tcx>, @@ -102,21 +103,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { } } - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - // Drop implicitly calls `drop_in_place`, which takes a `&mut`. - // This implies that `Drop` implicitly takes the address of the place. - if let TerminatorKind::Drop { place, .. } - | TerminatorKind::DropAndReplace { place, .. } = terminator.kind - { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; - } - } - self.super_terminator(terminator, location); - } - // We ignore anything that happens in debuginfo, since we expand it using // `VarDebugInfoContents::Composite`. fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {} @@ -198,14 +184,14 @@ fn replace_flattened_locals<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, replacements: ReplacementMap<'tcx>, -) { +) -> BitSet { let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); for p in replacements.fields.keys() { all_dead_locals.insert(p.local); } debug!(?all_dead_locals); if all_dead_locals.is_empty() { - return; + return all_dead_locals; } let mut visitor = ReplacementVisitor { @@ -227,7 +213,9 @@ fn replace_flattened_locals<'tcx>( for var_debug_info in &mut body.var_debug_info { visitor.visit_var_debug_info(var_debug_info); } - visitor.patch.apply(body); + let ReplacementVisitor { patch, all_dead_locals, .. } = visitor; + patch.apply(body); + all_dead_locals } struct ReplacementVisitor<'tcx, 'll> { @@ -361,6 +349,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } } + #[instrument(level = "trace", skip(self))] fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { match &mut var_debug_info.value { VarDebugInfoContents::Place(ref mut place) => { @@ -375,11 +364,12 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } VarDebugInfoContents::Composite { ty: _, ref mut fragments } => { let mut new_fragments = Vec::new(); + debug!(?fragments); fragments .drain_filter(|fragment| { if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { fragment.contents = repl; - true + false } else if let Some(frg) = self .replacements .gather_debug_info_fragments(fragment.contents.as_ref()) @@ -388,12 +378,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { f.projection.splice(0..0, fragment.projection.iter().copied()); f })); - false - } else { true + } else { + false } }) .for_each(drop); + debug!(?fragments); + debug!(?new_fragments); fragments.extend(new_fragments); } VarDebugInfoContents::Const(_) => {} diff --git a/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff b/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff index 37fbcf9dd496..d088c4f662b7 100644 --- a/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff +++ b/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff @@ -3,30 +3,27 @@ fn main() -> () { let mut _0: (); // return place in scope 0 at $DIR/mutable_variable_aggregate.rs:+0:11: +0:11 - let mut _1: (i32, i32); // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 + let mut _3: i32; // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 + let mut _4: i32; // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 scope 1 { - debug x => _1; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 + debug x => (i32, i32){ .0 => _3, .1 => _4, }; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 + let _1: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 let _2: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 - let _3: i32; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 scope 2 { - debug y => (i32, i32){ .0 => _2, .1 => _3, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 + debug y => (i32, i32){ .0 => _3, .1 => _2, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 } } bb0: { - StorageLive(_1); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 -- _1 = (const 42_i32, const 43_i32); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25 -+ _1 = const (42_i32, 43_i32); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25 - (_1.1: i32) = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13 + StorageLive(_4); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14 + _3 = const 42_i32; // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25 + _4 = const 43_i32; // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25 + _4 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13 StorageLive(_2); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 - StorageLive(_3); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10 -- _2 = (_1.0: i32); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 -- _3 = (_1.1: i32); // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 -+ _2 = const 42_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 -+ _3 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 +- _2 = _4; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 ++ _2 = const 99_i32; // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14 StorageDead(_2); // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2 - StorageDead(_3); // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2 - StorageDead(_1); // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2 + StorageDead(_4); // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2 return; // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:2: +4:2 } } diff --git a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff index b76e2d6d0f29..976f6d44b752 100644 --- a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff +++ b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff @@ -5,8 +5,13 @@ debug x => _1; // in scope 0 at $DIR/sroa.rs:+0:11: +0:12 let mut _0: (); // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19 let _2: Foo; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _11: u8; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _12: (); // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _13: &str; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ let _14: std::option::Option; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10 scope 1 { - debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 +- debug y => _2; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 ++ debug y => Foo{ .0 => _11, .1 => _12, .2 => _13, .3 => _14, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10 let _3: u8; // in scope 1 at $DIR/sroa.rs:+2:9: +2:10 scope 2 { debug t => _3; // in scope 2 at $DIR/sroa.rs:+2:9: +2:10 @@ -31,23 +36,35 @@ } bb0: { - StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 - _2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14 +- StorageLive(_2); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 +- _2 = _1; // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ StorageLive(_11); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ StorageLive(_12); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ StorageLive(_13); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ StorageLive(_14); // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ nop; // scope 0 at $DIR/sroa.rs:+1:9: +1:10 ++ _11 = (_1.0: u8); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ _12 = (_1.1: ()); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ _13 = (_1.2: &str); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ _14 = (_1.3: std::option::Option); // scope 0 at $DIR/sroa.rs:+1:13: +1:14 ++ nop; // scope 0 at $DIR/sroa.rs:+1:13: +1:14 StorageLive(_3); // scope 1 at $DIR/sroa.rs:+2:9: +2:10 - _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16 +- _3 = (_2.0: u8); // scope 1 at $DIR/sroa.rs:+2:13: +2:16 ++ _3 = _11; // scope 1 at $DIR/sroa.rs:+2:13: +2:16 StorageLive(_4); // scope 2 at $DIR/sroa.rs:+3:9: +3:10 - _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16 +- _4 = (_2.2: &str); // scope 2 at $DIR/sroa.rs:+3:13: +3:16 - StorageLive(_5); // scope 3 at $DIR/sroa.rs:+4:9: +4:10 - _5 = _2; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 ++ _4 = _13; // scope 2 at $DIR/sroa.rs:+3:13: +3:16 + StorageLive(_7); // scope 3 at $DIR/sroa.rs:+4:9: +4:10 + StorageLive(_8); // scope 3 at $DIR/sroa.rs:+4:9: +4:10 + StorageLive(_9); // scope 3 at $DIR/sroa.rs:+4:9: +4:10 + StorageLive(_10); // scope 3 at $DIR/sroa.rs:+4:9: +4:10 + nop; // scope 3 at $DIR/sroa.rs:+4:9: +4:10 -+ _7 = (_2.0: u8); // scope 3 at $DIR/sroa.rs:+4:13: +4:14 -+ _8 = (_2.1: ()); // scope 3 at $DIR/sroa.rs:+4:13: +4:14 -+ _9 = (_2.2: &str); // scope 3 at $DIR/sroa.rs:+4:13: +4:14 -+ _10 = (_2.3: std::option::Option); // scope 3 at $DIR/sroa.rs:+4:13: +4:14 ++ _7 = _11; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 ++ _8 = _12; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 ++ _9 = _13; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 ++ _10 = _14; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 + nop; // scope 3 at $DIR/sroa.rs:+4:13: +4:14 StorageLive(_6); // scope 4 at $DIR/sroa.rs:+5:9: +5:10 - _6 = (_5.1: ()); // scope 4 at $DIR/sroa.rs:+5:13: +5:16 @@ -62,7 +79,12 @@ + nop; // scope 3 at $DIR/sroa.rs:+6:1: +6:2 StorageDead(_4); // scope 2 at $DIR/sroa.rs:+6:1: +6:2 StorageDead(_3); // scope 1 at $DIR/sroa.rs:+6:1: +6:2 - StorageDead(_2); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 +- StorageDead(_2); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 ++ StorageDead(_11); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 ++ StorageDead(_12); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 ++ StorageDead(_13); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 ++ StorageDead(_14); // scope 0 at $DIR/sroa.rs:+6:1: +6:2 ++ nop; // scope 0 at $DIR/sroa.rs:+6:1: +6:2 return; // scope 0 at $DIR/sroa.rs:+6:2: +6:2 } }