Implement SSE2 shift by immediate (slli, srli, srai) with simd_sh{l,r} instead of LLVM intrinsics

This commit is contained in:
Eduardo Sánchez Muñoz 2023-08-12 00:52:13 +02:00 committed by Amanieu d'Antras
parent 3b7dc00f66
commit 819fe11c49

View file

@ -501,7 +501,11 @@ pub unsafe fn _mm_bsrli_si128<const IMM8: i32>(a: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_slli_epi16<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(pslliw(a.as_i16x8(), IMM8))
if IMM8 >= 16 {
_mm_setzero_si128()
} else {
transmute(simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16)))
}
}
/// Shifts packed 16-bit integers in `a` left by `count` while shifting in
@ -526,7 +530,11 @@ pub unsafe fn _mm_sll_epi16(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_slli_epi32<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psllid(a.as_i32x4(), IMM8))
if IMM8 >= 32 {
_mm_setzero_si128()
} else {
transmute(simd_shl(a.as_u32x4(), u32x4::splat(IMM8 as u32)))
}
}
/// Shifts packed 32-bit integers in `a` left by `count` while shifting in
@ -551,7 +559,11 @@ pub unsafe fn _mm_sll_epi32(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_slli_epi64<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(pslliq(a.as_i64x2(), IMM8))
if IMM8 >= 64 {
_mm_setzero_si128()
} else {
transmute(simd_shl(a.as_u64x2(), u64x2::splat(IMM8 as u64)))
}
}
/// Shifts packed 64-bit integers in `a` left by `count` while shifting in
@ -577,7 +589,7 @@ pub unsafe fn _mm_sll_epi64(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srai_epi16<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psraiw(a.as_i16x8(), IMM8))
transmute(simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16)))
}
/// Shifts packed 16-bit integers in `a` right by `count` while shifting in sign
@ -603,7 +615,7 @@ pub unsafe fn _mm_sra_epi16(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srai_epi32<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psraid(a.as_i32x4(), IMM8))
transmute(simd_shr(a.as_i32x4(), i32x4::splat(IMM8.min(31))))
}
/// Shifts packed 32-bit integers in `a` right by `count` while shifting in sign
@ -680,7 +692,11 @@ unsafe fn _mm_srli_si128_impl<const IMM8: i32>(a: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srli_epi16<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrliw(a.as_i16x8(), IMM8))
if IMM8 >= 16 {
_mm_setzero_si128()
} else {
transmute(simd_shr(a.as_u16x8(), u16x8::splat(IMM8 as u16)))
}
}
/// Shifts packed 16-bit integers in `a` right by `count` while shifting in
@ -706,7 +722,11 @@ pub unsafe fn _mm_srl_epi16(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srli_epi32<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrlid(a.as_i32x4(), IMM8))
if IMM8 >= 32 {
_mm_setzero_si128()
} else {
transmute(simd_shr(a.as_u32x4(), u32x4::splat(IMM8 as u32)))
}
}
/// Shifts packed 32-bit integers in `a` right by `count` while shifting in
@ -732,7 +752,11 @@ pub unsafe fn _mm_srl_epi32(a: __m128i, count: __m128i) -> __m128i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_srli_epi64<const IMM8: i32>(a: __m128i) -> __m128i {
static_assert_uimm_bits!(IMM8, 8);
transmute(psrliq(a.as_i64x2(), IMM8))
if IMM8 >= 32 {
_mm_setzero_si128()
} else {
transmute(simd_shr(a.as_u64x2(), u64x2::splat(IMM8 as u64)))
}
}
/// Shifts packed 64-bit integers in `a` right by `count` while shifting in
@ -2816,36 +2840,20 @@ extern "C" {
fn pmuludq(a: u32x4, b: u32x4) -> u64x2;
#[link_name = "llvm.x86.sse2.psad.bw"]
fn psadbw(a: u8x16, b: u8x16) -> u64x2;
#[link_name = "llvm.x86.sse2.pslli.w"]
fn pslliw(a: i16x8, imm8: i32) -> i16x8;
#[link_name = "llvm.x86.sse2.psll.w"]
fn psllw(a: i16x8, count: i16x8) -> i16x8;
#[link_name = "llvm.x86.sse2.pslli.d"]
fn psllid(a: i32x4, imm8: i32) -> i32x4;
#[link_name = "llvm.x86.sse2.psll.d"]
fn pslld(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.sse2.pslli.q"]
fn pslliq(a: i64x2, imm8: i32) -> i64x2;
#[link_name = "llvm.x86.sse2.psll.q"]
fn psllq(a: i64x2, count: i64x2) -> i64x2;
#[link_name = "llvm.x86.sse2.psrai.w"]
fn psraiw(a: i16x8, imm8: i32) -> i16x8;
#[link_name = "llvm.x86.sse2.psra.w"]
fn psraw(a: i16x8, count: i16x8) -> i16x8;
#[link_name = "llvm.x86.sse2.psrai.d"]
fn psraid(a: i32x4, imm8: i32) -> i32x4;
#[link_name = "llvm.x86.sse2.psra.d"]
fn psrad(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.sse2.psrli.w"]
fn psrliw(a: i16x8, imm8: i32) -> i16x8;
#[link_name = "llvm.x86.sse2.psrl.w"]
fn psrlw(a: i16x8, count: i16x8) -> i16x8;
#[link_name = "llvm.x86.sse2.psrli.d"]
fn psrlid(a: i32x4, imm8: i32) -> i32x4;
#[link_name = "llvm.x86.sse2.psrl.d"]
fn psrld(a: i32x4, count: i32x4) -> i32x4;
#[link_name = "llvm.x86.sse2.psrli.q"]
fn psrliq(a: i64x2, imm8: i32) -> i64x2;
#[link_name = "llvm.x86.sse2.psrl.q"]
fn psrlq(a: i64x2, count: i64x2) -> i64x2;
#[link_name = "llvm.x86.sse2.cvtdq2ps"]