From 91bd957a21bc936f2d8732c658c0f4a1737361b3 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Wed, 4 Dec 2024 08:35:31 +0100 Subject: [PATCH] implement simd_relaxed_fma --- src/tools/miri/src/intrinsics/simd.rs | 17 +++- .../intrinsics/fmuladd_nondeterministic.rs | 91 +++++++++++++------ .../tests/pass/intrinsics/portable-simd.rs | 22 +++++ 3 files changed, 97 insertions(+), 33 deletions(-) diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs index 075b6f35e0ee..54bdd3f02c2c 100644 --- a/src/tools/miri/src/intrinsics/simd.rs +++ b/src/tools/miri/src/intrinsics/simd.rs @@ -1,4 +1,5 @@ use either::Either; +use rand::Rng; use rustc_abi::{Endian, HasDataLayout}; use rustc_apfloat::{Float, Round}; use rustc_middle::ty::FloatTy; @@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.write_scalar(val, &dest)?; } } - "fma" => { + "fma" | "relaxed_fma" => { let [a, b, c] = check_arg_count(args)?; let (a, a_len) = this.project_to_simd(a)?; let (b, b_len) = this.project_to_simd(b)?; @@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let c = this.read_scalar(&this.project_index(&c, i)?)?; let dest = this.project_index(&dest, i)?; + let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen(); + // Works for f32 and f64. // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468. let ty::Float(float_ty) = dest.layout.ty.kind() else { @@ -314,7 +317,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let a = a.to_f32()?; let b = b.to_f32()?; let c = c.to_f32()?; - let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = if fuse { + a.to_host().mul_add(b.to_host(), c.to_host()).to_soft() + } else { + ((a * b).value + c).value + }; let res = this.adjust_nan(res, &[a, b, c]); Scalar::from(res) } @@ -322,7 +329,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let a = a.to_f64()?; let b = b.to_f64()?; let c = c.to_f64()?; - let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = if fuse { + a.to_host().mul_add(b.to_host(), c.to_host()).to_soft() + } else { + ((a * b).value + c).value + }; let res = this.adjust_nan(res, &[a, b, c]); Scalar::from(res) } diff --git a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs index b46cf1ddf65d..b688405c4b18 100644 --- a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs +++ b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs @@ -1,44 +1,75 @@ -#![feature(core_intrinsics)] +#![feature(core_intrinsics, portable_simd)] +use std::intrinsics::simd::simd_relaxed_fma; use std::intrinsics::{fmuladdf32, fmuladdf64}; +use std::simd::prelude::*; -fn main() { - let mut saw_zero = false; - let mut saw_nonzero = false; +fn ensure_both_happen(f: impl Fn() -> bool) -> bool { + let mut saw_true = false; + let mut saw_false = false; for _ in 0..50 { - let a = std::hint::black_box(0.1_f64); - let b = std::hint::black_box(0.2); - let c = std::hint::black_box(-a * b); - // It is unspecified whether the following operation is fused or not. The - // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused. - let x = unsafe { fmuladdf64(a, b, c) }; - if x == 0.0 { - saw_zero = true; + let b = f(); + if b { + saw_true = true; } else { - saw_nonzero = true; + saw_false = true; + } + if saw_true && saw_false { + return true; } } + false +} + +fn main() { assert!( - saw_zero && saw_nonzero, + ensure_both_happen(|| { + let a = std::hint::black_box(0.1_f64); + let b = std::hint::black_box(0.2); + let c = std::hint::black_box(-a * b); + // It is unspecified whether the following operation is fused or not. The + // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused. + let x = unsafe { fmuladdf64(a, b, c) }; + x == 0.0 + }), "`fmuladdf64` failed to be evaluated as both fused and unfused" ); - let mut saw_zero = false; - let mut saw_nonzero = false; - for _ in 0..50 { - let a = std::hint::black_box(0.1_f32); - let b = std::hint::black_box(0.2); - let c = std::hint::black_box(-a * b); - // It is unspecified whether the following operation is fused or not. The - // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused. - let x = unsafe { fmuladdf32(a, b, c) }; - if x == 0.0 { - saw_zero = true; - } else { - saw_nonzero = true; - } - } assert!( - saw_zero && saw_nonzero, + ensure_both_happen(|| { + let a = std::hint::black_box(0.1_f32); + let b = std::hint::black_box(0.2); + let c = std::hint::black_box(-a * b); + // It is unspecified whether the following operation is fused or not. The + // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused. + let x = unsafe { fmuladdf32(a, b, c) }; + x == 0.0 + }), "`fmuladdf32` failed to be evaluated as both fused and unfused" ); + + assert!( + ensure_both_happen(|| { + let a = f32x4::splat(std::hint::black_box(0.1)); + let b = f32x4::splat(std::hint::black_box(0.2)); + let c = std::hint::black_box(-a * b); + let x = unsafe { simd_relaxed_fma(a, b, c) }; + // Whether we fuse or not is a per-element decision, so sometimes these should be + // the same and sometimes not. + x[0] == x[1] + }), + "`simd_relaxed_fma` failed to be evaluated as both fused and unfused" + ); + + assert!( + ensure_both_happen(|| { + let a = f64x4::splat(std::hint::black_box(0.1)); + let b = f64x4::splat(std::hint::black_box(0.2)); + let c = std::hint::black_box(-a * b); + let x = unsafe { simd_relaxed_fma(a, b, c) }; + // Whether we fuse or not is a per-element decision, so sometimes these should be + // the same and sometimes not. + x[0] == x[1] + }), + "`simd_relaxed_fma` failed to be evaluated as both fused and unfused" + ); } diff --git a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs index f560669dd635..acd3502f5289 100644 --- a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs +++ b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs @@ -40,6 +40,17 @@ fn simd_ops_f32() { f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)), f32x4::splat(f32::NEG_INFINITY) ); + + unsafe { + assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a); + assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a); + assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b); + assert_eq!( + intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)), + f32x4::splat(f32::NEG_INFINITY) + ); + } + assert_eq!((a * a).sqrt(), a); assert_eq!((b * b).sqrt(), b.abs()); @@ -94,6 +105,17 @@ fn simd_ops_f64() { f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)), f64x4::splat(f64::NEG_INFINITY) ); + + unsafe { + assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a); + assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a); + assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b); + assert_eq!( + intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)), + f64x4::splat(f64::NEG_INFINITY) + ); + } + assert_eq!((a * a).sqrt(), a); assert_eq!((b * b).sqrt(), b.abs());