autodiff: struct support in typetree
This commit is contained in:
parent
7c5fbfbdbb
commit
574f0b97d6
4 changed files with 67 additions and 0 deletions
|
|
@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
|
|||
return TypeTree(types);
|
||||
}
|
||||
|
||||
if let ty::Adt(adt_def, args) = ty.kind() {
|
||||
if adt_def.is_struct() {
|
||||
let struct_layout =
|
||||
tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
|
||||
if let Ok(layout) = struct_layout {
|
||||
let mut types = Vec::new();
|
||||
|
||||
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
|
||||
let field_ty = field_def.ty(tcx, args);
|
||||
let field_tree = typetree_from_ty(tcx, field_ty);
|
||||
|
||||
let field_offset = layout.fields.offset(field_idx).bytes_usize();
|
||||
|
||||
for elem_type in &field_tree.0 {
|
||||
types.push(Type {
|
||||
offset: if elem_type.offset == -1 {
|
||||
field_offset as isize
|
||||
} else {
|
||||
field_offset as isize + elem_type.offset
|
||||
},
|
||||
size: elem_type.size,
|
||||
kind: elem_type.kind,
|
||||
child: elem_type.child.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return TypeTree(types);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TypeTree::new()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
//@ needs-enzyme
|
||||
//@ ignore-cross-compile
|
||||
|
||||
use run_make_support::{llvm_filecheck, rfs, rustc};
|
||||
|
||||
fn main() {
|
||||
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
|
||||
llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
; Check that struct TypeTree metadata is correctly generated
|
||||
; Should show Float@double at offsets 0, 8, 16 for Point struct fields
|
||||
|
||||
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"
|
||||
22
tests/run-make/autodiff/type-trees/struct-typetree/test.rs
Normal file
22
tests/run-make/autodiff/type-trees/struct-typetree/test.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[repr(C)]
|
||||
struct Point {
|
||||
x: f64,
|
||||
y: f64,
|
||||
z: f64,
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_test, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
fn test_struct(point: &Point) -> f64 {
|
||||
point.x + point.y * 2.0 + point.z * 3.0
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let point = Point { x: 1.0, y: 2.0, z: 3.0 };
|
||||
let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
|
||||
let _result = d_test(&point, &mut d_point, 1.0);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue