Rollup merge of #151640 - ZuseZ4:cleanup-datatransfer, r=nnethercote

Cleanup offload datatransfer

There are 3 steps to run code on a GPU: Copy data from the host to the device, launch the kernel, and move it back.
At the moment, we have a single variable describing the memory handling to do in each step, but that makes it hard for LLVM's opt pass to understand what's going on. We therefore split it into three variables, each only including the bits relevant for the corresponding stage.

cc @jdoerfert @kevinsala

r? compiler
This commit is contained in:
Jonathan Brouwer 2026-02-08 19:15:26 +01:00 committed by GitHub
commit 16c7ee5c05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 92 additions and 53 deletions

View file

@ -1,12 +1,13 @@
use std::ffi::CString;
use bitflags::Flags;
use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;
use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata};
use crate::builder::Builder;
use crate::common::CodegenCx;
@ -28,10 +29,6 @@ pub(crate) struct OffloadGlobals<'ll> {
pub mapper_fn_ty: &'ll llvm::Type,
pub ident_t_global: &'ll llvm::Value,
// FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
// LLVM will initialize them for us if it sees gpu kernels being registered.
pub init_rtls: &'ll llvm::Value,
}
impl<'ll> OffloadGlobals<'ll> {
@ -42,9 +39,6 @@ impl<'ll> OffloadGlobals<'ll> {
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
let ident_t_global = generate_at_one(cx);
let init_ty = cx.type_func(&[], cx.type_void());
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
// We want LLVM's openmp-opt pass to pick up and optimize this module, since it covers both
// openmp and offload optimizations.
llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
@ -58,7 +52,6 @@ impl<'ll> OffloadGlobals<'ll> {
end_mapper,
mapper_fn_ty,
ident_t_global,
init_rtls,
}
}
}
@ -91,6 +84,11 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
// FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
// LLVM will initialize them for us if it sees gpu kernels being registered.
let init_ty = cx.type_func(&[], cx.type_void());
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
let desc_ty = cx.type_func(&[], cx.type_void());
let reg_name = ".omp_offloading.descriptor_reg";
let unreg_name = ".omp_offloading.descriptor_unreg";
@ -104,12 +102,14 @@ pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
// define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
// entry:
// call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
// call void @__tgt_init_all_rtls()
// %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
// ret void
// }
let bb = Builder::append_block(cx, desc_reg_fn, "entry");
let mut a = Builder::build(cx, bb);
a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
a.call(init_ty, None, None, init_rtls, &[], None, None);
a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
a.ret_void();
@ -345,7 +345,9 @@ impl KernelArgsTy {
#[derive(Copy, Clone)]
pub(crate) struct OffloadKernelGlobals<'ll> {
pub offload_sizes: &'ll llvm::Value,
pub memtransfer_types: &'ll llvm::Value,
pub memtransfer_begin: &'ll llvm::Value,
pub memtransfer_kernel: &'ll llvm::Value,
pub memtransfer_end: &'ll llvm::Value,
pub region_id: &'ll llvm::Value,
}
@ -423,18 +425,38 @@ pub(crate) fn gen_define_handling<'ll>(
let offload_entry_ty = offload_globals.offload_entry_ty;
// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let (sizes, transfer): (Vec<_>, Vec<_>) =
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
// Our begin mapper should only see simplified information about which args have to be
// transferred to the device, the end mapper only about which args should be transferred back.
// Any information beyond that makes it harder for LLVM's opt pass to evaluate whether it can
// safely move (=optimize) the LLVM-IR location of this data transfer. Only the mapping types
// mentioned below are handled, so make sure that we don't generate any other ones.
let handled_mappings = MappingFlags::TO
| MappingFlags::FROM
| MappingFlags::TARGET_PARAM
| MappingFlags::LITERAL
| MappingFlags::IMPLICIT;
for arg in &transfer {
debug_assert!(!arg.contains_unknown_bits());
debug_assert!(handled_mappings.contains(*arg));
}
let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
let transfer_to: Vec<u64> =
transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
let transfer_from: Vec<u64> =
transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
// FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let transfer_kernel = vec![MappingFlags::TARGET_PARAM.bits(); transfer_to.len()];
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
let memtransfer_types =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);
let memtransfer_begin =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
let memtransfer_kernel =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
let memtransfer_end =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}.end"), &transfer_from);
// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value
@ -469,7 +491,13 @@ pub(crate) fn gen_define_handling<'ll>(
cx.add_compiler_used_global(offload_entry);
let result = OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id };
let result = OffloadKernelGlobals {
offload_sizes,
memtransfer_begin,
memtransfer_kernel,
memtransfer_end,
region_id,
};
// FIXME(Sa4dUs): use this global for constant offload sizes
cx.add_compiler_used_global(result.offload_sizes);
@ -535,7 +563,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
offload_dims: &OffloadKernelDims<'ll>,
) {
let cx = builder.cx;
let OffloadKernelGlobals { memtransfer_types, region_id, .. } = offload_data;
let OffloadKernelGlobals {
memtransfer_begin,
memtransfer_kernel,
memtransfer_end,
region_id,
..
} = offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;
@ -608,12 +642,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
geps.push(gep);
}
let init_ty = cx.type_func(&[], cx.type_void());
let init_rtls_decl = offload_globals.init_rtls;
// call void @__tgt_init_all_rtls()
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
for i in 0..num_args {
let idx = cx.get_const_i32(i);
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
@ -668,14 +696,14 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
generate_mapper_call(
builder,
geps,
memtransfer_types,
memtransfer_begin,
begin_mapper_decl,
fn_ty,
num_args,
s_ident_t,
);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
KernelArgsTy::new(&cx, num_args, memtransfer_kernel, geps, workgroup_dims, thread_dims);
// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
@ -701,7 +729,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
generate_mapper_call(
builder,
geps,
memtransfer_types,
memtransfer_end,
end_mapper_decl,
fn_ty,
num_args,

View file

@ -19,9 +19,9 @@
// CHECK: br label %bb3
// CHECK-NOT define
// CHECK: bb3
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.begin, ptr null, ptr null)
// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo.end, ptr null, ptr null)
#[unsafe(no_mangle)]
unsafe fn main() {
let A = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0];

View file

@ -14,19 +14,20 @@
#[unsafe(no_mangle)]
fn main() {
let mut x = [3.0; 256];
kernel_1(&mut x);
let y = [1.0; 256];
kernel_1(&mut x, &y);
core::hint::black_box(&x);
core::hint::black_box(&y);
}
pub fn kernel_1(x: &mut [f32; 256]) {
core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,))
pub fn kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x, y))
}
#[unsafe(no_mangle)]
#[inline(never)]
pub fn _kernel_1(x: &mut [f32; 256]) {
pub fn _kernel_1(x: &mut [f32; 256], y: &[f32; 256]) {
for i in 0..256 {
x[i] = 21.0;
x[i] = 21.0 + y[i];
}
}
@ -39,8 +40,10 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-DAG: @.omp_offloading.descriptor = internal constant { i32, ptr, ptr, ptr } zeroinitializer
// CHECK-DAG: @llvm.global_ctors = appending constant [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
// CHECK-DAG: @.offload_sizes.[[K:[^ ]*kernel_1]] = private unnamed_addr constant [1 x i64] [i64 1024]
// CHECK-DAG: @.offload_maptypes.[[K]] = private unnamed_addr constant [1 x i64] [i64 35]
// CHECK-DAG: @.offload_sizes.[[K:[^ ]*kernel_1]] = private unnamed_addr constant [2 x i64] [i64 1024, i64 1024]
// CHECK-DAG: @.offload_maptypes.[[K]].begin = private unnamed_addr constant [2 x i64] [i64 1, i64 1]
// CHECK-DAG: @.offload_maptypes.[[K]].kernel = private unnamed_addr constant [2 x i64] [i64 32, i64 32]
// CHECK-DAG: @.offload_maptypes.[[K]].end = private unnamed_addr constant [2 x i64] [i64 2, i64 0]
// CHECK-DAG: @.[[K]].region_id = internal constant i8 0
// CHECK-DAG: @.offloading.entry_name.[[K]] = internal unnamed_addr constant [{{[0-9]+}} x i8] c"[[K]]{{\\00}}", section ".llvm.rodata.offloading", align 1
// CHECK-DAG: @.offloading.entry.[[K]] = internal constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.[[K]].region_id, ptr @.offloading.entry_name.[[K]], i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8
@ -49,20 +52,27 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-LABEL: define{{( dso_local)?}} void @main()
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = alloca [8 x i8], align 8
// CHECK-NEXT: %x = alloca [1024 x i8], align 16
// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8
// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8
// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8
// CHECK-NEXT: %0 = alloca [8 x i8], align 8
// CHECK-NEXT: %1 = alloca [8 x i8], align 8
// CHECK-NEXT: %y = alloca [1024 x i8], align 16
// CHECK-NEXT: %x = alloca [1024 x i8], align 16
// CHECK-NEXT: %.offload_baseptrs = alloca [2 x ptr], align 8
// CHECK-NEXT: %.offload_ptrs = alloca [2 x ptr], align 8
// CHECK-NEXT: %.offload_sizes = alloca [2 x i64], align 8
// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
// CHECK: call void @__tgt_init_all_rtls()
// CHECK-NEXT: store ptr %x, ptr %.offload_baseptrs, align 8
// CHECK: store ptr %x, ptr %.offload_baseptrs, align 8
// CHECK-NEXT: store ptr %x, ptr %.offload_ptrs, align 8
// CHECK-NEXT: store i64 1024, ptr %.offload_sizes, align 8
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null)
// CHECK-NEXT: [[BPTRS_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_baseptrs, i64 8
// CHECK-NEXT: store ptr %y, ptr [[BPTRS_1]], align 8
// CHECK-NEXT: [[PTRS_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_ptrs, i64 8
// CHECK-NEXT: store ptr %y, ptr [[PTRS_1]], align 8
// CHECK-NEXT: [[SIZES_1:%.*]] = getelementptr inbounds nuw i8, ptr %.offload_sizes, i64 8
// CHECK-NEXT: store i64 1024, ptr [[SIZES_1]], align 8
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]].begin, ptr null, ptr null)
// CHECK-NEXT: store i32 3, ptr %kernel_args, align 8
// CHECK-NEXT: [[P4:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4
// CHECK-NEXT: store i32 1, ptr [[P4]], align 4
// CHECK-NEXT: store i32 2, ptr [[P4]], align 4
// CHECK-NEXT: [[P8:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8
// CHECK-NEXT: store ptr %.offload_baseptrs, ptr [[P8]], align 8
// CHECK-NEXT: [[P16:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16
@ -70,7 +80,7 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-NEXT: [[P24:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24
// CHECK-NEXT: store ptr %.offload_sizes, ptr [[P24]], align 8
// CHECK-NEXT: [[P32:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32
// CHECK-NEXT: store ptr @.offload_maptypes.[[K]], ptr [[P32]], align 8
// CHECK-NEXT: store ptr @.offload_maptypes.[[K]].kernel, ptr [[P32]], align 8
// CHECK-NEXT: [[P40:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
// CHECK-NEXT: [[P72:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) [[P40]], i8 0, i64 32, i1 false)
@ -81,9 +91,9 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-NEXT: store i32 1, ptr [[P92]], align 4
// CHECK-NEXT: [[P96:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
// CHECK-NEXT: store i32 0, ptr [[P96]], align 8
// CHECK-NEXT: {{%[^ ]+}} = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.[[K]].region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null)
// CHECK: ret void
// CHECK-NEXT: [[TGT_RET:%.*]] = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.[[K]].region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]].end, ptr null, ptr null)
// CHECK: ret void
// CHECK-NEXT: }
// CHECK: declare void @__tgt_register_lib(ptr) local_unnamed_addr
@ -92,6 +102,7 @@ pub fn _kernel_1(x: &mut [f32; 256]) {
// CHECK-LABEL: define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @__tgt_register_lib(ptr nonnull @.omp_offloading.descriptor)
// CHECK-NEXT: call void @__tgt_init_all_rtls()
// CHECK-NEXT: %0 = {{tail }}call i32 @atexit(ptr nonnull @.omp_offloading.descriptor_unreg)
// CHECK-NEXT: ret void
// CHECK-NEXT: }