Fix device code generation, to account for an implicit dyn_ptr argument.

This commit is contained in:
Manuel Drehwald 2025-08-31 19:49:40 -07:00
parent 401ae55427
commit 360b38cceb
12 changed files with 129 additions and 2 deletions

View file

@ -616,7 +616,8 @@ pub(crate) fn run_pass_manager(
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
}
if enable_gpu && !thin {
// Here we only handle the GPU host (=cpu) code.
if enable_gpu && !thin && !cgcx.target_is_like_gpu {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);

View file

@ -43,7 +43,7 @@ use crate::errors::{
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
use crate::llvm::{self, DiagnosticInfo};
use crate::type_::llvm_type_ptr;
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, base, common, llvm_util};
pub(crate) fn llvm_err<'a>(dcx: DiagCtxtHandle<'_>, err: LlvmError<'a>) -> ! {
match llvm::last_error() {
@ -645,6 +645,74 @@ pub(crate) unsafe fn llvm_optimize(
None
};
fn handle_offload<'ll>(cx: &'ll SimpleCx<'_>, old_fn: &llvm::Value) {
let old_fn_ty = cx.get_type_of_global(old_fn);
let old_param_types = cx.func_params_types(old_fn_ty);
let old_param_count = old_param_types.len();
if old_param_count == 0 {
return;
}
let first_param = llvm::get_param(old_fn, 0);
let c_name = llvm::get_value_name(first_param);
let first_arg_name = str::from_utf8(&c_name).unwrap();
// We might call llvm_optimize (and thus this code) multiple times on the same IR,
// but we shouldn't add this helper ptr multiple times.
// FIXME(offload): This could break if the user calls his first argument `dyn_ptr`.
if first_arg_name == "dyn_ptr" {
return;
}
// Create the new parameter list, with ptr as the first argument
let mut new_param_types = Vec::with_capacity(old_param_count as usize + 1);
new_param_types.push(cx.type_ptr());
new_param_types.extend(old_param_types);
// Create the new function type
let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };
let new_fn_ty = cx.type_func(&new_param_types, ret_ty);
// Create the new function, with a temporary .offload name to avoid a name collision.
let old_fn_name = String::from_utf8(llvm::get_value_name(old_fn)).unwrap();
let new_fn_name = format!("{}.offload", &old_fn_name);
let new_fn = cx.add_func(&new_fn_name, new_fn_ty);
let a0 = llvm::get_param(new_fn, 0);
llvm::set_value_name(a0, CString::new("dyn_ptr").unwrap().as_bytes());
// Here we map the old arguments to the new arguments, with an offset of 1 to make sure
// that we don't use the newly added `%dyn_ptr`.
unsafe {
llvm::LLVMRustOffloadMapper(cx.llmod(), old_fn, new_fn);
}
llvm::set_linkage(new_fn, llvm::get_linkage(old_fn));
llvm::set_visibility(new_fn, llvm::get_visibility(old_fn));
// Replace all uses of old_fn with new_fn (RAUW)
unsafe {
llvm::LLVMReplaceAllUsesWith(old_fn, new_fn);
}
let name = llvm::get_value_name(old_fn);
unsafe {
llvm::LLVMDeleteFunction(old_fn);
}
// Now we can re-use the old name, without name collision.
llvm::set_value_name(new_fn, &name);
}
if cgcx.target_is_like_gpu && config.offload.contains(&config::Offload::Enable) {
let cx =
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
// For now we only support up to 10 kernels named kernel_0 ... kernel_9, a follow-up PR is
// introducing a proper offload intrinsic to solve this limitation.
for num in 0..9 {
let name = format!("kernel_{num}");
if let Some(kernel) = cx.get_function(&name) {
handle_offload(&cx, kernel);
}
}
}
let mut llvm_profiler = cgcx
.prof
.llvm_recording_enabled()

View file

@ -19,6 +19,9 @@ pub(crate) fn handle_gpu_code<'ll>(
let mut memtransfer_types = vec![];
let mut region_ids = vec![];
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
// This is a temporary hack, we only search for kernel_0 to kernel_9 functions.
// There is a draft PR in progress which will introduce a proper offload intrinsic to remove
// this limitation.
for num in 0..9 {
let kernel = cx.get_function(&format!("kernel_{num}"));
if let Some(kernel) = kernel {

View file

@ -1127,6 +1127,11 @@ unsafe extern "C" {
// Operations on functions
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
pub(crate) fn LLVMAddFunction<'a>(
Mod: &'a Module,
Name: *const c_char,
FunctionTy: &'a Type,
) -> &'a Value;
pub(crate) fn LLVMDeleteFunction(Fn: &Value);
// Operations about llvm intrinsics
@ -2017,6 +2022,7 @@ unsafe extern "C" {
) -> &Attribute;
// Operations on functions
pub(crate) fn LLVMRustOffloadMapper<'a>(M: &'a Module, Fn: &'a Value, Fn: &'a Value);
pub(crate) fn LLVMRustGetOrInsertFunction<'a>(
M: &'a Module,
Name: *const c_char,

View file

@ -68,6 +68,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
unsafe { llvm::LLVMVectorType(ty, len as c_uint) }
}
pub(crate) fn add_func(&self, name: &str, ty: &'ll Type) -> &'ll Value {
let name = SmallCStr::new(name);
unsafe { llvm::LLVMAddFunction(self.llmod(), name.as_ptr(), ty) }
}
pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
unsafe {
let n_args = llvm::LLVMCountParamTypes(ty) as usize;

View file

@ -342,6 +342,7 @@ pub struct CodegenContext<B: WriteBackendMethods> {
pub target_arch: String,
pub target_is_like_darwin: bool,
pub target_is_like_aix: bool,
pub target_is_like_gpu: bool,
pub split_debuginfo: rustc_target::spec::SplitDebuginfo,
pub split_dwarf_kind: rustc_session::config::SplitDwarfKind,
pub pointer_size: Size,
@ -1309,6 +1310,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
target_arch: tcx.sess.target.arch.to_string(),
target_is_like_darwin: tcx.sess.target.is_like_darwin,
target_is_like_aix: tcx.sess.target.is_like_aix,
target_is_like_gpu: tcx.sess.target.is_like_gpu,
split_debuginfo: tcx.sess.split_debuginfo(),
split_dwarf_kind: tcx.sess.opts.unstable_opts.split_dwarf_kind,
parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend,

View file

@ -35,6 +35,8 @@
#include "llvm/Support/Signals.h"
#include "llvm/Support/Timer.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <iostream>
// for raw `write` in the bad-alloc handler
@ -142,6 +144,28 @@ extern "C" void LLVMRustPrintStatistics(RustStringRef OutBuf) {
llvm::PrintStatistics(OS);
}
extern "C" void LLVMRustOffloadMapper(LLVMModuleRef M, LLVMValueRef OldFn,
LLVMValueRef NewFn) {
llvm::Module *module = llvm::unwrap(M);
llvm::Function *oldFn = llvm::unwrap<llvm::Function>(OldFn);
llvm::Function *newFn = llvm::unwrap<llvm::Function>(NewFn);
// Map old arguments to new arguments. We skip the first dyn_ptr argument,
// since it can't be used directly by user code.
llvm::ValueToValueMapTy vmap;
auto newArgIt = newFn->arg_begin();
newArgIt->setName("dyn_ptr");
++newArgIt; // skip %dyn_ptr
for (auto &oldArg : oldFn->args()) {
vmap[&oldArg] = &*newArgIt++;
}
llvm::SmallVector<llvm::ReturnInst *, 8> returns;
llvm::CloneFunctionInto(newFn, oldFn, vmap,
llvm::CloneFunctionChangeType::LocalChangesOnly,
returns);
}
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
size_t NameLen) {
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));

View file

@ -578,6 +578,7 @@ impl RiscvInterruptKind {
///
/// The signature represented by this type may not match the MIR function signature.
/// Certain attributes, like `#[track_caller]` can introduce additional arguments, which are present in [`FnAbi`], but not in `FnSig`.
/// The std::offload module also adds an addition dyn_ptr argument to the GpuKernel ABI.
/// While this difference is rarely relevant, it should still be kept in mind.
///
/// I will do my best to describe this structure, but these

View file

@ -147,6 +147,7 @@ impl Target {
forward!(is_like_darwin);
forward!(is_like_solaris);
forward!(is_like_windows);
forward!(is_like_gpu);
forward!(is_like_msvc);
forward!(is_like_wasm);
forward!(is_like_android);
@ -337,6 +338,7 @@ impl ToJson for Target {
target_option_val!(is_like_darwin);
target_option_val!(is_like_solaris);
target_option_val!(is_like_windows);
target_option_val!(is_like_gpu);
target_option_val!(is_like_msvc);
target_option_val!(is_like_wasm);
target_option_val!(is_like_android);
@ -556,6 +558,7 @@ struct TargetSpecJson {
is_like_darwin: Option<bool>,
is_like_solaris: Option<bool>,
is_like_windows: Option<bool>,
is_like_gpu: Option<bool>,
is_like_msvc: Option<bool>,
is_like_wasm: Option<bool>,
is_like_android: Option<bool>,

View file

@ -2180,6 +2180,8 @@ pub struct TargetOptions {
/// Also indicates whether to use Apple-specific ABI changes, such as extending function
/// parameters to 32-bits.
pub is_like_darwin: bool,
/// Whether the target is a GPU (e.g. NVIDIA, AMD, Intel).
pub is_like_gpu: bool,
/// Whether the target toolchain is like Solaris's.
/// Only useful for compiling against Illumos/Solaris,
/// as they have a different set of linker flags. Defaults to false.
@ -2583,6 +2585,7 @@ impl Default for TargetOptions {
abi_return_struct_as_int: false,
is_like_aix: false,
is_like_darwin: false,
is_like_gpu: false,
is_like_solaris: false,
is_like_windows: false,
is_like_msvc: false,
@ -2748,6 +2751,11 @@ impl Target {
self.os == "solaris" || self.os == "illumos",
"`is_like_solaris` must be set if and only if `os` is `solaris` or `illumos`"
);
check_eq!(
self.is_like_gpu,
self.arch == Arch::Nvptx64 || self.arch == Arch::AmdGpu,
"`is_like_gpu` must be set if and only if `target` is `nvptx64` or `amdgcn`"
);
check_eq!(
self.is_like_windows,
self.os == "windows" || self.os == "uefi" || self.os == "cygwin",

View file

@ -34,6 +34,9 @@ pub(crate) fn target() -> Target {
no_builtins: true,
simd_types_indirect: false,
// Clearly a GPU
is_like_gpu: true,
// Allow `cdylib` crate type.
dynamic_linking: true,
only_cdylib: true,

View file

@ -42,6 +42,9 @@ pub(crate) fn target() -> Target {
// Let the `ptx-linker` to handle LLVM lowering into MC / assembly.
obj_is_bitcode: true,
// Clearly a GPU
is_like_gpu: true,
// Convenient and predicable naming scheme.
dll_prefix: "".into(),
dll_suffix: ".ptx".into(),