Rollup merge of #57043 - ssomers:master, r=alexcrichton

Fix poor worst case performance of set intersection

Specifically, intersection of asymmetrically sized sets when the large set is on the left. See also the [latest answer on stackoverflow](https://stackoverflow.com/questions/35439376/python-set-intersection-is-faster-then-rust-hashset-intersection).

Also applied to the union member, where the effect is much less but still measurable.

Formatted the changed code only, does not increase the error count reported by tidy check, and tried to adhere to the spirit of the unit tests.
This commit is contained in:
Mazdak Farrokhzad 2019-01-14 20:31:51 +01:00 committed by GitHub
commit 5bc95de47d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -410,9 +410,16 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a HashSet<T, S>) -> Intersection<'a, T, S> {
Intersection {
iter: self.iter(),
other,
if self.len() <= other.len() {
Intersection {
iter: self.iter(),
other,
}
} else {
Intersection {
iter: other.iter(),
other: self,
}
}
}
@ -436,7 +443,15 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a HashSet<T, S>) -> Union<'a, T, S> {
Union { iter: self.iter().chain(other.difference(self)) }
if self.len() <= other.len() {
Union {
iter: self.iter().chain(other.difference(self)),
}
} else {
Union {
iter: other.iter().chain(self.difference(other)),
}
}
}
/// Returns the number of elements in the set.
@ -584,7 +599,11 @@ impl<T, S> HashSet<T, S>
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn is_disjoint(&self, other: &HashSet<T, S>) -> bool {
self.iter().all(|v| !other.contains(v))
if self.len() <= other.len() {
self.iter().all(|v| !other.contains(v))
} else {
other.iter().all(|v| !self.contains(v))
}
}
/// Returns `true` if the set is a subset of another,
@ -1494,6 +1513,7 @@ mod test_set {
fn test_intersection() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.intersection(&b).next().is_none());
assert!(a.insert(11));
assert!(a.insert(1));
@ -1518,6 +1538,22 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); // make a bigger than b
i = 0;
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.intersection(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
@ -1573,11 +1609,11 @@ mod test_set {
fn test_union() {
let mut a = HashSet::new();
let mut b = HashSet::new();
assert!(a.union(&b).next().is_none());
assert!(b.union(&a).next().is_none());
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(a.insert(16));
assert!(a.insert(19));
@ -1597,6 +1633,23 @@ mod test_set {
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); // make a bigger than b
assert!(a.insert(5));
i = 0;
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.union(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]