diff --git a/crates/core_simd/src/vector/float.rs b/crates/core_simd/src/vector/float.rs index 5044ac57ec56..7061b9b06748 100644 --- a/crates/core_simd/src/vector/float.rs +++ b/crates/core_simd/src/vector/float.rs @@ -136,6 +136,47 @@ macro_rules! impl_float_vector { let magnitude = self.to_bits() & !Self::splat(-0.).to_bits(); Self::from_bits(sign_bit | magnitude) } + + /// Returns the minimum of each lane. + /// + /// If one of the values is `NAN`, then the other value is returned. + #[inline] + pub fn min(self, other: Self) -> Self { + // TODO consider using an intrinsic + self.is_nan().select( + other, + self.lanes_ge(other).select(other, self) + ) + } + + /// Returns the maximum of each lane. + /// + /// If one of the values is `NAN`, then the other value is returned. + #[inline] + pub fn max(self, other: Self) -> Self { + // TODO consider using an intrinsic + self.is_nan().select( + other, + self.lanes_le(other).select(other, self) + ) + } + + /// Restrict each lane to a certain interval unless it is NaN. + /// + /// For each lane in `self`, returns the corresponding lane in `max` if the lane is + /// greater than `max`, and the corresponding lane in `min` if the lane is less + /// than `min`. Otherwise returns the lane in `self`. + #[inline] + pub fn clamp(self, min: Self, max: Self) -> Self { + assert!( + min.lanes_le(max).all(), + "each lane in `min` must be less than or equal to the corresponding lane in `max`", + ); + let mut x = self; + x = x.lanes_lt(min).select(min, x); + x = x.lanes_gt(max).select(max, x); + x + } } }; } diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 9ada95e851e1..8ef2edc8370a 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -483,6 +483,76 @@ macro_rules! impl_float_tests { ) } + fn min() { + // Regular conditions (both values aren't zero) + test_helpers::test_binary_elementwise( + &Vector::::min, + &Scalar::min, + // Reject the case where both values are zero with different signs + &|a, b| { + for (a, b) in a.iter().zip(b.iter()) { + if *a == 0. && *b == 0. && a.signum() != b.signum() { + return false; + } + } + true + } + ); + + // Special case where both values are zero + let p_zero = Vector::::splat(0.); + let n_zero = Vector::::splat(-0.); + assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.)); + assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.)); + } + + fn max() { + // Regular conditions (both values aren't zero) + test_helpers::test_binary_elementwise( + &Vector::::max, + &Scalar::max, + // Reject the case where both values are zero with different signs + &|a, b| { + for (a, b) in a.iter().zip(b.iter()) { + if *a == 0. && *b == 0. && a.signum() != b.signum() { + return false; + } + } + true + } + ); + + // Special case where both values are zero + let p_zero = Vector::::splat(0.); + let n_zero = Vector::::splat(-0.); + assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.)); + assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.)); + } + + fn clamp() { + test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| { + for (min, max) in min.iter_mut().zip(max.iter_mut()) { + if max < min { + core::mem::swap(min, max); + } + if min.is_nan() { + *min = Scalar::NEG_INFINITY; + } + if max.is_nan() { + *max = Scalar::INFINITY; + } + } + + let mut result_scalar = [Scalar::default(); LANES]; + for i in 0..LANES { + result_scalar[i] = value[i].clamp(min[i], max[i]); + } + let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array(); + test_helpers::prop_assert_biteq!(result_scalar, result_vector); + Ok(()) + }) + } + fn horizontal_sum() { test_helpers::test_1(&|x| { test_helpers::prop_assert_biteq! ( diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs index fffd088f4da3..ff6d30a1afb7 100644 --- a/crates/test_helpers/src/lib.rs +++ b/crates/test_helpers/src/lib.rs @@ -97,6 +97,27 @@ pub fn test_2( + f: &dyn Fn(A, B, C) -> proptest::test_runner::TestCaseResult, +) { + let mut runner = proptest::test_runner::TestRunner::default(); + runner + .run( + &( + A::default_strategy(), + B::default_strategy(), + C::default_strategy(), + ), + |(a, b, c)| f(a, b, c), + ) + .unwrap(); +} + /// Test a unary vector function against a unary scalar function, applied elementwise. #[inline(never)] pub fn test_unary_elementwise(