From 720ba18931628f28fb690a2936e53a32233f88d5 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Wed, 12 Feb 2025 09:55:04 +0000 Subject: [PATCH] fma refactor 3/3: combine `fma` public API with its implementation Similar to other recent changes, just put public API in the same file as its generic implementation. To keep things slightly cleaner, split the default implementation from the `_wide` implementation. Also introduces a stub `fmaf16`. --- .../libm/etc/function-definitions.json | 9 +- .../compiler-builtins/libm/src/math/fma.rs | 140 +++++++----------- .../libm/src/math/fma_wide.rs | 97 ++++++++++++ .../compiler-builtins/libm/src/math/fmaf.rs | 21 --- .../libm/src/math/fmaf128.rs | 7 - .../libm/src/math/generic/mod.rs | 2 - .../compiler-builtins/libm/src/math/mod.rs | 10 +- 7 files changed, 161 insertions(+), 125 deletions(-) create mode 100644 library/compiler-builtins/libm/src/math/fma_wide.rs delete mode 100644 library/compiler-builtins/libm/src/math/fmaf.rs delete mode 100644 library/compiler-builtins/libm/src/math/fmaf128.rs diff --git a/library/compiler-builtins/libm/etc/function-definitions.json b/library/compiler-builtins/libm/etc/function-definitions.json index 63d9927ad6f8..a966852b1128 100644 --- a/library/compiler-builtins/libm/etc/function-definitions.json +++ b/library/compiler-builtins/libm/etc/function-definitions.json @@ -343,22 +343,19 @@ }, "fma": { "sources": [ - "src/math/fma.rs", - "src/math/generic/fma.rs" + "src/math/fma.rs" ], "type": "f64" }, "fmaf": { "sources": [ - "src/math/fmaf.rs", - "src/math/generic/fma.rs" + "src/math/fma_wide.rs" ], "type": "f32" }, "fmaf128": { "sources": [ - "src/math/fmaf128.rs", - "src/math/generic/fma.rs" + "src/math/fma.rs" ], "type": "f128" }, diff --git a/library/compiler-builtins/libm/src/math/fma.rs b/library/compiler-builtins/libm/src/math/fma.rs index cb1061cc38b1..a54984c936b7 100644 --- a/library/compiler-builtins/libm/src/math/fma.rs +++ b/library/compiler-builtins/libm/src/math/fma.rs @@ -1,23 +1,28 @@ /* SPDX-License-Identifier: MIT */ -/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */ +/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */ use super::super::support::{DInt, FpResult, HInt, IntTy, Round, Status}; -use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt}; +use super::{CastFrom, CastInto, Float, Int, MinInt}; -/// Fused multiply-add that works when there is not a larger float size available. Currently this -/// is still specialized only for `f64`. Computes `(x * y) + z`. +/// Fused multiply add (f64) +/// +/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] -pub fn fma(x: F, y: F, z: F) -> F -where - F: Float, - F: CastFrom, - F: CastFrom, - F::Int: HInt, - u32: CastInto, -{ +pub fn fma(x: f64, y: f64, z: f64) -> f64 { fma_round(x, y, z, Round::Nearest).val } +/// Fused multiply add (f128) +/// +/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). +#[cfg(f128_enabled)] +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 { + fma_round(x, y, z, Round::Nearest).val +} + +/// Fused multiply-add that works when there is not a larger float size available. Computes +/// `(x * y) + z`. pub fn fma_round(x: F, y: F, z: F, _round: Round) -> FpResult where F: Float, @@ -222,79 +227,7 @@ where } // Use our exponent to scale the final value. - FpResult::new(super::scalbn(r, e), status) -} - -/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`, -/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding. -pub fn fma_wide(x: F, y: F, z: F) -> F -where - F: Float + HFloat, - B: Float + DFloat, - B::Int: CastInto, - i32: CastFrom, -{ - fma_wide_round(x, y, z, Round::Nearest).val -} - -pub fn fma_wide_round(x: F, y: F, z: F, round: Round) -> FpResult -where - F: Float + HFloat, - B: Float + DFloat, - B::Int: CastInto, - i32: CastFrom, -{ - let one = IntTy::::ONE; - - let xy: B = x.widen() * y.widen(); - let mut result: B = xy + z.widen(); - let mut ui: B::Int = result.to_bits(); - let re = result.ex(); - let zb: B = z.widen(); - - let prec_diff = B::SIG_BITS - F::SIG_BITS; - let excess_prec = ui & ((one << prec_diff) - one); - let halfway = one << (prec_diff - 1); - - // Common case: the larger precision is fine if... - // This is not a halfway case - if excess_prec != halfway - // Or the result is NaN - || re == B::EXP_SAT - // Or the result is exact - || (result - xy == zb && result - zb == xy) - // Or the mode is something other than round to nearest - || round != Round::Nearest - { - let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32; - let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32; - - let mut status = Status::OK; - - if (min_inexact_exp..max_inexact_exp).contains(&re) && status.inexact() { - // This branch is never hit; requires previous operations to set a status - status.set_inexact(false); - - result = xy + z.widen(); - if status.inexact() { - status.set_underflow(true); - } else { - status.set_inexact(true); - } - } - - return FpResult { val: result.narrow(), status }; - } - - let neg = ui >> (B::BITS - 1) != IntTy::::ZERO; - let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy }; - if neg == (err < B::ZERO) { - ui += one; - } else { - ui -= one; - } - - FpResult::ok(B::from_bits(ui).narrow()) + FpResult::new(super::generic::scalbn(r, e), status) } /// Representation of `F` that has handled subnormals. @@ -363,6 +296,7 @@ impl Norm { mod tests { use super::*; + /// Test the generic `fma_round` algorithm for a given float. fn spec_test() where F: Float, @@ -375,6 +309,8 @@ mod tests { let y = F::from_bits(F::Int::ONE); let z = F::ZERO; + let fma = |x, y, z| fma_round(x, y, z, Round::Nearest).val; + // 754-2020 says "When the exact result of (a × b) + c is non-zero yet the result of // fusedMultiplyAdd is zero because of rounding, the zero result takes the sign of the // exact result" @@ -384,6 +320,11 @@ mod tests { assert_biteq!(fma(-x, -y, z), F::ZERO); } + #[test] + fn spec_test_f32() { + spec_test::(); + } + #[test] fn spec_test_f64() { spec_test::(); @@ -417,4 +358,33 @@ mod tests { fn spec_test_f128() { spec_test::(); } + + #[test] + fn fma_segfault() { + // These two inputs cause fma to segfault on release due to overflow: + assert_eq!( + fma( + -0.0000000000000002220446049250313, + -0.0000000000000002220446049250313, + -0.0000000000000002220446049250313 + ), + -0.00000000000000022204460492503126, + ); + + let result = fma(-0.992, -0.992, -0.992); + //force rounding to storage format on x87 to prevent superious errors. + #[cfg(all(target_arch = "x86", not(target_feature = "sse2")))] + let result = force_eval!(result); + assert_eq!(result, -0.007936000000000007,); + } + + #[test] + fn fma_sbb() { + assert_eq!(fma(-(1.0 - f64::EPSILON), f64::MIN, f64::MIN), -3991680619069439e277); + } + + #[test] + fn fma_underflow() { + assert_eq!(fma(1.1102230246251565e-16, -9.812526705433188e-305, 1.0894e-320), 0.0,); + } } diff --git a/library/compiler-builtins/libm/src/math/fma_wide.rs b/library/compiler-builtins/libm/src/math/fma_wide.rs new file mode 100644 index 000000000000..a8c1a548879d --- /dev/null +++ b/library/compiler-builtins/libm/src/math/fma_wide.rs @@ -0,0 +1,97 @@ +/* SPDX-License-Identifier: MIT */ +/* origin: musl src/math/fmaf.c. Ported to generic Rust algorithm in 2025, TG. */ + +use super::super::support::{FpResult, IntTy, Round, Status}; +use super::{CastFrom, CastInto, DFloat, Float, HFloat, MinInt}; + +// Placeholder so we can have `fmaf16` in the `Float` trait. +#[allow(unused)] +#[cfg(f16_enabled)] +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 { + unimplemented!() +} + +/// Floating multiply add (f32) +/// +/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub fn fmaf(x: f32, y: f32, z: f32) -> f32 { + fma_wide_round(x, y, z, Round::Nearest).val +} + +/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`, +/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding. +pub fn fma_wide_round(x: F, y: F, z: F, round: Round) -> FpResult +where + F: Float + HFloat, + B: Float + DFloat, + B::Int: CastInto, + i32: CastFrom, +{ + let one = IntTy::::ONE; + + let xy: B = x.widen() * y.widen(); + let mut result: B = xy + z.widen(); + let mut ui: B::Int = result.to_bits(); + let re = result.ex(); + let zb: B = z.widen(); + + let prec_diff = B::SIG_BITS - F::SIG_BITS; + let excess_prec = ui & ((one << prec_diff) - one); + let halfway = one << (prec_diff - 1); + + // Common case: the larger precision is fine if... + // This is not a halfway case + if excess_prec != halfway + // Or the result is NaN + || re == B::EXP_SAT + // Or the result is exact + || (result - xy == zb && result - zb == xy) + // Or the mode is something other than round to nearest + || round != Round::Nearest + { + let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32; + let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32; + + let mut status = Status::OK; + + if (min_inexact_exp..max_inexact_exp).contains(&re) && status.inexact() { + // This branch is never hit; requires previous operations to set a status + status.set_inexact(false); + + result = xy + z.widen(); + if status.inexact() { + status.set_underflow(true); + } else { + status.set_inexact(true); + } + } + + return FpResult { val: result.narrow(), status }; + } + + let neg = ui >> (B::BITS - 1) != IntTy::::ZERO; + let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy }; + if neg == (err < B::ZERO) { + ui += one; + } else { + ui -= one; + } + + FpResult::ok(B::from_bits(ui).narrow()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn issue_263() { + let a = f32::from_bits(1266679807); + let b = f32::from_bits(1300234242); + let c = f32::from_bits(1115553792); + let expected = f32::from_bits(1501560833); + assert_eq!(fmaf(a, b, c), expected); + } +} diff --git a/library/compiler-builtins/libm/src/math/fmaf.rs b/library/compiler-builtins/libm/src/math/fmaf.rs deleted file mode 100644 index 40d7f40d6173..000000000000 --- a/library/compiler-builtins/libm/src/math/fmaf.rs +++ /dev/null @@ -1,21 +0,0 @@ -/// Floating multiply add (f32) -/// -/// Computes `(x*y)+z`, rounded as one ternary operation: -/// Computes the value (as if) to infinite precision and rounds once to the result format, -/// according to the rounding mode characterized by the value of FLT_ROUNDS. -#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] -pub fn fmaf(x: f32, y: f32, z: f32) -> f32 { - super::generic::fma_wide(x, y, z) -} - -#[cfg(test)] -mod tests { - #[test] - fn issue_263() { - let a = f32::from_bits(1266679807); - let b = f32::from_bits(1300234242); - let c = f32::from_bits(1115553792); - let expected = f32::from_bits(1501560833); - assert_eq!(super::fmaf(a, b, c), expected); - } -} diff --git a/library/compiler-builtins/libm/src/math/fmaf128.rs b/library/compiler-builtins/libm/src/math/fmaf128.rs deleted file mode 100644 index 50f7360deb45..000000000000 --- a/library/compiler-builtins/libm/src/math/fmaf128.rs +++ /dev/null @@ -1,7 +0,0 @@ -/// Fused multiply add (f128) -/// -/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). -#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] -pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 { - return super::generic::fma(x, y, z); -} diff --git a/library/compiler-builtins/libm/src/math/generic/mod.rs b/library/compiler-builtins/libm/src/math/generic/mod.rs index f224eba731c3..9be185f809f1 100644 --- a/library/compiler-builtins/libm/src/math/generic/mod.rs +++ b/library/compiler-builtins/libm/src/math/generic/mod.rs @@ -3,7 +3,6 @@ mod copysign; mod fabs; mod fdim; mod floor; -mod fma; mod fmax; mod fmaximum; mod fmaximum_num; @@ -22,7 +21,6 @@ pub use copysign::copysign; pub use fabs::fabs; pub use fdim::fdim; pub use floor::floor; -pub use fma::{fma, fma_wide}; pub use fmax::fmax; pub use fmaximum::fmaximum; pub use fmaximum_num::fmaximum_num; diff --git a/library/compiler-builtins/libm/src/math/mod.rs b/library/compiler-builtins/libm/src/math/mod.rs index e58d79adc419..5fc8fa0b3cd0 100644 --- a/library/compiler-builtins/libm/src/math/mod.rs +++ b/library/compiler-builtins/libm/src/math/mod.rs @@ -164,7 +164,7 @@ mod fdimf; mod floor; mod floorf; mod fma; -mod fmaf; +mod fma_wide; mod fmin_fmax; mod fminimum_fmaximum; mod fminimum_fmaximum_num; @@ -271,7 +271,7 @@ pub use self::fdimf::fdimf; pub use self::floor::floor; pub use self::floorf::floorf; pub use self::fma::fma; -pub use self::fmaf::fmaf; +pub use self::fma_wide::fmaf; pub use self::fmin_fmax::{fmax, fmaxf, fmin, fminf}; pub use self::fminimum_fmaximum::{fmaximum, fmaximumf, fminimum, fminimumf}; pub use self::fminimum_fmaximum_num::{fmaximum_num, fmaximum_numf, fminimum_num, fminimum_numf}; @@ -370,6 +370,9 @@ cfg_if! { pub use self::sqrtf16::sqrtf16; pub use self::truncf16::truncf16; // verify-sorted-end + + #[allow(unused_imports)] + pub(crate) use self::fma_wide::fmaf16; } } @@ -381,7 +384,6 @@ cfg_if! { mod fabsf128; mod fdimf128; mod floorf128; - mod fmaf128; mod fmodf128; mod ldexpf128; mod roundf128; @@ -396,7 +398,7 @@ cfg_if! { pub use self::fabsf128::fabsf128; pub use self::fdimf128::fdimf128; pub use self::floorf128::floorf128; - pub use self::fmaf128::fmaf128; + pub use self::fma::fmaf128; pub use self::fmin_fmax::{fmaxf128, fminf128}; pub use self::fminimum_fmaximum::{fmaximumf128, fminimumf128}; pub use self::fminimum_fmaximum_num::{fmaximum_numf128, fminimum_numf128};