Auto merge of #149495 - scottmcm:assume-filter-count, r=Mark-Simulacrum

Assume the returned value in `.filter(…).count()`

Similar to how this helps in `slice::Iter::position`, LLVM sometimes loses track of how high this can get, so for `TrustedLen` iterators tell it what the upper bound is.
This commit is contained in:
bors 2025-12-06 09:13:21 +00:00
commit da2544bfbe
3 changed files with 107 additions and 2 deletions

View file

@ -4,7 +4,7 @@ use core::ops::ControlFlow;
use crate::fmt;
use crate::iter::adapters::SourceIter;
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused};
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
use crate::num::NonZero;
use crate::ops::Try;
@ -138,7 +138,13 @@ where
move |x| predicate(&x) as usize
}
self.iter.map(to_usize(self.predicate)).sum()
let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
let total = self.iter.map(to_usize(self.predicate)).sum();
// SAFETY: `total` and `before` came from the same iterator of type `I`
unsafe {
<I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
}
total
}
#[inline]
@ -214,3 +220,34 @@ unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
}
trait SpecAssumeCount {
/// # Safety
///
/// `count` must be an number of items actually read from the iterator.
///
/// `upper` must either:
/// - have come from `size_hint().1` on the iterator, or
/// - be `usize::MAX` which will vacuously do nothing.
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
}
impl<I: Iterator> SpecAssumeCount for I {
#[inline]
#[rustc_inherit_overflow_checks]
default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
// In the default we can't trust the `upper` for soundness
// because it came from an untrusted `size_hint`.
// In debug mode we might as well check that the size_hint wasn't too small
let _ = upper - count;
}
}
impl<I: TrustedLen> SpecAssumeCount for I {
#[inline]
unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
// SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
unsafe { crate::hint::assert_unchecked(count <= upper) }
}
}

View file

@ -0,0 +1,34 @@
//@ compile-flags: -Copt-level=3
//@ edition: 2024
#![crate_type = "lib"]
// Similar to how we `assume` that `slice::Iter::position` is within the length,
// check that `count` also does that for `TrustedLen` iterators.
// See https://rust-lang.zulipchat.com/#narrow/channel/122651-general/topic/Overflow-chk.20removed.20for.20array.20of.2059.2C.20but.20not.2060.2C.20elems/with/561070780
// CHECK-LABEL: @filter_count_untrusted
#[unsafe(no_mangle)]
pub fn filter_count_untrusted(bar: &[u8; 1234]) -> u16 {
// CHECK-NOT: llvm.assume
// CHECK: call void @{{.+}}unwrap_failed
// CHECK-NOT: llvm.assume
let mut iter = bar.iter();
let iter = std::iter::from_fn(|| iter.next()); // Make it not TrustedLen
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
}
// CHECK-LABEL: @filter_count_trusted
#[unsafe(no_mangle)]
pub fn filter_count_trusted(bar: &[u8; 1234]) -> u16 {
// CHECK-NOT: unwrap_failed
// CHECK: %[[ASSUME:.+]] = icmp ult {{i64|i32|i16}} %{{.+}}, 1235
// CHECK-NEXT: tail call void @llvm.assume(i1 %[[ASSUME]])
// CHECK-NOT: unwrap_failed
let iter = bar.iter();
u16::try_from(iter.filter(|v| **v == 0).count()).unwrap()
}
// CHECK: ; core::result::unwrap_failed
// CHECK-NEXT: Function Attrs
// CHECK-NEXT: declare{{.+}}void @{{.+}}unwrap_failed

View file

@ -0,0 +1,34 @@
//@ run-pass
//@ needs-unwind
//@ ignore-backends: gcc
//@ compile-flags: -C overflow-checks
use std::panic;
struct Lies(usize);
impl Iterator for Lies {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if self.0 == 0 {
None
} else {
self.0 -= 1;
Some(self.0)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(2))
}
}
fn main() {
let r = panic::catch_unwind(|| {
// This returns more items than its `size_hint` said was possible,
// which `Filter::count` detects via `overflow-checks`.
let _ = Lies(10).filter(|&x| x > 3).count();
});
assert!(r.is_err());
}