From 0c0f72ee7f4816581346a0b03059f900512a70f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sat, 12 Aug 2023 12:30:27 +0200 Subject: [PATCH] Implement AVX512BW 16-bit shift by immediate (slli_epi16) with `simd_shl` instead of LLVM intrinsics --- .../crates/core_arch/src/x86/avx512bw.rs | 75 +++++++++++-------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs index 1924d42fb475..f55526f4b763 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs @@ -5339,9 +5339,11 @@ pub unsafe fn _mm_maskz_sll_epi16(k: __mmask8, a: __m128i, count: __m128i) -> __ #[rustc_legacy_const_generics(1)] pub unsafe fn _mm512_slli_epi16(a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let r = vpslliw(a, IMM8); - transmute(r) + if IMM8 >= 16 { + _mm512_setzero_si512() + } else { + transmute(simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16))) + } } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -5357,9 +5359,12 @@ pub unsafe fn _mm512_mask_slli_epi16( a: __m512i, ) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let shf = vpslliw(a, IMM8); - transmute(simd_select_bitmask(k, shf, src.as_i16x32())) + let shf = if IMM8 >= 16 { + u16x32::splat(0) + } else { + simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16)) + }; + transmute(simd_select_bitmask(k, shf, src.as_u16x32())) } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -5371,10 +5376,13 @@ pub unsafe fn _mm512_mask_slli_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm512_maskz_slli_epi16(k: __mmask32, a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let shf = vpslliw(a, IMM8); - let zero = _mm512_setzero_si512().as_i16x32(); - transmute(simd_select_bitmask(k, shf, zero)) + if IMM8 >= 16 { + _mm512_setzero_si512() + } else { + let shf = simd_shl(a.as_u16x32(), u16x32::splat(IMM8 as u16)); + let zero = u16x32::splat(0); + transmute(simd_select_bitmask(k, shf, zero)) + } } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -5390,9 +5398,12 @@ pub unsafe fn _mm256_mask_slli_epi16( a: __m256i, ) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = pslliw256(a.as_i16x16(), imm8); - transmute(simd_select_bitmask(k, r, src.as_i16x16())) + let shf = if IMM8 >= 16 { + u16x16::splat(0) + } else { + simd_shl(a.as_u16x16(), u16x16::splat(IMM8 as u16)) + }; + transmute(simd_select_bitmask(k, shf, src.as_u16x16())) } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -5404,10 +5415,13 @@ pub unsafe fn _mm256_mask_slli_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm256_maskz_slli_epi16(k: __mmask16, a: __m256i) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = pslliw256(a.as_i16x16(), imm8); - let zero = _mm256_setzero_si256().as_i16x16(); - transmute(simd_select_bitmask(k, r, zero)) + if IMM8 >= 16 { + _mm256_setzero_si256() + } else { + let shf = simd_shl(a.as_u16x16(), u16x16::splat(IMM8 as u16)); + let zero = u16x16::splat(0); + transmute(simd_select_bitmask(k, shf, zero)) + } } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -5423,9 +5437,12 @@ pub unsafe fn _mm_mask_slli_epi16( a: __m128i, ) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = pslliw128(a.as_i16x8(), imm8); - transmute(simd_select_bitmask(k, r, src.as_i16x8())) + let shf = if IMM8 >= 16 { + u16x8::splat(0) + } else { + simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16)) + }; + transmute(simd_select_bitmask(k, shf, src.as_u16x8())) } /// Shift packed 16-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -5437,10 +5454,13 @@ pub unsafe fn _mm_mask_slli_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm_maskz_slli_epi16(k: __mmask8, a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = pslliw128(a.as_i16x8(), imm8); - let zero = _mm_setzero_si128().as_i16x8(); - transmute(simd_select_bitmask(k, r, zero)) + if IMM8 >= 16 { + _mm_setzero_si128() + } else { + let shf = simd_shl(a.as_u16x8(), u16x8::splat(IMM8 as u16)); + let zero = u16x8::splat(0); + transmute(simd_select_bitmask(k, shf, zero)) + } } /// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst. @@ -9965,13 +9985,6 @@ extern "C" { #[link_name = "llvm.x86.avx512.psll.w.512"] fn vpsllw(a: i16x32, count: i16x8) -> i16x32; - #[link_name = "llvm.x86.avx512.pslli.w.512"] - fn vpslliw(a: i16x32, imm8: u32) -> i16x32; - - #[link_name = "llvm.x86.avx2.pslli.w"] - fn pslliw256(a: i16x16, imm8: i32) -> i16x16; - #[link_name = "llvm.x86.sse2.pslli.w"] - fn pslliw128(a: i16x8, imm8: i32) -> i16x8; #[link_name = "llvm.x86.avx512.psllv.w.512"] fn vpsllvw(a: i16x32, b: i16x32) -> i16x32;