autodiff: recurion added for typetree

This commit is contained in:
Karan Janthe 2025-09-11 07:30:35 +00:00
parent 4f3f0f48e7
commit 4520926bb5
12 changed files with 191 additions and 20 deletions

View file

@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD {
);
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
}
unsafe extern "C" {

View file

@ -39,11 +39,7 @@ fn process_typetree_recursive(
let mut indices = parent_indices.to_vec();
if !parent_indices.is_empty() {
if rust_type.offset == -1 {
indices.push(-1);
} else {
indices.push(rust_type.offset as i64);
}
indices.push(rust_type.offset as i64);
} else if rust_type.offset == -1 {
indices.push(-1);
} else {
@ -52,7 +48,9 @@ fn process_typetree_recursive(
enzyme_tt.insert(&indices, concrete_type, llcx);
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
&& !rust_type.child.0.is_empty()
{
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
}
}

View file

@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
/// Generate TypeTree for a specific type.
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let mut visited = Vec::new();
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
}
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
fn typetree_from_ty_inner<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
) -> TypeTree {
#[cfg(llvm_enzyme)]
{
unsafe extern "C" {
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
}
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
if depth > max_depth {
return TypeTree::new();
}
}
#[cfg(not(llvm_enzyme))]
if depth > 6 {
return TypeTree::new();
}
if visited.contains(&ty) {
return TypeTree::new();
}
visited.push(ty);
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
visited.pop();
result
}
/// Implementation of TypeTree generation logic.
fn typetree_from_ty_impl<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
) -> TypeTree {
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
}
/// Internal implementation with context about whether this is for a reference target.
fn typetree_from_ty_impl_inner<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
depth: usize,
visited: &mut Vec<Ty<'tcx>>,
is_reference_target: bool,
) -> TypeTree {
if ty.is_scalar() {
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
(Kind::Integer, 0)
};
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
// Use offset 0 for scalars that are direct targets of references (like &f64)
// Use offset -1 for scalars used directly (like function return types)
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
}
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
return TypeTree::new();
};
let child = typetree_from_ty(tcx, inner_ty);
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
return TypeTree(vec![Type {
offset: -1,
size: tcx.data_layout.pointer_size().bytes_usize(),
@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
if len == 0 {
return TypeTree::new();
}
let element_tree = typetree_from_ty(tcx, *element_ty);
let element_tree =
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
let mut types = Vec::new();
for elem_type in &element_tree.0 {
types.push(Type {
@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
if ty.is_slice() {
if let ty::Slice(element_ty) = ty.kind() {
let element_tree = typetree_from_ty(tcx, *element_ty);
let element_tree =
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
return element_tree;
}
}
@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
let mut current_offset = 0;
for tuple_ty in tuple_types.iter() {
let element_tree = typetree_from_ty(tcx, tuple_ty);
let element_tree =
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
let element_layout = tcx
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
let field_ty = field_def.ty(tcx, args);
let field_tree = typetree_from_ty(tcx, field_ty);
let field_tree =
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
let field_offset = layout.fields.offset(field_idx).bytes_usize();

View file

@ -1,4 +1,4 @@
// Check that enzyme_type attributes are present when TypeTree is enabled
// This verifies our TypeTree metadata attachment is working
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -0,0 +1,3 @@
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,100 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
// Self-referential struct to test recursion detection
#[derive(Clone)]
struct Node {
value: f64,
next: Option<Box<Node>>,
}
// Mutually recursive structs to test cycle detection
#[derive(Clone)]
struct GraphNodeA {
value: f64,
connections: Vec<GraphNodeB>,
}
#[derive(Clone)]
struct GraphNodeB {
weight: f64,
target: Option<Box<GraphNodeA>>,
}
#[autodiff_reverse(d_test_node, Duplicated, Active)]
#[no_mangle]
fn test_node(node: &Node) -> f64 {
node.value * 2.0
}
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
#[no_mangle]
fn test_graph(a: &GraphNodeA) -> f64 {
a.value * 3.0
}
// Simple depth test - deeply nested but not circular
#[derive(Clone)]
struct Level1 {
val: f64,
next: Option<Box<Level2>>,
}
#[derive(Clone)]
struct Level2 {
val: f64,
next: Option<Box<Level3>>,
}
#[derive(Clone)]
struct Level3 {
val: f64,
next: Option<Box<Level4>>,
}
#[derive(Clone)]
struct Level4 {
val: f64,
next: Option<Box<Level5>>,
}
#[derive(Clone)]
struct Level5 {
val: f64,
next: Option<Box<Level6>>,
}
#[derive(Clone)]
struct Level6 {
val: f64,
next: Option<Box<Level7>>,
}
#[derive(Clone)]
struct Level7 {
val: f64,
next: Option<Box<Level8>>,
}
#[derive(Clone)]
struct Level8 {
val: f64,
}
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
#[no_mangle]
fn test_deep(deep: &Level1) -> f64 {
deep.val * 4.0
}
fn main() {
let node = Node { value: 1.0, next: None };
let graph = GraphNodeA { value: 2.0, connections: vec![] };
let deep = Level1 { val: 5.0, next: None };
let mut d_node = Node { value: 0.0, next: None };
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
let mut d_deep = Level1 { val: 0.0, next: None };
let _result1 = d_test_node(&node, &mut d_node, 1.0);
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
}

View file

@ -1,4 +1,4 @@
; Check that f128 TypeTree metadata is correctly generated
; Should show Float@fp128 for f128 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"

View file

@ -1,4 +1,4 @@
; Check that f16 TypeTree metadata is correctly generated
; Should show Float@half for f16 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"

View file

@ -1,4 +1,4 @@
; Check that f32 TypeTree metadata is correctly generated
; Should show Float@float for f32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"

View file

@ -1,4 +1,4 @@
; Check that f64 TypeTree metadata is correctly generated
; Should show Float@double for f64 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"

View file

@ -1,4 +1,4 @@
; Check that i32 TypeTree metadata is correctly generated
; Should show Integer for i32 values and Pointer for references
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}"