Update fma.rs

This commit is contained in:
Tobias Decking 2024-06-16 13:10:24 +02:00 committed by Amanieu d'Antras
parent 61fb419a9b
commit b683da6b0b

View file

@ -83,7 +83,11 @@ pub unsafe fn _mm256_fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfmaddsd(a, b, c)
simd_insert!(
a,
0,
_mm_cvtsd_f64(a).mul_add(_mm_cvtsd_f64(b), _mm_cvtsd_f64(c))
)
}
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@ -97,7 +101,11 @@ pub unsafe fn _mm_fmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
vfmaddss(a, b, c)
simd_insert!(
a,
0,
_mm_cvtss_f32(a).mul_add(_mm_cvtss_f32(b), _mm_cvtss_f32(c))
)
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -161,7 +169,7 @@ pub unsafe fn _mm256_fmaddsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfmsubpd(a, b, c)
simd_fma(a, b, simd_neg(c))
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -173,7 +181,7 @@ pub unsafe fn _mm_fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
vfmsubpd256(a, b, c)
simd_fma(a, b, simd_neg(c))
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -185,7 +193,7 @@ pub unsafe fn _mm256_fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
#[cfg_attr(test, assert_instr(vfmsub213ps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
vfmsubps(a, b, c)
simd_fma(a, b, simd_neg(c))
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -197,7 +205,7 @@ pub unsafe fn _mm_fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
#[cfg_attr(test, assert_instr(vfmsub213ps))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
vfmsubps256(a, b, c)
simd_fma(a, b, simd_neg(c))
}
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@ -211,7 +219,11 @@ pub unsafe fn _mm256_fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfmsubsd(a, b, c)
simd_insert!(
a,
0,
_mm_cvtsd_f64(a).mul_add(_mm_cvtsd_f64(b), -_mm_cvtsd_f64(c))
)
}
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@ -225,7 +237,11 @@ pub unsafe fn _mm_fmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fmsub_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
vfmsubss(a, b, c)
simd_insert!(
a,
0,
_mm_cvtss_f32(a).mul_add(_mm_cvtss_f32(b), -_mm_cvtss_f32(c))
)
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -289,7 +305,7 @@ pub unsafe fn _mm256_fmsubadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfnmaddpd(a, b, c)
simd_fma(simd_neg(a), b, c)
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -301,7 +317,7 @@ pub unsafe fn _mm_fnmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fnmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
vfnmaddpd256(a, b, c)
simd_fma(simd_neg(a), b, c)
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -313,7 +329,7 @@ pub unsafe fn _mm256_fnmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
vfnmaddps(a, b, c)
simd_fma(simd_neg(a), b, c)
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -325,7 +341,7 @@ pub unsafe fn _mm_fnmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fnmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
vfnmaddps256(a, b, c)
simd_fma(simd_neg(a), b, c)
}
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@ -339,7 +355,11 @@ pub unsafe fn _mm256_fnmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfnmaddsd(a, b, c)
simd_insert!(
a,
0,
_mm_cvtsd_f64(a).mul_add(-_mm_cvtsd_f64(b), _mm_cvtsd_f64(c))
)
}
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@ -353,7 +373,11 @@ pub unsafe fn _mm_fnmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfnmadd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
vfnmaddss(a, b, c)
simd_insert!(
a,
0,
_mm_cvtss_f32(a).mul_add(-_mm_cvtss_f32(b), _mm_cvtss_f32(c))
)
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -366,7 +390,7 @@ pub unsafe fn _mm_fnmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfnmsubpd(a, b, c)
simd_fma(simd_neg(a), b, simd_neg(c))
}
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@ -379,7 +403,7 @@ pub unsafe fn _mm_fnmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fnmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
vfnmsubpd256(a, b, c)
simd_fma(simd_neg(a), b, simd_neg(c))
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -392,7 +416,7 @@ pub unsafe fn _mm256_fnmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
vfnmsubps(a, b, c)
simd_fma(simd_neg(a), b, simd_neg(c))
}
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@ -405,7 +429,7 @@ pub unsafe fn _mm_fnmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_fnmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
vfnmsubps256(a, b, c)
simd_fma(simd_neg(a), b, simd_neg(c))
}
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@ -420,7 +444,11 @@ pub unsafe fn _mm256_fnmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
vfnmsubsd(a, b, c)
simd_insert!(
a,
0,
_mm_cvtsd_f64(a).mul_add(-_mm_cvtsd_f64(b), -_mm_cvtsd_f64(c))
)
}
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@ -435,15 +463,15 @@ pub unsafe fn _mm_fnmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
#[cfg_attr(test, assert_instr(vfnmsub))]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm_fnmsub_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
vfnmsubss(a, b, c)
simd_insert!(
a,
0,
_mm_cvtss_f32(a).mul_add(-_mm_cvtss_f32(b), -_mm_cvtss_f32(c))
)
}
#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.x86.fma.vfmadd.sd"]
fn vfmaddsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfmadd.ss"]
fn vfmaddss(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfmaddsub.pd"]
fn vfmaddsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfmaddsub.pd.256"]
@ -452,18 +480,6 @@ extern "C" {
fn vfmaddsubps(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfmaddsub.ps.256"]
fn vfmaddsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
#[link_name = "llvm.x86.fma.vfmsub.pd"]
fn vfmsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfmsub.pd.256"]
fn vfmsubpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
#[link_name = "llvm.x86.fma.vfmsub.ps"]
fn vfmsubps(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfmsub.ps.256"]
fn vfmsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
#[link_name = "llvm.x86.fma.vfmsub.sd"]
fn vfmsubsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfmsub.ss"]
fn vfmsubss(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfmsubadd.pd"]
fn vfmsubaddpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfmsubadd.pd.256"]
@ -472,30 +488,6 @@ extern "C" {
fn vfmsubaddps(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfmsubadd.ps.256"]
fn vfmsubaddps256(a: __m256, b: __m256, c: __m256) -> __m256;
#[link_name = "llvm.x86.fma.vfnmadd.pd"]
fn vfnmaddpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfnmadd.pd.256"]
fn vfnmaddpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
#[link_name = "llvm.x86.fma.vfnmadd.ps"]
fn vfnmaddps(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfnmadd.ps.256"]
fn vfnmaddps256(a: __m256, b: __m256, c: __m256) -> __m256;
#[link_name = "llvm.x86.fma.vfnmadd.sd"]
fn vfnmaddsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfnmadd.ss"]
fn vfnmaddss(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfnmsub.pd"]
fn vfnmsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfnmsub.pd.256"]
fn vfnmsubpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
#[link_name = "llvm.x86.fma.vfnmsub.ps"]
fn vfnmsubps(a: __m128, b: __m128, c: __m128) -> __m128;
#[link_name = "llvm.x86.fma.vfnmsub.ps.256"]
fn vfnmsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
#[link_name = "llvm.x86.fma.vfnmsub.sd"]
fn vfnmsubsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
#[link_name = "llvm.x86.fma.vfnmsub.ss"]
fn vfnmsubss(a: __m128, b: __m128, c: __m128) -> __m128;
}
#[cfg(test)]