feat: add generated parameters to generated function
- update pretty printing tests - only add generic parameters when function is actually generic (no empty turbofish)
This commit is contained in:
parent
8b3228233e
commit
e2b7278942
2 changed files with 54 additions and 5 deletions
|
|
@ -305,6 +305,7 @@ mod llvm_enzyme {
|
|||
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
|
||||
let d_body = gen_enzyme_body(
|
||||
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
|
||||
&generics,
|
||||
);
|
||||
|
||||
// The first element of it is the name of the function to be generated
|
||||
|
|
@ -477,6 +478,7 @@ mod llvm_enzyme {
|
|||
new_decl_span: Span,
|
||||
idents: &[Ident],
|
||||
errored: bool,
|
||||
generics: &Generics,
|
||||
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
|
||||
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
|
||||
let noop = ast::InlineAsm {
|
||||
|
|
@ -499,7 +501,7 @@ mod llvm_enzyme {
|
|||
};
|
||||
let unsf_expr = ecx.expr_block(P(unsf_block));
|
||||
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
|
||||
let primal_call = gen_primal_call(ecx, span, primal, idents);
|
||||
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
|
||||
let black_box_primal_call = ecx.expr_call(
|
||||
new_decl_span,
|
||||
blackbox_call_expr.clone(),
|
||||
|
|
@ -548,6 +550,7 @@ mod llvm_enzyme {
|
|||
sig_span: Span,
|
||||
idents: Vec<Ident>,
|
||||
errored: bool,
|
||||
generics: &Generics,
|
||||
) -> P<ast::Block> {
|
||||
let new_decl_span = d_sig.span;
|
||||
|
||||
|
|
@ -568,6 +571,7 @@ mod llvm_enzyme {
|
|||
new_decl_span,
|
||||
&idents,
|
||||
errored,
|
||||
generics,
|
||||
);
|
||||
|
||||
if !has_ret(&d_sig.decl.output) {
|
||||
|
|
@ -610,7 +614,6 @@ mod llvm_enzyme {
|
|||
panic!("Did not expect Default ret ty: {:?}", span);
|
||||
}
|
||||
};
|
||||
|
||||
if x.mode.is_fwd() {
|
||||
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
|
||||
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
|
||||
|
|
@ -670,8 +673,10 @@ mod llvm_enzyme {
|
|||
span: Span,
|
||||
primal: Ident,
|
||||
idents: &[Ident],
|
||||
generics: &Generics,
|
||||
) -> P<ast::Expr> {
|
||||
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
|
||||
|
||||
if has_self {
|
||||
let args: ThinVec<_> =
|
||||
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||
|
|
@ -680,7 +685,51 @@ mod llvm_enzyme {
|
|||
} else {
|
||||
let args: ThinVec<_> =
|
||||
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
|
||||
let mut primal_path = ecx.path_ident(span, primal);
|
||||
|
||||
let is_generic = !generics.params.is_empty();
|
||||
|
||||
match (is_generic, primal_path.segments.last_mut()) {
|
||||
(true, Some(function_path)) => {
|
||||
let primal_generic_types = generics
|
||||
.params
|
||||
.iter()
|
||||
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
|
||||
|
||||
let generated_generic_types = primal_generic_types
|
||||
.map(|type_param| {
|
||||
let generic_param = TyKind::Path(
|
||||
None,
|
||||
ast::Path {
|
||||
span,
|
||||
segments: thin_vec![ast::PathSegment {
|
||||
ident: type_param.ident,
|
||||
args: None,
|
||||
id: ast::DUMMY_NODE_ID,
|
||||
}],
|
||||
tokens: None,
|
||||
},
|
||||
);
|
||||
|
||||
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
|
||||
id: type_param.id,
|
||||
span,
|
||||
kind: generic_param,
|
||||
tokens: None,
|
||||
})))
|
||||
})
|
||||
.collect();
|
||||
|
||||
function_path.args =
|
||||
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
|
||||
span,
|
||||
args: generated_generic_types,
|
||||
})));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let primal_call_expr = ecx.expr_path(primal_path);
|
||||
ecx.expr_call(span, primal_call_expr, args)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -191,8 +191,8 @@ pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
|
|||
pub fn d_square<T: std::ops::Mul<Output = T> +
|
||||
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f10(x));
|
||||
::core::hint::black_box(f10::<T>(x));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f10(x))
|
||||
::core::hint::black_box(f10::<T>(x))
|
||||
}
|
||||
fn main() {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue