NaN non-determinism for SIMD intrinsics

This commit is contained in:
Ralf Jung 2023-12-28 12:14:01 +01:00
parent 0f98c0e610
commit ee42d1eb9f
2 changed files with 145 additions and 75 deletions

View file

@ -5,10 +5,17 @@ use rustc_span::{sym, Symbol};
use rustc_target::abi::{Endian, HasDataLayout};
use crate::helpers::{
bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool,
bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool, ToHost,
ToSoft,
};
use crate::*;
#[derive(Copy, Clone)]
pub(crate) enum MinMax {
Min,
Max,
}
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
/// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
@ -67,13 +74,17 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let op = this.read_immediate(&this.project_index(&op, i)?)?;
let dest = this.project_index(&dest, i)?;
let val = match which {
Op::MirOp(mir_op) => this.wrapping_unary_op(mir_op, &op)?.to_scalar(),
Op::MirOp(mir_op) => {
// This already does NaN adjustments
this.wrapping_unary_op(mir_op, &op)?.to_scalar()
}
Op::Abs => {
// Works for f32 and f64.
let ty::Float(float_ty) = op.layout.ty.kind() else {
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
let op = op.to_scalar();
// "Bitwise" operation, no NaN adjustments
match float_ty {
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@ -86,14 +97,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
// FIXME using host floats
match float_ty {
FloatTy::F32 => {
let f = f32::from_bits(op.to_scalar().to_u32()?);
let res = f.sqrt();
Scalar::from_u32(res.to_bits())
let f = op.to_scalar().to_f32()?;
let res = f.to_host().sqrt().to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
FloatTy::F64 => {
let f = f64::from_bits(op.to_scalar().to_u64()?);
let res = f.sqrt();
Scalar::from_u64(res.to_bits())
let f = op.to_scalar().to_f64()?;
let res = f.to_host().sqrt().to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
}
}
@ -105,11 +118,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
FloatTy::F32 => {
let f = op.to_scalar().to_f32()?;
let res = f.round_to_integral(rounding).value;
let res = this.adjust_nan(res, &[f]);
Scalar::from_f32(res)
}
FloatTy::F64 => {
let f = op.to_scalar().to_f64()?;
let res = f.round_to_integral(rounding).value;
let res = this.adjust_nan(res, &[f]);
Scalar::from_f64(res)
}
}
@ -157,8 +172,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
enum Op {
MirOp(BinOp),
SaturatingOp(BinOp),
FMax,
FMin,
FMinMax(MinMax),
WrappingOffset,
}
let which = match intrinsic_name {
@ -178,8 +192,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"le" => Op::MirOp(BinOp::Le),
"gt" => Op::MirOp(BinOp::Gt),
"ge" => Op::MirOp(BinOp::Ge),
"fmax" => Op::FMax,
"fmin" => Op::FMin,
"fmax" => Op::FMinMax(MinMax::Max),
"fmin" => Op::FMinMax(MinMax::Min),
"saturating_add" => Op::SaturatingOp(BinOp::Add),
"saturating_sub" => Op::SaturatingOp(BinOp::Sub),
"arith_offset" => Op::WrappingOffset,
@ -192,6 +206,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let dest = this.project_index(&dest, i)?;
let val = match which {
Op::MirOp(mir_op) => {
// This does NaN adjustments.
let (val, overflowed) = this.overflowing_binary_op(mir_op, &left, &right)?;
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
@ -225,11 +240,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
Scalar::from_maybe_pointer(offset_ptr, this)
}
Op::FMax => {
fmax_op(&left, &right)?
}
Op::FMin => {
fmin_op(&left, &right)?
Op::FMinMax(op) => {
this.fminmax_op(op, &left, &right)?
}
};
this.write_scalar(val, &dest)?;
@ -259,18 +271,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
};
let val = match float_ty {
FloatTy::F32 => {
let a = f32::from_bits(a.to_u32()?);
let b = f32::from_bits(b.to_u32()?);
let c = f32::from_bits(c.to_u32()?);
let res = a.mul_add(b, c);
Scalar::from_u32(res.to_bits())
let a = a.to_f32()?;
let b = b.to_f32()?;
let c = c.to_f32()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}
FloatTy::F64 => {
let a = f64::from_bits(a.to_u64()?);
let b = f64::from_bits(b.to_u64()?);
let c = f64::from_bits(c.to_u64()?);
let res = a.mul_add(b, c);
Scalar::from_u64(res.to_bits())
let a = a.to_f64()?;
let b = b.to_f64()?;
let c = c.to_f64()?;
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
let res = this.adjust_nan(res, &[a, b, c]);
Scalar::from(res)
}
};
this.write_scalar(val, &dest)?;
@ -295,8 +309,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
enum Op {
MirOp(BinOp),
MirOpBool(BinOp),
Max,
Min,
MinMax(MinMax),
}
let which = match intrinsic_name {
"reduce_and" => Op::MirOp(BinOp::BitAnd),
@ -304,8 +317,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"reduce_xor" => Op::MirOp(BinOp::BitXor),
"reduce_any" => Op::MirOpBool(BinOp::BitOr),
"reduce_all" => Op::MirOpBool(BinOp::BitAnd),
"reduce_max" => Op::Max,
"reduce_min" => Op::Min,
"reduce_max" => Op::MinMax(MinMax::Max),
"reduce_min" => Op::MinMax(MinMax::Min),
_ => unreachable!(),
};
@ -325,24 +338,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
let op = imm_from_bool(simd_element_to_bool(op)?);
this.wrapping_binary_op(mir_op, &res, &op)?
}
Op::Max => {
Op::MinMax(mmop) => {
if matches!(res.layout.ty.kind(), ty::Float(_)) {
ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
} else {
// Just boring integers, so NaNs to worry about
if this.wrapping_binary_op(BinOp::Ge, &res, &op)?.to_scalar().to_bool()? {
res
} else {
op
}
}
}
Op::Min => {
if matches!(res.layout.ty.kind(), ty::Float(_)) {
ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
} else {
// Just boring integers, so NaNs to worry about
if this.wrapping_binary_op(BinOp::Le, &res, &op)?.to_scalar().to_bool()? {
let mirop = match mmop {
MinMax::Min => BinOp::Le,
MinMax::Max => BinOp::Ge,
};
if this.wrapping_binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
res
} else {
op
@ -709,6 +714,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
Ok(())
}
fn fminmax_op(
&self,
op: MinMax,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
let this = self.eval_context_ref();
assert_eq!(left.layout.ty, right.layout.ty);
let ty::Float(float_ty) = left.layout.ty.kind() else {
bug!("fmax operand is not a float")
};
let left = left.to_scalar();
let right = right.to_scalar();
Ok(match float_ty {
FloatTy::F32 => {
let left = left.to_f32()?;
let right = right.to_f32()?;
let res = match op {
MinMax::Min => left.min(right),
MinMax::Max => left.max(right),
};
let res = this.adjust_nan(res, &[left, right]);
Scalar::from_f32(res)
}
FloatTy::F64 => {
let left = left.to_f64()?;
let right = right.to_f64()?;
let res = match op {
MinMax::Min => left.min(right),
MinMax::Max => left.max(right),
};
let res = this.adjust_nan(res, &[left, right]);
Scalar::from_f64(res)
}
})
}
}
fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
@ -719,31 +761,3 @@ fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
Endian::Big => vec_len - 1 - idx, // reverse order of bits
}
}
fn fmax_op<'tcx>(
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
assert_eq!(left.layout.ty, right.layout.ty);
let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmax operand is not a float") };
let left = left.to_scalar();
let right = right.to_scalar();
Ok(match float_ty {
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
})
}
fn fmin_op<'tcx>(
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
assert_eq!(left.layout.ty, right.layout.ty);
let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmin operand is not a float") };
let left = left.to_scalar();
let right = right.to_scalar();
Ok(match float_ty {
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
})
}

View file

@ -1,4 +1,4 @@
#![feature(float_gamma)]
#![feature(float_gamma, portable_simd, core_intrinsics, platform_intrinsics)]
use std::collections::HashSet;
use std::fmt;
use std::hash::Hash;
@ -535,6 +535,61 @@ fn test_casts() {
);
}
fn test_simd() {
use std::intrinsics::simd::*;
use std::simd::*;
extern "platform-intrinsic" {
fn simd_fsqrt<T>(x: T) -> T;
fn simd_ceil<T>(x: T) -> T;
fn simd_fma<T>(x: T, y: T, z: T) -> T;
}
let nan = F32::nan(Neg, Quiet, 0).as_f32();
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_div(f32x4::splat(0.0), f32x4::splat(0.0)) }[0]),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_fmin(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_fmax(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| {
F32::from(
unsafe { simd_fma(f32x4::splat(nan), f32x4::splat(nan), f32x4::splat(nan)) }[0],
)
},
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_reduce_add_ordered::<_, f32>(f32x4::splat(nan), nan) }),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_reduce_max::<_, f32>(f32x4::splat(nan)) }),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_fsqrt(f32x4::splat(nan)) }[0]),
);
check_all_outcomes(
HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
|| F32::from(unsafe { simd_ceil(f32x4::splat(nan)) }[0]),
);
// Casts
check_all_outcomes(
HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]),
|| F64::from(unsafe { simd_cast::<f32x4, f64x4>(f32x4::splat(nan)) }[0]),
);
}
fn main() {
// Check our constants against std, just to be sure.
// We add 1 since our numbers are the number of bits stored
@ -546,4 +601,5 @@ fn main() {
test_f32();
test_f64();
test_casts();
test_simd();
}