Convert some AVX intrinsics to const generics

* _mm256_extractf128_ps
* _mm256_extractf128_pd
* _mm256_extractf128_si256
* _mm256_insertf128_ps
* _mm256_insertf128_pd
* _mm256_insertf128_si256
This commit is contained in:
Tomasz Miąsko 2021-03-04 00:00:00 +00:00 committed by Amanieu d'Antras
parent d4952b2084
commit 3671f418ed

View file

@ -983,15 +983,17 @@ pub unsafe fn _mm256_cvttps_epi32(a: __m256) -> __m256i {
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vextractf128, imm8 = 1)
assert_instr(vextractf128, IMM1 = 1)
)]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_extractf128_ps(a: __m256, imm8: i32) -> __m128 {
match imm8 & 1 {
0 => simd_shuffle4(a, _mm256_undefined_ps(), [0, 1, 2, 3]),
_ => simd_shuffle4(a, _mm256_undefined_ps(), [4, 5, 6, 7]),
}
pub unsafe fn _mm256_extractf128_ps<const IMM1: i32>(a: __m256) -> __m128 {
static_assert_imm1!(IMM1);
simd_shuffle4(
a,
_mm256_undefined_ps(),
[[0, 1, 2, 3], [4, 5, 6, 7]][IMM1 as usize],
)
}
/// Extracts 128 bits (composed of 2 packed double-precision (64-bit)
@ -1002,15 +1004,13 @@ pub unsafe fn _mm256_extractf128_ps(a: __m256, imm8: i32) -> __m128 {
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vextractf128, imm8 = 1)
assert_instr(vextractf128, IMM1 = 1)
)]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_extractf128_pd(a: __m256d, imm8: i32) -> __m128d {
match imm8 & 1 {
0 => simd_shuffle2(a, _mm256_undefined_pd(), [0, 1]),
_ => simd_shuffle2(a, _mm256_undefined_pd(), [2, 3]),
}
pub unsafe fn _mm256_extractf128_pd<const IMM1: i32>(a: __m256d) -> __m128d {
static_assert_imm1!(IMM1);
simd_shuffle2(a, _mm256_undefined_pd(), [[0, 1], [2, 3]][IMM1 as usize])
}
/// Extracts 128 bits (composed of integer data) from `a`, selected with `imm8`.
@ -1020,16 +1020,17 @@ pub unsafe fn _mm256_extractf128_pd(a: __m256d, imm8: i32) -> __m128d {
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vextractf128, imm8 = 1)
assert_instr(vextractf128, IMM1 = 1)
)]
#[rustc_args_required_const(1)]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_extractf128_si256(a: __m256i, imm8: i32) -> __m128i {
let b = _mm256_undefined_si256().as_i64x4();
let dst: i64x2 = match imm8 & 1 {
0 => simd_shuffle2(a.as_i64x4(), b, [0, 1]),
_ => simd_shuffle2(a.as_i64x4(), b, [2, 3]),
};
pub unsafe fn _mm256_extractf128_si256<const IMM1: i32>(a: __m256i) -> __m128i {
static_assert_imm1!(IMM1);
let dst: i64x2 = simd_shuffle2(
a.as_i64x4(),
_mm256_undefined_si256().as_i64x4(),
[[0, 1], [2, 3]][IMM1 as usize],
);
transmute(dst)
}
@ -1410,16 +1411,17 @@ pub unsafe fn _mm256_broadcast_pd(a: &__m128d) -> __m256d {
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vinsertf128, imm8 = 1)
assert_instr(vinsertf128, IMM1 = 1)
)]
#[rustc_args_required_const(2)]
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_insertf128_ps(a: __m256, b: __m128, imm8: i32) -> __m256 {
let b = _mm256_castps128_ps256(b);
match imm8 & 1 {
0 => simd_shuffle8(a, b, [8, 9, 10, 11, 4, 5, 6, 7]),
_ => simd_shuffle8(a, b, [0, 1, 2, 3, 8, 9, 10, 11]),
}
pub unsafe fn _mm256_insertf128_ps<const IMM1: i32>(a: __m256, b: __m128) -> __m256 {
static_assert_imm1!(IMM1);
simd_shuffle8(
a,
_mm256_castps128_ps256(b),
[[8, 9, 10, 11, 4, 5, 6, 7], [0, 1, 2, 3, 8, 9, 10, 11]][IMM1 as usize],
)
}
/// Copies `a` to result, then inserts 128 bits (composed of 2 packed
@ -1431,15 +1433,17 @@ pub unsafe fn _mm256_insertf128_ps(a: __m256, b: __m128, imm8: i32) -> __m256 {
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vinsertf128, imm8 = 1)
assert_instr(vinsertf128, IMM1 = 1)
)]
#[rustc_args_required_const(2)]
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_insertf128_pd(a: __m256d, b: __m128d, imm8: i32) -> __m256d {
match imm8 & 1 {
0 => simd_shuffle4(a, _mm256_castpd128_pd256(b), [4, 5, 2, 3]),
_ => simd_shuffle4(a, _mm256_castpd128_pd256(b), [0, 1, 4, 5]),
}
pub unsafe fn _mm256_insertf128_pd<const IMM1: i32>(a: __m256d, b: __m128d) -> __m256d {
static_assert_imm1!(IMM1);
simd_shuffle4(
a,
_mm256_castpd128_pd256(b),
[[4, 5, 2, 3], [0, 1, 4, 5]][IMM1 as usize],
)
}
/// Copies `a` to result, then inserts 128 bits from `b` into result
@ -1450,16 +1454,17 @@ pub unsafe fn _mm256_insertf128_pd(a: __m256d, b: __m128d, imm8: i32) -> __m256d
#[target_feature(enable = "avx")]
#[cfg_attr(
all(test, not(target_os = "windows")),
assert_instr(vinsertf128, imm8 = 1)
assert_instr(vinsertf128, IMM1 = 1)
)]
#[rustc_args_required_const(2)]
#[rustc_legacy_const_generics(2)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_insertf128_si256(a: __m256i, b: __m128i, imm8: i32) -> __m256i {
let b = _mm256_castsi128_si256(b).as_i64x4();
let dst: i64x4 = match imm8 & 1 {
0 => simd_shuffle4(a.as_i64x4(), b, [4, 5, 2, 3]),
_ => simd_shuffle4(a.as_i64x4(), b, [0, 1, 4, 5]),
};
pub unsafe fn _mm256_insertf128_si256<const IMM1: i32>(a: __m256i, b: __m128i) -> __m256i {
static_assert_imm1!(IMM1);
let dst: i64x4 = simd_shuffle4(
a.as_i64x4(),
_mm256_castsi128_si256(b).as_i64x4(),
[[4, 5, 2, 3], [0, 1, 4, 5]][IMM1 as usize],
);
transmute(dst)
}
@ -2961,7 +2966,7 @@ pub unsafe fn _mm256_setr_m128i(lo: __m128i, hi: __m128i) -> __m256i {
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_loadu2_m128(hiaddr: *const f32, loaddr: *const f32) -> __m256 {
let a = _mm256_castps128_ps256(_mm_loadu_ps(loaddr));
_mm256_insertf128_ps(a, _mm_loadu_ps(hiaddr), 1)
_mm256_insertf128_ps::<1>(a, _mm_loadu_ps(hiaddr))
}
/// Loads two 128-bit values (composed of 2 packed double-precision (64-bit)
@ -2976,7 +2981,7 @@ pub unsafe fn _mm256_loadu2_m128(hiaddr: *const f32, loaddr: *const f32) -> __m2
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_loadu2_m128d(hiaddr: *const f64, loaddr: *const f64) -> __m256d {
let a = _mm256_castpd128_pd256(_mm_loadu_pd(loaddr));
_mm256_insertf128_pd(a, _mm_loadu_pd(hiaddr), 1)
_mm256_insertf128_pd::<1>(a, _mm_loadu_pd(hiaddr))
}
/// Loads two 128-bit values (composed of integer data) from memory, and combine
@ -2990,7 +2995,7 @@ pub unsafe fn _mm256_loadu2_m128d(hiaddr: *const f64, loaddr: *const f64) -> __m
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_loadu2_m128i(hiaddr: *const __m128i, loaddr: *const __m128i) -> __m256i {
let a = _mm256_castsi128_si256(_mm_loadu_si128(loaddr));
_mm256_insertf128_si256(a, _mm_loadu_si128(hiaddr), 1)
_mm256_insertf128_si256::<1>(a, _mm_loadu_si128(hiaddr))
}
/// Stores the high and low 128-bit halves (each composed of 4 packed
@ -3006,7 +3011,7 @@ pub unsafe fn _mm256_loadu2_m128i(hiaddr: *const __m128i, loaddr: *const __m128i
pub unsafe fn _mm256_storeu2_m128(hiaddr: *mut f32, loaddr: *mut f32, a: __m256) {
let lo = _mm256_castps256_ps128(a);
_mm_storeu_ps(loaddr, lo);
let hi = _mm256_extractf128_ps(a, 1);
let hi = _mm256_extractf128_ps::<1>(a);
_mm_storeu_ps(hiaddr, hi);
}
@ -3023,7 +3028,7 @@ pub unsafe fn _mm256_storeu2_m128(hiaddr: *mut f32, loaddr: *mut f32, a: __m256)
pub unsafe fn _mm256_storeu2_m128d(hiaddr: *mut f64, loaddr: *mut f64, a: __m256d) {
let lo = _mm256_castpd256_pd128(a);
_mm_storeu_pd(loaddr, lo);
let hi = _mm256_extractf128_pd(a, 1);
let hi = _mm256_extractf128_pd::<1>(a);
_mm_storeu_pd(hiaddr, hi);
}
@ -3039,7 +3044,7 @@ pub unsafe fn _mm256_storeu2_m128d(hiaddr: *mut f64, loaddr: *mut f64, a: __m256
pub unsafe fn _mm256_storeu2_m128i(hiaddr: *mut __m128i, loaddr: *mut __m128i, a: __m256i) {
let lo = _mm256_castsi256_si128(a);
_mm_storeu_si128(loaddr, lo);
let hi = _mm256_extractf128_si256(a, 1);
let hi = _mm256_extractf128_si256::<1>(a);
_mm_storeu_si128(hiaddr, hi);
}
@ -3727,7 +3732,7 @@ mod tests {
#[simd_test(enable = "avx")]
unsafe fn test_mm256_extractf128_ps() {
let a = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.);
let r = _mm256_extractf128_ps(a, 0);
let r = _mm256_extractf128_ps::<0>(a);
let e = _mm_setr_ps(4., 3., 2., 5.);
assert_eq_m128(r, e);
}
@ -3735,7 +3740,7 @@ mod tests {
#[simd_test(enable = "avx")]
unsafe fn test_mm256_extractf128_pd() {
let a = _mm256_setr_pd(4., 3., 2., 5.);
let r = _mm256_extractf128_pd(a, 0);
let r = _mm256_extractf128_pd::<0>(a);
let e = _mm_setr_pd(4., 3.);
assert_eq_m128d(r, e);
}
@ -3743,7 +3748,7 @@ mod tests {
#[simd_test(enable = "avx")]
unsafe fn test_mm256_extractf128_si256() {
let a = _mm256_setr_epi64x(4, 3, 2, 5);
let r = _mm256_extractf128_si256(a, 0);
let r = _mm256_extractf128_si256::<0>(a);
let e = _mm_setr_epi64x(4, 3);
assert_eq_m128i(r, e);
}
@ -3894,7 +3899,7 @@ mod tests {
unsafe fn test_mm256_insertf128_ps() {
let a = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.);
let b = _mm_setr_ps(4., 9., 16., 25.);
let r = _mm256_insertf128_ps(a, b, 0);
let r = _mm256_insertf128_ps::<0>(a, b);
let e = _mm256_setr_ps(4., 9., 16., 25., 8., 9., 64., 50.);
assert_eq_m256(r, e);
}
@ -3903,7 +3908,7 @@ mod tests {
unsafe fn test_mm256_insertf128_pd() {
let a = _mm256_setr_pd(1., 2., 3., 4.);
let b = _mm_setr_pd(5., 6.);
let r = _mm256_insertf128_pd(a, b, 0);
let r = _mm256_insertf128_pd::<0>(a, b);
let e = _mm256_setr_pd(5., 6., 3., 4.);
assert_eq_m256d(r, e);
}
@ -3912,7 +3917,7 @@ mod tests {
unsafe fn test_mm256_insertf128_si256() {
let a = _mm256_setr_epi64x(1, 2, 3, 4);
let b = _mm_setr_epi64x(5, 6);
let r = _mm256_insertf128_si256(a, b, 0);
let r = _mm256_insertf128_si256::<0>(a, b);
let e = _mm256_setr_epi64x(5, 6, 3, 4);
assert_eq_m256i(r, e);
}