add various wrappers for gpu code generation
This commit is contained in:
parent
634016478e
commit
5958ebe829
5 changed files with 140 additions and 2 deletions
|
|
@ -3,6 +3,7 @@ use std::ops::Deref;
|
|||
use std::{iter, ptr};
|
||||
|
||||
pub(crate) mod autodiff;
|
||||
pub(crate) mod gpu_offload;
|
||||
|
||||
use libc::{c_char, c_uint, size_t};
|
||||
use rustc_abi as abi;
|
||||
|
|
@ -117,6 +118,74 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
|
|||
}
|
||||
bx
|
||||
}
|
||||
|
||||
// The generic builder has less functionality and thus (unlike the other alloca) we can not
|
||||
// easily jump to the beginning of the function to place our allocas there. We trust the user
|
||||
// to manually do that. FIXME(offload): improve the genericCx and add more llvm wrappers to
|
||||
// handle this.
|
||||
pub(crate) fn direct_alloca(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
|
||||
let val = unsafe {
|
||||
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
|
||||
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
|
||||
// Cast to default addrspace if necessary
|
||||
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
|
||||
};
|
||||
if name != "" {
|
||||
let name = std::ffi::CString::new(name).unwrap();
|
||||
llvm::set_value_name(val, &name.as_bytes());
|
||||
}
|
||||
val
|
||||
}
|
||||
|
||||
pub(crate) fn inbounds_gep(
|
||||
&mut self,
|
||||
ty: &'ll Type,
|
||||
ptr: &'ll Value,
|
||||
indices: &[&'ll Value],
|
||||
) -> &'ll Value {
|
||||
unsafe {
|
||||
llvm::LLVMBuildGEPWithNoWrapFlags(
|
||||
self.llbuilder,
|
||||
ty,
|
||||
ptr,
|
||||
indices.as_ptr(),
|
||||
indices.len() as c_uint,
|
||||
UNNAMED,
|
||||
GEPNoWrapFlags::InBounds,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value {
|
||||
debug!("Store {:?} -> {:?}", val, ptr);
|
||||
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
|
||||
unsafe {
|
||||
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
|
||||
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
|
||||
store
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
|
||||
unsafe {
|
||||
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
|
||||
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
|
||||
load
|
||||
}
|
||||
}
|
||||
|
||||
fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
|
||||
unsafe {
|
||||
llvm::LLVMRustBuildMemSet(
|
||||
self.llbuilder,
|
||||
ptr,
|
||||
align.bytes() as c_uint,
|
||||
fill_byte,
|
||||
size,
|
||||
false,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Empty string, to be used where LLVM expects an instruction name, indicating
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ pub(crate) unsafe fn create_module<'ll>(
|
|||
|
||||
// Ensure the data-layout values hardcoded remain the defaults.
|
||||
{
|
||||
let tm = crate::back::write::create_informational_target_machine(tcx.sess, false);
|
||||
let tm = crate::back::write::create_informational_target_machine(sess, false);
|
||||
unsafe {
|
||||
llvm::LLVMRustSetDataLayoutFromTargetMachine(llmod, tm.raw());
|
||||
}
|
||||
|
|
@ -680,6 +680,22 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
|
|||
unsafe { llvm::LLVMConstInt(ty, val, llvm::False) }
|
||||
}
|
||||
|
||||
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
|
||||
self.get_const_int(self.type_i64(), n)
|
||||
}
|
||||
|
||||
pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value {
|
||||
self.get_const_int(self.type_i32(), n)
|
||||
}
|
||||
|
||||
pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value {
|
||||
self.get_const_int(self.type_i16(), n)
|
||||
}
|
||||
|
||||
pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value {
|
||||
self.get_const_int(self.type_i8(), n)
|
||||
}
|
||||
|
||||
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
|
||||
let name = SmallCStr::new(name);
|
||||
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use libc::{c_char, c_uint};
|
|||
|
||||
use super::MetadataKindId;
|
||||
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
|
||||
use crate::llvm::Bool;
|
||||
use crate::llvm::{Bool, Builder};
|
||||
|
||||
#[link(name = "llvm-wrapper", kind = "static")]
|
||||
unsafe extern "C" {
|
||||
|
|
@ -31,6 +31,14 @@ unsafe extern "C" {
|
|||
index: c_uint,
|
||||
kind: AttributeKind,
|
||||
);
|
||||
pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
|
||||
pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
|
||||
pub(crate) fn LLVMRustGetFunctionCall(
|
||||
F: &Value,
|
||||
name: *const c_char,
|
||||
NameLen: libc::size_t,
|
||||
) -> Option<&Value>;
|
||||
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
|
|
|||
|
|
@ -1138,6 +1138,11 @@ unsafe extern "C" {
|
|||
Count: c_uint,
|
||||
Packed: Bool,
|
||||
) -> &'a Value;
|
||||
pub(crate) fn LLVMConstNamedStruct<'a>(
|
||||
StructTy: &'a Type,
|
||||
ConstantVals: *const &'a Value,
|
||||
Count: c_uint,
|
||||
) -> &'a Value;
|
||||
pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value;
|
||||
|
||||
// Constant expressions
|
||||
|
|
@ -1217,6 +1222,8 @@ unsafe extern "C" {
|
|||
) -> &'a BasicBlock;
|
||||
|
||||
// Operations on instructions
|
||||
pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock;
|
||||
pub(crate) fn LLVMGetCalledValue(CallInst: &Value) -> Option<&Value>;
|
||||
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
|
||||
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
|
||||
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;
|
||||
|
|
@ -2556,6 +2563,7 @@ unsafe extern "C" {
|
|||
|
||||
pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine);
|
||||
|
||||
pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value);
|
||||
pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock);
|
||||
|
||||
pub(crate) fn LLVMRustSetModulePICLevel(M: &Module);
|
||||
|
|
|
|||
|
|
@ -1591,12 +1591,49 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst,
|
|||
MaybeAlign(DstAlign), IsVolatile));
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B,
|
||||
LLVMValueRef Fn) {
|
||||
Function *F = unwrap<Function>(Fn);
|
||||
unwrap(B)->SetInsertPointPastAllocas(F);
|
||||
}
|
||||
extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
|
||||
LLVMBasicBlockRef BB) {
|
||||
auto Point = unwrap(BB)->getFirstInsertionPt();
|
||||
unwrap(B)->SetInsertPoint(unwrap(BB), Point);
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) {
|
||||
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
|
||||
unwrap(B)->SetInsertPoint(I);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) {
|
||||
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
|
||||
auto J = I->getNextNonDebugInstruction();
|
||||
unwrap(B)->SetInsertPoint(J);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LLVMValueRef
|
||||
LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) {
|
||||
auto targetName = StringRef(Name, NameLen);
|
||||
Function *F = unwrap<Function>(Fn);
|
||||
for (auto &BB : *F) {
|
||||
for (auto &I : BB) {
|
||||
if (auto *callInst = llvm::dyn_cast<llvm::CallBase>(&I)) {
|
||||
const llvm::Function *calledFunc = callInst->getCalledFunction();
|
||||
if (calledFunc && calledFunc->getName() == targetName) {
|
||||
// Found a call to the target function
|
||||
return wrap(callInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
extern "C" bool LLVMRustConstIntGetZExtValue(LLVMValueRef CV, uint64_t *value) {
|
||||
auto C = unwrap<llvm::ConstantInt>(CV);
|
||||
if (C->getBitWidth() > 64)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue