generalize bitvector code into a bitmatrix; write some unit tests, but

probably not enough. This code is so simple, what could possibly go
wrong?
This commit is contained in:
Niko Matsakis 2015-08-18 17:36:32 -04:00
parent 6c11e4a48e
commit 5448de72c2

View file

@ -15,26 +15,193 @@ pub struct BitVector {
impl BitVector {
pub fn new(num_bits: usize) -> BitVector {
let num_words = (num_bits + 63) / 64;
let num_words = u64s(num_bits);
BitVector { data: vec![0; num_words] }
}
fn word_mask(&self, bit: usize) -> (usize, u64) {
let word = bit / 64;
let mask = 1 << (bit % 64);
(word, mask)
}
pub fn contains(&self, bit: usize) -> bool {
let (word, mask) = self.word_mask(bit);
let (word, mask) = word_mask(bit);
(self.data[word] & mask) != 0
}
pub fn insert(&mut self, bit: usize) -> bool {
let (word, mask) = self.word_mask(bit);
let (word, mask) = word_mask(bit);
let data = &mut self.data[word];
let value = *data;
*data = value | mask;
(value | mask) != value
}
pub fn insert_all(&mut self, all: &BitVector) -> bool {
assert!(self.data.len() == all.data.len());
let mut changed = false;
for (i, j) in self.data.iter_mut().zip(&all.data) {
let value = *i;
*i = value | *j;
if value != *i { changed = true; }
}
changed
}
pub fn grow(&mut self, num_bits: usize) {
let num_words = u64s(num_bits);
let extra_words = self.data.len() - num_words;
self.data.extend((0..extra_words).map(|_| 0));
}
}
/// A "bit matrix" is basically a square matrix of booleans
/// represented as one gigantic bitvector. In other words, it is as if
/// you have N bitvectors, each of length N.
#[derive(Clone)]
pub struct BitMatrix {
elements: usize,
vector: Vec<u64>,
}
impl BitMatrix {
pub fn new(elements: usize) -> BitMatrix {
// For every element, we need one bit for every other
// element. Round up to an even number of u64s.
let u64s_per_elem = u64s(elements);
BitMatrix {
elements: elements,
vector: vec![0; elements * u64s_per_elem]
}
}
/// The range of bits for a given element.
fn range(&self, element: usize) -> (usize, usize) {
let u64s_per_elem = u64s(self.elements);
let start = element * u64s_per_elem;
(start, start + u64s_per_elem)
}
pub fn add(&mut self, source: usize, target: usize) -> bool {
let (start, _) = self.range(source);
let (word, mask) = word_mask(target);
let mut vector = &mut self.vector[..];
let v1 = vector[start+word];
let v2 = v1 | mask;
vector[start+word] = v2;
v1 != v2
}
/// Do the bits from `source` contain `target`?
/// Put another way, can `source` reach `target`?
pub fn contains(&self, source: usize, target: usize) -> bool {
let (start, _) = self.range(source);
let (word, mask) = word_mask(target);
(self.vector[start+word] & mask) != 0
}
/// Returns those indices that are reachable from both source and
/// target. This is an O(n) operation where `n` is the number of
/// elements (somewhat independent from the actual size of the
/// intersection, in particular).
pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
let (a_start, a_end) = self.range(a);
let (b_start, b_end) = self.range(b);
let mut result = Vec::with_capacity(self.elements);
for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
let mut v = self.vector[i] & self.vector[j];
for bit in 0..64 {
if v == 0 { break; }
if v & 0x1 != 0 { result.push(base*64 + bit); }
v >>= 1;
}
}
result
}
/// Add the bits from source to the bits from destination,
/// return true if anything changed.
///
/// This is used when computing reachability because if you have
/// an edge `destination -> source`, because in that case
/// `destination` can reach everything that `source` can (and
/// potentially more).
pub fn merge(&mut self, source: usize, destination: usize) -> bool {
let (source_start, source_end) = self.range(source);
let (destination_start, destination_end) = self.range(destination);
let vector = &mut self.vector[..];
let mut changed = false;
for (source_index, destination_index) in
(source_start..source_end).zip(destination_start..destination_end)
{
let v1 = vector[destination_index];
let v2 = v1 | vector[source_index];
vector[destination_index] = v2;
changed = changed | (v1 != v2);
}
changed
}
}
fn u64s(elements: usize) -> usize {
(elements + 63) / 64
}
fn word_mask(index: usize) -> (usize, u64) {
let word = index / 64;
let mask = 1 << (index % 64);
(word, mask)
}
#[test]
fn union_two_vecs() {
let mut vec1 = BitVector::new(65);
let mut vec2 = BitVector::new(65);
assert!(vec1.insert(3));
assert!(!vec1.insert(3));
assert!(vec2.insert(5));
assert!(vec2.insert(64));
assert!(vec1.insert_all(&vec2));
assert!(!vec1.insert_all(&vec2));
assert!(vec1.contains(3));
assert!(!vec1.contains(4));
assert!(vec1.contains(5));
assert!(!vec1.contains(63));
assert!(vec1.contains(64));
}
#[test]
fn grow() {
let mut vec1 = BitVector::new(65);
assert!(vec1.insert(3));
assert!(!vec1.insert(3));
assert!(vec1.insert(5));
assert!(vec1.insert(64));
vec1.grow(128);
assert!(vec1.contains(3));
assert!(vec1.contains(5));
assert!(vec1.contains(64));
assert!(!vec1.contains(126));
}
#[test]
fn matrix_intersection() {
let mut vec1 = BitMatrix::new(200);
vec1.add(2, 3);
vec1.add(2, 6);
vec1.add(2, 10);
vec1.add(2, 64);
vec1.add(2, 65);
vec1.add(2, 130);
vec1.add(2, 160);
vec1.add(65, 2);
vec1.add(65, 8);
vec1.add(65, 10); // X
vec1.add(65, 64); // X
vec1.add(65, 68);
vec1.add(65, 133);
vec1.add(65, 160); // X
let intersection = vec1.intersection(2, 64);
assert!(intersection.is_empty());
let intersection = vec1.intersection(2, 65);
assert_eq!(intersection, vec![10, 64, 160]);
}