Add scalar support for both host and device

This commit is contained in:
Marcelo Domínguez 2025-12-22 21:03:02 +01:00
parent 2c9c5d14a2
commit 307a4fcdf8
9 changed files with 210 additions and 34 deletions

View file

@ -13,6 +13,7 @@ use rustc_codegen_ssa::back::write::{
TargetMachineFactoryConfig, TargetMachineFactoryFn,
};
use rustc_codegen_ssa::base::wants_wasm_eh;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::*;
use rustc_codegen_ssa::{CompiledModule, ModuleCodegen, ModuleKind};
use rustc_data_structures::profiling::SelfProfilerRef;
@ -33,6 +34,8 @@ use crate::back::owned_target_machine::OwnedTargetMachine;
use crate::back::profiling::{
LlvmSelfProfiler, selfprofile_after_pass_callback, selfprofile_before_pass_callback,
};
use crate::builder::SBuilder;
use crate::builder::gpu_offload::scalar_width;
use crate::common::AsCCharPtr;
use crate::errors::{
CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, UnknownCompression,
@ -669,7 +672,17 @@ pub(crate) unsafe fn llvm_optimize(
// Create the new parameter list, with ptr as the first argument
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
new_param_types.push(cx.type_ptr());
new_param_types.extend(old_param_types);
// This relies on undocumented LLVM knowledge that scalars must be passed as i64
for &old_ty in &old_param_types {
let new_ty = match cx.type_kind(old_ty) {
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
cx.type_i64()
}
_ => old_ty,
};
new_param_types.push(new_ty);
}
// Create the new function type
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
@ -682,10 +695,33 @@ pub(crate) unsafe fn llvm_optimize(
let a0 = llvm::get_param(new_fn, 0);
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
let bb = SBuilder::append_block(cx, new_fn, "entry");
let mut builder = SBuilder::build(cx, bb);
let mut old_args_rebuilt = Vec::with_capacity(old_param_types.len());
for (i, &old_ty) in old_param_types.iter().enumerate() {
let new_arg = llvm::get_param(new_fn, (i + 1) as u32);
let rebuilt = match cx.type_kind(old_ty) {
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
let num_bits = scalar_width(cx, old_ty);
let trunc = builder.trunc(new_arg, cx.type_ix(num_bits));
builder.bitcast(trunc, old_ty)
}
_ => new_arg,
};
old_args_rebuilt.push(rebuilt);
}
builder.ret_void();
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
// that we don't use the newly added `%dyn_ptr`.
unsafe {
llvm::LLVMRustOffloadMapper(old_fn, new_fn);
llvm::LLVMRustOffloadMapper(old_fn, new_fn, old_args_rebuilt.as_ptr());
}
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));

View file

@ -97,6 +97,21 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
GenericBuilder { llbuilder, cx: scx }
}
pub(crate) fn append_block(
cx: &'a GenericCx<'ll, CX>,
llfn: &'ll Value,
name: &str,
) -> &'ll BasicBlock {
unsafe {
let name = SmallCStr::new(name);
llvm::LLVMAppendBasicBlockInContext(cx.llcx(), llfn, name.as_ptr())
}
}
pub(crate) fn trunc(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildTrunc(self.llbuilder, val, dest_ty, UNNAMED) }
}
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
}

View file

@ -2,6 +2,7 @@ use std::ffi::CString;
use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
@ -357,7 +358,6 @@ pub(crate) fn add_global<'ll>(
pub(crate) fn gen_define_handling<'ll>(
cx: &CodegenCx<'ll, '_>,
metadata: &[OffloadMetadata],
types: &[&'ll Type],
symbol: String,
offload_globals: &OffloadGlobals<'ll>,
) -> OffloadKernelGlobals<'ll> {
@ -367,25 +367,18 @@ pub(crate) fn gen_define_handling<'ll>(
let offload_entry_ty = offload_globals.offload_entry_ty;
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
// reference) types.
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
_ => None,
});
// FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
let (sizes, transfer): (Vec<_>, Vec<_>) =
metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &sizes);
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
let memtransfer_types =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &transfer);
// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value
@ -441,11 +434,24 @@ fn declare_offload_fn<'ll>(
)
}
pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
match cx.type_kind(ty) {
TypeKind::Half
| TypeKind::Float
| TypeKind::Double
| TypeKind::X86_FP80
| TypeKind::FP128
| TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
TypeKind::Integer => cx.int_width(ty),
other => bug!("scalar_width was called on a non scalar type {other:?}"),
}
}
// For each kernel *call*, we now use some of our previous declared globals to move data to and from
// the gpu. For now, we only handle the data transfer part of it.
// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
// Since in our frontend users (by default) don't have to specify data transfer, this is something
// we should optimize in the future! In some cases we can directly zero-allocate ont he device and
// we should optimize in the future! In some cases we can directly zero-allocate on the device and
// only move data back, or if something is immutable, we might only copy it to the device.
//
// Current steps:
@ -533,8 +539,34 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
let mut geps = vec![];
let i32_0 = cx.get_const_i32(0);
for &v in args {
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
vals.push(v);
let ty = cx.val_ty(v);
let ty_kind = cx.type_kind(ty);
let (base_val, gep_base) = match ty_kind {
TypeKind::Pointer => (v, v),
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
// FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
let num_bits = scalar_width(cx, ty);
let bb = builder.llbb();
unsafe {
llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
}
let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
unsafe {
llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
}
let cast = builder.bitcast(v, cx.type_ix(num_bits));
let value = builder.zext(cast, cx.type_i64());
builder.store(value, addr, Align::EIGHT);
(value, addr)
}
other => bug!("offload does not support {other:?}"),
};
let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
vals.push(base_val);
geps.push(gep);
}

View file

@ -1388,7 +1388,7 @@ fn codegen_offload<'ll, 'tcx>(
let args = get_args_from_tuple(bx, args[3], fn_target);
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);
let sig = tcx.fn_sig(fn_target.def_id()).instantiate_identity();
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder();
let sig = tcx.instantiate_bound_regions_with_erased(sig);
let inputs = sig.inputs();
@ -1404,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>(
return;
}
};
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
let offload_data = gen_define_handling(&cx, &metadata, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
}

View file

@ -1675,7 +1675,11 @@ mod Offload {
_M: &'a Module,
_host_out: *const c_char,
) -> bool;
pub(crate) fn LLVMRustOffloadMapper<'a>(OldFn: &'a Value, NewFn: &'a Value);
pub(crate) fn LLVMRustOffloadMapper<'a>(
OldFn: &'a Value,
NewFn: &'a Value,
RebuiltArgs: *const &Value,
);
}
}
@ -1702,7 +1706,11 @@ mod Offload_fallback {
unimplemented!("This rustc version was not built with LLVM Offload support!");
}
#[allow(unused_unsafe)]
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(_OldFn: &'a Value, _NewFn: &'a Value) {
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(
_OldFn: &'a Value,
_NewFn: &'a Value,
_RebuiltArgs: *const &Value,
) {
unimplemented!("This rustc version was not built with LLVM Offload support!");
}
}

View file

@ -223,7 +223,12 @@ extern "C" bool LLVMRustOffloadEmbedBufferInModule(LLVMModuleRef HostM,
return true;
}
extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) {
// Clone OldFn into NewFn, remapping its arguments to RebuiltArgs.
// Each arg of OldFn is replaced with the corresponding value in RebuiltArgs.
// For scalars, RebuiltArgs contains the value cast and/or truncated to the
// original type.
extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn,
const LLVMValueRef *RebuiltArgs) {
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
@ -232,15 +237,25 @@ extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) {
llvm::ValueToValueMapTy vmap;
auto newArgIt = newFn->arg_begin();
newArgIt->setName("dyn_ptr");
++newArgIt; // skip %dyn_ptr
unsigned i = 0;
for (auto &oldArg : oldFn->args()) {
vmap[&oldArg] = &*newArgIt++;
vmap[&oldArg] = unwrap<Value>(RebuiltArgs[i++]);
}
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
llvm::CloneFunctionInto(newFn, oldFn, vmap,
llvm::CloneFunctionChangeType::LocalChangesOnly,
returns);
BasicBlock &entry = newFn->getEntryBlock();
BasicBlock &clonedEntry = *std::next(newFn->begin());
if (entry.getTerminator())
entry.getTerminator()->eraseFromParent();
IRBuilder<> B(&entry);
B.CreateBr(&clonedEntry);
}
#endif

View file

@ -78,16 +78,13 @@ impl MappingFlags {
use rustc_ast::Mutability::*;
match ty.kind() {
ty::Bool
| ty::Char
| ty::Int(_)
| ty::Uint(_)
| ty::Float(_)
| ty::Adt(_, _)
| ty::Tuple(_)
| ty::Array(_, _)
| ty::Alias(_, _)
| ty::Param(_) => MappingFlags::TO,
ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) => {
MappingFlags::LITERAL | MappingFlags::IMPLICIT
}
ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Alias(_, _) | ty::Param(_) => {
MappingFlags::TO
}
ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,

View file

@ -0,0 +1,36 @@
//@ add-minicore
//@ revisions: amdgpu nvptx
//@[nvptx] compile-flags: -Copt-level=0 -Zunstable-options -Zoffload=Device --target nvptx64-nvidia-cuda --crate-type=rlib
//@[nvptx] needs-llvm-components: nvptx
//@[amdgpu] compile-flags: -Copt-level=0 -Zunstable-options -Zoffload=Device --target amdgcn-amd-amdhsa -Ctarget-cpu=gfx900 --crate-type=rlib
//@[amdgpu] needs-llvm-components: amdgpu
//@ no-prefer-dynamic
//@ needs-offload
// This test verifies that the offload intrinsic is properly handling scalar args on the device,
// replacing the args by i64 and then trunc and cast them to the original type
#![feature(abi_gpu_kernel, rustc_attrs, no_core)]
#![no_core]
extern crate minicore;
// CHECK: ; Function Attrs
// nvptx-NEXT: define ptx_kernel void @foo(ptr %dyn_ptr, ptr %0, i64 %1)
// amdgpu-NEXT: define amdgpu_kernel void @foo(ptr %dyn_ptr, ptr %0, i64 %1)
// CHECK-NEXT: entry:
// CHECK-NEXT: %2 = trunc i64 %1 to i32
// CHECK-NEXT: %3 = bitcast i32 %2 to float
// CHECK-NEXT: br label %start
// CHECK: start:
// CHECK-NEXT: store float %3, ptr %0, align 4
// CHECK-NEXT: ret void
// CHECK-NEXT: }
#[unsafe(no_mangle)]
#[rustc_offload_kernel]
pub unsafe extern "gpu-kernel" fn foo(x: *mut f32, k: f32) {
unsafe {
*x = k;
};
}

View file

@ -0,0 +1,37 @@
//@ compile-flags: -Zoffload=Test -Zunstable-options -C opt-level=1 -Clto=fat
//@ no-prefer-dynamic
//@ needs-offload
// This test verifies that the offload intrinsic is properly handling scalar args, passing them to
// the kernel as i64
#![feature(abi_gpu_kernel)]
#![feature(rustc_attrs)]
#![feature(core_intrinsics)]
#![no_main]
// CHECK: define{{( dso_local)?}} void @main()
// CHECK-NOT: define
// CHECK: %addr = alloca i64, align 8
// CHECK: store double 4.200000e+01, ptr %0, align 8
// CHECK: %_0.i = load double, ptr %0, align 8
// CHECK: store double %_0.i, ptr %addr, align 8
// CHECK: %1 = getelementptr inbounds nuw i8, ptr %.offload_baseptrs, i64 8
// CHECK-NEXT: store double %_0.i, ptr %1, align 8
// CHECK-NEXT: %2 = getelementptr inbounds nuw i8, ptr %.offload_ptrs, i64 8
// CHECK-NEXT: store ptr %addr, ptr %2, align 8
// CHECK-NEXT: %3 = getelementptr inbounds nuw i8, ptr %.offload_sizes, i64 8
// CHECK-NEXT: store i64 4, ptr %3, align 8
// CHECK-NEXT: call void @__tgt_target_data_begin_mapper
#[unsafe(no_mangle)]
fn main() {
let mut x = 0.0;
let k = core::hint::black_box(42.0);
core::intrinsics::offload::<_, _, ()>(foo, [1, 1, 1], [1, 1, 1], (&mut x, k));
}
unsafe extern "C" {
pub fn foo(x: *mut f32, k: f32);
}