diff --git a/src/tools/miri/src/shims/x86/avx512.rs b/src/tools/miri/src/shims/x86/avx512.rs index 9231fc446919..b057a78b6c8e 100644 --- a/src/tools/miri/src/shims/x86/avx512.rs +++ b/src/tools/miri/src/shims/x86/avx512.rs @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::{permute, pmaddbw, pmaddwd, psadbw, pshufb}; +use super::{packssdw, packsswb, packusdw, packuswb, permute, pmaddbw, pmaddwd, psadbw, pshufb}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -130,6 +130,38 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { vpdpbusd(this, src, a, b, dest)?; } + // Used to implement the _mm512_packs_epi16 function + "packsswb.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + packsswb(this, a, b, dest)?; + } + // Used to implement the _mm512_packus_epi16 function + "packuswb.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + packuswb(this, a, b, dest)?; + } + // Used to implement the _mm512_packs_epi32 function + "packssdw.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + packssdw(this, a, b, dest)?; + } + // Used to implement the _mm512_packus_epi32 function + "packusdw.512" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?; + + let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + packusdw(this, a, b, dest)?; + } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) diff --git a/src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs index 7cc554ef5a3c..e1e23eda8428 100644 --- a/src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -1,6 +1,6 @@ // We're testing x86 target specific features //@only-target: x86_64 i686 -//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq,+avx512vnni +//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bw,+avx512bitalg,+avx512vpopcntdq,+avx512vnni #[cfg(target_arch = "x86")] use std::arch::x86::*; @@ -11,12 +11,14 @@ use std::mem::transmute; fn main() { assert!(is_x86_feature_detected!("avx512f")); assert!(is_x86_feature_detected!("avx512vl")); + assert!(is_x86_feature_detected!("avx512bw")); assert!(is_x86_feature_detected!("avx512bitalg")); assert!(is_x86_feature_detected!("avx512vpopcntdq")); assert!(is_x86_feature_detected!("avx512vnni")); unsafe { test_avx512(); + test_avx512bw(); test_avx512bitalg(); test_avx512vpopcntdq(); test_avx512ternarylogic(); @@ -579,9 +581,133 @@ unsafe fn test_avx512vnni() { test_mm512_dpbusd_epi32(); } +#[target_feature(enable = "avx512bw")] +unsafe fn test_avx512bw() { + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_packs_epi16() { + let a = _mm512_set1_epi16(120); + + // Because `packs` instructions do signed saturation, we expect + // that any value over `i8::MAX` will be saturated to `i8::MAX`, and any value + // less than `i8::MIN` will also be saturated to `i8::MIN`. + let b = _mm512_set_epi16( + 200, 200, 200, 200, 200, 200, 200, 200, -200, -200, -200, -200, -200, -200, -200, -200, + 200, 200, 200, 200, 200, 200, 200, 200, -200, -200, -200, -200, -200, -200, -200, -200, + ); + + // The pack* family of instructions in x86 operate in blocks + // of 128-bit lanes, meaning the first 128-bit lane in `a` is converted and written + // then the first 128-bit lane of `b`, followed by the second 128-bit lane in `a`, etc... + // Because we are going from 16-bits to 8-bits our 128-bit block becomes 64-bits in + // the output register. + // This leaves us with 8x 8-bit values interleaved in the final register. + #[rustfmt::skip] + const DST: [i8; 64] = [ + 120, 120, 120, 120, 120, 120, 120, 120, + i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, + 120, 120, 120, 120, 120, 120, 120, 120, + i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, + 120, 120, 120, 120, 120, 120, 120, 120, + i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, i8::MIN, + 120, 120, 120, 120, 120, 120, 120, 120, + i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, i8::MAX, + ]; + let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>()); + assert_eq_m512i(_mm512_packs_epi16(a, b), dst); + } + test_mm512_packs_epi16(); + + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_packus_epi16() { + let a = _mm512_set1_epi16(120); + + // Because `packus` instructions do unsigned saturation, we expect + // that any value over `u8::MAX` will be saturated to `u8::MAX`, and any value + // less than `u8::MIN` will also be saturated to `u8::MIN`. + let b = _mm512_set_epi16( + 300, 300, 300, 300, 300, 300, 300, 300, -200, -200, -200, -200, -200, -200, -200, -200, + 300, 300, 300, 300, 300, 300, 300, 300, -200, -200, -200, -200, -200, -200, -200, -200, + ); + + // See `test_mm512_packs_epi16` for an explanation of the output structure. + #[rustfmt::skip] + const DST: [u8; 64] = [ + 120, 120, 120, 120, 120, 120, 120, 120, + u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, + 120, 120, 120, 120, 120, 120, 120, 120, + u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, + 120, 120, 120, 120, 120, 120, 120, 120, + u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, u8::MIN, + 120, 120, 120, 120, 120, 120, 120, 120, + u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, u8::MAX, + ]; + let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>()); + assert_eq_m512i(_mm512_packus_epi16(a, b), dst); + } + test_mm512_packus_epi16(); + + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_packs_epi32() { + let a = _mm512_set1_epi32(8_000); + + // Because `packs` instructions do signed saturation, we expect + // that any value over `i16::MAX` will be saturated to `i16::MAX`, and any value + // less than `i16::MIN` will also be saturated to `i16::MIN`. + let b = _mm512_set_epi32( + 50_000, 50_000, 50_000, 50_000, -50_000, -50_000, -50_000, -50_000, 50_000, 50_000, + 50_000, 50_000, -50_000, -50_000, -50_000, -50_000, + ); + + // See `test_mm512_packs_epi16` for an explanation of the output structure. + #[rustfmt::skip] + const DST: [i16; 32] = [ + 8_000, 8_000, 8_000, 8_000, + i16::MIN, i16::MIN, i16::MIN, i16::MIN, + 8_000, 8_000, 8_000, 8_000, + i16::MAX, i16::MAX, i16::MAX, i16::MAX, + 8_000, 8_000, 8_000, 8_000, + i16::MIN, i16::MIN, i16::MIN, i16::MIN, + 8_000, 8_000, 8_000, 8_000, + i16::MAX, i16::MAX, i16::MAX, i16::MAX, + ]; + let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>()); + assert_eq_m512i(_mm512_packs_epi32(a, b), dst); + } + test_mm512_packs_epi32(); + + #[target_feature(enable = "avx512bw")] + unsafe fn test_mm512_packus_epi32() { + let a = _mm512_set1_epi32(8_000); + + // Because `packus` instructions do unsigned saturation, we expect + // that any value over `u16::MAX` will be saturated to `u16::MAX`, and any value + // less than `u16::MIN` will also be saturated to `u16::MIN`. + let b = _mm512_set_epi32( + 80_000, 80_000, 80_000, 80_000, -50_000, -50_000, -50_000, -50_000, 80_000, 80_000, + 80_000, 80_000, -50_000, -50_000, -50_000, -50_000, + ); + + // See `test_mm512_packs_epi16` for an explanation of the output structure. + #[rustfmt::skip] + const DST: [u16; 32] = [ + 8_000, 8_000, 8_000, 8_000, + u16::MIN, u16::MIN, u16::MIN, u16::MIN, + 8_000, 8_000, 8_000, 8_000, + u16::MAX, u16::MAX, u16::MAX, u16::MAX, + 8_000, 8_000, 8_000, 8_000, + u16::MIN, u16::MIN, u16::MIN, u16::MIN, + 8_000, 8_000, 8_000, 8_000, + u16::MAX, u16::MAX, u16::MAX, u16::MAX, + ]; + let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>()); + assert_eq_m512i(_mm512_packus_epi32(a, b), dst); + } + test_mm512_packus_epi32(); +} + #[track_caller] unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) { - assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b)) + assert_eq!(transmute::<_, [u16; 32]>(a), transmute::<_, [u16; 32]>(b)) } #[track_caller]