Implement AVX512F 64-bit shift by immediate (srai_epi64) with simd_shr instead of LLVM intrinsics

This commit is contained in:
Eduardo Sánchez Muñoz 2023-08-12 12:20:42 +02:00 committed by Amanieu d'Antras
parent 01ff55e216
commit 4b2efda9a9

View file

@ -18328,9 +18328,7 @@ pub unsafe fn _mm_maskz_srai_epi32<const IMM8: u32>(k: __mmask8, a: __m128i) ->
#[rustc_legacy_const_generics(1)]
pub unsafe fn _mm512_srai_epi64<const IMM8: u32>(a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x8();
let r = vpsraiq(a, IMM8);
transmute(r)
transmute(simd_shr(a.as_i64x8(), i64x8::splat(IMM8.min(63) as i64)))
}
/// Shift packed 64-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -18346,8 +18344,7 @@ pub unsafe fn _mm512_mask_srai_epi64<const IMM8: u32>(
a: __m512i,
) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x8();
let shf = vpsraiq(a, IMM8);
let shf = simd_shr(a.as_i64x8(), i64x8::splat(IMM8.min(63) as i64));
transmute(simd_select_bitmask(k, shf, src.as_i64x8()))
}
@ -18360,9 +18357,8 @@ pub unsafe fn _mm512_mask_srai_epi64<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm512_maskz_srai_epi64<const IMM8: u32>(k: __mmask8, a: __m512i) -> __m512i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x8();
let shf = vpsraiq(a, IMM8);
let zero = _mm512_setzero_si512().as_i64x8();
let shf = simd_shr(a.as_i64x8(), i64x8::splat(IMM8.min(63) as i64));
let zero = i64x8::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
@ -18375,9 +18371,7 @@ pub unsafe fn _mm512_maskz_srai_epi64<const IMM8: u32>(k: __mmask8, a: __m512i)
#[rustc_legacy_const_generics(1)]
pub unsafe fn _mm256_srai_epi64<const IMM8: u32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x4();
let r = vpsraiq256(a, IMM8);
transmute(r)
transmute(simd_shr(a.as_i64x4(), i64x4::splat(IMM8.min(63) as i64)))
}
/// Shift packed 64-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -18393,8 +18387,7 @@ pub unsafe fn _mm256_mask_srai_epi64<const IMM8: u32>(
a: __m256i,
) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x4();
let shf = vpsraiq256(a, IMM8);
let shf = simd_shr(a.as_i64x4(), i64x4::splat(IMM8.min(63) as i64));
transmute(simd_select_bitmask(k, shf, src.as_i64x4()))
}
@ -18407,9 +18400,8 @@ pub unsafe fn _mm256_mask_srai_epi64<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm256_maskz_srai_epi64<const IMM8: u32>(k: __mmask8, a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x4();
let shf = vpsraiq256(a, IMM8);
let zero = _mm256_setzero_si256().as_i64x4();
let shf = simd_shr(a.as_i64x4(), i64x4::splat(IMM8.min(63) as i64));
let zero = i64x4::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
@ -18422,9 +18414,7 @@ pub unsafe fn _mm256_maskz_srai_epi64<const IMM8: u32>(k: __mmask8, a: __m256i)
#[rustc_legacy_const_generics(1)]
pub unsafe fn _mm_srai_epi64<const IMM8: u32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x2();
let r = vpsraiq128(a, IMM8);
transmute(r)
transmute(simd_shr(a.as_i64x2(), i64x2::splat(IMM8.min(63) as i64)))
}
/// Shift packed 64-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@ -18440,8 +18430,7 @@ pub unsafe fn _mm_mask_srai_epi64<const IMM8: u32>(
a: __m128i,
) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x2();
let shf = vpsraiq128(a, IMM8);
let shf = simd_shr(a.as_i64x2(), i64x2::splat(IMM8.min(63) as i64));
transmute(simd_select_bitmask(k, shf, src.as_i64x2()))
}
@ -18454,9 +18443,8 @@ pub unsafe fn _mm_mask_srai_epi64<const IMM8: u32>(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm_maskz_srai_epi64<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
let a = a.as_i64x2();
let shf = vpsraiq128(a, IMM8);
let zero = _mm_setzero_si128().as_i64x2();
let shf = simd_shr(a.as_i64x2(), i64x2::splat(IMM8.min(63) as i64));
let zero = i64x2::splat(0);
transmute(simd_select_bitmask(k, shf, zero))
}
@ -38540,13 +38528,6 @@ extern "C" {
#[link_name = "llvm.x86.avx512.psra.q.128"]
fn vpsraq128(a: i64x2, count: i64x2) -> i64x2;
#[link_name = "llvm.x86.avx512.psrai.q.512"]
fn vpsraiq(a: i64x8, imm8: u32) -> i64x8;
#[link_name = "llvm.x86.avx512.psrai.q.256"]
fn vpsraiq256(a: i64x4, imm8: u32) -> i64x4;
#[link_name = "llvm.x86.avx512.psrai.q.128"]
fn vpsraiq128(a: i64x2, imm8: u32) -> i64x2;
#[link_name = "llvm.x86.avx512.psrav.d.512"]
fn vpsravd(a: i32x16, count: i32x16) -> i32x16;