Add assembly version of simple operations on aarch64
Replace `core::arch` versions of the following with handwritten assembly, which avoids recursion issues (cg_gcc using `rint` as a fallback) as well as problems with `aarch64be`. * `rint` * `rintf` Additionally, add assembly versions of the following: * `fma` * `fmaf` * `sqrt` * `sqrtf` If the `fp16` target feature is available, which implies `neon`, also include the following: * `rintf16` * `sqrtf16` `sqrt` is added to match the implementation for `x86`. `fma` is included since it is used by many other routines. There are a handful of other operations that have assembly implementations. They are omitted here because we should have basic float math routines available in `core` in the near future, which will allow us to defer to LLVM for assembly lowering rather than implementing these ourselves.
This commit is contained in:
parent
375cb5402f
commit
28b6df8603
9 changed files with 161 additions and 34 deletions
|
|
@ -342,12 +342,14 @@
|
|||
},
|
||||
"fma": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/fma.rs"
|
||||
],
|
||||
"type": "f64"
|
||||
},
|
||||
"fmaf": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/fma_wide.rs"
|
||||
],
|
||||
"type": "f32"
|
||||
|
|
@ -806,6 +808,7 @@
|
|||
},
|
||||
"rintf16": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/rint.rs"
|
||||
],
|
||||
"type": "f16"
|
||||
|
|
@ -928,6 +931,7 @@
|
|||
},
|
||||
"sqrt": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/arch/i686.rs",
|
||||
"src/math/arch/wasm32.rs",
|
||||
"src/math/generic/sqrt.rs",
|
||||
|
|
@ -937,6 +941,7 @@
|
|||
},
|
||||
"sqrtf": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/arch/i686.rs",
|
||||
"src/math/arch/wasm32.rs",
|
||||
"src/math/generic/sqrt.rs",
|
||||
|
|
@ -953,6 +958,7 @@
|
|||
},
|
||||
"sqrtf16": {
|
||||
"sources": [
|
||||
"src/math/arch/aarch64.rs",
|
||||
"src/math/generic/sqrt.rs",
|
||||
"src/math/sqrtf16.rs"
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,33 +1,115 @@
|
|||
use core::arch::aarch64::{
|
||||
float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32,
|
||||
vrndn_f64,
|
||||
};
|
||||
//! Architecture-specific support for aarch64 with neon.
|
||||
|
||||
pub fn rint(x: f64) -> f64 {
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let x_vec: float64x1_t = unsafe { vdup_n_f64(x) };
|
||||
use core::arch::asm;
|
||||
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) };
|
||||
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) };
|
||||
|
||||
result
|
||||
pub fn fma(mut x: f64, y: f64, z: f64) -> f64 {
|
||||
// SAFETY: `fmadd` is available with neon and has no side effects.
|
||||
unsafe {
|
||||
asm!(
|
||||
"fmadd {x:d}, {x:d}, {y:d}, {z:d}",
|
||||
x = inout(vreg) x,
|
||||
y = in(vreg) y,
|
||||
z = in(vreg) z,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn rintf(x: f32) -> f32 {
|
||||
// There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we
|
||||
// have to use the vector form and drop the other lanes afterwards.
|
||||
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let x_vec: float32x2_t = unsafe { vdup_n_f32(x) };
|
||||
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) };
|
||||
|
||||
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
|
||||
let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) };
|
||||
|
||||
result
|
||||
pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 {
|
||||
// SAFETY: `fmadd` is available with neon and has no side effects.
|
||||
unsafe {
|
||||
asm!(
|
||||
"fmadd {x:s}, {x:s}, {y:s}, {z:s}",
|
||||
x = inout(vreg) x,
|
||||
y = in(vreg) y,
|
||||
z = in(vreg) z,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn rint(mut x: f64) -> f64 {
|
||||
// SAFETY: `frintn` is available with neon and has no side effects.
|
||||
//
|
||||
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
|
||||
// not support rounding modes.
|
||||
unsafe {
|
||||
asm!(
|
||||
"frintn {x:d}, {x:d}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn rintf(mut x: f32) -> f32 {
|
||||
// SAFETY: `frintn` is available with neon and has no side effects.
|
||||
//
|
||||
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
|
||||
// not support rounding modes.
|
||||
unsafe {
|
||||
asm!(
|
||||
"frintn {x:s}, {x:s}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
#[cfg(all(f16_enabled, target_feature = "fp16"))]
|
||||
pub fn rintf16(mut x: f16) -> f16 {
|
||||
// SAFETY: `frintn` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
|
||||
//
|
||||
// `frintn` is always round-to-nearest which does not match the C specification, but Rust does
|
||||
// not support rounding modes.
|
||||
unsafe {
|
||||
asm!(
|
||||
"frintn {x:h}, {x:h}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn sqrt(mut x: f64) -> f64 {
|
||||
// SAFETY: `fsqrt` is available with neon and has no side effects.
|
||||
unsafe {
|
||||
asm!(
|
||||
"fsqrt {x:d}, {x:d}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
pub fn sqrtf(mut x: f32) -> f32 {
|
||||
// SAFETY: `fsqrt` is available with neon and has no side effects.
|
||||
unsafe {
|
||||
asm!(
|
||||
"fsqrt {x:s}, {x:s}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
#[cfg(all(f16_enabled, target_feature = "fp16"))]
|
||||
pub fn sqrtf16(mut x: f16) -> f16 {
|
||||
// SAFETY: `fsqrt` is available for `f16` with `fp16` (implies `neon`) and has no
|
||||
// side effects.
|
||||
unsafe {
|
||||
asm!(
|
||||
"fsqrt {x:h}, {x:h}",
|
||||
x = inout(vreg) x,
|
||||
options(nomem, nostack, pure)
|
||||
);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,12 +18,25 @@ cfg_if! {
|
|||
mod i686;
|
||||
pub use i686::{sqrt, sqrtf};
|
||||
} else if #[cfg(all(
|
||||
target_arch = "aarch64", // TODO: also arm64ec?
|
||||
target_feature = "neon",
|
||||
target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484
|
||||
any(target_arch = "aarch64", target_arch = "arm64ec"),
|
||||
target_feature = "neon"
|
||||
))] {
|
||||
mod aarch64;
|
||||
pub use aarch64::{rint, rintf};
|
||||
|
||||
pub use aarch64::{
|
||||
fma,
|
||||
fmaf,
|
||||
rint,
|
||||
rintf,
|
||||
sqrt,
|
||||
sqrtf,
|
||||
};
|
||||
|
||||
#[cfg(all(f16_enabled, target_feature = "fp16"))]
|
||||
pub use aarch64::{
|
||||
rintf16,
|
||||
sqrtf16,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ use super::{CastFrom, CastInto, Float, Int, MinInt};
|
|||
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
|
||||
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
|
||||
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
|
||||
select_implementation! {
|
||||
name: fma,
|
||||
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
|
||||
args: x, y, z,
|
||||
}
|
||||
|
||||
fma_round(x, y, z, Round::Nearest).val
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,12 @@ pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
|
|||
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
|
||||
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
|
||||
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
|
||||
select_implementation! {
|
||||
name: fmaf,
|
||||
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
|
||||
args: x, y, z,
|
||||
}
|
||||
|
||||
fma_wide_round(x, y, z, Round::Nearest).val
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@ use super::support::Round;
|
|||
#[cfg(f16_enabled)]
|
||||
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
|
||||
pub fn rintf16(x: f16) -> f16 {
|
||||
select_implementation! {
|
||||
name: rintf16,
|
||||
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
|
||||
args: x,
|
||||
}
|
||||
|
||||
super::generic::rint_round(x, Round::Nearest).val
|
||||
}
|
||||
|
||||
|
|
@ -13,8 +19,8 @@ pub fn rintf(x: f32) -> f32 {
|
|||
select_implementation! {
|
||||
name: rintf,
|
||||
use_arch: any(
|
||||
all(target_arch = "aarch64", target_feature = "neon"),
|
||||
all(target_arch = "wasm32", intrinsics_enabled),
|
||||
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
|
||||
),
|
||||
args: x,
|
||||
}
|
||||
|
|
@ -28,8 +34,8 @@ pub fn rint(x: f64) -> f64 {
|
|||
select_implementation! {
|
||||
name: rint,
|
||||
use_arch: any(
|
||||
all(target_arch = "aarch64", target_feature = "neon"),
|
||||
all(target_arch = "wasm32", intrinsics_enabled),
|
||||
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
|
||||
),
|
||||
args: x,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ pub fn sqrt(x: f64) -> f64 {
|
|||
select_implementation! {
|
||||
name: sqrt,
|
||||
use_arch: any(
|
||||
all(target_arch = "aarch64", target_feature = "neon"),
|
||||
all(target_arch = "wasm32", intrinsics_enabled),
|
||||
target_feature = "sse2"
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ pub fn sqrtf(x: f32) -> f32 {
|
|||
select_implementation! {
|
||||
name: sqrtf,
|
||||
use_arch: any(
|
||||
all(target_arch = "aarch64", target_feature = "neon"),
|
||||
all(target_arch = "wasm32", intrinsics_enabled),
|
||||
target_feature = "sse2"
|
||||
),
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
/// The square root of `x` (f16).
|
||||
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
|
||||
pub fn sqrtf16(x: f16) -> f16 {
|
||||
select_implementation! {
|
||||
name: sqrtf16,
|
||||
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
|
||||
args: x,
|
||||
}
|
||||
|
||||
return super::generic::sqrt(x);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue