Merge pull request #375 from rust-lang/bitmask

Simplify bitmasks
This commit is contained in:
Caleb Zulawski 2023-11-18 22:05:02 -05:00 committed by GitHub
commit 7e5c03a33d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 185 deletions

View file

@ -12,9 +12,6 @@
)]
mod mask_impl;
mod to_bitmask;
pub use to_bitmask::{ToBitMask, ToBitMaskArray};
use crate::simd::{
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
};
@ -262,6 +259,45 @@ where
pub fn all(self) -> bool {
self.0.all()
}
/// Create a bitmask from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
}
/// Create a mask from a bitmask.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
/// If the mask contains more than 64 elements, the remainder are set to `false`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}
/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
self.0.to_bitmask_vector()
}
/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}
}
// vector/array conversion

View file

@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use core::marker::PhantomData;
/// A mask where each lane is represented by a single bit.
@ -120,39 +120,37 @@ where
}
#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
assert!(core::mem::size_of::<Self>() == M);
// Safety: converting an integer to an array of bytes of the same size is safe
unsafe { core::mem::transmute_copy(&self.0) }
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);
bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
bitmask
}
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
assert!(core::mem::size_of::<Self>() == M);
// Safety: converting an array of bytes to an integer of the same size is safe
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
Self(bytes, PhantomData)
}
#[inline]
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
u64::from_ne_bytes(bitmask)
}
#[inline]
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
pub fn from_bitmask_integer(bitmask: u64) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_mut().len();
bytes
.as_mut()
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
Self(bytes, PhantomData)
}
#[inline]

View file

@ -1,8 +1,7 @@
//! Masks that take up full SIMD vector registers.
use super::{to_bitmask::ToBitMaskArray, MaskElement};
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
#[repr(transparent)]
pub struct Mask<T, const N: usize>(Simd<T, N>)
@ -143,53 +142,49 @@ where
}
#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
where
super::Mask<T, N>: ToBitMaskArray,
{
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);
// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
intrinsics::simd_bitmask(self.0);
// Transmute to the return type
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);
// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
*x = x.reverse_bits();
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
};
}
bitmask
bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
}
bitmask
}
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
where
super::Mask<T, N>: ToBitMaskArray,
{
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
// Safety: Bytes is the right size array
unsafe {
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
}
// Transmute to the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
core::mem::transmute_copy(&bitmask);
// Compute the regular mask
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
@ -197,40 +192,107 @@ where
}
#[inline]
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
let resized = self.to_int().resize::<M>(T::FALSE);
// Safety: `resized` is an integer vector with length M, which must match T
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
}
}
#[inline]
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
};
// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, M> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
Simd::<T, M>::splat(T::TRUE),
Simd::<T, M>::splat(T::FALSE),
)
};
// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
}
#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
)*)*
// Safety: bitmask matches length
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}
#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
)*)*
// Safety: bitmask matches length
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

View file

@ -1,111 +0,0 @@
use super::{mask_impl, Mask, MaskElement};
use crate::simd::{LaneCount, SupportedLaneCount};
use core::borrow::{Borrow, BorrowMut};
mod sealed {
pub trait Sealed {}
}
pub use sealed::Sealed;
impl<T, const N: usize> Sealed for Mask<T, N>
where
T: MaskElement,
LaneCount<N>: SupportedLaneCount,
{
}
/// Converts masks to and from integer bitmasks.
///
/// Each bit of the bitmask corresponds to a mask element, starting with the LSB.
pub trait ToBitMask: Sealed {
/// The integer bitmask type.
type BitMask;
/// Converts a mask to a bitmask.
fn to_bitmask(self) -> Self::BitMask;
/// Converts a bitmask to a mask.
fn from_bitmask(bitmask: Self::BitMask) -> Self;
}
/// Converts masks to and from byte array bitmasks.
///
/// Each bit of the bitmask corresponds to a mask element, starting with the LSB of the first byte.
pub trait ToBitMaskArray: Sealed {
/// The bitmask array.
type BitMaskArray: Copy
+ Unpin
+ Send
+ Sync
+ AsRef<[u8]>
+ AsMut<[u8]>
+ Borrow<[u8]>
+ BorrowMut<[u8]>
+ 'static;
/// Converts a mask to a bitmask.
fn to_bitmask_array(self) -> Self::BitMaskArray;
/// Converts a bitmask to a mask.
fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self;
}
macro_rules! impl_integer {
{ $(impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
$(
impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
type BitMask = $int;
#[inline]
fn to_bitmask(self) -> $int {
self.0.to_bitmask_integer()
}
#[inline]
fn from_bitmask(bitmask: $int) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}
}
)*
}
}
macro_rules! impl_array {
{ $(impl ToBitMaskArray<Bytes=$int:literal> for Mask<_, $lanes:literal>)* } => {
$(
impl<T: MaskElement> ToBitMaskArray for Mask<T, $lanes> {
type BitMaskArray = [u8; $int];
#[inline]
fn to_bitmask_array(self) -> Self::BitMaskArray {
self.0.to_bitmask_array()
}
#[inline]
fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self {
Self(mask_impl::Mask::from_bitmask_array(bitmask))
}
}
)*
}
}
impl_integer! {
impl ToBitMask<BitMask=u8> for Mask<_, 1>
impl ToBitMask<BitMask=u8> for Mask<_, 2>
impl ToBitMask<BitMask=u8> for Mask<_, 4>
impl ToBitMask<BitMask=u8> for Mask<_, 8>
impl ToBitMask<BitMask=u16> for Mask<_, 16>
impl ToBitMask<BitMask=u32> for Mask<_, 32>
impl ToBitMask<BitMask=u64> for Mask<_, 64>
}
impl_array! {
impl ToBitMaskArray<Bytes=1> for Mask<_, 1>
impl ToBitMaskArray<Bytes=1> for Mask<_, 2>
impl ToBitMaskArray<Bytes=1> for Mask<_, 4>
impl ToBitMaskArray<Bytes=1> for Mask<_, 8>
impl ToBitMaskArray<Bytes=2> for Mask<_, 16>
impl ToBitMaskArray<Bytes=4> for Mask<_, 32>
impl ToBitMaskArray<Bytes=8> for Mask<_, 64>
}

View file

@ -349,4 +349,39 @@ where
Odd::concat_swizzle(self, other),
)
}
/// Resize a vector.
///
/// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
/// If `M` < `N`, truncates the vector to the first `M` elements.
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::u32x4;
/// let x = u32x4::from_array([0, 1, 2, 3]);
/// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
/// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
/// ```
#[inline]
#[must_use = "method returns a new vector and does not mutate the original inputs"]
pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
where
LaneCount<M>: SupportedLaneCount,
{
struct Resize<const N: usize>;
impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
const INDEX: [usize; M] = const {
let mut index = [0; M];
let mut i = 0;
while i < M {
index[i] = if i < N { i } else { N };
i += 1;
}
index
};
}
Resize::<N>::concat_swizzle(self, Simd::splat(value))
}
}

View file

@ -72,7 +72,6 @@ macro_rules! test_mask_api {
#[test]
fn roundtrip_bitmask_conversion() {
use core_simd::simd::ToBitMask;
let values = [
true, false, false, true, false, false, true, false,
true, true, false, false, false, false, false, true,
@ -85,8 +84,6 @@ macro_rules! test_mask_api {
#[test]
fn roundtrip_bitmask_conversion_short() {
use core_simd::simd::ToBitMask;
let values = [
false, false, false, true,
];
@ -126,16 +123,16 @@ macro_rules! test_mask_api {
}
#[test]
fn roundtrip_bitmask_array_conversion() {
use core_simd::simd::ToBitMaskArray;
fn roundtrip_bitmask_vector_conversion() {
use core_simd::simd::ToBytes;
let values = [
true, false, false, true, false, false, true, false,
true, true, false, false, false, false, false, true,
];
let mask = Mask::<$type, 16>::from_array(values);
let bitmask = mask.to_bitmask_array();
assert_eq!(bitmask, [0b01001001, 0b10000011]);
assert_eq!(Mask::<$type, 16>::from_bitmask_array(bitmask), mask);
let bitmask = mask.to_bitmask_vector();
assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
}
}
}