From 7defd9b4290bce1fd912f55887217c213a642937 Mon Sep 17 00:00:00 2001 From: Hanna Kruppe Date: Sun, 12 Jan 2025 10:56:30 +0100 Subject: [PATCH 1/2] Use wasm32 arch intrinsics for rint{,f} --- .../compiler-builtins/libm/etc/function-definitions.json | 2 ++ library/compiler-builtins/libm/src/math/arch/mod.rs | 4 +++- library/compiler-builtins/libm/src/math/arch/wasm32.rs | 8 ++++++++ library/compiler-builtins/libm/src/math/rint.rs | 6 ++++++ library/compiler-builtins/libm/src/math/rintf.rs | 6 ++++++ 5 files changed, 25 insertions(+), 1 deletion(-) diff --git a/library/compiler-builtins/libm/etc/function-definitions.json b/library/compiler-builtins/libm/etc/function-definitions.json index 3cf7e0fed3de..f60a7e5673c9 100644 --- a/library/compiler-builtins/libm/etc/function-definitions.json +++ b/library/compiler-builtins/libm/etc/function-definitions.json @@ -604,12 +604,14 @@ "rint": { "sources": [ "src/libm_helper.rs", + "src/math/arch/wasm32.rs", "src/math/rint.rs" ], "type": "f64" }, "rintf": { "sources": [ + "src/math/arch/wasm32.rs", "src/math/rintf.rs" ], "type": "f32" diff --git a/library/compiler-builtins/libm/src/math/arch/mod.rs b/library/compiler-builtins/libm/src/math/arch/mod.rs index bd79ae1c69c2..3992419cbaad 100644 --- a/library/compiler-builtins/libm/src/math/arch/mod.rs +++ b/library/compiler-builtins/libm/src/math/arch/mod.rs @@ -11,7 +11,9 @@ cfg_if! { if #[cfg(all(target_arch = "wasm32", intrinsics_enabled))] { mod wasm32; - pub use wasm32::{ceil, ceilf, fabs, fabsf, floor, floorf, sqrt, sqrtf, trunc, truncf}; + pub use wasm32::{ + ceil, ceilf, fabs, fabsf, floor, floorf, rint, rintf, sqrt, sqrtf, trunc, truncf, + }; } else if #[cfg(target_feature = "sse2")] { mod i686; pub use i686::{sqrt, sqrtf}; diff --git a/library/compiler-builtins/libm/src/math/arch/wasm32.rs b/library/compiler-builtins/libm/src/math/arch/wasm32.rs index 384445f12914..de80c8a58172 100644 --- a/library/compiler-builtins/libm/src/math/arch/wasm32.rs +++ b/library/compiler-builtins/libm/src/math/arch/wasm32.rs @@ -25,6 +25,14 @@ pub fn floorf(x: f32) -> f32 { core::arch::wasm32::f32_floor(x) } +pub fn rint(x: f64) -> f64 { + core::arch::wasm32::f64_nearest(x) +} + +pub fn rintf(x: f32) -> f32 { + core::arch::wasm32::f32_nearest(x) +} + pub fn sqrt(x: f64) -> f64 { core::arch::wasm32::f64_sqrt(x) } diff --git a/library/compiler-builtins/libm/src/math/rint.rs b/library/compiler-builtins/libm/src/math/rint.rs index cbdc3c2b91c8..50192ffdf03d 100644 --- a/library/compiler-builtins/libm/src/math/rint.rs +++ b/library/compiler-builtins/libm/src/math/rint.rs @@ -1,5 +1,11 @@ #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn rint(x: f64) -> f64 { + select_implementation! { + name: rint, + use_arch: all(target_arch = "wasm32", intrinsics_enabled), + args: x, + } + let one_over_e = 1.0 / f64::EPSILON; let as_u64: u64 = x.to_bits(); let exponent: u64 = (as_u64 >> 52) & 0x7ff; diff --git a/library/compiler-builtins/libm/src/math/rintf.rs b/library/compiler-builtins/libm/src/math/rintf.rs index 2d22c9393543..64968b6be3a3 100644 --- a/library/compiler-builtins/libm/src/math/rintf.rs +++ b/library/compiler-builtins/libm/src/math/rintf.rs @@ -1,5 +1,11 @@ #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn rintf(x: f32) -> f32 { + select_implementation! { + name: rintf, + use_arch: all(target_arch = "wasm32", intrinsics_enabled), + args: x, + } + let one_over_e = 1.0 / f32::EPSILON; let as_u32: u32 = x.to_bits(); let exponent: u32 = (as_u32 >> 23) & 0xff; From 87cc064e35d4fc65caf2da13bf263b1323b16760 Mon Sep 17 00:00:00 2001 From: Hanna Kruppe Date: Sun, 12 Jan 2025 11:16:40 +0100 Subject: [PATCH 2/2] Introduce arch::aarch64 and use it for rint{,f} --- .../libm/etc/function-definitions.json | 2 ++ .../libm/src/math/arch/aarch64.rs | 33 +++++++++++++++++++ .../libm/src/math/arch/mod.rs | 7 ++++ .../compiler-builtins/libm/src/math/rint.rs | 5 ++- .../compiler-builtins/libm/src/math/rintf.rs | 5 ++- 5 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 library/compiler-builtins/libm/src/math/arch/aarch64.rs diff --git a/library/compiler-builtins/libm/etc/function-definitions.json b/library/compiler-builtins/libm/etc/function-definitions.json index f60a7e5673c9..39b6c97029cb 100644 --- a/library/compiler-builtins/libm/etc/function-definitions.json +++ b/library/compiler-builtins/libm/etc/function-definitions.json @@ -604,6 +604,7 @@ "rint": { "sources": [ "src/libm_helper.rs", + "src/math/arch/aarch64.rs", "src/math/arch/wasm32.rs", "src/math/rint.rs" ], @@ -611,6 +612,7 @@ }, "rintf": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/arch/wasm32.rs", "src/math/rintf.rs" ], diff --git a/library/compiler-builtins/libm/src/math/arch/aarch64.rs b/library/compiler-builtins/libm/src/math/arch/aarch64.rs new file mode 100644 index 000000000000..374ec11bfec3 --- /dev/null +++ b/library/compiler-builtins/libm/src/math/arch/aarch64.rs @@ -0,0 +1,33 @@ +use core::arch::aarch64::{ + float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32, + vrndn_f64, +}; + +pub fn rint(x: f64) -> f64 { + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let x_vec: float64x1_t = unsafe { vdup_n_f64(x) }; + + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) }; + + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) }; + + result +} + +pub fn rintf(x: f32) -> f32 { + // There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we + // have to use the vector form and drop the other lanes afterwards. + + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let x_vec: float32x2_t = unsafe { vdup_n_f32(x) }; + + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) }; + + // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. + let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) }; + + result +} diff --git a/library/compiler-builtins/libm/src/math/arch/mod.rs b/library/compiler-builtins/libm/src/math/arch/mod.rs index 3992419cbaad..091d7650a5ac 100644 --- a/library/compiler-builtins/libm/src/math/arch/mod.rs +++ b/library/compiler-builtins/libm/src/math/arch/mod.rs @@ -17,6 +17,13 @@ cfg_if! { } else if #[cfg(target_feature = "sse2")] { mod i686; pub use i686::{sqrt, sqrtf}; + } else if #[cfg(all( + target_arch = "aarch64", // TODO: also arm64ec? + target_feature = "neon", + target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484 + ))] { + mod aarch64; + pub use aarch64::{rint, rintf}; } } diff --git a/library/compiler-builtins/libm/src/math/rint.rs b/library/compiler-builtins/libm/src/math/rint.rs index 50192ffdf03d..c9ea6402ec73 100644 --- a/library/compiler-builtins/libm/src/math/rint.rs +++ b/library/compiler-builtins/libm/src/math/rint.rs @@ -2,7 +2,10 @@ pub fn rint(x: f64) -> f64 { select_implementation! { name: rint, - use_arch: all(target_arch = "wasm32", intrinsics_enabled), + use_arch: any( + all(target_arch = "wasm32", intrinsics_enabled), + all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"), + ), args: x, } diff --git a/library/compiler-builtins/libm/src/math/rintf.rs b/library/compiler-builtins/libm/src/math/rintf.rs index 64968b6be3a3..33b5b3ddebf6 100644 --- a/library/compiler-builtins/libm/src/math/rintf.rs +++ b/library/compiler-builtins/libm/src/math/rintf.rs @@ -2,7 +2,10 @@ pub fn rintf(x: f32) -> f32 { select_implementation! { name: rintf, - use_arch: all(target_arch = "wasm32", intrinsics_enabled), + use_arch: any( + all(target_arch = "wasm32", intrinsics_enabled), + all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"), + ), args: x, }