Check some float ops approximately

This commit is contained in:
Caleb Zulawski 2025-08-18 01:31:25 -04:00
parent 323484c827
commit c43c8d25a8
5 changed files with 222 additions and 4 deletions

10
Cargo.lock generated
View file

@ -52,6 +52,15 @@ dependencies = [
"wasm-bindgen-test",
]
[[package]]
name = "float-cmp"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8"
dependencies = [
"num-traits",
]
[[package]]
name = "js-sys"
version = "0.3.77"
@ -225,6 +234,7 @@ dependencies = [
name = "test_helpers"
version = "0.1.0"
dependencies = [
"float-cmp",
"proptest",
]

View file

@ -16,15 +16,33 @@ macro_rules! unary_test {
}
}
macro_rules! binary_test {
macro_rules! unary_approx_test {
{ $scalar:tt, $($func:tt),+ } => {
test_helpers::test_lanes! {
$(
fn $func<const LANES: usize>() {
test_helpers::test_binary_elementwise(
test_helpers::test_unary_elementwise_approx(
&core_simd::simd::Simd::<$scalar, LANES>::$func,
&$scalar::$func,
&|_| true,
8,
)
}
)*
}
}
}
macro_rules! binary_approx_test {
{ $scalar:tt, $($func:tt),+ } => {
test_helpers::test_lanes! {
$(
fn $func<const LANES: usize>() {
test_helpers::test_binary_elementwise_approx(
&core_simd::simd::Simd::<$scalar, LANES>::$func,
&$scalar::$func,
&|_, _| true,
16,
)
}
)*
@ -53,10 +71,13 @@ macro_rules! impl_tests {
mod $scalar {
use std_float::StdFloat;
unary_test! { $scalar, sqrt, sin, cos, exp, exp2, ln, log2, log10, ceil, floor, round, trunc }
binary_test! { $scalar, log }
unary_test! { $scalar, sqrt, ceil, floor, round, trunc }
ternary_test! { $scalar, mul_add }
// https://github.com/rust-lang/miri/issues/3555
unary_approx_test! { $scalar, sin, cos, exp, exp2, ln, log2, log10 }
binary_approx_test! { $scalar, log }
test_helpers::test_lanes! {
fn fract<const LANES: usize>() {
test_helpers::test_unary_elementwise_flush_subnormals(

View file

@ -6,3 +6,4 @@ publish = false
[dependencies]
proptest = { version = "0.10", default-features = false, features = ["alloc"] }
float-cmp = "0.10"

View file

@ -0,0 +1,110 @@
//! Compare numeric types approximately.
use float_cmp::Ulps;
pub trait ApproxEq {
fn approxeq(&self, other: &Self, _ulps: i64) -> bool;
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
}
impl ApproxEq for bool {
fn approxeq(&self, other: &Self, _ulps: i64) -> bool {
self == other
}
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{:?}", self)
}
}
macro_rules! impl_integer_approxeq {
{ $($type:ty),* } => {
$(
impl ApproxEq for $type {
fn approxeq(&self, other: &Self, _ulps: i64) -> bool {
self == other
}
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{:?} ({:x})", self, self)
}
}
)*
};
}
impl_integer_approxeq! { u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize }
macro_rules! impl_float_approxeq {
{ $($type:ty),* } => {
$(
impl ApproxEq for $type {
fn approxeq(&self, other: &Self, ulps: i64) -> bool {
if self.is_nan() && other.is_nan() {
true
} else {
(self.ulps(other) as i64).abs() <= ulps
}
}
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{:?} ({:x})", self, self.to_bits())
}
}
)*
};
}
impl_float_approxeq! { f32, f64 }
impl<T: ApproxEq, const N: usize> ApproxEq for [T; N] {
fn approxeq(&self, other: &Self, ulps: i64) -> bool {
self.iter()
.zip(other.iter())
.fold(true, |value, (left, right)| {
value && left.approxeq(right, ulps)
})
}
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
#[repr(transparent)]
struct Wrapper<'a, T: ApproxEq>(&'a T);
impl<T: ApproxEq> core::fmt::Debug for Wrapper<'_, T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
self.0.fmt(f)
}
}
f.debug_list()
.entries(self.iter().map(|x| Wrapper(x)))
.finish()
}
}
#[doc(hidden)]
pub struct ApproxEqWrapper<'a, T>(pub &'a T, pub i64);
impl<T: ApproxEq> PartialEq<T> for ApproxEqWrapper<'_, T> {
fn eq(&self, other: &T) -> bool {
self.0.approxeq(other, self.1)
}
}
impl<T: ApproxEq> core::fmt::Debug for ApproxEqWrapper<'_, T> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
self.0.fmt(f)
}
}
#[macro_export]
macro_rules! prop_assert_approxeq {
{ $a:expr, $b:expr, $ulps:expr $(,)? } => {
{
use $crate::approxeq::ApproxEqWrapper;
let a = $a;
let b = $b;
proptest::prop_assert_eq!(ApproxEqWrapper(&a, $ulps), b);
}
};
}

View file

@ -12,6 +12,9 @@ pub mod wasm;
#[macro_use]
pub mod biteq;
#[macro_use]
pub mod approxeq;
pub mod subnormals;
use subnormals::FlushSubnormals;
@ -185,6 +188,41 @@ pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const
});
}
/// Test a unary vector function against a unary scalar function, applied elementwise.
///
/// Floats are checked approximately.
pub fn test_unary_elementwise_approx<
Scalar,
ScalarResult,
Vector,
VectorResult,
const LANES: usize,
>(
fv: &dyn Fn(Vector) -> VectorResult,
fs: &dyn Fn(Scalar) -> ScalarResult,
check: &dyn Fn([Scalar; LANES]) -> bool,
ulps: i64,
) where
Scalar: Copy + core::fmt::Debug + DefaultStrategy,
ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy,
Vector: Into<[Scalar; LANES]> + From<[Scalar; LANES]> + Copy,
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
{
test_1(&|x: [Scalar; LANES]| {
proptest::prop_assume!(check(x));
let result_1: [ScalarResult; LANES] = fv(x.into()).into();
let result_2: [ScalarResult; LANES] = x
.iter()
.copied()
.map(fs)
.collect::<Vec<_>>()
.try_into()
.unwrap();
crate::prop_assert_approxeq!(result_1, result_2, ulps);
Ok(())
});
}
/// Test a unary vector function against a unary scalar function, applied elementwise.
///
/// Where subnormals are flushed, use approximate equality.
@ -290,6 +328,44 @@ pub fn test_binary_elementwise<
});
}
/// Test a binary vector function against a binary scalar function, applied elementwise.
pub fn test_binary_elementwise_approx<
Scalar1,
Scalar2,
ScalarResult,
Vector1,
Vector2,
VectorResult,
const LANES: usize,
>(
fv: &dyn Fn(Vector1, Vector2) -> VectorResult,
fs: &dyn Fn(Scalar1, Scalar2) -> ScalarResult,
check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES]) -> bool,
ulps: i64,
) where
Scalar1: Copy + core::fmt::Debug + DefaultStrategy,
Scalar2: Copy + core::fmt::Debug + DefaultStrategy,
ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy,
Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy,
Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy,
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
{
test_2(&|x: [Scalar1; LANES], y: [Scalar2; LANES]| {
proptest::prop_assume!(check(x, y));
let result_1: [ScalarResult; LANES] = fv(x.into(), y.into()).into();
let result_2: [ScalarResult; LANES] = x
.iter()
.copied()
.zip(y.iter().copied())
.map(|(x, y)| fs(x, y))
.collect::<Vec<_>>()
.try_into()
.unwrap();
crate::prop_assert_approxeq!(result_1, result_2, ulps);
Ok(())
});
}
/// Test a binary vector function against a binary scalar function, applied elementwise.
///
/// Where subnormals are flushed, use approximate equality.