Implement AVX512F 32-bit shift by immediate (srai_epi32) with simd_shr instead of LLVM intrinsics

This commit is contained in:
Eduardo Sánchez Muñoz 2023-08-12 12:15:11 +02:00 committed by Amanieu d'Antras
parent 29ba594589
commit 01ff55e216

View file

@ -18227,9 +18227,7 @@ pub unsafe fn _mm_maskz_sra_epi64(k: __mmask8, a: __m128i, count: __m128i) -> __
#[rustc_legacy_const_generics(1)]
pub unsafe fn _mm512_srai_epi32<const IMM8: u32>(a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i32x16();
let r = vpsraid512(a, IMM8);
transmute(r)
transmute(simd_shr(a.as_i32x16(), i32x16::splat(IMM8.min(31) as i32)))
}
/// Shift packed 32-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -18245,8 +18243,7 @@ pub unsafe fn _mm512_mask_srai_epi32<const IMM8: u32>(
a: __m512i,
) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i32x16();
let r = vpsraid512(a, IMM8);
let r = simd_shr(a.as_i32x16(), i32x16::splat(IMM8.min(31) as i32));
transmute(simd_select_bitmask(k, r, src.as_i32x16()))
}
@ -18259,9 +18256,8 @@ pub unsafe fn _mm512_mask_srai_epi32<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm512_maskz_srai_epi32<const IMM8: u32>(k: __mmask16, a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i32x16();
let r = vpsraid512(a, IMM8);
let zero = _mm512_setzero_si512().as_i32x16();
let r = simd_shr(a.as_i32x16(), i32x16::splat(IMM8.min(31) as i32));
let zero = i32x16::splat(0);
transmute(simd_select_bitmask(k, r, zero))
}
@ -18277,8 +18273,7 @@ pub unsafe fn _mm256_mask_srai_epi32<const IMM8: u32>(
k: __mmask8,
a: __m256i,
) -> __m256i {
let imm8 = IMM8 as i32;
let r = psraid256(a.as_i32x8(), imm8);
let r = simd_shr(a.as_i32x8(), i32x8::splat(IMM8.min(31) as i32));
transmute(simd_select_bitmask(k, r, src.as_i32x8()))
}
@ -18290,9 +18285,8 @@ pub unsafe fn _mm256_mask_srai_epi32<const IMM8: u32>(
#[cfg_attr(test, assert_instr(vpsrad, IMM8 = 1))]
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm256_maskz_srai_epi32<const IMM8: u32>(k: __mmask8, a: __m256i) -> __m256i {
let imm8 = IMM8 as i32;
let r = psraid256(a.as_i32x8(), imm8);
let zero = _mm256_setzero_si256().as_i32x8();
let r = simd_shr(a.as_i32x8(), i32x8::splat(IMM8.min(31) as i32));
let zero = i32x8::splat(0);
transmute(simd_select_bitmask(k, r, zero))
}
@ -18308,8 +18302,7 @@ pub unsafe fn _mm_mask_srai_epi32<const IMM8: u32>(
k: __mmask8,
a: __m128i,
) -> __m128i {
let imm8 = IMM8 as i32;
let r = psraid128(a.as_i32x4(), imm8);
let r = simd_shr(a.as_i32x4(), i32x4::splat(IMM8.min(31) as i32));
transmute(simd_select_bitmask(k, r, src.as_i32x4()))
}
@ -18321,9 +18314,8 @@ pub unsafe fn _mm_mask_srai_epi32<const IMM8: u32>(
#[cfg_attr(test, assert_instr(vpsrad, IMM8 = 1))]
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm_maskz_srai_epi32<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i {
let imm8 = IMM8 as i32;
let r = psraid128(a.as_i32x4(), imm8);
let zero = _mm_setzero_si128().as_i32x4();
let r = simd_shr(a.as_i32x4(), i32x4::splat(IMM8.min(31) as i32));
let zero = i32x4::splat(0);
transmute(simd_select_bitmask(k, r, zero))
}
@ -38548,13 +38540,6 @@ extern "C" {
#[link_name = "llvm.x86.avx512.psra.q.128"]
fn vpsraq128(a: i64x2, count: i64x2) -> i64x2;
#[link_name = "llvm.x86.avx512.psrai.d.512"]
fn vpsraid512(a: i32x16, imm8: u32) -> i32x16;
#[link_name = "llvm.x86.avx2.psrai.d"]
fn psraid256(a: i32x8, imm8: i32) -> i32x8;
#[link_name = "llvm.x86.sse2.psrai.d"]
fn psraid128(a: i32x4, imm8: i32) -> i32x4;
#[link_name = "llvm.x86.avx512.psrai.q.512"]
fn vpsraiq(a: i64x8, imm8: u32) -> i64x8;
#[link_name = "llvm.x86.avx512.psrai.q.256"]