From 819fe11c495bba563d59cc1302b4104f794c0be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sat, 12 Aug 2023 00:52:13 +0200 Subject: [PATCH] Implement SSE2 shift by immediate (slli, srli, srai) with `simd_sh{l,r}` instead of LLVM intrinsics --- .../stdarch/crates/core_arch/src/x86/sse2.rs | 56 +++++++++++-------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/library/stdarch/crates/core_arch/src/x86/sse2.rs b/library/stdarch/crates/core_arch/src/x86/sse2.rs index 342423c84ff9..1136747ea436 100644 --- a/library/stdarch/crates/core_arch/src/x86/sse2.rs +++ b/library/stdarch/crates/core_arch/src/x86/sse2.rs @@ -501,7 +501,11 @@ pub unsafe fn _mm_bsrli_si128(a: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_slli_epi16(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(pslliw(a.as_i16x8(), IMM8)) + if IMM8 >= 16 { + _mm_setzero_si128() + } else { + transmute(simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16))) + } } /// Shifts packed 16-bit integers in `a` left by `count` while shifting in @@ -526,7 +530,11 @@ pub unsafe fn _mm_sll_epi16(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_slli_epi32(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psllid(a.as_i32x4(), IMM8)) + if IMM8 >= 32 { + _mm_setzero_si128() + } else { + transmute(simd_shl(a.as_u32x4(), u32x4::splat(IMM8 as u32))) + } } /// Shifts packed 32-bit integers in `a` left by `count` while shifting in @@ -551,7 +559,11 @@ pub unsafe fn _mm_sll_epi32(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_slli_epi64(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(pslliq(a.as_i64x2(), IMM8)) + if IMM8 >= 64 { + _mm_setzero_si128() + } else { + transmute(simd_shl(a.as_u64x2(), u64x2::splat(IMM8 as u64))) + } } /// Shifts packed 64-bit integers in `a` left by `count` while shifting in @@ -577,7 +589,7 @@ pub unsafe fn _mm_sll_epi64(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_srai_epi16(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psraiw(a.as_i16x8(), IMM8)) + transmute(simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16))) } /// Shifts packed 16-bit integers in `a` right by `count` while shifting in sign @@ -603,7 +615,7 @@ pub unsafe fn _mm_sra_epi16(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_srai_epi32(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psraid(a.as_i32x4(), IMM8)) + transmute(simd_shr(a.as_i32x4(), i32x4::splat(IMM8.min(31)))) } /// Shifts packed 32-bit integers in `a` right by `count` while shifting in sign @@ -680,7 +692,11 @@ unsafe fn _mm_srli_si128_impl(a: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_srli_epi16(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psrliw(a.as_i16x8(), IMM8)) + if IMM8 >= 16 { + _mm_setzero_si128() + } else { + transmute(simd_shr(a.as_u16x8(), u16x8::splat(IMM8 as u16))) + } } /// Shifts packed 16-bit integers in `a` right by `count` while shifting in @@ -706,7 +722,11 @@ pub unsafe fn _mm_srl_epi16(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_srli_epi32(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psrlid(a.as_i32x4(), IMM8)) + if IMM8 >= 32 { + _mm_setzero_si128() + } else { + transmute(simd_shr(a.as_u32x4(), u32x4::splat(IMM8 as u32))) + } } /// Shifts packed 32-bit integers in `a` right by `count` while shifting in @@ -732,7 +752,11 @@ pub unsafe fn _mm_srl_epi32(a: __m128i, count: __m128i) -> __m128i { #[stable(feature = "simd_x86", since = "1.27.0")] pub unsafe fn _mm_srli_epi64(a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - transmute(psrliq(a.as_i64x2(), IMM8)) + if IMM8 >= 32 { + _mm_setzero_si128() + } else { + transmute(simd_shr(a.as_u64x2(), u64x2::splat(IMM8 as u64))) + } } /// Shifts packed 64-bit integers in `a` right by `count` while shifting in @@ -2816,36 +2840,20 @@ extern "C" { fn pmuludq(a: u32x4, b: u32x4) -> u64x2; #[link_name = "llvm.x86.sse2.psad.bw"] fn psadbw(a: u8x16, b: u8x16) -> u64x2; - #[link_name = "llvm.x86.sse2.pslli.w"] - fn pslliw(a: i16x8, imm8: i32) -> i16x8; #[link_name = "llvm.x86.sse2.psll.w"] fn psllw(a: i16x8, count: i16x8) -> i16x8; - #[link_name = "llvm.x86.sse2.pslli.d"] - fn psllid(a: i32x4, imm8: i32) -> i32x4; #[link_name = "llvm.x86.sse2.psll.d"] fn pslld(a: i32x4, count: i32x4) -> i32x4; - #[link_name = "llvm.x86.sse2.pslli.q"] - fn pslliq(a: i64x2, imm8: i32) -> i64x2; #[link_name = "llvm.x86.sse2.psll.q"] fn psllq(a: i64x2, count: i64x2) -> i64x2; - #[link_name = "llvm.x86.sse2.psrai.w"] - fn psraiw(a: i16x8, imm8: i32) -> i16x8; #[link_name = "llvm.x86.sse2.psra.w"] fn psraw(a: i16x8, count: i16x8) -> i16x8; - #[link_name = "llvm.x86.sse2.psrai.d"] - fn psraid(a: i32x4, imm8: i32) -> i32x4; #[link_name = "llvm.x86.sse2.psra.d"] fn psrad(a: i32x4, count: i32x4) -> i32x4; - #[link_name = "llvm.x86.sse2.psrli.w"] - fn psrliw(a: i16x8, imm8: i32) -> i16x8; #[link_name = "llvm.x86.sse2.psrl.w"] fn psrlw(a: i16x8, count: i16x8) -> i16x8; - #[link_name = "llvm.x86.sse2.psrli.d"] - fn psrlid(a: i32x4, imm8: i32) -> i32x4; #[link_name = "llvm.x86.sse2.psrl.d"] fn psrld(a: i32x4, count: i32x4) -> i32x4; - #[link_name = "llvm.x86.sse2.psrli.q"] - fn psrliq(a: i64x2, imm8: i32) -> i64x2; #[link_name = "llvm.x86.sse2.psrl.q"] fn psrlq(a: i64x2, count: i64x2) -> i64x2; #[link_name = "llvm.x86.sse2.cvtdq2ps"]