Use generic SIMD intrinsics for AVX maskload and maskstore intrinsics

This commit is contained in:
sayantn 2025-11-06 06:26:16 +05:30
parent 9126145419
commit 7ea8483696
No known key found for this signature in database
GPG key ID: B60412E056614AA4
2 changed files with 32 additions and 48 deletions

View file

@ -1675,7 +1675,8 @@ pub unsafe fn _mm256_storeu_si256(mem_addr: *mut __m256i, a: __m256i) {
#[cfg_attr(test, assert_instr(vmaskmovpd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskload_pd(mem_addr: *const f64, mask: __m256i) -> __m256d {
maskloadpd256(mem_addr as *const i8, mask.as_i64x4())
let mask = simd_shr(mask.as_i64x4(), i64x4::splat(63));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, _mm256_setzero_pd())
}
/// Stores packed double-precision (64-bit) floating-point elements from `a`
@ -1687,7 +1688,8 @@ pub unsafe fn _mm256_maskload_pd(mem_addr: *const f64, mask: __m256i) -> __m256d
#[cfg_attr(test, assert_instr(vmaskmovpd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskstore_pd(mem_addr: *mut f64, mask: __m256i, a: __m256d) {
maskstorepd256(mem_addr as *mut i8, mask.as_i64x4(), a);
let mask = simd_shr(mask.as_i64x4(), i64x4::splat(63));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a)
}
/// Loads packed double-precision (64-bit) floating-point elements from memory
@ -1700,7 +1702,8 @@ pub unsafe fn _mm256_maskstore_pd(mem_addr: *mut f64, mask: __m256i, a: __m256d)
#[cfg_attr(test, assert_instr(vmaskmovpd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskload_pd(mem_addr: *const f64, mask: __m128i) -> __m128d {
maskloadpd(mem_addr as *const i8, mask.as_i64x2())
let mask = simd_shr(mask.as_i64x2(), i64x2::splat(63));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, _mm_setzero_pd())
}
/// Stores packed double-precision (64-bit) floating-point elements from `a`
@ -1712,7 +1715,8 @@ pub unsafe fn _mm_maskload_pd(mem_addr: *const f64, mask: __m128i) -> __m128d {
#[cfg_attr(test, assert_instr(vmaskmovpd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskstore_pd(mem_addr: *mut f64, mask: __m128i, a: __m128d) {
maskstorepd(mem_addr as *mut i8, mask.as_i64x2(), a);
let mask = simd_shr(mask.as_i64x2(), i64x2::splat(63));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a)
}
/// Loads packed single-precision (32-bit) floating-point elements from memory
@ -1725,7 +1729,8 @@ pub unsafe fn _mm_maskstore_pd(mem_addr: *mut f64, mask: __m128i, a: __m128d) {
#[cfg_attr(test, assert_instr(vmaskmovps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskload_ps(mem_addr: *const f32, mask: __m256i) -> __m256 {
maskloadps256(mem_addr as *const i8, mask.as_i32x8())
let mask = simd_shr(mask.as_i32x8(), i32x8::splat(31));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, _mm256_setzero_ps())
}
/// Stores packed single-precision (32-bit) floating-point elements from `a`
@ -1737,7 +1742,8 @@ pub unsafe fn _mm256_maskload_ps(mem_addr: *const f32, mask: __m256i) -> __m256
#[cfg_attr(test, assert_instr(vmaskmovps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskstore_ps(mem_addr: *mut f32, mask: __m256i, a: __m256) {
maskstoreps256(mem_addr as *mut i8, mask.as_i32x8(), a);
let mask = simd_shr(mask.as_i32x8(), i32x8::splat(31));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a)
}
/// Loads packed single-precision (32-bit) floating-point elements from memory
@ -1750,7 +1756,8 @@ pub unsafe fn _mm256_maskstore_ps(mem_addr: *mut f32, mask: __m256i, a: __m256)
#[cfg_attr(test, assert_instr(vmaskmovps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskload_ps(mem_addr: *const f32, mask: __m128i) -> __m128 {
maskloadps(mem_addr as *const i8, mask.as_i32x4())
let mask = simd_shr(mask.as_i32x4(), i32x4::splat(31));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, _mm_setzero_ps())
}
/// Stores packed single-precision (32-bit) floating-point elements from `a`
@ -1762,7 +1769,8 @@ pub unsafe fn _mm_maskload_ps(mem_addr: *const f32, mask: __m128i) -> __m128 {
#[cfg_attr(test, assert_instr(vmaskmovps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskstore_ps(mem_addr: *mut f32, mask: __m128i, a: __m128) {
maskstoreps(mem_addr as *mut i8, mask.as_i32x4(), a);
let mask = simd_shr(mask.as_i32x4(), i32x4::splat(31));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a)
}
/// Duplicate odd-indexed single-precision (32-bit) floating-point elements
@ -3147,22 +3155,6 @@ unsafe extern "C" {
fn vpermilpd256(a: __m256d, b: i64x4) -> __m256d;
#[link_name = "llvm.x86.avx.vpermilvar.pd"]
fn vpermilpd(a: __m128d, b: i64x2) -> __m128d;
#[link_name = "llvm.x86.avx.maskload.pd.256"]
fn maskloadpd256(mem_addr: *const i8, mask: i64x4) -> __m256d;
#[link_name = "llvm.x86.avx.maskstore.pd.256"]
fn maskstorepd256(mem_addr: *mut i8, mask: i64x4, a: __m256d);
#[link_name = "llvm.x86.avx.maskload.pd"]
fn maskloadpd(mem_addr: *const i8, mask: i64x2) -> __m128d;
#[link_name = "llvm.x86.avx.maskstore.pd"]
fn maskstorepd(mem_addr: *mut i8, mask: i64x2, a: __m128d);
#[link_name = "llvm.x86.avx.maskload.ps.256"]
fn maskloadps256(mem_addr: *const i8, mask: i32x8) -> __m256;
#[link_name = "llvm.x86.avx.maskstore.ps.256"]
fn maskstoreps256(mem_addr: *mut i8, mask: i32x8, a: __m256);
#[link_name = "llvm.x86.avx.maskload.ps"]
fn maskloadps(mem_addr: *const i8, mask: i32x4) -> __m128;
#[link_name = "llvm.x86.avx.maskstore.ps"]
fn maskstoreps(mem_addr: *mut i8, mask: i32x4, a: __m128);
#[link_name = "llvm.x86.avx.ldu.dq.256"]
fn vlddqu(mem_addr: *const i8) -> i8x32;
#[link_name = "llvm.x86.avx.rcp.ps.256"]

View file

@ -1786,7 +1786,8 @@ pub fn _mm256_maddubs_epi16(a: __m256i, b: __m256i) -> __m256i {
#[cfg_attr(test, assert_instr(vpmaskmovd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskload_epi32(mem_addr: *const i32, mask: __m128i) -> __m128i {
transmute(maskloadd(mem_addr as *const i8, mask.as_i32x4()))
let mask = simd_shr(mask.as_i32x4(), i32x4::splat(31));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, i32x4::ZERO).as_m128i()
}
/// Loads packed 32-bit integers from memory pointed by `mem_addr` using `mask`
@ -1799,7 +1800,8 @@ pub unsafe fn _mm_maskload_epi32(mem_addr: *const i32, mask: __m128i) -> __m128i
#[cfg_attr(test, assert_instr(vpmaskmovd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskload_epi32(mem_addr: *const i32, mask: __m256i) -> __m256i {
transmute(maskloadd256(mem_addr as *const i8, mask.as_i32x8()))
let mask = simd_shr(mask.as_i32x8(), i32x8::splat(31));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, i32x8::ZERO).as_m256i()
}
/// Loads packed 64-bit integers from memory pointed by `mem_addr` using `mask`
@ -1812,7 +1814,8 @@ pub unsafe fn _mm256_maskload_epi32(mem_addr: *const i32, mask: __m256i) -> __m2
#[cfg_attr(test, assert_instr(vpmaskmovq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskload_epi64(mem_addr: *const i64, mask: __m128i) -> __m128i {
transmute(maskloadq(mem_addr as *const i8, mask.as_i64x2()))
let mask = simd_shr(mask.as_i64x2(), i64x2::splat(63));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, i64x2::ZERO).as_m128i()
}
/// Loads packed 64-bit integers from memory pointed by `mem_addr` using `mask`
@ -1825,7 +1828,8 @@ pub unsafe fn _mm_maskload_epi64(mem_addr: *const i64, mask: __m128i) -> __m128i
#[cfg_attr(test, assert_instr(vpmaskmovq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskload_epi64(mem_addr: *const i64, mask: __m256i) -> __m256i {
transmute(maskloadq256(mem_addr as *const i8, mask.as_i64x4()))
let mask = simd_shr(mask.as_i64x4(), i64x4::splat(63));
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, i64x4::ZERO).as_m256i()
}
/// Stores packed 32-bit integers from `a` into memory pointed by `mem_addr`
@ -1838,7 +1842,8 @@ pub unsafe fn _mm256_maskload_epi64(mem_addr: *const i64, mask: __m256i) -> __m2
#[cfg_attr(test, assert_instr(vpmaskmovd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskstore_epi32(mem_addr: *mut i32, mask: __m128i, a: __m128i) {
maskstored(mem_addr as *mut i8, mask.as_i32x4(), a.as_i32x4())
let mask = simd_shr(mask.as_i32x4(), i32x4::splat(31));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i32x4())
}
/// Stores packed 32-bit integers from `a` into memory pointed by `mem_addr`
@ -1851,7 +1856,8 @@ pub unsafe fn _mm_maskstore_epi32(mem_addr: *mut i32, mask: __m128i, a: __m128i)
#[cfg_attr(test, assert_instr(vpmaskmovd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskstore_epi32(mem_addr: *mut i32, mask: __m256i, a: __m256i) {
maskstored256(mem_addr as *mut i8, mask.as_i32x8(), a.as_i32x8())
let mask = simd_shr(mask.as_i32x8(), i32x8::splat(31));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i32x8())
}
/// Stores packed 64-bit integers from `a` into memory pointed by `mem_addr`
@ -1864,7 +1870,8 @@ pub unsafe fn _mm256_maskstore_epi32(mem_addr: *mut i32, mask: __m256i, a: __m25
#[cfg_attr(test, assert_instr(vpmaskmovq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_maskstore_epi64(mem_addr: *mut i64, mask: __m128i, a: __m128i) {
maskstoreq(mem_addr as *mut i8, mask.as_i64x2(), a.as_i64x2())
let mask = simd_shr(mask.as_i64x2(), i64x2::splat(63));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i64x2())
}
/// Stores packed 64-bit integers from `a` into memory pointed by `mem_addr`
@ -1877,7 +1884,8 @@ pub unsafe fn _mm_maskstore_epi64(mem_addr: *mut i64, mask: __m128i, a: __m128i)
#[cfg_attr(test, assert_instr(vpmaskmovq))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_maskstore_epi64(mem_addr: *mut i64, mask: __m256i, a: __m256i) {
maskstoreq256(mem_addr as *mut i8, mask.as_i64x4(), a.as_i64x4())
let mask = simd_shr(mask.as_i64x4(), i64x4::splat(63));
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i64x4())
}
/// Compares packed 16-bit integers in `a` and `b`, and returns the packed
@ -3645,22 +3653,6 @@ unsafe extern "C" {
fn phsubsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.pmadd.ub.sw"]
fn pmaddubsw(a: u8x32, b: u8x32) -> i16x16;
#[link_name = "llvm.x86.avx2.maskload.d"]
fn maskloadd(mem_addr: *const i8, mask: i32x4) -> i32x4;
#[link_name = "llvm.x86.avx2.maskload.d.256"]
fn maskloadd256(mem_addr: *const i8, mask: i32x8) -> i32x8;
#[link_name = "llvm.x86.avx2.maskload.q"]
fn maskloadq(mem_addr: *const i8, mask: i64x2) -> i64x2;
#[link_name = "llvm.x86.avx2.maskload.q.256"]
fn maskloadq256(mem_addr: *const i8, mask: i64x4) -> i64x4;
#[link_name = "llvm.x86.avx2.maskstore.d"]
fn maskstored(mem_addr: *mut i8, mask: i32x4, a: i32x4);
#[link_name = "llvm.x86.avx2.maskstore.d.256"]
fn maskstored256(mem_addr: *mut i8, mask: i32x8, a: i32x8);
#[link_name = "llvm.x86.avx2.maskstore.q"]
fn maskstoreq(mem_addr: *mut i8, mask: i64x2, a: i64x2);
#[link_name = "llvm.x86.avx2.maskstore.q.256"]
fn maskstoreq256(mem_addr: *mut i8, mask: i64x4, a: i64x4);
#[link_name = "llvm.x86.avx2.mpsadbw"]
fn mpsadbw(a: u8x32, b: u8x32, imm8: i8) -> u16x16;
#[link_name = "llvm.x86.avx2.pmul.hr.sw"]