add neon instruction vfma (#1116)

This commit is contained in:
surechen 2021-04-14 22:34:53 +08:00 committed by GitHub
parent 6405058a6f
commit aaaa9335eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 121 additions and 1 deletions

View file

@ -2784,6 +2784,32 @@ pub unsafe fn vmull_high_p8(a: poly8x16_t, b: poly8x16_t) -> poly16x8_t {
vmull_p8(a, b)
}
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fmadd))]
pub unsafe fn vfma_f64(a: float64x1_t, b: float64x1_t, c: float64x1_t) -> float64x1_t {
#[allow(improper_ctypes)]
extern "C" {
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.fma.v1f64")]
fn vfma_f64_(a: float64x1_t, b: float64x1_t, c: float64x1_t) -> float64x1_t;
}
vfma_f64_(a, b, c)
}
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fmla))]
pub unsafe fn vfmaq_f64(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {
#[allow(improper_ctypes)]
extern "C" {
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.fma.v2f64")]
fn vfmaq_f64_(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t;
}
vfmaq_f64_(a, b, c)
}
/// Divide
#[inline]
#[target_feature(enable = "neon")]
@ -7233,6 +7259,26 @@ mod test {
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vfma_f64() {
let a: f64 = 2.0;
let b: f64 = 6.0;
let c: f64 = 8.0;
let e: f64 = 20.0;
let r: f64 = transmute(vfma_f64(transmute(a), transmute(b), transmute(c)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vfmaq_f64() {
let a: f64x2 = f64x2::new(2.0, 3.0);
let b: f64x2 = f64x2::new(6.0, 4.0);
let c: f64x2 = f64x2::new(8.0, 18.0);
let e: f64x2 = f64x2::new(20.0, 30.0);
let r: f64x2 = transmute(vfmaq_f64(transmute(a), transmute(b), transmute(c)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vdiv_f32() {
let a: f32x2 = f32x2::new(2.0, 6.0);

View file

@ -4706,6 +4706,38 @@ pub unsafe fn vmull_p8(a: poly8x8_t, b: poly8x8_t) -> poly16x8_t {
vmull_p8_(a, b)
}
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfma_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
#[allow(improper_ctypes)]
extern "C" {
#[cfg_attr(target_arch = "arm", link_name = "llvm.fma.v2f32")]
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.fma.v2f32")]
fn vfma_f32_(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t;
}
vfma_f32_(a, b, c)
}
/// Floating-point fused Multiply-Add to accumulator(vector)
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfmaq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
#[allow(improper_ctypes)]
extern "C" {
#[cfg_attr(target_arch = "arm", link_name = "llvm.fma.v4f32")]
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.fma.v4f32")]
fn vfmaq_f32_(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t;
}
vfmaq_f32_(a, b, c)
}
/// Subtract
#[inline]
#[target_feature(enable = "neon")]
@ -12642,6 +12674,26 @@ mod test {
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vfma_f32() {
let a: f32x2 = f32x2::new(2.0, 3.0);
let b: f32x2 = f32x2::new(6.0, 4.0);
let c: f32x2 = f32x2::new(8.0, 18.0);
let e: f32x2 = f32x2::new(20.0, 30.0);
let r: f32x2 = transmute(vfma_f32(transmute(a), transmute(b), transmute(c)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vfmaq_f32() {
let a: f32x4 = f32x4::new(2.0, 3.0, 4.0, 5.0);
let b: f32x4 = f32x4::new(6.0, 4.0, 7.0, 8.0);
let c: f32x4 = f32x4::new(8.0, 18.0, 12.0, 10.0);
let e: f32x4 = f32x4::new(20.0, 30.0, 40.0, 50.0);
let r: f32x4 = transmute(vfmaq_f32(transmute(a), transmute(b), transmute(c)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vsub_s8() {
let a: i8x8 = i8x8::new(1, 2, 3, 4, 5, 6, 7, 8);

View file

@ -1544,6 +1544,28 @@ validate 9, 30, 11, 20, 13, 18, 15, 48
aarch64 = pmull
generate poly8x16_t:poly8x16_t:poly16x8_t
/// Floating-point fused Multiply-Add to accumulator(vector)
name = vfma
a = 2.0, 3.0, 4.0, 5.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0, 18.0, 12.0, 10.0
validate 20.0, 30.0, 40.0, 50.0
aarch64 = fmadd
link-aarch64 = llvm.fma._EXT_
generate float64x1_t
aarch64 = fmla
link-aarch64 = llvm.fma._EXT_
generate float64x2_t
target = fp-armv8
arm = vfma
aarch64 = fmla
link-arm = llvm.fma._EXT_
link-aarch64 = llvm.fma._EXT_
generate float*_t
/// Divide
name = vdiv
fn = simd_div

View file

@ -1238,7 +1238,7 @@ fn gen_arm(
),
(0, 3, _) => format!(
r#"pub unsafe fn {}{}(a: {}, b: {}, c: {}) -> {} {{
{}{}(a, b)
{}{}(a, b, c)
}}"#,
name, const_declare, in_t[0], in_t[1], in_t[2], out_t, ext_c, current_fn,
),