From 082e3c8a5da8146b4e3d382d4f84a8a6847dd783 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Fri, 17 Nov 2023 10:15:12 -0500 Subject: [PATCH] Workaround simd_bitmask limitations --- crates/core_simd/src/masks/full_masks.rs | 90 +++++++++++++++++++++--- crates/core_simd/src/swizzle.rs | 16 ++--- crates/core_simd/tests/masks.rs | 9 +-- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index 73a0d8987009..a529490f3a21 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -207,40 +207,108 @@ where } #[inline] - pub(crate) fn to_bitmask_integer(self) -> u64 { - let resized = self.to_int().extend::<64>(T::FALSE); + unsafe fn to_bitmask_impl(self) -> U + where + LaneCount: SupportedLaneCount, + { + let resized = self.to_int().resize::(T::FALSE); - // SAFETY: `resized` is an integer vector with length 64 - let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) }; + // Safety: `resized` is an integer vector with length M, which must match T + let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) }; // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - bitmask.reverse_bits() + bitmask.reverse_bits(M) } else { bitmask } } #[inline] - pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { + unsafe fn from_bitmask_impl(bitmask: U) -> Self + where + LaneCount: SupportedLaneCount, + { // LLVM assumes bit order should match endianness let bitmask = if cfg!(target_endian = "big") { - bitmask.reverse_bits() + bitmask.reverse_bits(M) } else { bitmask }; // SAFETY: `mask` is the correct bitmask type for a u64 bitmask - let mask: Simd = unsafe { + let mask: Simd = unsafe { intrinsics::simd_select_bitmask( bitmask, - Simd::::splat(T::TRUE), - Simd::::splat(T::FALSE), + Simd::::splat(T::TRUE), + Simd::::splat(T::FALSE), ) }; // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` - unsafe { Self::from_int_unchecked(mask.extend::(T::FALSE)) } + unsafe { Self::from_int_unchecked(mask.resize::(T::FALSE)) } + } + + #[inline] + pub(crate) fn to_bitmask_integer(self) -> u64 { + // TODO modify simd_bitmask to zero-extend output, making this unnecessary + macro_rules! bitmask { + { $($ty:ty: $($len:literal),*;)* } => { + match N { + $($( + // Safety: bitmask matches length + $len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 }, + )*)* + // Safety: bitmask matches length + _ => unsafe { self.to_bitmask_impl::() }, + } + } + } + #[cfg(all_lane_counts)] + bitmask! { + u8: 1, 2, 3, 4, 5, 6, 7, 8; + u16: 9, 10, 11, 12, 13, 14, 15, 16; + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; + } + #[cfg(not(all_lane_counts))] + bitmask! { + u8: 1, 2, 4, 8; + u16: 16; + u32: 32; + u64: 64; + } + } + + #[inline] + pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { + // TODO modify simd_bitmask_select to truncate input, making this unnecessary + macro_rules! bitmask { + { $($ty:ty: $($len:literal),*;)* } => { + match N { + $($( + // Safety: bitmask matches length + $len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) }, + )*)* + // Safety: bitmask matches length + _ => unsafe { Self::from_bitmask_impl::(bitmask) }, + } + } + } + #[cfg(all_lane_counts)] + bitmask! { + u8: 1, 2, 3, 4, 5, 6, 7, 8; + u16: 9, 10, 11, 12, 13, 14, 15, 16; + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; + } + #[cfg(not(all_lane_counts))] + bitmask! { + u8: 1, 2, 4, 8; + u16: 16; + u32: 32; + u64: 64; + } } #[inline] diff --git a/crates/core_simd/src/swizzle.rs b/crates/core_simd/src/swizzle.rs index e5b3d4444d8c..ec8548d55745 100644 --- a/crates/core_simd/src/swizzle.rs +++ b/crates/core_simd/src/swizzle.rs @@ -350,9 +350,9 @@ where ) } - /// Extend a vector. + /// Resize a vector. /// - /// Extends the length of a vector, setting the new elements to `value`. + /// If `M` > `N`, extends the length of a vector, setting the new elements to `value`. /// If `M` < `N`, truncates the vector to the first `M` elements. /// /// ``` @@ -361,17 +361,17 @@ where /// # #[cfg(not(feature = "as_crate"))] use core::simd; /// # use simd::u32x4; /// let x = u32x4::from_array([0, 1, 2, 3]); - /// assert_eq!(x.extend::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]); - /// assert_eq!(x.extend::<2>(9).to_array(), [0, 1]); + /// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]); + /// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]); /// ``` #[inline] #[must_use = "method returns a new vector and does not mutate the original inputs"] - pub fn extend(self, value: T) -> Simd + pub fn resize(self, value: T) -> Simd where LaneCount: SupportedLaneCount, { - struct Extend; - impl Swizzle for Extend { + struct Resize; + impl Swizzle for Resize { const INDEX: [usize; M] = const { let mut index = [0; M]; let mut i = 0; @@ -382,6 +382,6 @@ where index }; } - Extend::::concat_swizzle(self, Simd::splat(value)) + Resize::::concat_swizzle(self, Simd::splat(value)) } } diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 92ee53b3e555..00fc2a24e27a 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -13,7 +13,7 @@ macro_rules! test_mask_api { #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::*; - use core_simd::simd::{Mask, Simd}; + use core_simd::simd::Mask; #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] @@ -124,17 +124,14 @@ macro_rules! test_mask_api { #[test] fn roundtrip_bitmask_vector_conversion() { + use core_simd::simd::ToBytes; let values = [ true, false, false, true, false, false, true, false, true, true, false, false, false, false, false, true, ]; let mask = Mask::<$type, 16>::from_array(values); let bitmask = mask.to_bitmask_vector(); - if core::mem::size_of::<$type>() == 1 { - assert_eq!(bitmask, Simd::from_array([0b01001001 as _, 0b10000011 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ])); - } else { - assert_eq!(bitmask, Simd::from_array([0b1000001101001001 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])); - } + assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]); assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask); } }