Expose workgroup/thread dims as intrinsic args

This commit is contained in:
Marcelo Domínguez 2025-12-27 21:50:40 +01:00
parent 1b4325211c
commit 58e2610f71
7 changed files with 108 additions and 25 deletions

View file

@ -2,7 +2,9 @@ use std::ffi::CString;
use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;
use crate::builder::Builder;
@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> {
}
}
pub(crate) struct OffloadKernelDims<'ll> {
num_workgroups: &'ll Value,
threads_per_block: &'ll Value,
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
}
impl<'ll> OffloadKernelDims<'ll> {
pub(crate) fn from_operands<'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
) -> Self {
let cx = builder.cx;
let arr_ty = cx.type_array(cx.type_i32(), 3);
let four = Align::from_bytes(4).unwrap();
let OperandValue::Ref(place) = workgroup_op.val else {
bug!("expected array operand by reference");
};
let workgroup_val = builder.load(arr_ty, place.llval, four);
let OperandValue::Ref(place) = thread_op.val else {
bug!("expected array operand by reference");
};
let thread_val = builder.load(arr_ty, place.llval, four);
fn mul_dim3<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
arr: &'ll Value,
) -> &'ll Value {
let x = builder.extract_value(arr, 0);
let y = builder.extract_value(arr, 1);
let z = builder.extract_value(arr, 2);
let xy = builder.mul(x, y);
builder.mul(xy, z)
}
let num_workgroups = mul_dim3(builder, workgroup_val);
let threads_per_block = mul_dim3(builder, thread_val);
OffloadKernelDims {
workgroup_dims: workgroup_val,
thread_dims: thread_val,
num_workgroups,
threads_per_block,
}
}
}
// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
@ -204,12 +257,12 @@ impl KernelArgsTy {
num_args: u64,
memtransfer_types: &'ll Value,
geps: [&'ll Value; 3],
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
) -> [(Align, &'ll Value); 13] {
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
let eight = Align::EIGHT;
let ti32 = cx.type_i32();
let ci32_0 = cx.get_const_i32(0);
[
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
(four, cx.get_const_i32(num_args)),
@ -222,8 +275,8 @@ impl KernelArgsTy {
(eight, cx.const_null(cx.type_ptr())), // dbg
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
(four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
(four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
(four, workgroup_dims),
(four, thread_dims),
(four, cx.get_const_i32(0)),
]
}
@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
types: &[&Type],
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
offload_dims: &OffloadKernelDims<'ll>,
) {
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;
let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;
@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
num_args,
s_ident_t,
);
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
s_ident_t,
// FIXME(offload) give users a way to select which GPU to use.
cx.get_const_i64(u64::MAX), // MAX == -1.
// FIXME(offload): Don't hardcode the numbers of threads in the future.
cx.get_const_i32(2097152),
cx.get_const_i32(256),
num_workgroups,
threads_per_block,
region_id,
a5,
];

View file

@ -30,7 +30,7 @@ use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::builder::gpu_offload::{gen_call_handling, gen_define_handling};
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
use crate::context::CodegenCx;
use crate::declare::declare_raw_fn;
use crate::errors::{
@ -1384,7 +1384,8 @@ fn codegen_offload<'ll, 'tcx>(
}
};
let args = get_args_from_tuple(bx, args[1], fn_target);
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
@ -1403,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>(
}
};
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
}
fn get_args_from_tuple<'ll, 'tcx>(

View file

@ -4,7 +4,7 @@ use rustc_abi::ExternAbi;
use rustc_errors::DiagMessage;
use rustc_hir::{self as hir, LangItem};
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_span::def_id::LocalDefId;
use rustc_span::{Span, Symbol, sym};
@ -315,7 +315,17 @@ pub(crate) fn check_intrinsic_type(
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
(0, 0, vec![type_id, type_id], tcx.types.bool)
}
sym::offload => (3, 0, vec![param(0), param(1)], param(2)),
sym::offload => (
3,
0,
vec![
param(0),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
param(1),
],
param(2),
),
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::arith_offset => (
1,

View file

@ -3385,11 +3385,17 @@ pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) ->
/// - `T`: A tuple of arguments passed to `f`.
/// - `R`: The return type of the kernel.
///
/// Arguments:
/// - `f`: The kernel function to offload.
/// - `workgroup_dim`: A 3D size specifying the number of workgroups to launch.
/// - `thread_dim`: A 3D size specifying the number of threads per workgroup.
/// - `args`: A tuple of arguments forwarded to `f`.
///
/// Example usage (pseudocode):
///
/// ```rust,ignore (pseudocode)
/// fn kernel(x: *mut [f64; 128]) {
/// core::intrinsics::offload(kernel_1, (x,))
/// core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,))
/// }
///
/// #[cfg(target_os = "linux")]
@ -3408,7 +3414,12 @@ pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) ->
/// <https://clang.llvm.org/docs/OffloadingDesign.html>.
#[rustc_nounwind]
#[rustc_intrinsic]
pub const fn offload<F, T: crate::marker::Tuple, R>(f: F, args: T) -> R;
pub const fn offload<F, T: crate::marker::Tuple, R>(
f: F,
workgroup_dim: [u32; 3],
thread_dim: [u32; 3],
args: T,
) -> R;
/// Inform Miri that a given pointer definitely has a certain alignment.
#[cfg(miri)]

View file

@ -57,7 +57,7 @@ fn main() {
#[inline(never)]
unsafe fn kernel(x: *mut [f64; 256]) {
core::intrinsics::offload(kernel_1, (x,))
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,))
}
#[cfg(target_os = "linux")]

View file

@ -21,14 +21,19 @@
// CHECK-NOT define
// CHECK: bb3
// CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2097152, i32 256, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args)
// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null)
#[unsafe(no_mangle)]
unsafe fn main() {
let A = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
for i in 0..100 {
core::intrinsics::offload::<_, _, ()>(foo, (A.as_ptr() as *const [f32; 6],));
core::intrinsics::offload::<_, _, ()>(
foo,
[256, 1, 1],
[32, 1, 1],
(A.as_ptr() as *const [f32; 6],),
);
}
}

View file

@ -82,14 +82,14 @@ fn main() {
// CHECK-NEXT: %5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40
// CHECK-NEXT: %6 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %5, i8 0, i64 32, i1 false)
// CHECK-NEXT: store <4 x i32> <i32 2097152, i32 0, i32 0, i32 256>, ptr %6, align 8
// CHECK-NEXT: %.fca.1.gep3 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88
// CHECK-NEXT: store i32 0, ptr %.fca.1.gep3, align 8
// CHECK-NEXT: %.fca.2.gep4 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92
// CHECK-NEXT: store i32 0, ptr %.fca.2.gep4, align 4
// CHECK-NEXT: store <4 x i32> <i32 256, i32 1, i32 1, i32 32>, ptr %6, align 8
// CHECK-NEXT: %.fca.1.gep5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88
// CHECK-NEXT: store i32 1, ptr %.fca.1.gep5, align 8
// CHECK-NEXT: %.fca.2.gep7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92
// CHECK-NEXT: store i32 1, ptr %.fca.2.gep7, align 4
// CHECK-NEXT: %7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96
// CHECK-NEXT: store i32 0, ptr %7, align 8
// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2097152, i32 256, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args)
// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null)
// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc)
// CHECK-NEXT: ret void
@ -98,7 +98,7 @@ fn main() {
#[unsafe(no_mangle)]
#[inline(never)]
pub fn kernel_1(x: &mut [f32; 256]) {
core::intrinsics::offload(_kernel_1, (x,))
core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,))
}
#[unsafe(no_mangle)]