Rollup merge of #147390 - ZuseZ4:autodiff-dbg, r=jieyouxu
Use globals instead of metadata for std::autodiff LLVM's Metadata is quite fragile. In debug builds we use incremental compilation, which caused the metadata to be dropped. With this change we use named globals instead of metadata to instruct Enzyme how to differentiate functions. Globals are proper llvm values and thus can't be dropped. Also added an incremental/dbg test which now passes, to unblock the EnzymeAD CI which wants to run Rust autodiff tests. r? compiler
This commit is contained in:
commit
27b3881df8
4 changed files with 76 additions and 83 deletions
|
|
@ -12,7 +12,7 @@ use tracing::debug;
|
|||
use crate::builder::{Builder, PlaceRef, UNNAMED};
|
||||
use crate::context::SimpleCx;
|
||||
use crate::declare::declare_simple_fn;
|
||||
use crate::llvm::{self, Metadata, TRUE, Type, Value};
|
||||
use crate::llvm::{self, TRUE, Type, Value};
|
||||
|
||||
pub(crate) fn adjust_activity_to_abi<'tcx>(
|
||||
tcx: TyCtxt<'tcx>,
|
||||
|
|
@ -143,9 +143,9 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
|
|||
cx: &SimpleCx<'ll>,
|
||||
builder: &mut Builder<'_, 'll, 'tcx>,
|
||||
width: u32,
|
||||
args: &mut Vec<&'ll llvm::Value>,
|
||||
args: &mut Vec<&'ll Value>,
|
||||
inputs: &[DiffActivity],
|
||||
outer_args: &[&'ll llvm::Value],
|
||||
outer_args: &[&'ll Value],
|
||||
) {
|
||||
debug!("matching autodiff arguments");
|
||||
// We now handle the issue that Rust level arguments not always match the llvm-ir level
|
||||
|
|
@ -157,32 +157,36 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
|
|||
let mut outer_pos: usize = 0;
|
||||
let mut activity_pos = 0;
|
||||
|
||||
let enzyme_const = cx.create_metadata(b"enzyme_const");
|
||||
let enzyme_out = cx.create_metadata(b"enzyme_out");
|
||||
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
|
||||
let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
|
||||
let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
|
||||
let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
|
||||
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
|
||||
// In debug mode we would use incremental compilation which caused the metadata to be
|
||||
// dropped. This is prevented by now using named globals, which are also understood
|
||||
// by Enzyme.
|
||||
let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
|
||||
let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
|
||||
let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
|
||||
let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
|
||||
let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
|
||||
let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
|
||||
|
||||
while activity_pos < inputs.len() {
|
||||
let diff_activity = inputs[activity_pos as usize];
|
||||
// Duplicated arguments received a shadow argument, into which enzyme will write the
|
||||
// gradient.
|
||||
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
|
||||
let (activity, duplicated): (&Value, bool) = match diff_activity {
|
||||
DiffActivity::None => panic!("not a valid input activity"),
|
||||
DiffActivity::Const => (enzyme_const, false),
|
||||
DiffActivity::Active => (enzyme_out, false),
|
||||
DiffActivity::ActiveOnly => (enzyme_out, false),
|
||||
DiffActivity::Dual => (enzyme_dup, true),
|
||||
DiffActivity::Dualv => (enzyme_dupv, true),
|
||||
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
|
||||
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
|
||||
DiffActivity::Duplicated => (enzyme_dup, true),
|
||||
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
|
||||
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
|
||||
DiffActivity::Const => (global_const, false),
|
||||
DiffActivity::Active => (global_out, false),
|
||||
DiffActivity::ActiveOnly => (global_out, false),
|
||||
DiffActivity::Dual => (global_dup, true),
|
||||
DiffActivity::Dualv => (global_dupv, true),
|
||||
DiffActivity::DualOnly => (global_dupnoneed, true),
|
||||
DiffActivity::DualvOnly => (global_dupnoneedv, true),
|
||||
DiffActivity::Duplicated => (global_dup, true),
|
||||
DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
|
||||
DiffActivity::FakeActivitySize(_) => (global_const, false),
|
||||
};
|
||||
let outer_arg = outer_args[outer_pos];
|
||||
args.push(cx.get_metadata_value(activity));
|
||||
args.push(activity);
|
||||
if matches!(diff_activity, DiffActivity::Dualv) {
|
||||
let next_outer_arg = outer_args[outer_pos + 1];
|
||||
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
|
||||
|
|
@ -242,7 +246,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
|
|||
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
|
||||
args.push(next_outer_arg2);
|
||||
}
|
||||
args.push(cx.get_metadata_value(enzyme_const));
|
||||
args.push(global_const);
|
||||
args.push(next_outer_arg);
|
||||
outer_pos += 2 + 2 * iterations;
|
||||
activity_pos += 2;
|
||||
|
|
@ -351,13 +355,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
|
|||
let mut args = Vec::with_capacity(num_args as usize + 1);
|
||||
args.push(fn_to_diff);
|
||||
|
||||
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
|
||||
let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
|
||||
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
|
||||
args.push(cx.get_metadata_value(enzyme_primal_ret));
|
||||
args.push(global_primal_ret);
|
||||
}
|
||||
if attrs.width > 1 {
|
||||
let enzyme_width = cx.create_metadata(b"enzyme_width");
|
||||
args.push(cx.get_metadata_value(enzyme_width));
|
||||
let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
|
||||
args.push(global_width);
|
||||
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -110,15 +110,6 @@ fn f14(x: f32) -> Foo {
|
|||
|
||||
type MyFloat = f32;
|
||||
|
||||
// We would like to support type alias to f32/f64 in argument type in the future,
|
||||
// but that requires us to implement our checks at a later stage
|
||||
// like THIR which has type information available.
|
||||
#[autodiff_reverse(df15, Active, Active)]
|
||||
fn f15(x: MyFloat) -> f32 {
|
||||
//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433]
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// We would like to support type alias to f32/f64 in return type in the future
|
||||
#[autodiff_reverse(df16, Active, Active)]
|
||||
fn f16(x: f32) -> MyFloat {
|
||||
|
|
@ -136,13 +127,6 @@ fn f17(x: f64) -> F64Trans {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future
|
||||
#[autodiff_reverse(df18, Active, Active)]
|
||||
fn f18(x: F64Trans) -> f64 {
|
||||
//~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433]
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
// Invalid return activity
|
||||
#[autodiff_forward(df19, Dual, Active)]
|
||||
fn f19(x: f32) -> f32 {
|
||||
|
|
@ -163,11 +147,4 @@ fn f21(x: f32) -> f32 {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
struct DoesNotImplDefault;
|
||||
#[autodiff_forward(df22, Dual)]
|
||||
pub fn f22() -> DoesNotImplDefault {
|
||||
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn main() {}
|
||||
|
|
|
|||
|
|
@ -107,53 +107,24 @@ LL | #[autodiff_reverse(df13, Reverse)]
|
|||
| ^^^^^^^
|
||||
|
||||
error: invalid return activity Active in Forward Mode
|
||||
--> $DIR/autodiff_illegal.rs:147:1
|
||||
--> $DIR/autodiff_illegal.rs:131:1
|
||||
|
|
||||
LL | #[autodiff_forward(df19, Dual, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
error: invalid return activity Dual in Reverse Mode
|
||||
--> $DIR/autodiff_illegal.rs:153:1
|
||||
--> $DIR/autodiff_illegal.rs:137:1
|
||||
|
|
||||
LL | #[autodiff_reverse(df20, Active, Dual)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
error: invalid return activity Duplicated in Reverse Mode
|
||||
--> $DIR/autodiff_illegal.rs:160:1
|
||||
--> $DIR/autodiff_illegal.rs:144:1
|
||||
|
|
||||
LL | #[autodiff_reverse(df21, Active, Duplicated)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type `MyFloat`
|
||||
--> $DIR/autodiff_illegal.rs:116:1
|
||||
|
|
||||
LL | #[autodiff_reverse(df15, Active, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
|
||||
error: aborting due to 18 previous errors
|
||||
|
||||
error[E0433]: failed to resolve: use of undeclared type `F64Trans`
|
||||
--> $DIR/autodiff_illegal.rs:140:1
|
||||
|
|
||||
LL | #[autodiff_reverse(df18, Active, Active)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
|
||||
|
||||
error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
|
||||
--> $DIR/autodiff_illegal.rs:167:1
|
||||
|
|
||||
LL | struct DoesNotImplDefault;
|
||||
| ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
|
||||
LL | #[autodiff_forward(df22, Dual)]
|
||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
|
||||
|
|
||||
= note: the following trait bounds were not satisfied:
|
||||
`DoesNotImplDefault: Default`
|
||||
which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
|
||||
help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
|
||||
|
|
||||
LL + #[derive(Default)]
|
||||
LL | struct DoesNotImplDefault;
|
||||
|
|
||||
|
||||
error: aborting due to 21 previous errors
|
||||
|
||||
Some errors have detailed explanations: E0428, E0433, E0599, E0658.
|
||||
Some errors have detailed explanations: E0428, E0658.
|
||||
For more information about an error, try `rustc --explain E0428`.
|
||||
|
|
|
|||
41
tests/ui/autodiff/incremental.rs
Normal file
41
tests/ui/autodiff/incremental.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
//@ revisions: DEBUG RELEASE
|
||||
//@[RELEASE] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
|
||||
//@[DEBUG] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat -C debuginfo=2
|
||||
//@ needs-enzyme
|
||||
//@ incremental
|
||||
//@ no-prefer-dynamic
|
||||
//@ build-pass
|
||||
#![crate_type = "bin"]
|
||||
#![feature(autodiff)]
|
||||
|
||||
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
|
||||
// In debug mode we would use incremental compilation which caused the metadata to be
|
||||
// dropped. We now use globals instead and add this test to verify that incremental
|
||||
// keeps working. Also testing debug mode while at it.
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(bar, Duplicated, Duplicated)]
|
||||
pub fn foo(r: &[f64; 10], res: &mut f64) {
|
||||
let mut output = [0.0; 10];
|
||||
output[0] = r[0];
|
||||
output[1] = r[1] * r[2];
|
||||
output[2] = r[4] * r[5];
|
||||
output[3] = r[2] * r[6];
|
||||
output[4] = r[1] * r[7];
|
||||
output[5] = r[2] * r[8];
|
||||
output[6] = r[1] * r[9];
|
||||
output[7] = r[5] * r[6];
|
||||
output[8] = r[5] * r[7];
|
||||
output[9] = r[4] * r[8];
|
||||
*res = output.iter().sum();
|
||||
}
|
||||
fn main() {
|
||||
let inputs = Box::new([3.1; 10]);
|
||||
let mut d_inputs = Box::new([0.0; 10]);
|
||||
let mut res = Box::new(0.0);
|
||||
let mut d_res = Box::new(1.0);
|
||||
|
||||
bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
|
||||
dbg!(&d_inputs);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue