Merge pull request #4071 from RalfJung/simd_relaxed_fma

implement simd_relaxed_fma
This commit is contained in:
Ralf Jung 2024-12-04 08:46:25 +00:00 committed by GitHub
commit b4b0e0356c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 97 additions and 33 deletions

View file

@ -1,4 +1,5 @@
use either::Either;
use rand::Rng;
use rustc_abi::{Endian, HasDataLayout};
use rustc_apfloat::{Float, Round};
use rustc_middle::ty::FloatTy;
@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(val, &dest)?;
}
}
"fma" => {
"fma" | "relaxed_fma" => {
let [a, b, c] = check_arg_count(args)?;
let (a, a_len) = this.project_to_simd(a)?;
let (b, b_len) = this.project_to_simd(b)?;
@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let c = this.read_scalar(&this.project_index(&c, i)?)?;
let dest = this.project_index(&dest, i)?;
let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen();
// Works for f32 and f64.
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
let ty::Float(float_ty) = dest.layout.ty.kind() else {
@ -314,7 +317,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let a = a.to_f32()?;
let b = b.to_f32()?;
let c = c.to_f32()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = if fuse {
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
} else {
((a * b).value + c).value
};
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}
@ -322,7 +329,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let a = a.to_f64()?;
let b = b.to_f64()?;
let c = c.to_f64()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = if fuse {
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
} else {
((a * b).value + c).value
};
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}

View file

@ -1,44 +1,75 @@
#![feature(core_intrinsics)]
#![feature(core_intrinsics, portable_simd)]
use std::intrinsics::simd::simd_relaxed_fma;
use std::intrinsics::{fmuladdf32, fmuladdf64};
use std::simd::prelude::*;
fn main() {
let mut saw_zero = false;
let mut saw_nonzero = false;
fn ensure_both_happen(f: impl Fn() -> bool) -> bool {
let mut saw_true = false;
let mut saw_false = false;
for _ in 0..50 {
let a = std::hint::black_box(0.1_f64);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
let x = unsafe { fmuladdf64(a, b, c) };
if x == 0.0 {
saw_zero = true;
let b = f();
if b {
saw_true = true;
} else {
saw_nonzero = true;
saw_false = true;
}
if saw_true && saw_false {
return true;
}
}
false
}
fn main() {
assert!(
saw_zero && saw_nonzero,
ensure_both_happen(|| {
let a = std::hint::black_box(0.1_f64);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
let x = unsafe { fmuladdf64(a, b, c) };
x == 0.0
}),
"`fmuladdf64` failed to be evaluated as both fused and unfused"
);
let mut saw_zero = false;
let mut saw_nonzero = false;
for _ in 0..50 {
let a = std::hint::black_box(0.1_f32);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
let x = unsafe { fmuladdf32(a, b, c) };
if x == 0.0 {
saw_zero = true;
} else {
saw_nonzero = true;
}
}
assert!(
saw_zero && saw_nonzero,
ensure_both_happen(|| {
let a = std::hint::black_box(0.1_f32);
let b = std::hint::black_box(0.2);
let c = std::hint::black_box(-a * b);
// It is unspecified whether the following operation is fused or not. The
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
let x = unsafe { fmuladdf32(a, b, c) };
x == 0.0
}),
"`fmuladdf32` failed to be evaluated as both fused and unfused"
);
assert!(
ensure_both_happen(|| {
let a = f32x4::splat(std::hint::black_box(0.1));
let b = f32x4::splat(std::hint::black_box(0.2));
let c = std::hint::black_box(-a * b);
let x = unsafe { simd_relaxed_fma(a, b, c) };
// Whether we fuse or not is a per-element decision, so sometimes these should be
// the same and sometimes not.
x[0] == x[1]
}),
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
);
assert!(
ensure_both_happen(|| {
let a = f64x4::splat(std::hint::black_box(0.1));
let b = f64x4::splat(std::hint::black_box(0.2));
let c = std::hint::black_box(-a * b);
let x = unsafe { simd_relaxed_fma(a, b, c) };
// Whether we fuse or not is a per-element decision, so sometimes these should be
// the same and sometimes not.
x[0] == x[1]
}),
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
);
}

View file

@ -40,6 +40,17 @@ fn simd_ops_f32() {
f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
f32x4::splat(f32::NEG_INFINITY)
);
unsafe {
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
assert_eq!(
intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)),
f32x4::splat(f32::NEG_INFINITY)
);
}
assert_eq!((a * a).sqrt(), a);
assert_eq!((b * b).sqrt(), b.abs());
@ -94,6 +105,17 @@ fn simd_ops_f64() {
f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
f64x4::splat(f64::NEG_INFINITY)
);
unsafe {
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
assert_eq!(
intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)),
f64x4::splat(f64::NEG_INFINITY)
);
}
assert_eq!((a * a).sqrt(), a);
assert_eq!((b * b).sqrt(), b.abs());