generalize bitvector code into a bitmatrix; write some unit tests, but
probably not enough. This code is so simple, what could possibly go wrong?
This commit is contained in:
parent
6c11e4a48e
commit
5448de72c2
1 changed files with 176 additions and 9 deletions
|
|
@ -15,26 +15,193 @@ pub struct BitVector {
|
|||
|
||||
impl BitVector {
|
||||
pub fn new(num_bits: usize) -> BitVector {
|
||||
let num_words = (num_bits + 63) / 64;
|
||||
let num_words = u64s(num_bits);
|
||||
BitVector { data: vec![0; num_words] }
|
||||
}
|
||||
|
||||
fn word_mask(&self, bit: usize) -> (usize, u64) {
|
||||
let word = bit / 64;
|
||||
let mask = 1 << (bit % 64);
|
||||
(word, mask)
|
||||
}
|
||||
|
||||
pub fn contains(&self, bit: usize) -> bool {
|
||||
let (word, mask) = self.word_mask(bit);
|
||||
let (word, mask) = word_mask(bit);
|
||||
(self.data[word] & mask) != 0
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, bit: usize) -> bool {
|
||||
let (word, mask) = self.word_mask(bit);
|
||||
let (word, mask) = word_mask(bit);
|
||||
let data = &mut self.data[word];
|
||||
let value = *data;
|
||||
*data = value | mask;
|
||||
(value | mask) != value
|
||||
}
|
||||
|
||||
pub fn insert_all(&mut self, all: &BitVector) -> bool {
|
||||
assert!(self.data.len() == all.data.len());
|
||||
let mut changed = false;
|
||||
for (i, j) in self.data.iter_mut().zip(&all.data) {
|
||||
let value = *i;
|
||||
*i = value | *j;
|
||||
if value != *i { changed = true; }
|
||||
}
|
||||
changed
|
||||
}
|
||||
|
||||
pub fn grow(&mut self, num_bits: usize) {
|
||||
let num_words = u64s(num_bits);
|
||||
let extra_words = self.data.len() - num_words;
|
||||
self.data.extend((0..extra_words).map(|_| 0));
|
||||
}
|
||||
}
|
||||
|
||||
/// A "bit matrix" is basically a square matrix of booleans
|
||||
/// represented as one gigantic bitvector. In other words, it is as if
|
||||
/// you have N bitvectors, each of length N.
|
||||
#[derive(Clone)]
|
||||
pub struct BitMatrix {
|
||||
elements: usize,
|
||||
vector: Vec<u64>,
|
||||
}
|
||||
|
||||
impl BitMatrix {
|
||||
pub fn new(elements: usize) -> BitMatrix {
|
||||
// For every element, we need one bit for every other
|
||||
// element. Round up to an even number of u64s.
|
||||
let u64s_per_elem = u64s(elements);
|
||||
BitMatrix {
|
||||
elements: elements,
|
||||
vector: vec![0; elements * u64s_per_elem]
|
||||
}
|
||||
}
|
||||
|
||||
/// The range of bits for a given element.
|
||||
fn range(&self, element: usize) -> (usize, usize) {
|
||||
let u64s_per_elem = u64s(self.elements);
|
||||
let start = element * u64s_per_elem;
|
||||
(start, start + u64s_per_elem)
|
||||
}
|
||||
|
||||
pub fn add(&mut self, source: usize, target: usize) -> bool {
|
||||
let (start, _) = self.range(source);
|
||||
let (word, mask) = word_mask(target);
|
||||
let mut vector = &mut self.vector[..];
|
||||
let v1 = vector[start+word];
|
||||
let v2 = v1 | mask;
|
||||
vector[start+word] = v2;
|
||||
v1 != v2
|
||||
}
|
||||
|
||||
/// Do the bits from `source` contain `target`?
|
||||
/// Put another way, can `source` reach `target`?
|
||||
pub fn contains(&self, source: usize, target: usize) -> bool {
|
||||
let (start, _) = self.range(source);
|
||||
let (word, mask) = word_mask(target);
|
||||
(self.vector[start+word] & mask) != 0
|
||||
}
|
||||
|
||||
/// Returns those indices that are reachable from both source and
|
||||
/// target. This is an O(n) operation where `n` is the number of
|
||||
/// elements (somewhat independent from the actual size of the
|
||||
/// intersection, in particular).
|
||||
pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
|
||||
let (a_start, a_end) = self.range(a);
|
||||
let (b_start, b_end) = self.range(b);
|
||||
let mut result = Vec::with_capacity(self.elements);
|
||||
for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
|
||||
let mut v = self.vector[i] & self.vector[j];
|
||||
for bit in 0..64 {
|
||||
if v == 0 { break; }
|
||||
if v & 0x1 != 0 { result.push(base*64 + bit); }
|
||||
v >>= 1;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Add the bits from source to the bits from destination,
|
||||
/// return true if anything changed.
|
||||
///
|
||||
/// This is used when computing reachability because if you have
|
||||
/// an edge `destination -> source`, because in that case
|
||||
/// `destination` can reach everything that `source` can (and
|
||||
/// potentially more).
|
||||
pub fn merge(&mut self, source: usize, destination: usize) -> bool {
|
||||
let (source_start, source_end) = self.range(source);
|
||||
let (destination_start, destination_end) = self.range(destination);
|
||||
let vector = &mut self.vector[..];
|
||||
let mut changed = false;
|
||||
for (source_index, destination_index) in
|
||||
(source_start..source_end).zip(destination_start..destination_end)
|
||||
{
|
||||
let v1 = vector[destination_index];
|
||||
let v2 = v1 | vector[source_index];
|
||||
vector[destination_index] = v2;
|
||||
changed = changed | (v1 != v2);
|
||||
}
|
||||
changed
|
||||
}
|
||||
}
|
||||
|
||||
fn u64s(elements: usize) -> usize {
|
||||
(elements + 63) / 64
|
||||
}
|
||||
|
||||
fn word_mask(index: usize) -> (usize, u64) {
|
||||
let word = index / 64;
|
||||
let mask = 1 << (index % 64);
|
||||
(word, mask)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn union_two_vecs() {
|
||||
let mut vec1 = BitVector::new(65);
|
||||
let mut vec2 = BitVector::new(65);
|
||||
assert!(vec1.insert(3));
|
||||
assert!(!vec1.insert(3));
|
||||
assert!(vec2.insert(5));
|
||||
assert!(vec2.insert(64));
|
||||
assert!(vec1.insert_all(&vec2));
|
||||
assert!(!vec1.insert_all(&vec2));
|
||||
assert!(vec1.contains(3));
|
||||
assert!(!vec1.contains(4));
|
||||
assert!(vec1.contains(5));
|
||||
assert!(!vec1.contains(63));
|
||||
assert!(vec1.contains(64));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grow() {
|
||||
let mut vec1 = BitVector::new(65);
|
||||
assert!(vec1.insert(3));
|
||||
assert!(!vec1.insert(3));
|
||||
assert!(vec1.insert(5));
|
||||
assert!(vec1.insert(64));
|
||||
vec1.grow(128);
|
||||
assert!(vec1.contains(3));
|
||||
assert!(vec1.contains(5));
|
||||
assert!(vec1.contains(64));
|
||||
assert!(!vec1.contains(126));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matrix_intersection() {
|
||||
let mut vec1 = BitMatrix::new(200);
|
||||
|
||||
vec1.add(2, 3);
|
||||
vec1.add(2, 6);
|
||||
vec1.add(2, 10);
|
||||
vec1.add(2, 64);
|
||||
vec1.add(2, 65);
|
||||
vec1.add(2, 130);
|
||||
vec1.add(2, 160);
|
||||
|
||||
vec1.add(65, 2);
|
||||
vec1.add(65, 8);
|
||||
vec1.add(65, 10); // X
|
||||
vec1.add(65, 64); // X
|
||||
vec1.add(65, 68);
|
||||
vec1.add(65, 133);
|
||||
vec1.add(65, 160); // X
|
||||
|
||||
let intersection = vec1.intersection(2, 64);
|
||||
assert!(intersection.is_empty());
|
||||
|
||||
let intersection = vec1.intersection(2, 65);
|
||||
assert_eq!(intersection, vec![10, 64, 160]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue