Merge pull request #4071 from RalfJung/simd_relaxed_fma
implement simd_relaxed_fma
This commit is contained in:
commit
b4b0e0356c
3 changed files with 97 additions and 33 deletions
|
|
@ -1,4 +1,5 @@
|
|||
use either::Either;
|
||||
use rand::Rng;
|
||||
use rustc_abi::{Endian, HasDataLayout};
|
||||
use rustc_apfloat::{Float, Round};
|
||||
use rustc_middle::ty::FloatTy;
|
||||
|
|
@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
this.write_scalar(val, &dest)?;
|
||||
}
|
||||
}
|
||||
"fma" => {
|
||||
"fma" | "relaxed_fma" => {
|
||||
let [a, b, c] = check_arg_count(args)?;
|
||||
let (a, a_len) = this.project_to_simd(a)?;
|
||||
let (b, b_len) = this.project_to_simd(b)?;
|
||||
|
|
@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
let c = this.read_scalar(&this.project_index(&c, i)?)?;
|
||||
let dest = this.project_index(&dest, i)?;
|
||||
|
||||
let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen();
|
||||
|
||||
// Works for f32 and f64.
|
||||
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
|
||||
let ty::Float(float_ty) = dest.layout.ty.kind() else {
|
||||
|
|
@ -314,7 +317,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
let a = a.to_f32()?;
|
||||
let b = b.to_f32()?;
|
||||
let c = c.to_f32()?;
|
||||
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
|
||||
let res = if fuse {
|
||||
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
|
||||
} else {
|
||||
((a * b).value + c).value
|
||||
};
|
||||
let res = this.adjust_nan(res, &[a, b, c]);
|
||||
Scalar::from(res)
|
||||
}
|
||||
|
|
@ -322,7 +329,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
let a = a.to_f64()?;
|
||||
let b = b.to_f64()?;
|
||||
let c = c.to_f64()?;
|
||||
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
|
||||
let res = if fuse {
|
||||
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
|
||||
} else {
|
||||
((a * b).value + c).value
|
||||
};
|
||||
let res = this.adjust_nan(res, &[a, b, c]);
|
||||
Scalar::from(res)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,44 +1,75 @@
|
|||
#![feature(core_intrinsics)]
|
||||
#![feature(core_intrinsics, portable_simd)]
|
||||
use std::intrinsics::simd::simd_relaxed_fma;
|
||||
use std::intrinsics::{fmuladdf32, fmuladdf64};
|
||||
use std::simd::prelude::*;
|
||||
|
||||
fn main() {
|
||||
let mut saw_zero = false;
|
||||
let mut saw_nonzero = false;
|
||||
fn ensure_both_happen(f: impl Fn() -> bool) -> bool {
|
||||
let mut saw_true = false;
|
||||
let mut saw_false = false;
|
||||
for _ in 0..50 {
|
||||
let a = std::hint::black_box(0.1_f64);
|
||||
let b = std::hint::black_box(0.2);
|
||||
let c = std::hint::black_box(-a * b);
|
||||
// It is unspecified whether the following operation is fused or not. The
|
||||
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
|
||||
let x = unsafe { fmuladdf64(a, b, c) };
|
||||
if x == 0.0 {
|
||||
saw_zero = true;
|
||||
let b = f();
|
||||
if b {
|
||||
saw_true = true;
|
||||
} else {
|
||||
saw_nonzero = true;
|
||||
saw_false = true;
|
||||
}
|
||||
if saw_true && saw_false {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn main() {
|
||||
assert!(
|
||||
saw_zero && saw_nonzero,
|
||||
ensure_both_happen(|| {
|
||||
let a = std::hint::black_box(0.1_f64);
|
||||
let b = std::hint::black_box(0.2);
|
||||
let c = std::hint::black_box(-a * b);
|
||||
// It is unspecified whether the following operation is fused or not. The
|
||||
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
|
||||
let x = unsafe { fmuladdf64(a, b, c) };
|
||||
x == 0.0
|
||||
}),
|
||||
"`fmuladdf64` failed to be evaluated as both fused and unfused"
|
||||
);
|
||||
|
||||
let mut saw_zero = false;
|
||||
let mut saw_nonzero = false;
|
||||
for _ in 0..50 {
|
||||
let a = std::hint::black_box(0.1_f32);
|
||||
let b = std::hint::black_box(0.2);
|
||||
let c = std::hint::black_box(-a * b);
|
||||
// It is unspecified whether the following operation is fused or not. The
|
||||
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
|
||||
let x = unsafe { fmuladdf32(a, b, c) };
|
||||
if x == 0.0 {
|
||||
saw_zero = true;
|
||||
} else {
|
||||
saw_nonzero = true;
|
||||
}
|
||||
}
|
||||
assert!(
|
||||
saw_zero && saw_nonzero,
|
||||
ensure_both_happen(|| {
|
||||
let a = std::hint::black_box(0.1_f32);
|
||||
let b = std::hint::black_box(0.2);
|
||||
let c = std::hint::black_box(-a * b);
|
||||
// It is unspecified whether the following operation is fused or not. The
|
||||
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
|
||||
let x = unsafe { fmuladdf32(a, b, c) };
|
||||
x == 0.0
|
||||
}),
|
||||
"`fmuladdf32` failed to be evaluated as both fused and unfused"
|
||||
);
|
||||
|
||||
assert!(
|
||||
ensure_both_happen(|| {
|
||||
let a = f32x4::splat(std::hint::black_box(0.1));
|
||||
let b = f32x4::splat(std::hint::black_box(0.2));
|
||||
let c = std::hint::black_box(-a * b);
|
||||
let x = unsafe { simd_relaxed_fma(a, b, c) };
|
||||
// Whether we fuse or not is a per-element decision, so sometimes these should be
|
||||
// the same and sometimes not.
|
||||
x[0] == x[1]
|
||||
}),
|
||||
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
|
||||
);
|
||||
|
||||
assert!(
|
||||
ensure_both_happen(|| {
|
||||
let a = f64x4::splat(std::hint::black_box(0.1));
|
||||
let b = f64x4::splat(std::hint::black_box(0.2));
|
||||
let c = std::hint::black_box(-a * b);
|
||||
let x = unsafe { simd_relaxed_fma(a, b, c) };
|
||||
// Whether we fuse or not is a per-element decision, so sometimes these should be
|
||||
// the same and sometimes not.
|
||||
x[0] == x[1]
|
||||
}),
|
||||
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,17 @@ fn simd_ops_f32() {
|
|||
f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
|
||||
f32x4::splat(f32::NEG_INFINITY)
|
||||
);
|
||||
|
||||
unsafe {
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
|
||||
assert_eq!(
|
||||
intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)),
|
||||
f32x4::splat(f32::NEG_INFINITY)
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!((a * a).sqrt(), a);
|
||||
assert_eq!((b * b).sqrt(), b.abs());
|
||||
|
||||
|
|
@ -94,6 +105,17 @@ fn simd_ops_f64() {
|
|||
f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
|
||||
f64x4::splat(f64::NEG_INFINITY)
|
||||
);
|
||||
|
||||
unsafe {
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
|
||||
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
|
||||
assert_eq!(
|
||||
intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)),
|
||||
f64x4::splat(f64::NEG_INFINITY)
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!((a * a).sqrt(), a);
|
||||
assert_eq!((b * b).sqrt(), b.abs());
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue