Fixed _mm512_kunpackb, reduce-max and reduce-min

`_mm512_kunpackb` was implemented wrong, and `simd_reduce_max` uses `maxnum` for comparison, which adheres to IEEE754, but Intel specifically says that they do NOT adhere to IEEE754 for NaNs, which can give wrong results
This commit is contained in:
sayantn 2024-06-29 14:04:03 +05:30 committed by Amanieu d'Antras
parent b3e96f2584
commit 95d273aaf9

View file

@ -27663,9 +27663,7 @@ pub unsafe fn _mm512_mask2int(k1: __mmask16) -> i32 {
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
#[cfg_attr(test, assert_instr(mov))] // generate normal and code instead of kunpckbw
pub unsafe fn _mm512_kunpackb(a: __mmask16, b: __mmask16) -> __mmask16 {
let a = a & 0b00000000_11111111;
let b = b & 0b11111111_00000000;
a | b
((a & 0xff) << 8) | (b & 0xff)
}
/// Performs bitwise OR between k1 and k2, storing the result in dst. CF flag is set if dst consists of all 1's.
@ -31554,7 +31552,13 @@ pub unsafe fn _mm512_mask_reduce_max_epu64(k: __mmask8, a: __m512i) -> u64 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_reduce_max_ps(a: __m512) -> f32 {
simd_reduce_max(a.as_f32x16())
let a = _mm256_max_ps(
simd_shuffle!(a, a, [0, 1, 2, 3, 4, 5, 6, 7]),
simd_shuffle!(a, a, [8, 9, 10, 11, 12, 13, 14, 15]),
);
let a = _mm_max_ps(_mm256_extractf128_ps::<0>(a), _mm256_extractf128_ps::<1>(a));
let a = _mm_max_ps(a, simd_shuffle!(a, a, [2, 3, 0, 1]));
_mm_cvtss_f32(_mm_max_ss(a, _mm_movehdup_ps(a)))
}
/// Reduce the packed single-precision (32-bit) floating-point elements in a by maximum using mask k. Returns the maximum of all active elements in a.
@ -31564,11 +31568,7 @@ pub unsafe fn _mm512_reduce_max_ps(a: __m512) -> f32 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_mask_reduce_max_ps(k: __mmask16, a: __m512) -> f32 {
simd_reduce_max(simd_select_bitmask(
k,
a.as_f32x16(),
_mm512_undefined_ps().as_f32x16(),
))
_mm512_reduce_max_ps(_mm512_mask_mov_ps(_mm512_set1_ps(f32::MIN), k, a))
}
/// Reduce the packed double-precision (64-bit) floating-point elements in a by maximum. Returns the maximum of all elements in a.
@ -31578,7 +31578,12 @@ pub unsafe fn _mm512_mask_reduce_max_ps(k: __mmask16, a: __m512) -> f32 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_reduce_max_pd(a: __m512d) -> f64 {
simd_reduce_max(a.as_f64x8())
let a = _mm256_max_pd(
_mm512_extractf64x4_pd::<0>(a),
_mm512_extractf64x4_pd::<1>(a),
);
let a = _mm_max_pd(_mm256_extractf128_pd::<0>(a), _mm256_extractf128_pd::<1>(a));
_mm_cvtsd_f64(_mm_max_sd(a, simd_shuffle!(a, a, [1, 0])))
}
/// Reduce the packed double-precision (64-bit) floating-point elements in a by maximum using mask k. Returns the maximum of all active elements in a.
@ -31588,11 +31593,7 @@ pub unsafe fn _mm512_reduce_max_pd(a: __m512d) -> f64 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_mask_reduce_max_pd(k: __mmask8, a: __m512d) -> f64 {
simd_reduce_max(simd_select_bitmask(
k,
a.as_f64x8(),
_mm512_undefined_pd().as_f64x8(),
))
_mm512_reduce_max_pd(_mm512_mask_mov_pd(_mm512_set1_pd(f64::MIN), k, a))
}
/// Reduce the packed signed 32-bit integers in a by minimum. Returns the minimum of all elements in a.
@ -31698,7 +31699,13 @@ pub unsafe fn _mm512_mask_reduce_min_epu64(k: __mmask8, a: __m512i) -> u64 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_reduce_min_ps(a: __m512) -> f32 {
simd_reduce_min(a.as_f32x16())
let a = _mm256_min_ps(
simd_shuffle!(a, a, [0, 1, 2, 3, 4, 5, 6, 7]),
simd_shuffle!(a, a, [8, 9, 10, 11, 12, 13, 14, 15]),
);
let a = _mm_min_ps(_mm256_extractf128_ps::<0>(a), _mm256_extractf128_ps::<1>(a));
let a = _mm_min_ps(a, simd_shuffle!(a, a, [2, 3, 0, 1]));
_mm_cvtss_f32(_mm_min_ss(a, _mm_movehdup_ps(a)))
}
/// Reduce the packed single-precision (32-bit) floating-point elements in a by maximum using mask k. Returns the minimum of all active elements in a.
@ -31708,11 +31715,7 @@ pub unsafe fn _mm512_reduce_min_ps(a: __m512) -> f32 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_mask_reduce_min_ps(k: __mmask16, a: __m512) -> f32 {
simd_reduce_min(simd_select_bitmask(
k,
a.as_f32x16(),
_mm512_undefined_ps().as_f32x16(),
))
_mm512_reduce_min_ps(_mm512_mask_mov_ps(_mm512_set1_ps(f32::MAX), k, a))
}
/// Reduce the packed double-precision (64-bit) floating-point elements in a by minimum. Returns the minimum of all elements in a.
@ -31722,7 +31725,12 @@ pub unsafe fn _mm512_mask_reduce_min_ps(k: __mmask16, a: __m512) -> f32 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_reduce_min_pd(a: __m512d) -> f64 {
simd_reduce_min(a.as_f64x8())
let a = _mm256_min_pd(
_mm512_extractf64x4_pd::<0>(a),
_mm512_extractf64x4_pd::<1>(a),
);
let a = _mm_min_pd(_mm256_extractf128_pd::<0>(a), _mm256_extractf128_pd::<1>(a));
_mm_cvtsd_f64(_mm_min_sd(a, simd_shuffle!(a, a, [1, 0])))
}
/// Reduce the packed double-precision (64-bit) floating-point elements in a by maximum using mask k. Returns the minimum of all active elements in a.
@ -31732,11 +31740,7 @@ pub unsafe fn _mm512_reduce_min_pd(a: __m512d) -> f64 {
#[target_feature(enable = "avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm512_mask_reduce_min_pd(k: __mmask8, a: __m512d) -> f64 {
simd_reduce_min(simd_select_bitmask(
k,
a.as_f64x8(),
_mm512_undefined_pd().as_f64x8(),
))
_mm512_reduce_min_pd(_mm512_mask_mov_pd(_mm512_set1_pd(f64::MAX), k, a))
}
/// Reduce the packed 32-bit integers in a by bitwise AND. Returns the bitwise AND of all elements in a.
@ -54323,7 +54327,7 @@ mod tests {
let a: u16 = 0b11001100_00110011;
let b: u16 = 0b00101110_00001011;
let r = _mm512_kunpackb(a, b);
let e: u16 = 0b00101110_00110011;
let e: u16 = 0b00110011_00001011;
assert_eq!(r, e);
}