autodiff: recurion added for typetree
This commit is contained in:
parent
4f3f0f48e7
commit
4520926bb5
12 changed files with 191 additions and 20 deletions
|
|
@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD {
|
|||
);
|
||||
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
|
||||
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
|
||||
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
|
|
|
|||
|
|
@ -39,11 +39,7 @@ fn process_typetree_recursive(
|
|||
|
||||
let mut indices = parent_indices.to_vec();
|
||||
if !parent_indices.is_empty() {
|
||||
if rust_type.offset == -1 {
|
||||
indices.push(-1);
|
||||
} else {
|
||||
indices.push(rust_type.offset as i64);
|
||||
}
|
||||
indices.push(rust_type.offset as i64);
|
||||
} else if rust_type.offset == -1 {
|
||||
indices.push(-1);
|
||||
} else {
|
||||
|
|
@ -52,7 +48,9 @@ fn process_typetree_recursive(
|
|||
|
||||
enzyme_tt.insert(&indices, concrete_type, llcx);
|
||||
|
||||
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
|
||||
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
|
||||
&& !rust_type.child.0.is_empty()
|
||||
{
|
||||
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
|
|||
/// Generate TypeTree for a specific type.
|
||||
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
|
||||
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
||||
let mut visited = Vec::new();
|
||||
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
|
||||
}
|
||||
|
||||
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
|
||||
fn typetree_from_ty_inner<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
) -> TypeTree {
|
||||
#[cfg(llvm_enzyme)]
|
||||
{
|
||||
unsafe extern "C" {
|
||||
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
|
||||
}
|
||||
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
|
||||
if depth > max_depth {
|
||||
return TypeTree::new();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(llvm_enzyme))]
|
||||
if depth > 6 {
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
if visited.contains(&ty) {
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
visited.push(ty);
|
||||
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
|
||||
visited.pop();
|
||||
result
|
||||
}
|
||||
|
||||
/// Implementation of TypeTree generation logic.
|
||||
fn typetree_from_ty_impl<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
) -> TypeTree {
|
||||
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
|
||||
}
|
||||
|
||||
/// Internal implementation with context about whether this is for a reference target.
|
||||
fn typetree_from_ty_impl_inner<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
ty: Ty<'tcx>,
|
||||
depth: usize,
|
||||
visited: &mut Vec<Ty<'tcx>>,
|
||||
is_reference_target: bool,
|
||||
) -> TypeTree {
|
||||
if ty.is_scalar() {
|
||||
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
|
||||
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
|
||||
|
|
@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
(Kind::Integer, 0)
|
||||
};
|
||||
|
||||
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
|
||||
// Use offset 0 for scalars that are direct targets of references (like &f64)
|
||||
// Use offset -1 for scalars used directly (like function return types)
|
||||
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
|
||||
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
|
||||
}
|
||||
|
||||
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
|
||||
|
|
@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
return TypeTree::new();
|
||||
};
|
||||
|
||||
let child = typetree_from_ty(tcx, inner_ty);
|
||||
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
|
||||
return TypeTree(vec![Type {
|
||||
offset: -1,
|
||||
size: tcx.data_layout.pointer_size().bytes_usize(),
|
||||
|
|
@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
if len == 0 {
|
||||
return TypeTree::new();
|
||||
}
|
||||
|
||||
let element_tree = typetree_from_ty(tcx, *element_ty);
|
||||
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
|
||||
let mut types = Vec::new();
|
||||
for elem_type in &element_tree.0 {
|
||||
types.push(Type {
|
||||
|
|
@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
|
||||
if ty.is_slice() {
|
||||
if let ty::Slice(element_ty) = ty.kind() {
|
||||
let element_tree = typetree_from_ty(tcx, *element_ty);
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
|
||||
return element_tree;
|
||||
}
|
||||
}
|
||||
|
|
@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
let mut current_offset = 0;
|
||||
|
||||
for tuple_ty in tuple_types.iter() {
|
||||
let element_tree = typetree_from_ty(tcx, tuple_ty);
|
||||
let element_tree =
|
||||
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
|
||||
|
||||
let element_layout = tcx
|
||||
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
|
||||
|
|
@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
|
||||
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
|
||||
let field_ty = field_def.ty(tcx, args);
|
||||
let field_tree = typetree_from_ty(tcx, field_ty);
|
||||
let field_tree =
|
||||
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
|
||||
|
||||
let field_offset = layout.fields.offset(field_idx).bytes_usize();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Check that enzyme_type attributes are present when TypeTree is enabled
|
||||
// This verifies our TypeTree metadata attachment is working
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
100
tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
Normal file
100
tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
// Self-referential struct to test recursion detection
|
||||
#[derive(Clone)]
|
||||
struct Node {
|
||||
value: f64,
|
||||
next: Option<Box<Node>>,
|
||||
}
|
||||
|
||||
// Mutually recursive structs to test cycle detection
|
||||
#[derive(Clone)]
|
||||
struct GraphNodeA {
|
||||
value: f64,
|
||||
connections: Vec<GraphNodeB>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct GraphNodeB {
|
||||
weight: f64,
|
||||
target: Option<Box<GraphNodeA>>,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_node, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_node(node: &Node) -> f64 {
|
||||
node.value * 2.0
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_graph(a: &GraphNodeA) -> f64 {
|
||||
a.value * 3.0
|
||||
}
|
||||
|
||||
// Simple depth test - deeply nested but not circular
|
||||
#[derive(Clone)]
|
||||
struct Level1 {
|
||||
val: f64,
|
||||
next: Option<Box<Level2>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level2 {
|
||||
val: f64,
|
||||
next: Option<Box<Level3>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level3 {
|
||||
val: f64,
|
||||
next: Option<Box<Level4>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level4 {
|
||||
val: f64,
|
||||
next: Option<Box<Level5>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level5 {
|
||||
val: f64,
|
||||
next: Option<Box<Level6>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level6 {
|
||||
val: f64,
|
||||
next: Option<Box<Level7>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level7 {
|
||||
val: f64,
|
||||
next: Option<Box<Level8>>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
struct Level8 {
|
||||
val: f64,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_deep(deep: &Level1) -> f64 {
|
||||
deep.val * 4.0
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let node = Node { value: 1.0, next: None };
|
||||
|
||||
let graph = GraphNodeA { value: 2.0, connections: vec![] };
|
||||
|
||||
let deep = Level1 { val: 5.0, next: None };
|
||||
|
||||
let mut d_node = Node { value: 0.0, next: None };
|
||||
|
||||
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
|
||||
|
||||
let mut d_deep = Level1 { val: 0.0, next: None };
|
||||
|
||||
let _result1 = d_test_node(&node, &mut d_node, 1.0);
|
||||
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
|
||||
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
; Check that f128 TypeTree metadata is correctly generated
|
||||
; Should show Float@fp128 for f128 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
; Check that f16 TypeTree metadata is correctly generated
|
||||
; Should show Float@half for f16 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
; Check that f32 TypeTree metadata is correctly generated
|
||||
; Should show Float@float for f32 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
; Check that f64 TypeTree metadata is correctly generated
|
||||
; Should show Float@double for f64 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
; Check that i32 TypeTree metadata is correctly generated
|
||||
; Should show Integer for i32 values and Pointer for references
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"
|
||||
CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue