Added a bf16 type
This commit is contained in:
parent
70fbc2e97c
commit
c862e4e487
5 changed files with 52 additions and 21 deletions
|
|
@ -486,9 +486,9 @@ pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
|
|||
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512bf16,avx512f")]
|
||||
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
|
||||
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
|
||||
f32::from_bits((a as u32) << 16)
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub unsafe fn _mm_cvtsbh_ss(a: bf16) -> f32 {
|
||||
f32::from_bits((a.to_bits() as u32) << 16)
|
||||
}
|
||||
|
||||
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
|
||||
|
|
@ -558,9 +558,10 @@ pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
|
|||
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
|
||||
#[inline]
|
||||
#[target_feature(enable = "avx512bf16,avx512vl")]
|
||||
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
|
||||
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
|
||||
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub unsafe fn _mm_cvtness_sbh(a: f32) -> bf16 {
|
||||
let value: u16 = simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0);
|
||||
bf16::from_bits(value)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -1910,7 +1911,7 @@ mod tests {
|
|||
|
||||
#[simd_test(enable = "avx512bf16")]
|
||||
unsafe fn test_mm_cvtsbh_ss() {
|
||||
let r = _mm_cvtsbh_ss(BF16_ONE);
|
||||
let r = _mm_cvtsbh_ss(bf16::from_bits(BF16_ONE));
|
||||
assert_eq!(r, 1.);
|
||||
}
|
||||
|
||||
|
|
@ -1944,6 +1945,6 @@ mod tests {
|
|||
#[simd_test(enable = "avx512bf16,avx512vl")]
|
||||
unsafe fn test_mm_cvtness_sbh() {
|
||||
let r = _mm_cvtness_sbh(1.);
|
||||
assert_eq!(r, BF16_ONE);
|
||||
assert_eq!(r.to_bits(), BF16_ONE);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::arch::asm;
|
||||
use crate::core_arch::{simd::*, x86::*};
|
||||
use crate::core_arch::x86::*;
|
||||
|
||||
#[cfg(test)]
|
||||
use stdarch_test::assert_instr;
|
||||
|
|
@ -15,9 +15,9 @@ use stdarch_test::assert_instr;
|
|||
all(test, any(target_os = "linux", target_env = "msvc")),
|
||||
assert_instr(vbcstnebf162ps)
|
||||
)]
|
||||
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
|
||||
pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
|
||||
transmute(bcstnebf162ps_128(a))
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub unsafe fn _mm_bcstnebf16_ps(a: *const bf16) -> __m128 {
|
||||
bcstnebf162ps_128(a)
|
||||
}
|
||||
|
||||
/// Convert scalar BF16 (16-bit) floating point element stored at memory locations starting at location
|
||||
|
|
@ -31,9 +31,9 @@ pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
|
|||
all(test, any(target_os = "linux", target_env = "msvc")),
|
||||
assert_instr(vbcstnebf162ps)
|
||||
)]
|
||||
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
|
||||
pub unsafe fn _mm256_bcstnebf16_ps(a: *const u16) -> __m256 {
|
||||
transmute(bcstnebf162ps_256(a))
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub unsafe fn _mm256_bcstnebf16_ps(a: *const bf16) -> __m256 {
|
||||
bcstnebf162ps_256(a)
|
||||
}
|
||||
|
||||
/// Convert packed BF16 (16-bit) floating-point even-indexed elements stored at memory locations starting at
|
||||
|
|
@ -143,9 +143,9 @@ pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
|
|||
#[allow(improper_ctypes)]
|
||||
extern "C" {
|
||||
#[link_name = "llvm.x86.vbcstnebf162ps128"]
|
||||
fn bcstnebf162ps_128(a: *const u16) -> f32x4;
|
||||
fn bcstnebf162ps_128(a: *const bf16) -> __m128;
|
||||
#[link_name = "llvm.x86.vbcstnebf162ps256"]
|
||||
fn bcstnebf162ps_256(a: *const u16) -> f32x8;
|
||||
fn bcstnebf162ps_256(a: *const bf16) -> __m256;
|
||||
|
||||
#[link_name = "llvm.x86.vcvtneebf162ps128"]
|
||||
fn cvtneebf162ps_128(a: *const __m128bh) -> __m128;
|
||||
|
|
@ -177,7 +177,7 @@ mod tests {
|
|||
|
||||
#[simd_test(enable = "avxneconvert")]
|
||||
unsafe fn test_mm_bcstnebf16_ps() {
|
||||
let a = BF16_ONE;
|
||||
let a = bf16::from_bits(BF16_ONE);
|
||||
let r = _mm_bcstnebf16_ps(addr_of!(a));
|
||||
let e = _mm_set_ps(1., 1., 1., 1.);
|
||||
assert_eq_m128(r, e);
|
||||
|
|
@ -185,7 +185,7 @@ mod tests {
|
|||
|
||||
#[simd_test(enable = "avxneconvert")]
|
||||
unsafe fn test_mm256_bcstnebf16_ps() {
|
||||
let a = BF16_ONE;
|
||||
let a = bf16::from_bits(BF16_ONE);
|
||||
let r = _mm256_bcstnebf16_ps(addr_of!(a));
|
||||
let e = _mm256_set_ps(1., 1., 1., 1., 1., 1., 1., 1.);
|
||||
assert_eq_m256(r, e);
|
||||
|
|
|
|||
|
|
@ -337,6 +337,31 @@ types! {
|
|||
);
|
||||
}
|
||||
|
||||
/// The BFloat16 type used in AVX-512 intrinsics.
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[allow(non_camel_case_types)]
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub struct bf16(u16);
|
||||
|
||||
impl bf16 {
|
||||
/// Raw transmutation from `u16`
|
||||
#[inline]
|
||||
#[must_use]
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub const fn from_bits(bits: u16) -> bf16 {
|
||||
bf16(bits)
|
||||
}
|
||||
|
||||
/// Raw transmutation to `u16`
|
||||
#[inline]
|
||||
#[must_use = "this returns the result of the operation, without modifying the original"]
|
||||
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
|
||||
pub const fn to_bits(self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
|
||||
#[allow(non_camel_case_types)]
|
||||
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
|
||||
|
|
|
|||
|
|
@ -197,6 +197,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
|
|||
"_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM },
|
||||
"_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
|
||||
"bool" => quote! { &BOOL },
|
||||
"bf16" => quote! { &BF16 },
|
||||
"f32" => quote! { &F32 },
|
||||
"f64" => quote! { &F64 },
|
||||
"i16" => quote! { &I16 },
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ struct Function {
|
|||
has_test: bool,
|
||||
}
|
||||
|
||||
static BF16: Type = Type::BFloat16;
|
||||
static F32: Type = Type::PrimFloat(32);
|
||||
static F64: Type = Type::PrimFloat(64);
|
||||
static I8: Type = Type::PrimSigned(8);
|
||||
|
|
@ -65,6 +66,7 @@ enum Type {
|
|||
PrimFloat(u8),
|
||||
PrimSigned(u8),
|
||||
PrimUnsigned(u8),
|
||||
BFloat16,
|
||||
MutPtr(&'static Type),
|
||||
ConstPtr(&'static Type),
|
||||
M128,
|
||||
|
|
@ -699,7 +701,8 @@ fn equate(
|
|||
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
|
||||
(&Type::PrimSigned(64), "__int64" | "long long") => {}
|
||||
(&Type::PrimUnsigned(8), "unsigned char") => {}
|
||||
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
|
||||
(&Type::PrimUnsigned(16), "unsigned short") => {}
|
||||
(&Type::BFloat16, "__bfloat16") => {}
|
||||
(
|
||||
&Type::PrimUnsigned(32),
|
||||
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",
|
||||
|
|
@ -758,9 +761,10 @@ fn equate(
|
|||
(&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimSigned(32)), "__int32 const*" | "int const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimSigned(64)), "__int64 const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*" | "__bf16 const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimUnsigned(32)), "unsigned int const*") => {}
|
||||
(&Type::ConstPtr(&Type::PrimUnsigned(64)), "unsigned __int64 const*") => {}
|
||||
(&Type::ConstPtr(&Type::BFloat16), "__bf16 const*") => {}
|
||||
|
||||
(&Type::ConstPtr(&Type::M128), "__m128 const*") => {}
|
||||
(&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue