diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 28963c77aa55..26acd406ed8a 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,11 +1,10 @@ use crate::MirPass; -use rustc_data_structures::fx::FxIndexMap; use rustc_index::bit_set::BitSet; use rustc_index::vec::IndexVec; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; pub struct ScalarReplacementOfAggregates; @@ -26,13 +25,13 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { let replacements = compute_flattening(tcx, body, escaping); debug!(?replacements); let all_dead_locals = replace_flattened_locals(tcx, body, replacements); - if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 { + if !all_dead_locals.is_empty() { for local in excluded.indices() { - excluded[local] |= all_dead_locals.contains(local) ; + excluded[local] |= all_dead_locals.contains(local); } excluded.raw.resize(body.local_decls.len(), false); } else { - break + break; } } } @@ -111,36 +110,29 @@ fn escaping_locals(excluded: &IndexVec, body: &Body<'_>) -> BitSet< #[derive(Default, Debug)] struct ReplacementMap<'tcx> { - fields: FxIndexMap, Local>, /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage /// and deinit statement and debuginfo. - fragments: IndexVec], Local)>>>, + fragments: IndexVec, Local)>>>>, } impl<'tcx> ReplacementMap<'tcx> { - fn gather_debug_info_fragments( - &self, - place: PlaceRef<'tcx>, - ) -> Option>> { - let mut fragments = Vec::new(); - let Some(parts) = &self.fragments[place.local] else { return None }; - for (proj, replacement_local) in parts { - if proj.starts_with(place.projection) { - fragments.push(VarDebugInfoFragment { - projection: proj[place.projection.len()..].to_vec(), - contents: Place::from(*replacement_local), - }); - } - } - Some(fragments) + fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option> { + let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; }; + let fields = self.fragments[place.local].as_ref()?; + let (_, new_local) = fields[f]?; + Some(Place { local: new_local, projection: tcx.intern_place_elems(&rest) }) } fn place_fragments( &self, place: Place<'tcx>, - ) -> Option<&Vec<(&'tcx [PlaceElem<'tcx>], Local)>> { + ) -> Option, Local)> + '_> { let local = place.as_local()?; - self.fragments[local].as_ref() + let fields = self.fragments[local].as_ref()?; + Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| { + let (ty, local) = opt_ty_local?; + Some((field, ty, local)) + })) } } @@ -153,8 +145,7 @@ fn compute_flattening<'tcx>( body: &mut Body<'tcx>, escaping: BitSet, ) -> ReplacementMap<'tcx> { - let mut fields = FxIndexMap::default(); - let mut fragments = IndexVec::from_elem(None::>, &body.local_decls); + let mut fragments = IndexVec::from_elem(None, &body.local_decls); for local in body.local_decls.indices() { if escaping.contains(local) { @@ -169,14 +160,10 @@ fn compute_flattening<'tcx>( }; let new_local = body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }); - let place = Place::from(local) - .project_deeper(&[PlaceElem::Field(field, field_ty)], tcx) - .as_ref(); - fields.insert(place, new_local); - fragments[local].get_or_insert_default().push((place.projection, new_local)); + fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local)); }); } - ReplacementMap { fields, fragments } + ReplacementMap { fragments } } /// Perform the replacement computed by `compute_flattening`. @@ -186,8 +173,10 @@ fn replace_flattened_locals<'tcx>( replacements: ReplacementMap<'tcx>, ) -> BitSet { let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); - for p in replacements.fields.keys() { - all_dead_locals.insert(p.local); + for (local, replacements) in replacements.fragments.iter_enumerated() { + if replacements.is_some() { + all_dead_locals.insert(local); + } } debug!(?all_dead_locals); if all_dead_locals.is_empty() { @@ -197,7 +186,7 @@ fn replace_flattened_locals<'tcx>( let mut visitor = ReplacementVisitor { tcx, local_decls: &body.local_decls, - replacements, + replacements: &replacements, all_dead_locals, patch: MirPatch::new(body), }; @@ -223,21 +212,23 @@ struct ReplacementVisitor<'tcx, 'll> { /// This is only used to compute the type for `VarDebugInfoContents::Composite`. local_decls: &'ll LocalDecls<'tcx>, /// Work to do. - replacements: ReplacementMap<'tcx>, + replacements: &'ll ReplacementMap<'tcx>, /// This is used to check that we are not leaving references to replaced locals behind. all_dead_locals: BitSet, patch: MirPatch<'tcx>, } -impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { - fn replace_place(&self, place: PlaceRef<'tcx>) -> Option> { - if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection { - let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; - let local = self.replacements.fields.get(&pr)?; - Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) }) - } else { - None +impl<'tcx> ReplacementVisitor<'tcx, '_> { + fn gather_debug_info_fragments(&self, local: Local) -> Option>> { + let mut fragments = Vec::new(); + let parts = self.replacements.place_fragments(local.into())?; + for (field, ty, replacement_local) in parts { + fragments.push(VarDebugInfoFragment { + projection: vec![PlaceElem::Field(field, ty)], + contents: Place::from(replacement_local), + }); } + Some(fragments) } } @@ -246,12 +237,21 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { self.tcx } + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { + *place = repl + } else { + self.super_place(place, context, location) + } + } + #[instrument(level = "trace", skip(self))] fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { match statement.kind { + // Duplicate storage and deinit statements, as they pretty much apply to all fields. StatementKind::StorageLive(l) => { - if let Some(final_locals) = &self.replacements.fragments[l] { - for &(_, fl) in final_locals { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { self.patch.add_statement(location, StatementKind::StorageLive(fl)); } statement.make_nop(); @@ -259,8 +259,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { return; } StatementKind::StorageDead(l) => { - if let Some(final_locals) = &self.replacements.fragments[l] { - for &(_, fl) in final_locals { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { self.patch.add_statement(location, StatementKind::StorageDead(fl)); } statement.make_nop(); @@ -269,7 +269,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } StatementKind::Deinit(box place) => { if let Some(final_locals) = self.replacements.place_fragments(place) { - for &(_, fl) in final_locals { + for (_, _, fl) in final_locals { self.patch .add_statement(location, StatementKind::Deinit(Box::new(fl.into()))); } @@ -278,48 +278,80 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { } } - StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref operands))) => { - if let Some(final_locals) = self.replacements.place_fragments(place) { - for &(projection, fl) in final_locals { - let &[PlaceElem::Field(index, _)] = projection else { bug!() }; - let index = index.as_usize(); - let rvalue = Rvalue::Use(operands[index].clone()); - self.patch.add_statement( - location, - StatementKind::Assign(Box::new((fl.into(), rvalue))), - ); + // We have `a = Struct { 0: x, 1: y, .. }`. + // We replace it by + // ``` + // a_0 = x + // a_1 = y + // ... + // ``` + StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.replacements.fragments[local] + { + // This is ok as we delete the statement later. + let operands = std::mem::take(operands); + for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) { + if let Some((_, new_local)) = opt_ty_local { + // Replace mentions of SROA'd locals that appear in the operand. + self.visit_operand(&mut operand, location); + + let rvalue = Rvalue::Use(operand); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } } statement.make_nop(); return; } } + // We have `a = some constant` + // We add the projections. + // ``` + // a_0 = a.0 + // a_1 = a.1 + // ... + // ``` + // ConstProp will pick up the pieces and replace them by actual constants. StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => { if let Some(final_locals) = self.replacements.place_fragments(place) { - for &(projection, fl) in final_locals { - let rvalue = - Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx))); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(place, field, ty); + let rvalue = Rvalue::Use(Operand::Move(rplace)); self.patch.add_statement( location, - StatementKind::Assign(Box::new((fl.into(), rvalue))), + StatementKind::Assign(Box::new((new_local.into(), rvalue))), ); } - self.all_dead_locals.remove(place.local); + // We still need `place.local` to exist, so don't make it nop. return; } } + // We have `a = move? place` + // We replace it by + // ``` + // a_0 = move? place.0 + // a_1 = move? place.1 + // ... + // ``` StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => { - let (rplace, copy) = match op { + let (rplace, copy) = match *op { Operand::Copy(rplace) => (rplace, true), Operand::Move(rplace) => (rplace, false), Operand::Constant(_) => bug!(), }; if let Some(final_locals) = self.replacements.place_fragments(lhs) { - for &(projection, fl) in final_locals { - let rplace = rplace.project_deeper(projection, self.tcx); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(rplace, field, ty); debug!(?rplace); - let rplace = self.replace_place(rplace.as_ref()).unwrap_or(rplace); + let rplace = self + .replacements + .replace_place(self.tcx, rplace.as_ref()) + .unwrap_or(rplace); debug!(?rplace); let rvalue = if copy { Rvalue::Use(Operand::Copy(rplace)) @@ -328,7 +360,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { }; self.patch.add_statement( location, - StatementKind::Assign(Box::new((fl.into(), rvalue))), + StatementKind::Assign(Box::new((new_local.into(), rvalue))), ); } statement.make_nop(); @@ -341,22 +373,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { self.super_statement(statement, location) } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - if let Some(repl) = self.replace_place(place.as_ref()) { - *place = repl - } else { - self.super_place(place, context, location) - } - } - #[instrument(level = "trace", skip(self))] fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { match &mut var_debug_info.value { VarDebugInfoContents::Place(ref mut place) => { - if let Some(repl) = self.replace_place(place.as_ref()) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { *place = repl; - } else if let Some(fragments) = - self.replacements.gather_debug_info_fragments(place.as_ref()) + } else if let Some(local) = place.as_local() + && let Some(fragments) = self.gather_debug_info_fragments(local) { let ty = place.ty(self.local_decls, self.tcx).ty; var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; @@ -367,12 +391,13 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { debug!(?fragments); fragments .drain_filter(|fragment| { - if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { + if let Some(repl) = + self.replacements.replace_place(self.tcx, fragment.contents.as_ref()) + { fragment.contents = repl; false - } else if let Some(frg) = self - .replacements - .gather_debug_info_fragments(fragment.contents.as_ref()) + } else if let Some(local) = fragment.contents.as_local() + && let Some(frg) = self.gather_debug_info_fragments(local) { new_fragments.extend(frg.into_iter().map(|mut f| { f.projection.splice(0..0, fragment.projection.iter().copied());