From 3fb437b867ea30304f07f666d7fe41e2751f3619 Mon Sep 17 00:00:00 2001 From: Amanieu d'Antras Date: Sun, 28 Feb 2021 18:07:42 +0000 Subject: [PATCH] Add stricter validation of const arguments on x86 intrinsics (#1025) --- .../stdarch/crates/core_arch/src/x86/avx2.rs | 64 ++++++++++++------- .../crates/core_arch/src/x86/avx512bw.rs | 12 ++-- .../crates/core_arch/src/x86/avx512f.rs | 20 +++--- .../crates/stdarch-verify/tests/x86-intel.rs | 33 +++++++--- 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/library/stdarch/crates/core_arch/src/x86/avx2.rs b/library/stdarch/crates/core_arch/src/x86/avx2.rs index 358ce15081b6..78bb0f3481be 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx2.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx2.rs @@ -2929,9 +2929,11 @@ pub unsafe fn _mm256_sll_epi64(a: __m256i, count: __m128i) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi16) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsllw))] +#[cfg_attr(test, assert_instr(vpsllw, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_slli_epi16(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_slli_epi16(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(pslliw(a.as_i16x16(), imm8)) } @@ -2941,9 +2943,11 @@ pub unsafe fn _mm256_slli_epi16(a: __m256i, imm8: i32) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi32) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpslld))] +#[cfg_attr(test, assert_instr(vpslld, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_slli_epi32(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_slli_epi32(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psllid(a.as_i32x8(), imm8)) } @@ -2953,9 +2957,11 @@ pub unsafe fn _mm256_slli_epi32(a: __m256i, imm8: i32) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi64) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsllq))] +#[cfg_attr(test, assert_instr(vpsllq, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_slli_epi64(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_slli_epi64(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(pslliq(a.as_i64x4(), imm8)) } @@ -3077,9 +3083,11 @@ pub unsafe fn _mm256_sra_epi32(a: __m256i, count: __m128i) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srai_epi16) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsraw))] +#[cfg_attr(test, assert_instr(vpsraw, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_srai_epi16(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_srai_epi16(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psraiw(a.as_i16x16(), imm8)) } @@ -3089,9 +3097,11 @@ pub unsafe fn _mm256_srai_epi16(a: __m256i, imm8: i32) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srai_epi32) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsrad))] +#[cfg_attr(test, assert_instr(vpsrad, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_srai_epi32(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_srai_epi32(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psraid(a.as_i32x8(), imm8)) } @@ -3197,9 +3207,11 @@ pub unsafe fn _mm256_srl_epi64(a: __m256i, count: __m128i) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi16) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsrlw))] +#[cfg_attr(test, assert_instr(vpsrlw, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_srli_epi16(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_srli_epi16(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psrliw(a.as_i16x16(), imm8)) } @@ -3209,9 +3221,11 @@ pub unsafe fn _mm256_srli_epi16(a: __m256i, imm8: i32) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi32) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsrld))] +#[cfg_attr(test, assert_instr(vpsrld, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_srli_epi32(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_srli_epi32(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psrlid(a.as_i32x8(), imm8)) } @@ -3221,9 +3235,11 @@ pub unsafe fn _mm256_srli_epi32(a: __m256i, imm8: i32) -> __m256i { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi64) #[inline] #[target_feature(enable = "avx2")] -#[cfg_attr(test, assert_instr(vpsrlq))] +#[cfg_attr(test, assert_instr(vpsrlq, imm8 = 7))] +#[rustc_legacy_const_generics(1)] #[stable(feature = "simd_x86", since = "1.27.0")] -pub unsafe fn _mm256_srli_epi64(a: __m256i, imm8: i32) -> __m256i { +pub unsafe fn _mm256_srli_epi64(a: __m256i) -> __m256i { + static_assert_imm8!(imm8); transmute(psrliq(a.as_i64x4(), imm8)) } @@ -5204,7 +5220,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_slli_epi16() { assert_eq_m256i( - _mm256_slli_epi16(_mm256_set1_epi16(0xFF), 4), + _mm256_slli_epi16::<4>(_mm256_set1_epi16(0xFF)), _mm256_set1_epi16(0xFF0), ); } @@ -5212,7 +5228,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_slli_epi32() { assert_eq_m256i( - _mm256_slli_epi32(_mm256_set1_epi32(0xFFFF), 4), + _mm256_slli_epi32::<4>(_mm256_set1_epi32(0xFFFF)), _mm256_set1_epi32(0xFFFF0), ); } @@ -5220,7 +5236,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_slli_epi64() { assert_eq_m256i( - _mm256_slli_epi64(_mm256_set1_epi64x(0xFFFFFFFF), 4), + _mm256_slli_epi64::<4>(_mm256_set1_epi64x(0xFFFFFFFF)), _mm256_set1_epi64x(0xFFFFFFFF0), ); } @@ -5287,7 +5303,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_srai_epi16() { assert_eq_m256i( - _mm256_srai_epi16(_mm256_set1_epi16(-1), 1), + _mm256_srai_epi16::<1>(_mm256_set1_epi16(-1)), _mm256_set1_epi16(-1), ); } @@ -5295,7 +5311,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_srai_epi32() { assert_eq_m256i( - _mm256_srai_epi32(_mm256_set1_epi32(-1), 1), + _mm256_srai_epi32::<1>(_mm256_set1_epi32(-1)), _mm256_set1_epi32(-1), ); } @@ -5365,7 +5381,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_srli_epi16() { assert_eq_m256i( - _mm256_srli_epi16(_mm256_set1_epi16(0xFF), 4), + _mm256_srli_epi16::<4>(_mm256_set1_epi16(0xFF)), _mm256_set1_epi16(0xF), ); } @@ -5373,7 +5389,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_srli_epi32() { assert_eq_m256i( - _mm256_srli_epi32(_mm256_set1_epi32(0xFFFF), 4), + _mm256_srli_epi32::<4>(_mm256_set1_epi32(0xFFFF)), _mm256_set1_epi32(0xFFF), ); } @@ -5381,7 +5397,7 @@ mod tests { #[simd_test(enable = "avx2")] unsafe fn test_mm256_srli_epi64() { assert_eq_m256i( - _mm256_srli_epi64(_mm256_set1_epi64x(0xFFFFFFFF), 4), + _mm256_srli_epi64::<4>(_mm256_set1_epi64x(0xFFFFFFFF)), _mm256_set1_epi64x(0xFFFFFFF), ); } diff --git a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs index 16dfdf7eec2c..b6fa9d254a42 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512bw.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512bw.rs @@ -5166,7 +5166,7 @@ pub unsafe fn _mm512_maskz_slli_epi16(k: __mmask32, a: __m512i, imm8: u32) -> __ pub unsafe fn _mm256_mask_slli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi16(a, $imm8) + _mm256_slli_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -5183,7 +5183,7 @@ pub unsafe fn _mm256_mask_slli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm pub unsafe fn _mm256_maskz_slli_epi16(k: __mmask16, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi16(a, $imm8) + _mm256_slli_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -5495,7 +5495,7 @@ pub unsafe fn _mm512_maskz_srli_epi16(k: __mmask32, a: __m512i, imm8: i32) -> __ pub unsafe fn _mm256_mask_srli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: i32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi16(a, $imm8) + _mm256_srli_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -5512,7 +5512,7 @@ pub unsafe fn _mm256_mask_srli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm pub unsafe fn _mm256_maskz_srli_epi16(k: __mmask16, a: __m256i, imm8: i32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi16(a, $imm8) + _mm256_srli_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -5823,7 +5823,7 @@ pub unsafe fn _mm512_maskz_srai_epi16(k: __mmask32, a: __m512i, imm8: u32) -> __ pub unsafe fn _mm256_mask_srai_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srai_epi16(a, $imm8) + _mm256_srai_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -5840,7 +5840,7 @@ pub unsafe fn _mm256_mask_srai_epi16(src: __m256i, k: __mmask16, a: __m256i, imm pub unsafe fn _mm256_maskz_srai_epi16(k: __mmask16, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srai_epi16(a, $imm8) + _mm256_srai_epi16::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); diff --git a/library/stdarch/crates/core_arch/src/x86/avx512f.rs b/library/stdarch/crates/core_arch/src/x86/avx512f.rs index b450963c2bc6..439135920444 100644 --- a/library/stdarch/crates/core_arch/src/x86/avx512f.rs +++ b/library/stdarch/crates/core_arch/src/x86/avx512f.rs @@ -18149,7 +18149,7 @@ pub unsafe fn _mm512_maskz_slli_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __ pub unsafe fn _mm256_mask_slli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi32(a, $imm8) + _mm256_slli_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18166,7 +18166,7 @@ pub unsafe fn _mm256_mask_slli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8 pub unsafe fn _mm256_maskz_slli_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi32(a, $imm8) + _mm256_slli_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18274,7 +18274,7 @@ pub unsafe fn _mm512_maskz_srli_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __ pub unsafe fn _mm256_mask_srli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi32(a, $imm8) + _mm256_srli_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18291,7 +18291,7 @@ pub unsafe fn _mm256_mask_srli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8 pub unsafe fn _mm256_maskz_srli_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi32(a, $imm8) + _mm256_srli_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18399,7 +18399,7 @@ pub unsafe fn _mm512_maskz_slli_epi64(k: __mmask8, a: __m512i, imm8: u32) -> __m pub unsafe fn _mm256_mask_slli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi64(a, $imm8) + _mm256_slli_epi64::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18416,7 +18416,7 @@ pub unsafe fn _mm256_mask_slli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8 pub unsafe fn _mm256_maskz_slli_epi64(k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_slli_epi64(a, $imm8) + _mm256_slli_epi64::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18524,7 +18524,7 @@ pub unsafe fn _mm512_maskz_srli_epi64(k: __mmask8, a: __m512i, imm8: u32) -> __m pub unsafe fn _mm256_mask_srli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi64(a, $imm8) + _mm256_srli_epi64::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -18541,7 +18541,7 @@ pub unsafe fn _mm256_mask_srli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8 pub unsafe fn _mm256_maskz_srli_epi64(k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srli_epi64(a, $imm8) + _mm256_srli_epi64::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -19203,7 +19203,7 @@ pub unsafe fn _mm512_maskz_srai_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __ pub unsafe fn _mm256_mask_srai_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srai_epi32(a, $imm8) + _mm256_srai_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); @@ -19220,7 +19220,7 @@ pub unsafe fn _mm256_mask_srai_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8 pub unsafe fn _mm256_maskz_srai_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i { macro_rules! call { ($imm8:expr) => { - _mm256_srai_epi32(a, $imm8) + _mm256_srai_epi32::<$imm8>(a) }; } let shf = constify_imm8_sae!(imm8, call); diff --git a/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs b/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs index 01a245c54c27..89494bfd2a55 100644 --- a/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs +++ b/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs @@ -128,6 +128,8 @@ struct Intrinsic { struct Parameter { #[serde(rename = "type")] type_: String, + #[serde(default)] + etype: String, } #[derive(Deserialize)] @@ -548,7 +550,7 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> { // Make sure we've got the right return type. if let Some(t) = rust.ret { - equate(t, &intel.return_.type_, rust.name, false)?; + equate(t, &intel.return_.type_, "", rust.name, false)?; } else if intel.return_.type_ != "" && intel.return_.type_ != "void" { bail!( "{} returns `{}` with intel, void in rust", @@ -570,7 +572,7 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> { } for (i, (a, b)) in intel.parameters.iter().zip(rust.arguments).enumerate() { let is_const = rust.required_const.contains(&i); - equate(b, &a.type_, &intel.name, is_const)?; + equate(b, &a.type_, &a.etype, &intel.name, is_const)?; } } @@ -669,7 +671,13 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> { Ok(()) } -fn equate(t: &Type, intel: &str, intrinsic: &str, is_const: bool) -> Result<(), String> { +fn equate( + t: &Type, + intel: &str, + etype: &str, + intrinsic: &str, + is_const: bool, +) -> Result<(), String> { // Make pointer adjacent to the type: float * foo => float* foo let mut intel = intel.replace(" *", "*"); // Make mutability modifier adjacent to the pointer: @@ -681,19 +689,26 @@ fn equate(t: &Type, intel: &str, intrinsic: &str, is_const: bool) -> Result<(), intel = intel.replace("const ", ""); intel = intel.replace("*", " const*"); } - let require_const = || { - if is_const { - return Ok(()); + if etype == "IMM" { + // The _bittest intrinsics claim to only accept immediates but actually + // accept run-time values as well. + if !is_const && !intrinsic.starts_with("_bittest") { + return bail!("argument required to be const but isn't"); } - Err(format!("argument required to be const but isn't")) - }; + } else { + // const int must be an IMM + assert_ne!(intel, "const int"); + if is_const { + return bail!("argument is const but shouldn't be"); + } + } match (t, &intel[..]) { (&Type::PrimFloat(32), "float") => {} (&Type::PrimFloat(64), "double") => {} (&Type::PrimSigned(16), "__int16") => {} (&Type::PrimSigned(16), "short") => {} (&Type::PrimSigned(32), "__int32") => {} - (&Type::PrimSigned(32), "const int") => require_const()?, + (&Type::PrimSigned(32), "const int") => {} (&Type::PrimSigned(32), "int") => {} (&Type::PrimSigned(64), "__int64") => {} (&Type::PrimSigned(64), "long long") => {}