Auto merge of #149495 - scottmcm:assume-filter-count, r=Mark-Simulacrum
Assume the returned value in `.filter(…).count()` Similar to how this helps in `slice::Iter::position`, LLVM sometimes loses track of how high this can get, so for `TrustedLen` iterators tell it what the upper bound is.
This commit is contained in:
commit
da2544bfbe
3 changed files with 107 additions and 2 deletions
|
|
@ -4,7 +4,7 @@ use core::ops::ControlFlow;
|
|||
|
||||
use crate::fmt;
|
||||
use crate::iter::adapters::SourceIter;
|
||||
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused};
|
||||
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
|
||||
use crate::num::NonZero;
|
||||
use crate::ops::Try;
|
||||
|
||||
|
|
@ -138,7 +138,13 @@ where
|
|||
move |x| predicate(&x) as usize
|
||||
}
|
||||
|
||||
self.iter.map(to_usize(self.predicate)).sum()
|
||||
let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
|
||||
let total = self.iter.map(to_usize(self.predicate)).sum();
|
||||
// SAFETY: `total` and `before` came from the same iterator of type `I`
|
||||
unsafe {
|
||||
<I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
|
@ -214,3 +220,34 @@ unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
|
|||
const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
|
||||
const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
|
||||
}
|
||||
|
||||
trait SpecAssumeCount {
|
||||
/// # Safety
|
||||
///
|
||||
/// `count` must be an number of items actually read from the iterator.
|
||||
///
|
||||
/// `upper` must either:
|
||||
/// - have come from `size_hint().1` on the iterator, or
|
||||
/// - be `usize::MAX` which will vacuously do nothing.
|
||||
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
|
||||
}
|
||||
|
||||
impl<I: Iterator> SpecAssumeCount for I {
|
||||
#[inline]
|
||||
#[rustc_inherit_overflow_checks]
|
||||
default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
|
||||
// In the default we can't trust the `upper` for soundness
|
||||
// because it came from an untrusted `size_hint`.
|
||||
|
||||
// In debug mode we might as well check that the size_hint wasn't too small
|
||||
let _ = upper - count;
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: TrustedLen> SpecAssumeCount for I {
|
||||
#[inline]
|
||||
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
|
||||
// SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
|
||||
unsafe { crate::hint::assert_unchecked(count <= upper) }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
//@ compile-flags: -Copt-level=3
|
||||
//@ edition: 2024
|
||||
|
||||
#![crate_type = "lib"]
|
||||
|
||||
// Similar to how we `assume` that `slice::Iter::position` is within the length,
|
||||
// check that `count` also does that for `TrustedLen` iterators.
|
||||
// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780
|
||||
|
||||
// CHECK-LABEL: @filter_count_untrusted
|
||||
#[unsafe(no_mangle)]
|
||||
pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 {
|
||||
// CHECK-NOT: llvm.assume
|
||||
// CHECK: call void @{{.+}}unwrap_failed
|
||||
// CHECK-NOT: llvm.assume
|
||||
let mut iter = bar.iter();
|
||||
let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen
|
||||
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @filter_count_trusted
|
||||
#[unsafe(no_mangle)]
|
||||
pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 {
|
||||
// CHECK-NOT: unwrap_failed
|
||||
// CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235
|
||||
// CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]])
|
||||
// CHECK-NOT: unwrap_failed
|
||||
let iter = bar.iter();
|
||||
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
|
||||
}
|
||||
|
||||
// CHECK: ; core::result::unwrap_failed
|
||||
// CHECK-NEXT: Function Attrs
|
||||
// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed
|
||||
34
tests/ui/iterators/iter-filter-count-debug-check.rs
Normal file
34
tests/ui/iterators/iter-filter-count-debug-check.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
//@ run-pass
|
||||
//@ needs-unwind
|
||||
//@ ignore-backends: gcc
|
||||
//@ compile-flags: -C overflow-checks
|
||||
|
||||
use std::panic;
|
||||
|
||||
struct Lies(usize);
|
||||
|
||||
impl Iterator for Lies {
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<usize> {
|
||||
if self.0 == 0 {
|
||||
None
|
||||
} else {
|
||||
self.0 -= 1;
|
||||
Some(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(0, Some(2))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let r = panic::catch_unwind(|| {
|
||||
// This returns more items than its `size_hint` said was possible,
|
||||
// which `Filter::count` detects via `overflow-checks`.
|
||||
let _ = Lies(10).filter(|&x| x > 3).count();
|
||||
});
|
||||
assert!(r.is_err());
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue