From 307a4fcdf803f7f1032bd317d8a34413d2d1e2c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 22 Dec 2025 21:03:02 +0100 Subject: [PATCH] Add scalar support for both host and device --- compiler/rustc_codegen_llvm/src/back/write.rs | 40 +++++++++++- compiler/rustc_codegen_llvm/src/builder.rs | 15 +++++ .../src/builder/gpu_offload.rs | 62 ++++++++++++++----- compiler/rustc_codegen_llvm/src/intrinsic.rs | 4 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 12 +++- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 21 ++++++- compiler/rustc_middle/src/ty/offload_meta.rs | 17 +++-- .../codegen-llvm/gpu_offload/scalar_device.rs | 36 +++++++++++ tests/codegen-llvm/gpu_offload/scalar_host.rs | 37 +++++++++++ 9 files changed, 210 insertions(+), 34 deletions(-) create mode 100644 tests/codegen-llvm/gpu_offload/scalar_device.rs create mode 100644 tests/codegen-llvm/gpu_offload/scalar_host.rs diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index bcadb6f0de92..09d1ca1a5952 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -13,6 +13,7 @@ use rustc_codegen_ssa::back::write::{ TargetMachineFactoryConfig, TargetMachineFactoryFn, }; use rustc_codegen_ssa::base::wants_wasm_eh; +use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen, ModuleKind}; use rustc_data_structures::profiling::SelfProfilerRef; @@ -33,6 +34,8 @@ use crate::back::owned_target_machine::OwnedTargetMachine; use crate::back::profiling::{ LlvmSelfProfiler, selfprofile_after_pass_callback, selfprofile_before_pass_callback, }; +use crate::builder::SBuilder; +use crate::builder::gpu_offload::scalar_width; use crate::common::AsCCharPtr; use crate::errors::{ CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression, @@ -669,7 +672,17 @@ pub(crate) unsafe fn llvm_optimize( // Create the new parameter list, with ptr as the first argument let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1); new_param_types.push(cx.type_ptr()); - new_param_types.extend(old_param_types); + + // This relies on undocumented LLVM knowledge that scalars must be passed as i64 + for &old_ty in &old_param_types { + let new_ty = match cx.type_kind(old_ty) { + TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => { + cx.type_i64() + } + _ => old_ty, + }; + new_param_types.push(new_ty); + } // Create the new function type let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) }; @@ -682,10 +695,33 @@ pub(crate) unsafe fn llvm_optimize( let a0 = llvm::get_param(new_fn, 0); llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes()); + let bb = SBuilder::append_block(cx, new_fn, "entry"); + let mut builder = SBuilder::build(cx, bb); + + let mut old_args_rebuilt = Vec::with_capacity(old_param_types.len()); + + for (i, &old_ty) in old_param_types.iter().enumerate() { + let new_arg = llvm::get_param(new_fn, (i + 1) as u32); + + let rebuilt = match cx.type_kind(old_ty) { + TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => { + let num_bits = scalar_width(cx, old_ty); + + let trunc = builder.trunc(new_arg, cx.type_ix(num_bits)); + builder.bitcast(trunc, old_ty) + } + _ => new_arg, + }; + + old_args_rebuilt.push(rebuilt); + } + + builder.ret_void(); + // Here we map the old arguments to the new arguments, with an offset of 1 to make sure // that we don't use the newly added `%dyn_ptr`. unsafe { - llvm::LLVMRustOffloadMapper(old_fn, new_fn); + llvm::LLVMRustOffloadMapper(old_fn, new_fn, old_args_rebuilt.as_ptr()); } llvm::set_linkage(new_fn, llvm::get_linkage(old_fn)); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 7a49ba64029e..557ae7b0333e 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -97,6 +97,21 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { GenericBuilder { llbuilder, cx: scx } } + pub(crate) fn append_block( + cx: &'a GenericCx<'ll, CX>, + llfn: &'ll Value, + name: &str, + ) -> &'ll BasicBlock { + unsafe { + let name = SmallCStr::new(name); + llvm::LLVMAppendBasicBlockInContext(cx.llcx(), llfn, name.as_ptr()) + } + } + + pub(crate) fn trunc(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value { + unsafe { llvm::LLVMBuildTrunc(self.llbuilder, val, dest_ty, UNNAMED) } + } + pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value { unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 7817755dafe4..c591c785cae3 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -2,6 +2,7 @@ use std::ffi::CString; use llvm::Linkage::*; use rustc_abi::Align; +use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_middle::bug; @@ -357,7 +358,6 @@ pub(crate) fn add_global<'ll>( pub(crate) fn gen_define_handling<'ll>( cx: &CodegenCx<'ll, '_>, metadata: &[OffloadMetadata], - types: &[&'ll Type], symbol: String, offload_globals: &OffloadGlobals<'ll>, ) -> OffloadKernelGlobals<'ll> { @@ -367,25 +367,18 @@ pub(crate) fn gen_define_handling<'ll>( let offload_entry_ty = offload_globals.offload_entry_ty; - // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or - // reference) types. - let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), - _ => None, - }); - // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary - let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) = - ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip(); + let (sizes, transfer): (Vec<_>, Vec<_>) = + metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip(); - let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); + let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes); // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). let memtransfer_types = - add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer); + add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer); // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value @@ -441,11 +434,24 @@ fn declare_offload_fn<'ll>( ) } +pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 { + match cx.type_kind(ty) { + TypeKind::Half + | TypeKind::Float + | TypeKind::Double + | TypeKind::X86_FP80 + | TypeKind::FP128 + | TypeKind::PPC_FP128 => cx.float_width(ty) as u64, + TypeKind::Integer => cx.int_width(ty), + other => bug!("scalar_width was called on a non scalar type {other:?}"), + } +} + // For each kernel *call*, we now use some of our previous declared globals to move data to and from // the gpu. For now, we only handle the data transfer part of it. // If two consecutive kernels use the same memory, we still move it to the host and back to the gpu. // Since in our frontend users (by default) don't have to specify data transfer, this is something -// we should optimize in the future! In some cases we can directly zero-allocate ont he device and +// we should optimize in the future! In some cases we can directly zero-allocate on the device and // only move data back, or if something is immutable, we might only copy it to the device. // // Current steps: @@ -533,8 +539,34 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); for &v in args { - let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); - vals.push(v); + let ty = cx.val_ty(v); + let ty_kind = cx.type_kind(ty); + let (base_val, gep_base) = match ty_kind { + TypeKind::Pointer => (v, v), + TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => { + // FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it + let num_bits = scalar_width(cx, ty); + + let bb = builder.llbb(); + unsafe { + llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn()); + } + let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr"); + unsafe { + llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb); + } + + let cast = builder.bitcast(v, cx.type_ix(num_bits)); + let value = builder.zext(cast, cx.type_i64()); + builder.store(value, addr, Align::EIGHT); + (value, addr) + } + other => bug!("offload does not support {other:?}"), + }; + + let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]); + + vals.push(base_val); geps.push(gep); } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 59cbcd78dd0f..c2975df4b6a0 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1388,7 +1388,7 @@ fn codegen_offload<'ll, 'tcx>( let args = get_args_from_tuple(bx, args[3], fn_target); let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE); - let sig = tcx.fn_sig(fn_target.def_id()).instantiate_identity(); + let sig = tcx.fn_sig(fn_target.def_id()).skip_binder(); let sig = tcx.instantiate_bound_regions_with_erased(sig); let inputs = sig.inputs(); @@ -1404,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>( return; } }; - let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals); + let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals); gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a90013801c8c..c535fade9c04 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1675,7 +1675,11 @@ mod Offload { _M: &'a Module, _host_out: *const c_char, ) -> bool; - pub(crate) fn LLVMRustOffloadMapper<'a>(OldFn: &'a Value, NewFn: &'a Value); + pub(crate) fn LLVMRustOffloadMapper<'a>( + OldFn: &'a Value, + NewFn: &'a Value, + RebuiltArgs: *const &Value, + ); } } @@ -1702,7 +1706,11 @@ mod Offload_fallback { unimplemented!("This rustc version was not built with LLVM Offload support!"); } #[allow(unused_unsafe)] - pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(_OldFn: &'a Value, _NewFn: &'a Value) { + pub(crate) unsafe fn LLVMRustOffloadMapper<'a>( + _OldFn: &'a Value, + _NewFn: &'a Value, + _RebuiltArgs: *const &Value, + ) { unimplemented!("This rustc version was not built with LLVM Offload support!"); } } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 38e8886910f7..336d58974036 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -223,7 +223,12 @@ extern "C" bool LLVMRustOffloadEmbedBufferInModule(LLVMModuleRef HostM, return true; } -extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) { +// Clone OldFn into NewFn, remapping its arguments to RebuiltArgs. +// Each arg of OldFn is replaced with the corresponding value in RebuiltArgs. +// For scalars, RebuiltArgs contains the value cast and/or truncated to the +// original type. +extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn, + const LLVMValueRef *RebuiltArgs) { llvm::Function *oldFn = llvm::unwrap(OldFn); llvm::Function *newFn = llvm::unwrap(NewFn); @@ -232,15 +237,25 @@ extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) { llvm::ValueToValueMapTy vmap; auto newArgIt = newFn->arg_begin(); newArgIt->setName("dyn_ptr"); - ++newArgIt; // skip %dyn_ptr + + unsigned i = 0; for (auto &oldArg : oldFn->args()) { - vmap[&oldArg] = &*newArgIt++; + vmap[&oldArg] = unwrap(RebuiltArgs[i++]); } llvm::SmallVector returns; llvm::CloneFunctionInto(newFn, oldFn, vmap, llvm::CloneFunctionChangeType::LocalChangesOnly, returns); + + BasicBlock &entry = newFn->getEntryBlock(); + BasicBlock &clonedEntry = *std::next(newFn->begin()); + + if (entry.getTerminator()) + entry.getTerminator()->eraseFromParent(); + + IRBuilder<> B(&entry); + B.CreateBr(&clonedEntry); } #endif diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs index 04a7cd2c75f2..67c00765ed57 100644 --- a/compiler/rustc_middle/src/ty/offload_meta.rs +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -78,16 +78,13 @@ impl MappingFlags { use rustc_ast::Mutability::*; match ty.kind() { - ty::Bool - | ty::Char - | ty::Int(_) - | ty::Uint(_) - | ty::Float(_) - | ty::Adt(_, _) - | ty::Tuple(_) - | ty::Array(_, _) - | ty::Alias(_, _) - | ty::Param(_) => MappingFlags::TO, + ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) => { + MappingFlags::LITERAL | MappingFlags::IMPLICIT + } + + ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Alias(_, _) | ty::Param(_) => { + MappingFlags::TO + } ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO, diff --git a/tests/codegen-llvm/gpu_offload/scalar_device.rs b/tests/codegen-llvm/gpu_offload/scalar_device.rs new file mode 100644 index 000000000000..61772d404063 --- /dev/null +++ b/tests/codegen-llvm/gpu_offload/scalar_device.rs @@ -0,0 +1,36 @@ +//@ add-minicore +//@ revisions: amdgpu nvptx +//@[nvptx] compile-flags: -Copt-level=0 -Zunstable-options -Zoffload=Device --target nvptx64-nvidia-cuda --crate-type=rlib +//@[nvptx] needs-llvm-components: nvptx +//@[amdgpu] compile-flags: -Copt-level=0 -Zunstable-options -Zoffload=Device --target amdgcn-amd-amdhsa -Ctarget-cpu=gfx900 --crate-type=rlib +//@[amdgpu] needs-llvm-components: amdgpu +//@ no-prefer-dynamic +//@ needs-offload + +// This test verifies that the offload intrinsic is properly handling scalar args on the device, +// replacing the args by i64 and then trunc and cast them to the original type + +#![feature(abi_gpu_kernel, rustc_attrs, no_core)] +#![no_core] + +extern crate minicore; + +// CHECK: ; Function Attrs +// nvptx-NEXT: define ptx_kernel void @foo(ptr %dyn_ptr, ptr %0, i64 %1) +// amdgpu-NEXT: define amdgpu_kernel void @foo(ptr %dyn_ptr, ptr %0, i64 %1) +// CHECK-NEXT: entry: +// CHECK-NEXT: %2 = trunc i64 %1 to i32 +// CHECK-NEXT: %3 = bitcast i32 %2 to float +// CHECK-NEXT: br label %start +// CHECK: start: +// CHECK-NEXT: store float %3, ptr %0, align 4 +// CHECK-NEXT: ret void +// CHECK-NEXT: } + +#[unsafe(no_mangle)] +#[rustc_offload_kernel] +pub unsafe extern "gpu-kernel" fn foo(x: *mut f32, k: f32) { + unsafe { + *x = k; + }; +} diff --git a/tests/codegen-llvm/gpu_offload/scalar_host.rs b/tests/codegen-llvm/gpu_offload/scalar_host.rs new file mode 100644 index 000000000000..8c7dcd4dd581 --- /dev/null +++ b/tests/codegen-llvm/gpu_offload/scalar_host.rs @@ -0,0 +1,37 @@ +//@ compile-flags: -Zoffload=Test -Zunstable-options -C opt-level=1 -Clto=fat +//@ no-prefer-dynamic +//@ needs-offload + +// This test verifies that the offload intrinsic is properly handling scalar args, passing them to +// the kernel as i64 + +#![feature(abi_gpu_kernel)] +#![feature(rustc_attrs)] +#![feature(core_intrinsics)] +#![no_main] + +// CHECK: define{{( dso_local)?}} void @main() +// CHECK-NOT: define +// CHECK: %addr = alloca i64, align 8 +// CHECK: store double 4.200000e+01, ptr %0, align 8 +// CHECK: %_0.i = load double, ptr %0, align 8 +// CHECK: store double %_0.i, ptr %addr, align 8 +// CHECK: %1 = getelementptr inbounds nuw i8, ptr %.offload_baseptrs, i64 8 +// CHECK-NEXT: store double %_0.i, ptr %1, align 8 +// CHECK-NEXT: %2 = getelementptr inbounds nuw i8, ptr %.offload_ptrs, i64 8 +// CHECK-NEXT: store ptr %addr, ptr %2, align 8 +// CHECK-NEXT: %3 = getelementptr inbounds nuw i8, ptr %.offload_sizes, i64 8 +// CHECK-NEXT: store i64 4, ptr %3, align 8 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper + +#[unsafe(no_mangle)] +fn main() { + let mut x = 0.0; + let k = core::hint::black_box(42.0); + + core::intrinsics::offload::<_, _, ()>(foo, [1, 1, 1], [1, 1, 1], (&mut x, k)); +} + +unsafe extern "C" { + pub fn foo(x: *mut f32, k: f32); +}