Auto merge of #130887 - Soveu:repeatn, r=scottmcm

Safer implementation of RepeatN

I've seen the "Use MaybeUninit for RepeatN" commit while reading This Week In Rust and immediately thought about something I've written some time ago - https://github.com/Soveu/repeat_finite/blob/master/src/lib.rs.

Using the fact, that `Option` will find niche in `(T, NonZeroUsize)`, we can construct something that has the same size as `(T, usize)` while completely getting rid of `MaybeUninit`.
This leaves only `unsafe` on `TrustedLen`, which is pretty neat.
This commit is contained in:
bors 2025-06-18 03:18:10 +00:00
commit 27733d46d7
2 changed files with 48 additions and 122 deletions

View file

@ -1,8 +1,7 @@
use crate::fmt;
use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
use crate::mem::MaybeUninit;
use crate::num::NonZero;
use crate::ops::{NeverShortCircuit, Try};
use crate::ops::Try;
/// Creates a new iterator that repeats a single element a given number of times.
///
@ -58,14 +57,20 @@ use crate::ops::{NeverShortCircuit, Try};
#[inline]
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
let element = if count == 0 {
// `element` gets dropped eagerly.
MaybeUninit::uninit()
} else {
MaybeUninit::new(element)
};
RepeatN { inner: RepeatNInner::new(element, count) }
}
RepeatN { element, count }
#[derive(Clone, Copy)]
struct RepeatNInner<T> {
count: NonZero<usize>,
element: T,
}
impl<T> RepeatNInner<T> {
fn new(element: T, count: usize) -> Option<Self> {
let count = NonZero::<usize>::new(count)?;
Some(Self { element, count })
}
}
/// An iterator that repeats an element an exact number of times.
@ -73,63 +78,27 @@ pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
/// This `struct` is created by the [`repeat_n()`] function.
/// See its documentation for more.
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
#[derive(Clone)]
pub struct RepeatN<A> {
count: usize,
// Invariant: uninit iff count == 0.
element: MaybeUninit<A>,
inner: Option<RepeatNInner<A>>,
}
impl<A> RepeatN<A> {
/// Returns the element if it hasn't been dropped already.
fn element_ref(&self) -> Option<&A> {
if self.count > 0 {
// SAFETY: The count is non-zero, so it must be initialized.
Some(unsafe { self.element.assume_init_ref() })
} else {
None
}
}
/// If we haven't already dropped the element, return it in an option.
///
/// Clears the count so it won't be dropped again later.
#[inline]
fn take_element(&mut self) -> Option<A> {
if self.count > 0 {
self.count = 0;
// SAFETY: We just set count to zero so it won't be dropped again,
// and it used to be non-zero so it hasn't already been dropped.
let element = unsafe { self.element.assume_init_read() };
Some(element)
} else {
None
}
}
}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> Clone for RepeatN<A> {
fn clone(&self) -> RepeatN<A> {
RepeatN {
count: self.count,
element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
}
self.inner.take().map(|inner| inner.element)
}
}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RepeatN")
.field("count", &self.count)
.field("element", &self.element_ref())
.finish()
}
}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A> Drop for RepeatN<A> {
fn drop(&mut self) {
self.take_element();
let (count, element) = match self.inner.as_ref() {
Some(inner) => (inner.count.get(), Some(&inner.element)),
None => (0, None),
};
f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
}
}
@ -139,12 +108,17 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[inline]
fn next(&mut self) -> Option<A> {
if self.count > 0 {
// SAFETY: Just checked it's not empty
unsafe { Some(self.next_unchecked()) }
} else {
None
let inner = self.inner.as_mut()?;
let count = inner.count.get();
if let Some(decremented) = NonZero::<usize>::new(count - 1) {
// Order of these is important for optimization
let tmp = inner.element.clone();
inner.count = decremented;
return Some(tmp);
}
return self.take_element();
}
#[inline]
@ -155,52 +129,19 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[inline]
fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
let len = self.count;
let Some(inner) = self.inner.as_mut() else {
return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
};
if skip >= len {
self.take_element();
let len = inner.count.get();
if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
inner.count = new_len;
return Ok(());
}
if skip > len {
// SAFETY: we just checked that the difference is positive
Err(unsafe { NonZero::new_unchecked(skip - len) })
} else {
self.count = len - skip;
Ok(())
}
}
fn try_fold<B, F, R>(&mut self, mut acc: B, mut f: F) -> R
where
F: FnMut(B, A) -> R,
R: Try<Output = B>,
{
if self.count > 0 {
while self.count > 1 {
self.count -= 1;
// SAFETY: the count was larger than 1, so the element is
// initialized and hasn't been dropped.
acc = f(acc, unsafe { self.element.assume_init_ref().clone() })?;
}
// We could just set the count to zero directly, but doing it this
// way should make it easier for the optimizer to fold this tail
// into the loop when `clone()` is equivalent to copying.
self.count -= 1;
// SAFETY: we just set the count to zero from one, so the element
// is still initialized, has not been dropped yet and will not be
// accessed by future calls.
f(acc, unsafe { self.element.assume_init_read() })
} else {
try { acc }
}
}
fn fold<B, F>(mut self, init: B, f: F) -> B
where
F: FnMut(B, A) -> B,
{
self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
self.inner = None;
return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
}
#[inline]
@ -217,7 +158,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> ExactSizeIterator for RepeatN<A> {
fn len(&self) -> usize {
self.count
self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
}
}
@ -262,20 +203,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
#[unstable(feature = "trusted_len", issue = "37572")]
unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> UncheckedIterator for RepeatN<A> {
#[inline]
unsafe fn next_unchecked(&mut self) -> Self::Item {
// SAFETY: The caller promised the iterator isn't empty
self.count = unsafe { self.count.unchecked_sub(1) };
if self.count == 0 {
// SAFETY: the check above ensured that the count used to be non-zero,
// so element hasn't been dropped yet, and we just lowered the count to
// zero so it won't be dropped later, and thus it's okay to take it here.
unsafe { self.element.assume_init_read() }
} else {
// SAFETY: the count is non-zero, so it must have not been dropped yet.
let element = unsafe { self.element.assume_init_ref() };
A::clone(element)
}
}
}
impl<A: Clone> UncheckedIterator for RepeatN<A> {}

View file

@ -1,5 +1,6 @@
//@ compile-flags: -C opt-level=3
//@ only-x86_64
//@ needs-deterministic-layouts
#![crate_type = "lib"]
#![feature(iter_repeat_n)]
@ -25,10 +26,10 @@ pub fn iter_repeat_n_next(it: &mut std::iter::RepeatN<NotCopy>) -> Option<NotCop
// CHECK-NEXT: br i1 %[[COUNT_ZERO]], label %[[EMPTY:.+]], label %[[NOT_EMPTY:.+]]
// CHECK: [[NOT_EMPTY]]:
// CHECK-NEXT: %[[DEC:.+]] = add i64 %[[COUNT]], -1
// CHECK-NEXT: store i64 %[[DEC]]
// CHECK-NOT: br
// CHECK: %[[VAL:.+]] = load i16
// CHECK: %[[DEC:.+]] = add i64 %[[COUNT]], -1
// CHECK-NEXT: %[[VAL:.+]] = load i16
// CHECK-NEXT: store i64 %[[DEC]]
// CHECK-NEXT: br label %[[EMPTY]]
// CHECK: [[EMPTY]]: