Implement union, intersection, and difference functions for TrieSet.

This commit is contained in:
Chase Southwood 2014-11-25 02:15:28 -06:00
parent 7222ba9650
commit 2a6f197bf4

View file

@ -9,7 +9,6 @@
// except according to those terms.
// FIXME(conventions): implement bounded iterators
// FIXME(conventions): implement union family of fns
// FIXME(conventions): implement BitOr, BitAnd, BitXor, and Sub
// FIXME(conventions): replace each_reverse by making iter DoubleEnded
// FIXME(conventions): implement iter_mut and into_iter
@ -19,6 +18,7 @@ use core::prelude::*;
use core::default::Default;
use core::fmt;
use core::fmt::Show;
use core::iter::Peekable;
use std::hash::Hash;
use trie_map::{TrieMap, Entries};
@ -172,6 +172,106 @@ impl TrieSet {
SetItems{iter: self.map.upper_bound(val)}
}
/// Visits the values representing the difference, in ascending order.
///
/// # Example
///
/// ```
/// use std::collections::TrieSet;
///
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
///
/// // Can be seen as `a - b`.
/// for x in a.difference(&b) {
/// println!("{}", x); // Print 1 then 2
/// }
///
/// let diff1: TrieSet = a.difference(&b).collect();
/// assert_eq!(diff1, [1, 2].iter().map(|&x| x).collect());
///
/// // Note that difference is not symmetric,
/// // and `b - a` means something else:
/// let diff2: TrieSet = b.difference(&a).collect();
/// assert_eq!(diff2, [4, 5].iter().map(|&x| x).collect());
/// ```
#[unstable = "matches collection reform specification, waiting for dust to settle"]
pub fn difference<'a>(&'a self, other: &'a TrieSet) -> DifferenceItems<'a> {
DifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
}
/// Visits the values representing the symmetric difference, in ascending order.
///
/// # Example
///
/// ```
/// use std::collections::TrieSet;
///
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
///
/// // Print 1, 2, 4, 5 in ascending order.
/// for x in a.symmetric_difference(&b) {
/// println!("{}", x);
/// }
///
/// let diff1: TrieSet = a.symmetric_difference(&b).collect();
/// let diff2: TrieSet = b.symmetric_difference(&a).collect();
///
/// assert_eq!(diff1, diff2);
/// assert_eq!(diff1, [1, 2, 4, 5].iter().map(|&x| x).collect());
/// ```
#[unstable = "matches collection reform specification, waiting for dust to settle."]
pub fn symmetric_difference<'a>(&'a self, other: &'a TrieSet) -> SymDifferenceItems<'a> {
SymDifferenceItems{a: self.iter().peekable(), b: other.iter().peekable()}
}
/// Visits the values representing the intersection, in ascending order.
///
/// # Example
///
/// ```
/// use std::collections::TrieSet;
///
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
/// let b: TrieSet = [2, 3, 4].iter().map(|&x| x).collect();
///
/// // Print 2, 3 in ascending order.
/// for x in a.intersection(&b) {
/// println!("{}", x);
/// }
///
/// let diff: TrieSet = a.intersection(&b).collect();
/// assert_eq!(diff, [2, 3].iter().map(|&x| x).collect());
/// ```
#[unstable = "matches collection reform specification, waiting for dust to settle"]
pub fn intersection<'a>(&'a self, other: &'a TrieSet) -> IntersectionItems<'a> {
IntersectionItems{a: self.iter().peekable(), b: other.iter().peekable()}
}
/// Visits the values representing the union, in ascending order.
///
/// # Example
///
/// ```
/// use std::collections::TrieSet;
///
/// let a: TrieSet = [1, 2, 3].iter().map(|&x| x).collect();
/// let b: TrieSet = [3, 4, 5].iter().map(|&x| x).collect();
///
/// // Print 1, 2, 3, 4, 5 in ascending order.
/// for x in a.union(&b) {
/// println!("{}", x);
/// }
///
/// let diff: TrieSet = a.union(&b).collect();
/// assert_eq!(diff, [1, 2, 3, 4, 5].iter().map(|&x| x).collect());
/// ```
#[unstable = "matches collection reform specification, waiting for dust to settle"]
pub fn union<'a>(&'a self, other: &'a TrieSet) -> UnionItems<'a> {
UnionItems{a: self.iter().peekable(), b: other.iter().peekable()}
}
/// Return the number of elements in the set
///
/// # Example
@ -368,6 +468,39 @@ pub struct SetItems<'a> {
iter: Entries<'a, ()>
}
/// An iterator producing elements in the set difference (in-order).
pub struct DifferenceItems<'a> {
a: Peekable<uint, SetItems<'a>>,
b: Peekable<uint, SetItems<'a>>,
}
/// An iterator producing elements in the set symmetric difference (in-order).
pub struct SymDifferenceItems<'a> {
a: Peekable<uint, SetItems<'a>>,
b: Peekable<uint, SetItems<'a>>,
}
/// An iterator producing elements in the set intersection (in-order).
pub struct IntersectionItems<'a> {
a: Peekable<uint, SetItems<'a>>,
b: Peekable<uint, SetItems<'a>>,
}
/// An iterator producing elements in the set union (in-order).
pub struct UnionItems<'a> {
a: Peekable<uint, SetItems<'a>>,
b: Peekable<uint, SetItems<'a>>,
}
/// Compare `x` and `y`, but return `short` if x is None and `long` if y is None
fn cmp_opt(x: Option<&uint>, y: Option<&uint>, short: Ordering, long: Ordering) -> Ordering {
match (x, y) {
(None , _ ) => short,
(_ , None ) => long,
(Some(x1), Some(y1)) => x1.cmp(y1),
}
}
impl<'a> Iterator<uint> for SetItems<'a> {
fn next(&mut self) -> Option<uint> {
self.iter.next().map(|(key, _)| key)
@ -378,6 +511,60 @@ impl<'a> Iterator<uint> for SetItems<'a> {
}
}
impl<'a> Iterator<uint> for DifferenceItems<'a> {
fn next(&mut self) -> Option<uint> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Less, Less) {
Less => return self.a.next(),
Equal => { self.a.next(); self.b.next(); }
Greater => { self.b.next(); }
}
}
}
}
impl<'a> Iterator<uint> for SymDifferenceItems<'a> {
fn next(&mut self) -> Option<uint> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => return self.a.next(),
Equal => { self.a.next(); self.b.next(); }
Greater => return self.b.next(),
}
}
}
}
impl<'a> Iterator<uint> for IntersectionItems<'a> {
fn next(&mut self) -> Option<uint> {
loop {
let o_cmp = match (self.a.peek(), self.b.peek()) {
(None , _ ) => None,
(_ , None ) => None,
(Some(a1), Some(b1)) => Some(a1.cmp(b1)),
};
match o_cmp {
None => return None,
Some(Less) => { self.a.next(); }
Some(Equal) => { self.b.next(); return self.a.next() }
Some(Greater) => { self.b.next(); }
}
}
}
}
impl<'a> Iterator<uint> for UnionItems<'a> {
fn next(&mut self) -> Option<uint> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => return self.a.next(),
Equal => { self.b.next(); return self.a.next() }
Greater => return self.b.next(),
}
}
}
}
#[cfg(test)]
mod test {
use std::prelude::*;
@ -471,4 +658,84 @@ mod test {
assert!(b > a && b >= a);
assert!(a < b && a <= b);
}
fn check(a: &[uint],
b: &[uint],
expected: &[uint],
f: |&TrieSet, &TrieSet, f: |uint| -> bool| -> bool) {
let mut set_a = TrieSet::new();
let mut set_b = TrieSet::new();
for x in a.iter() { assert!(set_a.insert(*x)) }
for y in b.iter() { assert!(set_b.insert(*y)) }
let mut i = 0;
f(&set_a, &set_b, |x| {
assert_eq!(x, expected[i]);
i += 1;
true
});
assert_eq!(i, expected.len());
}
#[test]
fn test_intersection() {
fn check_intersection(a: &[uint], b: &[uint], expected: &[uint]) {
check(a, b, expected, |x, y, f| x.intersection(y).all(f))
}
check_intersection(&[], &[], &[]);
check_intersection(&[1, 2, 3], &[], &[]);
check_intersection(&[], &[1, 2, 3], &[]);
check_intersection(&[2], &[1, 2, 3], &[2]);
check_intersection(&[1, 2, 3], &[2], &[2]);
check_intersection(&[11, 1, 3, 77, 103, 5],
&[2, 11, 77, 5, 3],
&[3, 5, 11, 77]);
}
#[test]
fn test_difference() {
fn check_difference(a: &[uint], b: &[uint], expected: &[uint]) {
check(a, b, expected, |x, y, f| x.difference(y).all(f))
}
check_difference(&[], &[], &[]);
check_difference(&[1, 12], &[], &[1, 12]);
check_difference(&[], &[1, 2, 3, 9], &[]);
check_difference(&[1, 3, 5, 9, 11],
&[3, 9],
&[1, 5, 11]);
check_difference(&[11, 22, 33, 40, 42],
&[14, 23, 34, 38, 39, 50],
&[11, 22, 33, 40, 42]);
}
#[test]
fn test_symmetric_difference() {
fn check_symmetric_difference(a: &[uint], b: &[uint], expected: &[uint]) {
check(a, b, expected, |x, y, f| x.symmetric_difference(y).all(f))
}
check_symmetric_difference(&[], &[], &[]);
check_symmetric_difference(&[1, 2, 3], &[2], &[1, 3]);
check_symmetric_difference(&[2], &[1, 2, 3], &[1, 3]);
check_symmetric_difference(&[1, 3, 5, 9, 11],
&[3, 9, 14, 22],
&[1, 5, 11, 14, 22]);
}
#[test]
fn test_union() {
fn check_union(a: &[uint], b: &[uint], expected: &[uint]) {
check(a, b, expected, |x, y, f| x.union(y).all(f))
}
check_union(&[], &[], &[]);
check_union(&[1, 2, 3], &[2], &[1, 2, 3]);
check_union(&[2], &[1, 2, 3], &[1, 2, 3]);
check_union(&[1, 3, 5, 9, 11, 16, 19, 24],
&[1, 5, 9, 13, 19],
&[1, 3, 5, 9, 11, 13, 16, 19, 24]);
}
}