Convert fmaf to a generic implementation
Introduce a version of generic `fma` that works when there is a larger hardware-backed float type available to compute the result with more precision. This is currently used only for `f32`, but with some minor adjustments it should work for `f16` as well.
This commit is contained in:
parent
3e2de21344
commit
aa4ae487d4
6 changed files with 129 additions and 99 deletions
|
|
@ -1,103 +1,11 @@
|
|||
/* origin: FreeBSD /usr/src/lib/msun/src/s_fmaf.c */
|
||||
/*-
|
||||
* Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions
|
||||
* are met:
|
||||
* 1. Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
||||
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
||||
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
||||
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
|
||||
* SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
use core::f32;
|
||||
use core::ptr::read_volatile;
|
||||
|
||||
use super::fenv::{
|
||||
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
|
||||
};
|
||||
|
||||
/*
|
||||
* Fused multiply-add: Compute x * y + z with a single rounding error.
|
||||
*
|
||||
* A double has more than twice as much precision than a float, so
|
||||
* direct double-precision arithmetic suffices, except where double
|
||||
* rounding occurs.
|
||||
*/
|
||||
|
||||
/// Floating multiply add (f32)
|
||||
///
|
||||
/// Computes `(x*y)+z`, rounded as one ternary operation:
|
||||
/// Computes the value (as if) to infinite precision and rounds once to the result format,
|
||||
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
|
||||
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
|
||||
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
|
||||
let xy: f64;
|
||||
let mut result: f64;
|
||||
let mut ui: u64;
|
||||
let e: i32;
|
||||
|
||||
xy = x as f64 * y as f64;
|
||||
result = xy + z as f64;
|
||||
ui = result.to_bits();
|
||||
e = (ui >> 52) as i32 & 0x7ff;
|
||||
/* Common case: The double precision result is fine. */
|
||||
if (
|
||||
/* not a halfway case */
|
||||
ui & 0x1fffffff) != 0x10000000 ||
|
||||
/* NaN */
|
||||
e == 0x7ff ||
|
||||
/* exact */
|
||||
(result - xy == z as f64 && result - z as f64 == xy) ||
|
||||
/* not round-to-nearest */
|
||||
fegetround() != FE_TONEAREST
|
||||
{
|
||||
/*
|
||||
underflow may not be raised correctly, example:
|
||||
fmaf(0x1p-120f, 0x1p-120f, 0x1p-149f)
|
||||
*/
|
||||
if ((0x3ff - 149)..(0x3ff - 126)).contains(&e) && fetestexcept(FE_INEXACT) != 0 {
|
||||
feclearexcept(FE_INEXACT);
|
||||
// prevent `xy + vz` from being CSE'd with `xy + z` above
|
||||
let vz: f32 = unsafe { read_volatile(&z) };
|
||||
result = xy + vz as f64;
|
||||
if fetestexcept(FE_INEXACT) != 0 {
|
||||
feraiseexcept(FE_UNDERFLOW);
|
||||
} else {
|
||||
feraiseexcept(FE_INEXACT);
|
||||
}
|
||||
}
|
||||
z = result as f32;
|
||||
return z;
|
||||
}
|
||||
|
||||
/*
|
||||
* If result is inexact, and exactly halfway between two float values,
|
||||
* we need to adjust the low-order bit in the direction of the error.
|
||||
*/
|
||||
let neg = ui >> 63 != 0;
|
||||
let err = if neg == (z as f64 > xy) { xy - result + z as f64 } else { z as f64 - result + xy };
|
||||
if neg == (err < 0.0) {
|
||||
ui += 1;
|
||||
} else {
|
||||
ui -= 1;
|
||||
}
|
||||
f64::from_bits(ui) as f32
|
||||
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
|
||||
super::generic::fma_wide(x, y, z)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
/* SPDX-License-Identifier: MIT */
|
||||
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
|
||||
/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */
|
||||
|
||||
use core::{f32, f64};
|
||||
|
||||
use super::super::fenv::{
|
||||
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
|
||||
};
|
||||
use super::super::support::{DInt, HInt, IntTy};
|
||||
use super::super::{CastFrom, CastInto, Float, Int, MinInt};
|
||||
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt};
|
||||
|
||||
/// Fused multiply-add that works when there is not a larger float size available. Currently this
|
||||
/// is still specialized only for `f64`. Computes `(x * y) + z`.
|
||||
|
|
@ -212,6 +215,66 @@ where
|
|||
super::scalbn(r, e)
|
||||
}
|
||||
|
||||
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
|
||||
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
|
||||
pub fn fma_wide<F, B>(x: F, y: F, z: F) -> F
|
||||
where
|
||||
F: Float + HFloat<D = B>,
|
||||
B: Float + DFloat<H = F>,
|
||||
B::Int: CastInto<i32>,
|
||||
i32: CastFrom<i32>,
|
||||
{
|
||||
let one = IntTy::<B>::ONE;
|
||||
|
||||
let xy: B = x.widen() * y.widen();
|
||||
let mut result: B = xy + z.widen();
|
||||
let mut ui: B::Int = result.to_bits();
|
||||
let re = result.exp();
|
||||
let zb: B = z.widen();
|
||||
|
||||
let prec_diff = B::SIG_BITS - F::SIG_BITS;
|
||||
let excess_prec = ui & ((one << prec_diff) - one);
|
||||
let halfway = one << (prec_diff - 1);
|
||||
|
||||
// Common case: the larger precision is fine if...
|
||||
// This is not a halfway case
|
||||
if excess_prec != halfway
|
||||
// Or the result is NaN
|
||||
|| re == B::EXP_SAT
|
||||
// Or the result is exact
|
||||
|| (result - xy == zb && result - zb == xy)
|
||||
// Or the mode is something other than round to nearest
|
||||
|| fegetround() != FE_TONEAREST
|
||||
{
|
||||
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
|
||||
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;
|
||||
|
||||
if (min_inexact_exp..max_inexact_exp).contains(&re) && fetestexcept(FE_INEXACT) != 0 {
|
||||
feclearexcept(FE_INEXACT);
|
||||
// prevent `xy + vz` from being CSE'd with `xy + z` above
|
||||
let vz: F = force_eval!(z);
|
||||
result = xy + vz.widen();
|
||||
if fetestexcept(FE_INEXACT) != 0 {
|
||||
feraiseexcept(FE_UNDERFLOW);
|
||||
} else {
|
||||
feraiseexcept(FE_INEXACT);
|
||||
}
|
||||
}
|
||||
|
||||
return result.narrow();
|
||||
}
|
||||
|
||||
let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
|
||||
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
|
||||
if neg == (err < B::ZERO) {
|
||||
ui += one;
|
||||
} else {
|
||||
ui -= one;
|
||||
}
|
||||
|
||||
B::from_bits(ui).narrow()
|
||||
}
|
||||
|
||||
/// Representation of `F` that has handled subnormals.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Norm<F: Float> {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ pub use copysign::copysign;
|
|||
pub use fabs::fabs;
|
||||
pub use fdim::fdim;
|
||||
pub use floor::floor;
|
||||
pub use fma::fma;
|
||||
pub use fma::{fma, fma_wide};
|
||||
pub use fmax::fmax;
|
||||
pub use fmin::fmin;
|
||||
pub use fmod::fmod;
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ use self::rem_pio2::rem_pio2;
|
|||
use self::rem_pio2_large::rem_pio2_large;
|
||||
use self::rem_pio2f::rem_pio2f;
|
||||
#[allow(unused_imports)]
|
||||
use self::support::{CastFrom, CastInto, DInt, Float, HInt, Int, IntTy, MinInt};
|
||||
use self::support::{CastFrom, CastInto, DFloat, DInt, Float, HFloat, HInt, Int, IntTy, MinInt};
|
||||
|
||||
// Public modules
|
||||
mod acos;
|
||||
|
|
|
|||
|
|
@ -276,6 +276,64 @@ pub const fn f64_from_bits(bits: u64) -> f64 {
|
|||
unsafe { mem::transmute::<u64, f64>(bits) }
|
||||
}
|
||||
|
||||
/// Trait for floats twice the bit width of another integer.
|
||||
pub trait DFloat: Float {
|
||||
/// Float that is half the bit width of the floatthis trait is implemented for.
|
||||
type H: HFloat<D = Self>;
|
||||
|
||||
/// Narrow the float type.
|
||||
fn narrow(self) -> Self::H;
|
||||
}
|
||||
|
||||
/// Trait for floats half the bit width of another float.
|
||||
pub trait HFloat: Float {
|
||||
/// Float that is double the bit width of the float this trait is implemented for.
|
||||
type D: DFloat<H = Self>;
|
||||
|
||||
/// Widen the float type.
|
||||
fn widen(self) -> Self::D;
|
||||
}
|
||||
|
||||
macro_rules! impl_d_float {
|
||||
($($X:ident $D:ident),*) => {
|
||||
$(
|
||||
impl DFloat for $D {
|
||||
type H = $X;
|
||||
|
||||
fn narrow(self) -> Self::H {
|
||||
self as $X
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_h_float {
|
||||
($($H:ident $X:ident),*) => {
|
||||
$(
|
||||
impl HFloat for $H {
|
||||
type D = $X;
|
||||
|
||||
fn widen(self) -> Self::D {
|
||||
self as $X
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_d_float!(f32 f64);
|
||||
#[cfg(f16_enabled)]
|
||||
impl_d_float!(f16 f32);
|
||||
#[cfg(f128_enabled)]
|
||||
impl_d_float!(f64 f128);
|
||||
|
||||
impl_h_float!(f32 f64);
|
||||
#[cfg(f16_enabled)]
|
||||
impl_h_float!(f16 f32);
|
||||
#[cfg(f128_enabled)]
|
||||
impl_h_float!(f64 f128);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ mod float_traits;
|
|||
pub mod hex_float;
|
||||
mod int_traits;
|
||||
|
||||
pub use float_traits::{Float, IntTy};
|
||||
#[allow(unused_imports)]
|
||||
pub use float_traits::{DFloat, Float, HFloat, IntTy};
|
||||
pub(crate) use float_traits::{f32_from_bits, f64_from_bits};
|
||||
#[cfg(f16_enabled)]
|
||||
#[allow(unused_imports)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue