diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 36f84c5d74b2..1ff4fc6aaab2 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -305,6 +305,7 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, + &generics, ); // The first element of it is the name of the function to be generated @@ -477,6 +478,7 @@ mod llvm_enzyme { new_decl_span: Span, idents: &[Ident], errored: bool, + generics: &Generics, ) -> (P, P, P, P) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { @@ -499,7 +501,7 @@ mod llvm_enzyme { }; let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents); + let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( new_decl_span, blackbox_call_expr.clone(), @@ -548,6 +550,7 @@ mod llvm_enzyme { sig_span: Span, idents: Vec, errored: bool, + generics: &Generics, ) -> P { let new_decl_span = d_sig.span; @@ -568,6 +571,7 @@ mod llvm_enzyme { new_decl_span, &idents, errored, + generics, ); if !has_ret(&d_sig.decl.output) { @@ -610,7 +614,6 @@ mod llvm_enzyme { panic!("Did not expect Default ret ty: {:?}", span); } }; - if x.mode.is_fwd() { // Fwd mode is easy. If the return activity is Const, we support arbitrary types. // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. @@ -670,8 +673,10 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], + generics: &Generics, ) -> P { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); @@ -680,7 +685,51 @@ mod llvm_enzyme { } else { let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let mut primal_path = ecx.path_ident(span, primal); + + let is_generic = !generics.params.is_empty(); + + match (is_generic, primal_path.segments.last_mut()) { + (true, Some(function_path)) => { + let primal_generic_types = generics + .params + .iter() + .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); + + let generated_generic_types = primal_generic_types + .map(|type_param| { + let generic_param = TyKind::Path( + None, + ast::Path { + span, + segments: thin_vec![ast::PathSegment { + ident: type_param.ident, + args: None, + id: ast::DUMMY_NODE_ID, + }], + tokens: None, + }, + ); + + ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { + id: type_param.id, + span, + kind: generic_param, + tokens: None, + }))) + }) + .collect(); + + function_path.args = + Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { + span, + args: generated_generic_types, + }))); + } + _ => {} + } + + let primal_call_expr = ecx.expr_path(primal_path); ecx.expr_call(span, primal_call_expr, args) } } diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 2fa122a618d6..8253603e8079 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -191,8 +191,8 @@ pub fn f10 + Copy>(x: &T) -> T { *x * *x } pub fn d_square + Copy>(x: &T, dx_0: &mut T, dret: T) -> T { unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10(x)); + ::core::hint::black_box(f10::(x)); ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10(x)) + ::core::hint::black_box(f10::(x)) } fn main() {}