add an avx512 psad shim

also combine the sse2 and avx2 version into one generic function for all 3
This commit is contained in:
Folkert de Vries 2025-11-14 00:57:27 +01:00
parent a3955227a8
commit 04d97bc964
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
5 changed files with 98 additions and 64 deletions

View file

@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
use super::{
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
};
use crate::*;
@ -241,41 +241,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
}
// Used to implement the _mm256_sad_epu8 function.
// Compute the absolute differences of packed unsigned 8-bit integers
// in `left` and `right`, then horizontally sum each consecutive 8
// differences to produce four unsigned 16-bit integers, and pack
// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
// in `dest`.
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8
"psad.bw" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
assert_eq!(left_len, right_len);
assert_eq!(left_len, dest_len.strict_mul(8));
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;
let mut acc: u16 = 0;
for j in 0..8 {
let src_index = i.strict_mul(8).strict_add(j);
let left = this.project_index(&left, src_index)?;
let left = this.read_scalar(&left)?.to_u8()?;
let right = this.project_index(&right, src_index)?;
let right = this.read_scalar(&right)?.to_u8()?;
acc = acc.strict_add(left.abs_diff(right).into());
}
this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
}
psadbw(this, left, right, dest)?
}
// Used to implement the _mm256_shuffle_epi8 intrinsic.
// Shuffles bytes from `left` using `right` as pattern.

View file

@ -3,6 +3,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;
use super::psadbw;
use crate::*;
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@ -78,6 +79,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(Scalar::from_u32(r), &d_lane)?;
}
}
// Used to implement the _mm512_sad_epu8 function.
"psad.bw.512" => {
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
psadbw(this, left, right, dest)?
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}
interp_ok(EmulateItemResult::NeedsReturn)

View file

@ -1038,6 +1038,54 @@ fn mpsadbw<'tcx>(
interp_ok(())
}
/// Compute the absolute differences of packed unsigned 8-bit integers
/// in `left` and `right`, then horizontally sum each consecutive 8
/// differences to produce unsigned 16-bit integers, and pack
/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
/// in `dest`.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_sad_epu8>
fn psadbw<'tcx>(
ecx: &mut crate::MiriInterpCx<'tcx>,
left: &OpTy<'tcx>,
right: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = ecx.project_to_simd(left)?;
let (right, right_len) = ecx.project_to_simd(right)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;
// fn psadbw(a: u8x16, b: u8x16) -> u64x2;
// fn psadbw(a: u8x32, b: u8x32) -> u64x4;
// fn vpsadbw(a: u8x64, b: u8x64) -> u64x8;
assert_eq!(left_len, right_len);
assert_eq!(left_len, left.layout.layout.size().bytes());
assert_eq!(dest_len, left_len.strict_div(8));
for i in 0..dest_len {
let dest = ecx.project_index(&dest, i)?;
let mut acc: u16 = 0;
for j in 0..8 {
let src_index = i.strict_mul(8).strict_add(j);
let left = ecx.project_index(&left, src_index)?;
let left = ecx.read_scalar(&left)?.to_u8()?;
let right = ecx.project_index(&right, src_index)?;
let right = ecx.read_scalar(&right)?.to_u8()?;
acc = acc.strict_add(left.abs_diff(right).into());
}
ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
}
interp_ok(())
}
/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
/// product to the 18 most significant bits by right-shifting, and then
/// divides the 18-bit value by 2 (rounding to nearest) by first adding

View file

@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
use super::{
FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
packssdw, packsswb, packuswb, shift_simd_by_scalar,
packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar,
};
use crate::*;
@ -37,41 +37,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// vectors.
match unprefixed_name {
// Used to implement the _mm_sad_epu8 function.
// Computes the absolute differences of packed unsigned 8-bit integers in `a`
// and `b`, then horizontally sum each consecutive 8 differences to produce
// two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
// the low 16 bits of 64-bit elements returned.
//
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
"psad.bw" => {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
// left and right are u8x16, dest is u64x2
assert_eq!(left_len, right_len);
assert_eq!(left_len, 16);
assert_eq!(dest_len, 2);
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;
let mut res: u16 = 0;
let n = left_len.strict_div(dest_len);
for j in 0..n {
let op_i = j.strict_add(i.strict_mul(n));
let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
let right =
this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;
res = res.strict_add(left.abs_diff(right).into());
}
this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
}
psadbw(this, left, right, dest)?
}
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
// (except _mm_sra_epi64, which is not available in SSE2).

View file

@ -15,12 +15,48 @@ fn main() {
assert!(is_x86_feature_detected!("avx512vpopcntdq"));
unsafe {
test_avx512();
test_avx512bitalg();
test_avx512vpopcntdq();
test_avx512ternarylogic();
}
}
#[target_feature(enable = "avx512bw")]
unsafe fn test_avx512() {
#[target_feature(enable = "avx512bw")]
unsafe fn test_mm512_sad_epu8() {
let a = _mm512_set_epi8(
71, 70, 69, 68, 67, 66, 65, 64, //
55, 54, 53, 52, 51, 50, 49, 48, //
47, 46, 45, 44, 43, 42, 41, 40, //
39, 38, 37, 36, 35, 34, 33, 32, //
31, 30, 29, 28, 27, 26, 25, 24, //
23, 22, 21, 20, 19, 18, 17, 16, //
15, 14, 13, 12, 11, 10, 9, 8, //
7, 6, 5, 4, 3, 2, 1, 0, //
);
// `d` is the absolute difference with the corresponding row in `a`.
let b = _mm512_set_epi8(
63, 62, 61, 60, 59, 58, 57, 56, // lane 7 (d = 8)
62, 61, 60, 59, 58, 57, 56, 55, // lane 6 (d = 7)
53, 52, 51, 50, 49, 48, 47, 46, // lane 5 (d = 6)
44, 43, 42, 41, 40, 39, 38, 37, // lane 4 (d = 5)
35, 34, 33, 32, 31, 30, 29, 28, // lane 3 (d = 4)
26, 25, 24, 23, 22, 21, 20, 19, // lane 2 (d = 3)
17, 16, 15, 14, 13, 12, 11, 10, // lane 1 (d = 2)
8, 7, 6, 5, 4, 3, 2, 1, // lane 0 (d = 1)
);
let r = _mm512_sad_epu8(a, b);
let e = _mm512_set_epi64(64, 56, 48, 40, 32, 24, 16, 8);
assert_eq_m512i(r, e);
}
test_mm512_sad_epu8();
}
// Some of the constants in the tests below are just bit patterns. They should not
// be interpreted as integers; signedness does not make sense for them, but
// __mXXXi happens to be defined in terms of signed integers.