From 0d9520dfd4cd376c67bf7356df9296bb6ab3713e Mon Sep 17 00:00:00 2001
From: Tobias Decking
Date: Mon, 24 Jun 2024 00:06:30 +0200
Subject: [PATCH] Refactor avx512f: sqrt + rounding fix
---
.../stdarch/crates/core_arch/missing-x86.md | 4 -
.../crates/core_arch/src/x86/avx512f.rs | 198 +++++++++---------
.../crates/core_arch/src/x86_64/avx512f.rs | 16 ++
3 files changed, 120 insertions(+), 98 deletions(-)
diff --git a/library/stdarch/crates/core_arch/missing-x86.md b/library/stdarch/crates/core_arch/missing-x86.md
index 8da6074cacca..6daab7715d32 100644
--- a/library/stdarch/crates/core_arch/missing-x86.md
+++ b/library/stdarch/crates/core_arch/missing-x86.md
@@ -204,8 +204,6 @@
* [ ] [`_mm256_mmask_i64gather_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_epi64)
* [ ] [`_mm256_mmask_i64gather_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_pd)
* [ ] [`_mm256_mmask_i64gather_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mmask_i64gather_ps)
- * [ ] [`_mm256_rsqrt14_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_rsqrt14_pd)
- * [ ] [`_mm256_rsqrt14_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_rsqrt14_ps)
* [ ] [`_mm_abs_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_abs_epi64)
* [ ] [`_mm_i32scatter_epi32`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_i32scatter_epi32)
* [ ] [`_mm_i32scatter_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_i32scatter_epi64)
@@ -236,8 +234,6 @@
* [ ] [`_mm_mmask_i64gather_epi64`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_epi64)
* [ ] [`_mm_mmask_i64gather_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_pd)
* [ ] [`_mm_mmask_i64gather_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mmask_i64gather_ps)
- * [ ] [`_mm_rsqrt14_pd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_rsqrt14_pd)
- * [ ] [`_mm_rsqrt14_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_rsqrt14_ps)
diff --git a/library/stdarch/crates/core_arch/src/x86/avx512f.rs b/library/stdarch/crates/core_arch/src/x86/avx512f.rs
index 1dbd813d430c..9e6ce92b1707 100644
--- a/library/stdarch/crates/core_arch/src/x86/avx512f.rs
+++ b/library/stdarch/crates/core_arch/src/x86/avx512f.rs
@@ -3001,7 +3001,7 @@ pub unsafe fn _mm_maskz_min_epu64(k: __mmask8, a: __m128i, b: __m128i) -> __m128
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm512_sqrt_ps(a: __m512) -> __m512 {
- transmute(vsqrtps(a.as_f32x16(), _MM_FROUND_CUR_DIRECTION))
+ simd_fsqrt(a)
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3012,8 +3012,7 @@ pub unsafe fn _mm512_sqrt_ps(a: __m512) -> __m512 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm512_mask_sqrt_ps(src: __m512, k: __mmask16, a: __m512) -> __m512 {
- let sqrt = _mm512_sqrt_ps(a).as_f32x16();
- transmute(simd_select_bitmask(k, sqrt, src.as_f32x16()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3024,9 +3023,7 @@ pub unsafe fn _mm512_mask_sqrt_ps(src: __m512, k: __mmask16, a: __m512) -> __m51
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm512_maskz_sqrt_ps(k: __mmask16, a: __m512) -> __m512 {
- let sqrt = _mm512_sqrt_ps(a).as_f32x16();
- let zero = _mm512_setzero_ps().as_f32x16();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm512_setzero_ps())
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3037,8 +3034,7 @@ pub unsafe fn _mm512_maskz_sqrt_ps(k: __mmask16, a: __m512) -> __m512 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm256_mask_sqrt_ps(src: __m256, k: __mmask8, a: __m256) -> __m256 {
- let sqrt = _mm256_sqrt_ps(a).as_f32x8();
- transmute(simd_select_bitmask(k, sqrt, src.as_f32x8()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3049,9 +3045,7 @@ pub unsafe fn _mm256_mask_sqrt_ps(src: __m256, k: __mmask8, a: __m256) -> __m256
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm256_maskz_sqrt_ps(k: __mmask8, a: __m256) -> __m256 {
- let sqrt = _mm256_sqrt_ps(a).as_f32x8();
- let zero = _mm256_setzero_ps().as_f32x8();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm256_setzero_ps())
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3062,8 +3056,7 @@ pub unsafe fn _mm256_maskz_sqrt_ps(k: __mmask8, a: __m256) -> __m256 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm_mask_sqrt_ps(src: __m128, k: __mmask8, a: __m128) -> __m128 {
- let sqrt = _mm_sqrt_ps(a).as_f32x4();
- transmute(simd_select_bitmask(k, sqrt, src.as_f32x4()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3074,9 +3067,7 @@ pub unsafe fn _mm_mask_sqrt_ps(src: __m128, k: __mmask8, a: __m128) -> __m128 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtps))]
pub unsafe fn _mm_maskz_sqrt_ps(k: __mmask8, a: __m128) -> __m128 {
- let sqrt = _mm_sqrt_ps(a).as_f32x4();
- let zero = _mm_setzero_ps().as_f32x4();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm_setzero_ps())
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst.
@@ -3087,7 +3078,7 @@ pub unsafe fn _mm_maskz_sqrt_ps(k: __mmask8, a: __m128) -> __m128 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm512_sqrt_pd(a: __m512d) -> __m512d {
- transmute(vsqrtpd(a.as_f64x8(), _MM_FROUND_CUR_DIRECTION))
+ simd_fsqrt(a)
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3098,8 +3089,7 @@ pub unsafe fn _mm512_sqrt_pd(a: __m512d) -> __m512d {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm512_mask_sqrt_pd(src: __m512d, k: __mmask8, a: __m512d) -> __m512d {
- let sqrt = _mm512_sqrt_pd(a).as_f64x8();
- transmute(simd_select_bitmask(k, sqrt, src.as_f64x8()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3110,9 +3100,7 @@ pub unsafe fn _mm512_mask_sqrt_pd(src: __m512d, k: __mmask8, a: __m512d) -> __m5
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm512_maskz_sqrt_pd(k: __mmask8, a: __m512d) -> __m512d {
- let sqrt = _mm512_sqrt_pd(a).as_f64x8();
- let zero = _mm512_setzero_pd().as_f64x8();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm512_setzero_pd())
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3123,8 +3111,7 @@ pub unsafe fn _mm512_maskz_sqrt_pd(k: __mmask8, a: __m512d) -> __m512d {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm256_mask_sqrt_pd(src: __m256d, k: __mmask8, a: __m256d) -> __m256d {
- let sqrt = _mm256_sqrt_pd(a).as_f64x4();
- transmute(simd_select_bitmask(k, sqrt, src.as_f64x4()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3135,9 +3122,7 @@ pub unsafe fn _mm256_mask_sqrt_pd(src: __m256d, k: __mmask8, a: __m256d) -> __m2
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm256_maskz_sqrt_pd(k: __mmask8, a: __m256d) -> __m256d {
- let sqrt = _mm256_sqrt_pd(a).as_f64x4();
- let zero = _mm256_setzero_pd().as_f64x4();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm256_setzero_pd())
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -3148,8 +3133,7 @@ pub unsafe fn _mm256_maskz_sqrt_pd(k: __mmask8, a: __m256d) -> __m256d {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm_mask_sqrt_pd(src: __m128d, k: __mmask8, a: __m128d) -> __m128d {
- let sqrt = _mm_sqrt_pd(a).as_f64x2();
- transmute(simd_select_bitmask(k, sqrt, src.as_f64x2()))
+ simd_select_bitmask(k, simd_fsqrt(a), src)
}
/// Compute the square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using zeromask k (elements are zeroed out when the corresponding mask bit is not set).
@@ -3160,9 +3144,7 @@ pub unsafe fn _mm_mask_sqrt_pd(src: __m128d, k: __mmask8, a: __m128d) -> __m128d
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtpd))]
pub unsafe fn _mm_maskz_sqrt_pd(k: __mmask8, a: __m128d) -> __m128d {
- let sqrt = _mm_sqrt_pd(a).as_f64x2();
- let zero = _mm_setzero_pd().as_f64x2();
- transmute(simd_select_bitmask(k, sqrt, zero))
+ simd_select_bitmask(k, simd_fsqrt(a), _mm_setzero_pd())
}
/// Multiply packed single-precision (32-bit) floating-point elements in a and b, add the intermediate result to packed elements in c, and store the results in dst.
@@ -4764,6 +4746,21 @@ pub unsafe fn _mm512_maskz_rsqrt14_ps(k: __mmask16, a: __m512) -> __m512 {
))
}
+/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_rsqrt14_ps)
+#[inline]
+#[target_feature(enable = "avx512f,avx512vl")]
+#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
+#[cfg_attr(test, assert_instr(vrsqrt14ps))]
+pub unsafe fn _mm256_rsqrt14_ps(a: __m256) -> __m256 {
+ transmute(vrsqrt14ps256(
+ a.as_f32x8(),
+ _mm256_setzero_ps().as_f32x8(),
+ 0b11111111,
+ ))
+}
+
/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mask_rsqrt14_ps&expand=4815)
@@ -4790,6 +4787,21 @@ pub unsafe fn _mm256_maskz_rsqrt14_ps(k: __mmask8, a: __m256) -> __m256 {
))
}
+/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_rsqrt14_ps)
+#[inline]
+#[target_feature(enable = "avx512f,avx512vl")]
+#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
+#[cfg_attr(test, assert_instr(vrsqrt14ps))]
+pub unsafe fn _mm_rsqrt14_ps(a: __m128) -> __m128 {
+ transmute(vrsqrt14ps128(
+ a.as_f32x4(),
+ _mm_setzero_ps().as_f32x4(),
+ 0b00001111,
+ ))
+}
+
/// Compute the approximate reciprocal square root of packed single-precision (32-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mask_rsqrt14_ps&expand=4813)
@@ -4849,6 +4861,21 @@ pub unsafe fn _mm512_maskz_rsqrt14_pd(k: __mmask8, a: __m512d) -> __m512d {
transmute(vrsqrt14pd(a.as_f64x8(), _mm512_setzero_pd().as_f64x8(), k))
}
+/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_rsqrt14_pd)
+#[inline]
+#[target_feature(enable = "avx512f,avx512vl")]
+#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
+#[cfg_attr(test, assert_instr(vrsqrt14pd))]
+pub unsafe fn _mm256_rsqrt14_pd(a: __m256d) -> __m256d {
+ transmute(vrsqrt14pd256(
+ a.as_f64x4(),
+ _mm256_setzero_pd().as_f64x4(),
+ 0b00001111,
+ ))
+}
+
/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mask_rsqrt14_pd&expand=4808)
@@ -4875,6 +4902,21 @@ pub unsafe fn _mm256_maskz_rsqrt14_pd(k: __mmask8, a: __m256d) -> __m256d {
))
}
+/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst. The maximum relative error for this approximation is less than 2^-14.
+///
+/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_rsqrt14_pd)
+#[inline]
+#[target_feature(enable = "avx512f,avx512vl")]
+#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
+#[cfg_attr(test, assert_instr(vrsqrt14pd))]
+pub unsafe fn _mm_rsqrt14_pd(a: __m128d) -> __m128d {
+ transmute(vrsqrt14pd128(
+ a.as_f64x2(),
+ _mm_setzero_pd().as_f64x2(),
+ 0b00000011,
+ ))
+}
+
/// Compute the approximate reciprocal square root of packed double-precision (64-bit) floating-point elements in a, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set). The maximum relative error for this approximation is less than 2^-14.
///
/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mask_rsqrt14_pd&expand=4806)
@@ -34834,13 +34876,7 @@ pub unsafe fn _mm_maskz_min_sd(k: __mmask8, a: __m128d, b: __m128d) -> __m128d {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtss))]
pub unsafe fn _mm_mask_sqrt_ss(src: __m128, k: __mmask8, a: __m128, b: __m128) -> __m128 {
- transmute(vsqrtss(
- a.as_f32x4(),
- b.as_f32x4(),
- src.as_f32x4(),
- k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
- ))
+ vsqrtss(a, b, src, k, _MM_FROUND_CUR_DIRECTION)
}
/// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst.
@@ -34851,13 +34887,7 @@ pub unsafe fn _mm_mask_sqrt_ss(src: __m128, k: __mmask8, a: __m128, b: __m128) -
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtss))]
pub unsafe fn _mm_maskz_sqrt_ss(k: __mmask8, a: __m128, b: __m128) -> __m128 {
- transmute(vsqrtss(
- a.as_f32x4(),
- b.as_f32x4(),
- _mm_setzero_ps().as_f32x4(),
- k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
- ))
+ vsqrtss(a, b, _mm_setzero_ps(), k, _MM_FROUND_CUR_DIRECTION)
}
/// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.
@@ -34868,13 +34898,7 @@ pub unsafe fn _mm_maskz_sqrt_ss(k: __mmask8, a: __m128, b: __m128) -> __m128 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtsd))]
pub unsafe fn _mm_mask_sqrt_sd(src: __m128d, k: __mmask8, a: __m128d, b: __m128d) -> __m128d {
- transmute(vsqrtsd(
- a.as_f64x2(),
- b.as_f64x2(),
- src.as_f64x2(),
- k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
- ))
+ vsqrtsd(a, b, src, k, _MM_FROUND_CUR_DIRECTION)
}
/// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.
@@ -34885,13 +34909,7 @@ pub unsafe fn _mm_mask_sqrt_sd(src: __m128d, k: __mmask8, a: __m128d, b: __m128d
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(vsqrtsd))]
pub unsafe fn _mm_maskz_sqrt_sd(k: __mmask8, a: __m128d, b: __m128d) -> __m128d {
- transmute(vsqrtsd(
- a.as_f64x2(),
- b.as_f64x2(),
- _mm_setzero_pd().as_f64x2(),
- k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
- ))
+ vsqrtsd(a, b, _mm_setzero_pd(), k, _MM_FROUND_CUR_DIRECTION)
}
/// Compute the approximate reciprocal square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst, and copy the upper 3 packed elements from a to the upper elements of dst. The maximum relative error for this approximation is less than 2^-14.
@@ -36979,11 +36997,7 @@ pub unsafe fn _mm_maskz_min_round_sd(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm_sqrt_round_ss(a: __m128, b: __m128) -> __m128 {
static_assert_rounding!(ROUNDING);
- let a = a.as_f32x4();
- let b = b.as_f32x4();
- let zero = _mm_setzero_ps().as_f32x4();
- let r = vsqrtss(a, b, zero, 0b1, ROUNDING);
- transmute(r)
+ vsqrtss(a, b, _mm_setzero_ps(), 0b1, ROUNDING)
}
/// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst.\
@@ -37008,11 +37022,7 @@ pub unsafe fn _mm_mask_sqrt_round_ss(
b: __m128,
) -> __m128 {
static_assert_rounding!(ROUNDING);
- let a = a.as_f32x4();
- let b = b.as_f32x4();
- let src = src.as_f32x4();
- let r = vsqrtss(a, b, src, k, ROUNDING);
- transmute(r)
+ vsqrtss(a, b, src, k, ROUNDING)
}
/// Compute the square root of the lower single-precision (32-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper 3 packed elements from a to the upper elements of dst.\
@@ -37036,11 +37046,7 @@ pub unsafe fn _mm_maskz_sqrt_round_ss(
b: __m128,
) -> __m128 {
static_assert_rounding!(ROUNDING);
- let a = a.as_f32x4();
- let b = b.as_f32x4();
- let zero = _mm_setzero_ps().as_f32x4();
- let r = vsqrtss(a, b, zero, k, ROUNDING);
- transmute(r)
+ vsqrtss(a, b, _mm_setzero_ps(), k, ROUNDING)
}
/// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst, and copy the upper element from a to the upper element of dst.\
@@ -37060,11 +37066,7 @@ pub unsafe fn _mm_maskz_sqrt_round_ss(
#[rustc_legacy_const_generics(2)]
pub unsafe fn _mm_sqrt_round_sd(a: __m128d, b: __m128d) -> __m128d {
static_assert_rounding!(ROUNDING);
- let a = a.as_f64x2();
- let b = b.as_f64x2();
- let zero = _mm_setzero_pd().as_f64x2();
- let r = vsqrtsd(a, b, zero, 0b1, ROUNDING);
- transmute(r)
+ vsqrtsd(a, b, _mm_setzero_pd(), 0b1, ROUNDING)
}
/// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using writemask k (the element is copied from src when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.\
@@ -37089,11 +37091,7 @@ pub unsafe fn _mm_mask_sqrt_round_sd(
b: __m128d,
) -> __m128d {
static_assert_rounding!(ROUNDING);
- let a = a.as_f64x2();
- let b = b.as_f64x2();
- let src = src.as_f64x2();
- let r = vsqrtsd(a, b, src, k, ROUNDING);
- transmute(r)
+ vsqrtsd(a, b, src, k, ROUNDING)
}
/// Compute the square root of the lower double-precision (64-bit) floating-point element in b, store the result in the lower element of dst using zeromask k (the element is zeroed out when mask bit 0 is not set), and copy the upper element from a to the upper element of dst.\
@@ -37117,11 +37115,7 @@ pub unsafe fn _mm_maskz_sqrt_round_sd(
b: __m128d,
) -> __m128d {
static_assert_rounding!(ROUNDING);
- let a = a.as_f64x2();
- let b = b.as_f64x2();
- let zero = _mm_setzero_pd().as_f64x2();
- let r = vsqrtsd(a, b, zero, k, ROUNDING);
- transmute(r)
+ vsqrtsd(a, b, _mm_setzero_pd(), k, ROUNDING)
}
/// Convert the exponent of the lower single-precision (32-bit) floating-point element in b to a single-precision (32-bit) floating-point number representing the integer exponent, store the result in the lower element of dst, and copy the upper 3 packed elements from a to the upper elements of dst. This intrinsic essentially calculates floor(log2(x)) for the lower element.\
@@ -39134,7 +39128,7 @@ pub unsafe fn _mm_mask_cvtsd_ss(src: __m128, k: __mmask8, a: __m128, b: __m128d)
b.as_f64x2(),
src.as_f32x4(),
k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
+ _MM_FROUND_CUR_DIRECTION,
))
}
@@ -39151,7 +39145,7 @@ pub unsafe fn _mm_maskz_cvtsd_ss(k: __mmask8, a: __m128, b: __m128d) -> __m128 {
b.as_f64x2(),
_mm_setzero_ps().as_f32x4(),
k,
- _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
+ _MM_FROUND_CUR_DIRECTION,
))
}
@@ -41117,9 +41111,9 @@ extern "C" {
#[link_name = "llvm.x86.avx512.mask.min.sd.round"]
fn vminsd(a: f64x2, b: f64x2, src: f64x2, mask: u8, sae: i32) -> f64x2;
#[link_name = "llvm.x86.avx512.mask.sqrt.ss"]
- fn vsqrtss(a: f32x4, b: f32x4, src: f32x4, mask: u8, rounding: i32) -> f32x4;
+ fn vsqrtss(a: __m128, b: __m128, src: __m128, mask: u8, rounding: i32) -> __m128;
#[link_name = "llvm.x86.avx512.mask.sqrt.sd"]
- fn vsqrtsd(a: f64x2, b: f64x2, src: f64x2, mask: u8, rounding: i32) -> f64x2;
+ fn vsqrtsd(a: __m128d, b: __m128d, src: __m128d, mask: u8, rounding: i32) -> __m128d;
#[link_name = "llvm.x86.avx512.mask.getexp.ss"]
fn vgetexpss(a: f32x4, b: f32x4, src: f32x4, mask: u8, sae: i32) -> f32x4;
#[link_name = "llvm.x86.avx512.mask.getexp.sd"]
@@ -43769,6 +43763,14 @@ mod tests {
assert_eq_m512(r, e);
}
+ #[simd_test(enable = "avx512f,avx512vl")]
+ unsafe fn test_mm256_rsqrt14_ps() {
+ let a = _mm256_set1_ps(3.);
+ let r = _mm256_rsqrt14_ps(a);
+ let e = _mm256_set1_ps(0.5773392);
+ assert_eq_m256(r, e);
+ }
+
#[simd_test(enable = "avx512f,avx512vl")]
unsafe fn test_mm256_mask_rsqrt14_ps() {
let a = _mm256_set1_ps(3.);
@@ -43789,6 +43791,14 @@ mod tests {
assert_eq_m256(r, e);
}
+ #[simd_test(enable = "avx512f,avx512vl")]
+ unsafe fn test_mm_rsqrt14_ps() {
+ let a = _mm_set1_ps(3.);
+ let r = _mm_rsqrt14_ps(a);
+ let e = _mm_set1_ps(0.5773392);
+ assert_eq_m128(r, e);
+ }
+
#[simd_test(enable = "avx512f,avx512vl")]
unsafe fn test_mm_mask_rsqrt14_ps() {
let a = _mm_set1_ps(3.);
diff --git a/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs b/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs
index 359a66858233..fec18e3ea3eb 100644
--- a/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs
+++ b/library/stdarch/crates/core_arch/src/x86_64/avx512f.rs
@@ -2745,6 +2745,14 @@ mod tests {
assert_eq_m512d(r, e);
}
+ #[simd_test(enable = "avx512f,avx512vl")]
+ unsafe fn test_mm256_rsqrt14_pd() {
+ let a = _mm256_set1_pd(3.);
+ let r = _mm256_rsqrt14_pd(a);
+ let e = _mm256_set1_pd(0.5773391723632813);
+ assert_eq_m256d(r, e);
+ }
+
#[simd_test(enable = "avx512f,avx512vl")]
unsafe fn test_mm256_mask_rsqrt14_pd() {
let a = _mm256_set1_pd(3.);
@@ -2765,6 +2773,14 @@ mod tests {
assert_eq_m256d(r, e);
}
+ #[simd_test(enable = "avx512f,avx512vl")]
+ unsafe fn test_mm_rsqrt14_pd() {
+ let a = _mm_set1_pd(3.);
+ let r = _mm_rsqrt14_pd(a);
+ let e = _mm_set1_pd(0.5773391723632813);
+ assert_eq_m128d(r, e);
+ }
+
#[simd_test(enable = "avx512f,avx512vl")]
unsafe fn test_mm_mask_rsqrt14_pd() {
let a = _mm_set1_pd(3.);