diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs index e6d7e695391c..6cfc8f80b53c 100644 --- a/crates/core_simd/src/ops.rs +++ b/crates/core_simd/src/ops.rs @@ -31,27 +31,10 @@ where } } -macro_rules! unsafe_base_op { - ($(impl $op:ident for Simd<$scalar:ty, LANES> { - fn $call:ident(self, rhs: Self) -> Self::Output { - unsafe{ $simd_call:ident } - } - })*) => { - $(impl $op for Simd<$scalar, LANES> - where - $scalar: SimdElement, - LaneCount: SupportedLaneCount, - { - type Output = Self; - - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - unsafe { $crate::intrinsics::$simd_call(self, rhs) } - } - } - )* - } +macro_rules! unsafe_base { + ($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => { + unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) } + }; } /// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic. @@ -64,388 +47,191 @@ macro_rules! unsafe_base_op { // FIXME: Consider implementing this in cg_llvm instead? // cg_clif defaults to this, and scalar MIR shifts also default to wrapping macro_rules! wrap_bitshift { - ($(impl $op:ident for Simd<$int:ty, LANES> { - fn $call:ident(self, rhs: Self) -> Self::Output { - unsafe { $simd_call:ident } + ($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => { + unsafe { + $crate::intrinsics::$simd_call($lhs, $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1))) } - })*) => { - $(impl $op for Simd<$int, LANES> - where - $int: SimdElement, - LaneCount: SupportedLaneCount, - { - type Output = Self; - - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - unsafe { - $crate::intrinsics::$simd_call(self, rhs.bitand(Simd::splat(<$int>::BITS as $int - 1))) - } - } - })* }; } -macro_rules! bitops { - ($(impl BitOps for Simd<$int:ty, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl BitAnd for Simd<$int, LANES> { - fn bitand(self, rhs: Self) -> Self::Output { - unsafe { simd_and } - } - } - - impl BitOr for Simd<$int, LANES> { - fn bitor(self, rhs: Self) -> Self::Output { - unsafe { simd_or } - } - } - - impl BitXor for Simd<$int, LANES> { - fn bitxor(self, rhs: Self) -> Self::Output { - unsafe { simd_xor } - } - } - } - wrap_bitshift! { - impl Shl for Simd<$int, LANES> { - fn shl(self, rhs: Self) -> Self::Output { - unsafe { simd_shl } - } - } - - impl Shr for Simd<$int, LANES> { - fn shr(self, rhs: Self) -> Self::Output { - // This automatically monomorphizes to lshr or ashr, depending, - // so it's fine to use it for both UInts and SInts. - unsafe { simd_shr } - } - } - } - )* - }; -} - -// Integers can always accept bitand, bitor, and bitxor. -// The only question is how to handle shifts >= ::BITS? -// Our current solution uses wrapping logic. -bitops! { - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } - - impl BitOps for Simd { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } -} - -macro_rules! float_arith { - ($(impl FloatArith for Simd<$float:ty, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl Add for Simd<$float, LANES> { - fn add(self, rhs: Self) -> Self::Output { - unsafe { simd_add } - } - } - - impl Mul for Simd<$float, LANES> { - fn mul(self, rhs: Self) -> Self::Output { - unsafe { simd_mul } - } - } - - impl Sub for Simd<$float, LANES> { - fn sub(self, rhs: Self) -> Self::Output { - unsafe { simd_sub } - } - } - - impl Div for Simd<$float, LANES> { - fn div(self, rhs: Self) -> Self::Output { - unsafe { simd_div } - } - } - - impl Rem for Simd<$float, LANES> { - fn rem(self, rhs: Self) -> Self::Output { - unsafe { simd_rem } - } - } - } - )* - }; -} - -// We don't need any special precautions here: -// Floats always accept arithmetic ops, but may become NaN. -float_arith! { - impl FloatArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } - - impl FloatArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } -} - // Division by zero is poison, according to LLVM. // So is dividing the MIN value of a signed integer by -1, // since that would return MAX + 1. // FIXME: Rust allows ::MIN / -1, // so we should probably figure out how to make that safe. macro_rules! int_divrem_guard { - ($(impl $op:ident for Simd<$sint:ty, LANES> { - const PANIC_ZERO: &'static str = $zero:literal; - const PANIC_OVERFLOW: &'static str = $overflow:literal; - fn $call:ident { - unsafe { $simd_call:ident } - } - })*) => { - $(impl $op for Simd<$sint, LANES> - where - $sint: SimdElement, - LaneCount: SupportedLaneCount, + ( $lhs:ident, + $rhs:ident, + { const PANIC_ZERO: &'static str = $zero:literal; + const PANIC_OVERFLOW: &'static str = $overflow:literal; + $simd_call:ident + }, + $int:ident ) => { + if $rhs.lanes_eq(Simd::splat(0)).any() { + panic!($zero); + } else if <$int>::MIN != 0 + && $lhs.lanes_eq(Simd::splat(<$int>::MIN)) & $rhs.lanes_eq(Simd::splat(-1 as _)) + != Mask::splat(false) { - type Output = Self; - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - if rhs.lanes_eq(Simd::splat(0)).any() { - panic!("attempt to calculate the remainder with a divisor of zero"); - } else if <$sint>::MIN != 0 && self.lanes_eq(Simd::splat(<$sint>::MIN)) & rhs.lanes_eq(Simd::splat(-1 as _)) - != Mask::splat(false) - { - panic!("attempt to calculate the remainder with overflow"); - } else { - unsafe { $crate::intrinsics::$simd_call(self, rhs) } - } - } - })* + panic!($overflow); + } else { + unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) } + } }; } -macro_rules! int_arith { - ($(impl IntArith for Simd<$sint:ty, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl Add for Simd<$sint, LANES> { - fn add(self, rhs: Self) -> Self::Output { - unsafe { simd_add } - } - } +macro_rules! for_base_types { + ( T = ($($scalar:ident),*); + type Lhs = Simd; + type Rhs = Simd; + type Output = $out:ty; - impl Mul for Simd<$sint, LANES> { - fn mul(self, rhs: Self) -> Self::Output { - unsafe { simd_mul } - } - } + impl $op:ident::$call:ident { + $macro_impl:ident $inner:tt + }) => { + $( + impl $op for Simd<$scalar, N> + where + $scalar: SimdElement, + LaneCount: SupportedLaneCount, + { + type Output = $out; - impl Sub for Simd<$sint, LANES> { - fn sub(self, rhs: Self) -> Self::Output { - unsafe { simd_sub } - } - } + #[inline] + #[must_use = "operator returns a new vector without mutating the inputs"] + fn $call(self, rhs: Self) -> Self::Output { + $macro_impl!(self, rhs, $inner, $scalar) + } + })* + } +} + +// A "TokenTree muncher": takes a set of scalar types `T = {};` +// type parameters for the ops it implements, `Op::fn` names, +// and a macro that expands into an expr, substituting in an intrinsic. +// It passes that to for_base_types, which expands an impl for the types, +// using the expanded expr in the function, and recurses with itself. +// +// tl;dr impls a set of ops::{Traits} for a set of types +macro_rules! for_base_ops { + ( + T = $types:tt; + type Lhs = Simd; + type Rhs = Simd; + type Output = $out:ident; + impl $op:ident::$call:ident + $inner:tt + $($rest:tt)* + ) => { + for_base_types! { + T = $types; + type Lhs = Simd; + type Rhs = Simd; + type Output = $out; + impl $op::$call + $inner } - - int_divrem_guard!{ - impl Div for Simd<$sint, LANES> { - const PANIC_ZERO: &'static str = "attempt to divide by zero"; - const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow"; - fn div { - unsafe { simd_div } - } - } - - impl Rem for Simd<$sint, LANES> { - const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero"; - const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow"; - fn rem { - unsafe { simd_rem } - } - } - })* + for_base_ops! { + T = $types; + type Lhs = Simd; + type Rhs = Simd; + type Output = $out; + $($rest)* + } + }; + ($($done:tt)*) => { + // Done. } } -int_arith! { - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; +// Integers can always accept add, mul, sub, bitand, bitor, and bitxor. +// For all of these operations, simd_* intrinsics apply wrapping logic. +for_base_ops! { + T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize); + type Lhs = Simd; + type Rhs = Simd; + type Output = Self; + + impl Add::add { + unsafe_base { simd_add } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Mul::mul { + unsafe_base { simd_mul } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Sub::sub { + unsafe_base { simd_sub } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl BitAnd::bitand { + unsafe_base { simd_and } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl BitOr::bitor { + unsafe_base { simd_or } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl BitXor::bitxor { + unsafe_base { simd_xor } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Div::div { + int_divrem_guard { + const PANIC_ZERO: &'static str = "attempt to divide by zero"; + const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow"; + simd_div + } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Rem::rem { + int_divrem_guard { + const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero"; + const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow"; + simd_rem + } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + // The only question is how to handle shifts >= ::BITS? + // Our current solution uses wrapping logic. + impl Shl::shl { + wrap_bitshift { simd_shl } } - impl IntArith for Simd { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Shr::shr { + wrap_bitshift { + // This automatically monomorphizes to lshr or ashr, depending, + // so it's fine to use it for both UInts and SInts. + simd_shr + } + } +} + +// We don't need any special precautions here: +// Floats always accept arithmetic ops, but may become NaN. +for_base_ops! { + T = (f32, f64); + type Lhs = Simd; + type Rhs = Simd; + type Output = Self; + + impl Add::add { + unsafe_base { simd_add } + } + + impl Mul::mul { + unsafe_base { simd_mul } + } + + impl Sub::sub { + unsafe_base { simd_sub } + } + + impl Div::div { + unsafe_base { simd_div } + } + + impl Rem::rem { + unsafe_base { simd_rem } } }