Move round_* functions from shims::x86::sse41 module to shims::x86

This commit is contained in:
Eduardo Sánchez Muñoz 2023-12-07 12:37:48 +01:00
parent c37f4d6a1d
commit 8c5882ec45
2 changed files with 84 additions and 84 deletions

View file

@ -455,6 +455,89 @@ fn unary_op_ps<'tcx>(
Ok(())
}
// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, op_len);
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}
Ok(())
}
/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}
/// Converts each element of `op` from floating point to signed integer.
///
/// When the input value is NaN or out of range, fall back to minimum value.

View file

@ -1,8 +1,8 @@
use rustc_middle::mir;
use rustc_span::Symbol;
use rustc_target::abi::Size;
use rustc_target::spec::abi::Abi;
use super::{round_all, round_first};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;
@ -283,86 +283,3 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
Ok(EmulateForeignItemResult::NeedsJumping)
}
}
// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;
for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}
Ok(())
}
// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, op_len);
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}
Ok(())
}
/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}