Refactor avx512bw: avg + mulhi + abs

This commit is contained in:
Tobias Decking 2024-06-30 14:31:23 +02:00 committed by Amanieu d'Antras
parent 268ac7fe92
commit 13063410dd
2 changed files with 160 additions and 24 deletions

View file

@ -743,3 +743,142 @@ simd_ty!(
x6,
x7
);
// 1024-bit wide types:
simd_ty!(
u16x64[u16]:
x0,
x1,
x2,
x3,
x4,
x5,
x6,
x7,
x8,
x9,
x10,
x11,
x12,
x13,
x14,
x15,
x16,
x17,
x18,
x19,
x20,
x21,
x22,
x23,
x24,
x25,
x26,
x27,
x28,
x29,
x30,
x31,
x32,
x33,
x34,
x35,
x36,
x37,
x38,
x39,
x40,
x41,
x42,
x43,
x44,
x45,
x46,
x47,
x48,
x49,
x50,
x51,
x52,
x53,
x54,
x55,
x56,
x57,
x58,
x59,
x60,
x61,
x62,
x63
);
simd_ty!(
i32x32[i32]:
x0,
x1,
x2,
x3,
x4,
x5,
x6,
x7,
x8,
x9,
x10,
x11,
x12,
x13,
x14,
x15,
x16,
x17,
x18,
x19,
x20,
x21,
x22,
x23,
x24,
x25,
x26,
x27,
x28,
x29,
x30,
x31
);
simd_ty!(
u32x32[u32]:
x0,
x1,
x2,
x3,
x4,
x5,
x6,
x7,
x8,
x9,
x10,
x11,
x12,
x13,
x14,
x15,
x16,
x17,
x18,
x19,
x20,
x21,
x22,
x23,
x24,
x25,
x26,
x27,
x28,
x29,
x30,
x31
);

View file

@ -2,7 +2,7 @@ use crate::{
arch::asm,
core_arch::{simd::*, x86::*},
intrinsics::simd::*,
mem, ptr,
ptr,
};
#[cfg(test)]
@ -17,11 +17,8 @@ use stdarch_test::assert_instr;
#[cfg_attr(test, assert_instr(vpabsw))]
pub unsafe fn _mm512_abs_epi16(a: __m512i) -> __m512i {
let a = a.as_i16x32();
// all-0 is a properly initialized i16x32
let zero: i16x32 = mem::zeroed();
let sub = simd_sub(zero, a);
let cmp: i16x32 = simd_gt(a, zero);
transmute(simd_select(cmp, a, sub))
let cmp: i16x32 = simd_gt(a, i16x32::splat(0));
transmute(simd_select(cmp, a, simd_neg(a)))
}
/// Compute the absolute value of packed signed 16-bit integers in a, and store the unsigned results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -108,11 +105,8 @@ pub unsafe fn _mm_maskz_abs_epi16(k: __mmask8, a: __m128i) -> __m128i {
#[cfg_attr(test, assert_instr(vpabsb))]
pub unsafe fn _mm512_abs_epi8(a: __m512i) -> __m512i {
let a = a.as_i8x64();
// all-0 is a properly initialized i8x64
let zero: i8x64 = mem::zeroed();
let sub = simd_sub(zero, a);
let cmp: i8x64 = simd_gt(a, zero);
transmute(simd_select(cmp, a, sub))
let cmp: i8x64 = simd_gt(a, i8x64::splat(0));
transmute(simd_select(cmp, a, simd_neg(a)))
}
/// Compute the absolute value of packed signed 8-bit integers in a, and store the unsigned results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -1368,7 +1362,10 @@ pub unsafe fn _mm_maskz_subs_epi8(k: __mmask16, a: __m128i, b: __m128i) -> __m12
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vpmulhuw))]
pub unsafe fn _mm512_mulhi_epu16(a: __m512i, b: __m512i) -> __m512i {
transmute(vpmulhuw(a.as_u16x32(), b.as_u16x32()))
let a = simd_cast::<_, u32x32>(a.as_u16x32());
let b = simd_cast::<_, u32x32>(b.as_u16x32());
let r = simd_shr(simd_mul(a, b), u32x32::splat(16));
transmute(simd_cast::<u32x32, u16x32>(r))
}
/// Multiply the packed unsigned 16-bit integers in a and b, producing intermediate 32-bit integers, and store the high 16 bits of the intermediate integers in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -1464,7 +1461,10 @@ pub unsafe fn _mm_maskz_mulhi_epu16(k: __mmask8, a: __m128i, b: __m128i) -> __m1
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vpmulhw))]
pub unsafe fn _mm512_mulhi_epi16(a: __m512i, b: __m512i) -> __m512i {
transmute(vpmulhw(a.as_i16x32(), b.as_i16x32()))
let a = simd_cast::<_, i32x32>(a.as_i16x32());
let b = simd_cast::<_, i32x32>(b.as_i16x32());
let r = simd_shr(simd_mul(a, b), i32x32::splat(16));
transmute(simd_cast::<i32x32, i16x32>(r))
}
/// Multiply the packed signed 16-bit integers in a and b, producing intermediate 32-bit integers, and store the high 16 bits of the intermediate integers in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -5505,7 +5505,10 @@ pub unsafe fn _mm_maskz_packus_epi16(k: __mmask16, a: __m128i, b: __m128i) -> __
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vpavgw))]
pub unsafe fn _mm512_avg_epu16(a: __m512i, b: __m512i) -> __m512i {
transmute(vpavgw(a.as_u16x32(), b.as_u16x32()))
let a = simd_cast::<_, u32x32>(a.as_u16x16());
let b = simd_cast::<_, u32x32>(b.as_u16x16());
let r = simd_shr(simd_add(simd_add(a, b), u32x32::splat(1)), u32x32::splat(1));
transmute(simd_cast::<_, u16x32>(r))
}
/// Average packed unsigned 16-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -5591,7 +5594,10 @@ pub unsafe fn _mm_maskz_avg_epu16(k: __mmask8, a: __m128i, b: __m128i) -> __m128
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vpavgb))]
pub unsafe fn _mm512_avg_epu8(a: __m512i, b: __m512i) -> __m512i {
transmute(vpavgb(a.as_u8x64(), b.as_u8x64()))
let a = simd_cast::<_, u16x64>(a.as_u8x64());
let b = simd_cast::<_, u16x64>(b.as_u8x64());
let r = simd_shr(simd_add(simd_add(a, b), u16x64::splat(1)), u16x64::splat(1));
transmute(simd_cast::<_, u8x64>(r))
}
/// Average packed unsigned 8-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -10645,10 +10651,6 @@ extern "C" {
#[link_name = "llvm.x86.avx512.mask.psubs.b.128"]
fn vpsubsb128(a: i8x16, b: i8x16, src: i8x16, mask: u16) -> i8x16;
#[link_name = "llvm.x86.avx512.pmulhu.w.512"]
fn vpmulhuw(a: u16x32, b: u16x32) -> u16x32;
#[link_name = "llvm.x86.avx512.pmulh.w.512"]
fn vpmulhw(a: i16x32, b: i16x32) -> i16x32;
#[link_name = "llvm.x86.avx512.pmul.hr.sw.512"]
fn vpmulhrsw(a: i16x32, b: i16x32) -> i16x32;
@ -10712,11 +10714,6 @@ extern "C" {
#[link_name = "llvm.x86.avx512.packuswb.512"]
fn vpackuswb(a: i16x32, b: i16x32) -> u8x64;
#[link_name = "llvm.x86.avx512.pavg.w.512"]
fn vpavgw(a: u16x32, b: u16x32) -> u16x32;
#[link_name = "llvm.x86.avx512.pavg.b.512"]
fn vpavgb(a: u8x64, b: u8x64) -> u8x64;
#[link_name = "llvm.x86.avx512.psll.w.512"]
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;