Rollup merge of #140697 - Sa4dUs:split-autodiff, r=ZuseZ4
Split `autodiff` into `autodiff_forward` and `autodiff_reverse` This PR splits `#[autodiff]` macro so `#[autodiff(df, Reverse, args)]` would become `#[autodiff_reverse(df, args)]` and `#[autodiff(df, Forward, args)]` would become `#[autodiff_forwad(df, args)]`.
This commit is contained in:
commit
7f5f29b663
32 changed files with 234 additions and 217 deletions
|
|
@ -56,7 +56,6 @@ builtin_macros_assert_requires_expression = macro requires an expression as an a
|
|||
|
||||
builtin_macros_autodiff = autodiff must be applied to function
|
||||
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
|
||||
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
|
||||
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
|
||||
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
|
||||
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
|
||||
|
|
|
|||
|
|
@ -86,27 +86,23 @@ mod llvm_enzyme {
|
|||
ecx: &mut ExtCtxt<'_>,
|
||||
meta_item: &ThinVec<MetaItemInner>,
|
||||
has_ret: bool,
|
||||
mode: DiffMode,
|
||||
) -> AutoDiffAttrs {
|
||||
let dcx = ecx.sess.dcx();
|
||||
let mode = name(&meta_item[1]);
|
||||
let Ok(mode) = DiffMode::from_str(&mode) else {
|
||||
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||
return AutoDiffAttrs::error();
|
||||
};
|
||||
|
||||
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
|
||||
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
|
||||
let mut first_activity = 2;
|
||||
let mut first_activity = 1;
|
||||
|
||||
let width = if let [_, _, x, ..] = &meta_item[..]
|
||||
let width = if let [_, x, ..] = &meta_item[..]
|
||||
&& let Some(x) = width(x)
|
||||
{
|
||||
first_activity = 3;
|
||||
first_activity = 2;
|
||||
match x.try_into() {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
dcx.emit_err(errors::AutoDiffInvalidWidth {
|
||||
span: meta_item[2].span(),
|
||||
span: meta_item[1].span(),
|
||||
width: x,
|
||||
});
|
||||
return AutoDiffAttrs::error();
|
||||
|
|
@ -165,6 +161,24 @@ mod llvm_enzyme {
|
|||
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||
}
|
||||
|
||||
pub(crate) fn expand_forward(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
expand_span: Span,
|
||||
meta_item: &ast::MetaItem,
|
||||
item: Annotatable,
|
||||
) -> Vec<Annotatable> {
|
||||
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
|
||||
}
|
||||
|
||||
pub(crate) fn expand_reverse(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
expand_span: Span,
|
||||
meta_item: &ast::MetaItem,
|
||||
item: Annotatable,
|
||||
) -> Vec<Annotatable> {
|
||||
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
|
||||
}
|
||||
|
||||
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||
/// type-checking and can be called by users. The function body of the placeholder function will
|
||||
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
||||
|
|
@ -198,11 +212,12 @@ mod llvm_enzyme {
|
|||
/// ```
|
||||
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
||||
/// in CI.
|
||||
pub(crate) fn expand(
|
||||
pub(crate) fn expand_with_mode(
|
||||
ecx: &mut ExtCtxt<'_>,
|
||||
expand_span: Span,
|
||||
meta_item: &ast::MetaItem,
|
||||
mut item: Annotatable,
|
||||
mode: DiffMode,
|
||||
) -> Vec<Annotatable> {
|
||||
if cfg!(not(llvm_enzyme)) {
|
||||
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
||||
|
|
@ -245,29 +260,41 @@ mod llvm_enzyme {
|
|||
// create TokenStream from vec elemtents:
|
||||
// meta_item doesn't have a .tokens field
|
||||
let mut ts: Vec<TokenTree> = vec![];
|
||||
if meta_item_vec.len() < 2 {
|
||||
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||
// input and output args.
|
||||
if meta_item_vec.len() < 1 {
|
||||
// At the bare minimum, we need a fnc name.
|
||||
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||
return vec![item];
|
||||
}
|
||||
|
||||
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
|
||||
let mode_symbol = match mode {
|
||||
DiffMode::Forward => sym::Forward,
|
||||
DiffMode::Reverse => sym::Reverse,
|
||||
_ => unreachable!("Unsupported mode: {:?}", mode),
|
||||
};
|
||||
|
||||
// Insert mode token
|
||||
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
|
||||
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
|
||||
ts.insert(
|
||||
1,
|
||||
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
|
||||
);
|
||||
|
||||
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
|
||||
// If it is not given, we default to 1 (scalar mode).
|
||||
let start_position;
|
||||
let kind: LitKind = LitKind::Integer;
|
||||
let symbol;
|
||||
if meta_item_vec.len() >= 3
|
||||
&& let Some(width) = width(&meta_item_vec[2])
|
||||
if meta_item_vec.len() >= 2
|
||||
&& let Some(width) = width(&meta_item_vec[1])
|
||||
{
|
||||
start_position = 3;
|
||||
start_position = 2;
|
||||
symbol = Symbol::intern(&width.to_string());
|
||||
} else {
|
||||
start_position = 2;
|
||||
start_position = 1;
|
||||
symbol = sym::integer(1);
|
||||
}
|
||||
|
||||
let l: Lit = Lit { kind, symbol, suffix: None };
|
||||
let t = Token::new(TokenKind::Literal(l), Span::default());
|
||||
let comma = Token::new(TokenKind::Comma, Span::default());
|
||||
|
|
@ -289,7 +316,7 @@ mod llvm_enzyme {
|
|||
ts.pop();
|
||||
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||
|
||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
|
||||
if !x.is_active() {
|
||||
// We encountered an error, so we return the original item.
|
||||
// This allows us to potentially parse other attributes.
|
||||
|
|
@ -1017,4 +1044,4 @@ mod llvm_enzyme {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) use llvm_enzyme::expand;
|
||||
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
|
||||
|
|
|
|||
|
|
@ -180,14 +180,6 @@ mod autodiff {
|
|||
pub(crate) act: String,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff_mode)]
|
||||
pub(crate) struct AutoDiffInvalidMode {
|
||||
#[primary_span]
|
||||
pub(crate) span: Span,
|
||||
pub(crate) mode: String,
|
||||
}
|
||||
|
||||
#[derive(Diagnostic)]
|
||||
#[diag(builtin_macros_autodiff_width)]
|
||||
pub(crate) struct AutoDiffInvalidWidth {
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
#![allow(internal_features)]
|
||||
#![allow(rustc::diagnostic_outside_of_impl)]
|
||||
#![allow(rustc::untranslatable_diagnostic)]
|
||||
#![cfg_attr(not(bootstrap), feature(autodiff))]
|
||||
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
|
||||
#![doc(rust_logo)]
|
||||
#![feature(assert_matches)]
|
||||
#![feature(autodiff)]
|
||||
#![feature(box_patterns)]
|
||||
#![feature(decl_macro)]
|
||||
#![feature(if_let_guard)]
|
||||
|
|
@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
|
|||
|
||||
register_attr! {
|
||||
alloc_error_handler: alloc_error_handler::expand,
|
||||
autodiff: autodiff::expand,
|
||||
autodiff_forward: autodiff::expand_forward,
|
||||
autodiff_reverse: autodiff::expand_reverse,
|
||||
bench: test::expand_bench,
|
||||
cfg_accessible: cfg_accessible::Expander,
|
||||
cfg_eval: cfg_eval::expand,
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
|||
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
||||
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
||||
}
|
||||
[sym::autodiff, ..] => {
|
||||
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
|
||||
self.check_autodiff(hir_id, attr, span, target)
|
||||
}
|
||||
[sym::coroutine, ..] => {
|
||||
|
|
|
|||
|
|
@ -244,6 +244,7 @@ symbols! {
|
|||
FnMut,
|
||||
FnOnce,
|
||||
Formatter,
|
||||
Forward,
|
||||
From,
|
||||
FromIterator,
|
||||
FromResidual,
|
||||
|
|
@ -339,6 +340,7 @@ symbols! {
|
|||
Result,
|
||||
ResumeTy,
|
||||
Return,
|
||||
Reverse,
|
||||
Right,
|
||||
Rust,
|
||||
RustaceansAreAwesome,
|
||||
|
|
@ -522,7 +524,8 @@ symbols! {
|
|||
audit_that,
|
||||
augmented_assignments,
|
||||
auto_traits,
|
||||
autodiff,
|
||||
autodiff_forward,
|
||||
autodiff_reverse,
|
||||
automatically_derived,
|
||||
avx,
|
||||
avx10_target_feature,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue