Implement AVX512BW 16-bit shift by immediate (slli_epi16) with simd_shl instead of LLVM intrinsics

This commit is contained in:
Eduardo Sánchez Muñoz 2023-08-12 12:30:27 +02:00 committed by Amanieu d'Antras
parent 4b2efda9a9
commit 0c0f72ee7f

View file

@ -5339,9 +5339,11 @@ pub unsafe fn _mm_maskz_sll_epi16(k: __mmask8, a: __m128i, count: __m128i) -> __
#[rustc_legacy_const_generics(1)]
pub unsafe fn _mm512_slli_epi16<const IMM8: u32>(a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i16x32();
let r = vpslliw(a, IMM8);
transmute(r)
if IMM8 >= 16 {
_mm512_setzero_si512()
} else {
transmute(simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16)))
}
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -5357,9 +5359,12 @@ pub unsafe fn _mm512_mask_slli_epi16<const IMM8: u32>(
a: __m512i,
) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i16x32();
let shf = vpslliw(a, IMM8);
transmute(simd_select_bitmask(k, shf, src.as_i16x32()))
let shf = if IMM8 >= 16 {
u16x32::splat(0)
} else {
simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16))
};
transmute(simd_select_bitmask(k, shf, src.as_u16x32()))
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@ -5371,10 +5376,13 @@ pub unsafe fn _mm512_mask_slli_epi16<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm512_maskz_slli_epi16<const IMM8: u32>(k: __mmask32, a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i16x32();
let shf = vpslliw(a, IMM8);
let zero = _mm512_setzero_si512().as_i16x32();
transmute(simd_select_bitmask(k, shf, zero))
if IMM8 >= 16 {
_mm512_setzero_si512()
} else {
let shf = simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16));
let zero = u16x32::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -5390,9 +5398,12 @@ pub unsafe fn _mm256_mask_slli_epi16<const IMM8: u32>(
a: __m256i,
) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
let imm8 = IMM8 as i32;
let r = pslliw256(a.as_i16x16(), imm8);
transmute(simd_select_bitmask(k, r, src.as_i16x16()))
let shf = if IMM8 >= 16 {
u16x16::splat(0)
} else {
simd_shl(a.as_u16x16(), u16x16::splat(IMM8 as u16))
};
transmute(simd_select_bitmask(k, shf, src.as_u16x16()))
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@ -5404,10 +5415,13 @@ pub unsafe fn _mm256_mask_slli_epi16<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm256_maskz_slli_epi16<const IMM8: u32>(k: __mmask16, a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
let imm8 = IMM8 as i32;
let r = pslliw256(a.as_i16x16(), imm8);
let zero = _mm256_setzero_si256().as_i16x16();
transmute(simd_select_bitmask(k, r, zero))
if IMM8 >= 16 {
_mm256_setzero_si256()
} else {
let shf = simd_shl(a.as_u16x16(), u16x16::splat(IMM8 as u16));
let zero = u16x16::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -5423,9 +5437,12 @@ pub unsafe fn _mm_mask_slli_epi16<const IMM8: u32>(
a: __m128i,
) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
let imm8 = IMM8 as i32;
let r = pslliw128(a.as_i16x8(), imm8);
transmute(simd_select_bitmask(k, r, src.as_i16x8()))
let shf = if IMM8 >= 16 {
u16x8::splat(0)
} else {
simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16))
};
transmute(simd_select_bitmask(k, shf, src.as_u16x8()))
}
/// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@ -5437,10 +5454,13 @@ pub unsafe fn _mm_mask_slli_epi16<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm_maskz_slli_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
let imm8 = IMM8 as i32;
let r = pslliw128(a.as_i16x8(), imm8);
let zero = _mm_setzero_si128().as_i16x8();
transmute(simd_select_bitmask(k, r, zero))
if IMM8 >= 16 {
_mm_setzero_si128()
} else {
let shf = simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16));
let zero = u16x8::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
}
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst.
@ -9965,13 +9985,6 @@ extern "C" {
#[link_name = "llvm.x86.avx512.psll.w.512"]
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;
#[link_name = "llvm.x86.avx512.pslli.w.512"]
fn vpslliw(a: i16x32, imm8: u32) -> i16x32;
#[link_name = "llvm.x86.avx2.pslli.w"]
fn pslliw256(a: i16x16, imm8: i32) -> i16x16;
#[link_name = "llvm.x86.sse2.pslli.w"]
fn pslliw128(a: i16x8, imm8: i32) -> i16x8;
#[link_name = "llvm.x86.avx512.psllv.w.512"]
fn vpsllvw(a: i16x32, b: i16x32) -> i16x32;