feat: add test for generics in generated function

This commit is contained in:
HaeNoe 2025-04-19 19:28:33 +02:00
parent 56a0c7dfea
commit 8b3228233e
No known key found for this signature in database
2 changed files with 20 additions and 0 deletions

View file

@ -31,6 +31,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
// We want to make sure that we can use the macro for functions defined inside of functions
// Make sure we can handle generics
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@ -181,4 +183,16 @@ pub fn f9() {
::core::hint::black_box(<f32>::default())
}
}
#[rustc_autodiff]
#[inline(never)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10(x))
}
fn main() {}

View file

@ -63,4 +63,10 @@ pub fn f9() {
}
}
// Make sure we can handle generics
#[autodiff(d_square, Reverse, Duplicated, Active)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
fn main() {}