diff --git a/library/compiler-builtins/libm/crates/libm-test/Cargo.toml b/library/compiler-builtins/libm/crates/libm-test/Cargo.toml index 1e76fb70707e..b6e2ced5877d 100644 --- a/library/compiler-builtins/libm/crates/libm-test/Cargo.toml +++ b/library/compiler-builtins/libm/crates/libm-test/Cargo.toml @@ -12,6 +12,7 @@ default = [] test-musl-serialized = ["rand"] [dependencies] +anyhow = "1.0.90" libm = { path = "../.." } libm-macros = { path = "../libm-macros" } diff --git a/library/compiler-builtins/libm/crates/libm-test/src/lib.rs b/library/compiler-builtins/libm/crates/libm-test/src/lib.rs index 5444709d87cf..41873099f637 100644 --- a/library/compiler-builtins/libm/crates/libm-test/src/lib.rs +++ b/library/compiler-builtins/libm/crates/libm-test/src/lib.rs @@ -1,6 +1,12 @@ mod num_traits; +mod test_traits; pub use num_traits::{Float, Hex, Int}; +pub use test_traits::{CheckBasis, CheckCtx, CheckOutput, GenerateInput, TupleCall}; + +/// Result type for tests is usually from `anyhow`. Most times there is no success value to +/// propagate. +pub type TestResult = Result; // List of all files present in libm's source include!(concat!(env!("OUT_DIR"), "/all_files.rs")); diff --git a/library/compiler-builtins/libm/crates/libm-test/src/num_traits.rs b/library/compiler-builtins/libm/crates/libm-test/src/num_traits.rs index 835d6e46d437..d7d806bab9ea 100644 --- a/library/compiler-builtins/libm/crates/libm-test/src/num_traits.rs +++ b/library/compiler-builtins/libm/crates/libm-test/src/num_traits.rs @@ -1,5 +1,7 @@ use std::fmt; +use crate::TestResult; + /// Common types and methods for floating point numbers. pub trait Float: Copy + fmt::Display + fmt::Debug + PartialEq { type Int: Int; @@ -134,6 +136,29 @@ macro_rules! impl_int { format!("{self:#0width$x}", width = ((Self::BITS / 4) + 2) as usize) } } + + impl $crate::CheckOutput for $ty { + fn validate<'a>( + self, + expected: Self, + input: Input, + _ctx: &$crate::CheckCtx, + ) -> TestResult { + anyhow::ensure!( + self == expected, + "\ + \n input: {input:?} {ibits}\ + \n expected: {expected:<22?} {expbits}\ + \n actual: {self:<22?} {actbits}\ + ", + actbits = self.hex(), + expbits = expected.hex(), + ibits = input.hex(), + ); + + Ok(()) + } + } } } diff --git a/library/compiler-builtins/libm/crates/libm-test/src/test_traits.rs b/library/compiler-builtins/libm/crates/libm-test/src/test_traits.rs new file mode 100644 index 000000000000..c6f1f84ae05c --- /dev/null +++ b/library/compiler-builtins/libm/crates/libm-test/src/test_traits.rs @@ -0,0 +1,217 @@ +//! Traits related to testing. +//! +//! There are three main traits in this module: +//! +//! - `GenerateInput`: implemented on any types that create test cases. +//! - `TupleCall`: implemented on tuples to allow calling them as function arguments. +//! - `CheckOutput`: implemented on anything that is an output type for validation against an +//! expected value. + +use std::fmt; + +use anyhow::{Context, bail, ensure}; + +use crate::{Float, Hex, Int, TestResult}; + +/// Implement this on types that can generate a sequence of tuples for test input. +pub trait GenerateInput { + fn get_cases(&self) -> impl Iterator; +} + +/// Trait for calling a function with a tuple as arguments. +/// +/// Implemented on the tuple with the function signature as the generic (so we can use the same +/// tuple for multiple signatures). +pub trait TupleCall: fmt::Debug { + type Output; + fn call(self, f: Func) -> Self::Output; +} + +/// Context passed to [`CheckOutput`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CheckCtx { + /// Allowed ULP deviation + pub ulp: u32, + /// Function name. + pub fname: &'static str, + /// Source of truth for tests. + pub basis: CheckBasis, +} + +/// Possible items to test against +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum CheckBasis {} + +/// A trait to implement on any output type so we can verify it in a generic way. +pub trait CheckOutput: Sized { + /// Validate `self` (actual) and `expected` are the same. + /// + /// `input` is only used here for error messages. + fn validate<'a>(self, expected: Self, input: Input, ctx: &CheckCtx) -> TestResult; +} + +impl TupleCall R> for (T1,) +where + T1: fmt::Debug, +{ + type Output = R; + + fn call(self, f: fn(T1) -> R) -> Self::Output { + f(self.0) + } +} + +impl TupleCall R> for (T1, T2) +where + T1: fmt::Debug, + T2: fmt::Debug, +{ + type Output = R; + + fn call(self, f: fn(T1, T2) -> R) -> Self::Output { + f(self.0, self.1) + } +} + +impl TupleCall R> for (T1,) +where + T1: fmt::Debug, + T2: fmt::Debug + Default, +{ + type Output = (R, T2); + + fn call(self, f: fn(T1, &mut T2) -> R) -> Self::Output { + let mut t2 = T2::default(); + (f(self.0, &mut t2), t2) + } +} + +impl TupleCall R> for (T1, T2, T3) +where + T1: fmt::Debug, + T2: fmt::Debug, + T3: fmt::Debug, +{ + type Output = R; + + fn call(self, f: fn(T1, T2, T3) -> R) -> Self::Output { + f(self.0, self.1, self.2) + } +} + +impl TupleCall R> for (T1, T2) +where + T1: fmt::Debug, + T2: fmt::Debug, + T3: fmt::Debug + Default, +{ + type Output = (R, T3); + + fn call(self, f: fn(T1, T2, &mut T3) -> R) -> Self::Output { + let mut t3 = T3::default(); + (f(self.0, self.1, &mut t3), t3) + } +} + +impl TupleCall for (T1,) +where + T1: fmt::Debug, + T2: fmt::Debug + Default, + T3: fmt::Debug + Default, +{ + type Output = (T2, T3); + + fn call(self, f: fn(T1, &mut T2, &mut T3)) -> Self::Output { + let mut t2 = T2::default(); + let mut t3 = T3::default(); + f(self.0, &mut t2, &mut t3); + (t2, t3) + } +} + +// Implement for floats +impl CheckOutput for F +where + F: Float + Hex, + Input: Hex + fmt::Debug, + u32: TryFrom, +{ + fn validate<'a>(self, expected: Self, input: Input, ctx: &CheckCtx) -> TestResult { + // Create a wrapper function so we only need to `.with_context` once. + let inner = || -> TestResult { + // Check when both are NaNs + if self.is_nan() && expected.is_nan() { + ensure!(self.to_bits() == expected.to_bits(), "NaNs have different bitpatterns"); + // Nothing else to check + return Ok(()); + } else if self.is_nan() || expected.is_nan() { + // Check when only one is a NaN + bail!("real value != NaN") + } + + // Make sure that the signs are the same before checing ULP to avoid wraparound + let act_sig = self.signum(); + let exp_sig = expected.signum(); + ensure!(act_sig == exp_sig, "mismatched signs {act_sig} {exp_sig}"); + + if self.is_infinite() ^ expected.is_infinite() { + bail!("mismatched infinities"); + } + + let act_bits = self.to_bits().signed(); + let exp_bits = expected.to_bits().signed(); + + let ulp_diff = act_bits.checked_sub(exp_bits).unwrap().abs(); + + let ulp_u32 = u32::try_from(ulp_diff) + .map_err(|e| anyhow::anyhow!("{e:?}: ulp of {ulp_diff} exceeds u32::MAX"))?; + + let allowed_ulp = ctx.ulp; + ensure!(ulp_u32 <= allowed_ulp, "ulp {ulp_diff} > {allowed_ulp}",); + + Ok(()) + }; + + inner().with_context(|| { + format!( + "\ + \n input: {input:?} {ibits}\ + \n expected: {expected:<22?} {expbits}\ + \n actual: {self:<22?} {actbits}\ + ", + actbits = self.hex(), + expbits = expected.hex(), + ibits = input.hex(), + ) + }) + } +} + +/// Implement `CheckOutput` for combinations of types. +macro_rules! impl_tuples { + ($(($a:ty, $b:ty);)*) => { + $( + impl CheckOutput for ($a, $b) { + fn validate<'a>( + self, + expected: Self, + input: Input, + ctx: &CheckCtx, + ) -> TestResult { + self.0.validate(expected.0, input, ctx,) + .and_then(|()| self.1.validate(expected.1, input, ctx)) + .with_context(|| format!( + "full input {input:?} full actual {self:?} expected {expected:?}" + )) + } + } + )* + }; +} + +impl_tuples!( + (f32, i32); + (f64, i32); + (f32, f32); + (f64, f64); +);