From f102de7c8b2f59bcdc8f27dfe42a94725c91fd36 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Sun, 13 Jun 2021 19:59:17 +0000 Subject: [PATCH] Add mul_add --- crates/core_simd/src/intrinsics.rs | 3 ++ crates/core_simd/src/vector/float.rs | 12 ++++++++ crates/core_simd/tests/ops_macros.rs | 8 ++++++ crates/test_helpers/src/lib.rs | 41 ++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+) diff --git a/crates/core_simd/src/intrinsics.rs b/crates/core_simd/src/intrinsics.rs index 7adf4c24e104..3983beb82ecf 100644 --- a/crates/core_simd/src/intrinsics.rs +++ b/crates/core_simd/src/intrinsics.rs @@ -49,6 +49,9 @@ extern "platform-intrinsic" { /// fsqrt pub(crate) fn simd_fsqrt(x: T) -> T; + /// fma + pub(crate) fn simd_fma(x: T, y: T, z: T) -> T; + pub(crate) fn simd_eq(x: T, y: T) -> U; pub(crate) fn simd_ne(x: T, y: T) -> U; pub(crate) fn simd_lt(x: T, y: T) -> U; diff --git a/crates/core_simd/src/vector/float.rs b/crates/core_simd/src/vector/float.rs index 7061b9b06748..4f0888f29f96 100644 --- a/crates/core_simd/src/vector/float.rs +++ b/crates/core_simd/src/vector/float.rs @@ -36,6 +36,18 @@ macro_rules! impl_float_vector { unsafe { crate::intrinsics::simd_fabs(self) } } + /// Fused multiply-add. Computes `(self * a) + b` with only one rounding error, + /// yielding a more accurate result than an unfused multiply-add. + /// + /// Using `mul_add` *may* be more performant than an unfused multiply-add if the target + /// architecture has a dedicated `fma` CPU instruction. However, this is not always + /// true, and will be heavily dependent on designing algorithms with specific target + /// hardware in mind. + #[inline] + pub fn mul_add(self, a: Self, b: Self) -> Self { + unsafe { crate::intrinsics::simd_fma(self, a, b) } + } + /// Produces a vector where every lane has the square root value /// of the equivalently-indexed lane in `self` #[inline] diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 8ef2edc8370a..4057f33d4470 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -435,6 +435,14 @@ macro_rules! impl_float_tests { ) } + fn mul_add() { + test_helpers::test_ternary_elementwise( + &Vector::::mul_add, + &Scalar::mul_add, + &|_, _, _| true, + ) + } + fn sqrt() { test_helpers::test_unary_elementwise( &Vector::::sqrt, diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs index ff6d30a1afb7..4f2380b8e5ba 100644 --- a/crates/test_helpers/src/lib.rs +++ b/crates/test_helpers/src/lib.rs @@ -278,6 +278,47 @@ pub fn test_binary_scalar_lhs_elementwise< }); } +/// Test a ternary vector function against a ternary scalar function, applied elementwise. +#[inline(never)] +pub fn test_ternary_elementwise< + Scalar1, + Scalar2, + Scalar3, + ScalarResult, + Vector1, + Vector2, + Vector3, + VectorResult, + const LANES: usize, +>( + fv: &dyn Fn(Vector1, Vector2, Vector3) -> VectorResult, + fs: &dyn Fn(Scalar1, Scalar2, Scalar3) -> ScalarResult, + check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES], [Scalar3; LANES]) -> bool, +) where + Scalar1: Copy + Default + core::fmt::Debug + DefaultStrategy, + Scalar2: Copy + Default + core::fmt::Debug + DefaultStrategy, + Scalar3: Copy + Default + core::fmt::Debug + DefaultStrategy, + ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy, + Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy, + Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy, + Vector3: Into<[Scalar3; LANES]> + From<[Scalar3; LANES]> + Copy, + VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy, +{ + test_3(&|x: [Scalar1; LANES], y: [Scalar2; LANES], z: [Scalar3; LANES]| { + proptest::prop_assume!(check(x, y, z)); + let result_1: [ScalarResult; LANES] = fv(x.into(), y.into(), z.into()).into(); + let result_2: [ScalarResult; LANES] = { + let mut result = [ScalarResult::default(); LANES]; + for ((i1, (i2, i3)), o) in x.iter().zip(y.iter().zip(z.iter())).zip(result.iter_mut()) { + *o = fs(*i1, *i2, *i3); + } + result + }; + crate::prop_assert_biteq!(result_1, result_2); + Ok(()) + }); +} + /// Expand a const-generic test into separate tests for each possible lane count. #[macro_export] macro_rules! test_lanes {