From a0643ee9ae5726edaa382a1a125319688477ec98 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Tue, 7 Jan 2014 00:14:37 +1100 Subject: [PATCH] std::trie: add an mutable-values iterator. --- src/libstd/trie.rs | 168 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) diff --git a/src/libstd/trie.rs b/src/libstd/trie.rs index 4f3f253d5e24..b66472c72cb7 100644 --- a/src/libstd/trie.rs +++ b/src/libstd/trie.rs @@ -156,6 +156,16 @@ impl TrieMap { } } + /// Get an iterator over the key-value pairs in the map, with the + /// ability to mutate the values. + pub fn mut_iter<'a>(&'a mut self) -> TrieMapMutIterator<'a, T> { + TrieMapMutIterator { + stack: ~[self.root.children.mut_iter()], + remaining_min: self.length, + remaining_max: self.length + } + } + // If `upper` is true then returns upper_bound else returns lower_bound. #[inline] fn bound<'a>(&'a self, key: uint, upper: bool) -> TrieMapIterator<'a, T> { @@ -202,6 +212,63 @@ impl TrieMap { pub fn upper_bound<'a>(&'a self, key: uint) -> TrieMapIterator<'a, T> { self.bound(key, true) } + // If `upper` is true then returns upper_bound else returns lower_bound. + #[inline] + fn mut_bound<'a>(&'a mut self, key: uint, upper: bool) -> TrieMapMutIterator<'a, T> { + // we need an unsafe pointer here because we are borrowing + // references to the internals of each of these + // nodes. + // + // However, we're allowed to flaunt rustc like this because we + // never actually modify the "shape" of the nodes. The only + // place that mutation is can actually occur is of the actual + // values of the TrieMap (as the return value of the + // iterator), i.e. we can never cause a deallocation of any + // TrieNodes so this pointer is always valid. + let mut node = &mut self.root as *mut TrieNode; + + let mut idx = 0; + let mut it = TrieMapMutIterator { + stack: ~[], + remaining_min: 0, + remaining_max: self.length + }; + loop { + let children = unsafe {&mut (*node).children}; + let child_id = chunk(key, idx); + match children[child_id] { + Internal(ref mut n) => { + node = &mut **n as *mut TrieNode; + } + External(stored, _) => { + if stored < key || (upper && stored == key) { + it.stack.push(children.mut_slice_from(child_id + 1).mut_iter()); + } else { + it.stack.push(children.mut_slice_from(child_id).mut_iter()); + } + return it; + } + Nothing => { + it.stack.push(children.mut_slice_from(child_id + 1).mut_iter()); + return it + } + } + it.stack.push(children.mut_slice_from(child_id + 1).mut_iter()); + idx += 1; + } + } + + /// Get an iterator pointing to the first key-value pair whose key is not less than `key`. + /// If all keys in the map are less than `key` an empty iterator is returned. + pub fn mut_lower_bound<'a>(&'a mut self, key: uint) -> TrieMapMutIterator<'a, T> { + self.mut_bound(key, false) + } + + /// Get an iterator pointing to the first key-value pair whose key is greater than `key`. + /// If all keys in the map are not greater than `key` an empty iterator is returned. + pub fn mut_upper_bound<'a>(&'a mut self, key: uint) -> TrieMapMutIterator<'a, T> { + self.mut_bound(key, true) + } } impl FromIterator<(uint, T)> for TrieMap { @@ -482,6 +549,47 @@ impl<'a, T> Iterator<(uint, &'a T)> for TrieMapIterator<'a, T> { } } +/// Forward iterator over the key-value pairs of a map, with the +/// values being mutable. +pub struct TrieMapMutIterator<'a, T> { + priv stack: ~[vec::VecMutIterator<'a, Child>], + priv remaining_min: uint, + priv remaining_max: uint +} + +impl<'a, T> Iterator<(uint, &'a mut T)> for TrieMapMutIterator<'a, T> { + fn next(&mut self) -> Option<(uint, &'a mut T)> { + while !self.stack.is_empty() { + match self.stack[self.stack.len() - 1].next() { + None => { + self.stack.pop(); + } + Some(child) => { + match *child { + Internal(ref mut node) => { + self.stack.push(node.children.mut_iter()); + } + External(key, ref mut value) => { + self.remaining_max -= 1; + if self.remaining_min > 0 { + self.remaining_min -= 1; + } + return Some((key, value)); + } + Nothing => {} + } + } + } + } + return None; + } + + #[inline] + fn size_hint(&self) -> (uint, Option) { + (self.remaining_min, Some(self.remaining_max)) + } +} + /// Forward iterator over a set pub struct TrieSetIterator<'a> { priv iter: TrieMapIterator<'a, ()> @@ -712,6 +820,30 @@ mod test_map { assert_eq!(i, last - first); } + #[test] + fn test_mut_iter() { + let mut empty_map : TrieMap = TrieMap::new(); + assert!(empty_map.mut_iter().next().is_none()); + + let first = uint::max_value - 10000; + let last = uint::max_value; + + let mut map = TrieMap::new(); + for x in range(first, last).invert() { + map.insert(x, x / 2); + } + + let mut i = 0; + for (k, v) in map.mut_iter() { + assert_eq!(k, first + i); + *v -= k / 2; + i += 1; + } + assert_eq!(i, last - first); + + assert!(map.iter().all(|(_, &v)| v == 0)); + } + #[test] fn test_bound() { let empty_map : TrieMap = TrieMap::new(); @@ -753,6 +885,42 @@ mod test_map { assert_eq!(ub.next(), None); } } + + #[test] + fn test_mut_bound() { + let empty_map : TrieMap = TrieMap::new(); + assert_eq!(empty_map.lower_bound(0).next(), None); + assert_eq!(empty_map.upper_bound(0).next(), None); + + let mut m_lower = TrieMap::new(); + let mut m_upper = TrieMap::new(); + for i in range(0u, 100) { + m_lower.insert(2 * i, 4 * i); + m_upper.insert(2 * i, 4 * i); + } + + for i in range(0u, 199) { + let mut lb_it = m_lower.mut_lower_bound(i); + let (k, v) = lb_it.next().unwrap(); + let lb = i + i % 2; + assert_eq!(lb, k); + *v -= k; + } + + for i in range(0u, 198) { + let mut ub_it = m_upper.mut_upper_bound(i); + let (k, v) = ub_it.next().unwrap(); + let ub = i + 2 - i % 2; + assert_eq!(ub, k); + *v -= k; + } + + assert!(m_lower.mut_lower_bound(199).next().is_none()); + assert!(m_upper.mut_upper_bound(198).next().is_none()); + + assert!(m_lower.iter().all(|(_, &x)| x == 0)); + assert!(m_upper.iter().all(|(_, &x)| x == 0)); + } } #[cfg(test)]