Add mul_add

This commit is contained in:
Caleb Zulawski 2021-06-13 19:59:17 +00:00
parent 74e6262ce4
commit f102de7c8b
4 changed files with 64 additions and 0 deletions

View file

@ -49,6 +49,9 @@ extern "platform-intrinsic" {
/// fsqrt
pub(crate) fn simd_fsqrt<T>(x: T) -> T;
/// fma
pub(crate) fn simd_fma<T>(x: T, y: T, z: T) -> T;
pub(crate) fn simd_eq<T, U>(x: T, y: T) -> U;
pub(crate) fn simd_ne<T, U>(x: T, y: T) -> U;
pub(crate) fn simd_lt<T, U>(x: T, y: T) -> U;

View file

@ -36,6 +36,18 @@ macro_rules! impl_float_vector {
unsafe { crate::intrinsics::simd_fabs(self) }
}
/// Fused multiply-add. Computes `(self * a) + b` with only one rounding error,
/// yielding a more accurate result than an unfused multiply-add.
///
/// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
/// architecture has a dedicated `fma` CPU instruction. However, this is not always
/// true, and will be heavily dependent on designing algorithms with specific target
/// hardware in mind.
#[inline]
pub fn mul_add(self, a: Self, b: Self) -> Self {
unsafe { crate::intrinsics::simd_fma(self, a, b) }
}
/// Produces a vector where every lane has the square root value
/// of the equivalently-indexed lane in `self`
#[inline]

View file

@ -435,6 +435,14 @@ macro_rules! impl_float_tests {
)
}
fn mul_add<const LANES: usize>() {
test_helpers::test_ternary_elementwise(
&Vector::<LANES>::mul_add,
&Scalar::mul_add,
&|_, _, _| true,
)
}
fn sqrt<const LANES: usize>() {
test_helpers::test_unary_elementwise(
&Vector::<LANES>::sqrt,

View file

@ -278,6 +278,47 @@ pub fn test_binary_scalar_lhs_elementwise<
});
}
/// Test a ternary vector function against a ternary scalar function, applied elementwise.
#[inline(never)]
pub fn test_ternary_elementwise<
Scalar1,
Scalar2,
Scalar3,
ScalarResult,
Vector1,
Vector2,
Vector3,
VectorResult,
const LANES: usize,
>(
fv: &dyn Fn(Vector1, Vector2, Vector3) -> VectorResult,
fs: &dyn Fn(Scalar1, Scalar2, Scalar3) -> ScalarResult,
check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES], [Scalar3; LANES]) -> bool,
) where
Scalar1: Copy + Default + core::fmt::Debug + DefaultStrategy,
Scalar2: Copy + Default + core::fmt::Debug + DefaultStrategy,
Scalar3: Copy + Default + core::fmt::Debug + DefaultStrategy,
ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy,
Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy,
Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy,
Vector3: Into<[Scalar3; LANES]> + From<[Scalar3; LANES]> + Copy,
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
{
test_3(&|x: [Scalar1; LANES], y: [Scalar2; LANES], z: [Scalar3; LANES]| {
proptest::prop_assume!(check(x, y, z));
let result_1: [ScalarResult; LANES] = fv(x.into(), y.into(), z.into()).into();
let result_2: [ScalarResult; LANES] = {
let mut result = [ScalarResult::default(); LANES];
for ((i1, (i2, i3)), o) in x.iter().zip(y.iter().zip(z.iter())).zip(result.iter_mut()) {
*o = fs(*i1, *i2, *i3);
}
result
};
crate::prop_assert_biteq!(result_1, result_2);
Ok(())
});
}
/// Expand a const-generic test into separate tests for each possible lane count.
#[macro_export]
macro_rules! test_lanes {