From 86198a15d7f67ad61a7988d881e9af68a0bbf361 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Thu, 28 Dec 2023 08:41:51 +0100 Subject: [PATCH] make float intrinsics return non-deterministic NaN --- src/tools/miri/src/helpers.rs | 45 ++++++ src/tools/miri/src/operator.rs | 4 + src/tools/miri/src/shims/intrinsics/mod.rs | 166 +++++++++++---------- src/tools/miri/tests/pass/float_nan.rs | 98 ++++++++++++ 4 files changed, 238 insertions(+), 75 deletions(-) diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs index d2fd51b099ad..98f646da6b69 100644 --- a/src/tools/miri/src/helpers.rs +++ b/src/tools/miri/src/helpers.rs @@ -6,6 +6,7 @@ use std::time::Duration; use log::trace; use rustc_apfloat::ieee::{Double, Single}; +use rustc_apfloat::Float; use rustc_hir::def::{DefKind, Namespace}; use rustc_hir::def_id::{DefId, CRATE_DEF_INDEX}; use rustc_index::IndexVec; @@ -117,6 +118,50 @@ fn try_resolve_did(tcx: TyCtxt<'_>, path: &[&str], namespace: Option) } } +/// Convert a softfloat type to its corresponding hostfloat type. +pub trait ToHost { + type HostFloat; + fn to_host(self) -> Self::HostFloat; +} + +/// Convert a hostfloat type to its corresponding softfloat type. +pub trait ToSoft { + type SoftFloat; + fn to_soft(self) -> Self::SoftFloat; +} + +impl ToHost for rustc_apfloat::ieee::Double { + type HostFloat = f64; + + fn to_host(self) -> Self::HostFloat { + f64::from_bits(self.to_bits().try_into().unwrap()) + } +} + +impl ToSoft for f64 { + type SoftFloat = rustc_apfloat::ieee::Double; + + fn to_soft(self) -> Self::SoftFloat { + Float::from_bits(self.to_bits().into()) + } +} + +impl ToHost for rustc_apfloat::ieee::Single { + type HostFloat = f32; + + fn to_host(self) -> Self::HostFloat { + f32::from_bits(self.to_bits().try_into().unwrap()) + } +} + +impl ToSoft for f32 { + type SoftFloat = rustc_apfloat::ieee::Single; + + fn to_soft(self) -> Self::SoftFloat { + Float::from_bits(self.to_bits().into()) + } +} + impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { /// Checks if the given crate/module exists. diff --git a/src/tools/miri/src/operator.rs b/src/tools/miri/src/operator.rs index e5a437f95f0e..140764446969 100644 --- a/src/tools/miri/src/operator.rs +++ b/src/tools/miri/src/operator.rs @@ -118,4 +118,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { nan } } + + fn adjust_nan, F2: Float>(&self, f: F2, inputs: &[F1]) -> F2 { + if f.is_nan() { self.generate_nan(inputs) } else { f } + } } diff --git a/src/tools/miri/src/shims/intrinsics/mod.rs b/src/tools/miri/src/shims/intrinsics/mod.rs index 625ae3ef39e5..cc81ef6e6c91 100644 --- a/src/tools/miri/src/shims/intrinsics/mod.rs +++ b/src/tools/miri/src/shims/intrinsics/mod.rs @@ -15,7 +15,7 @@ use rustc_target::abi::Size; use crate::*; use atomic::EvalContextExt as _; -use helpers::check_arg_count; +use helpers::{check_arg_count, ToHost, ToSoft}; use simd::EvalContextExt as _; impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} @@ -146,12 +146,14 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let [f] = check_arg_count(args)?; let f = this.read_scalar(f)?.to_f32()?; // Can be implemented in soft-floats. + // This is a "bitwise" operation, so there's no NaN non-determinism. this.write_scalar(Scalar::from_f32(f.abs()), dest)?; } "fabsf64" => { let [f] = check_arg_count(args)?; let f = this.read_scalar(f)?.to_f64()?; // Can be implemented in soft-floats. + // This is a "bitwise" operation, so there's no NaN non-determinism. this.write_scalar(Scalar::from_f64(f.abs()), dest)?; } #[rustfmt::skip] @@ -170,25 +172,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { | "rintf32" => { let [f] = check_arg_count(args)?; + let f = this.read_scalar(f)?.to_f32()?; // FIXME: Using host floats. - let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); - let f = match intrinsic_name { - "sinf32" => f.sin(), - "cosf32" => f.cos(), - "sqrtf32" => f.sqrt(), - "expf32" => f.exp(), - "exp2f32" => f.exp2(), - "logf32" => f.ln(), - "log10f32" => f.log10(), - "log2f32" => f.log2(), - "floorf32" => f.floor(), - "ceilf32" => f.ceil(), - "truncf32" => f.trunc(), - "roundf32" => f.round(), - "rintf32" => f.round_ties_even(), + let f_host = f.to_host(); + let res = match intrinsic_name { + "sinf32" => f_host.sin(), + "cosf32" => f_host.cos(), + "sqrtf32" => f_host.sqrt(), + "expf32" => f_host.exp(), + "exp2f32" => f_host.exp2(), + "logf32" => f_host.ln(), + "log10f32" => f_host.log10(), + "log2f32" => f_host.log2(), + "floorf32" => f_host.floor(), + "ceilf32" => f_host.ceil(), + "truncf32" => f_host.trunc(), + "roundf32" => f_host.round(), + "rintf32" => f_host.round_ties_even(), _ => bug!(), }; - this.write_scalar(Scalar::from_u32(f.to_bits()), dest)?; + let res = res.to_soft(); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; } #[rustfmt::skip] @@ -207,25 +212,28 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { | "rintf64" => { let [f] = check_arg_count(args)?; + let f = this.read_scalar(f)?.to_f64()?; // FIXME: Using host floats. - let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); - let f = match intrinsic_name { - "sinf64" => f.sin(), - "cosf64" => f.cos(), - "sqrtf64" => f.sqrt(), - "expf64" => f.exp(), - "exp2f64" => f.exp2(), - "logf64" => f.ln(), - "log10f64" => f.log10(), - "log2f64" => f.log2(), - "floorf64" => f.floor(), - "ceilf64" => f.ceil(), - "truncf64" => f.trunc(), - "roundf64" => f.round(), - "rintf64" => f.round_ties_even(), + let f_host = f.to_host(); + let res = match intrinsic_name { + "sinf64" => f_host.sin(), + "cosf64" => f_host.cos(), + "sqrtf64" => f_host.sqrt(), + "expf64" => f_host.exp(), + "exp2f64" => f_host.exp2(), + "logf64" => f_host.ln(), + "log10f64" => f_host.log10(), + "log2f64" => f_host.log2(), + "floorf64" => f_host.floor(), + "ceilf64" => f_host.ceil(), + "truncf64" => f_host.trunc(), + "roundf64" => f_host.round(), + "rintf64" => f_host.round_ties_even(), _ => bug!(), }; - this.write_scalar(Scalar::from_u64(f.to_bits()), dest)?; + let res = res.to_soft(); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; } #[rustfmt::skip] @@ -272,6 +280,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { if !float_finite(&res)? { throw_ub_format!("`{intrinsic_name}` intrinsic produced non-finite value as result"); } + // This cannot be a NaN so we also don't have to apply any non-determinism. + // (Also, `wrapping_binary_op` already called `generate_nan` if needed.) this.write_immediate(*res, dest)?; } @@ -284,9 +294,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let a = this.read_scalar(a)?.to_f32()?; let b = this.read_scalar(b)?.to_f32()?; let res = match intrinsic_name { - "minnumf32" => a.min(b), - "maxnumf32" => a.max(b), - "copysignf32" => a.copy_sign(b), + "minnumf32" => this.adjust_nan(a.min(b), &[a, b]), + "maxnumf32" => this.adjust_nan(a.max(b), &[a, b]), + "copysignf32" => a.copy_sign(b), // bitwise, no NaN adjustments _ => bug!(), }; this.write_scalar(Scalar::from_f32(res), dest)?; @@ -301,68 +311,74 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let a = this.read_scalar(a)?.to_f64()?; let b = this.read_scalar(b)?.to_f64()?; let res = match intrinsic_name { - "minnumf64" => a.min(b), - "maxnumf64" => a.max(b), - "copysignf64" => a.copy_sign(b), + "minnumf64" => this.adjust_nan(a.min(b), &[a, b]), + "maxnumf64" => this.adjust_nan(a.max(b), &[a, b]), + "copysignf64" => a.copy_sign(b), // bitwise, no NaN adjustments _ => bug!(), }; this.write_scalar(Scalar::from_f64(res), dest)?; } - "powf32" => { - let [f, f2] = check_arg_count(args)?; - // FIXME: Using host floats. - let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); - let f2 = f32::from_bits(this.read_scalar(f2)?.to_u32()?); - let res = f.powf(f2); - this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; - } - - "powf64" => { - let [f, f2] = check_arg_count(args)?; - // FIXME: Using host floats. - let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); - let f2 = f64::from_bits(this.read_scalar(f2)?.to_u64()?); - let res = f.powf(f2); - this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; - } - "fmaf32" => { let [a, b, c] = check_arg_count(args)?; + let a = this.read_scalar(a)?.to_f32()?; + let b = this.read_scalar(b)?.to_f32()?; + let c = this.read_scalar(c)?.to_f32()?; // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11 - let a = f32::from_bits(this.read_scalar(a)?.to_u32()?); - let b = f32::from_bits(this.read_scalar(b)?.to_u32()?); - let c = f32::from_bits(this.read_scalar(c)?.to_u32()?); - let res = a.mul_add(b, c); - this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; + let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = this.adjust_nan(res, &[a, b, c]); + this.write_scalar(res, dest)?; } "fmaf64" => { let [a, b, c] = check_arg_count(args)?; + let a = this.read_scalar(a)?.to_f64()?; + let b = this.read_scalar(b)?.to_f64()?; + let c = this.read_scalar(c)?.to_f64()?; // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11 - let a = f64::from_bits(this.read_scalar(a)?.to_u64()?); - let b = f64::from_bits(this.read_scalar(b)?.to_u64()?); - let c = f64::from_bits(this.read_scalar(c)?.to_u64()?); - let res = a.mul_add(b, c); - this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; + let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft(); + let res = this.adjust_nan(res, &[a, b, c]); + this.write_scalar(res, dest)?; + } + + "powf32" => { + let [f1, f2] = check_arg_count(args)?; + let f1 = this.read_scalar(f1)?.to_f32()?; + let f2 = this.read_scalar(f2)?.to_f32()?; + // FIXME: Using host floats. + let res = f1.to_host().powf(f2.to_host()).to_soft(); + let res = this.adjust_nan(res, &[f1, f2]); + this.write_scalar(res, dest)?; + } + + "powf64" => { + let [f1, f2] = check_arg_count(args)?; + let f1 = this.read_scalar(f1)?.to_f64()?; + let f2 = this.read_scalar(f2)?.to_f64()?; + // FIXME: Using host floats. + let res = f1.to_host().powf(f2.to_host()).to_soft(); + let res = this.adjust_nan(res, &[f1, f2]); + this.write_scalar(res, dest)?; } "powif32" => { let [f, i] = check_arg_count(args)?; - // FIXME: Using host floats. - let f = f32::from_bits(this.read_scalar(f)?.to_u32()?); + let f = this.read_scalar(f)?.to_f32()?; let i = this.read_scalar(i)?.to_i32()?; - let res = f.powi(i); - this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?; + // FIXME: Using host floats. + let res = f.to_host().powi(i).to_soft(); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; } "powif64" => { let [f, i] = check_arg_count(args)?; - // FIXME: Using host floats. - let f = f64::from_bits(this.read_scalar(f)?.to_u64()?); + let f = this.read_scalar(f)?.to_f64()?; let i = this.read_scalar(i)?.to_i32()?; - let res = f.powi(i); - this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?; + // FIXME: Using host floats. + let res = f.to_host().powi(i).to_soft(); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; } "float_to_int_unchecked" => { diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs index 6ea034e2cda9..99151e5df7c0 100644 --- a/src/tools/miri/tests/pass/float_nan.rs +++ b/src/tools/miri/tests/pass/float_nan.rs @@ -249,6 +249,55 @@ fn test_f32() { check_all_outcomes(HashSet::from_iter([F32::nan(Neg, Signaling, all1_payload)]), || { F32::from(-all1_snan) }); + + // Intrinsics + let nan = F32::nan(Neg, Quiet, 0).as_f32(); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(f32::min(nan, nan)), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(nan.sin()), + ); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, 1), + F32::nan(Neg, Quiet, 1), + F32::nan(Pos, Quiet, 2), + F32::nan(Neg, Quiet, 2), + F32::nan(Pos, Quiet, all1_payload), + F32::nan(Neg, Quiet, all1_payload), + F32::nan(Pos, Signaling, all1_payload), + F32::nan(Neg, Signaling, all1_payload), + ]), + || F32::from(just1.mul_add(F32::nan(Neg, Quiet, 2).as_f32(), all1_snan)), + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(nan.powf(nan)), + ); + check_all_outcomes( + HashSet::from_iter([1.0f32.into()]), + || F32::from(1.0f32.powf(nan)), // special `pow` rule + ); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, 1), + F32::nan(Neg, Quiet, 1), + F32::nan(Pos, Signaling, 1), + F32::nan(Neg, Signaling, 1), + ]), + || F32::from(1.0f32.powf(F32::nan(Pos, Signaling, 1).as_f32())), // unspecified `pow` case + ); + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(nan.powi(1)), + ); } fn test_f64() { @@ -309,6 +358,55 @@ fn test_f64() { ]), || F64::from(just1 % all1_snan), ); + + // Intrinsics + let nan = F64::nan(Neg, Quiet, 0).as_f64(); + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(f64::min(nan, nan)), + ); + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(nan.sin()), + ); + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, 1), + F64::nan(Neg, Quiet, 1), + F64::nan(Pos, Quiet, 2), + F64::nan(Neg, Quiet, 2), + F64::nan(Pos, Quiet, all1_payload), + F64::nan(Neg, Quiet, all1_payload), + F64::nan(Pos, Signaling, all1_payload), + F64::nan(Neg, Signaling, all1_payload), + ]), + || F64::from(just1.mul_add(F64::nan(Neg, Quiet, 2).as_f64(), all1_snan)), + ); + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(nan.powf(nan)), + ); + check_all_outcomes( + HashSet::from_iter([1.0f64.into()]), + || F64::from(1.0f64.powf(nan)), // special `pow` rule + ); + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, 1), + F64::nan(Neg, Quiet, 1), + F64::nan(Pos, Signaling, 1), + F64::nan(Neg, Signaling, 1), + ]), + || F64::from(1.0f64.powf(F64::nan(Pos, Signaling, 1).as_f64())), // unspecified `pow` case + ); + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(nan.powi(1)), + ); } fn test_casts() {