diff --git a/src/librustc_data_structures/bitvec.rs b/src/librustc_data_structures/bitvec.rs index f2f4a69d882b..a0e4f4a3f2d6 100644 --- a/src/librustc_data_structures/bitvec.rs +++ b/src/librustc_data_structures/bitvec.rs @@ -15,26 +15,193 @@ pub struct BitVector { impl BitVector { pub fn new(num_bits: usize) -> BitVector { - let num_words = (num_bits + 63) / 64; + let num_words = u64s(num_bits); BitVector { data: vec![0; num_words] } } - fn word_mask(&self, bit: usize) -> (usize, u64) { - let word = bit / 64; - let mask = 1 << (bit % 64); - (word, mask) - } - pub fn contains(&self, bit: usize) -> bool { - let (word, mask) = self.word_mask(bit); + let (word, mask) = word_mask(bit); (self.data[word] & mask) != 0 } pub fn insert(&mut self, bit: usize) -> bool { - let (word, mask) = self.word_mask(bit); + let (word, mask) = word_mask(bit); let data = &mut self.data[word]; let value = *data; *data = value | mask; (value | mask) != value } + + pub fn insert_all(&mut self, all: &BitVector) -> bool { + assert!(self.data.len() == all.data.len()); + let mut changed = false; + for (i, j) in self.data.iter_mut().zip(&all.data) { + let value = *i; + *i = value | *j; + if value != *i { changed = true; } + } + changed + } + + pub fn grow(&mut self, num_bits: usize) { + let num_words = u64s(num_bits); + let extra_words = self.data.len() - num_words; + self.data.extend((0..extra_words).map(|_| 0)); + } +} + +/// A "bit matrix" is basically a square matrix of booleans +/// represented as one gigantic bitvector. In other words, it is as if +/// you have N bitvectors, each of length N. +#[derive(Clone)] +pub struct BitMatrix { + elements: usize, + vector: Vec, +} + +impl BitMatrix { + pub fn new(elements: usize) -> BitMatrix { + // For every element, we need one bit for every other + // element. Round up to an even number of u64s. + let u64s_per_elem = u64s(elements); + BitMatrix { + elements: elements, + vector: vec![0; elements * u64s_per_elem] + } + } + + /// The range of bits for a given element. + fn range(&self, element: usize) -> (usize, usize) { + let u64s_per_elem = u64s(self.elements); + let start = element * u64s_per_elem; + (start, start + u64s_per_elem) + } + + pub fn add(&mut self, source: usize, target: usize) -> bool { + let (start, _) = self.range(source); + let (word, mask) = word_mask(target); + let mut vector = &mut self.vector[..]; + let v1 = vector[start+word]; + let v2 = v1 | mask; + vector[start+word] = v2; + v1 != v2 + } + + /// Do the bits from `source` contain `target`? + /// Put another way, can `source` reach `target`? + pub fn contains(&self, source: usize, target: usize) -> bool { + let (start, _) = self.range(source); + let (word, mask) = word_mask(target); + (self.vector[start+word] & mask) != 0 + } + + /// Returns those indices that are reachable from both source and + /// target. This is an O(n) operation where `n` is the number of + /// elements (somewhat independent from the actual size of the + /// intersection, in particular). + pub fn intersection(&self, a: usize, b: usize) -> Vec { + let (a_start, a_end) = self.range(a); + let (b_start, b_end) = self.range(b); + let mut result = Vec::with_capacity(self.elements); + for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() { + let mut v = self.vector[i] & self.vector[j]; + for bit in 0..64 { + if v == 0 { break; } + if v & 0x1 != 0 { result.push(base*64 + bit); } + v >>= 1; + } + } + result + } + + /// Add the bits from source to the bits from destination, + /// return true if anything changed. + /// + /// This is used when computing reachability because if you have + /// an edge `destination -> source`, because in that case + /// `destination` can reach everything that `source` can (and + /// potentially more). + pub fn merge(&mut self, source: usize, destination: usize) -> bool { + let (source_start, source_end) = self.range(source); + let (destination_start, destination_end) = self.range(destination); + let vector = &mut self.vector[..]; + let mut changed = false; + for (source_index, destination_index) in + (source_start..source_end).zip(destination_start..destination_end) + { + let v1 = vector[destination_index]; + let v2 = v1 | vector[source_index]; + vector[destination_index] = v2; + changed = changed | (v1 != v2); + } + changed + } +} + +fn u64s(elements: usize) -> usize { + (elements + 63) / 64 +} + +fn word_mask(index: usize) -> (usize, u64) { + let word = index / 64; + let mask = 1 << (index % 64); + (word, mask) +} + +#[test] +fn union_two_vecs() { + let mut vec1 = BitVector::new(65); + let mut vec2 = BitVector::new(65); + assert!(vec1.insert(3)); + assert!(!vec1.insert(3)); + assert!(vec2.insert(5)); + assert!(vec2.insert(64)); + assert!(vec1.insert_all(&vec2)); + assert!(!vec1.insert_all(&vec2)); + assert!(vec1.contains(3)); + assert!(!vec1.contains(4)); + assert!(vec1.contains(5)); + assert!(!vec1.contains(63)); + assert!(vec1.contains(64)); +} + +#[test] +fn grow() { + let mut vec1 = BitVector::new(65); + assert!(vec1.insert(3)); + assert!(!vec1.insert(3)); + assert!(vec1.insert(5)); + assert!(vec1.insert(64)); + vec1.grow(128); + assert!(vec1.contains(3)); + assert!(vec1.contains(5)); + assert!(vec1.contains(64)); + assert!(!vec1.contains(126)); +} + +#[test] +fn matrix_intersection() { + let mut vec1 = BitMatrix::new(200); + + vec1.add(2, 3); + vec1.add(2, 6); + vec1.add(2, 10); + vec1.add(2, 64); + vec1.add(2, 65); + vec1.add(2, 130); + vec1.add(2, 160); + + vec1.add(65, 2); + vec1.add(65, 8); + vec1.add(65, 10); // X + vec1.add(65, 64); // X + vec1.add(65, 68); + vec1.add(65, 133); + vec1.add(65, 160); // X + + let intersection = vec1.intersection(2, 64); + assert!(intersection.is_empty()); + + let intersection = vec1.intersection(2, 65); + assert_eq!(intersection, vec![10, 64, 160]); }