Rollup merge of #144197 - KMJ-007:type-tree, r=ZuseZ4

TypeTree support in autodiff

# TypeTrees for Autodiff

## What are TypeTrees?
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

## Structure
```rust
TypeTree(Vec<Type>)

Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}
```

## Example: `fn compute(x: &f32, data: &[f32]) -> f32`

**Input 0: `x: &f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])
```

**Input 1: `data: &[f32]`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])
```

**Output: `f32`**
```rust
TypeTree(vec![Type {
    offset: -1, size: 4, kind: Float,
    child: TypeTree::new()
}])
```

## Why Needed?
- Enzyme can't deduce complex type layouts from LLVM IR
- Prevents slow memory pattern analysis
- Enables correct derivative computation for nested structures
- Tells Enzyme which bytes are differentiable vs metadata

## What Enzyme Does With This Information:

Without TypeTrees (current state):
```llvm
; Enzyme sees generic LLVM IR:
define float ``@distance(ptr*`` %p1, ptr* %p2) {
; Has to guess what these pointers point to
; Slow analysis of all memory operations
; May miss optimization opportunities
}
```

With TypeTrees (our implementation):
```llvm
define "enzyme_type"="{[]:Float@float}" float ``@distance(``
    ptr "enzyme_type"="{[]:Pointer}" %p1,
    ptr "enzyme_type"="{[]:Pointer}" %p2
) {
; Enzyme knows exact type layout
; Can generate efficient derivative code directly
}
```

# TypeTrees - Offset and -1 Explained

## Type Structure

```rust
Type {
    offset: isize, // WHERE this type starts
    size: usize,   // HOW BIG this type is
    kind: Kind,    // WHAT KIND of data (Float, Int, Pointer)
    child: TypeTree // WHAT'S INSIDE (for pointers/containers)
}
```

## Offset Values

### Regular Offset (0, 4, 8, etc.)
**Specific byte position within a structure**

```rust
struct Point {
    x: f32, // offset 0, size 4
    y: f32, // offset 4, size 4
    id: i32, // offset 8, size 4
}
```

TypeTree for `&Point` (internal representation):
```rust
TypeTree(vec![
    Type { offset: 0, size: 4, kind: Float },   // x at byte 0
    Type { offset: 4, size: 4, kind: Float },   // y at byte 4
    Type { offset: 8, size: 4, kind: Integer }  // id at byte 8
])
```

Generates LLVM:
```llvm
"enzyme_type"="{[]:Float@float}"
```

### Offset -1 (Special: "Everywhere")
**Means "this pattern repeats for ALL elements"**

#### Example 1: Array `[f32; 100]`
```rust
TypeTree(vec![Type {
    offset: -1, // ALL positions
    size: 4,    // each f32 is 4 bytes
    kind: Float, // every element is float
}])
```

Instead of listing 100 separate Types with offsets `0,4,8,12...396`

#### Example 2: Slice `&[i32]`
```rust
// Pointer to slice data
TypeTree(vec![Type {
    offset: -1, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, // ALL slice elements
        size: 4,    // each i32 is 4 bytes
        kind: Integer
    }])
}])
```

#### Example 3: Mixed Structure
```rust
struct Container {
    header: i64,        // offset 0
    data: [f32; 1000],  // offset 8, but elements use -1
}
```

```rust
TypeTree(vec![
    Type { offset: 0, size: 8, kind: Integer }, // header
    Type { offset: 8, size: 4000, kind: Pointer,
        child: TypeTree(vec![Type {
            offset: -1, size: 4, kind: Float // ALL array elements
        }])
    }
])
```
This commit is contained in:
Matthias Krüger 2025-09-28 18:13:11 +02:00 committed by GitHub
commit c29fb2e57e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 1250 additions and 14 deletions

View file

@ -6,6 +6,7 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;
use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::{Ty, TyKind};
@ -84,6 +85,8 @@ pub struct AutoDiffItem {
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@ -275,14 +278,22 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
AutoDiffItem { source, target, attrs: self }
pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}
impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}

View file

@ -31,6 +31,7 @@ pub enum Kind {
Half,
Float,
Double,
F128,
Unknown,
}

View file

@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
_src_align: Align,
size: RValue<'gcc>,
flags: MemFlags,
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_size_t(), false);

View file

@ -770,6 +770,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);
bx.lifetime_end(scratch, scratch_size);

View file

@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}

View file

@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
// Disables TypeTree generation
config::AutoDiff::NoTT => {}
}
}
// This helps with handling enums for now.

View file

@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};
use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;
@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

View file

@ -1,6 +1,7 @@
use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
dest: PlaceRef<'tcx, &'ll Value>,
fnc_tree: FncTree,
) {
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
let mut ad_name: String = match attrs.mode {
@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args,
);
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
}
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
builder.store_to_place(call, dest.val);

View file

@ -1212,6 +1212,9 @@ fn codegen_autodiff<'ll, 'tcx>(
&mut diff_attrs.input_activity,
);
let fnc_tree =
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
// Build body
generate_enzyme_call(
bx,
@ -1222,6 +1225,7 @@ fn codegen_autodiff<'ll, 'tcx>(
&val_arr,
diff_attrs.clone(),
result,
fnc_tree,
);
}

View file

@ -68,6 +68,7 @@ mod llvm_util;
mod mono_item;
mod type_;
mod type_of;
mod typetree;
mod va_arg;
mod value;

View file

@ -3,9 +3,36 @@
use libc::{c_char, c_uint};
use super::MetadataKindId;
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
use crate::llvm::{Bool, Builder};
// TypeTree types
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub(crate) struct EnzymeTypeTree {
_unused: [u8; 0],
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub(crate) enum CConcreteType {
DT_Anything = 0,
DT_Integer = 1,
DT_Pointer = 2,
DT_Half = 3,
DT_Float = 4,
DT_Double = 5,
DT_Unknown = 6,
DT_FP128 = 9,
}
pub(crate) struct TypeTree {
pub(crate) inner: CTypeTreeRef,
}
#[link(name = "llvm-wrapper", kind = "static")]
unsafe extern "C" {
// Enzyme
@ -68,10 +95,40 @@ pub(crate) mod Enzyme_AD {
use libc::c_void;
use super::{CConcreteType, CTypeTreeRef, Context};
unsafe extern "C" {
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
}
// TypeTree functions
unsafe extern "C" {
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool;
pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64);
pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef);
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
);
pub(crate) fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
}
unsafe extern "C" {
static mut EnzymePrintPerf: c_void;
static mut EnzymePrintActivity: c_void;
@ -141,6 +198,67 @@ pub(crate) use self::Fallback_AD::*;
pub(crate) mod Fallback_AD {
#![allow(unused_variables)]
use libc::c_char;
use super::{CConcreteType, CTypeTreeRef, Context};
// TypeTree function fallbacks
pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef {
unimplemented!()
}
pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq(
arg1: CTypeTreeRef,
data_layout: *const c_char,
offset: i64,
max_size: i64,
add_offset: u64,
) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
CTT: CTypeTreeRef,
indices: *const i64,
len: usize,
ct: CConcreteType,
ctx: &Context,
) {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
unimplemented!()
}
pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) {
unimplemented!()
}
pub(crate) fn set_inline(val: bool) {
unimplemented!()
}
@ -169,3 +287,89 @@ pub(crate) mod Fallback_AD {
unimplemented!()
}
}
impl TypeTree {
pub(crate) fn new() -> TypeTree {
let inner = unsafe { EnzymeNewTypeTree() };
TypeTree { inner }
}
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
TypeTree { inner }
}
pub(crate) fn merge(self, other: Self) -> Self {
unsafe {
EnzymeMergeTypeTree(self.inner, other.inner);
}
drop(other);
self
}
#[must_use]
pub(crate) fn shift(
self,
layout: &str,
offset: isize,
max_size: isize,
add_offset: usize,
) -> Self {
let layout = std::ffi::CString::new(layout).unwrap();
unsafe {
EnzymeTypeTreeShiftIndiciesEq(
self.inner,
layout.as_ptr(),
offset as i64,
max_size as i64,
add_offset as u64,
);
}
self
}
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
unsafe {
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
}
}
}
impl Clone for TypeTree {
fn clone(&self) -> Self {
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
TypeTree { inner }
}
}
impl std::fmt::Display for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
match cstr.to_str() {
Ok(x) => write!(f, "{}", x)?,
Err(err) => write!(f, "could not parse: {}", err)?,
}
// delete C string pointer
unsafe {
EnzymeTypeTreeToStringFree(ptr);
}
Ok(())
}
}
impl std::fmt::Debug for TypeTree {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Display>::fmt(self, f)
}
}
impl Drop for TypeTree {
fn drop(&mut self) {
unsafe { EnzymeFreeTypeTree(self.inner) }
}
}

View file

@ -0,0 +1,122 @@
use rustc_ast::expand::typetree::FncTree;
#[cfg(feature = "llvm_enzyme")]
use {
crate::attributes,
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
std::ffi::{CString, c_char, c_uint},
};
use crate::llvm::{self, Value};
#[cfg(feature = "llvm_enzyme")]
fn to_enzyme_typetree(
rust_typetree: RustTypeTree,
_data_layout: &str,
llcx: &llvm::Context,
) -> llvm::TypeTree {
let mut enzyme_tt = llvm::TypeTree::new();
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
enzyme_tt
}
#[cfg(feature = "llvm_enzyme")]
fn process_typetree_recursive(
enzyme_tt: &mut llvm::TypeTree,
rust_typetree: &RustTypeTree,
parent_indices: &[i64],
llcx: &llvm::Context,
) {
for rust_type in &rust_typetree.0 {
let concrete_type = match rust_type.kind {
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
};
let mut indices = parent_indices.to_vec();
if !parent_indices.is_empty() {
indices.push(rust_type.offset as i64);
} else if rust_type.offset == -1 {
indices.push(-1);
} else {
indices.push(rust_type.offset as i64);
}
enzyme_tt.insert(&indices, concrete_type, llcx);
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
&& !rust_type.child.0.is_empty()
{
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
}
}
}
#[cfg(feature = "llvm_enzyme")]
pub(crate) fn add_tt<'ll>(
llmod: &'ll llvm::Module,
llcx: &'ll llvm::Context,
fn_def: &'ll Value,
tt: FncTree,
) {
let inputs = tt.args;
let ret_tt: RustTypeTree = tt.ret;
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");
let attr_name = "enzyme_type";
let c_attr_name = CString::new(attr_name).unwrap();
for (i, input) in inputs.iter().enumerate() {
unsafe {
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}
unsafe {
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);
let ret_attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}
#[cfg(not(feature = "llvm_enzyme"))]
pub(crate) fn add_tt<'ll>(
_llmod: &'ll llvm::Module,
_llcx: &'ll llvm::Context,
_fn_def: &'ll Value,
_tt: FncTree,
) {
unimplemented!()
}

View file

@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
src_align,
bx.const_u32(layout.layout.size().bytes() as u32),
MemFlags::empty(),
None,
);
tmp
} else {

View file

@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
// ...and then load it with the ABI type.
llval = load_cast(bx, cast, llscratch, scratch_align);

View file

@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
if allow_overlap {
bx.memmove(dst, align, src, align, size, flags);
} else {
bx.memcpy(dst, align, src, align, size, flags);
bx.memcpy(dst, align, src, align, size, flags, None);
}
}

View file

@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let align = pointee_layout.align;
let dst = dst_val.immediate();
let src = src_val.immediate();
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
}
mir::StatementKind::FakeRead(..)
| mir::StatementKind::Retag { .. }

View file

@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
src_align: Align,
size: Self::Value,
flags: MemFlags,
tt: Option<rustc_ast::expand::typetree::FncTree>,
);
fn memmove(
&mut self,
@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
temp.val.store_with_flags(self, dst.with_type(layout), flags);
} else if !layout.is_zst() {
let bytes = self.const_usize(layout.size.bytes());
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
}
}

View file

@ -765,7 +765,7 @@ fn test_unstable_options_tracking_hash() {
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(assume_incomplete_release, true);
tracked!(autodiff, vec![AutoDiff::Enable]);
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
tracked!(binary_dep_depinfo, true);
tracked!(box_noalias, false);
tracked!(

View file

@ -1812,3 +1812,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
MD.NoHWAddress = true;
GV.setSanitizerMetadata(MD);
}
#ifdef ENZYME
extern "C" {
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
}
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
#else
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
return 6; // Default fallback depth
}
#endif

View file

@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
pub sub: TypeMismatchReason,
}
// FIXME(autodiff): I should get used somewhere
#[derive(Diagnostic)]
#[diag(middle_unsupported_union)]
pub struct UnsupportedUnion {

View file

@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
pub use generics::*;
pub use intrinsic::IntrinsicDef;
use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx};
use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
use rustc_ast::node_id::NodeMap;
pub use rustc_ast_ir::{Movability, Mutability, try_visit};
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
@ -62,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
pub use rustc_type_ir::*;
#[allow(hidden_glob_reexports, unused_imports)]
use rustc_type_ir::{InferCtxtLike, Interner};
use tracing::{debug, instrument};
use tracing::{debug, instrument, trace};
pub use vtable::*;
use {rustc_ast as ast, rustc_hir as hir};
@ -2216,3 +2217,225 @@ pub struct DestructuredConst<'tcx> {
pub variant: Option<VariantIdx>,
pub fields: &'tcx [ty::Const<'tcx>],
}
/// Generate TypeTree information for autodiff.
/// This function creates TypeTree metadata that describes the memory layout
/// of function parameters and return types for Enzyme autodiff.
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
// Check if TypeTrees are disabled via NoTT flag
if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
return FncTree { args: vec![], ret: TypeTree::new() };
}
// Check if this is actually a function type
if !fn_ty.is_fn() {
return FncTree { args: vec![], ret: TypeTree::new() };
}
// Get the function signature
let fn_sig = fn_ty.fn_sig(tcx);
let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
// Create TypeTrees for each input parameter
let mut args = vec![];
for ty in sig.inputs().iter() {
let type_tree = typetree_from_ty(tcx, *ty);
args.push(type_tree);
}
// Create TypeTree for return type
let ret = typetree_from_ty(tcx, sig.output());
FncTree { args, ret }
}
/// Generate TypeTree for a specific type.
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let mut visited = Vec::new();
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
}
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
/// from pathological deeply nested types. Combined with cycle detection.
const MAX_TYPETREE_DEPTH: usize = 6;
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
fn typetree_from_ty_inner<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
) -> TypeTree {
if depth >= MAX_TYPETREE_DEPTH {
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
return TypeTree::new();
}
if visited.contains(&ty) {
return TypeTree::new();
}
visited.push(ty);
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
visited.pop();
result
}
/// Implementation of TypeTree generation logic.
fn typetree_from_ty_impl<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
) -> TypeTree {
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
}
/// Internal implementation with context about whether this is for a reference target.
fn typetree_from_ty_impl_inner<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
is_reference_target: bool,
) -> TypeTree {
if ty.is_scalar() {
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
} else if ty.is_floating_point() {
match ty {
x if x == tcx.types.f16 => (Kind::Half, 2),
x if x == tcx.types.f32 => (Kind::Float, 4),
x if x == tcx.types.f64 => (Kind::Double, 8),
x if x == tcx.types.f128 => (Kind::F128, 16),
_ => (Kind::Integer, 0),
}
} else {
(Kind::Integer, 0)
};
// Use offset 0 for scalars that are direct targets of references (like &f64)
// Use offset -1 for scalars used directly (like function return types)
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
}
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
inner
} else {
return TypeTree::new();
};
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
return TypeTree(vec![Type {
offset: -1,
size: tcx.data_layout.pointer_size().bytes_usize(),
kind: Kind::Pointer,
child,
}]);
}
if ty.is_array() {
if let ty::Array(element_ty, len_const) = ty.kind() {
let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
if len == 0 {
return TypeTree::new();
}
let element_tree =
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
let mut types = Vec::new();
for elem_type in &element_tree.0 {
types.push(Type {
offset: -1,
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
return TypeTree(types);
}
}
if ty.is_slice() {
if let ty::Slice(element_ty) = ty.kind() {
let element_tree =
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
return element_tree;
}
}
if let ty::Tuple(tuple_types) = ty.kind() {
if tuple_types.is_empty() {
return TypeTree::new();
}
let mut types = Vec::new();
let mut current_offset = 0;
for tuple_ty in tuple_types.iter() {
let element_tree =
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
let element_layout = tcx
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
.ok()
.map(|layout| layout.size.bytes_usize())
.unwrap_or(0);
for elem_type in &element_tree.0 {
types.push(Type {
offset: if elem_type.offset == -1 {
current_offset as isize
} else {
current_offset as isize + elem_type.offset
},
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
current_offset += element_layout;
}
return TypeTree(types);
}
if let ty::Adt(adt_def, args) = ty.kind() {
if adt_def.is_struct() {
let struct_layout =
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
if let Ok(layout) = struct_layout {
let mut types = Vec::new();
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
let field_ty = field_def.ty(tcx, args);
let field_tree =
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
let field_offset = layout.fields.offset(field_idx).bytes_usize();
for elem_type in &field_tree.0 {
types.push(Type {
offset: if elem_type.offset == -1 {
field_offset as isize
} else {
field_offset as isize + elem_type.offset
},
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
}
return TypeTree(types);
}
}
}
TypeTree::new()
}

View file

@ -258,6 +258,8 @@ pub enum AutoDiff {
LooseTypes,
/// Runs Enzyme's aggressive inlining
Inline,
/// Disable Type Tree
NoTT,
}
/// Settings for `-Z instrument-xray` flag.

View file

@ -792,7 +792,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`";
pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
@ -1481,6 +1481,7 @@ pub mod parse {
"PrintPasses" => AutoDiff::PrintPasses,
"LooseTypes" => AutoDiff::LooseTypes,
"Inline" => AutoDiff::Inline,
"NoTT" => AutoDiff::NoTT,
_ => {
// FIXME(ZuseZ4): print an error saying which value is not recognized
return false;

View file

@ -0,0 +1,33 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
// Test that basic autodiff still works with our TypeTree infrastructure
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_simple, Duplicated, Active)]
#[no_mangle]
#[inline(never)]
fn simple(x: &f64) -> f64 {
2.0 * x
}
// CHECK-LABEL: @simple
// CHECK: fmul double
// The derivative function should be generated normally
// CHECK-LABEL: diffesimple
// CHECK: fadd fast double
fn main() {
let x = std::hint::black_box(3.0);
let output = simple(&x);
assert_eq!(6.0, output);
let mut df_dx = 0.0;
let output_ = d_simple(&x, &mut df_dx, 1.0);
assert_eq!(output, output_);
assert_eq!(2.0, df_dx);
}

View file

@ -0,0 +1,4 @@
; Check that array TypeTree metadata is correctly generated
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_array(arr: &[f64; 5]) -> f64 {
arr[0] + arr[1] + arr[2] + arr[3] + arr[4]
}
fn main() {
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
let mut d_arr = [0.0; 5];
let _result = d_test(&arr, &mut d_arr, 1.0);
}

View file

@ -0,0 +1,8 @@
; Check that enzyme_type attributes are present in the LLVM IR function definition
; This verifies our TypeTree system correctly attaches metadata for Enzyme
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
; Check that llvm.memcpy exists (either call or declare)
CHECK: {{(call|declare).*}}@llvm.memcpy

View file

@ -0,0 +1,13 @@
CHECK: force_memcpy
CHECK: @llvm.memcpy.p0.p0.i64
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}

View file

@ -0,0 +1,36 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
use std::ptr;
#[inline(never)]
fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
unsafe {
ptr::copy_nonoverlapping(src, dst, count);
}
}
#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
#[no_mangle]
fn test_memcpy(input: &[f64; 128]) -> f64 {
let mut local_data = [0.0f64; 128];
// Use a separate function to prevent inlining and optimization
force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
// Sum only first few elements to keep the computation simple
local_data[0] * local_data[0]
+ local_data[1] * local_data[1]
+ local_data[2] * local_data[2]
+ local_data[3] * local_data[3]
}
fn main() {
let input = [1.0; 128];
let mut d_input = [0.0; 128];
let result = test_memcpy(&input);
let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
assert_eq!(result, result_d);
println!("Memcpy test passed: result = {}", result);
}

View file

@ -0,0 +1,39 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// First, compile to LLVM IR to check for enzyme_type attributes
let _ir_output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable")
.arg("-Zautodiff=NoPostopt")
.opt_level("0")
.arg("--emit=llvm-ir")
.arg("-o")
.arg("main.ll")
.run();
// Then compile with TypeTree analysis output for the existing checks
let output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();
let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();
let ir_content = rfs::read_to_string("main.ll");
rfs::write("memcpy.stdout", &stdout);
rfs::write("memcpy.stderr", &stderr);
rfs::write("main.ir", &ir_content);
llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
}

View file

@ -0,0 +1,2 @@
; Check that mixed struct with large array generates correct detailed type tree
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}"

View file

@ -0,0 +1,16 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc()
.input("test.rs")
.arg("-Zautodiff=Enable")
.arg("-Zautodiff=NoPostopt")
.opt_level("0")
.emit("llvm-ir")
.run();
llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,23 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[repr(C)]
struct Container {
header: i64,
data: [f32; 1000],
}
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
#[inline(never)]
fn test_mixed_struct(container: &Container) -> f32 {
container.data[0] + container.data[999]
}
fn main() {
let container = Container { header: 42, data: [1.0; 1000] };
let mut d_container = Container { header: 0, data: [0.0; 1000] };
let result = d_test(&container, &mut d_container, 1.0);
std::hint::black_box(result);
}

View file

@ -0,0 +1,5 @@
// Check that enzyme_type attributes are NOT present when NoTT flag is used
// This verifies the NoTT flag correctly disables TypeTree metadata
CHECK: define{{.*}}@square
CHECK-NOT: "enzyme_type"

View file

@ -0,0 +1,30 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Test with NoTT flag - should not generate TypeTree metadata
rustc()
.input("test.rs")
.arg("-Zautodiff=Enable,NoTT")
.emit("llvm-ir")
.arg("-o")
.arg("nott.ll")
.run();
// Test without NoTT flag - should generate TypeTree metadata
rustc()
.input("test.rs")
.arg("-Zautodiff=Enable")
.emit("llvm-ir")
.arg("-o")
.arg("with_tt.ll")
.run();
// Verify NoTT version does NOT have enzyme_type attributes
llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
// Verify TypeTree version DOES have enzyme_type attributes
llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn square(x: &f64) -> f64 {
x * x
}
fn main() {
let x = 2.0;
let mut dx = 0.0;
let _result = d_square(&x, &mut dx, 1.0);
}

View file

@ -0,0 +1,4 @@
// Check that enzyme_type attributes are present when TypeTree is enabled
// This verifies our TypeTree metadata attachment is working
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -0,0 +1,3 @@
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,100 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
// Self-referential struct to test recursion detection
#[derive(Clone)]
struct Node {
value: f64,
next: Option<Box<Node>>,
}
// Mutually recursive structs to test cycle detection
#[derive(Clone)]
struct GraphNodeA {
value: f64,
connections: Vec<GraphNodeB>,
}
#[derive(Clone)]
struct GraphNodeB {
weight: f64,
target: Option<Box<GraphNodeA>>,
}
#[autodiff_reverse(d_test_node, Duplicated, Active)]
#[no_mangle]
fn test_node(node: &Node) -> f64 {
node.value * 2.0
}
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
#[no_mangle]
fn test_graph(a: &GraphNodeA) -> f64 {
a.value * 3.0
}
// Simple depth test - deeply nested but not circular
#[derive(Clone)]
struct Level1 {
val: f64,
next: Option<Box<Level2>>,
}
#[derive(Clone)]
struct Level2 {
val: f64,
next: Option<Box<Level3>>,
}
#[derive(Clone)]
struct Level3 {
val: f64,
next: Option<Box<Level4>>,
}
#[derive(Clone)]
struct Level4 {
val: f64,
next: Option<Box<Level5>>,
}
#[derive(Clone)]
struct Level5 {
val: f64,
next: Option<Box<Level6>>,
}
#[derive(Clone)]
struct Level6 {
val: f64,
next: Option<Box<Level7>>,
}
#[derive(Clone)]
struct Level7 {
val: f64,
next: Option<Box<Level8>>,
}
#[derive(Clone)]
struct Level8 {
val: f64,
}
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
#[no_mangle]
fn test_deep(deep: &Level1) -> f64 {
deep.val * 4.0
}
fn main() {
let node = Node { value: 1.0, next: None };
let graph = GraphNodeA { value: 2.0, connections: vec![] };
let deep = Level1 { val: 5.0, next: None };
let mut d_node = Node { value: 0.0, next: None };
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
let mut d_deep = Level1 { val: 0.0, next: None };
let _result1 = d_test_node(&node, &mut d_node, 1.0);
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that f128 TypeTree metadata is correctly generated
; Should show Float@fp128 for f128 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"

View file

@ -0,0 +1,12 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Compile with TypeTree enabled and emit LLVM IR
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
// Check that f128 TypeTree metadata is correctly generated
llvm_filecheck().patterns("f128.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff, f128)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_f128(x: &f128) -> f128 {
*x * *x
}
fn main() {
let x = 2.0_f128;
let mut dx = 0.0_f128;
let _result = d_test(&x, &mut dx, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that f16 TypeTree metadata is correctly generated
; Should show Float@half for f16 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"

View file

@ -0,0 +1,12 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Compile with TypeTree enabled and emit LLVM IR
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
// Check that f16 TypeTree metadata is correctly generated
llvm_filecheck().patterns("f16.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff, f16)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_f16(x: &f16) -> f16 {
*x * *x
}
fn main() {
let x = 2.0_f16;
let mut dx = 0.0_f16;
let _result = d_test(&x, &mut dx, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that f32 TypeTree metadata is correctly generated
; Should show Float@float for f32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"

View file

@ -0,0 +1,12 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Compile with TypeTree enabled and emit LLVM IR
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
// Check that f32 TypeTree metadata is correctly generated
llvm_filecheck().patterns("f32.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_f32(x: &f32) -> f32 {
x * x
}
fn main() {
let x = 2.0_f32;
let mut dx = 0.0_f32;
let _result = d_test(&x, &mut dx, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that f64 TypeTree metadata is correctly generated
; Should show Float@double for f64 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -0,0 +1,12 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Compile with TypeTree enabled and emit LLVM IR
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
// Check that f64 TypeTree metadata is correctly generated
llvm_filecheck().patterns("f64.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_f64(x: &f64) -> f64 {
x * x
}
fn main() {
let x = 2.0_f64;
let mut dx = 0.0_f64;
let _result = d_test(&x, &mut dx, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that i32 TypeTree metadata is correctly generated
; Should show Integer for i32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}"

View file

@ -0,0 +1,12 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
// Compile with TypeTree enabled and emit LLVM IR
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
// Check that i32 TypeTree metadata is correctly generated
llvm_filecheck().patterns("i32.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_i32(x: &i32) -> i32 {
x * x
}
fn main() {
let x = 5_i32;
let mut dx = 0_i32;
let _result = d_test(&x, &mut dx, 1);
}

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("slice.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,4 @@
; Check that slice TypeTree metadata is correctly generated
; Should show Float@double for slice elements
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

View file

@ -0,0 +1,16 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_slice(slice: &[f64]) -> f64 {
slice.iter().sum()
}
fn main() {
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
let slice = &arr[..];
let mut d_slice = [0.0; 5];
let _result = d_test(slice, &mut d_slice[..], 1.0);
}

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,4 @@
; Check that struct TypeTree metadata is correctly generated
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"

View file

@ -0,0 +1,22 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[repr(C)]
struct Point {
x: f64,
y: f64,
z: f64,
}
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_struct(point: &Point) -> f64 {
point.x + point.y * 2.0 + point.z * 3.0
}
fn main() {
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
let _result = d_test(&point, &mut d_point, 1.0);
}

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,15 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_tuple(tuple: &(f64, f64, f64)) -> f64 {
tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0
}
fn main() {
let tuple = (1.0, 2.0, 3.0);
let mut d_tuple = (0.0, 0.0, 0.0);
let _result = d_test(&tuple, &mut d_tuple, 1.0);
}

View file

@ -0,0 +1,4 @@
; Check that tuple TypeTree metadata is correctly generated
; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64)
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}"

View file

@ -1,7 +1,7 @@
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !102, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 16, !dbg !{{[0-9]+}}: {[-1]:Pointer}
// CHECK-DAG: %{{[0-9]+}} = load i64, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {}
// CHECK-DAG: %{{[0-9]+}} = icmp eq i64 %{{[0-9]+}}, 0, !dbg !{{[0-9]+}}: {[-1]:Integer}

View file

@ -0,0 +1,19 @@
//@ compile-flags: -Zautodiff=Enable,NoTT
//@ needs-enzyme
//@ check-pass
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
// Test that NoTT flag is accepted and doesn't cause compilation errors
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
fn main() {
let x = 2.0;
let mut dx = 0.0;
let result = d_square(&x, &mut dx, 1.0);
}

View file

@ -53,6 +53,7 @@ ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC"
ERROR_MCA_OCCURED = "ERROR_MCA_OCCURED"
ERRNO_ACCES = "ERRNO_ACCES"
tolen = "tolen"
EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq"
[default]
extend-ignore-words-re = [