Update autodiff tests for the new intrinsics impl

This commit is contained in:
Marcelo Domínguez 2025-08-14 15:42:14 +00:00
parent e1d79b9aad
commit cdd4118204
13 changed files with 153 additions and 222 deletions

View file

@ -26,12 +26,13 @@
#![feature(autodiff)]
use std::autodiff::autodiff;
use std::autodiff::autodiff_forward;
// CHECK: ;
#[no_mangle]
//#[autodiff(d_square1, Forward, Dual, Dual)]
#[autodiff(d_square2, Forward, 4, Dualv, Dualv)]
#[autodiff(d_square3, Forward, 4, Dual, Dual)]
#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
#[autodiff_forward(d_square3, 4, Dual, Dual)]
fn square(x: &[f32], y: &mut [f32]) {
assert!(x.len() >= 4);
assert!(y.len() >= 5);

View file

@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward;
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
#[inline(never)]
fn square(x: &f32) -> f32 {
x * x
}
// d_sqaure2
// d_square2
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 {
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: ret [4 x float] %19
// CHECK-NEXT: }
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: ret [4 x float] %15
// CHECK-NEXT: }
// d_square3, the extra float is the original return value (x * x)
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 {
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
// CHECK-NEXT: ret { float, [4 x float] } %21
// CHECK-NEXT: }
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
// CHECK-NEXT: ret { float, [4 x float] } %17
// CHECK-NEXT: }
fn main() {
let x = std::hint::black_box(3.0);

View file

@ -6,19 +6,11 @@
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double
// Ensure that `d_square::<f32>` code is generated
//
// CHECK: ; generic::square
@ -28,6 +20,15 @@ fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
// CHECK-NOT: ret
// CHECK: fmul float
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double
fn main() {
let xf32: f32 = std::hint::black_box(3.0);
let xf64: f64 = std::hint::black_box(3.0);

View file

@ -14,25 +14,27 @@
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[inline(never)]
fn square(x: &f64) -> f64 {
x * x
}
#[autodiff_reverse(d_square2, Duplicated, Active)]
#[inline(never)]
fn square2(x: &f64) -> f64 {
x * x
}
// CHECK:; identical_fnc::main
// CHECK-NEXT:; Function Attrs:
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
// CHECK-NEXT:start:
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
fn main() {
let x = std::hint::black_box(3.0);

View file

@ -1,23 +0,0 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
// CHECK: ; inline::d_square
// CHECK-NEXT: ; Function Attrs: alwaysinline
// CHECK-NOT: noinline
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
fn main() {
let x = std::hint::black_box(3.0);
let mut dx1 = std::hint::black_box(1.0);
let _ = d_square(&x, &mut dx1, 1.0);
assert_eq!(dx1, 6.0);
}

View file

@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
#[inline(never)]
fn square(x: &f64) -> f64 {
x * x
}
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'")
// CHECK-NEXT:invertstart:
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val

View file

@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse;
#[no_mangle]
#[autodiff_reverse(df, Active, Active, Active)]
#[inline(never)]
fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
// CHECK-NEXT:start:
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
// CHECK-NEXT: ret void
// CHECK-NEXT:}
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
// CHECK-NEXT: invertstart:
// CHECK-NEXT: %_4 = fmul float %x, %x
// CHECK-NEXT: %_3 = fmul float %_4, %y
// CHECK-NEXT: %_0 = fpext float %_3 to double
// CHECK-NEXT: %0 = fadd fast float %y, %y
// CHECK-NEXT: %1 = fmul fast float %0, %x
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
// CHECK-NEXT: ret { double, float, float } %4
// CHECK-NEXT: }
fn main() {
let x = std::hint::black_box(3.0);
let y = std::hint::black_box(2.5);
let scalar = std::hint::black_box(1.0);
let (r1, r2, r3) = df(x, y, scalar);
// 3*3*1.5 = 22.5
// 3*3*2.5 = 22.5
assert_eq!(r1, 22.5);
// 2*x*y = 2*3*2.5 = 15.0
assert_eq!(r2, 15.0);

View file

@ -0,0 +1,30 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
// Just check it does not crash for now
// CHECK: ;
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
struct Foo {
a: f64,
}
trait MyTrait {
fn f(&self, x: f64) -> f64;
fn df(&self, x: f64, seed: f64) -> (f64, f64);
}
impl MyTrait for Foo {
#[autodiff_reverse(df, Const, Active, Active)]
fn f(&self, x: f64) -> f64 {
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
}
}
fn main() {
let foo = Foo { a: 3.0f64 };
dbg!(foo.df(1.0, 1.0));
}

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:autodiff_forward.pp
@ -16,7 +16,6 @@ extern crate std;
use std::autodiff::{autodiff_forward, autodiff_reverse};
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
@ -36,163 +35,96 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
#[inline(never)]
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f64, f64)>::default())
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f2(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f2(x, y))
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f3(x, y))
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f4() {}
#[rustc_autodiff(Forward, 1, None)]
#[inline(never)]
pub fn df4() -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4());
::core::hint::black_box(());
}
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
#[inline(never)]
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((by_0,));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
}
struct DoesNotImplDefault;
#[rustc_autodiff]
#[inline(never)]
pub fn f6() -> DoesNotImplDefault {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const)]
#[inline(never)]
pub fn df6() -> DoesNotImplDefault {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f6());
::core::hint::black_box(());
::core::hint::black_box(f6())
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
}
#[rustc_autodiff]
#[inline(never)]
pub fn f7(x: f32) -> () {}
#[rustc_autodiff(Forward, 1, Const, None)]
#[inline(never)]
pub fn df7(x: f32) -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f7(x));
::core::hint::black_box(());
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
}
#[no_mangle]
#[rustc_autodiff]
#[inline(never)]
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Forward, 4, Dual, Dual)]
#[inline(never)]
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 5usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 5usize]>::default())
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
#[inline(never)]
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 4usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 4usize]>::default())
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
}
pub fn f9() {
#[rustc_autodiff]
#[inline(never)]
fn inner(x: f32) -> f32 { x * x }
#[rustc_autodiff(Forward, 1, Dual, Dual)]
#[inline(never)]
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f32, f32)>::default())
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
}
}
#[rustc_autodiff]
#[inline(never)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10::<T>(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10::<T>(x))
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
}
fn main() {}

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:autodiff_reverse.pp
@ -16,7 +16,6 @@ extern crate std;
use std::autodiff::autodiff_reverse;
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
// Not the most interesting derivative, but who are we to judge
@ -29,58 +28,33 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f1(x, y))
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f2() {}
#[rustc_autodiff(Reverse, 1, None)]
#[inline(never)]
pub fn df2() {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2());
::core::hint::black_box(());
}
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f3(x, y))
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
}
enum Foo { Reverse, }
use Foo::Reverse;
#[rustc_autodiff]
#[inline(never)]
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Reverse, 1, Const, None)]
#[inline(never)]
pub fn df4(x: f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4(x));
::core::hint::black_box(());
}
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: *const f32, y: &f32) {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
#[inline(never)]
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dy_0));
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
}
fn main() {}

View file

@ -23,7 +23,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
unimplemented!()
}
enum Foo { Reverse }
enum Foo {
Reverse,
}
use Foo::Reverse;
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
// constructor) namespace? > It's expected to work normally.

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:inherent_impl.pp
@ -26,16 +26,12 @@ trait MyTrait {
impl MyTrait for Foo {
#[rustc_autodiff]
#[inline(never)]
fn f(&self, x: f64) -> f64 {
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
}
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
#[inline(never)]
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(self.f(x));
::core::hint::black_box((dret,));
::core::hint::black_box((self.f(x), f64::default()))
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
(self, x, dret))
}
}

View file

@ -0,0 +1,22 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//@ check-pass
// In the past, we just checked for correct macro hygiene information.
#![feature(autodiff)]
macro_rules! demo {
() => {
#[std::autodiff::autodiff_reverse(fd, Active, Active)]
fn f(x: f64) -> f64 {
x * x
}
};
}
demo!();
fn main() {
dbg!(f(2.0f64));
}