Merge pull request #4808 from folkertdev/pmaddwd

add `pmaddwd` shim
This commit is contained in:
Ralf Jung 2026-01-08 18:48:10 +00:00 committed by GitHub
commit 7dc794291f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 131 additions and 57 deletions

View file

@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
use super::{
ShiftOp, horizontal_bin_op, mpsadbw, packssdw, packsswb, packusdw, packuswb, permute, pmaddbw,
pmulhrsw, psadbw, pshufb, psign, shift_simd_by_scalar,
pmaddwd, pmulhrsw, psadbw, pshufb, psign, shift_simd_by_scalar,
};
use crate::*;
@ -232,33 +232,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);
for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
let j2 = j1.strict_add(1);
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
let dest = this.project_index(&dest, i)?;
// Multiplications are i16*i16->i32, which will not overflow.
let mul1 = i32::from(left1).strict_mul(right1.into());
let mul2 = i32::from(left2).strict_mul(right2.into());
// However, this addition can overflow in the most extreme case
// (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
let res = mul1.wrapping_add(mul2);
this.write_scalar(Scalar::from_i32(res), &dest)?;
}
pmaddwd(this, left, right, dest)?;
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}

View file

@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::callconv::FnAbi;
use super::{permute, pmaddbw, psadbw, pshufb};
use super::{permute, pmaddbw, pmaddwd, psadbw, pshufb};
use crate::*;
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@ -88,6 +88,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
psadbw(this, left, right, dest)?
}
// Used to implement the _mm512_madd_epi16 function.
"pmaddw.d.512" => {
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
pmaddwd(this, left, right, dest)?;
}
// Used to implement the _mm512_maddubs_epi16 function.
"pmaddubs.w.512" => {
let [left, right] =

View file

@ -964,6 +964,52 @@ fn psadbw<'tcx>(
interp_ok(())
}
/// Multiply packed signed 16-bit integers in `left` and `right`, producing intermediate signed 32-bit integers.
/// Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in `dest`.
///
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_madd_epi16>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_madd_epi16>
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_madd_epi16>
fn pmaddwd<'tcx>(
ecx: &mut crate::MiriInterpCx<'tcx>,
left: &OpTy<'tcx>,
right: &OpTy<'tcx>,
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = ecx.project_to_simd(left)?;
let (right, right_len) = ecx.project_to_simd(right)?;
let (dest, dest_len) = ecx.project_to_simd(dest)?;
// fn pmaddwd(a: i16x8, b: i16x8) -> i32x4;
// fn pmaddwd(a: i16x16, b: i16x16) -> i32x8;
// fn vpmaddwd(a: i16x32, b: i16x32) -> i32x16;
assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);
for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = ecx.read_scalar(&ecx.project_index(&left, j1)?)?.to_i16()?;
let right1 = ecx.read_scalar(&ecx.project_index(&right, j1)?)?.to_i16()?;
let j2 = j1.strict_add(1);
let left2 = ecx.read_scalar(&ecx.project_index(&left, j2)?)?.to_i16()?;
let right2 = ecx.read_scalar(&ecx.project_index(&right, j2)?)?.to_i16()?;
let dest = ecx.project_index(&dest, i)?;
// Multiplications are i16*i16->i32, which will not overflow.
let mul1 = i32::from(left1).strict_mul(right1.into());
let mul2 = i32::from(left2).strict_mul(right2.into());
// However, this addition can overflow in the most extreme case
// (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
let res = mul1.wrapping_add(mul2);
ecx.write_scalar(Scalar::from_i32(res), &dest)?;
}
interp_ok(())
}
/// Multiplies packed 8-bit unsigned integers from `left` and packed
/// signed 8-bit integers from `right` into 16-bit signed integers. Then,
/// the saturating sum of the products with indices `2*i` and `2*i+1`

View file

@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
use super::{
FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar,
packssdw, packsswb, packuswb, pmaddwd, psadbw, shift_simd_by_scalar,
};
use crate::*;
@ -286,33 +286,7 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);
for i in 0..dest_len {
let j1 = i.strict_mul(2);
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
let j2 = j1.strict_add(1);
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
let dest = this.project_index(&dest, i)?;
// Multiplications are i16*i16->i32, which will not overflow.
let mul1 = i32::from(left1).strict_mul(right1.into());
let mul2 = i32::from(left2).strict_mul(right2.into());
// However, this addition can overflow in the most extreme case
// (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
let res = mul1.wrapping_add(mul2);
this.write_scalar(Scalar::from_i32(res), &dest)?;
}
pmaddwd(this, left, right, dest)?;
}
_ => return interp_ok(EmulateItemResult::NotSupported),
}

View file

@ -100,6 +100,77 @@ unsafe fn test_avx512() {
}
test_mm512_maddubs_epi16();
#[target_feature(enable = "avx512bw")]
unsafe fn test_mm512_madd_epi16() {
// Input pairs
//
// - `i16::MIN * i16::MIN + i16::MIN * i16::MIN`: the 32-bit addition overflows
// - `i16::MAX * i16::MAX + i16::MAX * i16::MAX`: check that widening happens before
// arithmetic
// - `i16::MIN * i16::MAX + i16::MAX * i16::MIN`: check that large negative values are
// handled correctly
// - `3 * 1 + 4 * 2`: A sanity check, the result should be 14.
#[rustfmt::skip]
let a = _mm512_set_epi16(
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MIN, i16::MAX,
3, 1,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MIN, i16::MAX,
3, 1,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MIN, i16::MAX,
3, 1,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MIN, i16::MAX,
3, 1,
);
#[rustfmt::skip]
let b = _mm512_set_epi16(
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MAX, i16::MIN,
4, 2,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MAX, i16::MIN,
4, 2,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MAX, i16::MIN,
4, 2,
i16::MIN, i16::MIN,
i16::MAX, i16::MAX,
i16::MAX, i16::MIN,
4, 2,
);
let r = _mm512_madd_epi16(a, b);
#[rustfmt::skip]
let e = _mm512_set_epi32(
i32::MIN, 2_147_352_578, -2_147_418_112, 14,
i32::MIN, 2_147_352_578, -2_147_418_112, 14,
i32::MIN, 2_147_352_578, -2_147_418_112, 14,
i32::MIN, 2_147_352_578, -2_147_418_112, 14,
);
assert_eq_m512i(r, e);
}
test_mm512_madd_epi16();
#[target_feature(enable = "avx512f")]
unsafe fn test_mm512_permutexvar_epi32() {
let a = _mm512_set_epi32(