add an avx512 psad shim
also combine the sse2 and avx2 version into one generic function for all 3
This commit is contained in:
parent
a3955227a8
commit
04d97bc964
5 changed files with 98 additions and 64 deletions
|
|
@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
|
|||
|
||||
use super::{
|
||||
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
|
||||
packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
|
||||
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
|
||||
};
|
||||
use crate::*;
|
||||
|
||||
|
|
@ -241,41 +241,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
}
|
||||
}
|
||||
// Used to implement the _mm256_sad_epu8 function.
|
||||
// Compute the absolute differences of packed unsigned 8-bit integers
|
||||
// in `left` and `right`, then horizontally sum each consecutive 8
|
||||
// differences to produce four unsigned 16-bit integers, and pack
|
||||
// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
|
||||
// in `dest`.
|
||||
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8
|
||||
"psad.bw" => {
|
||||
let [left, right] =
|
||||
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
|
||||
|
||||
let (left, left_len) = this.project_to_simd(left)?;
|
||||
let (right, right_len) = this.project_to_simd(right)?;
|
||||
let (dest, dest_len) = this.project_to_simd(dest)?;
|
||||
|
||||
assert_eq!(left_len, right_len);
|
||||
assert_eq!(left_len, dest_len.strict_mul(8));
|
||||
|
||||
for i in 0..dest_len {
|
||||
let dest = this.project_index(&dest, i)?;
|
||||
|
||||
let mut acc: u16 = 0;
|
||||
for j in 0..8 {
|
||||
let src_index = i.strict_mul(8).strict_add(j);
|
||||
|
||||
let left = this.project_index(&left, src_index)?;
|
||||
let left = this.read_scalar(&left)?.to_u8()?;
|
||||
|
||||
let right = this.project_index(&right, src_index)?;
|
||||
let right = this.read_scalar(&right)?.to_u8()?;
|
||||
|
||||
acc = acc.strict_add(left.abs_diff(right).into());
|
||||
}
|
||||
|
||||
this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
|
||||
}
|
||||
psadbw(this, left, right, dest)?
|
||||
}
|
||||
// Used to implement the _mm256_shuffle_epi8 intrinsic.
|
||||
// Shuffles bytes from `left` using `right` as pattern.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use rustc_middle::ty::Ty;
|
|||
use rustc_span::Symbol;
|
||||
use rustc_target::callconv::FnAbi;
|
||||
|
||||
use super::psadbw;
|
||||
use crate::*;
|
||||
|
||||
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
|
||||
|
|
@ -78,6 +79,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
this.write_scalar(Scalar::from_u32(r), &d_lane)?;
|
||||
}
|
||||
}
|
||||
// Used to implement the _mm512_sad_epu8 function.
|
||||
"psad.bw.512" => {
|
||||
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
|
||||
|
||||
let [left, right] =
|
||||
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
|
||||
|
||||
psadbw(this, left, right, dest)?
|
||||
}
|
||||
_ => return interp_ok(EmulateItemResult::NotSupported),
|
||||
}
|
||||
interp_ok(EmulateItemResult::NeedsReturn)
|
||||
|
|
|
|||
|
|
@ -1038,6 +1038,54 @@ fn mpsadbw<'tcx>(
|
|||
interp_ok(())
|
||||
}
|
||||
|
||||
/// Compute the absolute differences of packed unsigned 8-bit integers
|
||||
/// in `left` and `right`, then horizontally sum each consecutive 8
|
||||
/// differences to produce unsigned 16-bit integers, and pack
|
||||
/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
|
||||
/// in `dest`.
|
||||
///
|
||||
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8>
|
||||
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8>
|
||||
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_sad_epu8>
|
||||
fn psadbw<'tcx>(
|
||||
ecx: &mut crate::MiriInterpCx<'tcx>,
|
||||
left: &OpTy<'tcx>,
|
||||
right: &OpTy<'tcx>,
|
||||
dest: &MPlaceTy<'tcx>,
|
||||
) -> InterpResult<'tcx, ()> {
|
||||
let (left, left_len) = ecx.project_to_simd(left)?;
|
||||
let (right, right_len) = ecx.project_to_simd(right)?;
|
||||
let (dest, dest_len) = ecx.project_to_simd(dest)?;
|
||||
|
||||
// fn psadbw(a: u8x16, b: u8x16) -> u64x2;
|
||||
// fn psadbw(a: u8x32, b: u8x32) -> u64x4;
|
||||
// fn vpsadbw(a: u8x64, b: u8x64) -> u64x8;
|
||||
assert_eq!(left_len, right_len);
|
||||
assert_eq!(left_len, left.layout.layout.size().bytes());
|
||||
assert_eq!(dest_len, left_len.strict_div(8));
|
||||
|
||||
for i in 0..dest_len {
|
||||
let dest = ecx.project_index(&dest, i)?;
|
||||
|
||||
let mut acc: u16 = 0;
|
||||
for j in 0..8 {
|
||||
let src_index = i.strict_mul(8).strict_add(j);
|
||||
|
||||
let left = ecx.project_index(&left, src_index)?;
|
||||
let left = ecx.read_scalar(&left)?.to_u8()?;
|
||||
|
||||
let right = ecx.project_index(&right, src_index)?;
|
||||
let right = ecx.read_scalar(&right)?.to_u8()?;
|
||||
|
||||
acc = acc.strict_add(left.abs_diff(right).into());
|
||||
}
|
||||
|
||||
ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
|
||||
}
|
||||
|
||||
interp_ok(())
|
||||
}
|
||||
|
||||
/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
|
||||
/// product to the 18 most significant bits by right-shifting, and then
|
||||
/// divides the 18-bit value by 2 (rounding to nearest) by first adding
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
|
|||
|
||||
use super::{
|
||||
FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
|
||||
packssdw, packsswb, packuswb, shift_simd_by_scalar,
|
||||
packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar,
|
||||
};
|
||||
use crate::*;
|
||||
|
||||
|
|
@ -37,41 +37,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
|
|||
// vectors.
|
||||
match unprefixed_name {
|
||||
// Used to implement the _mm_sad_epu8 function.
|
||||
// Computes the absolute differences of packed unsigned 8-bit integers in `a`
|
||||
// and `b`, then horizontally sum each consecutive 8 differences to produce
|
||||
// two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
|
||||
// the low 16 bits of 64-bit elements returned.
|
||||
//
|
||||
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
|
||||
"psad.bw" => {
|
||||
let [left, right] =
|
||||
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
|
||||
|
||||
let (left, left_len) = this.project_to_simd(left)?;
|
||||
let (right, right_len) = this.project_to_simd(right)?;
|
||||
let (dest, dest_len) = this.project_to_simd(dest)?;
|
||||
|
||||
// left and right are u8x16, dest is u64x2
|
||||
assert_eq!(left_len, right_len);
|
||||
assert_eq!(left_len, 16);
|
||||
assert_eq!(dest_len, 2);
|
||||
|
||||
for i in 0..dest_len {
|
||||
let dest = this.project_index(&dest, i)?;
|
||||
|
||||
let mut res: u16 = 0;
|
||||
let n = left_len.strict_div(dest_len);
|
||||
for j in 0..n {
|
||||
let op_i = j.strict_add(i.strict_mul(n));
|
||||
let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
|
||||
let right =
|
||||
this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;
|
||||
|
||||
res = res.strict_add(left.abs_diff(right).into());
|
||||
}
|
||||
|
||||
this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
|
||||
}
|
||||
psadbw(this, left, right, dest)?
|
||||
}
|
||||
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
|
||||
// (except _mm_sra_epi64, which is not available in SSE2).
|
||||
|
|
|
|||
|
|
@ -15,12 +15,48 @@ fn main() {
|
|||
assert!(is_x86_feature_detected!("avx512vpopcntdq"));
|
||||
|
||||
unsafe {
|
||||
test_avx512();
|
||||
test_avx512bitalg();
|
||||
test_avx512vpopcntdq();
|
||||
test_avx512ternarylogic();
|
||||
}
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx512bw")]
|
||||
unsafe fn test_avx512() {
|
||||
#[target_feature(enable = "avx512bw")]
|
||||
unsafe fn test_mm512_sad_epu8() {
|
||||
let a = _mm512_set_epi8(
|
||||
71, 70, 69, 68, 67, 66, 65, 64, //
|
||||
55, 54, 53, 52, 51, 50, 49, 48, //
|
||||
47, 46, 45, 44, 43, 42, 41, 40, //
|
||||
39, 38, 37, 36, 35, 34, 33, 32, //
|
||||
31, 30, 29, 28, 27, 26, 25, 24, //
|
||||
23, 22, 21, 20, 19, 18, 17, 16, //
|
||||
15, 14, 13, 12, 11, 10, 9, 8, //
|
||||
7, 6, 5, 4, 3, 2, 1, 0, //
|
||||
);
|
||||
|
||||
// `d` is the absolute difference with the corresponding row in `a`.
|
||||
let b = _mm512_set_epi8(
|
||||
63, 62, 61, 60, 59, 58, 57, 56, // lane 7 (d = 8)
|
||||
62, 61, 60, 59, 58, 57, 56, 55, // lane 6 (d = 7)
|
||||
53, 52, 51, 50, 49, 48, 47, 46, // lane 5 (d = 6)
|
||||
44, 43, 42, 41, 40, 39, 38, 37, // lane 4 (d = 5)
|
||||
35, 34, 33, 32, 31, 30, 29, 28, // lane 3 (d = 4)
|
||||
26, 25, 24, 23, 22, 21, 20, 19, // lane 2 (d = 3)
|
||||
17, 16, 15, 14, 13, 12, 11, 10, // lane 1 (d = 2)
|
||||
8, 7, 6, 5, 4, 3, 2, 1, // lane 0 (d = 1)
|
||||
);
|
||||
|
||||
let r = _mm512_sad_epu8(a, b);
|
||||
let e = _mm512_set_epi64(64, 56, 48, 40, 32, 24, 16, 8);
|
||||
|
||||
assert_eq_m512i(r, e);
|
||||
}
|
||||
test_mm512_sad_epu8();
|
||||
}
|
||||
|
||||
// Some of the constants in the tests below are just bit patterns. They should not
|
||||
// be interpreted as integers; signedness does not make sense for them, but
|
||||
// __mXXXi happens to be defined in terms of signed integers.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue