From ee42d1eb9f5983f8f281e87989617caba45007b6 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Thu, 28 Dec 2023 12:14:01 +0100 Subject: [PATCH] NaN non-determinism for SIMD intrinsics --- src/tools/miri/src/shims/intrinsics/simd.rs | 162 +++++++++++--------- src/tools/miri/tests/pass/float_nan.rs | 58 ++++++- 2 files changed, 145 insertions(+), 75 deletions(-) diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs index 2c8493d8aad1..d749182ed5e7 100644 --- a/src/tools/miri/src/shims/intrinsics/simd.rs +++ b/src/tools/miri/src/shims/intrinsics/simd.rs @@ -5,10 +5,17 @@ use rustc_span::{sym, Symbol}; use rustc_target::abi::{Endian, HasDataLayout}; use crate::helpers::{ - bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool, + bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool, ToHost, + ToSoft, }; use crate::*; +#[derive(Copy, Clone)] +pub(crate) enum MinMax { + Min, + Max, +} + impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed. @@ -67,13 +74,17 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let op = this.read_immediate(&this.project_index(&op, i)?)?; let dest = this.project_index(&dest, i)?; let val = match which { - Op::MirOp(mir_op) => this.wrapping_unary_op(mir_op, &op)?.to_scalar(), + Op::MirOp(mir_op) => { + // This already does NaN adjustments + this.wrapping_unary_op(mir_op, &op)?.to_scalar() + } Op::Abs => { // Works for f32 and f64. let ty::Float(float_ty) = op.layout.ty.kind() else { span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) }; let op = op.to_scalar(); + // "Bitwise" operation, no NaN adjustments match float_ty { FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()), FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()), @@ -86,14 +97,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { // FIXME using host floats match float_ty { FloatTy::F32 => { - let f = f32::from_bits(op.to_scalar().to_u32()?); - let res = f.sqrt(); - Scalar::from_u32(res.to_bits()) + let f = op.to_scalar().to_f32()?; + let res = f.to_host().sqrt().to_soft(); + let res = this.adjust_nan(res, &[f]); + Scalar::from(res) } FloatTy::F64 => { - let f = f64::from_bits(op.to_scalar().to_u64()?); - let res = f.sqrt(); - Scalar::from_u64(res.to_bits()) + let f = op.to_scalar().to_f64()?; + let res = f.to_host().sqrt().to_soft(); + let res = this.adjust_nan(res, &[f]); + Scalar::from(res) } } } @@ -105,11 +118,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { FloatTy::F32 => { let f = op.to_scalar().to_f32()?; let res = f.round_to_integral(rounding).value; + let res = this.adjust_nan(res, &[f]); Scalar::from_f32(res) } FloatTy::F64 => { let f = op.to_scalar().to_f64()?; let res = f.round_to_integral(rounding).value; + let res = this.adjust_nan(res, &[f]); Scalar::from_f64(res) } } @@ -157,8 +172,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { enum Op { MirOp(BinOp), SaturatingOp(BinOp), - FMax, - FMin, + FMinMax(MinMax), WrappingOffset, } let which = match intrinsic_name { @@ -178,8 +192,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { "le" => Op::MirOp(BinOp::Le), "gt" => Op::MirOp(BinOp::Gt), "ge" => Op::MirOp(BinOp::Ge), - "fmax" => Op::FMax, - "fmin" => Op::FMin, + "fmax" => Op::FMinMax(MinMax::Max), + "fmin" => Op::FMinMax(MinMax::Min), "saturating_add" => Op::SaturatingOp(BinOp::Add), "saturating_sub" => Op::SaturatingOp(BinOp::Sub), "arith_offset" => Op::WrappingOffset, @@ -192,6 +206,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let dest = this.project_index(&dest, i)?; let val = match which { Op::MirOp(mir_op) => { + // This does NaN adjustments. let (val, overflowed) = this.overflowing_binary_op(mir_op, &left, &right)?; if matches!(mir_op, BinOp::Shl | BinOp::Shr) { // Shifts have extra UB as SIMD operations that the MIR binop does not have. @@ -225,11 +240,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this); Scalar::from_maybe_pointer(offset_ptr, this) } - Op::FMax => { - fmax_op(&left, &right)? - } - Op::FMin => { - fmin_op(&left, &right)? + Op::FMinMax(op) => { + this.fminmax_op(op, &left, &right)? } }; this.write_scalar(val, &dest)?; @@ -259,18 +271,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { }; let val = match float_ty { FloatTy::F32 => { - let a = f32::from_bits(a.to_u32()?); - let b = f32::from_bits(b.to_u32()?); - let c = f32::from_bits(c.to_u32()?); - let res = a.mul_add(b, c); - Scalar::from_u32(res.to_bits()) + let a = a.to_f32()?; + let b = b.to_f32()?; + let c = c.to_f32()?; + let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = this.adjust_nan(res, &[a, b, c]); + Scalar::from(res) } FloatTy::F64 => { - let a = f64::from_bits(a.to_u64()?); - let b = f64::from_bits(b.to_u64()?); - let c = f64::from_bits(c.to_u64()?); - let res = a.mul_add(b, c); - Scalar::from_u64(res.to_bits()) + let a = a.to_f64()?; + let b = b.to_f64()?; + let c = c.to_f64()?; + let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = this.adjust_nan(res, &[a, b, c]); + Scalar::from(res) } }; this.write_scalar(val, &dest)?; @@ -295,8 +309,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { enum Op { MirOp(BinOp), MirOpBool(BinOp), - Max, - Min, + MinMax(MinMax), } let which = match intrinsic_name { "reduce_and" => Op::MirOp(BinOp::BitAnd), @@ -304,8 +317,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { "reduce_xor" => Op::MirOp(BinOp::BitXor), "reduce_any" => Op::MirOpBool(BinOp::BitOr), "reduce_all" => Op::MirOpBool(BinOp::BitAnd), - "reduce_max" => Op::Max, - "reduce_min" => Op::Min, + "reduce_max" => Op::MinMax(MinMax::Max), + "reduce_min" => Op::MinMax(MinMax::Min), _ => unreachable!(), }; @@ -325,24 +338,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let op = imm_from_bool(simd_element_to_bool(op)?); this.wrapping_binary_op(mir_op, &res, &op)? } - Op::Max => { + Op::MinMax(mmop) => { if matches!(res.layout.ty.kind(), ty::Float(_)) { - ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout) + ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout) } else { // Just boring integers, so NaNs to worry about - if this.wrapping_binary_op(BinOp::Ge, &res, &op)?.to_scalar().to_bool()? { - res - } else { - op - } - } - } - Op::Min => { - if matches!(res.layout.ty.kind(), ty::Float(_)) { - ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout) - } else { - // Just boring integers, so NaNs to worry about - if this.wrapping_binary_op(BinOp::Le, &res, &op)?.to_scalar().to_bool()? { + let mirop = match mmop { + MinMax::Min => BinOp::Le, + MinMax::Max => BinOp::Ge, + }; + if this.wrapping_binary_op(mirop, &res, &op)?.to_scalar().to_bool()? { res } else { op @@ -709,6 +714,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { } Ok(()) } + + fn fminmax_op( + &self, + op: MinMax, + left: &ImmTy<'tcx, Provenance>, + right: &ImmTy<'tcx, Provenance>, + ) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_ref(); + assert_eq!(left.layout.ty, right.layout.ty); + let ty::Float(float_ty) = left.layout.ty.kind() else { + bug!("fmax operand is not a float") + }; + let left = left.to_scalar(); + let right = right.to_scalar(); + Ok(match float_ty { + FloatTy::F32 => { + let left = left.to_f32()?; + let right = right.to_f32()?; + let res = match op { + MinMax::Min => left.min(right), + MinMax::Max => left.max(right), + }; + let res = this.adjust_nan(res, &[left, right]); + Scalar::from_f32(res) + } + FloatTy::F64 => { + let left = left.to_f64()?; + let right = right.to_f64()?; + let res = match op { + MinMax::Min => left.min(right), + MinMax::Max => left.max(right), + }; + let res = this.adjust_nan(res, &[left, right]); + Scalar::from_f64(res) + } + }) + } } fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 { @@ -719,31 +761,3 @@ fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 { Endian::Big => vec_len - 1 - idx, // reverse order of bits } } - -fn fmax_op<'tcx>( - left: &ImmTy<'tcx, Provenance>, - right: &ImmTy<'tcx, Provenance>, -) -> InterpResult<'tcx, Scalar> { - assert_eq!(left.layout.ty, right.layout.ty); - let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmax operand is not a float") }; - let left = left.to_scalar(); - let right = right.to_scalar(); - Ok(match float_ty { - FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)), - FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)), - }) -} - -fn fmin_op<'tcx>( - left: &ImmTy<'tcx, Provenance>, - right: &ImmTy<'tcx, Provenance>, -) -> InterpResult<'tcx, Scalar> { - assert_eq!(left.layout.ty, right.layout.ty); - let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmin operand is not a float") }; - let left = left.to_scalar(); - let right = right.to_scalar(); - Ok(match float_ty { - FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)), - FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)), - }) -} diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs index 5e717bdca007..fff103a776f4 100644 --- a/src/tools/miri/tests/pass/float_nan.rs +++ b/src/tools/miri/tests/pass/float_nan.rs @@ -1,4 +1,4 @@ -#![feature(float_gamma)] +#![feature(float_gamma, portable_simd, core_intrinsics, platform_intrinsics)] use std::collections::HashSet; use std::fmt; use std::hash::Hash; @@ -535,6 +535,61 @@ fn test_casts() { ); } +fn test_simd() { + use std::intrinsics::simd::*; + use std::simd::*; + + extern "platform-intrinsic" { + fn simd_fsqrt(x: T) -> T; + fn simd_ceil(x: T) -> T; + fn simd_fma(x: T, y: T, z: T) -> T; + } + + let nan = F32::nan(Neg, Quiet, 0).as_f32(); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_div(f32x4::splat(0.0), f32x4::splat(0.0)) }[0]), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_fmin(f32x4::splat(nan), f32x4::splat(nan)) }[0]), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_fmax(f32x4::splat(nan), f32x4::splat(nan)) }[0]), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || { + F32::from( + unsafe { simd_fma(f32x4::splat(nan), f32x4::splat(nan), f32x4::splat(nan)) }[0], + ) + }, + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_reduce_add_ordered::<_, f32>(f32x4::splat(nan), nan) }), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_reduce_max::<_, f32>(f32x4::splat(nan)) }), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_fsqrt(f32x4::splat(nan)) }[0]), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(unsafe { simd_ceil(f32x4::splat(nan)) }[0]), + ); + + // Casts + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(unsafe { simd_cast::(f32x4::splat(nan)) }[0]), + ); +} + fn main() { // Check our constants against std, just to be sure. // We add 1 since our numbers are the number of bits stored @@ -546,4 +601,5 @@ fn main() { test_f32(); test_f64(); test_casts(); + test_simd(); }