Convert fmaf to a generic implementation

Introduce a version of generic `fma` that works when there is a larger
hardware-backed float type available to compute the result with more
precision. This is currently used only for `f32`, but with some minor
adjustments it should work for `f16` as well.
This commit is contained in:
Trevor Gross 2025-02-07 03:41:05 +00:00 committed by Trevor Gross
parent 3e2de21344
commit aa4ae487d4
6 changed files with 129 additions and 99 deletions

View file

@ -1,103 +1,11 @@
/* origin: FreeBSD /usr/src/lib/msun/src/s_fmaf.c */
/*-
* Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/
use core::f32;
use core::ptr::read_volatile;
use super::fenv::{
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
};
/*
* Fused multiply-add: Compute x * y + z with a single rounding error.
*
* A double has more than twice as much precision than a float, so
* direct double-precision arithmetic suffices, except where double
* rounding occurs.
*/
/// Floating multiply add (f32)
///
/// Computes `(x*y)+z`, rounded as one ternary operation:
/// Computes the value (as if) to infinite precision and rounds once to the result format,
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
let xy: f64;
let mut result: f64;
let mut ui: u64;
let e: i32;
xy = x as f64 * y as f64;
result = xy + z as f64;
ui = result.to_bits();
e = (ui >> 52) as i32 & 0x7ff;
/* Common case: The double precision result is fine. */
if (
/* not a halfway case */
ui & 0x1fffffff) != 0x10000000 ||
/* NaN */
e == 0x7ff ||
/* exact */
(result - xy == z as f64 && result - z as f64 == xy) ||
/* not round-to-nearest */
fegetround() != FE_TONEAREST
{
/*
underflow may not be raised correctly, example:
fmaf(0x1p-120f, 0x1p-120f, 0x1p-149f)
*/
if ((0x3ff - 149)..(0x3ff - 126)).contains(&e) && fetestexcept(FE_INEXACT) != 0 {
feclearexcept(FE_INEXACT);
// prevent `xy + vz` from being CSE'd with `xy + z` above
let vz: f32 = unsafe { read_volatile(&z) };
result = xy + vz as f64;
if fetestexcept(FE_INEXACT) != 0 {
feraiseexcept(FE_UNDERFLOW);
} else {
feraiseexcept(FE_INEXACT);
}
}
z = result as f32;
return z;
}
/*
* If result is inexact, and exactly halfway between two float values,
* we need to adjust the low-order bit in the direction of the error.
*/
let neg = ui >> 63 != 0;
let err = if neg == (z as f64 > xy) { xy - result + z as f64 } else { z as f64 - result + xy };
if neg == (err < 0.0) {
ui += 1;
} else {
ui -= 1;
}
f64::from_bits(ui) as f32
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
super::generic::fma_wide(x, y, z)
}
#[cfg(test)]

View file

@ -1,10 +1,13 @@
/* SPDX-License-Identifier: MIT */
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */
use core::{f32, f64};
use super::super::fenv::{
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
};
use super::super::support::{DInt, HInt, IntTy};
use super::super::{CastFrom, CastInto, Float, Int, MinInt};
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt};
/// Fused multiply-add that works when there is not a larger float size available. Currently this
/// is still specialized only for `f64`. Computes `(x * y) + z`.
@ -212,6 +215,66 @@ where
super::scalbn(r, e)
}
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
pub fn fma_wide<F, B>(x: F, y: F, z: F) -> F
where
F: Float + HFloat<D = B>,
B: Float + DFloat<H = F>,
B::Int: CastInto<i32>,
i32: CastFrom<i32>,
{
let one = IntTy::<B>::ONE;
let xy: B = x.widen() * y.widen();
let mut result: B = xy + z.widen();
let mut ui: B::Int = result.to_bits();
let re = result.exp();
let zb: B = z.widen();
let prec_diff = B::SIG_BITS - F::SIG_BITS;
let excess_prec = ui & ((one << prec_diff) - one);
let halfway = one << (prec_diff - 1);
// Common case: the larger precision is fine if...
// This is not a halfway case
if excess_prec != halfway
// Or the result is NaN
|| re == B::EXP_SAT
// Or the result is exact
|| (result - xy == zb && result - zb == xy)
// Or the mode is something other than round to nearest
|| fegetround() != FE_TONEAREST
{
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;
if (min_inexact_exp..max_inexact_exp).contains(&re) && fetestexcept(FE_INEXACT) != 0 {
feclearexcept(FE_INEXACT);
// prevent `xy + vz` from being CSE'd with `xy + z` above
let vz: F = force_eval!(z);
result = xy + vz.widen();
if fetestexcept(FE_INEXACT) != 0 {
feraiseexcept(FE_UNDERFLOW);
} else {
feraiseexcept(FE_INEXACT);
}
}
return result.narrow();
}
let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
if neg == (err < B::ZERO) {
ui += one;
} else {
ui -= one;
}
B::from_bits(ui).narrow()
}
/// Representation of `F` that has handled subnormals.
#[derive(Clone, Copy, Debug)]
struct Norm<F: Float> {

View file

@ -18,7 +18,7 @@ pub use copysign::copysign;
pub use fabs::fabs;
pub use fdim::fdim;
pub use floor::floor;
pub use fma::fma;
pub use fma::{fma, fma_wide};
pub use fmax::fmax;
pub use fmin::fmin;
pub use fmod::fmod;

View file

@ -121,7 +121,7 @@ use self::rem_pio2::rem_pio2;
use self::rem_pio2_large::rem_pio2_large;
use self::rem_pio2f::rem_pio2f;
#[allow(unused_imports)]
use self::support::{CastFrom, CastInto, DInt, Float, HInt, Int, IntTy, MinInt};
use self::support::{CastFrom, CastInto, DFloat, DInt, Float, HFloat, HInt, Int, IntTy, MinInt};
// Public modules
mod acos;

View file

@ -276,6 +276,64 @@ pub const fn f64_from_bits(bits: u64) -> f64 {
unsafe { mem::transmute::<u64, f64>(bits) }
}
/// Trait for floats twice the bit width of another integer.
pub trait DFloat: Float {
/// Float that is half the bit width of the floatthis trait is implemented for.
type H: HFloat<D = Self>;
/// Narrow the float type.
fn narrow(self) -> Self::H;
}
/// Trait for floats half the bit width of another float.
pub trait HFloat: Float {
/// Float that is double the bit width of the float this trait is implemented for.
type D: DFloat<H = Self>;
/// Widen the float type.
fn widen(self) -> Self::D;
}
macro_rules! impl_d_float {
($($X:ident $D:ident),*) => {
$(
impl DFloat for $D {
type H = $X;
fn narrow(self) -> Self::H {
self as $X
}
}
)*
};
}
macro_rules! impl_h_float {
($($H:ident $X:ident),*) => {
$(
impl HFloat for $H {
type D = $X;
fn widen(self) -> Self::D {
self as $X
}
}
)*
};
}
impl_d_float!(f32 f64);
#[cfg(f16_enabled)]
impl_d_float!(f16 f32);
#[cfg(f128_enabled)]
impl_d_float!(f64 f128);
impl_h_float!(f32 f64);
#[cfg(f16_enabled)]
impl_h_float!(f16 f32);
#[cfg(f128_enabled)]
impl_h_float!(f64 f128);
#[cfg(test)]
mod tests {
use super::*;

View file

@ -5,7 +5,8 @@ mod float_traits;
pub mod hex_float;
mod int_traits;
pub use float_traits::{Float, IntTy};
#[allow(unused_imports)]
pub use float_traits::{DFloat, Float, HFloat, IntTy};
pub(crate) use float_traits::{f32_from_bits, f64_from_bits};
#[cfg(f16_enabled)]
#[allow(unused_imports)]