fma refactor 3/3: combine fma public API with its implementation

Similar to other recent changes, just put public API in the same file as
its generic implementation. To keep things slightly cleaner, split the
default implementation from the `_wide` implementation.

Also introduces a stub `fmaf16`.
This commit is contained in:
Trevor Gross 2025-02-12 09:55:04 +00:00
parent bcbdb0b74f
commit 720ba18931
7 changed files with 161 additions and 125 deletions

View file

@ -343,22 +343,19 @@
},
"fma": {
"sources": [
"src/math/fma.rs",
"src/math/generic/fma.rs"
"src/math/fma.rs"
],
"type": "f64"
},
"fmaf": {
"sources": [
"src/math/fmaf.rs",
"src/math/generic/fma.rs"
"src/math/fma_wide.rs"
],
"type": "f32"
},
"fmaf128": {
"sources": [
"src/math/fmaf128.rs",
"src/math/generic/fma.rs"
"src/math/fma.rs"
],
"type": "f128"
},

View file

@ -1,23 +1,28 @@
/* SPDX-License-Identifier: MIT */
/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
use super::super::support::{DInt, FpResult, HInt, IntTy, Round, Status};
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt};
use super::{CastFrom, CastInto, Float, Int, MinInt};
/// Fused multiply-add that works when there is not a larger float size available. Currently this
/// is still specialized only for `f64`. Computes `(x * y) + z`.
/// Fused multiply add (f64)
///
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fma<F>(x: F, y: F, z: F) -> F
where
F: Float,
F: CastFrom<F::SignedInt>,
F: CastFrom<i8>,
F::Int: HInt,
u32: CastInto<F::Int>,
{
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
fma_round(x, y, z, Round::Nearest).val
}
/// Fused multiply add (f128)
///
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg(f128_enabled)]
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 {
fma_round(x, y, z, Round::Nearest).val
}
/// Fused multiply-add that works when there is not a larger float size available. Computes
/// `(x * y) + z`.
pub fn fma_round<F>(x: F, y: F, z: F, _round: Round) -> FpResult<F>
where
F: Float,
@ -222,79 +227,7 @@ where
}
// Use our exponent to scale the final value.
FpResult::new(super::scalbn(r, e), status)
}
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
pub fn fma_wide<F, B>(x: F, y: F, z: F) -> F
where
F: Float + HFloat<D = B>,
B: Float + DFloat<H = F>,
B::Int: CastInto<i32>,
i32: CastFrom<i32>,
{
fma_wide_round(x, y, z, Round::Nearest).val
}
pub fn fma_wide_round<F, B>(x: F, y: F, z: F, round: Round) -> FpResult<F>
where
F: Float + HFloat<D = B>,
B: Float + DFloat<H = F>,
B::Int: CastInto<i32>,
i32: CastFrom<i32>,
{
let one = IntTy::<B>::ONE;
let xy: B = x.widen() * y.widen();
let mut result: B = xy + z.widen();
let mut ui: B::Int = result.to_bits();
let re = result.ex();
let zb: B = z.widen();
let prec_diff = B::SIG_BITS - F::SIG_BITS;
let excess_prec = ui & ((one << prec_diff) - one);
let halfway = one << (prec_diff - 1);
// Common case: the larger precision is fine if...
// This is not a halfway case
if excess_prec != halfway
// Or the result is NaN
|| re == B::EXP_SAT
// Or the result is exact
|| (result - xy == zb && result - zb == xy)
// Or the mode is something other than round to nearest
|| round != Round::Nearest
{
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;
let mut status = Status::OK;
if (min_inexact_exp..max_inexact_exp).contains(&re) && status.inexact() {
// This branch is never hit; requires previous operations to set a status
status.set_inexact(false);
result = xy + z.widen();
if status.inexact() {
status.set_underflow(true);
} else {
status.set_inexact(true);
}
}
return FpResult { val: result.narrow(), status };
}
let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
if neg == (err < B::ZERO) {
ui += one;
} else {
ui -= one;
}
FpResult::ok(B::from_bits(ui).narrow())
FpResult::new(super::generic::scalbn(r, e), status)
}
/// Representation of `F` that has handled subnormals.
@ -363,6 +296,7 @@ impl<F: Float> Norm<F> {
mod tests {
use super::*;
/// Test the generic `fma_round` algorithm for a given float.
fn spec_test<F>()
where
F: Float,
@ -375,6 +309,8 @@ mod tests {
let y = F::from_bits(F::Int::ONE);
let z = F::ZERO;
let fma = |x, y, z| fma_round(x, y, z, Round::Nearest).val;
// 754-2020 says "When the exact result of (a × b) + c is non-zero yet the result of
// fusedMultiplyAdd is zero because of rounding, the zero result takes the sign of the
// exact result"
@ -384,6 +320,11 @@ mod tests {
assert_biteq!(fma(-x, -y, z), F::ZERO);
}
#[test]
fn spec_test_f32() {
spec_test::<f32>();
}
#[test]
fn spec_test_f64() {
spec_test::<f64>();
@ -417,4 +358,33 @@ mod tests {
fn spec_test_f128() {
spec_test::<f128>();
}
#[test]
fn fma_segfault() {
// These two inputs cause fma to segfault on release due to overflow:
assert_eq!(
fma(
-0.0000000000000002220446049250313,
-0.0000000000000002220446049250313,
-0.0000000000000002220446049250313
),
-0.00000000000000022204460492503126,
);
let result = fma(-0.992, -0.992, -0.992);
//force rounding to storage format on x87 to prevent superious errors.
#[cfg(all(target_arch = "x86", not(target_feature = "sse2")))]
let result = force_eval!(result);
assert_eq!(result, -0.007936000000000007,);
}
#[test]
fn fma_sbb() {
assert_eq!(fma(-(1.0 - f64::EPSILON), f64::MIN, f64::MIN), -3991680619069439e277);
}
#[test]
fn fma_underflow() {
assert_eq!(fma(1.1102230246251565e-16, -9.812526705433188e-305, 1.0894e-320), 0.0,);
}
}

View file

@ -0,0 +1,97 @@
/* SPDX-License-Identifier: MIT */
/* origin: musl src/math/fmaf.c. Ported to generic Rust algorithm in 2025, TG. */
use super::super::support::{FpResult, IntTy, Round, Status};
use super::{CastFrom, CastInto, DFloat, Float, HFloat, MinInt};
// Placeholder so we can have `fmaf16` in the `Float` trait.
#[allow(unused)]
#[cfg(f16_enabled)]
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
unimplemented!()
}
/// Floating multiply add (f32)
///
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
fma_wide_round(x, y, z, Round::Nearest).val
}
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
pub fn fma_wide_round<F, B>(x: F, y: F, z: F, round: Round) -> FpResult<F>
where
F: Float + HFloat<D = B>,
B: Float + DFloat<H = F>,
B::Int: CastInto<i32>,
i32: CastFrom<i32>,
{
let one = IntTy::<B>::ONE;
let xy: B = x.widen() * y.widen();
let mut result: B = xy + z.widen();
let mut ui: B::Int = result.to_bits();
let re = result.ex();
let zb: B = z.widen();
let prec_diff = B::SIG_BITS - F::SIG_BITS;
let excess_prec = ui & ((one << prec_diff) - one);
let halfway = one << (prec_diff - 1);
// Common case: the larger precision is fine if...
// This is not a halfway case
if excess_prec != halfway
// Or the result is NaN
|| re == B::EXP_SAT
// Or the result is exact
|| (result - xy == zb && result - zb == xy)
// Or the mode is something other than round to nearest
|| round != Round::Nearest
{
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;
let mut status = Status::OK;
if (min_inexact_exp..max_inexact_exp).contains(&re) && status.inexact() {
// This branch is never hit; requires previous operations to set a status
status.set_inexact(false);
result = xy + z.widen();
if status.inexact() {
status.set_underflow(true);
} else {
status.set_inexact(true);
}
}
return FpResult { val: result.narrow(), status };
}
let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
if neg == (err < B::ZERO) {
ui += one;
} else {
ui -= one;
}
FpResult::ok(B::from_bits(ui).narrow())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn issue_263() {
let a = f32::from_bits(1266679807);
let b = f32::from_bits(1300234242);
let c = f32::from_bits(1115553792);
let expected = f32::from_bits(1501560833);
assert_eq!(fmaf(a, b, c), expected);
}
}

View file

@ -1,21 +0,0 @@
/// Floating multiply add (f32)
///
/// Computes `(x*y)+z`, rounded as one ternary operation:
/// Computes the value (as if) to infinite precision and rounds once to the result format,
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
super::generic::fma_wide(x, y, z)
}
#[cfg(test)]
mod tests {
#[test]
fn issue_263() {
let a = f32::from_bits(1266679807);
let b = f32::from_bits(1300234242);
let c = f32::from_bits(1115553792);
let expected = f32::from_bits(1501560833);
assert_eq!(super::fmaf(a, b, c), expected);
}
}

View file

@ -1,7 +0,0 @@
/// Fused multiply add (f128)
///
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf128(x: f128, y: f128, z: f128) -> f128 {
return super::generic::fma(x, y, z);
}

View file

@ -3,7 +3,6 @@ mod copysign;
mod fabs;
mod fdim;
mod floor;
mod fma;
mod fmax;
mod fmaximum;
mod fmaximum_num;
@ -22,7 +21,6 @@ pub use copysign::copysign;
pub use fabs::fabs;
pub use fdim::fdim;
pub use floor::floor;
pub use fma::{fma, fma_wide};
pub use fmax::fmax;
pub use fmaximum::fmaximum;
pub use fmaximum_num::fmaximum_num;

View file

@ -164,7 +164,7 @@ mod fdimf;
mod floor;
mod floorf;
mod fma;
mod fmaf;
mod fma_wide;
mod fmin_fmax;
mod fminimum_fmaximum;
mod fminimum_fmaximum_num;
@ -271,7 +271,7 @@ pub use self::fdimf::fdimf;
pub use self::floor::floor;
pub use self::floorf::floorf;
pub use self::fma::fma;
pub use self::fmaf::fmaf;
pub use self::fma_wide::fmaf;
pub use self::fmin_fmax::{fmax, fmaxf, fmin, fminf};
pub use self::fminimum_fmaximum::{fmaximum, fmaximumf, fminimum, fminimumf};
pub use self::fminimum_fmaximum_num::{fmaximum_num, fmaximum_numf, fminimum_num, fminimum_numf};
@ -370,6 +370,9 @@ cfg_if! {
pub use self::sqrtf16::sqrtf16;
pub use self::truncf16::truncf16;
// verify-sorted-end
#[allow(unused_imports)]
pub(crate) use self::fma_wide::fmaf16;
}
}
@ -381,7 +384,6 @@ cfg_if! {
mod fabsf128;
mod fdimf128;
mod floorf128;
mod fmaf128;
mod fmodf128;
mod ldexpf128;
mod roundf128;
@ -396,7 +398,7 @@ cfg_if! {
pub use self::fabsf128::fabsf128;
pub use self::fdimf128::fdimf128;
pub use self::floorf128::floorf128;
pub use self::fmaf128::fmaf128;
pub use self::fma::fmaf128;
pub use self::fmin_fmax::{fmaxf128, fminf128};
pub use self::fminimum_fmaximum::{fmaximumf128, fminimumf128};
pub use self::fminimum_fmaximum_num::{fmaximum_numf128, fminimum_numf128};