From 0d9520dfd4cd376c67bf7356df9296bb6ab3713e Mon Sep 17 00:00:00 2001 From: Tobias Decking Date: Mon, 24 Jun 2024 00:06:30 +0200 Subject: [PATCH] Refactor avx512f: sqrt + rounding fix --- .../stdarch/crates/core_arch/missing-x86.md | 4 - .../crates/core_arch/src/x86/avx512f.rs | 198 +++++++++--------- .../crates/core_arch/src/x86_64/avx512f.rs | 16 ++ 3 files changed, 120 insertions(+), 98 deletions(-) diff --git a/library/stdarch/crates/core_arch/missing-x86.md b/library/stdarch/crates/core_arch/missing-x86.md index 8da6074cacca..6daab7715d32 100644 --- a/library/stdarch/crates/core_arch/missing-x86.md +++ b/library/stdarch/crates/core_arch/missing-x86.md @@ -204,8 +204,6 @@ * [ ] [`_mm256_mmask_i64gather_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_epi64) * [ ] [`_mm256_mmask_i64gather_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_pd) * [ ] [`_mm256_mmask_i64gather_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_ps) - * [ ] [`_mm256_rsqrt14_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_rsqrt14_pd) - * [ ] [`_mm256_rsqrt14_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_rsqrt14_ps) * [ ] [`_mm_abs_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_abs_epi64) * [ ] [`_mm_i32scatter_epi32`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_i32scatter_epi32) * [ ] [`_mm_i32scatter_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_i32scatter_epi64) @@ -236,8 +234,6 @@ * [ ] [`_mm_mmask_i64gather_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_epi64) * [ ] [`_mm_mmask_i64gather_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_pd) * [ ] [`_mm_mmask_i64gather_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_ps) - * [ ] [`_mm_rsqrt14_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_rsqrt14_pd) - * [ ] [`_mm_rsqrt14_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_rsqrt14_ps)

diff --git a/library/stdarch/crates/core_arch/src/x86/avx512f.rs b/library/stdarch/crates/core_arch/src/x86/avx512f.rs index 1dbd813d430c..9e6ce92b1707 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512f.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512f.rs @@ -3001,7 +3001,7 @@ pub unsafe fn _mm_maskz_min_epu64(k: __mmask8, a: __m128i, b: __m128i) -> __m128 #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm512_sqrt_ps(a: __m512) -> __m512 { - transmute(vsqrtps(a.as_f32x16(), _MM_FROUND_CUR_DIRECTION)) + simd_fsqrt(a) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3012,8 +3012,7 @@ pub unsafe fn _mm512_sqrt_ps(a: __m512) -> __m512 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm512_mask_sqrt_ps(src: __m512, k: __mmask16, a: __m512) -> __m512 { - let sqrt = _mm512_sqrt_ps(a).as_f32x16(); - transmute(simd_select_bitmask(k, sqrt, src.as_f32x16())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3024,9 +3023,7 @@ pub unsafe fn _mm512_mask_sqrt_ps(src: __m512, k: __mmask16, a: __m512) -> __m51 #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm512_maskz_sqrt_ps(k: __mmask16, a: __m512) -> __m512 { - let sqrt = _mm512_sqrt_ps(a).as_f32x16(); - let zero = _mm512_setzero_ps().as_f32x16(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm512_setzero_ps()) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3037,8 +3034,7 @@ pub unsafe fn _mm512_maskz_sqrt_ps(k: __mmask16, a: __m512) -> __m512 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm256_mask_sqrt_ps(src: __m256, k: __mmask8, a: __m256) -> __m256 { - let sqrt = _mm256_sqrt_ps(a).as_f32x8(); - transmute(simd_select_bitmask(k, sqrt, src.as_f32x8())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3049,9 +3045,7 @@ pub unsafe fn _mm256_mask_sqrt_ps(src: __m256, k: __mmask8, a: __m256) -> __m256 #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm256_maskz_sqrt_ps(k: __mmask8, a: __m256) -> __m256 { - let sqrt = _mm256_sqrt_ps(a).as_f32x8(); - let zero = _mm256_setzero_ps().as_f32x8(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm256_setzero_ps()) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3062,8 +3056,7 @@ pub unsafe fn _mm256_maskz_sqrt_ps(k: __mmask8, a: __m256) -> __m256 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm_mask_sqrt_ps(src: __m128, k: __mmask8, a: __m128) -> __m128 { - let sqrt = _mm_sqrt_ps(a).as_f32x4(); - transmute(simd_select_bitmask(k, sqrt, src.as_f32x4())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3074,9 +3067,7 @@ pub unsafe fn _mm_mask_sqrt_ps(src: __m128, k: __mmask8, a: __m128) -> __m128 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtps))] pub unsafe fn _mm_maskz_sqrt_ps(k: __mmask8, a: __m128) -> __m128 { - let sqrt = _mm_sqrt_ps(a).as_f32x4(); - let zero = _mm_setzero_ps().as_f32x4(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm_setzero_ps()) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst. @@ -3087,7 +3078,7 @@ pub unsafe fn _mm_maskz_sqrt_ps(k: __mmask8, a: __m128) -> __m128 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm512_sqrt_pd(a: __m512d) -> __m512d { - transmute(vsqrtpd(a.as_f64x8(), _MM_FROUND_CUR_DIRECTION)) + simd_fsqrt(a) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3098,8 +3089,7 @@ pub unsafe fn _mm512_sqrt_pd(a: __m512d) -> __m512d { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm512_mask_sqrt_pd(src: __m512d, k: __mmask8, a: __m512d) -> __m512d { - let sqrt = _mm512_sqrt_pd(a).as_f64x8(); - transmute(simd_select_bitmask(k, sqrt, src.as_f64x8())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3110,9 +3100,7 @@ pub unsafe fn _mm512_mask_sqrt_pd(src: __m512d, k: __mmask8, a: __m512d) -> __m5 #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm512_maskz_sqrt_pd(k: __mmask8, a: __m512d) -> __m512d { - let sqrt = _mm512_sqrt_pd(a).as_f64x8(); - let zero = _mm512_setzero_pd().as_f64x8(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm512_setzero_pd()) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3123,8 +3111,7 @@ pub unsafe fn _mm512_maskz_sqrt_pd(k: __mmask8, a: __m512d) -> __m512d { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm256_mask_sqrt_pd(src: __m256d, k: __mmask8, a: __m256d) -> __m256d { - let sqrt = _mm256_sqrt_pd(a).as_f64x4(); - transmute(simd_select_bitmask(k, sqrt, src.as_f64x4())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3135,9 +3122,7 @@ pub unsafe fn _mm256_mask_sqrt_pd(src: __m256d, k: __mmask8, a: __m256d) -> __m2 #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm256_maskz_sqrt_pd(k: __mmask8, a: __m256d) -> __m256d { - let sqrt = _mm256_sqrt_pd(a).as_f64x4(); - let zero = _mm256_setzero_pd().as_f64x4(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm256_setzero_pd()) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). @@ -3148,8 +3133,7 @@ pub unsafe fn _mm256_maskz_sqrt_pd(k: __mmask8, a: __m256d) -> __m256d { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm_mask_sqrt_pd(src: __m128d, k: __mmask8, a: __m128d) -> __m128d { - let sqrt = _mm_sqrt_pd(a).as_f64x2(); - transmute(simd_select_bitmask(k, sqrt, src.as_f64x2())) + simd_select_bitmask(k, simd_fsqrt(a), src) } /// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set). @@ -3160,9 +3144,7 @@ pub unsafe fn _mm_mask_sqrt_pd(src: __m128d, k: __mmask8, a: __m128d) -> __m128d #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtpd))] pub unsafe fn _mm_maskz_sqrt_pd(k: __mmask8, a: __m128d) -> __m128d { - let sqrt = _mm_sqrt_pd(a).as_f64x2(); - let zero = _mm_setzero_pd().as_f64x2(); - transmute(simd_select_bitmask(k, sqrt, zero)) + simd_select_bitmask(k, simd_fsqrt(a), _mm_setzero_pd()) } /// Multiply packed single-precision (32-bit) floating-point elements in a and b, add the intermediate result to packed elements in c, and store the results in dst. @@ -4764,6 +4746,21 @@ pub unsafe fn _mm512_maskz_rsqrt14_ps(k: __mmask16, a: __m512) -> __m512 { )) } +/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_rsqrt14_ps) +#[inline] +#[target_feature(enable = "avx512f,avx512vl")] +#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] +#[cfg_attr(test, assert_instr(vrsqrt14ps))] +pub unsafe fn _mm256_rsqrt14_ps(a: __m256) -> __m256 { + transmute(vrsqrt14ps256( + a.as_f32x8(), + _mm256_setzero_ps().as_f32x8(), + 0b11111111, + )) +} + /// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mask_rsqrt14_ps&expand=4815) @@ -4790,6 +4787,21 @@ pub unsafe fn _mm256_maskz_rsqrt14_ps(k: __mmask8, a: __m256) -> __m256 { )) } +/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_rsqrt14_ps) +#[inline] +#[target_feature(enable = "avx512f,avx512vl")] +#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] +#[cfg_attr(test, assert_instr(vrsqrt14ps))] +pub unsafe fn _mm_rsqrt14_ps(a: __m128) -> __m128 { + transmute(vrsqrt14ps128( + a.as_f32x4(), + _mm_setzero_ps().as_f32x4(), + 0b00001111, + )) +} + /// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mask_rsqrt14_ps&expand=4813) @@ -4849,6 +4861,21 @@ pub unsafe fn _mm512_maskz_rsqrt14_pd(k: __mmask8, a: __m512d) -> __m512d { transmute(vrsqrt14pd(a.as_f64x8(), _mm512_setzero_pd().as_f64x8(), k)) } +/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_rsqrt14_pd) +#[inline] +#[target_feature(enable = "avx512f,avx512vl")] +#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] +#[cfg_attr(test, assert_instr(vrsqrt14pd))] +pub unsafe fn _mm256_rsqrt14_pd(a: __m256d) -> __m256d { + transmute(vrsqrt14pd256( + a.as_f64x4(), + _mm256_setzero_pd().as_f64x4(), + 0b00001111, + )) +} + /// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mask_rsqrt14_pd&expand=4808) @@ -4875,6 +4902,21 @@ pub unsafe fn _mm256_maskz_rsqrt14_pd(k: __mmask8, a: __m256d) -> __m256d { )) } +/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_rsqrt14_pd) +#[inline] +#[target_feature(enable = "avx512f,avx512vl")] +#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] +#[cfg_attr(test, assert_instr(vrsqrt14pd))] +pub unsafe fn _mm_rsqrt14_pd(a: __m128d) -> __m128d { + transmute(vrsqrt14pd128( + a.as_f64x2(), + _mm_setzero_pd().as_f64x2(), + 0b00000011, + )) +} + /// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mask_rsqrt14_pd&expand=4806) @@ -34834,13 +34876,7 @@ pub unsafe fn _mm_maskz_min_sd(k: __mmask8, a: __m128d, b: __m128d) -> __m128d { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtss))] pub unsafe fn _mm_mask_sqrt_ss(src: __m128, k: __mmask8, a: __m128, b: __m128) -> __m128 { - transmute(vsqrtss( - a.as_f32x4(), - b.as_f32x4(), - src.as_f32x4(), - k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, - )) + vsqrtss(a, b, src, k, _MM_FROUND_CUR_DIRECTION) } /// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst. @@ -34851,13 +34887,7 @@ pub unsafe fn _mm_mask_sqrt_ss(src: __m128, k: __mmask8, a: __m128, b: __m128) - #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtss))] pub unsafe fn _mm_maskz_sqrt_ss(k: __mmask8, a: __m128, b: __m128) -> __m128 { - transmute(vsqrtss( - a.as_f32x4(), - b.as_f32x4(), - _mm_setzero_ps().as_f32x4(), - k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, - )) + vsqrtss(a, b, _mm_setzero_ps(), k, _MM_FROUND_CUR_DIRECTION) } /// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper element from a to the upper element of dst. @@ -34868,13 +34898,7 @@ pub unsafe fn _mm_maskz_sqrt_ss(k: __mmask8, a: __m128, b: __m128) -> __m128 { #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtsd))] pub unsafe fn _mm_mask_sqrt_sd(src: __m128d, k: __mmask8, a: __m128d, b: __m128d) -> __m128d { - transmute(vsqrtsd( - a.as_f64x2(), - b.as_f64x2(), - src.as_f64x2(), - k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, - )) + vsqrtsd(a, b, src, k, _MM_FROUND_CUR_DIRECTION) } /// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper element from a to the upper element of dst. @@ -34885,13 +34909,7 @@ pub unsafe fn _mm_mask_sqrt_sd(src: __m128d, k: __mmask8, a: __m128d, b: __m128d #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] #[cfg_attr(test, assert_instr(vsqrtsd))] pub unsafe fn _mm_maskz_sqrt_sd(k: __mmask8, a: __m128d, b: __m128d) -> __m128d { - transmute(vsqrtsd( - a.as_f64x2(), - b.as_f64x2(), - _mm_setzero_pd().as_f64x2(), - k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, - )) + vsqrtsd(a, b, _mm_setzero_pd(), k, _MM_FROUND_CUR_DIRECTION) } /// Compute the approximate reciprocal square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst, and copy the upper 3 packed elements from a to the upper elements of dst. The maximum relative error for this approximation is less than 2^-14. @@ -36979,11 +36997,7 @@ pub unsafe fn _mm_maskz_min_round_sd( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm_sqrt_round_ss(a: __m128, b: __m128) -> __m128 { static_assert_rounding!(ROUNDING); - let a = a.as_f32x4(); - let b = b.as_f32x4(); - let zero = _mm_setzero_ps().as_f32x4(); - let r = vsqrtss(a, b, zero, 0b1, ROUNDING); - transmute(r) + vsqrtss(a, b, _mm_setzero_ps(), 0b1, ROUNDING) } /// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst.\ @@ -37008,11 +37022,7 @@ pub unsafe fn _mm_mask_sqrt_round_ss( b: __m128, ) -> __m128 { static_assert_rounding!(ROUNDING); - let a = a.as_f32x4(); - let b = b.as_f32x4(); - let src = src.as_f32x4(); - let r = vsqrtss(a, b, src, k, ROUNDING); - transmute(r) + vsqrtss(a, b, src, k, ROUNDING) } /// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst.\ @@ -37036,11 +37046,7 @@ pub unsafe fn _mm_maskz_sqrt_round_ss( b: __m128, ) -> __m128 { static_assert_rounding!(ROUNDING); - let a = a.as_f32x4(); - let b = b.as_f32x4(); - let zero = _mm_setzero_ps().as_f32x4(); - let r = vsqrtss(a, b, zero, k, ROUNDING); - transmute(r) + vsqrtss(a, b, _mm_setzero_ps(), k, ROUNDING) } /// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst, and copy the upper element from a to the upper element of dst.\ @@ -37060,11 +37066,7 @@ pub unsafe fn _mm_maskz_sqrt_round_ss( #[rustc_legacy_const_generics(2)] pub unsafe fn _mm_sqrt_round_sd(a: __m128d, b: __m128d) -> __m128d { static_assert_rounding!(ROUNDING); - let a = a.as_f64x2(); - let b = b.as_f64x2(); - let zero = _mm_setzero_pd().as_f64x2(); - let r = vsqrtsd(a, b, zero, 0b1, ROUNDING); - transmute(r) + vsqrtsd(a, b, _mm_setzero_pd(), 0b1, ROUNDING) } /// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.\ @@ -37089,11 +37091,7 @@ pub unsafe fn _mm_mask_sqrt_round_sd( b: __m128d, ) -> __m128d { static_assert_rounding!(ROUNDING); - let a = a.as_f64x2(); - let b = b.as_f64x2(); - let src = src.as_f64x2(); - let r = vsqrtsd(a, b, src, k, ROUNDING); - transmute(r) + vsqrtsd(a, b, src, k, ROUNDING) } /// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.\ @@ -37117,11 +37115,7 @@ pub unsafe fn _mm_maskz_sqrt_round_sd( b: __m128d, ) -> __m128d { static_assert_rounding!(ROUNDING); - let a = a.as_f64x2(); - let b = b.as_f64x2(); - let zero = _mm_setzero_pd().as_f64x2(); - let r = vsqrtsd(a, b, zero, k, ROUNDING); - transmute(r) + vsqrtsd(a, b, _mm_setzero_pd(), k, ROUNDING) } /// Convert the exponent of the lower single-precision (32-bit) floating-point element in b to a single-precision (32-bit) floating-point number representing the integer exponent, store the result in the lower element of dst, and copy the upper 3 packed elements from a to the upper elements of dst. This intrinsic essentially calculates floor(log2(x)) for the lower element.\ @@ -39134,7 +39128,7 @@ pub unsafe fn _mm_mask_cvtsd_ss(src: __m128, k: __mmask8, a: __m128, b: __m128d) b.as_f64x2(), src.as_f32x4(), k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, + _MM_FROUND_CUR_DIRECTION, )) } @@ -39151,7 +39145,7 @@ pub unsafe fn _mm_maskz_cvtsd_ss(k: __mmask8, a: __m128, b: __m128d) -> __m128 { b.as_f64x2(), _mm_setzero_ps().as_f32x4(), k, - _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC, + _MM_FROUND_CUR_DIRECTION, )) } @@ -41117,9 +41111,9 @@ extern "C" { #[link_name = "llvm.x86.avx512.mask.min.sd.round"] fn vminsd(a: f64x2, b: f64x2, src: f64x2, mask: u8, sae: i32) -> f64x2; #[link_name = "llvm.x86.avx512.mask.sqrt.ss"] - fn vsqrtss(a: f32x4, b: f32x4, src: f32x4, mask: u8, rounding: i32) -> f32x4; + fn vsqrtss(a: __m128, b: __m128, src: __m128, mask: u8, rounding: i32) -> __m128; #[link_name = "llvm.x86.avx512.mask.sqrt.sd"] - fn vsqrtsd(a: f64x2, b: f64x2, src: f64x2, mask: u8, rounding: i32) -> f64x2; + fn vsqrtsd(a: __m128d, b: __m128d, src: __m128d, mask: u8, rounding: i32) -> __m128d; #[link_name = "llvm.x86.avx512.mask.getexp.ss"] fn vgetexpss(a: f32x4, b: f32x4, src: f32x4, mask: u8, sae: i32) -> f32x4; #[link_name = "llvm.x86.avx512.mask.getexp.sd"] @@ -43769,6 +43763,14 @@ mod tests { assert_eq_m512(r, e); } + #[simd_test(enable = "avx512f,avx512vl")] + unsafe fn test_mm256_rsqrt14_ps() { + let a = _mm256_set1_ps(3.); + let r = _mm256_rsqrt14_ps(a); + let e = _mm256_set1_ps(0.5773392); + assert_eq_m256(r, e); + } + #[simd_test(enable = "avx512f,avx512vl")] unsafe fn test_mm256_mask_rsqrt14_ps() { let a = _mm256_set1_ps(3.); @@ -43789,6 +43791,14 @@ mod tests { assert_eq_m256(r, e); } + #[simd_test(enable = "avx512f,avx512vl")] + unsafe fn test_mm_rsqrt14_ps() { + let a = _mm_set1_ps(3.); + let r = _mm_rsqrt14_ps(a); + let e = _mm_set1_ps(0.5773392); + assert_eq_m128(r, e); + } + #[simd_test(enable = "avx512f,avx512vl")] unsafe fn test_mm_mask_rsqrt14_ps() { let a = _mm_set1_ps(3.); diff --git a/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs b/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs index 359a66858233..fec18e3ea3eb 100644 --- a/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs +++ b/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs @@ -2745,6 +2745,14 @@ mod tests { assert_eq_m512d(r, e); } + #[simd_test(enable = "avx512f,avx512vl")] + unsafe fn test_mm256_rsqrt14_pd() { + let a = _mm256_set1_pd(3.); + let r = _mm256_rsqrt14_pd(a); + let e = _mm256_set1_pd(0.5773391723632813); + assert_eq_m256d(r, e); + } + #[simd_test(enable = "avx512f,avx512vl")] unsafe fn test_mm256_mask_rsqrt14_pd() { let a = _mm256_set1_pd(3.); @@ -2765,6 +2773,14 @@ mod tests { assert_eq_m256d(r, e); } + #[simd_test(enable = "avx512f,avx512vl")] + unsafe fn test_mm_rsqrt14_pd() { + let a = _mm_set1_pd(3.); + let r = _mm_rsqrt14_pd(a); + let e = _mm_set1_pd(0.5773391723632813); + assert_eq_m128d(r, e); + } + #[simd_test(enable = "avx512f,avx512vl")] unsafe fn test_mm_mask_rsqrt14_pd() { let a = _mm_set1_pd(3.);