Add stricter validation of const arguments on x86 intrinsics (#1025)

This commit is contained in:
Amanieu d'Antras 2021-02-28 18:07:42 +00:00 committed by GitHub
parent d6a22093aa
commit 3fb437b867
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 49 deletions

View file

@ -2929,9 +2929,11 @@ pub unsafe fn _mm256_sll_epi64(a: __m256i, count: __m128i) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi16)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsllw))]
#[cfg_attr(test, assert_instr(vpsllw, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi16(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_slli_epi16<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(pslliw(a.as_i16x16(), imm8))
}
@ -2941,9 +2943,11 @@ pub unsafe fn _mm256_slli_epi16(a: __m256i, imm8: i32) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi32)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpslld))]
#[cfg_attr(test, assert_instr(vpslld, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi32(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_slli_epi32<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psllid(a.as_i32x8(), imm8))
}
@ -2953,9 +2957,11 @@ pub unsafe fn _mm256_slli_epi32(a: __m256i, imm8: i32) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_slli_epi64)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsllq))]
#[cfg_attr(test, assert_instr(vpsllq, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_slli_epi64(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_slli_epi64<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(pslliq(a.as_i64x4(), imm8))
}
@ -3077,9 +3083,11 @@ pub unsafe fn _mm256_sra_epi32(a: __m256i, count: __m128i) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srai_epi16)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsraw))]
#[cfg_attr(test, assert_instr(vpsraw, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srai_epi16(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_srai_epi16<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psraiw(a.as_i16x16(), imm8))
}
@ -3089,9 +3097,11 @@ pub unsafe fn _mm256_srai_epi16(a: __m256i, imm8: i32) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srai_epi32)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrad))]
#[cfg_attr(test, assert_instr(vpsrad, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srai_epi32(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_srai_epi32<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psraid(a.as_i32x8(), imm8))
}
@ -3197,9 +3207,11 @@ pub unsafe fn _mm256_srl_epi64(a: __m256i, count: __m128i) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi16)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrlw))]
#[cfg_attr(test, assert_instr(vpsrlw, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi16(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_srli_epi16<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psrliw(a.as_i16x16(), imm8))
}
@ -3209,9 +3221,11 @@ pub unsafe fn _mm256_srli_epi16(a: __m256i, imm8: i32) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi32)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrld))]
#[cfg_attr(test, assert_instr(vpsrld, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi32(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_srli_epi32<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psrlid(a.as_i32x8(), imm8))
}
@ -3221,9 +3235,11 @@ pub unsafe fn _mm256_srli_epi32(a: __m256i, imm8: i32) -> __m256i {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_srli_epi64)
#[inline]
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpsrlq))]
#[cfg_attr(test, assert_instr(vpsrlq, imm8 = 7))]
#[rustc_legacy_const_generics(1)]
#[stable(feature = "simd_x86", since = "1.27.0")]
pub unsafe fn _mm256_srli_epi64(a: __m256i, imm8: i32) -> __m256i {
pub unsafe fn _mm256_srli_epi64<const imm8: i32>(a: __m256i) -> __m256i {
static_assert_imm8!(imm8);
transmute(psrliq(a.as_i64x4(), imm8))
}
@ -5204,7 +5220,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_slli_epi16() {
assert_eq_m256i(
_mm256_slli_epi16(_mm256_set1_epi16(0xFF), 4),
_mm256_slli_epi16::<4>(_mm256_set1_epi16(0xFF)),
_mm256_set1_epi16(0xFF0),
);
}
@ -5212,7 +5228,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_slli_epi32() {
assert_eq_m256i(
_mm256_slli_epi32(_mm256_set1_epi32(0xFFFF), 4),
_mm256_slli_epi32::<4>(_mm256_set1_epi32(0xFFFF)),
_mm256_set1_epi32(0xFFFF0),
);
}
@ -5220,7 +5236,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_slli_epi64() {
assert_eq_m256i(
_mm256_slli_epi64(_mm256_set1_epi64x(0xFFFFFFFF), 4),
_mm256_slli_epi64::<4>(_mm256_set1_epi64x(0xFFFFFFFF)),
_mm256_set1_epi64x(0xFFFFFFFF0),
);
}
@ -5287,7 +5303,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_srai_epi16() {
assert_eq_m256i(
_mm256_srai_epi16(_mm256_set1_epi16(-1), 1),
_mm256_srai_epi16::<1>(_mm256_set1_epi16(-1)),
_mm256_set1_epi16(-1),
);
}
@ -5295,7 +5311,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_srai_epi32() {
assert_eq_m256i(
_mm256_srai_epi32(_mm256_set1_epi32(-1), 1),
_mm256_srai_epi32::<1>(_mm256_set1_epi32(-1)),
_mm256_set1_epi32(-1),
);
}
@ -5365,7 +5381,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_srli_epi16() {
assert_eq_m256i(
_mm256_srli_epi16(_mm256_set1_epi16(0xFF), 4),
_mm256_srli_epi16::<4>(_mm256_set1_epi16(0xFF)),
_mm256_set1_epi16(0xF),
);
}
@ -5373,7 +5389,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_srli_epi32() {
assert_eq_m256i(
_mm256_srli_epi32(_mm256_set1_epi32(0xFFFF), 4),
_mm256_srli_epi32::<4>(_mm256_set1_epi32(0xFFFF)),
_mm256_set1_epi32(0xFFF),
);
}
@ -5381,7 +5397,7 @@ mod tests {
#[simd_test(enable = "avx2")]
unsafe fn test_mm256_srli_epi64() {
assert_eq_m256i(
_mm256_srli_epi64(_mm256_set1_epi64x(0xFFFFFFFF), 4),
_mm256_srli_epi64::<4>(_mm256_set1_epi64x(0xFFFFFFFF)),
_mm256_set1_epi64x(0xFFFFFFF),
);
}

View file

@ -5166,7 +5166,7 @@ pub unsafe fn _mm512_maskz_slli_epi16(k: __mmask32, a: __m512i, imm8: u32) -> __
pub unsafe fn _mm256_mask_slli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi16(a, $imm8)
_mm256_slli_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -5183,7 +5183,7 @@ pub unsafe fn _mm256_mask_slli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm
pub unsafe fn _mm256_maskz_slli_epi16(k: __mmask16, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi16(a, $imm8)
_mm256_slli_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -5495,7 +5495,7 @@ pub unsafe fn _mm512_maskz_srli_epi16(k: __mmask32, a: __m512i, imm8: i32) -> __
pub unsafe fn _mm256_mask_srli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: i32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi16(a, $imm8)
_mm256_srli_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -5512,7 +5512,7 @@ pub unsafe fn _mm256_mask_srli_epi16(src: __m256i, k: __mmask16, a: __m256i, imm
pub unsafe fn _mm256_maskz_srli_epi16(k: __mmask16, a: __m256i, imm8: i32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi16(a, $imm8)
_mm256_srli_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -5823,7 +5823,7 @@ pub unsafe fn _mm512_maskz_srai_epi16(k: __mmask32, a: __m512i, imm8: u32) -> __
pub unsafe fn _mm256_mask_srai_epi16(src: __m256i, k: __mmask16, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srai_epi16(a, $imm8)
_mm256_srai_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -5840,7 +5840,7 @@ pub unsafe fn _mm256_mask_srai_epi16(src: __m256i, k: __mmask16, a: __m256i, imm
pub unsafe fn _mm256_maskz_srai_epi16(k: __mmask16, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srai_epi16(a, $imm8)
_mm256_srai_epi16::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);

View file

@ -18149,7 +18149,7 @@ pub unsafe fn _mm512_maskz_slli_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __
pub unsafe fn _mm256_mask_slli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi32(a, $imm8)
_mm256_slli_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18166,7 +18166,7 @@ pub unsafe fn _mm256_mask_slli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8
pub unsafe fn _mm256_maskz_slli_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi32(a, $imm8)
_mm256_slli_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18274,7 +18274,7 @@ pub unsafe fn _mm512_maskz_srli_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __
pub unsafe fn _mm256_mask_srli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi32(a, $imm8)
_mm256_srli_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18291,7 +18291,7 @@ pub unsafe fn _mm256_mask_srli_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8
pub unsafe fn _mm256_maskz_srli_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi32(a, $imm8)
_mm256_srli_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18399,7 +18399,7 @@ pub unsafe fn _mm512_maskz_slli_epi64(k: __mmask8, a: __m512i, imm8: u32) -> __m
pub unsafe fn _mm256_mask_slli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi64(a, $imm8)
_mm256_slli_epi64::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18416,7 +18416,7 @@ pub unsafe fn _mm256_mask_slli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8
pub unsafe fn _mm256_maskz_slli_epi64(k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_slli_epi64(a, $imm8)
_mm256_slli_epi64::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18524,7 +18524,7 @@ pub unsafe fn _mm512_maskz_srli_epi64(k: __mmask8, a: __m512i, imm8: u32) -> __m
pub unsafe fn _mm256_mask_srli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi64(a, $imm8)
_mm256_srli_epi64::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -18541,7 +18541,7 @@ pub unsafe fn _mm256_mask_srli_epi64(src: __m256i, k: __mmask8, a: __m256i, imm8
pub unsafe fn _mm256_maskz_srli_epi64(k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srli_epi64(a, $imm8)
_mm256_srli_epi64::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -19203,7 +19203,7 @@ pub unsafe fn _mm512_maskz_srai_epi32(k: __mmask16, a: __m512i, imm8: u32) -> __
pub unsafe fn _mm256_mask_srai_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srai_epi32(a, $imm8)
_mm256_srai_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);
@ -19220,7 +19220,7 @@ pub unsafe fn _mm256_mask_srai_epi32(src: __m256i, k: __mmask8, a: __m256i, imm8
pub unsafe fn _mm256_maskz_srai_epi32(k: __mmask8, a: __m256i, imm8: u32) -> __m256i {
macro_rules! call {
($imm8:expr) => {
_mm256_srai_epi32(a, $imm8)
_mm256_srai_epi32::<$imm8>(a)
};
}
let shf = constify_imm8_sae!(imm8, call);

View file

@ -128,6 +128,8 @@ struct Intrinsic {
struct Parameter {
#[serde(rename = "type")]
type_: String,
#[serde(default)]
etype: String,
}
#[derive(Deserialize)]
@ -548,7 +550,7 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> {
// Make sure we've got the right return type.
if let Some(t) = rust.ret {
equate(t, &intel.return_.type_, rust.name, false)?;
equate(t, &intel.return_.type_, "", rust.name, false)?;
} else if intel.return_.type_ != "" && intel.return_.type_ != "void" {
bail!(
"{} returns `{}` with intel, void in rust",
@ -570,7 +572,7 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> {
}
for (i, (a, b)) in intel.parameters.iter().zip(rust.arguments).enumerate() {
let is_const = rust.required_const.contains(&i);
equate(b, &a.type_, &intel.name, is_const)?;
equate(b, &a.type_, &a.etype, &intel.name, is_const)?;
}
}
@ -669,7 +671,13 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> {
Ok(())
}
fn equate(t: &Type, intel: &str, intrinsic: &str, is_const: bool) -> Result<(), String> {
fn equate(
t: &Type,
intel: &str,
etype: &str,
intrinsic: &str,
is_const: bool,
) -> Result<(), String> {
// Make pointer adjacent to the type: float * foo => float* foo
let mut intel = intel.replace(" *", "*");
// Make mutability modifier adjacent to the pointer:
@ -681,19 +689,26 @@ fn equate(t: &Type, intel: &str, intrinsic: &str, is_const: bool) -> Result<(),
intel = intel.replace("const ", "");
intel = intel.replace("*", " const*");
}
let require_const = || {
if is_const {
return Ok(());
if etype == "IMM" {
// The _bittest intrinsics claim to only accept immediates but actually
// accept run-time values as well.
if !is_const && !intrinsic.starts_with("_bittest") {
return bail!("argument required to be const but isn't");
}
Err(format!("argument required to be const but isn't"))
};
} else {
// const int must be an IMM
assert_ne!(intel, "const int");
if is_const {
return bail!("argument is const but shouldn't be");
}
}
match (t, &intel[..]) {
(&Type::PrimFloat(32), "float") => {}
(&Type::PrimFloat(64), "double") => {}
(&Type::PrimSigned(16), "__int16") => {}
(&Type::PrimSigned(16), "short") => {}
(&Type::PrimSigned(32), "__int32") => {}
(&Type::PrimSigned(32), "const int") => require_const()?,
(&Type::PrimSigned(32), "const int") => {}
(&Type::PrimSigned(32), "int") => {}
(&Type::PrimSigned(64), "__int64") => {}
(&Type::PrimSigned(64), "long long") => {}