Rollup merge of #144197 - KMJ-007:type-tree, r=ZuseZ4
TypeTree support in autodiff
# TypeTrees for Autodiff
## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
## Structure
```rust
TypeTree(Vec<Type>)
Type {
offset: isize, // byte offset (-1 = everywhere)
size: usize, // size in bytes
kind: Kind, // Float, Integer, Pointer, etc.
child: TypeTree // nested structure
}
```
## Example: `fn compute(x: &f32, data: &[f32]) -> f32`
**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float,
child: TypeTree::new()
}])
}])
```
**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float, // -1 = all elements
child: TypeTree::new()
}])
}])
```
**Output: `f32`**
```rust
TypeTree(vec![Type {
offset: -1, size: 4, kind: Float,
child: TypeTree::new()
}])
```
## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata
## What Enzyme Does With This Information:
Without TypeTrees (current state):
```llvm
; Enzyme sees generic LLVM IR:
define float ``@distance(ptr*`` %p1, ptr* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```
With TypeTrees (our implementation):
```llvm
define "enzyme_type"="{[]:Float@float}" float ``@distance(``
ptr "enzyme_type"="{[]:Pointer}" %p1,
ptr "enzyme_type"="{[]:Pointer}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```
# TypeTrees - Offset and -1 Explained
## Type Structure
```rust
Type {
offset: isize, // WHERE this type starts
size: usize, // HOW BIG this type is
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```
## Offset Values
### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**
```rust
struct Point {
x: f32, // offset 0, size 4
y: f32, // offset 4, size 4
id: i32, // offset 8, size 4
}
```
TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
])
```
Generates LLVM:
```llvm
"enzyme_type"="{[]:Float@float}"
```
### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**
#### Example 1: Array `[f32; 100]`
```rust
TypeTree(vec![Type {
offset: -1, // ALL positions
size: 4, // each f32 is 4 bytes
kind: Float, // every element is float
}])
```
Instead of listing 100 separate Types with offsets `0,4,8,12...396`
#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
offset: -1, size: 8, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, // ALL slice elements
size: 4, // each i32 is 4 bytes
kind: Integer
}])
}])
```
#### Example 3: Mixed Structure
```rust
struct Container {
header: i64, // offset 0
data: [f32; 1000], // offset 8, but elements use -1
}
```
```rust
TypeTree(vec![
Type { offset: 0, size: 8, kind: Integer }, // header
Type { offset: 8, size: 4000, kind: Pointer,
child: TypeTree(vec![Type {
offset: -1, size: 4, kind: Float // ALL array elements
}])
}
])
```
This commit is contained in:
commit
c29fb2e57e
68 changed files with 1250 additions and 14 deletions
|
|
@ -6,6 +6,7 @@
|
|||
use std::fmt::{self, Display, Formatter};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::expand::typetree::TypeTree;
|
||||
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||
use crate::{Ty, TyKind};
|
||||
|
||||
|
|
@ -84,6 +85,8 @@ pub struct AutoDiffItem {
|
|||
/// The name of the function being generated
|
||||
pub target: String,
|
||||
pub attrs: AutoDiffAttrs,
|
||||
pub inputs: Vec<TypeTree>,
|
||||
pub output: TypeTree,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||
|
|
@ -275,14 +278,22 @@ impl AutoDiffAttrs {
|
|||
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
||||
}
|
||||
|
||||
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
|
||||
AutoDiffItem { source, target, attrs: self }
|
||||
pub fn into_item(
|
||||
self,
|
||||
source: String,
|
||||
target: String,
|
||||
inputs: Vec<TypeTree>,
|
||||
output: TypeTree,
|
||||
) -> AutoDiffItem {
|
||||
AutoDiffItem { source, target, inputs, output, attrs: self }
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AutoDiffItem {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
||||
write!(f, " with attributes: {:?}", self.attrs)
|
||||
write!(f, " with attributes: {:?}", self.attrs)?;
|
||||
write!(f, " with inputs: {:?}", self.inputs)?;
|
||||
write!(f, " with output: {:?}", self.output)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ pub enum Kind {
|
|||
Half,
|
||||
Float,
|
||||
Double,
|
||||
F128,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
|
|||
_src_align: Align,
|
||||
size: RValue<'gcc>,
|
||||
flags: MemFlags,
|
||||
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
|
||||
) {
|
||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||
let size = self.intcast(size, self.type_size_t(), false);
|
||||
|
|
|
|||
|
|
@ -770,6 +770,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
|||
scratch_align,
|
||||
bx.const_usize(self.layout.size.bytes()),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
|
||||
bx.lifetime_end(scratch, scratch_size);
|
||||
|
|
|
|||
|
|
@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
|
|||
scratch_align,
|
||||
bx.const_usize(copy_bytes),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
bx.lifetime_end(llscratch, scratch_size);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
|
|||
config::AutoDiff::Enable => {}
|
||||
// We handle this below
|
||||
config::AutoDiff::NoPostopt => {}
|
||||
// Disables TypeTree generation
|
||||
config::AutoDiff::NoTT => {}
|
||||
}
|
||||
}
|
||||
// This helps with handling enums for now.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
|
|||
use std::ops::Deref;
|
||||
use std::{iter, ptr};
|
||||
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
pub(crate) mod autodiff;
|
||||
pub(crate) mod gpu_offload;
|
||||
|
||||
|
|
@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
|||
src_align: Align,
|
||||
size: &'ll Value,
|
||||
flags: MemFlags,
|
||||
tt: Option<FncTree>,
|
||||
) {
|
||||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
|
||||
let size = self.intcast(size, self.type_isize(), false);
|
||||
let is_volatile = flags.contains(MemFlags::VOLATILE);
|
||||
unsafe {
|
||||
let memcpy = unsafe {
|
||||
llvm::LLVMRustBuildMemCpy(
|
||||
self.llbuilder,
|
||||
dst,
|
||||
|
|
@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
|
|||
src_align.bytes() as c_uint,
|
||||
size,
|
||||
is_volatile,
|
||||
);
|
||||
)
|
||||
};
|
||||
|
||||
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
|
||||
// a memcpy during autodiff, it needs to know the structure of the data being
|
||||
// copied to properly track derivatives. For example, copying an array of floats
|
||||
// vs. copying a struct with mixed types requires different derivative handling.
|
||||
// The TypeTree tells Enzyme exactly what memory layout to expect.
|
||||
if let Some(tt) = tt {
|
||||
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::ptr;
|
||||
|
||||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
|
||||
use rustc_ast::expand::typetree::FncTree;
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
|
||||
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
|
||||
|
|
@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
|
|||
fn_args: &[&'ll Value],
|
||||
attrs: AutoDiffAttrs,
|
||||
dest: PlaceRef<'tcx, &'ll Value>,
|
||||
fnc_tree: FncTree,
|
||||
) {
|
||||
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
|
||||
let mut ad_name: String = match attrs.mode {
|
||||
|
|
@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
|
|||
fn_args,
|
||||
);
|
||||
|
||||
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
|
||||
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
|
||||
}
|
||||
|
||||
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
|
||||
|
||||
builder.store_to_place(call, dest.val);
|
||||
|
|
|
|||
|
|
@ -1212,6 +1212,9 @@ fn codegen_autodiff<'ll, 'tcx>(
|
|||
&mut diff_attrs.input_activity,
|
||||
);
|
||||
|
||||
let fnc_tree =
|
||||
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
|
||||
|
||||
// Build body
|
||||
generate_enzyme_call(
|
||||
bx,
|
||||
|
|
@ -1222,6 +1225,7 @@ fn codegen_autodiff<'ll, 'tcx>(
|
|||
&val_arr,
|
||||
diff_attrs.clone(),
|
||||
result,
|
||||
fnc_tree,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ mod llvm_util;
|
|||
mod mono_item;
|
||||
mod type_;
|
||||
mod type_of;
|
||||
mod typetree;
|
||||
mod va_arg;
|
||||
mod value;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,36 @@
|
|||
use libc::{c_char, c_uint};
|
||||
|
||||
use super::MetadataKindId;
|
||||
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
|
||||
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
|
||||
use crate::llvm::{Bool, Builder};
|
||||
|
||||
// TypeTree types
|
||||
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub(crate) struct EnzymeTypeTree {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub(crate) enum CConcreteType {
|
||||
DT_Anything = 0,
|
||||
DT_Integer = 1,
|
||||
DT_Pointer = 2,
|
||||
DT_Half = 3,
|
||||
DT_Float = 4,
|
||||
DT_Double = 5,
|
||||
DT_Unknown = 6,
|
||||
DT_FP128 = 9,
|
||||
}
|
||||
|
||||
pub(crate) struct TypeTree {
|
||||
pub(crate) inner: CTypeTreeRef,
|
||||
}
|
||||
|
||||
#[link(name = "llvm-wrapper", kind = "static")]
|
||||
unsafe extern "C" {
|
||||
// Enzyme
|
||||
|
|
@ -68,10 +95,40 @@ pub(crate) mod Enzyme_AD {
|
|||
|
||||
use libc::c_void;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
|
||||
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
|
||||
}
|
||||
|
||||
// TypeTree functions
|
||||
unsafe extern "C" {
|
||||
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
|
||||
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
|
||||
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
|
||||
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
|
||||
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeInsertEq(
|
||||
CTT: CTypeTreeRef,
|
||||
indices: *const i64,
|
||||
len: usize,
|
||||
ct: CConcreteType,
|
||||
ctx: &Context,
|
||||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
static mut EnzymePrintPerf: c_void;
|
||||
static mut EnzymePrintActivity: c_void;
|
||||
|
|
@ -141,6 +198,67 @@ pub(crate) use self::Fallback_AD::*;
|
|||
pub(crate) mod Fallback_AD {
|
||||
#![allow(unused_variables)]
|
||||
|
||||
use libc::c_char;
|
||||
|
||||
use super::{CConcreteType, CTypeTreeRef, Context};
|
||||
|
||||
// TypeTree function fallbacks
|
||||
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
|
||||
arg1: CTypeTreeRef,
|
||||
data_layout: *const c_char,
|
||||
offset: i64,
|
||||
max_size: i64,
|
||||
add_offset: u64,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
|
||||
CTT: CTypeTreeRef,
|
||||
indices: *const i64,
|
||||
len: usize,
|
||||
ct: CConcreteType,
|
||||
ctx: &Context,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub(crate) fn set_inline(val: bool) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
|
@ -169,3 +287,89 @@ pub(crate) mod Fallback_AD {
|
|||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeTree {
|
||||
pub(crate) fn new() -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTree() };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
|
||||
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
|
||||
pub(crate) fn merge(self, other: Self) -> Self {
|
||||
unsafe {
|
||||
EnzymeMergeTypeTree(self.inner, other.inner);
|
||||
}
|
||||
drop(other);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn shift(
|
||||
self,
|
||||
layout: &str,
|
||||
offset: isize,
|
||||
max_size: isize,
|
||||
add_offset: usize,
|
||||
) -> Self {
|
||||
let layout = std::ffi::CString::new(layout).unwrap();
|
||||
|
||||
unsafe {
|
||||
EnzymeTypeTreeShiftIndiciesEq(
|
||||
self.inner,
|
||||
layout.as_ptr(),
|
||||
offset as i64,
|
||||
max_size as i64,
|
||||
add_offset as u64,
|
||||
);
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
|
||||
unsafe {
|
||||
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TypeTree {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
|
||||
TypeTree { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
|
||||
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
|
||||
match cstr.to_str() {
|
||||
Ok(x) => write!(f, "{}", x)?,
|
||||
Err(err) => write!(f, "could not parse: {}", err)?,
|
||||
}
|
||||
|
||||
// delete C string pointer
|
||||
unsafe {
|
||||
EnzymeTypeTreeToStringFree(ptr);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TypeTree {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
<Self as std::fmt::Display>::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TypeTree {
|
||||
fn drop(&mut self) {
|
||||
unsafe { EnzymeFreeTypeTree(self.inner) }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
122
compiler/rustc_codegen_llvm/src/typetree.rs
Normal file
122
compiler/rustc_codegen_llvm/src/typetree.rs
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
use rustc_ast::expand::typetree::FncTree;
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
use {
|
||||
crate::attributes,
|
||||
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
|
||||
std::ffi::{CString, c_char, c_uint},
|
||||
};
|
||||
|
||||
use crate::llvm::{self, Value};
|
||||
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
fn to_enzyme_typetree(
|
||||
rust_typetree: RustTypeTree,
|
||||
_data_layout: &str,
|
||||
llcx: &llvm::Context,
|
||||
) -> llvm::TypeTree {
|
||||
let mut enzyme_tt = llvm::TypeTree::new();
|
||||
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
|
||||
enzyme_tt
|
||||
}
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
fn process_typetree_recursive(
|
||||
enzyme_tt: &mut llvm::TypeTree,
|
||||
rust_typetree: &RustTypeTree,
|
||||
parent_indices: &[i64],
|
||||
llcx: &llvm::Context,
|
||||
) {
|
||||
for rust_type in &rust_typetree.0 {
|
||||
let concrete_type = match rust_type.kind {
|
||||
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
|
||||
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
|
||||
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
|
||||
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
|
||||
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
|
||||
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
|
||||
rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
|
||||
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
|
||||
};
|
||||
|
||||
let mut indices = parent_indices.to_vec();
|
||||
if !parent_indices.is_empty() {
|
||||
indices.push(rust_type.offset as i64);
|
||||
} else if rust_type.offset == -1 {
|
||||
indices.push(-1);
|
||||
} else {
|
||||
indices.push(rust_type.offset as i64);
|
||||
}
|
||||
|
||||
enzyme_tt.insert(&indices, concrete_type, llcx);
|
||||
|
||||
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
|
||||
&& !rust_type.child.0.is_empty()
|
||||
{
|
||||
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "llvm_enzyme")]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
llmod: &'ll llvm::Module,
|
||||
llcx: &'ll llvm::Context,
|
||||
fn_def: &'ll Value,
|
||||
tt: FncTree,
|
||||
) {
|
||||
let inputs = tt.args;
|
||||
let ret_tt: RustTypeTree = tt.ret;
|
||||
|
||||
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
|
||||
let llvm_data_layout =
|
||||
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
|
||||
.expect("got a non-UTF8 data-layout from LLVM");
|
||||
|
||||
let attr_name = "enzyme_type";
|
||||
let c_attr_name = CString::new(attr_name).unwrap();
|
||||
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
unsafe {
|
||||
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
let c_str = std::ffi::CStr::from_ptr(c_str);
|
||||
|
||||
let attr = llvm::LLVMCreateStringAttribute(
|
||||
llcx,
|
||||
c_attr_name.as_ptr(),
|
||||
c_attr_name.as_bytes().len() as c_uint,
|
||||
c_str.as_ptr(),
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
|
||||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
|
||||
let c_str = std::ffi::CStr::from_ptr(c_str);
|
||||
|
||||
let ret_attr = llvm::LLVMCreateStringAttribute(
|
||||
llcx,
|
||||
c_attr_name.as_ptr(),
|
||||
c_attr_name.as_bytes().len() as c_uint,
|
||||
c_str.as_ptr(),
|
||||
c_str.to_bytes().len() as c_uint,
|
||||
);
|
||||
|
||||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
|
||||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "llvm_enzyme"))]
|
||||
pub(crate) fn add_tt<'ll>(
|
||||
_llmod: &'ll llvm::Module,
|
||||
_llcx: &'ll llvm::Context,
|
||||
_fn_def: &'ll Value,
|
||||
_tt: FncTree,
|
||||
) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
|
@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
|
|||
src_align,
|
||||
bx.const_u32(layout.layout.size().bytes() as u32),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
tmp
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
|||
align,
|
||||
bx.const_usize(copy_bytes),
|
||||
MemFlags::empty(),
|
||||
None,
|
||||
);
|
||||
// ...and then load it with the ABI type.
|
||||
llval = load_cast(bx, cast, llscratch, scratch_align);
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
|
|||
if allow_overlap {
|
||||
bx.memmove(dst, align, src, align, size, flags);
|
||||
} else {
|
||||
bx.memcpy(dst, align, src, align, size, flags);
|
||||
bx.memcpy(dst, align, src, align, size, flags, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
|
|||
let align = pointee_layout.align;
|
||||
let dst = dst_val.immediate();
|
||||
let src = src_val.immediate();
|
||||
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
|
||||
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
|
||||
}
|
||||
mir::StatementKind::FakeRead(..)
|
||||
| mir::StatementKind::Retag { .. }
|
||||
|
|
|
|||
|
|
@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
|||
src_align: Align,
|
||||
size: Self::Value,
|
||||
flags: MemFlags,
|
||||
tt: Option<rustc_ast::expand::typetree::FncTree>,
|
||||
);
|
||||
fn memmove(
|
||||
&mut self,
|
||||
|
|
@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
|
|||
temp.val.store_with_flags(self, dst.with_type(layout), flags);
|
||||
} else if !layout.is_zst() {
|
||||
let bytes = self.const_usize(layout.size.bytes());
|
||||
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
|
||||
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -765,7 +765,7 @@ fn test_unstable_options_tracking_hash() {
|
|||
tracked!(allow_features, Some(vec![String::from("lang_items")]));
|
||||
tracked!(always_encode_mir, true);
|
||||
tracked!(assume_incomplete_release, true);
|
||||
tracked!(autodiff, vec![AutoDiff::Enable]);
|
||||
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
|
||||
tracked!(binary_dep_depinfo, true);
|
||||
tracked!(box_noalias, false);
|
||||
tracked!(
|
||||
|
|
|
|||
|
|
@ -1812,3 +1812,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
|
|||
MD.NoHWAddress = true;
|
||||
GV.setSanitizerMetadata(MD);
|
||||
}
|
||||
|
||||
#ifdef ENZYME
|
||||
extern "C" {
|
||||
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
|
||||
}
|
||||
|
||||
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
|
||||
#else
|
||||
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
|
||||
return 6; // Default fallback depth
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
|
|||
pub sub: TypeMismatchReason,
|
||||
}
|
||||
|
||||
// FIXME(autodiff): I should get used somewhere
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(middle_unsupported_union)]
|
||||
pub struct UnsupportedUnion {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
|
|||
pub use generics::*;
|
||||
pub use intrinsic::IntrinsicDef;
|
||||
use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx};
|
||||
use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
|
||||
use rustc_ast::node_id::NodeMap;
|
||||
pub use rustc_ast_ir::{Movability, Mutability, try_visit};
|
||||
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
|
||||
|
|
@ -62,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
|
|||
pub use rustc_type_ir::*;
|
||||
#[allow(hidden_glob_reexports, unused_imports)]
|
||||
use rustc_type_ir::{InferCtxtLike, Interner};
|
||||
use tracing::{debug, instrument};
|
||||
use tracing::{debug, instrument, trace};
|
||||
pub use vtable::*;
|
||||
use {rustc_ast as ast, rustc_hir as hir};
|
||||
|
||||
|
|
@ -2216,3 +2217,225 @@ pub struct DestructuredConst<'tcx> {
|
|||
pub variant: Option<VariantIdx>,
|
||||
pub fields: &'tcx [ty::Const<'tcx>],
|
||||
}
|
||||
|
||||
/// Generate TypeTree information for autodiff.
|
||||
/// This function creates TypeTree metadata that describes the memory layout
|
||||
/// of function parameters and return types for Enzyme autodiff.
|
||||
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
|
||||
// Check if TypeTrees are disabled via NoTT flag
|
||||
if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
|
||||
return FncTree { args: vec![], ret: TypeTree::new() };
|
||||
}
|
||||
|
||||
// Check if this is actually a function type
|
||||
if !fn_ty.is_fn() {
|
||||
return FncTree { args: vec![], ret: TypeTree::new() };
|
||||
}
|
||||
|
||||
// Get the function signature
|
||||
let fn_sig = fn_ty.fn_sig(tcx);
|
||||
let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
|
||||
|
||||
// Create TypeTrees for each input parameter
|
||||
let mut args = vec![];
|
||||
for ty in sig.inputs().iter() {
|
||||
let type_tree = typetree_from_ty(tcx, *ty);
|
||||
args.push(type_tree);
|
||||
}
|
||||
|
||||
// Create TypeTree for return type
|
||||
let ret = typetree_from_ty(tcx, sig.output());
|
||||
|
||||
FncTree { args, ret }
|
||||
}
|
||||
|
||||
/// Generate TypeTree for a specific type.
|
||||
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
|
||||
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
let mut visited = Vec::new();
|
||||
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
|
||||
}
|
||||
|
||||
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
|
||||
/// from pathological deeply nested types. Combined with cycle detection.
|
||||
const MAX_TYPETREE_DEPTH: usize = 6;
|
||||
|
||||
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
|
||||
fn typetree_from_ty_inner<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
) -> TypeTree {
|
||||
if depth >= MAX_TYPETREE_DEPTH {
|
||||
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
if visited.contains(&ty) {
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
visited.push(ty);
|
||||
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
|
||||
visited.pop();
|
||||
result
|
||||
}
|
||||
|
||||
/// Implementation of TypeTree generation logic.
|
||||
fn typetree_from_ty_impl<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
) -> TypeTree {
|
||||
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
|
||||
}
|
||||
|
||||
/// Internal implementation with context about whether this is for a reference target.
|
||||
fn typetree_from_ty_impl_inner<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
is_reference_target: bool,
|
||||
) -> TypeTree {
|
||||
if ty.is_scalar() {
|
||||
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
|
||||
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
|
||||
} else if ty.is_floating_point() {
|
||||
match ty {
|
||||
x if x == tcx.types.f16 => (Kind::Half, 2),
|
||||
x if x == tcx.types.f32 => (Kind::Float, 4),
|
||||
x if x == tcx.types.f64 => (Kind::Double, 8),
|
||||
x if x == tcx.types.f128 => (Kind::F128, 16),
|
||||
_ => (Kind::Integer, 0),
|
||||
}
|
||||
} else {
|
||||
(Kind::Integer, 0)
|
||||
};
|
||||
|
||||
// Use offset 0 for scalars that are direct targets of references (like &f64)
|
||||
// Use offset -1 for scalars used directly (like function return types)
|
||||
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
|
||||
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
|
||||
}
|
||||
|
||||
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
|
||||
let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
|
||||
inner
|
||||
} else {
|
||||
return TypeTree::new();
|
||||
};
|
||||
|
||||
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
|
||||
return TypeTree(vec![Type {
|
||||
offset: -1,
|
||||
size: tcx.data_layout.pointer_size().bytes_usize(),
|
||||
kind: Kind::Pointer,
|
||||
child,
|
||||
}]);
|
||||
}
|
||||
|
||||
if ty.is_array() {
|
||||
if let ty::Array(element_ty, len_const) = ty.kind() {
|
||||
let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
|
||||
if len == 0 {
|
||||
return TypeTree::new();
|
||||
}
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
|
||||
let mut types = Vec::new();
|
||||
for elem_type in &element_tree.0 {
|
||||
types.push(Type {
|
||||
offset: -1,
|
||||
size: elem_type.size,
|
||||
kind: elem_type.kind,
|
||||
child: elem_type.child.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
return TypeTree(types);
|
||||
}
|
||||
}
|
||||
|
||||
if ty.is_slice() {
|
||||
if let ty::Slice(element_ty) = ty.kind() {
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
|
||||
return element_tree;
|
||||
}
|
||||
}
|
||||
|
||||
if let ty::Tuple(tuple_types) = ty.kind() {
|
||||
if tuple_types.is_empty() {
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
let mut types = Vec::new();
|
||||
let mut current_offset = 0;
|
||||
|
||||
for tuple_ty in tuple_types.iter() {
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
|
||||
|
||||
let element_layout = tcx
|
||||
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
|
||||
.ok()
|
||||
.map(|layout| layout.size.bytes_usize())
|
||||
.unwrap_or(0);
|
||||
|
||||
for elem_type in &element_tree.0 {
|
||||
types.push(Type {
|
||||
offset: if elem_type.offset == -1 {
|
||||
current_offset as isize
|
||||
} else {
|
||||
current_offset as isize + elem_type.offset
|
||||
},
|
||||
size: elem_type.size,
|
||||
kind: elem_type.kind,
|
||||
child: elem_type.child.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
current_offset += element_layout;
|
||||
}
|
||||
|
||||
return TypeTree(types);
|
||||
}
|
||||
|
||||
if let ty::Adt(adt_def, args) = ty.kind() {
|
||||
if adt_def.is_struct() {
|
||||
let struct_layout =
|
||||
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
|
||||
if let Ok(layout) = struct_layout {
|
||||
let mut types = Vec::new();
|
||||
|
||||
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
|
||||
let field_ty = field_def.ty(tcx, args);
|
||||
let field_tree =
|
||||
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
|
||||
|
||||
let field_offset = layout.fields.offset(field_idx).bytes_usize();
|
||||
|
||||
for elem_type in &field_tree.0 {
|
||||
types.push(Type {
|
||||
offset: if elem_type.offset == -1 {
|
||||
field_offset as isize
|
||||
} else {
|
||||
field_offset as isize + elem_type.offset
|
||||
},
|
||||
size: elem_type.size,
|
||||
kind: elem_type.kind,
|
||||
child: elem_type.child.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return TypeTree(types);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TypeTree::new()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -258,6 +258,8 @@ pub enum AutoDiff {
|
|||
LooseTypes,
|
||||
/// Runs Enzyme's aggressive inlining
|
||||
Inline,
|
||||
/// Disable Type Tree
|
||||
NoTT,
|
||||
}
|
||||
|
||||
/// Settings for `-Z instrument-xray` flag.
|
||||
|
|
|
|||
|
|
@ -792,7 +792,7 @@ mod desc {
|
|||
pub(crate) const parse_list: &str = "a space-separated list of strings";
|
||||
pub(crate) const parse_list_with_polarity: &str =
|
||||
"a comma-separated list of strings, with elements beginning with + or -";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
|
||||
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`";
|
||||
pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
|
||||
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
|
||||
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
|
||||
|
|
@ -1481,6 +1481,7 @@ pub mod parse {
|
|||
"PrintPasses" => AutoDiff::PrintPasses,
|
||||
"LooseTypes" => AutoDiff::LooseTypes,
|
||||
"Inline" => AutoDiff::Inline,
|
||||
"NoTT" => AutoDiff::NoTT,
|
||||
_ => {
|
||||
// FIXME(ZuseZ4): print an error saying which value is not recognized
|
||||
return false;
|
||||
|
|
|
|||
33
tests/codegen-llvm/autodiff/typetree.rs
Normal file
33
tests/codegen-llvm/autodiff/typetree.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
// Test that basic autodiff still works with our TypeTree infrastructure
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_simple, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
#[inline(never)]
|
||||
fn simple(x: &f64) -> f64 {
|
||||
2.0 * x
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @simple
|
||||
// CHECK: fmul double
|
||||
|
||||
// The derivative function should be generated normally
|
||||
// CHECK-LABEL: diffesimple
|
||||
// CHECK: fadd fast double
|
||||
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
let output = simple(&x);
|
||||
assert_eq!(6.0, output);
|
||||
|
||||
let mut df_dx = 0.0;
|
||||
let output_ = d_simple(&x, &mut df_dx, 1.0);
|
||||
assert_eq!(output, output_);
|
||||
assert_eq!(2.0, df_dx);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that array TypeTree metadata is correctly generated
|
||||
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
15
tests/run-make/autodiff/type-trees/array-typetree/test.rs
Normal file
15
tests/run-make/autodiff/type-trees/array-typetree/test.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_array(arr: &[f64; 5]) -> f64 {
|
||||
arr[0] + arr[1] + arr[2] + arr[3] + arr[4]
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let mut d_arr = [0.0; 5];
|
||||
let _result = d_test(&arr, &mut d_arr, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
; Check that enzyme_type attributes are present in the LLVM IR function definition
|
||||
; This verifies our TypeTree system correctly attaches metadata for Enzyme
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
|
||||
|
||||
; Check that llvm.memcpy exists (either call or declare)
|
||||
CHECK: {{(call|declare).*}}@llvm.memcpy
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
CHECK: force_memcpy
|
||||
|
||||
CHECK: @llvm.memcpy.p0.p0.i64
|
||||
|
||||
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
|
||||
|
||||
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
|
||||
|
||||
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
|
||||
|
||||
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
|
||||
|
||||
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
|
||||
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
36
tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
use std::ptr;
|
||||
|
||||
#[inline(never)]
|
||||
fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(src, dst, count);
|
||||
}
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_memcpy(input: &[f64; 128]) -> f64 {
|
||||
let mut local_data = [0.0f64; 128];
|
||||
|
||||
// Use a separate function to prevent inlining and optimization
|
||||
force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
|
||||
|
||||
// Sum only first few elements to keep the computation simple
|
||||
local_data[0] * local_data[0]
|
||||
+ local_data[1] * local_data[1]
|
||||
+ local_data[2] * local_data[2]
|
||||
+ local_data[3] * local_data[3]
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let input = [1.0; 128];
|
||||
let mut d_input = [0.0; 128];
|
||||
let result = test_memcpy(&input);
|
||||
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
|
||||
|
||||
assert_eq!(result, result_d);
|
||||
println!("Memcpy test passed: result = {}", result);
|
||||
}
|
||||
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
39
tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// First, compile to LLVM IR to check for enzyme_type attributes
|
||||
let _ir_output = rustc()
|
||||
.input("memcpy.rs")
|
||||
.arg("-Zautodiff=Enable")
|
||||
.arg("-Zautodiff=NoPostopt")
|
||||
.opt_level("0")
|
||||
.arg("--emit=llvm-ir")
|
||||
.arg("-o")
|
||||
.arg("main.ll")
|
||||
.run();
|
||||
|
||||
// Then compile with TypeTree analysis output for the existing checks
|
||||
let output = rustc()
|
||||
.input("memcpy.rs")
|
||||
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
|
||||
.arg("-Zautodiff=NoPostopt")
|
||||
.opt_level("3")
|
||||
.arg("-Clto=fat")
|
||||
.arg("-g")
|
||||
.run();
|
||||
|
||||
let stdout = output.stdout_utf8();
|
||||
let stderr = output.stderr_utf8();
|
||||
let ir_content = rfs::read_to_string("main.ll");
|
||||
|
||||
rfs::write("memcpy.stdout", &stdout);
|
||||
rfs::write("memcpy.stderr", &stderr);
|
||||
rfs::write("main.ir", &ir_content);
|
||||
|
||||
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
|
||||
|
||||
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
; Check that mixed struct with large array generates correct detailed type tree
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}"
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc()
|
||||
.input("test.rs")
|
||||
.arg("-Zautodiff=Enable")
|
||||
.arg("-Zautodiff=NoPostopt")
|
||||
.opt_level("0")
|
||||
.emit("llvm-ir")
|
||||
.run();
|
||||
|
||||
llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[repr(C)]
|
||||
struct Container {
|
||||
header: i64,
|
||||
data: [f32; 1000],
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
#[inline(never)]
|
||||
fn test_mixed_struct(container: &Container) -> f32 {
|
||||
container.data[0] + container.data[999]
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let container = Container { header: 42, data: [1.0; 1000] };
|
||||
let mut d_container = Container { header: 0, data: [0.0; 1000] };
|
||||
let result = d_test(&container, &mut d_container, 1.0);
|
||||
std::hint::black_box(result);
|
||||
}
|
||||
5
tests/run-make/autodiff/type-trees/nott-flag/nott.check
Normal file
5
tests/run-make/autodiff/type-trees/nott-flag/nott.check
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
// Check that enzyme_type attributes are NOT present when NoTT flag is used
|
||||
// This verifies the NoTT flag correctly disables TypeTree metadata
|
||||
|
||||
CHECK: define{{.*}}@square
|
||||
CHECK-NOT: "enzyme_type"
|
||||
30
tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
Normal file
30
tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Test with NoTT flag - should not generate TypeTree metadata
|
||||
rustc()
|
||||
.input("test.rs")
|
||||
.arg("-Zautodiff=Enable,NoTT")
|
||||
.emit("llvm-ir")
|
||||
.arg("-o")
|
||||
.arg("nott.ll")
|
||||
.run();
|
||||
|
||||
// Test without NoTT flag - should generate TypeTree metadata
|
||||
rustc()
|
||||
.input("test.rs")
|
||||
.arg("-Zautodiff=Enable")
|
||||
.emit("llvm-ir")
|
||||
.arg("-o")
|
||||
.arg("with_tt.ll")
|
||||
.run();
|
||||
|
||||
// Verify NoTT version does NOT have enzyme_type attributes
|
||||
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
|
||||
|
||||
// Verify TypeTree version DOES have enzyme_type attributes
|
||||
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
|
||||
}
|
||||
15
tests/run-make/autodiff/type-trees/nott-flag/test.rs
Normal file
15
tests/run-make/autodiff/type-trees/nott-flag/test.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0;
|
||||
let mut dx = 0.0;
|
||||
let _result = d_square(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
// Check that enzyme_type attributes are present when TypeTree is enabled
|
||||
// This verifies our TypeTree metadata attachment is working
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
100
tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
Normal file
100
tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
// Self-referential struct to test recursion detection
|
||||
#[derive(Clone)]
|
||||
struct Node {
|
||||
value: f64,
|
||||
next: Option<Box<Node>>,
|
||||
}
|
||||
|
||||
// Mutually recursive structs to test cycle detection
|
||||
#[derive(Clone)]
|
||||
struct GraphNodeA {
|
||||
value: f64,
|
||||
connections: Vec<GraphNodeB>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphNodeB {
|
||||
weight: f64,
|
||||
target: Option<Box<GraphNodeA>>,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_node, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_node(node: &Node) -> f64 {
|
||||
node.value * 2.0
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_graph(a: &GraphNodeA) -> f64 {
|
||||
a.value * 3.0
|
||||
}
|
||||
|
||||
// Simple depth test - deeply nested but not circular
|
||||
#[derive(Clone)]
|
||||
struct Level1 {
|
||||
val: f64,
|
||||
next: Option<Box<Level2>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level2 {
|
||||
val: f64,
|
||||
next: Option<Box<Level3>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level3 {
|
||||
val: f64,
|
||||
next: Option<Box<Level4>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level4 {
|
||||
val: f64,
|
||||
next: Option<Box<Level5>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level5 {
|
||||
val: f64,
|
||||
next: Option<Box<Level6>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level6 {
|
||||
val: f64,
|
||||
next: Option<Box<Level7>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level7 {
|
||||
val: f64,
|
||||
next: Option<Box<Level8>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level8 {
|
||||
val: f64,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_deep(deep: &Level1) -> f64 {
|
||||
deep.val * 4.0
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let node = Node { value: 1.0, next: None };
|
||||
|
||||
let graph = GraphNodeA { value: 2.0, connections: vec![] };
|
||||
|
||||
let deep = Level1 { val: 5.0, next: None };
|
||||
|
||||
let mut d_node = Node { value: 0.0, next: None };
|
||||
|
||||
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
|
||||
|
||||
let mut d_deep = Level1 { val: 0.0, next: None };
|
||||
|
||||
let _result1 = d_test_node(&node, &mut d_node, 1.0);
|
||||
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
|
||||
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that f128 TypeTree metadata is correctly generated
|
||||
; Should show Float@fp128 for f128 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Compile with TypeTree enabled and emit LLVM IR
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
|
||||
// Check that f128 TypeTree metadata is correctly generated
|
||||
llvm_filecheck().patterns("f128.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff, f128)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_f128(x: &f128) -> f128 {
|
||||
*x * *x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0_f128;
|
||||
let mut dx = 0.0_f128;
|
||||
let _result = d_test(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that f16 TypeTree metadata is correctly generated
|
||||
; Should show Float@half for f16 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Compile with TypeTree enabled and emit LLVM IR
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
|
||||
// Check that f16 TypeTree metadata is correctly generated
|
||||
llvm_filecheck().patterns("f16.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff, f16)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_f16(x: &f16) -> f16 {
|
||||
*x * *x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0_f16;
|
||||
let mut dx = 0.0_f16;
|
||||
let _result = d_test(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that f32 TypeTree metadata is correctly generated
|
||||
; Should show Float@float for f32 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Compile with TypeTree enabled and emit LLVM IR
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
|
||||
// Check that f32 TypeTree metadata is correctly generated
|
||||
llvm_filecheck().patterns("f32.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_f32(x: &f32) -> f32 {
|
||||
x * x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0_f32;
|
||||
let mut dx = 0.0_f32;
|
||||
let _result = d_test(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that f64 TypeTree metadata is correctly generated
|
||||
; Should show Float@double for f64 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Compile with TypeTree enabled and emit LLVM IR
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
|
||||
// Check that f64 TypeTree metadata is correctly generated
|
||||
llvm_filecheck().patterns("f64.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_f64(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0_f64;
|
||||
let mut dx = 0.0_f64;
|
||||
let _result = d_test(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that i32 TypeTree metadata is correctly generated
|
||||
; Should show Integer for i32 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}"
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
// Compile with TypeTree enabled and emit LLVM IR
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
|
||||
// Check that i32 TypeTree metadata is correctly generated
|
||||
llvm_filecheck().patterns("i32.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_i32(x: &i32) -> i32 {
|
||||
x * x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 5_i32;
|
||||
let mut dx = 0_i32;
|
||||
let _result = d_test(&x, &mut dx, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("slice.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that slice TypeTree metadata is correctly generated
|
||||
; Should show Float@double for slice elements
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
|
||||
16
tests/run-make/autodiff/type-trees/slice-typetree/test.rs
Normal file
16
tests/run-make/autodiff/type-trees/slice-typetree/test.rs
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_slice(slice: &[f64]) -> f64 {
|
||||
slice.iter().sum()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let slice = &arr[..];
|
||||
let mut d_slice = [0.0; 5];
|
||||
let _result = d_test(slice, &mut d_slice[..], 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that struct TypeTree metadata is correctly generated
|
||||
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"
|
||||
22
tests/run-make/autodiff/type-trees/struct-typetree/test.rs
Normal file
22
tests/run-make/autodiff/type-trees/struct-typetree/test.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[repr(C)]
|
||||
struct Point {
|
||||
x: f64,
|
||||
y: f64,
|
||||
z: f64,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_struct(point: &Point) -> f64 {
|
||||
point.x + point.y * 2.0 + point.z * 3.0
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
|
||||
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
|
||||
let _result = d_test(&point, &mut d_point, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
15
tests/run-make/autodiff/type-trees/tuple-typetree/test.rs
Normal file
15
tests/run-make/autodiff/type-trees/tuple-typetree/test.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_tuple(tuple: &(f64, f64, f64)) -> f64 {
|
||||
tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let tuple = (1.0, 2.0, 3.0);
|
||||
let mut d_tuple = (0.0, 0.0, 0.0);
|
||||
let _result = d_test(&tuple, &mut d_tuple, 1.0);
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that tuple TypeTree metadata is correctly generated
|
||||
; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64)
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
|
||||
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer}
|
||||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer}
|
||||
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !102, !noundef !{{[0-9]+}}: {}
|
||||
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
|
||||
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 16, !dbg !{{[0-9]+}}: {[-1]:Pointer}
|
||||
// CHECK-DAG: %{{[0-9]+}} = load i64, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
|
||||
// CHECK-DAG: %{{[0-9]+}} = icmp eq i64 %{{[0-9]+}}, 0, !dbg !{{[0-9]+}}: {[-1]:Integer}
|
||||
|
|
|
|||
19
tests/ui/autodiff/flag_nott.rs
Normal file
19
tests/ui/autodiff/flag_nott.rs
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
//@ compile-flags: -Zautodiff=Enable,NoTT
|
||||
//@ needs-enzyme
|
||||
//@ check-pass
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
// Test that NoTT flag is accepted and doesn't cause compilation errors
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let x = 2.0;
|
||||
let mut dx = 0.0;
|
||||
let result = d_square(&x, &mut dx, 1.0);
|
||||
}
|
||||
|
|
@ -53,6 +53,7 @@ ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC"
|
|||
ERROR_MCA_OCCURED = "ERROR_MCA_OCCURED"
|
||||
ERRNO_ACCES = "ERRNO_ACCES"
|
||||
tolen = "tolen"
|
||||
EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq"
|
||||
|
||||
[default]
|
||||
extend-ignore-words-re = [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue