Fix poor worst case performance of set intersection (and union, somewhat) on asymmetrically sized sets and extend unit tests slightly beyond that

This commit is contained in:
Stein Somers 2018-12-21 14:56:52 +01:00
parent 01c6ea2f37
commit f9f71cc324

View file

@ -420,9 +420,16 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
Intersection {
iter: self.iter(),
other,
if self.len() <= other.len() {
Intersection {
iter: self.iter(),
other,
}
} else {
Intersection {
iter: other.iter(),
other: self,
}
}
}
@ -446,7 +453,15 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
Union { iter: self.iter().chain(other.difference(self)) }
if self.len() <= other.len() {
Union {
iter: self.iter().chain(other.difference(self)),
}
} else {
Union {
iter: other.iter().chain(self.difference(other)),
}
}
}
/// Returns the number of elements in the set.
@ -1504,6 +1519,8 @@ mod test_set {
fn test_intersection() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.intersection(&b).next().is_none());
assert!(b.intersection(&a).next().is_none());
assert!(a.insert(11));
assert!(a.insert(1));
@ -1528,6 +1545,22 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); // make a bigger than b
i = 0;
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.intersection(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
@ -1583,11 +1616,11 @@ mod test_set {
fn test_union() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.union(&b).next().is_none());
assert!(b.union(&a).next().is_none());
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(a.insert(16));
assert!(a.insert(19));
@ -1607,6 +1640,23 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); // make a bigger than b
assert!(a.insert(5));
i = 0;
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.union(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]