Check some float ops approximately
This commit is contained in:
parent
323484c827
commit
c43c8d25a8
5 changed files with 222 additions and 4 deletions
10
Cargo.lock
generated
10
Cargo.lock
generated
|
|
@ -52,6 +52,15 @@ dependencies = [
|
|||
"wasm-bindgen-test",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float-cmp"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
|
|
@ -225,6 +234,7 @@ dependencies = [
|
|||
name = "test_helpers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"float-cmp",
|
||||
"proptest",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -16,15 +16,33 @@ macro_rules! unary_test {
|
|||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_test {
|
||||
macro_rules! unary_approx_test {
|
||||
{ $scalar:tt, $($func:tt),+ } => {
|
||||
test_helpers::test_lanes! {
|
||||
$(
|
||||
fn $func<const LANES: usize>() {
|
||||
test_helpers::test_binary_elementwise(
|
||||
test_helpers::test_unary_elementwise_approx(
|
||||
&core_simd::simd::Simd::<$scalar, LANES>::$func,
|
||||
&$scalar::$func,
|
||||
&|_| true,
|
||||
8,
|
||||
)
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! binary_approx_test {
|
||||
{ $scalar:tt, $($func:tt),+ } => {
|
||||
test_helpers::test_lanes! {
|
||||
$(
|
||||
fn $func<const LANES: usize>() {
|
||||
test_helpers::test_binary_elementwise_approx(
|
||||
&core_simd::simd::Simd::<$scalar, LANES>::$func,
|
||||
&$scalar::$func,
|
||||
&|_, _| true,
|
||||
16,
|
||||
)
|
||||
}
|
||||
)*
|
||||
|
|
@ -53,10 +71,13 @@ macro_rules! impl_tests {
|
|||
mod $scalar {
|
||||
use std_float::StdFloat;
|
||||
|
||||
unary_test! { $scalar, sqrt, sin, cos, exp, exp2, ln, log2, log10, ceil, floor, round, trunc }
|
||||
binary_test! { $scalar, log }
|
||||
unary_test! { $scalar, sqrt, ceil, floor, round, trunc }
|
||||
ternary_test! { $scalar, mul_add }
|
||||
|
||||
// https://github.com/rust-lang/miri/issues/3555
|
||||
unary_approx_test! { $scalar, sin, cos, exp, exp2, ln, log2, log10 }
|
||||
binary_approx_test! { $scalar, log }
|
||||
|
||||
test_helpers::test_lanes! {
|
||||
fn fract<const LANES: usize>() {
|
||||
test_helpers::test_unary_elementwise_flush_subnormals(
|
||||
|
|
|
|||
|
|
@ -6,3 +6,4 @@ publish = false
|
|||
|
||||
[dependencies]
|
||||
proptest = { version = "0.10", default-features = false, features = ["alloc"] }
|
||||
float-cmp = "0.10"
|
||||
|
|
|
|||
110
crates/test_helpers/src/approxeq.rs
Normal file
110
crates/test_helpers/src/approxeq.rs
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
//! Compare numeric types approximately.
|
||||
|
||||
use float_cmp::Ulps;
|
||||
|
||||
pub trait ApproxEq {
|
||||
fn approxeq(&self, other: &Self, _ulps: i64) -> bool;
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
|
||||
}
|
||||
|
||||
impl ApproxEq for bool {
|
||||
fn approxeq(&self, other: &Self, _ulps: i64) -> bool {
|
||||
self == other
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_integer_approxeq {
|
||||
{ $($type:ty),* } => {
|
||||
$(
|
||||
impl ApproxEq for $type {
|
||||
fn approxeq(&self, other: &Self, _ulps: i64) -> bool {
|
||||
self == other
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{:?} ({:x})", self, self)
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_integer_approxeq! { u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize }
|
||||
|
||||
macro_rules! impl_float_approxeq {
|
||||
{ $($type:ty),* } => {
|
||||
$(
|
||||
impl ApproxEq for $type {
|
||||
fn approxeq(&self, other: &Self, ulps: i64) -> bool {
|
||||
if self.is_nan() && other.is_nan() {
|
||||
true
|
||||
} else {
|
||||
(self.ulps(other) as i64).abs() <= ulps
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{:?} ({:x})", self, self.to_bits())
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_float_approxeq! { f32, f64 }
|
||||
|
||||
impl<T: ApproxEq, const N: usize> ApproxEq for [T; N] {
|
||||
fn approxeq(&self, other: &Self, ulps: i64) -> bool {
|
||||
self.iter()
|
||||
.zip(other.iter())
|
||||
.fold(true, |value, (left, right)| {
|
||||
value && left.approxeq(right, ulps)
|
||||
})
|
||||
}
|
||||
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
#[repr(transparent)]
|
||||
struct Wrapper<'a, T: ApproxEq>(&'a T);
|
||||
|
||||
impl<T: ApproxEq> core::fmt::Debug for Wrapper<'_, T> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
f.debug_list()
|
||||
.entries(self.iter().map(|x| Wrapper(x)))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct ApproxEqWrapper<'a, T>(pub &'a T, pub i64);
|
||||
|
||||
impl<T: ApproxEq> PartialEq<T> for ApproxEqWrapper<'_, T> {
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.0.approxeq(other, self.1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ApproxEq> core::fmt::Debug for ApproxEqWrapper<'_, T> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! prop_assert_approxeq {
|
||||
{ $a:expr, $b:expr, $ulps:expr $(,)? } => {
|
||||
{
|
||||
use $crate::approxeq::ApproxEqWrapper;
|
||||
let a = $a;
|
||||
let b = $b;
|
||||
proptest::prop_assert_eq!(ApproxEqWrapper(&a, $ulps), b);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
@ -12,6 +12,9 @@ pub mod wasm;
|
|||
#[macro_use]
|
||||
pub mod biteq;
|
||||
|
||||
#[macro_use]
|
||||
pub mod approxeq;
|
||||
|
||||
pub mod subnormals;
|
||||
use subnormals::FlushSubnormals;
|
||||
|
||||
|
|
@ -185,6 +188,41 @@ pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const
|
|||
});
|
||||
}
|
||||
|
||||
/// Test a unary vector function against a unary scalar function, applied elementwise.
|
||||
///
|
||||
/// Floats are checked approximately.
|
||||
pub fn test_unary_elementwise_approx<
|
||||
Scalar,
|
||||
ScalarResult,
|
||||
Vector,
|
||||
VectorResult,
|
||||
const LANES: usize,
|
||||
>(
|
||||
fv: &dyn Fn(Vector) -> VectorResult,
|
||||
fs: &dyn Fn(Scalar) -> ScalarResult,
|
||||
check: &dyn Fn([Scalar; LANES]) -> bool,
|
||||
ulps: i64,
|
||||
) where
|
||||
Scalar: Copy + core::fmt::Debug + DefaultStrategy,
|
||||
ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy,
|
||||
Vector: Into<[Scalar; LANES]> + From<[Scalar; LANES]> + Copy,
|
||||
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
|
||||
{
|
||||
test_1(&|x: [Scalar; LANES]| {
|
||||
proptest::prop_assume!(check(x));
|
||||
let result_1: [ScalarResult; LANES] = fv(x.into()).into();
|
||||
let result_2: [ScalarResult; LANES] = x
|
||||
.iter()
|
||||
.copied()
|
||||
.map(fs)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
crate::prop_assert_approxeq!(result_1, result_2, ulps);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
/// Test a unary vector function against a unary scalar function, applied elementwise.
|
||||
///
|
||||
/// Where subnormals are flushed, use approximate equality.
|
||||
|
|
@ -290,6 +328,44 @@ pub fn test_binary_elementwise<
|
|||
});
|
||||
}
|
||||
|
||||
/// Test a binary vector function against a binary scalar function, applied elementwise.
|
||||
pub fn test_binary_elementwise_approx<
|
||||
Scalar1,
|
||||
Scalar2,
|
||||
ScalarResult,
|
||||
Vector1,
|
||||
Vector2,
|
||||
VectorResult,
|
||||
const LANES: usize,
|
||||
>(
|
||||
fv: &dyn Fn(Vector1, Vector2) -> VectorResult,
|
||||
fs: &dyn Fn(Scalar1, Scalar2) -> ScalarResult,
|
||||
check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES]) -> bool,
|
||||
ulps: i64,
|
||||
) where
|
||||
Scalar1: Copy + core::fmt::Debug + DefaultStrategy,
|
||||
Scalar2: Copy + core::fmt::Debug + DefaultStrategy,
|
||||
ScalarResult: Copy + approxeq::ApproxEq + core::fmt::Debug + DefaultStrategy,
|
||||
Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy,
|
||||
Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy,
|
||||
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
|
||||
{
|
||||
test_2(&|x: [Scalar1; LANES], y: [Scalar2; LANES]| {
|
||||
proptest::prop_assume!(check(x, y));
|
||||
let result_1: [ScalarResult; LANES] = fv(x.into(), y.into()).into();
|
||||
let result_2: [ScalarResult; LANES] = x
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(y.iter().copied())
|
||||
.map(|(x, y)| fs(x, y))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
crate::prop_assert_approxeq!(result_1, result_2, ulps);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
/// Test a binary vector function against a binary scalar function, applied elementwise.
|
||||
///
|
||||
/// Where subnormals are flushed, use approximate equality.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue