Implement all 16 AVX compare operators
`_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented.
The 16 AVX compare operators are now implemented and tested.
This commit is contained in:
parent
cee4c575f2
commit
81303e7ea5
4 changed files with 257 additions and 67 deletions
|
|
@ -119,53 +119,32 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
|
|||
}
|
||||
}
|
||||
|
||||
/// Floating point comparison operation
|
||||
///
|
||||
/// <https://www.felixcloutier.com/x86/cmpss>
|
||||
/// <https://www.felixcloutier.com/x86/cmpps>
|
||||
/// <https://www.felixcloutier.com/x86/cmpsd>
|
||||
/// <https://www.felixcloutier.com/x86/cmppd>
|
||||
#[derive(Copy, Clone)]
|
||||
enum FloatCmpOp {
|
||||
Eq,
|
||||
Lt,
|
||||
Le,
|
||||
Unord,
|
||||
Neq,
|
||||
/// Not less-than
|
||||
Nlt,
|
||||
/// Not less-or-equal
|
||||
Nle,
|
||||
/// Ordered, i.e. neither of them is NaN
|
||||
Ord,
|
||||
}
|
||||
|
||||
impl FloatCmpOp {
|
||||
/// Convert from the `imm` argument used to specify the comparison
|
||||
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
|
||||
fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
|
||||
match imm {
|
||||
0 => Ok(Self::Eq),
|
||||
1 => Ok(Self::Lt),
|
||||
2 => Ok(Self::Le),
|
||||
3 => Ok(Self::Unord),
|
||||
4 => Ok(Self::Neq),
|
||||
5 => Ok(Self::Nlt),
|
||||
6 => Ok(Self::Nle),
|
||||
7 => Ok(Self::Ord),
|
||||
imm => {
|
||||
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum FloatBinOp {
|
||||
/// Arithmetic operation
|
||||
Arith(mir::BinOp),
|
||||
/// Comparison
|
||||
Cmp(FloatCmpOp),
|
||||
///
|
||||
/// The semantics of this operator is a case distinction: we compare the two operands,
|
||||
/// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
|
||||
/// which class they fall into.
|
||||
///
|
||||
/// AVX supports all 16 combinations, SSE only a subset
|
||||
///
|
||||
/// <https://www.felixcloutier.com/x86/cmpss>
|
||||
/// <https://www.felixcloutier.com/x86/cmpps>
|
||||
/// <https://www.felixcloutier.com/x86/cmpsd>
|
||||
/// <https://www.felixcloutier.com/x86/cmppd>
|
||||
Cmp {
|
||||
/// Result when lhs < rhs
|
||||
gt: bool,
|
||||
/// Result when lhs > rhs
|
||||
lt: bool,
|
||||
/// Result when lhs == rhs
|
||||
eq: bool,
|
||||
/// Result when lhs is NaN or rhs is NaN
|
||||
unord: bool,
|
||||
},
|
||||
/// Minimum value (with SSE semantics)
|
||||
///
|
||||
/// <https://www.felixcloutier.com/x86/minss>
|
||||
|
|
@ -182,6 +161,44 @@ enum FloatBinOp {
|
|||
Max,
|
||||
}
|
||||
|
||||
impl FloatBinOp {
|
||||
/// Convert from the `imm` argument used to specify the comparison
|
||||
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
|
||||
fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
|
||||
// Only bits 0..=4 are used, remaining should be zero.
|
||||
if imm & !0b1_1111 != 0 {
|
||||
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
|
||||
}
|
||||
// Bit 4 specifies whether the operation is quiet or signaling, which
|
||||
// we do not care in Miri.
|
||||
// Bits 0..=2 specifies the operation.
|
||||
// `gt` indicates the result to be returned when the LHS is strictly
|
||||
// greater than the RHS, and so on.
|
||||
let (gt, lt, eq, unord) = match imm & 0b111 {
|
||||
// Equal
|
||||
0x0 => (false, false, true, false),
|
||||
// Less-than
|
||||
0x1 => (false, true, false, false),
|
||||
// Less-or-equal
|
||||
0x2 => (false, true, true, false),
|
||||
// Unordered (either is NaN)
|
||||
0x3 => (false, false, false, true),
|
||||
// Not equal
|
||||
0x4 => (true, true, false, true),
|
||||
// Not less-than
|
||||
0x5 => (true, false, true, true),
|
||||
// Not less-or-equal
|
||||
0x6 => (true, false, false, true),
|
||||
// Ordered (neither is NaN)
|
||||
0x7 => (true, true, true, false),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
// When bit 3 is 1 (only possible in AVX), unord is toggled.
|
||||
let unord = unord ^ (imm & 0b1000 != 0);
|
||||
Ok(Self::Cmp { gt, lt, eq, unord })
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs `which` scalar operation on `left` and `right` and returns
|
||||
/// the result.
|
||||
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
|
||||
|
|
@ -195,20 +212,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
|
|||
let res = this.wrapping_binary_op(which, left, right)?;
|
||||
Ok(res.to_scalar())
|
||||
}
|
||||
FloatBinOp::Cmp(which) => {
|
||||
FloatBinOp::Cmp { gt, lt, eq, unord } => {
|
||||
let left = left.to_scalar().to_float::<F>()?;
|
||||
let right = right.to_scalar().to_float::<F>()?;
|
||||
// FIXME: Make sure that these operations match the semantics
|
||||
// of cmpps/cmpss/cmppd/cmpsd
|
||||
let res = match which {
|
||||
FloatCmpOp::Eq => left == right,
|
||||
FloatCmpOp::Lt => left < right,
|
||||
FloatCmpOp::Le => left <= right,
|
||||
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
|
||||
FloatCmpOp::Neq => left != right,
|
||||
FloatCmpOp::Nlt => !(left < right),
|
||||
FloatCmpOp::Nle => !(left <= right),
|
||||
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
|
||||
|
||||
let res = match left.partial_cmp(&right) {
|
||||
None => unord,
|
||||
Some(std::cmp::Ordering::Less) => lt,
|
||||
Some(std::cmp::Ordering::Equal) => eq,
|
||||
Some(std::cmp::Ordering::Greater) => gt,
|
||||
};
|
||||
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;
|
|||
|
||||
use rand::Rng as _;
|
||||
|
||||
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
|
||||
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
|
||||
use crate::*;
|
||||
use shims::foreign_items::EmulateForeignItemResult;
|
||||
|
||||
|
|
@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
|
|||
|
||||
unary_op_ps(this, which, op, dest)?;
|
||||
}
|
||||
// Used to implement the _mm_cmp_ss function.
|
||||
// Used to implement the _mm_cmp*_ss functions.
|
||||
// Performs a comparison operation on the first component of `left`
|
||||
// and `right`, returning 0 if false or `u32::MAX` if true. The remaining
|
||||
// components are copied from `left`.
|
||||
// _mm_cmp_ss is actually an AVX function where the operation is specified
|
||||
// by a const parameter.
|
||||
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions
|
||||
// with hard-coded operations.
|
||||
"cmp.ss" => {
|
||||
let [left, right, imm] =
|
||||
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
|
||||
|
||||
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
|
||||
let which = FloatBinOp::cmp_from_imm(
|
||||
this.read_scalar(imm)?.to_i8()?,
|
||||
"llvm.x86.sse.cmp.ss",
|
||||
)?);
|
||||
)?;
|
||||
|
||||
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
|
||||
}
|
||||
// Used to implement the _mm_cmp_ps function.
|
||||
// Used to implement the _mm_cmp*_ps functions.
|
||||
// Performs a comparison operation on each component of `left`
|
||||
// and `right`. For each component, returns 0 if false or u32::MAX
|
||||
// if true.
|
||||
// _mm_cmp_ps is actually an AVX function where the operation is specified
|
||||
// by a const parameter.
|
||||
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions
|
||||
// with hard-coded operations.
|
||||
"cmp.ps" => {
|
||||
let [left, right, imm] =
|
||||
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
|
||||
|
||||
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
|
||||
let which = FloatBinOp::cmp_from_imm(
|
||||
this.read_scalar(imm)?.to_i8()?,
|
||||
"llvm.x86.sse.cmp.ps",
|
||||
)?);
|
||||
)?;
|
||||
|
||||
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
|
|||
use rustc_span::Symbol;
|
||||
use rustc_target::spec::abi::Abi;
|
||||
|
||||
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
|
||||
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
|
||||
use crate::*;
|
||||
use shims::foreign_items::EmulateForeignItemResult;
|
||||
|
||||
|
|
@ -461,18 +461,22 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
|
|||
this.write_scalar(res, &dest)?;
|
||||
}
|
||||
}
|
||||
// Used to implement the _mm_cmp*_sd function.
|
||||
// Used to implement the _mm_cmp*_sd functions.
|
||||
// Performs a comparison operation on the first component of `left`
|
||||
// and `right`, returning 0 if false or `u64::MAX` if true. The remaining
|
||||
// components are copied from `left`.
|
||||
// _mm_cmp_sd is actually an AVX function where the operation is specified
|
||||
// by a const parameter.
|
||||
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
|
||||
// with hard-coded operations.
|
||||
"cmp.sd" => {
|
||||
let [left, right, imm] =
|
||||
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
|
||||
|
||||
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
|
||||
let which = FloatBinOp::cmp_from_imm(
|
||||
this.read_scalar(imm)?.to_i8()?,
|
||||
"llvm.x86.sse2.cmp.sd",
|
||||
)?);
|
||||
)?;
|
||||
|
||||
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
|
||||
}
|
||||
|
|
@ -480,14 +484,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
|
|||
// Performs a comparison operation on each component of `left`
|
||||
// and `right`. For each component, returns 0 if false or `u64::MAX`
|
||||
// if true.
|
||||
// _mm_cmp_pd is actually an AVX function where the operation is specified
|
||||
// by a const parameter.
|
||||
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
|
||||
// with hard-coded operations.
|
||||
"cmp.pd" => {
|
||||
let [left, right, imm] =
|
||||
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
|
||||
|
||||
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
|
||||
let which = FloatBinOp::cmp_from_imm(
|
||||
this.read_scalar(imm)?.to_i8()?,
|
||||
"llvm.x86.sse2.cmp.pd",
|
||||
)?);
|
||||
)?;
|
||||
|
||||
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
|
||||
}
|
||||
|
|
|
|||
162
src/tools/miri/tests/pass/intrinsics-x86-avx.rs
Normal file
162
src/tools/miri/tests/pass/intrinsics-x86-avx.rs
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
// Ignore everything except x86 and x86_64
|
||||
// Any new targets that are added to CI should be ignored here.
|
||||
// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
|
||||
//@ignore-target-aarch64
|
||||
//@ignore-target-arm
|
||||
//@ignore-target-avr
|
||||
//@ignore-target-s390x
|
||||
//@ignore-target-thumbv7em
|
||||
//@ignore-target-wasm32
|
||||
//@compile-flags: -C target-feature=+avx
|
||||
|
||||
#[cfg(target_arch = "x86")]
|
||||
use std::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64::*;
|
||||
use std::mem::transmute;
|
||||
|
||||
fn main() {
|
||||
assert!(is_x86_feature_detected!("avx"));
|
||||
|
||||
unsafe {
|
||||
test_avx();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_avx() {
|
||||
fn expected_cmp<F: PartialOrd>(imm: i32, lhs: F, rhs: F, if_t: F, if_f: F) -> F {
|
||||
let res = match imm {
|
||||
_CMP_EQ_OQ => lhs == rhs,
|
||||
_CMP_LT_OS => lhs < rhs,
|
||||
_CMP_LE_OS => lhs <= rhs,
|
||||
_CMP_UNORD_Q => lhs.partial_cmp(&rhs).is_none(),
|
||||
_CMP_NEQ_UQ => lhs != rhs,
|
||||
_CMP_NLT_UQ => !(lhs < rhs),
|
||||
_CMP_NLE_UQ => !(lhs <= rhs),
|
||||
_CMP_ORD_Q => lhs.partial_cmp(&rhs).is_some(),
|
||||
_CMP_EQ_UQ => lhs == rhs || lhs.partial_cmp(&rhs).is_none(),
|
||||
_CMP_NGE_US => !(lhs >= rhs),
|
||||
_CMP_NGT_US => !(lhs > rhs),
|
||||
_CMP_FALSE_OQ => false,
|
||||
_CMP_NEQ_OQ => lhs != rhs && lhs.partial_cmp(&rhs).is_some(),
|
||||
_CMP_GE_OS => lhs >= rhs,
|
||||
_CMP_GT_OS => lhs > rhs,
|
||||
_CMP_TRUE_US => true,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if res { if_t } else { if_f }
|
||||
}
|
||||
fn expected_cmp_f32(imm: i32, lhs: f32, rhs: f32) -> f32 {
|
||||
expected_cmp(imm, lhs, rhs, f32::from_bits(u32::MAX), 0.0)
|
||||
}
|
||||
fn expected_cmp_f64(imm: i32, lhs: f64, rhs: f64) -> f64 {
|
||||
expected_cmp(imm, lhs, rhs, f64::from_bits(u64::MAX), 0.0)
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_mm_cmp_ss<const IMM: i32>() {
|
||||
let values = [
|
||||
(1.0, 1.0),
|
||||
(0.0, 1.0),
|
||||
(1.0, 0.0),
|
||||
(f32::NAN, 0.0),
|
||||
(0.0, f32::NAN),
|
||||
(f32::NAN, f32::NAN),
|
||||
];
|
||||
|
||||
for (lhs, rhs) in values {
|
||||
let a = _mm_setr_ps(lhs, 2.0, 3.0, 4.0);
|
||||
let b = _mm_setr_ps(rhs, 5.0, 6.0, 7.0);
|
||||
let r: [u32; 4] = transmute(_mm_cmp_ss::<IMM>(a, b));
|
||||
let e: [u32; 4] =
|
||||
transmute(_mm_setr_ps(expected_cmp_f32(IMM, lhs, rhs), 2.0, 3.0, 4.0));
|
||||
assert_eq!(r, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_mm_cmp_ps<const IMM: i32>() {
|
||||
let values = [
|
||||
(1.0, 1.0),
|
||||
(0.0, 1.0),
|
||||
(1.0, 0.0),
|
||||
(f32::NAN, 0.0),
|
||||
(0.0, f32::NAN),
|
||||
(f32::NAN, f32::NAN),
|
||||
];
|
||||
|
||||
for (lhs, rhs) in values {
|
||||
let a = _mm_set1_ps(lhs);
|
||||
let b = _mm_set1_ps(rhs);
|
||||
let r: [u32; 4] = transmute(_mm_cmp_ps::<IMM>(a, b));
|
||||
let e: [u32; 4] = transmute(_mm_set1_ps(expected_cmp_f32(IMM, lhs, rhs)));
|
||||
assert_eq!(r, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_mm_cmp_sd<const IMM: i32>() {
|
||||
let values = [
|
||||
(1.0, 1.0),
|
||||
(0.0, 1.0),
|
||||
(1.0, 0.0),
|
||||
(f64::NAN, 0.0),
|
||||
(0.0, f64::NAN),
|
||||
(f64::NAN, f64::NAN),
|
||||
];
|
||||
|
||||
for (lhs, rhs) in values {
|
||||
let a = _mm_setr_pd(lhs, 2.0);
|
||||
let b = _mm_setr_pd(rhs, 3.0);
|
||||
let r: [u64; 2] = transmute(_mm_cmp_sd::<IMM>(a, b));
|
||||
let e: [u64; 2] = transmute(_mm_setr_pd(expected_cmp_f64(IMM, lhs, rhs), 2.0));
|
||||
assert_eq!(r, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_mm_cmp_pd<const IMM: i32>() {
|
||||
let values = [
|
||||
(1.0, 1.0),
|
||||
(0.0, 1.0),
|
||||
(1.0, 0.0),
|
||||
(f64::NAN, 0.0),
|
||||
(0.0, f64::NAN),
|
||||
(f64::NAN, f64::NAN),
|
||||
];
|
||||
|
||||
for (lhs, rhs) in values {
|
||||
let a = _mm_set1_pd(lhs);
|
||||
let b = _mm_set1_pd(rhs);
|
||||
let r: [u64; 2] = transmute(_mm_cmp_pd::<IMM>(a, b));
|
||||
let e: [u64; 2] = transmute(_mm_set1_pd(expected_cmp_f64(IMM, lhs, rhs)));
|
||||
assert_eq!(r, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
unsafe fn test_cmp<const IMM: i32>() {
|
||||
test_mm_cmp_ss::<IMM>();
|
||||
test_mm_cmp_ps::<IMM>();
|
||||
test_mm_cmp_sd::<IMM>();
|
||||
test_mm_cmp_pd::<IMM>();
|
||||
}
|
||||
|
||||
test_cmp::<_CMP_EQ_OQ>();
|
||||
test_cmp::<_CMP_LT_OS>();
|
||||
test_cmp::<_CMP_LE_OS>();
|
||||
test_cmp::<_CMP_UNORD_Q>();
|
||||
test_cmp::<_CMP_NEQ_UQ>();
|
||||
test_cmp::<_CMP_NLT_UQ>();
|
||||
test_cmp::<_CMP_NLE_UQ>();
|
||||
test_cmp::<_CMP_ORD_Q>();
|
||||
test_cmp::<_CMP_EQ_UQ>();
|
||||
test_cmp::<_CMP_NGE_US>();
|
||||
test_cmp::<_CMP_NGT_US>();
|
||||
test_cmp::<_CMP_FALSE_OQ>();
|
||||
test_cmp::<_CMP_NEQ_OQ>();
|
||||
test_cmp::<_CMP_GE_OS>();
|
||||
test_cmp::<_CMP_GT_OS>();
|
||||
test_cmp::<_CMP_TRUE_US>();
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue