feat: add generated parameters to generated function

- update pretty printing tests
- only add generic parameters when function is actually generic (no empty turbofish)
This commit is contained in:
HaeNoe 2025-04-20 01:10:50 +02:00
parent 8b3228233e
commit e2b7278942
No known key found for this signature in database
2 changed files with 54 additions and 5 deletions

View file

@ -305,6 +305,7 @@ mod llvm_enzyme {
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
);
// The first element of it is the name of the function to be generated
@ -477,6 +478,7 @@ mod llvm_enzyme {
new_decl_span: Span,
idents: &[Ident],
errored: bool,
generics: &Generics,
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
@ -499,7 +501,7 @@ mod llvm_enzyme {
};
let unsf_expr = ecx.expr_block(P(unsf_block));
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
let primal_call = gen_primal_call(ecx, span, primal, idents);
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
let black_box_primal_call = ecx.expr_call(
new_decl_span,
blackbox_call_expr.clone(),
@ -548,6 +550,7 @@ mod llvm_enzyme {
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
generics: &Generics,
) -> P<ast::Block> {
let new_decl_span = d_sig.span;
@ -568,6 +571,7 @@ mod llvm_enzyme {
new_decl_span,
&idents,
errored,
generics,
);
if !has_ret(&d_sig.decl.output) {
@ -610,7 +614,6 @@ mod llvm_enzyme {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
if x.mode.is_fwd() {
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@ -670,8 +673,10 @@ mod llvm_enzyme {
span: Span,
primal: Ident,
idents: &[Ident],
generics: &Generics,
) -> P<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
let args: ThinVec<_> =
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@ -680,7 +685,51 @@ mod llvm_enzyme {
} else {
let args: ThinVec<_> =
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
let mut primal_path = ecx.path_ident(span, primal);
let is_generic = !generics.params.is_empty();
match (is_generic, primal_path.segments.last_mut()) {
(true, Some(function_path)) => {
let primal_generic_types = generics
.params
.iter()
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
let generated_generic_types = primal_generic_types
.map(|type_param| {
let generic_param = TyKind::Path(
None,
ast::Path {
span,
segments: thin_vec![ast::PathSegment {
ident: type_param.ident,
args: None,
id: ast::DUMMY_NODE_ID,
}],
tokens: None,
},
);
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
id: type_param.id,
span,
kind: generic_param,
tokens: None,
})))
})
.collect();
function_path.args =
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
span,
args: generated_generic_types,
})));
}
_ => {}
}
let primal_call_expr = ecx.expr_path(primal_path);
ecx.expr_call(span, primal_call_expr, args)
}
}

View file

@ -191,8 +191,8 @@ pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10(x));
::core::hint::black_box(f10::<T>(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10(x))
::core::hint::black_box(f10::<T>(x))
}
fn main() {}