Add TypeTree metadata attachment for autodiff
- Add F128 support to TypeTree Kind enum - Implement TypeTree FFI bindings and conversion functions - Add typetree.rs module for metadata attachment to LLVM functions - Integrate TypeTree generation with autodiff intrinsic pipeline - Support scalar types: f32, f64, integers, f16, f128 - Attach enzyme_type attributes as LLVM string metadata for Enzyme Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
This commit is contained in:
parent
e1258e79d6
commit
375e14ef49
7 changed files with 343 additions and 14 deletions
|
|
@ -31,6 +31,7 @@ pub enum Kind {
|
|||
Half,
|
||||
Float,
|
||||
Double,
|
||||
F128,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::ptr;
|
||||
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
|
||||
|
|
@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
|
|||
fn_args: &[&'ll Value],
|
||||
attrs: AutoDiffAttrs,
|
||||
dest: PlaceRef<'tcx, &'ll Value>,
|
||||
fnc_tree: FncTree,
|
||||
) {
|
||||
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
|
||||
let mut ad_name: String = match attrs.mode {
|
||||
|
|
@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
|
|||
fn_args,
|
||||
);
|
||||
|
||||
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
|
||||
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
|
||||
}
|
||||
|
||||
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
|
||||
|
||||
builder.store_to_place(call, dest.val);
|
||||
|
|
|
|||
|
|
@ -1213,6 +1213,9 @@ fn codegen_autodiff<'ll, 'tcx>(
|
|||
&mut diff_attrs.input_activity,
|
||||
);
|
||||
|
||||
let fnc_tree =
|
||||
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
|
||||
|
||||
// Build body
|
||||
generate_enzyme_call(
|
||||
bx,
|
||||
|
|
@ -1223,6 +1226,7 @@ fn codegen_autodiff<'ll, 'tcx>(
|
|||
&val_arr,
|
||||
diff_attrs.clone(),
|
||||
result,
|
||||
fnc_tree,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ mod llvm_util;
|
|||
mod mono_item;
|
||||
mod type_;
|
||||
mod type_of;
|
||||
mod typetree;
|
||||
mod va_arg;
|
||||
mod value;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,35 @@
|
|||
use libc::{c_char, c_uint};
|
||||
|
||||
use super::MetadataKindId;
|
||||
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
|
||||
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
|
||||
use crate::llvm::{Bool, Builder};
|
||||
|
||||
// TypeTree types
|
||||
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub(crate) struct EnzymeTypeTree {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub(crate) enum CConcreteType {
|
||||
DT_Anything = 0,
|
||||
DT_Integer = 1,
|
||||
DT_Pointer = 2,
|
||||
DT_Half = 3,
|
||||
DT_Float = 4,
|
||||
DT_Double = 5,
|
||||
DT_Unknown = 6,
|
||||
}
|
||||
|
||||
pub(crate) struct TypeTree {
|
||||
pub(crate) inner: CTypeTreeRef,
|
||||
}
|
||||
|
||||
#[link(name = "llvm-wrapper", kind = "static")]
|
||||
unsafe extern "C" {
|
||||
// Enzyme
|
||||
|
|
@ -68,10 +94,33 @@ pub(crate) mod Enzyme_AD {
|
|||
|
||||
use libc::c_void;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
|
||||
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
|
||||
}
|
||||
|
||||
// TypeTree functions
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
|
||||
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
|
||||
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
static mut EnzymePrintPerf: c_void;
|
||||
static mut EnzymePrintActivity: c_void;
|
||||
|
|
@ -141,6 +190,57 @@ pub(crate) use self::Fallback_AD::*;
|
|||
pub(crate) mod Fallback_AD {
|
||||
#![allow(unused_variables)]
|
||||
|
||||
use libc::c_char;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
// TypeTree function fallbacks
|
||||
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) fn set_inline(val: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
|
@ -169,3 +269,83 @@ pub(crate) mod Fallback_AD {
|
|||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeTree {
|
||||
pub(crate) fn new() -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTree() };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
unsafe {
|
||||
EnzymeMergeTypeTree(self.inner, other.inner);
|
||||
}
|
||||
drop(other);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn shift(
|
||||
self,
|
||||
layout: &str,
|
||||
offset: isize,
|
||||
max_size: isize,
|
||||
add_offset: usize,
|
||||
) -> Self {
|
||||
let layout = std::ffi::CString::new(layout).unwrap();
|
||||
|
||||
unsafe {
|
||||
EnzymeTypeTreeShiftIndiciesEq(
|
||||
self.inner,
|
||||
layout.as_ptr(),
|
||||
offset as i64,
|
||||
max_size as i64,
|
||||
add_offset as u64,
|
||||
);
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TypeTree {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
|
||||
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
|
||||
match cstr.to_str() {
|
||||
Ok(x) => write!(f, "{}", x)?,
|
||||
Err(err) => write!(f, "could not parse: {}", err)?,
|
||||
}
|
||||
|
||||
// delete C string pointer
|
||||
unsafe {
|
||||
EnzymeTypeTreeToStringFree(ptr);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
<Self as std::fmt::Display>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TypeTree {
|
||||
fn drop(&mut self) {
|
||||
unsafe { EnzymeFreeTypeTree(self.inner) }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
144
compiler/rustc_codegen_llvm/src/typetree.rs
Normal file
144
compiler/rustc_codegen_llvm/src/typetree.rs
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
use std::ffi::{CString, c_char, c_uint};
|
||||
|
||||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
|
||||
|
||||
use crate::attributes;
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
|
||||
///
|
||||
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
|
||||
/// and converts it to Enzyme's internal C++ TypeTree representation that
|
||||
/// Enzyme can understand during differentiation analysis.
|
||||
#[cfg(llvm_enzyme)]
|
||||
fn to_enzyme_typetree(
|
||||
rust_typetree: RustTypeTree,
|
||||
data_layout: &str,
|
||||
llcx: &llvm::Context,
|
||||
) -> llvm::TypeTree {
|
||||
// Start with an empty TypeTree
|
||||
let mut enzyme_tt = llvm::TypeTree::new();
|
||||
|
||||
// Convert each Type in the Rust TypeTree to Enzyme format
|
||||
for rust_type in rust_typetree.0 {
|
||||
let concrete_type = match rust_type.kind {
|
||||
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
|
||||
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
|
||||
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
|
||||
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
|
||||
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
|
||||
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
|
||||
rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_Unknown,
|
||||
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
|
||||
};
|
||||
|
||||
// Create a TypeTree for this specific type
|
||||
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
|
||||
|
||||
// Apply offset if specified
|
||||
let type_tt = if rust_type.offset == -1 {
|
||||
type_tt // -1 means everywhere/no specific offset
|
||||
} else {
|
||||
// Apply specific offset positioning
|
||||
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
|
||||
};
|
||||
|
||||
// Merge this type into the main TypeTree
|
||||
enzyme_tt = enzyme_tt.merge(type_tt);
|
||||
}
|
||||
|
||||
enzyme_tt
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
fn to_enzyme_typetree(
|
||||
_rust_typetree: RustTypeTree,
|
||||
_data_layout: &str,
|
||||
_llcx: &llvm::Context,
|
||||
) -> ! {
|
||||
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
|
||||
}
|
||||
|
||||
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
|
||||
#[cfg(llvm_enzyme)]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
llmod: &'ll llvm::Module,
|
||||
llcx: &'ll llvm::Context,
|
||||
fn_def: &'ll Value,
|
||||
tt: FncTree,
|
||||
) {
|
||||
let inputs = tt.args;
|
||||
let ret_tt: RustTypeTree = tt.ret;
|
||||
|
||||
// Get LLVM data layout string for TypeTree conversion
|
||||
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
|
||||
let llvm_data_layout =
|
||||
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
|
||||
.expect("got a non-UTF8 data-layout from LLVM");
|
||||
|
||||
// Attribute name that Enzyme recognizes for TypeTree information
|
||||
let attr_name = "enzyme_type";
|
||||
let c_attr_name = CString::new(attr_name).unwrap();
|
||||
|
||||
// Attach TypeTree attributes to each input parameter
|
||||
// Enzyme uses these to understand parameter memory layouts during differentiation
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
unsafe {
|
||||
// Convert Rust TypeTree to Enzyme's internal format
|
||||
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
|
||||
|
||||
// Serialize TypeTree to string format that Enzyme can parse
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
let c_str = std::ffi::CStr::from_ptr(c_str);
|
||||
|
||||
// Create LLVM string attribute with TypeTree information
|
||||
let attr = llvm::LLVMCreateStringAttribute(
|
||||
llcx,
|
||||
c_attr_name.as_ptr(),
|
||||
c_attr_name.as_bytes().len() as c_uint,
|
||||
c_str.as_ptr(),
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
// Attach attribute to the specific function parameter
|
||||
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
|
||||
|
||||
// Free the C string to prevent memory leaks
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Attach TypeTree attribute to the return type
|
||||
// Enzyme needs this to understand how to handle return value derivatives
|
||||
unsafe {
|
||||
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
let c_str = std::ffi::CStr::from_ptr(c_str);
|
||||
|
||||
let ret_attr = llvm::LLVMCreateStringAttribute(
|
||||
llcx,
|
||||
c_attr_name.as_ptr(),
|
||||
c_attr_name.as_bytes().len() as c_uint,
|
||||
c_str.as_ptr(),
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
// Attach to function return type
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
|
||||
|
||||
// Free the C string
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback implementation when Enzyme is not available
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
_llmod: &'ll llvm::Module,
|
||||
_llcx: &'ll llvm::Context,
|
||||
_fn_def: &'ll Value,
|
||||
_tt: FncTree,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
|
@ -2251,36 +2251,29 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
|
|||
|
||||
/// Generate TypeTree for a specific type.
|
||||
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
|
||||
fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
// Handle basic scalar types
|
||||
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
if ty.is_scalar() {
|
||||
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
|
||||
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
|
||||
} else if ty.is_floating_point() {
|
||||
match ty {
|
||||
x if x == tcx.types.f16 => (Kind::Half, 2),
|
||||
x if x == tcx.types.f32 => (Kind::Float, 4),
|
||||
x if x == tcx.types.f64 => (Kind::Double, 8),
|
||||
_ => return TypeTree::new(), // Unknown float type
|
||||
x if x == tcx.types.f128 => (Kind::F128, 16),
|
||||
_ => return TypeTree::new(),
|
||||
}
|
||||
} else {
|
||||
// TODO(KMJ-007): Handle other scalar types if needed
|
||||
return TypeTree::new();
|
||||
};
|
||||
|
||||
return TypeTree(vec![Type {
|
||||
offset: -1,
|
||||
size,
|
||||
kind,
|
||||
child: TypeTree::new()
|
||||
}]);
|
||||
|
||||
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
|
||||
}
|
||||
|
||||
// Handle references and pointers
|
||||
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
|
||||
let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
|
||||
inner
|
||||
} else {
|
||||
// TODO(KMJ-007): Handle complex pointer types
|
||||
return TypeTree::new();
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue