autodiff: struct support in typetree

This commit is contained in:
Karan Janthe 2025-09-01 16:28:14 +00:00
parent 7c5fbfbdbb
commit 574f0b97d6
4 changed files with 67 additions and 0 deletions

View file

@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
return TypeTree(types);
}
if let ty::Adt(adt_def, args) = ty.kind() {
if adt_def.is_struct() {
let struct_layout =
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
if let Ok(layout) = struct_layout {
let mut types = Vec::new();
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
let field_ty = field_def.ty(tcx, args);
let field_tree = typetree_from_ty(tcx, field_ty);
let field_offset = layout.fields.offset(field_idx).bytes_usize();
for elem_type in &field_tree.0 {
types.push(Type {
offset: if elem_type.offset == -1 {
field_offset as isize
} else {
field_offset as isize + elem_type.offset
},
size: elem_type.size,
kind: elem_type.kind,
child: elem_type.child.clone(),
});
}
}
return TypeTree(types);
}
}
}
TypeTree::new()
}

View file

@ -0,0 +1,9 @@
//@ needs-enzyme
//@ ignore-cross-compile
use run_make_support::{llvm_filecheck, rfs, rustc};
fn main() {
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
}

View file

@ -0,0 +1,4 @@
; Check that struct TypeTree metadata is correctly generated
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"

View file

@ -0,0 +1,22 @@
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[repr(C)]
struct Point {
x: f64,
y: f64,
z: f64,
}
#[autodiff_reverse(d_test, Duplicated, Active)]
#[no_mangle]
fn test_struct(point: &Point) -> f64 {
point.x + point.y * 2.0 + point.z * 3.0
}
fn main() {
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
let _result = d_test(&point, &mut d_point, 1.0);
}