Rollup merge of #55073 - alexcrichton:demote-simd, r=nagisa

The issue of passing around SIMD types as values between functions has
seen [quite a lot] of [discussion], and although we thought [we fixed
it][quite a lot] it [wasn't]! This PR is a change to rustc to, again,
try to fix this issue.

The fundamental problem here remains the same, if a SIMD vector argument
is passed by-value in LLVM's function type, then if the caller and
callee disagree on target features a miscompile happens. We solve this
by never passing SIMD vectors by-value, but LLVM will still thwart us
with its argument promotion pass to promote by-ref SIMD arguments to
by-val SIMD arguments.

This commit is an attempt to thwart LLVM thwarting us. We, just before
codegen, will take yet another look at the LLVM module and demote any
by-value SIMD arguments we see. This is a very manual attempt by us to
ensure the codegen for a module keeps working, and it unfortunately is
likely producing suboptimal code, even in release mode. The saving grace
for this, in theory, is that if SIMD types are passed by-value across
a boundary in release mode it's pretty unlikely to be performance
sensitive (as it's already doing a load/store, and otherwise
perf-sensitive bits should be inlined).

The implementation here is basically a big wad of C++. It was largely
copied from LLVM's own argument promotion pass, only doing the reverse.
In local testing this...

Closes #50154
Closes #52636
Closes #54583
Closes #55059

[quite a lot]: https://github.com/rust-lang/rust/pull/47743
[discussion]: https://github.com/rust-lang/rust/issues/44367
[wasn't]: https://github.com/rust-lang/rust/issues/50154
This commit is contained in:
Manish Goregaokar 2018-10-20 13:15:39 -07:00
commit b860765355
9 changed files with 332 additions and 9 deletions

View file

@ -0,0 +1,13 @@
-include ../../run-make-fulldeps/tools.mk
ifeq ($(TARGET),x86_64-unknown-linux-gnu)
all:
$(RUSTC) t1.rs -C opt-level=3
$(TMPDIR)/t1
$(RUSTC) t2.rs -C opt-level=3
$(TMPDIR)/t2
$(RUSTC) t3.rs -C opt-level=3
$(TMPDIR)/t3
else
all:
endif

View file

@ -0,0 +1,21 @@
use std::arch::x86_64;
fn main() {
if !is_x86_feature_detected!("avx2") {
return println!("AVX2 is not supported on this machine/build.");
}
let load_bytes: [u8; 32] = [0x0f; 32];
let lb_ptr = load_bytes.as_ptr();
let reg_load = unsafe {
x86_64::_mm256_loadu_si256(
lb_ptr as *const x86_64::__m256i
)
};
println!("{:?}", reg_load);
let mut store_bytes: [u8; 32] = [0; 32];
let sb_ptr = store_bytes.as_mut_ptr();
unsafe {
x86_64::_mm256_storeu_si256(sb_ptr as *mut x86_64::__m256i, reg_load);
}
assert_eq!(load_bytes, store_bytes);
}

View file

@ -0,0 +1,14 @@
use std::arch::x86_64::*;
fn main() {
if !is_x86_feature_detected!("avx") {
return println!("AVX is not supported on this machine/build.");
}
unsafe {
let f = _mm256_set_pd(2.0, 2.0, 2.0, 2.0);
let r = _mm256_mul_pd(f, f);
union A { a: __m256d, b: [f64; 4] }
assert_eq!(A { a: r }.b, [4.0, 4.0, 4.0, 4.0]);
}
}

View file

@ -0,0 +1,52 @@
use std::arch::x86_64::*;
#[target_feature(enable = "avx")]
unsafe fn avx_mul(a: __m256, b: __m256) -> __m256 {
_mm256_mul_ps(a, b)
}
#[target_feature(enable = "avx")]
unsafe fn avx_store(p: *mut f32, a: __m256) {
_mm256_storeu_ps(p, a)
}
#[target_feature(enable = "avx")]
unsafe fn avx_setr(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> __m256 {
_mm256_setr_ps(a, b, c, d, e, f, g, h)
}
#[target_feature(enable = "avx")]
unsafe fn avx_set1(a: f32) -> __m256 {
_mm256_set1_ps(a)
}
struct Avx(__m256);
fn mul(a: Avx, b: Avx) -> Avx {
unsafe { Avx(avx_mul(a.0, b.0)) }
}
fn set1(a: f32) -> Avx {
unsafe { Avx(avx_set1(a)) }
}
fn setr(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> Avx {
unsafe { Avx(avx_setr(a, b, c, d, e, f, g, h)) }
}
unsafe fn store(p: *mut f32, a: Avx) {
avx_store(p, a.0);
}
fn main() {
if !is_x86_feature_detected!("avx") {
return println!("AVX is not supported on this machine/build.");
}
let mut result = [0.0f32; 8];
let a = mul(setr(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0), set1(0.25));
unsafe {
store(result.as_mut_ptr(), a);
}
assert_eq!(result, [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75]);
}