Implement AVX2 shift by immediate (slli, srli, srai) with simd_sh{l,r} instead of LLVM intrinsics

This commit is contained in:
Eduardo Sánchez Muñoz 2023-08-12 01:04:53 +02:00 committed by Amanieu d'Antras
parent 819fe11c49
commit f942d69471

View file

@ -2557,7 +2557,11 @@ pub unsafe fn _mm256_sll_epi64(a: __m256i, count: __m128i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(pslliw(a.as_i16x16(), IMM8))
if IMM8 >= 16 {
_mm256_setzero_si256()
} else {
transmute(simd_shl(a.as_u16x16(), u16x16::splat(IMM8 as u16)))
}
}
/// Shifts packed 32-bit integers in `a` left by `IMM8` while
@ -2571,7 +2575,11 @@ pub unsafe fn _mm256_slli_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psllid(a.as_i32x8(), IMM8))
if IMM8 >= 32 {
_mm256_setzero_si256()
} else {
transmute(simd_shl(a.as_u32x8(), u32x8::splat(IMM8 as u32)))
}
}
/// Shifts packed 64-bit integers in `a` left by `IMM8` while
@ -2585,7 +2593,11 @@ pub unsafe fn _mm256_slli_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(pslliq(a.as_i64x4(), IMM8))
if IMM8 >= 64 {
_mm256_setzero_si256()
} else {
transmute(simd_shl(a.as_u64x4(), u64x4::splat(IMM8 as u64)))
}
}
/// Shifts 128-bit lanes in `a` left by `imm8` bytes while shifting in zeros.
@ -2749,7 +2761,7 @@ pub unsafe fn _mm256_sra_epi32(a: __m256i, count: __m128i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srai_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psraiw(a.as_i16x16(), IMM8))
transmute(simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16)))
}
/// Shifts packed 32-bit integers in `a` right by `IMM8` while
@ -2763,7 +2775,7 @@ pub unsafe fn _mm256_srai_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srai_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psraid(a.as_i32x8(), IMM8))
transmute(simd_shr(a.as_i32x8(), i32x8::splat(IMM8.min(31))))
}
/// Shifts packed 32-bit integers in `a` right by the amount specified by the
@ -2996,7 +3008,11 @@ pub unsafe fn _mm256_srl_epi64(a: __m256i, count: __m128i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrliw(a.as_i16x16(), IMM8))
if IMM8 >= 16 {
_mm256_setzero_si256()
} else {
transmute(simd_shr(a.as_u16x16(), u16x16::splat(IMM8 as u16)))
}
}
/// Shifts packed 32-bit integers in `a` right by `IMM8` while shifting in
@ -3010,7 +3026,11 @@ pub unsafe fn _mm256_srli_epi16<const IMM8: i32>(a: __m256i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrlid(a.as_i32x8(), IMM8))
if IMM8 >= 32 {
_mm256_setzero_si256()
} else {
transmute(simd_shr(a.as_u32x8(), u32x8::splat(IMM8 as u32)))
}
}
/// Shifts packed 64-bit integers in `a` right by `IMM8` while shifting in
@ -3024,7 +3044,11 @@ pub unsafe fn _mm256_srli_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrliq(a.as_i64x4(), IMM8))
if IMM8 >= 64 {
_mm256_setzero_si256()
} else {
transmute(simd_shr(a.as_u64x4(), u64x4::splat(IMM8 as u64)))
}
}
/// Shifts packed 32-bit integers in `a` right by the amount specified by
@ -3677,12 +3701,6 @@ extern "C" {
fn pslld(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psll.q"]
fn psllq(a: i64x4, count: i64x2) -> i64x4;
#[link_name = "llvm.x86.avx2.pslli.w"]
fn pslliw(a: i16x16, imm8: i32) -> i16x16;
#[link_name = "llvm.x86.avx2.pslli.d"]
fn psllid(a: i32x8, imm8: i32) -> i32x8;
#[link_name = "llvm.x86.avx2.pslli.q"]
fn pslliq(a: i64x4, imm8: i32) -> i64x4;
#[link_name = "llvm.x86.avx2.psllv.d"]
fn psllvd(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.avx2.psllv.d.256"]
@ -3695,10 +3713,6 @@ extern "C" {
fn psraw(a: i16x16, count: i16x8) -> i16x16;
#[link_name = "llvm.x86.avx2.psra.d"]
fn psrad(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psrai.w"]
fn psraiw(a: i16x16, imm8: i32) -> i16x16;
#[link_name = "llvm.x86.avx2.psrai.d"]
fn psraid(a: i32x8, imm8: i32) -> i32x8;
#[link_name = "llvm.x86.avx2.psrav.d"]
fn psravd(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.avx2.psrav.d.256"]
@ -3709,12 +3723,6 @@ extern "C" {
fn psrld(a: i32x8, count: i32x4) -> i32x8;
#[link_name = "llvm.x86.avx2.psrl.q"]
fn psrlq(a: i64x4, count: i64x2) -> i64x4;
#[link_name = "llvm.x86.avx2.psrli.w"]
fn psrliw(a: i16x16, imm8: i32) -> i16x16;
#[link_name = "llvm.x86.avx2.psrli.d"]
fn psrlid(a: i32x8, imm8: i32) -> i32x8;
#[link_name = "llvm.x86.avx2.psrli.q"]
fn psrliq(a: i64x4, imm8: i32) -> i64x4;
#[link_name = "llvm.x86.avx2.psrlv.d"]
fn psrlvd(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.avx2.psrlv.d.256"]