Convert _mm256_blend_{ps,pd} to const generics (#1058)

This commit is contained in:
tmiasko 2021-03-07 16:20:05 +01:00 committed by GitHub
parent e54e113b05
commit 1f1bfc92df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -458,44 +458,21 @@ pub unsafe fn _mm256_sqrt_pd(a: __m256d) -> __m256d {
// Note: LLVM7 prefers single-precision blend instructions when
// possible, see: https://bugs.llvm.org/show_bug.cgi?id=38194
// #[cfg_attr(test, assert_instr(vblendpd, imm8 = 9))]
#[cfg_attr(test, assert_instr(vblendps, imm8 = 9))]
#[rustc_args_required_const(2)]
#[cfg_attr(test, assert_instr(vblendps, IMM4 = 9))]
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_blend_pd(a: __m256d, b: __m256d, imm8: i32) -> __m256d {
let imm8 = (imm8 & 0xFF) as u8;
macro_rules! blend4 {
($a:expr, $b:expr, $c:expr, $d:expr) => {
simd_shuffle4(a, b, [$a, $b, $c, $d])
};
}
macro_rules! blend3 {
($a:expr, $b:expr, $c:expr) => {
match imm8 & 0x8 {
0 => blend4!($a, $b, $c, 3),
_ => blend4!($a, $b, $c, 7),
}
};
}
macro_rules! blend2 {
($a:expr, $b:expr) => {
match imm8 & 0x4 {
0 => blend3!($a, $b, 2),
_ => blend3!($a, $b, 6),
}
};
}
macro_rules! blend1 {
($a:expr) => {
match imm8 & 0x2 {
0 => blend2!($a, 1),
_ => blend2!($a, 5),
}
};
}
match imm8 & 0x1 {
0 => blend1!(0),
_ => blend1!(4),
}
pub unsafe fn _mm256_blend_pd<const IMM4: i32>(a: __m256d, b: __m256d) -> __m256d {
static_assert_imm4!(IMM4);
simd_shuffle4(
a,
b,
[
((IMM4 as u32 >> 0) & 1) * 4 + 0,
((IMM4 as u32 >> 1) & 1) * 4 + 1,
((IMM4 as u32 >> 2) & 1) * 4 + 2,
((IMM4 as u32 >> 3) & 1) * 4 + 3,
],
)
}
/// Blends packed single-precision (32-bit) floating-point elements from
@ -504,61 +481,25 @@ pub unsafe fn _mm256_blend_pd(a: __m256d, b: __m256d, imm8: i32) -> __m256d {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_blend_ps)
#[inline]
#[target_feature(enable = "avx")]
#[cfg_attr(test, assert_instr(vblendps, imm8 = 9))]
#[rustc_args_required_const(2)]
#[cfg_attr(test, assert_instr(vblendps, IMM8 = 9))]
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_blend_ps(a: __m256, b: __m256, imm8: i32) -> __m256 {
let imm8 = (imm8 & 0xFF) as u8;
macro_rules! blend4 {
(
$a:expr,
$b:expr,
$c:expr,
$d:expr,
$e:expr,
$f:expr,
$g:expr,
$h:expr
) => {
simd_shuffle8(a, b, [$a, $b, $c, $d, $e, $f, $g, $h])
};
}
macro_rules! blend3 {
($a:expr, $b:expr, $c:expr, $d:expr, $e:expr, $f:expr) => {
match (imm8 >> 6) & 0b11 {
0b00 => blend4!($a, $b, $c, $d, $e, $f, 6, 7),
0b01 => blend4!($a, $b, $c, $d, $e, $f, 14, 7),
0b10 => blend4!($a, $b, $c, $d, $e, $f, 6, 15),
_ => blend4!($a, $b, $c, $d, $e, $f, 14, 15),
}
};
}
macro_rules! blend2 {
($a:expr, $b:expr, $c:expr, $d:expr) => {
match (imm8 >> 4) & 0b11 {
0b00 => blend3!($a, $b, $c, $d, 4, 5),
0b01 => blend3!($a, $b, $c, $d, 12, 5),
0b10 => blend3!($a, $b, $c, $d, 4, 13),
_ => blend3!($a, $b, $c, $d, 12, 13),
}
};
}
macro_rules! blend1 {
($a:expr, $b:expr) => {
match (imm8 >> 2) & 0b11 {
0b00 => blend2!($a, $b, 2, 3),
0b01 => blend2!($a, $b, 10, 3),
0b10 => blend2!($a, $b, 2, 11),
_ => blend2!($a, $b, 10, 11),
}
};
}
match imm8 & 0b11 {
0b00 => blend1!(0, 1),
0b01 => blend1!(8, 1),
0b10 => blend1!(0, 9),
_ => blend1!(8, 9),
}
pub unsafe fn _mm256_blend_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
static_assert_imm8!(IMM8);
simd_shuffle8(
a,
b,
[
((IMM8 as u32 >> 0) & 1) * 8 + 0,
((IMM8 as u32 >> 1) & 1) * 8 + 1,
((IMM8 as u32 >> 2) & 1) * 8 + 2,
((IMM8 as u32 >> 3) & 1) * 8 + 3,
((IMM8 as u32 >> 4) & 1) * 8 + 4,
((IMM8 as u32 >> 5) & 1) * 8 + 5,
((IMM8 as u32 >> 6) & 1) * 8 + 6,
((IMM8 as u32 >> 7) & 1) * 8 + 7,
],
)
}
/// Blends packed double-precision (64-bit) floating-point elements from
@ -3378,11 +3319,11 @@ mod tests {
unsafe fn test_mm256_blend_pd() {
let a = _mm256_setr_pd(4., 9., 16., 25.);
let b = _mm256_setr_pd(4., 3., 2., 5.);
let r = _mm256_blend_pd(a, b, 0x0);
let r = _mm256_blend_pd::<0x0>(a, b);
assert_eq_m256d(r, _mm256_setr_pd(4., 9., 16., 25.));
let r = _mm256_blend_pd(a, b, 0x3);
let r = _mm256_blend_pd::<0x3>(a, b);
assert_eq_m256d(r, _mm256_setr_pd(4., 3., 16., 25.));
let r = _mm256_blend_pd(a, b, 0xF);
let r = _mm256_blend_pd::<0xF>(a, b);
assert_eq_m256d(r, _mm256_setr_pd(4., 3., 2., 5.));
}
@ -3390,11 +3331,11 @@ mod tests {
unsafe fn test_mm256_blend_ps() {
let a = _mm256_setr_ps(1., 4., 5., 8., 9., 12., 13., 16.);
let b = _mm256_setr_ps(2., 3., 6., 7., 10., 11., 14., 15.);
let r = _mm256_blend_ps(a, b, 0x0);
let r = _mm256_blend_ps::<0x0>(a, b);
assert_eq_m256(r, _mm256_setr_ps(1., 4., 5., 8., 9., 12., 13., 16.));
let r = _mm256_blend_ps(a, b, 0x3);
let r = _mm256_blend_ps::<0x3>(a, b);
assert_eq_m256(r, _mm256_setr_ps(2., 3., 5., 8., 9., 12., 13., 16.));
let r = _mm256_blend_ps(a, b, 0xF);
let r = _mm256_blend_ps::<0xF>(a, b);
assert_eq_m256(r, _mm256_setr_ps(2., 3., 6., 7., 9., 12., 13., 16.));
}