Added a bf16 type

This commit is contained in:
sayantn 2024-07-05 14:36:33 +05:30 committed by Amanieu d'Antras
parent 70fbc2e97c
commit c862e4e487
5 changed files with 52 additions and 21 deletions

View file

@ -486,9 +486,9 @@ pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
f32::from_bits((a as u32) << 16)
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_cvtsbh_ss(a: bf16) -> f32 {
f32::from_bits((a.to_bits() as u32) << 16)
}
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
@ -558,9 +558,10 @@ pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_cvtness_sbh(a: f32) -> bf16 {
let value: u16 = simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0);
bf16::from_bits(value)
}
#[cfg(test)]
@ -1910,7 +1911,7 @@ mod tests {
#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm_cvtsbh_ss() {
let r = _mm_cvtsbh_ss(BF16_ONE);
let r = _mm_cvtsbh_ss(bf16::from_bits(BF16_ONE));
assert_eq!(r, 1.);
}
@ -1944,6 +1945,6 @@ mod tests {
#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_cvtness_sbh() {
let r = _mm_cvtness_sbh(1.);
assert_eq!(r, BF16_ONE);
assert_eq!(r.to_bits(), BF16_ONE);
}
}

View file

@ -1,5 +1,5 @@
use crate::arch::asm;
use crate::core_arch::{simd::*, x86::*};
use crate::core_arch::x86::*;
#[cfg(test)]
use stdarch_test::assert_instr;
@ -15,9 +15,9 @@ use stdarch_test::assert_instr;
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vbcstnebf162ps)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
transmute(bcstnebf162ps_128(a))
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_bcstnebf16_ps(a: *const bf16) -> __m128 {
bcstnebf162ps_128(a)
}
/// Convert scalar BF16 (16-bit) floating point element stored at memory locations starting at location
@ -31,9 +31,9 @@ pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vbcstnebf162ps)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_bcstnebf16_ps(a: *const u16) -> __m256 {
transmute(bcstnebf162ps_256(a))
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm256_bcstnebf16_ps(a: *const bf16) -> __m256 {
bcstnebf162ps_256(a)
}
/// Convert packed BF16 (16-bit) floating-point even-indexed elements stored at memory locations starting at
@ -143,9 +143,9 @@ pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.x86.vbcstnebf162ps128"]
fn bcstnebf162ps_128(a: *const u16) -> f32x4;
fn bcstnebf162ps_128(a: *const bf16) -> __m128;
#[link_name = "llvm.x86.vbcstnebf162ps256"]
fn bcstnebf162ps_256(a: *const u16) -> f32x8;
fn bcstnebf162ps_256(a: *const bf16) -> __m256;
#[link_name = "llvm.x86.vcvtneebf162ps128"]
fn cvtneebf162ps_128(a: *const __m128bh) -> __m128;
@ -177,7 +177,7 @@ mod tests {
#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm_bcstnebf16_ps() {
let a = BF16_ONE;
let a = bf16::from_bits(BF16_ONE);
let r = _mm_bcstnebf16_ps(addr_of!(a));
let e = _mm_set_ps(1., 1., 1., 1.);
assert_eq_m128(r, e);
@ -185,7 +185,7 @@ mod tests {
#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm256_bcstnebf16_ps() {
let a = BF16_ONE;
let a = bf16::from_bits(BF16_ONE);
let r = _mm256_bcstnebf16_ps(addr_of!(a));
let e = _mm256_set_ps(1., 1., 1., 1., 1., 1., 1., 1.);
assert_eq_m256(r, e);

View file

@ -337,6 +337,31 @@ types! {
);
}
/// The BFloat16 type used in AVX-512 intrinsics.
#[repr(transparent)]
#[derive(Copy, Clone, Debug)]
#[allow(non_camel_case_types)]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub struct bf16(u16);
impl bf16 {
/// Raw transmutation from `u16`
#[inline]
#[must_use]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub const fn from_bits(bits: u16) -> bf16 {
bf16(bits)
}
/// Raw transmutation to `u16`
#[inline]
#[must_use = "this returns the result of the operation, without modifying the original"]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub const fn to_bits(self) -> u16 {
self.0
}
}
/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
#[allow(non_camel_case_types)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]

View file

@ -197,6 +197,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
"_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM },
"_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
"bool" => quote! { &BOOL },
"bf16" => quote! { &BF16 },
"f32" => quote! { &F32 },
"f64" => quote! { &F64 },
"i16" => quote! { &I16 },

View file

@ -22,6 +22,7 @@ struct Function {
has_test: bool,
}
static BF16: Type = Type::BFloat16;
static F32: Type = Type::PrimFloat(32);
static F64: Type = Type::PrimFloat(64);
static I8: Type = Type::PrimSigned(8);
@ -65,6 +66,7 @@ enum Type {
PrimFloat(u8),
PrimSigned(u8),
PrimUnsigned(u8),
BFloat16,
MutPtr(&'static Type),
ConstPtr(&'static Type),
M128,
@ -699,7 +701,8 @@ fn equate(
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
(&Type::PrimSigned(64), "__int64" | "long long") => {}
(&Type::PrimUnsigned(8), "unsigned char") => {}
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
(&Type::PrimUnsigned(16), "unsigned short") => {}
(&Type::BFloat16, "__bfloat16") => {}
(
&Type::PrimUnsigned(32),
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",
@ -758,9 +761,10 @@ fn equate(
(&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {}
(&Type::ConstPtr(&Type::PrimSigned(32)), "__int32 const*" | "int const*") => {}
(&Type::ConstPtr(&Type::PrimSigned(64)), "__int64 const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*" | "__bf16 const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(32)), "unsigned int const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(64)), "unsigned __int64 const*") => {}
(&Type::ConstPtr(&Type::BFloat16), "__bf16 const*") => {}
(&Type::ConstPtr(&Type::M128), "__m128 const*") => {}
(&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}