Add floating-point classification functions

This commit is contained in:
Caleb Zulawski 2021-02-14 23:35:24 -05:00
parent 4e6d44086c
commit 93ce1c1a59
5 changed files with 163 additions and 70 deletions

View file

@ -0,0 +1,86 @@
use crate::LanesAtMost64;
macro_rules! implement_mask_ops {
{ $($vector:ident => $mask:ident ($inner_mask_ty:ident, $inner_ty:ident),)* } => {
$(
impl<const LANES: usize> crate::$vector<LANES>
where
crate::$vector<LANES>: LanesAtMost64,
crate::$inner_ty<LANES>: LanesAtMost64,
{
/// Test if each lane is equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_eq(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_eq(self, other))
.into()
}
}
/// Test if each lane is not equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_ne(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_ne(self, other))
.into()
}
}
/// Test if each lane is less than the corresponding lane in `other`.
#[inline]
pub fn lanes_lt(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_lt(self, other))
.into()
}
}
/// Test if each lane is greater than the corresponding lane in `other`.
#[inline]
pub fn lanes_gt(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_gt(self, other))
.into()
}
}
/// Test if each lane is less than or equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_le(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_le(self, other))
.into()
}
}
/// Test if each lane is greater than or equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_ge(self, other: Self) -> crate::$mask<LANES> {
unsafe {
crate::$inner_mask_ty::from_int_unchecked(crate::intrinsics::simd_ge(self, other))
.into()
}
}
}
)*
}
}
implement_mask_ops! {
SimdI8 => Mask8 (SimdMask8, SimdI8),
SimdI16 => Mask16 (SimdMask16, SimdI16),
SimdI32 => Mask32 (SimdMask32, SimdI32),
SimdI64 => Mask64 (SimdMask64, SimdI64),
SimdI128 => Mask128 (SimdMask128, SimdI128),
SimdIsize => MaskSize (SimdMaskSize, SimdIsize),
SimdU8 => Mask8 (SimdMask8, SimdI8),
SimdU16 => Mask16 (SimdMask16, SimdI16),
SimdU32 => Mask32 (SimdMask32, SimdI32),
SimdU64 => Mask64 (SimdMask64, SimdI64),
SimdU128 => Mask128 (SimdMask128, SimdI128),
SimdUsize => MaskSize (SimdMaskSize, SimdIsize),
SimdF32 => Mask32 (SimdMask32, SimdI32),
SimdF64 => Mask64 (SimdMask64, SimdI64),
}

View file

@ -16,6 +16,7 @@ mod fmt;
mod intrinsics;
mod ops;
mod round;
mod comparisons;
mod math;

View file

@ -75,6 +75,25 @@ macro_rules! define_mask {
0
}
}
/// Creates a mask from an integer vector.
///
/// # Safety
/// All lanes must be either 0 or -1.
#[inline]
pub unsafe fn from_int_unchecked(value: $type) -> Self {
Self(value)
}
/// Creates a mask from an integer vector.
///
/// # Panics
/// Panics if any lane is not 0 or -1.
#[inline]
pub fn from_int(value: $type) -> Self {
use core::convert::TryInto;
value.try_into().unwrap()
}
}
impl<const $lanes: usize> core::convert::From<bool> for $name<$lanes>

View file

@ -360,73 +360,6 @@ define_opaque_mask! {
@bits crate::SimdIsize<LANES>
}
macro_rules! implement_mask_ops {
{ $($vector:ident => $mask:ident ($inner_ty:ident),)* } => {
$(
impl<const LANES: usize> crate::$vector<LANES>
where
crate::$vector<LANES>: LanesAtMost64,
crate::$inner_ty<LANES>: LanesAtMost64,
{
/// Test if each lane is equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_eq(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_eq(self, other)) }
}
/// Test if each lane is not equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_ne(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_ne(self, other)) }
}
/// Test if each lane is less than the corresponding lane in `other`.
#[inline]
pub fn lanes_lt(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_lt(self, other)) }
}
/// Test if each lane is greater than the corresponding lane in `other`.
#[inline]
pub fn lanes_gt(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_gt(self, other)) }
}
/// Test if each lane is less than or equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_le(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_le(self, other)) }
}
/// Test if each lane is greater than or equal to the corresponding lane in `other`.
#[inline]
pub fn lanes_ge(&self, other: &Self) -> $mask<LANES> {
unsafe { $mask(crate::intrinsics::simd_ge(self, other)) }
}
}
)*
}
}
implement_mask_ops! {
SimdI8 => Mask8 (SimdI8),
SimdI16 => Mask16 (SimdI16),
SimdI32 => Mask32 (SimdI32),
SimdI64 => Mask64 (SimdI64),
SimdI128 => Mask128 (SimdI128),
SimdIsize => MaskSize (SimdIsize),
SimdU8 => Mask8 (SimdI8),
SimdU16 => Mask16 (SimdI16),
SimdU32 => Mask32 (SimdI32),
SimdU64 => Mask64 (SimdI64),
SimdU128 => Mask128 (SimdI128),
SimdUsize => MaskSize (SimdIsize),
SimdF32 => Mask32 (SimdI32),
SimdF64 => Mask64 (SimdI64),
}
/// Vector of eight 8-bit masks
pub type mask8x8 = Mask8<8>;

View file

@ -4,7 +4,7 @@
/// `$lanes` of float `$type`, which uses `$bits_ty` as its binary
/// representation. Called from `define_float_vector!`.
macro_rules! impl_float_vector {
{ $name:ident, $type:ty, $bits_ty:ident } => {
{ $name:ident, $type:ty, $bits_ty:ident, $mask_ty:ident, $mask_impl_ty:ident } => {
impl_vector! { $name, $type }
impl<const LANES: usize> $name<LANES>
@ -36,6 +36,60 @@ macro_rules! impl_float_vector {
Self::from_bits(self.to_bits() & no_sign)
}
}
impl<const LANES: usize> $name<LANES>
where
Self: crate::LanesAtMost64,
crate::$bits_ty<LANES>: crate::LanesAtMost64,
crate::$mask_impl_ty<LANES>: crate::LanesAtMost64,
{
/// Returns true for each lane if it has a positive sign, including
/// `+0.0`, `NaN`s with positive sign bit and positive infinity.
#[inline]
pub fn is_sign_positive(self) -> crate::$mask_ty<LANES> {
let sign_bits = self.to_bits() & crate::$bits_ty::splat((!0 >> 1) + 1);
sign_bits.lanes_gt(crate::$bits_ty::splat(0))
}
/// Returns true for each lane if it has a negative sign, including
/// `-0.0`, `NaN`s with negative sign bit and negative infinity.
#[inline]
pub fn is_sign_negative(self) -> crate::$mask_ty<LANES> {
!self.is_sign_positive()
}
/// Returns true for each lane if its value is `NaN`.
#[inline]
pub fn is_nan(self) -> crate::$mask_ty<LANES> {
self.lanes_eq(self)
}
/// Returns true for each lane if its value is positive infinity or negative infinity.
#[inline]
pub fn is_infinite(self) -> crate::$mask_ty<LANES> {
self.abs().lanes_eq(Self::splat(<$type>::INFINITY))
}
/// Returns true for each lane if its value is neither infinite nor `NaN`.
#[inline]
pub fn is_finite(self) -> crate::$mask_ty<LANES> {
self.abs().lanes_lt(Self::splat(<$type>::INFINITY))
}
/// Returns true for each lane if its value is subnormal.
#[inline]
pub fn is_subnormal(self) -> crate::$mask_ty<LANES> {
let mantissa_mask = crate::$bits_ty::splat((1 << (<$type>::MANTISSA_DIGITS - 1)) - 1);
self.abs().lanes_ne(Self::splat(0.0)) & (self.to_bits() & mantissa_mask).lanes_eq(crate::$bits_ty::splat(0))
}
/// Returns true for each lane if its value is neither neither zero, infinite,
/// subnormal, or `NaN`.
#[inline]
pub fn is_normal(self) -> crate::$mask_ty<LANES> {
!(self.abs().lanes_eq(Self::splat(0.0)) | self.is_nan() | self.is_subnormal())
}
}
};
}
@ -46,7 +100,7 @@ pub struct SimdF32<const LANES: usize>([f32; LANES])
where
Self: crate::LanesAtMost64;
impl_float_vector! { SimdF32, f32, SimdU32 }
impl_float_vector! { SimdF32, f32, SimdU32, Mask32, SimdI32 }
from_transmute_x86! { unsafe f32x4 => __m128 }
from_transmute_x86! { unsafe f32x8 => __m256 }
@ -58,7 +112,7 @@ pub struct SimdF64<const LANES: usize>([f64; LANES])
where
Self: crate::LanesAtMost64;
impl_float_vector! { SimdF64, f64, SimdU64 }
impl_float_vector! { SimdF64, f64, SimdU64, Mask64, SimdI64 }
from_transmute_x86! { unsafe f64x2 => __m128d }
from_transmute_x86! { unsafe f64x4 => __m256d }