Simplify interleave/deinterleave and fix for odd-length vectors.

This commit is contained in:
Caleb Zulawski 2022-08-01 00:34:58 -04:00
parent 3183afb6b5
commit 6bf5128235

View file

@ -265,13 +265,10 @@ where
/// Interleave two vectors.
///
/// Produces two vectors with lanes taken alternately from `self` and `other`.
/// The resulting vectors contain lanes taken alternatively from `self` and `other`, first
/// filling the first result, and then the second.
///
/// The first result contains the first `LANES / 2` lanes from `self` and `other`,
/// alternating, starting with the first lane of `self`.
///
/// The second result contains the last `LANES / 2` lanes from `self` and `other`,
/// alternating, starting with the lane `LANES / 2` from the start of `self`.
/// The reverse of this operation is [`Simd::deinterleave`].
///
/// ```
/// #![feature(portable_simd)]
@ -285,29 +282,17 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn interleave(self, other: Self) -> (Self, Self) {
const fn lo<const LANES: usize>() -> [Which; LANES] {
const fn interleave<const LANES: usize>(high: bool) -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
let mut i = 0;
while i < LANES {
let offset = i / 2;
idx[i] = if i % 2 == 0 {
Which::First(offset)
// Treat the source as a concatenated vector
let dst_index = if high { i + LANES } else { i };
let src_index = dst_index / 2 + (dst_index % 2) * LANES;
idx[i] = if src_index < LANES {
Which::First(src_index)
} else {
Which::Second(offset)
};
i += 1;
}
idx
}
const fn hi<const LANES: usize>() -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
let mut i = 0;
while i < LANES {
let offset = (LANES + i) / 2;
idx[i] = if i % 2 == 0 {
Which::First(offset)
} else {
Which::Second(offset)
Which::Second(src_index % LANES)
};
i += 1;
}
@ -318,18 +303,14 @@ where
struct Hi;
impl<const LANES: usize> Swizzle2<LANES, LANES> for Lo {
const INDEX: [Which; LANES] = lo::<LANES>();
const INDEX: [Which; LANES] = interleave::<LANES>(false);
}
impl<const LANES: usize> Swizzle2<LANES, LANES> for Hi {
const INDEX: [Which; LANES] = hi::<LANES>();
const INDEX: [Which; LANES] = interleave::<LANES>(true);
}
if LANES == 1 {
(self, other)
} else {
(Lo::swizzle2(self, other), Hi::swizzle2(self, other))
}
(Lo::swizzle2(self, other), Hi::swizzle2(self, other))
}
/// Deinterleave two vectors.
@ -340,6 +321,8 @@ where
/// The second result takes every other lane of `self` and then `other`, starting with
/// the second lane.
///
/// The reverse of this operation is [`Simd::interleave`].
///
/// ```
/// #![feature(portable_simd)]
/// # use core::simd::Simd;
@ -352,22 +335,17 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn deinterleave(self, other: Self) -> (Self, Self) {
const fn even<const LANES: usize>() -> [Which; LANES] {
const fn deinterleave<const LANES: usize>(second: bool) -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
let mut i = 0;
while i < LANES / 2 {
idx[i] = Which::First(2 * i);
idx[i + LANES / 2] = Which::Second(2 * i);
i += 1;
}
idx
}
const fn odd<const LANES: usize>() -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
let mut i = 0;
while i < LANES / 2 {
idx[i] = Which::First(2 * i + 1);
idx[i + LANES / 2] = Which::Second(2 * i + 1);
while i < LANES {
// Treat the source as a concatenated vector
let src_index = i * 2 + if second { 1 } else { 0 };
idx[i] = if src_index < LANES {
Which::First(src_index)
} else {
Which::Second(src_index % LANES)
};
i += 1;
}
idx
@ -377,11 +355,11 @@ where
struct Odd;
impl<const LANES: usize> Swizzle2<LANES, LANES> for Even {
const INDEX: [Which; LANES] = even::<LANES>();
const INDEX: [Which; LANES] = deinterleave::<LANES>(false);
}
impl<const LANES: usize> Swizzle2<LANES, LANES> for Odd {
const INDEX: [Which; LANES] = odd::<LANES>();
const INDEX: [Which; LANES] = deinterleave::<LANES>(true);
}
if LANES == 1 {