Rollup merge of #139308 - Shourya742:2025-03-29-add-autodiff-inline, r=ZuseZ4
add autodiff inline closes: #138920 r? ```@ZuseZ4``` try-job: dist-aarch64-linux
This commit is contained in:
commit
d4845e1b0b
9 changed files with 145 additions and 2 deletions
|
|
@ -1,5 +1,4 @@
|
|||
//! Set and unset common attributes on LLVM values.
|
||||
|
||||
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
|
||||
use rustc_codegen_ssa::traits::*;
|
||||
use rustc_hir::def_id::DefId;
|
||||
|
|
@ -28,6 +27,22 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool {
|
||||
llvm::HasAttributeAtIndex(llfn, idx, attr)
|
||||
}
|
||||
|
||||
pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
|
||||
llvm::HasStringAttribute(llfn, name)
|
||||
}
|
||||
|
||||
pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
|
||||
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
|
||||
llvm::RemoveStringAttrFromFn(llfn, name);
|
||||
}
|
||||
|
||||
/// Get LLVM attribute for the provided inline heuristic.
|
||||
#[inline]
|
||||
fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> {
|
||||
|
|
|
|||
|
|
@ -28,8 +28,9 @@ use crate::back::write::{
|
|||
use crate::errors::{
|
||||
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
|
||||
};
|
||||
use crate::llvm::AttributePlace::Function;
|
||||
use crate::llvm::{self, build_string};
|
||||
use crate::{LlvmCodegenBackend, ModuleLlvm};
|
||||
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
|
||||
|
||||
/// We keep track of the computed LTO cache keys from the previous
|
||||
/// session to determine which CGUs we can reuse.
|
||||
|
|
@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
|
|||
}
|
||||
|
||||
if cfg!(llvm_enzyme) && enable_ad && !thin {
|
||||
let cx =
|
||||
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
|
||||
|
||||
for function in cx.get_functions() {
|
||||
let enzyme_marker = "enzyme_marker";
|
||||
if attributes::has_string_attr(function, enzyme_marker) {
|
||||
// Sanity check: Ensure 'noinline' is present before replacing it.
|
||||
assert!(
|
||||
!attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
|
||||
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
|
||||
);
|
||||
|
||||
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
|
||||
attributes::remove_string_attr_from_llfn(function, enzyme_marker);
|
||||
|
||||
assert!(
|
||||
!attributes::has_string_attr(function, enzyme_marker),
|
||||
"Expected function to not have 'enzyme_marker'"
|
||||
);
|
||||
|
||||
let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
|
||||
attributes::apply_to_llfn(function, Function, &[always_inline]);
|
||||
}
|
||||
}
|
||||
|
||||
let opt_stage = llvm::OptStage::FatLTO;
|
||||
let stage = write::AutodiffStage::PostAD;
|
||||
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
|
||||
|
|
|
|||
|
|
@ -361,6 +361,11 @@ fn generate_enzyme_call<'ll>(
|
|||
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
|
||||
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
|
||||
|
||||
// We add a made-up attribute just such that we can recognize it after AD to update
|
||||
// (no)-inline attributes. We'll then also remove this attribute.
|
||||
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
|
||||
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
|
||||
|
||||
// first, remove all calls from fnc
|
||||
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
|
||||
let br = llvm::LLVMRustGetTerminator(entry);
|
||||
|
|
|
|||
|
|
@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
|
|||
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
|
||||
let mut functions = vec![];
|
||||
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
|
||||
while let Some(f) = func {
|
||||
functions.push(f);
|
||||
func = unsafe { llvm::LLVMGetNextFunction(f) }
|
||||
}
|
||||
functions
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,19 @@ unsafe extern "C" {
|
|||
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
|
||||
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
|
||||
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
|
||||
pub(crate) fn LLVMRustHasFnAttribute(
|
||||
F: &Value,
|
||||
Name: *const c_char,
|
||||
NameLen: libc::size_t,
|
||||
) -> bool;
|
||||
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
|
||||
pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
|
||||
pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
|
||||
pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
|
||||
Fn: &Value,
|
||||
index: c_uint,
|
||||
kind: AttributeKind,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn HasAttributeAtIndex<'ll>(
|
||||
llfn: &'ll Value,
|
||||
idx: AttributePlace,
|
||||
kind: AttributeKind,
|
||||
) -> bool {
|
||||
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
|
||||
}
|
||||
|
||||
pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
|
||||
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
|
||||
}
|
||||
|
||||
pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
|
||||
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
|
||||
}
|
||||
|
||||
pub(crate) fn RemoveRustEnumAttributeAtIndex(
|
||||
llfn: &Value,
|
||||
place: AttributePlace,
|
||||
kind: AttributeKind,
|
||||
) {
|
||||
unsafe {
|
||||
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn AddCallSiteAttributes<'ll>(
|
||||
callsite: &'ll Value,
|
||||
idx: AttributePlace,
|
||||
|
|
|
|||
|
|
@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
|
|||
(**self).borrow().llcx
|
||||
}
|
||||
|
||||
pub(crate) fn llmod(&self) -> &'ll llvm::Module {
|
||||
(**self).borrow().llmod
|
||||
}
|
||||
|
||||
pub(crate) fn isize_ty(&self) -> &'ll Type {
|
||||
(**self).borrow().isize_ty
|
||||
}
|
||||
|
|
|
|||
|
|
@ -973,6 +973,27 @@ extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
extern "C" void
|
||||
LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index,
|
||||
LLVMRustAttributeKind RustAttr) {
|
||||
LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr));
|
||||
}
|
||||
|
||||
extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name,
|
||||
size_t NameLen) {
|
||||
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
|
||||
return Fn->hasFnAttribute(StringRef(Name, NameLen));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name,
|
||||
size_t NameLen) {
|
||||
if (auto *F = dyn_cast<Function>(unwrap<Value>(Fn))) {
|
||||
F->removeFnAttr(StringRef(Name, NameLen));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind,
|
||||
LLVMMetadataRef MD) {
|
||||
unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD));
|
||||
|
|
|
|||
23
tests/codegen/autodiff/inline.rs
Normal file
23
tests/codegen/autodiff/inline.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff;
|
||||
|
||||
#[autodiff(d_square, Reverse, Duplicated, Active)]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// CHECK: ; inline::d_square
|
||||
// CHECK-NEXT: ; Function Attrs: alwaysinline
|
||||
// CHECK-NOT: noinline
|
||||
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
let mut dx1 = std::hint::black_box(1.0);
|
||||
let _ = d_square(&x, &mut dx1, 1.0);
|
||||
assert_eq!(dx1, 6.0);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue