diff --git a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs index 3f58d165fbeb..3640235396a7 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs @@ -5996,9 +5996,7 @@ pub unsafe fn _mm_maskz_sra_epi16(k: __mmask8, a: __m128i, count: __m128i) -> __ #[rustc_legacy_const_generics(1)] pub unsafe fn _mm512_srai_epi16(a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let r = vpsraiw(a, IMM8); - transmute(r) + transmute(simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16))) } /// Shift packed 16-bit integers in a right by imm8 while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -6014,8 +6012,7 @@ pub unsafe fn _mm512_mask_srai_epi16( a: __m512i, ) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let shf = vpsraiw(a, IMM8); + let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16)); transmute(simd_select_bitmask(k, shf, src.as_i16x32())) } @@ -6028,9 +6025,8 @@ pub unsafe fn _mm512_mask_srai_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm512_maskz_srai_epi16(k: __mmask32, a: __m512i) -> __m512i { static_assert_uimm_bits!(IMM8, 8); - let a = a.as_i16x32(); - let shf = vpsraiw(a, IMM8); - let zero = _mm512_setzero_si512().as_i16x32(); + let shf = simd_shr(a.as_i16x32(), i16x32::splat(IMM8.min(15) as i16)); + let zero = i16x32::splat(0); transmute(simd_select_bitmask(k, shf, zero)) } @@ -6047,8 +6043,7 @@ pub unsafe fn _mm256_mask_srai_epi16( a: __m256i, ) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psraiw256(a.as_i16x16(), imm8); + let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16)); transmute(simd_select_bitmask(k, r, src.as_i16x16())) } @@ -6061,9 +6056,8 @@ pub unsafe fn _mm256_mask_srai_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm256_maskz_srai_epi16(k: __mmask16, a: __m256i) -> __m256i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psraiw256(a.as_i16x16(), imm8); - let zero = _mm256_setzero_si256().as_i16x16(); + let r = simd_shr(a.as_i16x16(), i16x16::splat(IMM8.min(15) as i16)); + let zero = i16x16::splat(0); transmute(simd_select_bitmask(k, r, zero)) } @@ -6080,8 +6074,7 @@ pub unsafe fn _mm_mask_srai_epi16( a: __m128i, ) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psraiw128(a.as_i16x8(), imm8); + let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16)); transmute(simd_select_bitmask(k, r, src.as_i16x8())) } @@ -6094,9 +6087,8 @@ pub unsafe fn _mm_mask_srai_epi16( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm_maskz_srai_epi16(k: __mmask8, a: __m128i) -> __m128i { static_assert_uimm_bits!(IMM8, 8); - let imm8 = IMM8 as i32; - let r = psraiw128(a.as_i16x8(), imm8); - let zero = _mm_setzero_si128().as_i16x8(); + let r = simd_shr(a.as_i16x8(), i16x8::splat(IMM8.min(15) as i16)); + let zero = i16x8::splat(0); transmute(simd_select_bitmask(k, r, zero)) } @@ -10013,13 +10005,6 @@ extern "C" { #[link_name = "llvm.x86.avx512.psra.w.512"] fn vpsraw(a: i16x32, count: i16x8) -> i16x32; - #[link_name = "llvm.x86.avx512.psrai.w.512"] - fn vpsraiw(a: i16x32, imm8: u32) -> i16x32; - - #[link_name = "llvm.x86.avx2.psrai.w"] - fn psraiw256(a: i16x16, imm8: i32) -> i16x16; - #[link_name = "llvm.x86.sse2.psrai.w"] - fn psraiw128(a: i16x8, imm8: i32) -> i16x8; #[link_name = "llvm.x86.avx512.psrav.w.512"] fn vpsravw(a: i16x32, count: i16x32) -> i16x32;