Add assembly version of simple operations on aarch64

Replace `core::arch` versions of the following with handwritten
assembly, which avoids recursion issues (cg_gcc using `rint` as a
fallback) as well as problems with `aarch64be`.

* `rint`
* `rintf`

Additionally, add assembly versions of the following:

* `fma`
* `fmaf`
* `sqrt`
* `sqrtf`

If the `fp16` target feature is available, which implies `neon`, also
include the following:

* `rintf16`
* `sqrtf16`

`sqrt` is added to match the implementation for `x86`. `fma` is included
since it is used by many other routines.

There are a handful of other operations that have assembly
implementations. They are omitted here because we should have basic
float math routines available in `core` in the near future, which will
allow us to defer to LLVM for assembly lowering rather than implementing
these ourselves.
This commit is contained in:
Trevor Gross 2025-01-23 01:46:24 +00:00 committed by Trevor Gross
parent 375cb5402f
commit 28b6df8603
9 changed files with 161 additions and 34 deletions

View file

@ -342,12 +342,14 @@
},
"fma": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/fma.rs"
],
"type": "f64"
},
"fmaf": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/fma_wide.rs"
],
"type": "f32"
@ -806,6 +808,7 @@
},
"rintf16": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/rint.rs"
],
"type": "f16"
@ -928,6 +931,7 @@
},
"sqrt": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/arch/i686.rs",
"src/math/arch/wasm32.rs",
"src/math/generic/sqrt.rs",
@ -937,6 +941,7 @@
},
"sqrtf": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/arch/i686.rs",
"src/math/arch/wasm32.rs",
"src/math/generic/sqrt.rs",
@ -953,6 +958,7 @@
},
"sqrtf16": {
"sources": [
"src/math/arch/aarch64.rs",
"src/math/generic/sqrt.rs",
"src/math/sqrtf16.rs"
],

View file

@ -1,33 +1,115 @@
use core::arch::aarch64::{
float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32,
vrndn_f64,
};
//! Architecture-specific support for aarch64 with neon.
pub fn rint(x: f64) -> f64 {
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let x_vec: float64x1_t = unsafe { vdup_n_f64(x) };
use core::arch::asm;
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) };
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) };
result
pub fn fma(mut x: f64, y: f64, z: f64) -> f64 {
// SAFETY: `fmadd` is available with neon and has no side effects.
unsafe {
asm!(
"fmadd {x:d}, {x:d}, {y:d}, {z:d}",
x = inout(vreg) x,
y = in(vreg) y,
z = in(vreg) z,
options(nomem, nostack, pure)
);
}
x
}
pub fn rintf(x: f32) -> f32 {
// There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we
// have to use the vector form and drop the other lanes afterwards.
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let x_vec: float32x2_t = unsafe { vdup_n_f32(x) };
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) };
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) };
result
pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 {
// SAFETY: `fmadd` is available with neon and has no side effects.
unsafe {
asm!(
"fmadd {x:s}, {x:s}, {y:s}, {z:s}",
x = inout(vreg) x,
y = in(vreg) y,
z = in(vreg) z,
options(nomem, nostack, pure)
);
}
x
}
pub fn rint(mut x: f64) -> f64 {
// SAFETY: `frintn` is available with neon and has no side effects.
//
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
// not support rounding modes.
unsafe {
asm!(
"frintn {x:d}, {x:d}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}
pub fn rintf(mut x: f32) -> f32 {
// SAFETY: `frintn` is available with neon and has no side effects.
//
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
// not support rounding modes.
unsafe {
asm!(
"frintn {x:s}, {x:s}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}
#[cfg(all(f16_enabled, target_feature = "fp16"))]
pub fn rintf16(mut x: f16) -> f16 {
// SAFETY: `frintn` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
//
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
// not support rounding modes.
unsafe {
asm!(
"frintn {x:h}, {x:h}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}
pub fn sqrt(mut x: f64) -> f64 {
// SAFETY: `fsqrt` is available with neon and has no side effects.
unsafe {
asm!(
"fsqrt {x:d}, {x:d}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}
pub fn sqrtf(mut x: f32) -> f32 {
// SAFETY: `fsqrt` is available with neon and has no side effects.
unsafe {
asm!(
"fsqrt {x:s}, {x:s}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}
#[cfg(all(f16_enabled, target_feature = "fp16"))]
pub fn sqrtf16(mut x: f16) -> f16 {
// SAFETY: `fsqrt` is available for `f16` with `fp16` (implies `neon`) and has no
// side effects.
unsafe {
asm!(
"fsqrt {x:h}, {x:h}",
x = inout(vreg) x,
options(nomem, nostack, pure)
);
}
x
}

View file

@ -18,12 +18,25 @@ cfg_if! {
mod i686;
pub use i686::{sqrt, sqrtf};
} else if #[cfg(all(
target_arch = "aarch64", // TODO: also arm64ec?
target_feature = "neon",
target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484
any(target_arch = "aarch64", target_arch = "arm64ec"),
target_feature = "neon"
))] {
mod aarch64;
pub use aarch64::{rint, rintf};
pub use aarch64::{
fma,
fmaf,
rint,
rintf,
sqrt,
sqrtf,
};
#[cfg(all(f16_enabled, target_feature = "fp16"))]
pub use aarch64::{
rintf16,
sqrtf16,
};
}
}

View file

@ -9,6 +9,12 @@ use super::{CastFrom, CastInto, Float, Int, MinInt};
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
select_implementation! {
name: fma,
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
args: x, y, z,
}
fma_round(x, y, z, Round::Nearest).val
}

View file

@ -17,6 +17,12 @@ pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
select_implementation! {
name: fmaf,
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
args: x, y, z,
}
fma_wide_round(x, y, z, Round::Nearest).val
}

View file

@ -4,6 +4,12 @@ use super::support::Round;
#[cfg(f16_enabled)]
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn rintf16(x: f16) -> f16 {
select_implementation! {
name: rintf16,
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
args: x,
}
super::generic::rint_round(x, Round::Nearest).val
}
@ -13,8 +19,8 @@ pub fn rintf(x: f32) -> f32 {
select_implementation! {
name: rintf,
use_arch: any(
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "wasm32", intrinsics_enabled),
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
),
args: x,
}
@ -28,8 +34,8 @@ pub fn rint(x: f64) -> f64 {
select_implementation! {
name: rint,
use_arch: any(
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "wasm32", intrinsics_enabled),
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
),
args: x,
}

View file

@ -4,6 +4,7 @@ pub fn sqrt(x: f64) -> f64 {
select_implementation! {
name: sqrt,
use_arch: any(
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "wasm32", intrinsics_enabled),
target_feature = "sse2"
),

View file

@ -4,6 +4,7 @@ pub fn sqrtf(x: f32) -> f32 {
select_implementation! {
name: sqrtf,
use_arch: any(
all(target_arch = "aarch64", target_feature = "neon"),
all(target_arch = "wasm32", intrinsics_enabled),
target_feature = "sse2"
),

View file

@ -1,5 +1,11 @@
/// The square root of `x` (f16).
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn sqrtf16(x: f16) -> f16 {
select_implementation! {
name: sqrtf16,
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
args: x,
}
return super::generic::sqrt(x);
}