implement different types of parameters and double suffixes in code generator (#1083)

This commit is contained in:
Sparrow Li 2021-03-16 02:45:51 +08:00 committed by GitHub
parent 6759925278
commit bb84df7d9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 263 additions and 26 deletions

View file

@ -1099,6 +1099,60 @@ pub unsafe fn vcaleq_f64(a: float64x2_t, b: float64x2_t) -> uint64x2_t {
vcageq_f64(b, a)
}
/// Floating-point convert to higher precision long
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtl))]
pub unsafe fn vcvt_f64_f32(a: float32x2_t) -> float64x2_t {
simd_cast(a)
}
/// Floating-point convert to higher precision long
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtl))]
pub unsafe fn vcvt_high_f64_f32(a: float32x4_t) -> float64x2_t {
let b: float32x2_t = simd_shuffle2(a, a, [2, 3]);
simd_cast(b)
}
/// Floating-point convert to lower precision narrow
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtn))]
pub unsafe fn vcvt_f32_f64(a: float64x2_t) -> float32x2_t {
simd_cast(a)
}
/// Floating-point convert to lower precision narrow
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtn))]
pub unsafe fn vcvt_high_f32_f64(a: float32x2_t, b: float64x2_t) -> float32x4_t {
simd_shuffle4(a, simd_cast(b), [0, 1, 2, 3])
}
/// Floating-point convert to lower precision narrow, rounding to odd
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtxn))]
pub unsafe fn vcvtx_f32_f64(a: float64x2_t) -> float32x2_t {
#[allow(improper_ctypes)]
extern "C" {
#[cfg_attr(target_arch = "aarch64", link_name = "llvm.aarch64.neon.fcvtxn.v2f32.v2f64")]
fn vcvtx_f32_f64_(a: float64x2_t) -> float32x2_t;
}
vcvtx_f32_f64_(a)
}
/// Floating-point convert to lower precision narrow, rounding to odd
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(test, assert_instr(fcvtxn))]
pub unsafe fn vcvtx_high_f32_f64(a: float32x2_t, b: float64x2_t) -> float32x4_t {
simd_shuffle4(a, vcvtx_f32_f64(b), [0, 1, 2, 3])
}
/// Multiply
#[inline]
#[target_feature(enable = "neon")]
@ -2366,6 +2420,56 @@ mod test {
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvt_f64_f32() {
let a: f32x2 = f32x2::new(-1.2, 1.2);
let e: f64x2 = f64x2::new(-1.2f32 as f64, 1.2f32 as f64);
let r: f64x2 = transmute(vcvt_f64_f32(transmute(a)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvt_high_f64_f32() {
let a: f32x4 = f32x4::new(-1.2, 1.2, 2.3, 3.4);
let e: f64x2 = f64x2::new(2.3f32 as f64, 3.4f32 as f64);
let r: f64x2 = transmute(vcvt_high_f64_f32(transmute(a)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvt_f32_f64() {
let a: f64x2 = f64x2::new(-1.2, 1.2);
let e: f32x2 = f32x2::new(-1.2f64 as f32, 1.2f64 as f32);
let r: f32x2 = transmute(vcvt_f32_f64(transmute(a)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvt_high_f32_f64() {
let a: f32x2 = f32x2::new(-1.2, 1.2);
let b: f64x2 = f64x2::new(-2.3, 3.4);
let e: f32x4 = f32x4::new(-1.2, 1.2, -2.3f64 as f32, 3.4f64 as f32);
let r: f32x4 = transmute(vcvt_high_f32_f64(transmute(a), transmute(b)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvtx_f32_f64() {
let a: f64x2 = f64x2::new(-1.0, 2.0);
let e: f32x2 = f32x2::new(-1.0, 2.0);
let r: f32x2 = transmute(vcvtx_f32_f64(transmute(a)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vcvtx_high_f32_f64() {
let a: f32x2 = f32x2::new(-1.0, 2.0);
let b: f64x2 = f64x2::new(-3.0, 4.0);
let e: f32x4 = f32x4::new(-1.0, 2.0, -3.0, 4.0);
let r: f32x4 = transmute(vcvtx_high_f32_f64(transmute(a), transmute(b)));
assert_eq!(r, e);
}
#[simd_test(enable = "neon")]
unsafe fn test_vmul_f64() {
let a: f64 = 1.0;

View file

@ -527,7 +527,7 @@ generate int*_t
/// Unsigned count leading sign bits
name = vclz
multi_fn = transmute, [self-signed-ext, transmute(a)]
multi_fn = transmute, {self-signed-ext, transmute(a)}
a = MIN, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, MAX
validate BITS, BITS, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, BITS_M1, 0
@ -589,6 +589,69 @@ generate float64x1_t:uint64x1_t, float64x2_t:uint64x2_t
arm = vacge.s
generate float32x2_t:uint32x2_t, float32x4_t:uint32x4_t
/// Floating-point convert to higher precision long
name = vcvt
double-suffixes
fn = simd_cast
a = -1.2, 1.2
validate -1.2f32 as f64, 1.2f32 as f64
aarch64 = fcvtl
generate float32x2_t:float64x2_t
/// Floating-point convert to higher precision long
name = vcvt_high
double-suffixes
multi_fn = simd_shuffle2, b:float32x2_t, a, a, [2, 3]
multi_fn = simd_cast, b
a = -1.2, 1.2, 2.3, 3.4
validate 2.3f32 as f64, 3.4f32 as f64
aarch64 = fcvtl
generate float32x4_t:float64x2_t
/// Floating-point convert to lower precision narrow
name = vcvt
double-suffixes
fn = simd_cast
a = -1.2, 1.2
validate -1.2f64 as f32, 1.2f64 as f32
aarch64 = fcvtn
generate float64x2_t:float32x2_t
/// Floating-point convert to lower precision narrow
name = vcvt_high
double-suffixes
multi_fn = simd_shuffle4, a, {simd_cast, b}, [0, 1, 2, 3]
a = -1.2, 1.2
b = -2.3, 3.4
validate -1.2, 1.2, -2.3f64 as f32, 3.4f64 as f32
aarch64 = fcvtn
generate float32x2_t:float64x2_t:float32x4_t
/// Floating-point convert to lower precision narrow, rounding to odd
name = vcvtx
double-suffixes
a = -1.0, 2.0
validate -1.0, 2.0
aarch64 = fcvtxn
link-aarch64 = fcvtxn._EXT2_._EXT_
generate float64x2_t:float32x2_t
/// Floating-point convert to lower precision narrow, rounding to odd
name = vcvtx_high
double-suffixes
multi_fn = simd_shuffle4, a, {vcvtx-doubleself-noext, b}, [0, 1, 2, 3]
a = -1.0, 2.0
b = -3.0, 4.0
validate -1.0, 2.0, -3.0, 4.0
aarch64 = fcvtxn
generate float32x2_t:float64x2_t:float32x4_t
/// Saturating subtract
name = vqsub
a = 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42

View file

@ -160,6 +160,20 @@ fn type_to_unsigned_suffix(t: &str) -> &str {
}
}
fn type_to_double_suffixes<'a>(out_t: &'a str, in_t: &'a str) -> &'a str {
match (out_t, in_t) {
("float32x2_t", "float64x2_t") => "_f32_f64",
("float64x2_t", "float32x2_t") => "_f64_f32",
("float64x2_t", "float32x4_t") => "_f64_f32",
("float32x4_t", "float64x2_t") => "_f32_f64",
("int32x2_t", "float32x2_t") => "_s32_f32",
("int32x4_t", "float32x4_t") => "q_s32_f32",
("int64x1_t", "float64x1_t") => "_s64_f64",
("int64x2_t", "float64x2_t") => "q_s64_f64",
(_, _) => panic!("unknown type: {}, {}", out_t, in_t),
}
}
fn type_to_global_type(t: &str) -> &str {
match t {
"int8x8_t" => "i8x8",
@ -388,15 +402,21 @@ fn gen_aarch64(
current_aarch64: &Option<String>,
link_aarch64: &Option<String>,
in_t: &str,
in_t2: &str,
out_t: &str,
current_tests: &[(Vec<String>, Vec<String>, Vec<String>)],
double_suffixes: bool,
para_num: i32,
fixed: &Vec<String>,
multi_fn: &Vec<String>,
) -> (String, String) {
let _global_t = type_to_global_type(in_t);
let _global_ret_t = type_to_global_type(out_t);
let name = format!("{}{}", current_name, type_to_suffix(in_t));
let name = if double_suffixes {
format!("{}{}", current_name, type_to_double_suffixes(out_t, in_t2))
} else {
format!("{}{}", current_name, type_to_suffix(in_t2))
};
let current_fn = if let Some(current_fn) = current_fn.clone() {
if link_aarch64.is_some() {
panic!(
@ -440,7 +460,7 @@ fn gen_aarch64(
format!("a: {}", in_t)
}
2 => {
format!("a: {}, b: {}", in_t, in_t)
format!("a: {}, b: {}", in_t, in_t2)
}
_ => unimplemented!("unknown para_num"),
},
@ -455,7 +475,14 @@ fn gen_aarch64(
if i > 0 {
calls.push_str("\n ");
}
calls.push_str(&get_call(&multi_fn[i], current_name, in_t, out_t, fixed));
calls.push_str(&get_call(
&multi_fn[i],
current_name,
in_t,
in_t2,
out_t,
fixed,
));
}
calls
} else {
@ -466,7 +493,7 @@ fn gen_aarch64(
r#"pub unsafe fn {}(a: {}, b: {}) -> {} {{
{}{}(a, b)
}}"#,
name, in_t, in_t, out_t, ext_c, current_fn,
name, in_t, in_t2, out_t, ext_c, current_fn,
),
(0, 1, 0) => format!(
r#"pub unsafe fn {}(a: {}) -> {} {{
@ -499,7 +526,7 @@ fn gen_aarch64(
r#"pub unsafe fn {}(a: {}, b: {}) -> {} {{
{}{}
}}"#,
name, in_t, in_t, out_t, ext_c, multi_calls,
name, in_t, in_t2, out_t, ext_c, multi_calls,
),
(_, _, _) => String::new(),
};
@ -517,9 +544,12 @@ fn gen_aarch64(
let test = gen_test(
&name,
&in_t,
&in_t2,
&out_t,
current_tests,
type_len(in_t),
type_len(in_t2),
type_len(out_t),
para_num,
);
(function, test)
@ -528,9 +558,12 @@ fn gen_aarch64(
fn gen_test(
name: &str,
in_t: &str,
in_t2: &str,
out_t: &str,
current_tests: &[(Vec<String>, Vec<String>, Vec<String>)],
len: usize,
len_in: usize,
len_in2: usize,
len_out: usize,
para_num: i32,
) -> String {
let mut test = format!(
@ -540,9 +573,9 @@ fn gen_test(
name,
);
for (a, b, e) in current_tests {
let a: Vec<String> = a.iter().take(len).cloned().collect();
let b: Vec<String> = b.iter().take(len).cloned().collect();
let e: Vec<String> = e.iter().take(len).cloned().collect();
let a: Vec<String> = a.iter().take(len_in).cloned().collect();
let b: Vec<String> = b.iter().take(len_in2).cloned().collect();
let e: Vec<String> = e.iter().take(len_out).cloned().collect();
let t = {
match para_num {
1 => {
@ -569,7 +602,7 @@ fn gen_test(
assert_eq!(r, e);
"#,
values(in_t, &a),
values(in_t, &b),
values(in_t2, &b),
values(out_t, &e),
type_to_global_type(out_t),
name
@ -597,15 +630,21 @@ fn gen_arm(
current_aarch64: &Option<String>,
link_aarch64: &Option<String>,
in_t: &str,
in_t2: &str,
out_t: &str,
current_tests: &[(Vec<String>, Vec<String>, Vec<String>)],
double_suffixes: bool,
para_num: i32,
fixed: &Vec<String>,
multi_fn: &Vec<String>,
) -> (String, String) {
let _global_t = type_to_global_type(in_t);
let _global_ret_t = type_to_global_type(out_t);
let name = format!("{}{}", current_name, type_to_suffix(in_t));
let name = if double_suffixes {
format!("{}{}", current_name, type_to_double_suffixes(out_t, in_t2))
} else {
format!("{}{}", current_name, type_to_suffix(in_t2))
};
let current_aarch64 = current_aarch64
.clone()
.unwrap_or_else(|| current_arm.to_string());
@ -655,7 +694,7 @@ fn gen_arm(
format!("a: {}", in_t)
}
2 => {
format!("a: {}, b: {}", in_t, in_t)
format!("a: {}, b: {}", in_t, in_t2)
}
_ => unimplemented!("unknown para_num"),
},
@ -670,7 +709,14 @@ fn gen_arm(
if i > 0 {
calls.push_str("\n ");
}
calls.push_str(&get_call(&multi_fn[i], current_name, in_t, out_t, fixed));
calls.push_str(&get_call(
&multi_fn[i],
current_name,
in_t,
in_t2,
out_t,
fixed,
));
}
calls
} else {
@ -681,7 +727,7 @@ fn gen_arm(
r#"pub unsafe fn {}(a: {}, b: {}) -> {} {{
{}{}(a, b)
}}"#,
name, in_t, in_t, out_t, ext_c, current_fn,
name, in_t, in_t2, out_t, ext_c, current_fn,
),
(0, 1, 0) => format!(
r#"pub unsafe fn {}(a: {}) -> {} {{
@ -714,7 +760,7 @@ fn gen_arm(
r#"pub unsafe fn {}(a: {}, b: {}) -> {} {{
{}{}
}}"#,
name, in_t, in_t, out_t, ext_c, multi_calls,
name, in_t, in_t2, out_t, ext_c, multi_calls,
),
(_, _, _) => String::new(),
};
@ -736,9 +782,12 @@ fn gen_arm(
let test = gen_test(
&name,
&in_t,
&in_t2,
&out_t,
current_tests,
type_len(in_t),
type_len(in_t2),
type_len(out_t),
para_num,
);
@ -819,6 +868,7 @@ fn get_call(
in_str: &str,
current_name: &str,
in_t: &str,
in_t2: &str,
out_t: &str,
fixed: &Vec<String>,
) -> String {
@ -830,20 +880,20 @@ fn get_call(
let mut i = 1;
while i < params.len() {
let s = &params[i];
if s.starts_with('[') {
if s.starts_with('{') {
let mut sub_fn = String::new();
let mut brackets = 1;
let mut paranthes = 0;
while i < params.len() {
if !sub_fn.is_empty() {
sub_fn.push_str(", ");
}
sub_fn.push_str(&params[i]);
if params[i].starts_with('[') {
brackets += 1;
if params[i].starts_with('{') {
paranthes += 1;
}
if params[i].ends_with("]") {
brackets -= 1;
if brackets == 0 {
if params[i].ends_with('}') {
paranthes -= 1;
if paranthes == 0 {
break;
}
}
@ -853,6 +903,7 @@ fn get_call(
&sub_fn[1..sub_fn.len() - 1],
current_name,
in_t,
in_t2,
out_t,
fixed,
);
@ -868,6 +919,8 @@ fn get_call(
re = Some((re_params[0].clone(), in_t.to_string()));
} else if re_params[1] == "out_t" {
re = Some((re_params[0].clone(), out_t.to_string()));
} else {
re = Some((re_params[0].clone(), re_params[1].clone()));
}
} else {
if !param_str.is_empty() {
@ -891,11 +944,13 @@ fn get_call(
fn_format[0].clone()
};
if fn_format[1] == "self" {
fn_name.push_str(type_to_suffix(in_t));
fn_name.push_str(type_to_suffix(in_t2));
} else if fn_format[1] == "signed" {
fn_name.push_str(type_to_signed_suffix(in_t));
fn_name.push_str(type_to_signed_suffix(in_t2));
} else if fn_format[1] == "unsigned" {
fn_name.push_str(type_to_unsigned_suffix(in_t));
fn_name.push_str(type_to_unsigned_suffix(in_t2));
} else if fn_format[1] == "doubleself" {
fn_name.push_str(type_to_double_suffixes(out_t, in_t2));
} else {
fn_name.push_str(&fn_format[1]);
};
@ -932,6 +987,7 @@ fn main() -> io::Result<()> {
let mut link_arm: Option<String> = None;
let mut link_aarch64: Option<String> = None;
let mut para_num = 2;
let mut double_suffixes = false;
let mut a: Vec<String> = Vec::new();
let mut b: Vec<String> = Vec::new();
let mut fixed: Vec<String> = Vec::new();
@ -1007,6 +1063,7 @@ mod test {
link_arm = None;
current_tests = Vec::new();
para_num = 2;
double_suffixes = false;
a = Vec::new();
b = Vec::new();
fixed = Vec::new();
@ -1022,6 +1079,8 @@ mod test {
current_arm = Some(String::from(&line[6..]));
} else if line.starts_with("aarch64 = ") {
current_aarch64 = Some(String::from(&line[10..]));
} else if line.starts_with("double-suffixes") {
double_suffixes = true;
} else if line.starts_with("a = ") {
a = line[4..].split(',').map(|v| v.trim().to_string()).collect();
} else if line.starts_with("b = ") {
@ -1054,13 +1113,20 @@ mod test {
for line in types {
let spec: Vec<&str> = line.split(':').map(|e| e.trim()).collect();
let in_t;
let in_t2;
let out_t;
if spec.len() == 1 {
in_t = spec[0];
in_t2 = spec[0];
out_t = spec[0];
} else if spec.len() == 2 {
in_t = spec[0];
in_t2 = spec[0];
out_t = spec[1];
} else if spec.len() == 3 {
in_t = spec[0];
in_t2 = spec[1];
out_t = spec[2];
} else {
panic!("Bad spec: {}", line)
}
@ -1078,8 +1144,10 @@ mod test {
&current_aarch64,
&link_aarch64,
&in_t,
&in_t2,
&out_t,
&current_tests,
double_suffixes,
para_num,
&fixed,
&multi_fn,
@ -1094,8 +1162,10 @@ mod test {
&current_aarch64,
&link_aarch64,
&in_t,
&in_t2,
&out_t,
&current_tests,
double_suffixes,
para_num,
&fixed,
&multi_fn,