Arm Fused Multiply-Add fixes (#1219)
This commit is contained in:
parent
328553ef64
commit
504b0cf68b
4 changed files with 60 additions and 19 deletions
|
|
@ -8721,7 +8721,7 @@ pub unsafe fn vmull_laneq_u32<const LANE: i32>(a: uint32x2_t, b: uint32x4_t) ->
|
|||
/// Floating-point fused Multiply-Add to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
|
||||
pub unsafe fn vfma_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
|
||||
|
|
@ -8737,7 +8737,7 @@ vfma_f32_(b, c, a)
|
|||
/// Floating-point fused Multiply-Add to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
|
||||
pub unsafe fn vfmaq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
|
||||
|
|
@ -8753,27 +8753,27 @@ vfmaq_f32_(b, c, a)
|
|||
/// Floating-point fused Multiply-Add to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
|
||||
pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
|
||||
vfma_f32(a, b, vdup_n_f32(c))
|
||||
vfma_f32(a, b, vdup_n_f32_vfp4(c))
|
||||
}
|
||||
|
||||
/// Floating-point fused Multiply-Add to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
|
||||
pub unsafe fn vfmaq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
|
||||
vfmaq_f32(a, b, vdupq_n_f32(c))
|
||||
vfmaq_f32(a, b, vdupq_n_f32_vfp4(c))
|
||||
}
|
||||
|
||||
/// Floating-point fused multiply-subtract from accumulator
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
|
||||
pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float32x2_t {
|
||||
|
|
@ -8784,7 +8784,7 @@ pub unsafe fn vfms_f32(a: float32x2_t, b: float32x2_t, c: float32x2_t) -> float3
|
|||
/// Floating-point fused multiply-subtract from accumulator
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
|
||||
pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
|
||||
|
|
@ -8795,21 +8795,21 @@ pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float
|
|||
/// Floating-point fused Multiply-subtract to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
|
||||
pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
|
||||
vfms_f32(a, b, vdup_n_f32(c))
|
||||
vfms_f32(a, b, vdup_n_f32_vfp4(c))
|
||||
}
|
||||
|
||||
/// Floating-point fused Multiply-subtract to accumulator(vector)
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "fp-armv8,v8"))]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
|
||||
pub unsafe fn vfmsq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
|
||||
vfmsq_f32(a, b, vdupq_n_f32(c))
|
||||
vfmsq_f32(a, b, vdupq_n_f32_vfp4(c))
|
||||
}
|
||||
|
||||
/// Subtract
|
||||
|
|
|
|||
|
|
@ -3786,6 +3786,19 @@ pub unsafe fn vdupq_n_f32(value: f32) -> float32x4_t {
|
|||
float32x4_t(value, value, value, value)
|
||||
}
|
||||
|
||||
/// Duplicate vector element to vector or scalar
|
||||
///
|
||||
/// Private vfp4 version used by FMA intriniscs because LLVM does
|
||||
/// not inline the non-vfp4 version in vfp4 functions.
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
|
||||
unsafe fn vdupq_n_f32_vfp4(value: f32) -> float32x4_t {
|
||||
float32x4_t(value, value, value, value)
|
||||
}
|
||||
|
||||
/// Duplicate vector element to vector or scalar
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
|
|
@ -3896,6 +3909,19 @@ pub unsafe fn vdup_n_f32(value: f32) -> float32x2_t {
|
|||
float32x2_t(value, value)
|
||||
}
|
||||
|
||||
/// Duplicate vector element to vector or scalar
|
||||
///
|
||||
/// Private vfp4 version used by FMA intriniscs because LLVM does
|
||||
/// not inline the non-vfp4 version in vfp4 functions.
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
|
||||
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
|
||||
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
|
||||
unsafe fn vdup_n_f32_vfp4(value: f32) -> float32x2_t {
|
||||
float32x2_t(value, value)
|
||||
}
|
||||
|
||||
/// Duplicate vector element to vector or scalar
|
||||
#[inline]
|
||||
#[target_feature(enable = "neon")]
|
||||
|
|
|
|||
|
|
@ -2733,7 +2733,7 @@ generate float64x1_t
|
|||
aarch64 = fmla
|
||||
generate float64x2_t
|
||||
|
||||
target = fp-armv8
|
||||
target = vfp4
|
||||
arm = vfma
|
||||
link-arm = llvm.fma._EXT_
|
||||
generate float*_t
|
||||
|
|
@ -2741,7 +2741,7 @@ generate float*_t
|
|||
/// Floating-point fused Multiply-Add to accumulator(vector)
|
||||
name = vfma
|
||||
n-suffix
|
||||
multi_fn = vfma-self-noext, a, b, {vdup-nself-noext, c}
|
||||
multi_fn = vfma-self-noext, a, b, {vdup-nselfvfp4-noext, c}
|
||||
a = 2.0, 3.0, 4.0, 5.0
|
||||
b = 6.0, 4.0, 7.0, 8.0
|
||||
c = 8.0
|
||||
|
|
@ -2752,7 +2752,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
|
|||
aarch64 = fmla
|
||||
generate float64x2_t:float64x2_t:f64:float64x2_t
|
||||
|
||||
target = fp-armv8
|
||||
target = vfp4
|
||||
arm = vfma
|
||||
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t
|
||||
|
||||
|
|
@ -2811,14 +2811,14 @@ generate float64x1_t
|
|||
aarch64 = fmls
|
||||
generate float64x2_t
|
||||
|
||||
target = fp-armv8
|
||||
target = vfp4
|
||||
arm = vfms
|
||||
generate float*_t
|
||||
|
||||
/// Floating-point fused Multiply-subtract to accumulator(vector)
|
||||
name = vfms
|
||||
n-suffix
|
||||
multi_fn = vfms-self-noext, a, b, {vdup-nself-noext, c}
|
||||
multi_fn = vfms-self-noext, a, b, {vdup-nselfvfp4-noext, c}
|
||||
a = 50.0, 35.0, 60.0, 69.0
|
||||
b = 6.0, 4.0, 7.0, 8.0
|
||||
c = 8.0
|
||||
|
|
@ -2829,7 +2829,7 @@ generate float64x1_t:float64x1_t:f64:float64x1_t
|
|||
aarch64 = fmls
|
||||
generate float64x2_t:float64x2_t:f64:float64x2_t
|
||||
|
||||
target = fp-armv8
|
||||
target = vfp4
|
||||
arm = vfms
|
||||
generate float32x2_t:float32x2_t:f32:float32x2_t, float32x4_t:float32x4_t:f32:float32x4_t
|
||||
|
||||
|
|
|
|||
|
|
@ -438,6 +438,7 @@ enum Suffix {
|
|||
enum TargetFeature {
|
||||
Default,
|
||||
ArmV7,
|
||||
Vfp4,
|
||||
FPArmV8,
|
||||
AES,
|
||||
}
|
||||
|
|
@ -980,6 +981,7 @@ fn gen_aarch64(
|
|||
let current_target = match target {
|
||||
Default => "neon",
|
||||
ArmV7 => "v7",
|
||||
Vfp4 => "vfp4",
|
||||
FPArmV8 => "fp-armv8,v8",
|
||||
AES => "neon,aes",
|
||||
};
|
||||
|
|
@ -1120,6 +1122,7 @@ fn gen_aarch64(
|
|||
out_t,
|
||||
fixed,
|
||||
None,
|
||||
true,
|
||||
));
|
||||
}
|
||||
calls
|
||||
|
|
@ -1630,12 +1633,14 @@ fn gen_arm(
|
|||
let current_target_aarch64 = match target {
|
||||
Default => "neon",
|
||||
ArmV7 => "neon",
|
||||
Vfp4 => "neon",
|
||||
FPArmV8 => "neon",
|
||||
AES => "neon,aes",
|
||||
};
|
||||
let current_target_arm = match target {
|
||||
Default => "v7",
|
||||
ArmV7 => "v7",
|
||||
Vfp4 => "vfp4",
|
||||
FPArmV8 => "fp-armv8,v8",
|
||||
AES => "aes,v8",
|
||||
};
|
||||
|
|
@ -1916,6 +1921,7 @@ fn gen_arm(
|
|||
out_t,
|
||||
fixed,
|
||||
None,
|
||||
false,
|
||||
));
|
||||
}
|
||||
calls
|
||||
|
|
@ -2283,6 +2289,7 @@ fn get_call(
|
|||
out_t: &str,
|
||||
fixed: &Vec<String>,
|
||||
n: Option<i32>,
|
||||
aarch64: bool,
|
||||
) -> String {
|
||||
let params: Vec<_> = in_str.split(',').map(|v| v.trim().to_string()).collect();
|
||||
assert!(params.len() > 0);
|
||||
|
|
@ -2450,7 +2457,8 @@ fn get_call(
|
|||
in_t,
|
||||
out_t,
|
||||
fixed,
|
||||
Some(i as i32)
|
||||
Some(i as i32),
|
||||
aarch64
|
||||
)
|
||||
);
|
||||
call.push_str(&sub_match);
|
||||
|
|
@ -2499,6 +2507,7 @@ fn get_call(
|
|||
out_t,
|
||||
fixed,
|
||||
n.clone(),
|
||||
aarch64,
|
||||
);
|
||||
if !param_str.is_empty() {
|
||||
param_str.push_str(", ");
|
||||
|
|
@ -2569,6 +2578,11 @@ fn get_call(
|
|||
fn_name.push_str(type_to_suffix(in_t[1]));
|
||||
} else if fn_format[1] == "nself" {
|
||||
fn_name.push_str(type_to_n_suffix(in_t[1]));
|
||||
} else if fn_format[1] == "nselfvfp4" {
|
||||
fn_name.push_str(type_to_n_suffix(in_t[1]));
|
||||
if !aarch64 {
|
||||
fn_name.push_str("_vfp4");
|
||||
}
|
||||
} else if fn_format[1] == "out" {
|
||||
fn_name.push_str(type_to_suffix(out_t));
|
||||
} else if fn_format[1] == "in0" {
|
||||
|
|
@ -2854,6 +2868,7 @@ mod test {
|
|||
target = match Some(String::from(&line[9..])) {
|
||||
Some(input) => match input.as_str() {
|
||||
"v7" => ArmV7,
|
||||
"vfp4" => Vfp4,
|
||||
"fp-armv8" => FPArmV8,
|
||||
"aes" => AES,
|
||||
_ => Default,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue