From c43c8d25a8c4f7035d4265e672f41848effbe615 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Mon, 18 Aug 2025 01:31:25 -0400 Subject: [PATCH] Check some float ops approximately --- Cargo.lock | 10 +++ crates/std_float/tests/float.rs | 29 +++++++- crates/test_helpers/Cargo.toml | 1 + crates/test_helpers/src/approxeq.rs | 110 ++++++++++++++++++++++++++++ crates/test_helpers/src/lib.rs | 76 +++++++++++++++++++ 5 files changed, 222 insertions(+), 4 deletions(-) create mode 100644 crates/test_helpers/src/approxeq.rs diff --git a/Cargo.lock b/Cargo.lock index d7accf71ab69..5a5f0d8907ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,15 @@ dependencies = [ "wasm-bindgen-test", ] +[[package]] +name = "float-cmp" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" +dependencies = [ + "num-traits", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -225,6 +234,7 @@ dependencies = [ name = "test_helpers" version = "0.1.0" dependencies = [ + "float-cmp", "proptest", ] diff --git a/crates/std_float/tests/float.rs b/crates/std_float/tests/float.rs index c66c968f8c66..c608ba49564e 100644 --- a/crates/std_float/tests/float.rs +++ b/crates/std_float/tests/float.rs @@ -16,15 +16,33 @@ macro_rules! unary_test { } } -macro_rules! binary_test { +macro_rules! unary_approx_test { { $scalar:tt, $($func:tt),+ } => { test_helpers::test_lanes! { $( fn $func() { - test_helpers::test_binary_elementwise( + test_helpers::test_unary_elementwise_approx( + &core_simd::simd::Simd::<$scalar, LANES>::$func, + &$scalar::$func, + &|_| true, + 8, + ) + } + )* + } + } +} + +macro_rules! binary_approx_test { + { $scalar:tt, $($func:tt),+ } => { + test_helpers::test_lanes! { + $( + fn $func() { + test_helpers::test_binary_elementwise_approx( &core_simd::simd::Simd::<$scalar, LANES>::$func, &$scalar::$func, &|_, _| true, + 16, ) } )* @@ -53,10 +71,13 @@ macro_rules! impl_tests { mod $scalar { use std_float::StdFloat; - unary_test! { $scalar, sqrt, sin, cos, exp, exp2, ln, log2, log10, ceil, floor, round, trunc } - binary_test! { $scalar, log } + unary_test! { $scalar, sqrt, ceil, floor, round, trunc } ternary_test! { $scalar, mul_add } + // https://github.com/rust-lang/miri/issues/3555 + unary_approx_test! { $scalar, sin, cos, exp, exp2, ln, log2, log10 } + binary_approx_test! { $scalar, log } + test_helpers::test_lanes! { fn fract() { test_helpers::test_unary_elementwise_flush_subnormals( diff --git a/crates/test_helpers/Cargo.toml b/crates/test_helpers/Cargo.toml index a5359b9abc84..408bb04c7aa4 100644 --- a/crates/test_helpers/Cargo.toml +++ b/crates/test_helpers/Cargo.toml @@ -6,3 +6,4 @@ publish = false [dependencies] proptest = { version = "0.10", default-features = false, features = ["alloc"] } +float-cmp = "0.10" diff --git a/crates/test_helpers/src/approxeq.rs b/crates/test_helpers/src/approxeq.rs new file mode 100644 index 000000000000..57b43a16bc6f --- /dev/null +++ b/crates/test_helpers/src/approxeq.rs @@ -0,0 +1,110 @@ +//! Compare numeric types approximately. + +use float_cmp::Ulps; + +pub trait ApproxEq { + fn approxeq(&self, other: &Self, _ulps: i64) -> bool; + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result; +} + +impl ApproxEq for bool { + fn approxeq(&self, other: &Self, _ulps: i64) -> bool { + self == other + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{:?}", self) + } +} + +macro_rules! impl_integer_approxeq { + { $($type:ty),* } => { + $( + impl ApproxEq for $type { + fn approxeq(&self, other: &Self, _ulps: i64) -> bool { + self == other + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{:?} ({:x})", self, self) + } + } + )* + }; +} + +impl_integer_approxeq! { u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize } + +macro_rules! impl_float_approxeq { + { $($type:ty),* } => { + $( + impl ApproxEq for $type { + fn approxeq(&self, other: &Self, ulps: i64) -> bool { + if self.is_nan() && other.is_nan() { + true + } else { + (self.ulps(other) as i64).abs() <= ulps + } + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{:?} ({:x})", self, self.to_bits()) + } + } + )* + }; +} + +impl_float_approxeq! { f32, f64 } + +impl ApproxEq for [T; N] { + fn approxeq(&self, other: &Self, ulps: i64) -> bool { + self.iter() + .zip(other.iter()) + .fold(true, |value, (left, right)| { + value && left.approxeq(right, ulps) + }) + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + #[repr(transparent)] + struct Wrapper<'a, T: ApproxEq>(&'a T); + + impl core::fmt::Debug for Wrapper<'_, T> { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + self.0.fmt(f) + } + } + + f.debug_list() + .entries(self.iter().map(|x| Wrapper(x))) + .finish() + } +} + +#[doc(hidden)] +pub struct ApproxEqWrapper<'a, T>(pub &'a T, pub i64); + +impl PartialEq for ApproxEqWrapper<'_, T> { + fn eq(&self, other: &T) -> bool { + self.0.approxeq(other, self.1) + } +} + +impl core::fmt::Debug for ApproxEqWrapper<'_, T> { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + self.0.fmt(f) + } +} + +#[macro_export] +macro_rules! prop_assert_approxeq { + { $a:expr, $b:expr, $ulps:expr $(,)? } => { + { + use $crate::approxeq::ApproxEqWrapper; + let a = $a; + let b = $b; + proptest::prop_assert_eq!(ApproxEqWrapper(&a, $ulps), b); + } + }; +} diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs index 197c920e11ea..35401a9ddb40 100644 --- a/crates/test_helpers/src/lib.rs +++ b/crates/test_helpers/src/lib.rs @@ -12,6 +12,9 @@ pub mod wasm; #[macro_use] pub mod biteq; +#[macro_use] +pub mod approxeq; + pub mod subnormals; use subnormals::FlushSubnormals; @@ -185,6 +188,41 @@ pub fn test_unary_elementwise( + fv: &dyn Fn(Vector) -> VectorResult, + fs: &dyn Fn(Scalar) -> ScalarResult, + check: &dyn Fn([Scalar; LANES]) -> bool, + ulps: i64, +) where + Scalar: Copy + core::fmt::Debug + DefaultStrategy, + ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy, + Vector: Into<[Scalar; LANES]> + From<[Scalar; LANES]> + Copy, + VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy, +{ + test_1(&|x: [Scalar; LANES]| { + proptest::prop_assume!(check(x)); + let result_1: [ScalarResult; LANES] = fv(x.into()).into(); + let result_2: [ScalarResult; LANES] = x + .iter() + .copied() + .map(fs) + .collect::>() + .try_into() + .unwrap(); + crate::prop_assert_approxeq!(result_1, result_2, ulps); + Ok(()) + }); +} + /// Test a unary vector function against a unary scalar function, applied elementwise. /// /// Where subnormals are flushed, use approximate equality. @@ -290,6 +328,44 @@ pub fn test_binary_elementwise< }); } +/// Test a binary vector function against a binary scalar function, applied elementwise. +pub fn test_binary_elementwise_approx< + Scalar1, + Scalar2, + ScalarResult, + Vector1, + Vector2, + VectorResult, + const LANES: usize, +>( + fv: &dyn Fn(Vector1, Vector2) -> VectorResult, + fs: &dyn Fn(Scalar1, Scalar2) -> ScalarResult, + check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES]) -> bool, + ulps: i64, +) where + Scalar1: Copy + core::fmt::Debug + DefaultStrategy, + Scalar2: Copy + core::fmt::Debug + DefaultStrategy, + ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy, + Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy, + Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy, + VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy, +{ + test_2(&|x: [Scalar1; LANES], y: [Scalar2; LANES]| { + proptest::prop_assume!(check(x, y)); + let result_1: [ScalarResult; LANES] = fv(x.into(), y.into()).into(); + let result_2: [ScalarResult; LANES] = x + .iter() + .copied() + .zip(y.iter().copied()) + .map(|(x, y)| fs(x, y)) + .collect::>() + .try_into() + .unwrap(); + crate::prop_assert_approxeq!(result_1, result_2, ulps); + Ok(()) + }); +} + /// Test a binary vector function against a binary scalar function, applied elementwise. /// /// Where subnormals are flushed, use approximate equality.