diff --git a/library/stdarch/src/x86/avx.rs b/library/stdarch/src/x86/avx.rs index 77a8738e19b9..919609e3e182 100644 --- a/library/stdarch/src/x86/avx.rs +++ b/library/stdarch/src/x86/avx.rs @@ -110,6 +110,72 @@ pub unsafe fn _mm256_floor_pd(a: f64x4) -> f64x4 { roundpd256(a, 0x01) } +/// Round packed single-precision (32-bit) floating point elements in `a` +/// according to the flag `b`. The value of `b` may be as follows: +/// 0x00: Round to the nearest whole number. +/// 0x01: Round down, toward negative infinity. +/// 0x02: Round up, toward positive infinity. +/// 0x03: Truncate the values. +/// For a few additional values options, check the LLVM docs: +/// https://github.com/llvm-mirror/clang/blob/dcd8d797b20291f1a6b3e0ddda085aa2bbb382a8/lib/Headers/avxintrin.h#L382 +#[inline(always)] +#[target_feature = "+avx"] +// #[cfg_attr(test, assert_instr(vroundps))] +// TODO: Replace with assert_expanded_instr https://github.com/rust-lang-nursery/stdsimd/issues/49 +pub fn _mm256_round_ps(a: f32x8, b: i32) -> f32x8 { + macro_rules! call { + ($imm8:expr) => { + unsafe { roundps256(a, $imm8) } + } + } + constify_imm8!(b, call) +} + +// TODO: Remove once a macro is ipmlemented to automate these tests +// https://github.com/rust-lang-nursery/stdsimd/issues/49 +#[cfg(test)] +#[target_feature = "+avx"] +#[cfg_attr(test, assert_instr(vroundps))] +fn test_mm256_round_ps(a: f32x8) -> f32x8 { + _mm256_round_ps(a, 0x00) +} + +/// Round packed single-precision (32-bit) floating point elements in `a` toward +/// positive infinity. +#[inline(always)] +#[target_feature = "+avx"] +#[cfg_attr(test, assert_instr(vroundps))] +pub fn _mm256_ceil_ps(a: f32x8) -> f32x8 { + unsafe { roundps256(a, 0x02) } +} + +/// Round packed single-precision (32-bit) floating point elements in `a` toward +/// negative infinity. +#[inline(always)] +#[target_feature = "+avx"] +#[cfg_attr(test, assert_instr(vroundps))] +pub fn _mm256_floor_ps(a: f32x8) -> f32x8 { + unsafe { roundps256(a, 0x01) } +} + +/// Return the square root of packed single-precision (32-bit) floating point +/// elements in `a`. +#[inline(always)] +#[target_feature = "+avx"] +#[cfg_attr(test, assert_instr(vsqrtps))] +pub fn _mm256_sqrt_ps(a: f32x8) -> f32x8 { + unsafe { sqrtps256(a) } +} + +/// Return the square root of packed double-precision (64-bit) floating point +/// elements in `a`. +#[inline(always)] +#[target_feature = "+avx"] +#[cfg_attr(test, assert_instr(vsqrtpd))] +pub fn _mm256_sqrt_pd(a: f64x4) -> f64x4 { + unsafe { sqrtpd256(a) } +} + /// LLVM intrinsics used in the above functions #[allow(improper_ctypes)] extern "C" { @@ -119,9 +185,15 @@ extern "C" { fn addsubps256(a: f32x8, b: f32x8) -> f32x8; #[link_name = "llvm.x86.avx.round.pd.256"] fn roundpd256(a: f64x4, b: i32) -> f64x4; + #[link_name = "llvm.x86.avx.round.ps.256"] + fn roundps256(a: f32x8, b: i32) -> f32x8; + #[link_name = "llvm.x86.avx.sqrt.pd.256"] + fn sqrtpd256(a: f64x4) -> f64x4; + #[link_name = "llvm.x86.avx.sqrt.ps.256"] + fn sqrtps256(a: f32x8) -> f32x8; } -#[cfg(test)] +#[cfg(all(test, target_feature = "avx", any(target_arch = "x86", target_arch = "x86_64")))] mod tests { use stdsimd_test::simd_test; @@ -229,4 +301,51 @@ mod tests { let expected_up = f64x4::new(2.0, 3.0, 4.0, -1.0); assert_eq!(result_up, expected_up); } + + #[simd_test = "avx"] + fn _mm256_round_ps() { + let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2); + let result_closest = avx::_mm256_round_ps(a, 0b00000000); + let result_down = avx::_mm256_round_ps(a, 0b00000001); + let result_up = avx::_mm256_round_ps(a, 0b00000010); + let expected_closest = f32x8::new(2.0, 2.0, 4.0, -1.0, 2.0, 2.0, 4.0, -1.0); + let expected_down = f32x8::new(1.0, 2.0, 3.0, -2.0, 1.0, 2.0, 3.0, -2.0); + let expected_up = f32x8::new(2.0, 3.0, 4.0, -1.0, 2.0, 3.0, 4.0, -1.0); + assert_eq!(result_closest, expected_closest); + assert_eq!(result_down, expected_down); + assert_eq!(result_up, expected_up); + } + + #[simd_test = "avx"] + fn _mm256_floor_ps() { + let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2); + let result_down = avx::_mm256_floor_ps(a); + let expected_down = f32x8::new(1.0, 2.0, 3.0, -2.0, 1.0, 2.0, 3.0, -2.0); + assert_eq!(result_down, expected_down); + } + + #[simd_test = "avx"] + fn _mm256_ceil_ps() { + let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2); + let result_up = avx::_mm256_ceil_ps(a); + let expected_up = f32x8::new(2.0, 3.0, 4.0, -1.0, 2.0, 3.0, 4.0, -1.0); + assert_eq!(result_up, expected_up); + } + + #[simd_test = "avx"] + fn _mm256_sqrt_pd() { + let a = f64x4::new(4.0, 9.0, 16.0, 25.0); + let r = avx::_mm256_sqrt_pd(a, ); + let e = f64x4::new(2.0, 3.0, 4.0, 5.0); + assert_eq!(r, e); + } + + #[simd_test = "avx"] + fn _mm256_sqrt_ps() { + let a = f32x8::new(4.0, 9.0, 16.0, 25.0, 4.0, 9.0, 16.0, 25.0); + let r = avx::_mm256_sqrt_ps(a); + let e = f32x8::new(2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0); + assert_eq!(r, e); + } + }