From bf3a28f6ad1d42b4faddedf1238a3b7c66a16250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sat, 12 Aug 2023 01:27:42 +0200 Subject: [PATCH] Implement AVX512F 32-bit shift by immediate (slli_epi32) with `simd_shl` instead of LLVM intrinsics --- .../crates/core_arch/src/x86/avx512f.rs | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/library/stdarch/crates/core_arch/src/x86/avx512f.rs b/library/stdarch/crates/core_arch/src/x86/avx512f.rs index c3c1275a69ab..5e567eb1112c 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512f.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512f.rs @@ -17141,9 +17141,11 @@ pub unsafe fn _mm_maskz_ror_epi64(k: __mmask8, a: __m128i) -> _ #[rustc_legacy_const_generics(1)] pub unsafe fn _mm512_slli_epi32(a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i32x16(); - let r = vpsllid(a, IMM8); - transmute(r) + if IMM8 >= 32 { + _mm512_setzero_si512() + } else { + transmute(simd_shl(a.as_u32x16(), u32x16::splat(IMM8 as u32))) + } } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -17159,9 +17161,12 @@ pub unsafe fn _mm512_mask_slli_epi32( a: __m512i, ) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i32x16(); - let shf = vpsllid(a, IMM8); - transmute(simd_select_bitmask(k, shf, src.as_i32x16())) + let shf = if IMM8 >= 32 { + u32x16::splat(0) + } else { + simd_shl(a.as_u32x16(), u32x16::splat(IMM8)) + }; + transmute(simd_select_bitmask(k, shf, src.as_u32x16())) } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -17173,10 +17178,13 @@ pub unsafe fn _mm512_mask_slli_epi32( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm512_maskz_slli_epi32(k: __mmask16, a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i32x16(); - let shf = vpsllid(a, IMM8); - let zero = _mm512_setzero_si512().as_i32x16(); - transmute(simd_select_bitmask(k, shf, zero)) + if IMM8 >= 32 { + _mm512_setzero_si512() + } else { + let shf = simd_shl(a.as_u32x16(), u32x16::splat(IMM8)); + let zero = u32x16::splat(0); + transmute(simd_select_bitmask(k, shf, zero)) + } } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -17192,9 +17200,12 @@ pub unsafe fn _mm256_mask_slli_epi32( a: __m256i, ) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psllid256(a.as_i32x8(), imm8); - transmute(simd_select_bitmask(k, r, src.as_i32x8())) + let r = if IMM8 >= 32 { + u32x8::splat(0) + } else { + simd_shl(a.as_u32x8(), u32x8::splat(IMM8)) + }; + transmute(simd_select_bitmask(k, r, src.as_u32x8())) } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -17206,10 +17217,13 @@ pub unsafe fn _mm256_mask_slli_epi32( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm256_maskz_slli_epi32(k: __mmask8, a: __m256i) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psllid256(a.as_i32x8(), imm8); - let zero = _mm256_setzero_si256().as_i32x8(); - transmute(simd_select_bitmask(k, r, zero)) + if IMM8 >= 32 { + _mm256_setzero_si256() + } else { + let r = simd_shl(a.as_u32x8(), u32x8::splat(IMM8)); + let zero = u32x8::splat(0); + transmute(simd_select_bitmask(k, r, zero)) + } } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -17225,9 +17239,12 @@ pub unsafe fn _mm_mask_slli_epi32( a: __m128i, ) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psllid128(a.as_i32x4(), imm8); - transmute(simd_select_bitmask(k, r, src.as_i32x4())) + let r = if IMM8 >= 32 { + u32x4::splat(0) + } else { + simd_shl(a.as_u32x4(), u32x4::splat(IMM8)) + }; + transmute(simd_select_bitmask(k, r, src.as_u32x4())) } /// Shift packed 32-bit integers in a left by imm8 while shifting in zeros, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -17239,10 +17256,13 @@ pub unsafe fn _mm_mask_slli_epi32( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm_maskz_slli_epi32(k: __mmask8, a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psllid128(a.as_i32x4(), imm8); - let zero = _mm_setzero_si128().as_i32x4(); - transmute(simd_select_bitmask(k, r, zero)) + if IMM8 >= 32 { + _mm_setzero_si128() + } else { + let r = simd_shl(a.as_u32x4(), u32x4::splat(IMM8)); + let zero = u32x4::splat(0); + transmute(simd_select_bitmask(k, r, zero)) + } } /// Shift packed 32-bit integers in a right by imm8 while shifting in zeros, and store the results in dst. @@ -38449,14 +38469,6 @@ extern "C" { #[link_name = "llvm.x86.avx512.psrlv.q.512"] fn vpsrlvq(a: i64x8, b: i64x8) -> i64x8; - #[link_name = "llvm.x86.avx512.pslli.d.512"] - fn vpsllid(a: i32x16, imm8: u32) -> i32x16; - - #[link_name = "llvm.x86.avx2.pslli.d"] - fn psllid256(a: i32x8, imm8: i32) -> i32x8; - #[link_name = "llvm.x86.sse2.pslli.d"] - fn psllid128(a: i32x4, imm8: i32) -> i32x4; - #[link_name = "llvm.x86.avx512.psrli.d.512"] fn vpsrlid(a: i32x16, imm8: u32) -> i32x16;