diff --git a/compiler/rustc_mir_build/src/thir/pattern/_match.rs b/compiler/rustc_mir_build/src/thir/pattern/_match.rs index ad94740c1606..1cbfc73a9c6d 100644 --- a/compiler/rustc_mir_build/src/thir/pattern/_match.rs +++ b/compiler/rustc_mir_build/src/thir/pattern/_match.rs @@ -276,7 +276,7 @@ use self::Usefulness::*; use self::WitnessPreference::*; use rustc_data_structures::captures::Captures; -use rustc_data_structures::fx::FxHashSet; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_index::vec::Idx; use super::{compare_const_vals, PatternFoldable, PatternFolder}; @@ -504,13 +504,27 @@ impl<'p, 'tcx> FromIterator<&'p Pat<'tcx>> for PatStack<'p, 'tcx> { } } +/// Depending on the match patterns, the specialization process might be able to use a fast path. +/// Tracks whether we can use the fast path and the lookup table needed in those cases. +#[derive(Clone, Debug)] +enum SpecializationCache { + /// Patterns consist of only enum variants. + Variants { lookup: FxHashMap>, wilds: SmallVec<[usize; 1]> }, + /// Does not belong to the cases above, use the slow path. + Incompatible, +} + /// A 2D matrix. #[derive(Clone)] -crate struct Matrix<'p, 'tcx>(Vec>); +crate struct Matrix<'p, 'tcx> { + patterns: Vec>, + cache: SpecializationCache, +} impl<'p, 'tcx> Matrix<'p, 'tcx> { crate fn empty() -> Self { - Matrix(vec![]) + // Use SpecializationCache::Incompatible as a placeholder; the initialization is in push(). + Matrix { patterns: vec![], cache: SpecializationCache::Incompatible } } /// Pushes a new row to the matrix. If the row starts with an or-pattern, this expands it. @@ -522,18 +536,65 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> { self.push(row) } } else { - self.0.push(row); + if self.patterns.is_empty() { + self.cache = if row.is_empty() { + SpecializationCache::Incompatible + } else { + match *row.head().kind { + PatKind::Variant { .. } => SpecializationCache::Variants { + lookup: FxHashMap::default(), + wilds: SmallVec::new(), + }, + // Note: If the first pattern is a wildcard, then all patterns after that is not + // useful. The check is simple enough so we treat it as the same as unsupported + // patterns. + _ => SpecializationCache::Incompatible, + } + }; + } + let idx_to_insert = self.patterns.len(); + match &mut self.cache { + SpecializationCache::Variants { ref mut lookup, ref mut wilds } => { + let head = row.head(); + match *head.kind { + _ if head.is_wildcard() => { + for (_, v) in lookup.iter_mut() { + v.push(idx_to_insert); + } + wilds.push(idx_to_insert); + } + PatKind::Variant { adt_def, variant_index, .. } => { + lookup + .entry(adt_def.variants[variant_index].def_id) + .or_insert_with(|| wilds.clone()) + .push(idx_to_insert); + } + _ => { + self.cache = SpecializationCache::Incompatible; + } + } + } + SpecializationCache::Incompatible => {} + } + self.patterns.push(row); } } /// Iterate over the first component of each row fn heads<'a>(&'a self) -> impl Iterator> + Captures<'p> { - self.0.iter().map(|r| r.head()) + self.patterns.iter().map(|r| r.head()) } /// This computes `D(self)`. See top of the file for explanations. fn specialize_wildcard(&self) -> Self { - self.0.iter().filter_map(|r| r.specialize_wildcard()).collect() + match &self.cache { + SpecializationCache::Variants { wilds, .. } => { + wilds.iter().filter_map(|&i| self.patterns[i].specialize_wildcard()).collect() + } + SpecializationCache::Incompatible => { + self.patterns.iter().filter_map(|r| r.specialize_wildcard()).collect() + } + } } /// This computes `S(constructor, self)`. See top of the file for explanations. @@ -543,10 +604,31 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> { constructor: &Constructor<'tcx>, ctor_wild_subpatterns: &Fields<'p, 'tcx>, ) -> Matrix<'p, 'tcx> { - self.0 - .iter() - .filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns)) - .collect() + match &self.cache { + SpecializationCache::Variants { lookup, wilds } => { + if let Constructor::Variant(id) = constructor { + lookup + .get(id) + .unwrap_or(&wilds) + .iter() + .filter_map(|&i| { + self.patterns[i].specialize_constructor( + cx, + constructor, + ctor_wild_subpatterns, + ) + }) + .collect() + } else { + unreachable!() + } + } + SpecializationCache::Incompatible => self + .patterns + .iter() + .filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns)) + .collect(), + } } } @@ -568,7 +650,7 @@ impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "\n")?; - let &Matrix(ref m) = self; + let Matrix { patterns: m, .. } = self; let pretty_printed_matrix: Vec> = m.iter().map(|row| row.iter().map(|pat| format!("{:?}", pat)).collect()).collect(); @@ -1824,7 +1906,7 @@ crate fn is_useful<'p, 'tcx>( is_under_guard: bool, is_top_level: bool, ) -> Usefulness<'tcx> { - let &Matrix(ref rows) = matrix; + let Matrix { patterns: rows, .. } = matrix; debug!("is_useful({:#?}, {:#?})", matrix, v); // The base case. We are pattern-matching on () and the return value is @@ -2266,7 +2348,7 @@ fn split_grouped_constructors<'p, 'tcx>( // `borders` is the set of borders between equivalence classes: each equivalence // class lies between 2 borders. let row_borders = matrix - .0 + .patterns .iter() .flat_map(|row| { IntRange::from_pat(tcx, param_env, row.head()).map(|r| (r, row.len()))