Simplify Swizzle trait and condense all swizzles into this trait

This commit is contained in:
Caleb Zulawski 2023-10-01 12:31:39 -04:00
parent 21fa6af5c7
commit 4fc3ce733d
5 changed files with 174 additions and 197 deletions

View file

@ -2,10 +2,7 @@
// Code ported from the `packed_simd` crate
// Run this code with `cargo test --example matrix_inversion`
#![feature(array_chunks, portable_simd)]
use core_simd::simd::{
prelude::*,
Which::{self, *},
};
use core_simd::simd::prelude::*;
// Gotta define our own 4x4 matrix since Rust doesn't ship multidim arrays yet :^)
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
@ -166,10 +163,10 @@ pub fn simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
let m_2 = f32x4::from_array(m[2]);
let m_3 = f32x4::from_array(m[3]);
const SHUFFLE01: [Which; 4] = [First(0), First(1), Second(0), Second(1)];
const SHUFFLE02: [Which; 4] = [First(0), First(2), Second(0), Second(2)];
const SHUFFLE13: [Which; 4] = [First(1), First(3), Second(1), Second(3)];
const SHUFFLE23: [Which; 4] = [First(2), First(3), Second(2), Second(3)];
const SHUFFLE01: [usize; 4] = [0, 1, 4, 5];
const SHUFFLE02: [usize; 4] = [0, 2, 4, 6];
const SHUFFLE13: [usize; 4] = [1, 3, 5, 7];
const SHUFFLE23: [usize; 4] = [2, 3, 6, 7];
let tmp = simd_swizzle!(m_0, m_1, SHUFFLE01);
let row1 = simd_swizzle!(m_2, m_3, SHUFFLE01);

View file

@ -5,6 +5,7 @@
const_mut_refs,
convert_float_to_int,
decl_macro,
inline_const,
intra_doc_pointers,
platform_intrinsics,
repr_simd,

View file

@ -1,17 +1,15 @@
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
/// Constructs a new SIMD vector by copying elements from selected lanes in other vectors.
/// Constructs a new SIMD vector by copying elements from selected elements in other vectors.
///
/// When swizzling one vector, lanes are selected by a `const` array of `usize`,
/// like [`Swizzle`].
/// When swizzling one vector, elements are selected like [`Swizzle::swizzle`].
///
/// When swizzling two vectors, lanes are selected by a `const` array of [`Which`],
/// like [`Swizzle2`].
/// When swizzling two vectors, elements are selected like [`Swizzle::concat_swizzle`].
///
/// # Examples
///
/// With a single SIMD vector, the const array specifies lane indices in that vector:
/// With a single SIMD vector, the const array specifies element indices in that vector:
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{u32x2, u32x4, simd_swizzle};
@ -21,25 +19,27 @@ use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
/// let r: u32x4 = simd_swizzle!(v, [3, 0, 1, 2]);
/// assert_eq!(r.to_array(), [13, 10, 11, 12]);
///
/// // Changing the number of lanes
/// // Changing the number of elements
/// let r: u32x2 = simd_swizzle!(v, [3, 1]);
/// assert_eq!(r.to_array(), [13, 11]);
/// ```
///
/// With two input SIMD vectors, the const array uses `Which` to specify the source of each index:
/// With two input SIMD vectors, the const array specifies element indices in the concatenation of
/// those vectors:
/// ```
/// # #![feature(portable_simd)]
/// # use core::simd::{u32x2, u32x4, simd_swizzle, Which};
/// use Which::{First, Second};
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{u32x2, u32x4, simd_swizzle};
/// let a = u32x4::from_array([0, 1, 2, 3]);
/// let b = u32x4::from_array([4, 5, 6, 7]);
///
/// // Keeping the same size
/// let r: u32x4 = simd_swizzle!(a, b, [First(0), First(1), Second(2), Second(3)]);
/// let r: u32x4 = simd_swizzle!(a, b, [0, 1, 6, 7]);
/// assert_eq!(r.to_array(), [0, 1, 6, 7]);
///
/// // Changing the number of lanes
/// let r: u32x2 = simd_swizzle!(a, b, [First(0), Second(0)]);
/// // Changing the number of elements
/// let r: u32x2 = simd_swizzle!(a, b, [0, 4]);
/// assert_eq!(r.to_array(), [0, 4]);
/// ```
#[allow(unused_macros)]
@ -50,7 +50,7 @@ pub macro simd_swizzle {
{
use $crate::simd::Swizzle;
struct Impl;
impl<const LANES: usize> Swizzle<LANES, {$index.len()}> for Impl {
impl Swizzle<{$index.len()}> for Impl {
const INDEX: [usize; {$index.len()}] = $index;
}
Impl::swizzle($vector)
@ -60,127 +60,117 @@ pub macro simd_swizzle {
$first:expr, $second:expr, $index:expr $(,)?
) => {
{
use $crate::simd::{Which, Swizzle2};
use $crate::simd::Swizzle;
struct Impl;
impl<const LANES: usize> Swizzle2<LANES, {$index.len()}> for Impl {
const INDEX: [Which; {$index.len()}] = $index;
impl Swizzle<{$index.len()}> for Impl {
const INDEX: [usize; {$index.len()}] = $index;
}
Impl::swizzle2($first, $second)
Impl::concat_swizzle($first, $second)
}
}
}
/// Specifies a lane index into one of two SIMD vectors.
///
/// This is an input type for [Swizzle2] and helper macros like [simd_swizzle].
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Which {
/// Index of a lane in the first input SIMD vector.
First(usize),
/// Index of a lane in the second input SIMD vector.
Second(usize),
}
/// Create a vector from the elements of another vector.
pub trait Swizzle<const INPUT_LANES: usize, const OUTPUT_LANES: usize> {
/// Map from the lanes of the input vector to the output vector.
const INDEX: [usize; OUTPUT_LANES];
pub trait Swizzle<const N: usize> {
/// Map from the elements of the input vector to the output vector.
const INDEX: [usize; N];
/// Create a new vector from the lanes of `vector`.
/// Create a new vector from the elements of `vector`.
///
/// Lane `i` of the output is `vector[Self::INDEX[i]]`.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
fn swizzle<T>(vector: Simd<T, INPUT_LANES>) -> Simd<T, OUTPUT_LANES>
fn swizzle<T, const M: usize>(vector: Simd<T, M>) -> Simd<T, N>
where
T: SimdElement,
LaneCount<INPUT_LANES>: SupportedLaneCount,
LaneCount<OUTPUT_LANES>: SupportedLaneCount,
LaneCount<N>: SupportedLaneCount,
LaneCount<M>: SupportedLaneCount,
{
// Safety: `vector` is a vector, and `INDEX_IMPL` is a const array of u32.
unsafe { intrinsics::simd_shuffle(vector, vector, Self::INDEX_IMPL) }
// Safety: `vector` is a vector, and the index is a const array of u32.
unsafe {
intrinsics::simd_shuffle(
vector,
vector,
const {
let mut output = [0; N];
let mut i = 0;
while i < N {
let index = Self::INDEX[i];
assert!(index as u32 as usize == index);
assert!(index < M, "source element index exceeds input vector length");
output[i] = index as u32;
i += 1;
}
output
},
)
}
}
}
/// Create a vector from the elements of two other vectors.
pub trait Swizzle2<const INPUT_LANES: usize, const OUTPUT_LANES: usize> {
/// Map from the lanes of the input vectors to the output vector
const INDEX: [Which; OUTPUT_LANES];
/// Create a new vector from the lanes of `first` and `second`.
/// Create a new vector from the elements of `first` and `second`.
///
/// Lane `i` is `first[j]` when `Self::INDEX[i]` is `First(j)`, or `second[j]` when it is
/// `Second(j)`.
/// Lane `i` of the output is `concat[Self::INDEX[i]]`, where `concat` is the concatenation of
/// `first` and `second`.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
fn swizzle2<T>(
first: Simd<T, INPUT_LANES>,
second: Simd<T, INPUT_LANES>,
) -> Simd<T, OUTPUT_LANES>
fn concat_swizzle<T, const M: usize>(first: Simd<T, M>, second: Simd<T, M>) -> Simd<T, N>
where
T: SimdElement,
LaneCount<INPUT_LANES>: SupportedLaneCount,
LaneCount<OUTPUT_LANES>: SupportedLaneCount,
LaneCount<N>: SupportedLaneCount,
LaneCount<M>: SupportedLaneCount,
{
// Safety: `first` and `second` are vectors, and `INDEX_IMPL` is a const array of u32.
unsafe { intrinsics::simd_shuffle(first, second, Self::INDEX_IMPL) }
// Safety: `first` and `second` are vectors, and the index is a const array of u32.
unsafe {
intrinsics::simd_shuffle(
first,
second,
const {
let mut output = [0; N];
let mut i = 0;
while i < N {
let index = Self::INDEX[i];
assert!(index as u32 as usize == index);
assert!(index < 2 * M, "source element index exceeds input vector length");
output[i] = index as u32;
i += 1;
}
output
},
)
}
}
}
/// The `simd_shuffle` intrinsic expects `u32`, so do error checking and conversion here.
/// This trait hides `INDEX_IMPL` from the public API.
trait SwizzleImpl<const INPUT_LANES: usize, const OUTPUT_LANES: usize> {
const INDEX_IMPL: [u32; OUTPUT_LANES];
}
/// Create a new mask from the elements of `first` and `second`.
///
/// Element `i` of the output is `concat[Self::INDEX[i]]`, where `concat` is the concatenation of
/// `first` and `second`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original inputs"]
fn swizzle_mask<T, const M: usize>(vector: Mask<T, M>) -> Mask<T, N>
where
T: MaskElement,
LaneCount<N>: SupportedLaneCount,
LaneCount<M>: SupportedLaneCount,
{
// SAFETY: all elements of this mask come from another mask
unsafe { Mask::from_int_unchecked(Self::swizzle(vector.to_int())) }
}
impl<T, const INPUT_LANES: usize, const OUTPUT_LANES: usize> SwizzleImpl<INPUT_LANES, OUTPUT_LANES>
for T
where
T: Swizzle<INPUT_LANES, OUTPUT_LANES> + ?Sized,
{
const INDEX_IMPL: [u32; OUTPUT_LANES] = {
let mut output = [0; OUTPUT_LANES];
let mut i = 0;
while i < OUTPUT_LANES {
let index = Self::INDEX[i];
assert!(index as u32 as usize == index);
assert!(index < INPUT_LANES, "source lane exceeds input lane count",);
output[i] = index as u32;
i += 1;
}
output
};
}
/// The `simd_shuffle` intrinsic expects `u32`, so do error checking and conversion here.
/// This trait hides `INDEX_IMPL` from the public API.
trait Swizzle2Impl<const INPUT_LANES: usize, const OUTPUT_LANES: usize> {
const INDEX_IMPL: [u32; OUTPUT_LANES];
}
impl<T, const INPUT_LANES: usize, const OUTPUT_LANES: usize> Swizzle2Impl<INPUT_LANES, OUTPUT_LANES>
for T
where
T: Swizzle2<INPUT_LANES, OUTPUT_LANES> + ?Sized,
{
const INDEX_IMPL: [u32; OUTPUT_LANES] = {
let mut output = [0; OUTPUT_LANES];
let mut i = 0;
while i < OUTPUT_LANES {
let (offset, index) = match Self::INDEX[i] {
Which::First(index) => (false, index),
Which::Second(index) => (true, index),
};
assert!(index < INPUT_LANES, "source lane exceeds input lane count",);
// lanes are indexed by the first vector, then second vector
let index = if offset { index + INPUT_LANES } else { index };
assert!(index as u32 as usize == index);
output[i] = index as u32;
i += 1;
}
output
};
/// Create a new mask from the elements of `first` and `second`.
///
/// Element `i` of the output is `concat[Self::INDEX[i]]`, where `concat` is the concatenation of
/// `first` and `second`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original inputs"]
fn concat_swizzle_mask<T, const M: usize>(first: Mask<T, M>, second: Mask<T, M>) -> Mask<T, N>
where
T: MaskElement,
LaneCount<N>: SupportedLaneCount,
LaneCount<M>: SupportedLaneCount,
{
// SAFETY: all elements of this mask come from another mask
unsafe { Mask::from_int_unchecked(Self::concat_swizzle(first.to_int(), second.to_int())) }
}
}
impl<T, const LANES: usize> Simd<T, LANES>
@ -188,24 +178,22 @@ where
T: SimdElement,
LaneCount<LANES>: SupportedLaneCount,
{
/// Reverse the order of the lanes in the vector.
/// Reverse the order of the elements in the vector.
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn reverse(self) -> Self {
const fn reverse_index<const LANES: usize>() -> [usize; LANES] {
let mut index = [0; LANES];
let mut i = 0;
while i < LANES {
index[i] = LANES - i - 1;
i += 1;
}
index
}
struct Reverse;
impl<const LANES: usize> Swizzle<LANES, LANES> for Reverse {
const INDEX: [usize; LANES] = reverse_index::<LANES>();
impl<const N: usize> Swizzle<N> for Reverse {
const INDEX: [usize; N] = const {
let mut index = [0; N];
let mut i = 0;
while i < N {
index[i] = N - i - 1;
i += 1;
}
index
};
}
Reverse::swizzle(self)
@ -217,21 +205,19 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn rotate_lanes_left<const OFFSET: usize>(self) -> Self {
const fn rotate_index<const OFFSET: usize, const LANES: usize>() -> [usize; LANES] {
let offset = OFFSET % LANES;
let mut index = [0; LANES];
let mut i = 0;
while i < LANES {
index[i] = (i + offset) % LANES;
i += 1;
}
index
}
struct Rotate<const OFFSET: usize>;
impl<const OFFSET: usize, const LANES: usize> Swizzle<LANES, LANES> for Rotate<OFFSET> {
const INDEX: [usize; LANES] = rotate_index::<OFFSET, LANES>();
impl<const OFFSET: usize, const N: usize> Swizzle<N> for Rotate<OFFSET> {
const INDEX: [usize; N] = const {
let offset = OFFSET % N;
let mut index = [0; N];
let mut i = 0;
while i < N {
index[i] = (i + offset) % N;
i += 1;
}
index
};
}
Rotate::<OFFSET>::swizzle(self)
@ -243,21 +229,19 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn rotate_lanes_right<const OFFSET: usize>(self) -> Self {
const fn rotate_index<const OFFSET: usize, const LANES: usize>() -> [usize; LANES] {
let offset = LANES - OFFSET % LANES;
let mut index = [0; LANES];
let mut i = 0;
while i < LANES {
index[i] = (i + offset) % LANES;
i += 1;
}
index
}
struct Rotate<const OFFSET: usize>;
impl<const OFFSET: usize, const LANES: usize> Swizzle<LANES, LANES> for Rotate<OFFSET> {
const INDEX: [usize; LANES] = rotate_index::<OFFSET, LANES>();
impl<const OFFSET: usize, const N: usize> Swizzle<N> for Rotate<OFFSET> {
const INDEX: [usize; N] = const {
let offset = N - OFFSET % N;
let mut index = [0; N];
let mut i = 0;
while i < N {
index[i] = (i + offset) % N;
i += 1;
}
index
};
}
Rotate::<OFFSET>::swizzle(self)
@ -265,7 +249,7 @@ where
/// Interleave two vectors.
///
/// The resulting vectors contain lanes taken alternatively from `self` and `other`, first
/// The resulting vectors contain elements taken alternatively from `self` and `other`, first
/// filling the first result, and then the second.
///
/// The reverse of this operation is [`Simd::deinterleave`].
@ -282,18 +266,13 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn interleave(self, other: Self) -> (Self, Self) {
const fn interleave<const LANES: usize>(high: bool) -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
const fn interleave<const N: usize>(high: bool) -> [usize; N] {
let mut idx = [0; N];
let mut i = 0;
while i < LANES {
// Treat the source as a concatenated vector
let dst_index = if high { i + LANES } else { i };
let src_index = dst_index / 2 + (dst_index % 2) * LANES;
idx[i] = if src_index < LANES {
Which::First(src_index)
} else {
Which::Second(src_index % LANES)
};
while i < N {
let dst_index = if high { i + N } else { i };
let src_index = dst_index / 2 + (dst_index % 2) * N;
idx[i] = src_index;
i += 1;
}
idx
@ -302,24 +281,27 @@ where
struct Lo;
struct Hi;
impl<const LANES: usize> Swizzle2<LANES, LANES> for Lo {
const INDEX: [Which; LANES] = interleave::<LANES>(false);
impl<const N: usize> Swizzle<N> for Lo {
const INDEX: [usize; N] = interleave::<N>(false);
}
impl<const LANES: usize> Swizzle2<LANES, LANES> for Hi {
const INDEX: [Which; LANES] = interleave::<LANES>(true);
impl<const N: usize> Swizzle<N> for Hi {
const INDEX: [usize; N] = interleave::<N>(true);
}
(Lo::swizzle2(self, other), Hi::swizzle2(self, other))
(
Lo::concat_swizzle(self, other),
Hi::concat_swizzle(self, other),
)
}
/// Deinterleave two vectors.
///
/// The first result takes every other lane of `self` and then `other`, starting with
/// the first lane.
/// The first result takes every other element of `self` and then `other`, starting with
/// the first element.
///
/// The second result takes every other lane of `self` and then `other`, starting with
/// the second lane.
/// The second result takes every other element of `self` and then `other`, starting with
/// the second element.
///
/// The reverse of this operation is [`Simd::interleave`].
///
@ -335,17 +317,11 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn deinterleave(self, other: Self) -> (Self, Self) {
const fn deinterleave<const LANES: usize>(second: bool) -> [Which; LANES] {
let mut idx = [Which::First(0); LANES];
const fn deinterleave<const N: usize>(second: bool) -> [usize; N] {
let mut idx = [0; N];
let mut i = 0;
while i < LANES {
// Treat the source as a concatenated vector
let src_index = i * 2 + second as usize;
idx[i] = if src_index < LANES {
Which::First(src_index)
} else {
Which::Second(src_index % LANES)
};
while i < N {
idx[i] = i * 2 + second as usize;
i += 1;
}
idx
@ -354,14 +330,17 @@ where
struct Even;
struct Odd;
impl<const LANES: usize> Swizzle2<LANES, LANES> for Even {
const INDEX: [Which; LANES] = deinterleave::<LANES>(false);
impl<const N: usize> Swizzle<N> for Even {
const INDEX: [usize; N] = deinterleave::<N>(false);
}
impl<const LANES: usize> Swizzle2<LANES, LANES> for Odd {
const INDEX: [Which; LANES] = deinterleave::<LANES>(true);
impl<const N: usize> Swizzle<N> for Odd {
const INDEX: [usize; N] = deinterleave::<N>(true);
}
(Even::swizzle2(self, other), Odd::swizzle2(self, other))
(
Even::concat_swizzle(self, other),
Odd::concat_swizzle(self, other),
)
}
}

View file

@ -144,10 +144,10 @@ where
// This is preferred over `[value; N]`, since it's explicitly a splat:
// https://github.com/rust-lang/rust/issues/97804
struct Splat;
impl<const N: usize> Swizzle<1, N> for Splat {
impl<const N: usize> Swizzle<N> for Splat {
const INDEX: [usize; N] = [0; N];
}
Splat::swizzle(Simd::<T, 1>::from([value]))
Splat::swizzle::<T, 1>(Simd::<T, 1>::from([value]))
}
/// Returns an array reference containing the entire SIMD vector.

View file

@ -11,10 +11,10 @@ wasm_bindgen_test_configure!(run_in_browser);
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn swizzle() {
struct Index;
impl Swizzle<4, 4> for Index {
impl Swizzle<4> for Index {
const INDEX: [usize; 4] = [2, 1, 3, 0];
}
impl Swizzle<4, 2> for Index {
impl Swizzle<2> for Index {
const INDEX: [usize; 2] = [1, 1];
}