Update autodiff tests for the new intrinsics impl
This commit is contained in:
parent
e1d79b9aad
commit
cdd4118204
13 changed files with 153 additions and 222 deletions
|
|
@ -26,12 +26,13 @@
|
|||
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff;
|
||||
use std::autodiff::autodiff_forward;
|
||||
|
||||
// CHECK: ;
|
||||
#[no_mangle]
|
||||
//#[autodiff(d_square1, Forward, Dual, Dual)]
|
||||
#[autodiff(d_square2, Forward, 4, Dualv, Dualv)]
|
||||
#[autodiff(d_square3, Forward, 4, Dual, Dual)]
|
||||
#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
|
||||
#[autodiff_forward(d_square3, 4, Dual, Dual)]
|
||||
fn square(x: &[f32], y: &mut [f32]) {
|
||||
assert!(x.len() >= 4);
|
||||
assert!(y.len() >= 5);
|
||||
|
|
@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward;
|
|||
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
|
||||
#[autodiff_forward(d_square1, 4, Dual, Dual)]
|
||||
#[no_mangle]
|
||||
#[inline(never)]
|
||||
fn square(x: &f32) -> f32 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// d_sqaure2
|
||||
// d_square2
|
||||
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
|
||||
|
|
@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 {
|
|||
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
|
||||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
|
||||
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
|
||||
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
|
||||
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
|
||||
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
|
||||
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
|
||||
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
|
||||
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
|
||||
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: ret [4 x float] %19
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
|
||||
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
|
||||
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
|
||||
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
|
||||
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
|
||||
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
|
||||
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
|
||||
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
|
||||
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
|
||||
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
|
||||
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
|
||||
// CHECK-NEXT: ret [4 x float] %15
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// d_square3, the extra float is the original return value (x * x)
|
||||
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
|
||||
|
|
@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 {
|
|||
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
|
||||
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
|
||||
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
|
||||
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
|
||||
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
|
||||
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
|
||||
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
|
||||
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
|
||||
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
|
||||
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
|
||||
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
|
||||
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
|
||||
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
|
||||
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
|
||||
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
|
||||
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
|
||||
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
|
||||
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
|
||||
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
|
||||
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
|
||||
// CHECK-NEXT: ret { float, [4 x float] } %21
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
|
||||
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
|
||||
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
|
||||
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
|
||||
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
|
||||
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
|
||||
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
|
||||
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
|
||||
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
|
||||
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
|
||||
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
|
||||
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
|
||||
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
|
||||
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
|
||||
// CHECK-NEXT: ret { float, [4 x float] } %17
|
||||
// CHECK-NEXT: }
|
||||
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
|
|
|
|||
|
|
@ -6,19 +6,11 @@
|
|||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
#[inline(never)]
|
||||
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
|
||||
*x * *x
|
||||
}
|
||||
|
||||
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
|
||||
//
|
||||
// CHECK: ; generic::square
|
||||
// CHECK-NEXT: ; Function Attrs:
|
||||
// CHECK-NEXT: define internal {{.*}} double
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NOT: ret
|
||||
// CHECK: fmul double
|
||||
|
||||
// Ensure that `d_square::<f32>` code is generated
|
||||
//
|
||||
// CHECK: ; generic::square
|
||||
|
|
@ -28,6 +20,15 @@ fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
|
|||
// CHECK-NOT: ret
|
||||
// CHECK: fmul float
|
||||
|
||||
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
|
||||
//
|
||||
// CHECK: ; generic::square
|
||||
// CHECK-NEXT: ; Function Attrs:
|
||||
// CHECK-NEXT: define internal {{.*}} double
|
||||
// CHECK-NEXT: start:
|
||||
// CHECK-NOT: ret
|
||||
// CHECK: fmul double
|
||||
|
||||
fn main() {
|
||||
let xf32: f32 = std::hint::black_box(3.0);
|
||||
let xf64: f64 = std::hint::black_box(3.0);
|
||||
|
|
|
|||
|
|
@ -14,25 +14,27 @@
|
|||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
#[inline(never)]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
#[autodiff_reverse(d_square2, Duplicated, Active)]
|
||||
#[inline(never)]
|
||||
fn square2(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// CHECK:; identical_fnc::main
|
||||
// CHECK-NEXT:; Function Attrs:
|
||||
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
|
||||
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
|
||||
// CHECK-NEXT:start:
|
||||
// CHECK-NOT:br
|
||||
// CHECK-NOT:ret
|
||||
// CHECK:; call identical_fnc::d_square
|
||||
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
|
||||
// CHECK-NEXT:; call identical_fnc::d_square
|
||||
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
|
||||
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
|
||||
// CHECK:; call identical_fnc::d_square
|
||||
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
|
||||
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// CHECK: ; inline::d_square
|
||||
// CHECK-NEXT: ; Function Attrs: alwaysinline
|
||||
// CHECK-NOT: noinline
|
||||
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
let mut dx1 = std::hint::black_box(1.0);
|
||||
let _ = d_square(&x, &mut dx1, 1.0);
|
||||
assert_eq!(dx1, 6.0);
|
||||
}
|
||||
|
|
@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse;
|
|||
|
||||
#[autodiff_reverse(d_square, Duplicated, Active)]
|
||||
#[no_mangle]
|
||||
#[inline(never)]
|
||||
fn square(x: &f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
|
||||
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
|
||||
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'")
|
||||
// CHECK-NEXT:invertstart:
|
||||
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
|
||||
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
|
||||
|
|
|
|||
|
|
@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse;
|
|||
|
||||
#[no_mangle]
|
||||
#[autodiff_reverse(df, Active, Active, Active)]
|
||||
#[inline(never)]
|
||||
fn primal(x: f32, y: f32) -> f64 {
|
||||
(x * x * y) as f64
|
||||
}
|
||||
|
||||
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
|
||||
// CHECK-NEXT:start:
|
||||
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
|
||||
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
|
||||
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
|
||||
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
|
||||
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
|
||||
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
|
||||
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
|
||||
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
|
||||
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK-NEXT:}
|
||||
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
|
||||
// CHECK-NEXT: invertstart:
|
||||
// CHECK-NEXT: %_4 = fmul float %x, %x
|
||||
// CHECK-NEXT: %_3 = fmul float %_4, %y
|
||||
// CHECK-NEXT: %_0 = fpext float %_3 to double
|
||||
// CHECK-NEXT: %0 = fadd fast float %y, %y
|
||||
// CHECK-NEXT: %1 = fmul fast float %0, %x
|
||||
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
|
||||
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
|
||||
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
|
||||
// CHECK-NEXT: ret { double, float, float } %4
|
||||
// CHECK-NEXT: }
|
||||
|
||||
fn main() {
|
||||
let x = std::hint::black_box(3.0);
|
||||
let y = std::hint::black_box(2.5);
|
||||
let scalar = std::hint::black_box(1.0);
|
||||
let (r1, r2, r3) = df(x, y, scalar);
|
||||
// 3*3*1.5 = 22.5
|
||||
// 3*3*2.5 = 22.5
|
||||
assert_eq!(r1, 22.5);
|
||||
// 2*x*y = 2*3*2.5 = 15.0
|
||||
assert_eq!(r2, 15.0);
|
||||
|
|
|
|||
30
tests/codegen-llvm/autodiff/trait.rs
Normal file
30
tests/codegen-llvm/autodiff/trait.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
|
||||
// Just check it does not crash for now
|
||||
// CHECK: ;
|
||||
#![feature(autodiff)]
|
||||
|
||||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
struct Foo {
|
||||
a: f64,
|
||||
}
|
||||
|
||||
trait MyTrait {
|
||||
fn f(&self, x: f64) -> f64;
|
||||
fn df(&self, x: f64, seed: f64) -> (f64, f64);
|
||||
}
|
||||
|
||||
impl MyTrait for Foo {
|
||||
#[autodiff_reverse(df, Const, Active, Active)]
|
||||
fn f(&self, x: f64) -> f64 {
|
||||
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let foo = Foo { a: 3.0f64 };
|
||||
dbg!(foo.df(1.0, 1.0));
|
||||
}
|
||||
|
|
@ -3,10 +3,10 @@
|
|||
//@ needs-enzyme
|
||||
|
||||
#![feature(autodiff)]
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
#[macro_use]
|
||||
extern crate std;
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
//@ pretty-mode:expanded
|
||||
//@ pretty-compare-only
|
||||
//@ pp-exact:autodiff_forward.pp
|
||||
|
|
@ -16,7 +16,6 @@ extern crate std;
|
|||
use std::autodiff::{autodiff_forward, autodiff_reverse};
|
||||
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||
|
||||
|
||||
|
|
@ -36,163 +35,96 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
|||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
|
||||
#[inline(never)]
|
||||
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f1(x, y));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<(f64, f64)>::default())
|
||||
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f2(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f2(x, y));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f2(x, y))
|
||||
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f3(x, y));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f3(x, y))
|
||||
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f4() {}
|
||||
#[rustc_autodiff(Forward, 1, None)]
|
||||
#[inline(never)]
|
||||
pub fn df4() -> () {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f4());
|
||||
::core::hint::black_box(());
|
||||
}
|
||||
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f5(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((by_0,));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f5(x, y))
|
||||
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
struct DoesNotImplDefault;
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f6() -> DoesNotImplDefault {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Const)]
|
||||
#[inline(never)]
|
||||
pub fn df6() -> DoesNotImplDefault {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f6());
|
||||
::core::hint::black_box(());
|
||||
::core::hint::black_box(f6())
|
||||
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f7(x: f32) -> () {}
|
||||
#[rustc_autodiff(Forward, 1, Const, None)]
|
||||
#[inline(never)]
|
||||
pub fn df7(x: f32) -> () {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f7(x));
|
||||
::core::hint::black_box(());
|
||||
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
|
||||
}
|
||||
#[no_mangle]
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Forward, 4, Dual, Dual)]
|
||||
#[inline(never)]
|
||||
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 5usize] {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
|
||||
::core::hint::black_box(<[f32; 5usize]>::default())
|
||||
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
|
||||
(x, bx_0, bx_1, bx_2, bx_3))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
|
||||
-> [f32; 4usize] {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
|
||||
::core::hint::black_box(<[f32; 4usize]>::default())
|
||||
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
|
||||
(x, bx_0, bx_1, bx_2, bx_3))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f8(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<f32>::default())
|
||||
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
|
||||
}
|
||||
pub fn f9() {
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
fn inner(x: f32) -> f32 { x * x }
|
||||
#[rustc_autodiff(Forward, 1, Dual, Dual)]
|
||||
#[inline(never)]
|
||||
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(inner(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<(f32, f32)>::default())
|
||||
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
|
||||
}
|
||||
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
|
||||
#[inline(never)]
|
||||
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(inner(x));
|
||||
::core::hint::black_box((bx_0,));
|
||||
::core::hint::black_box(<f32>::default())
|
||||
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
|
||||
}
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
|
||||
#[inline(never)]
|
||||
pub fn d_square<T: std::ops::Mul<Output = T> +
|
||||
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f10::<T>(x));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f10::<T>(x))
|
||||
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
|
||||
}
|
||||
fn main() {}
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
//@ needs-enzyme
|
||||
|
||||
#![feature(autodiff)]
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
#[macro_use]
|
||||
extern crate std;
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
//@ pretty-mode:expanded
|
||||
//@ pretty-compare-only
|
||||
//@ pp-exact:autodiff_reverse.pp
|
||||
|
|
@ -16,7 +16,6 @@ extern crate std;
|
|||
use std::autodiff::autodiff_reverse;
|
||||
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f1(x: &[f64], y: f64) -> f64 {
|
||||
|
||||
// Not the most interesting derivative, but who are we to judge
|
||||
|
|
@ -29,58 +28,33 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
|
|||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f1(x, y));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f1(x, y))
|
||||
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f2() {}
|
||||
#[rustc_autodiff(Reverse, 1, None)]
|
||||
#[inline(never)]
|
||||
pub fn df2() {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f2());
|
||||
::core::hint::black_box(());
|
||||
}
|
||||
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f3(x: &[f64], y: f64) -> f64 {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
|
||||
#[inline(never)]
|
||||
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f3(x, y));
|
||||
::core::hint::black_box((dx_0, dret));
|
||||
::core::hint::black_box(f3(x, y))
|
||||
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
|
||||
}
|
||||
enum Foo { Reverse, }
|
||||
use Foo::Reverse;
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
|
||||
#[rustc_autodiff(Reverse, 1, Const, None)]
|
||||
#[inline(never)]
|
||||
pub fn df4(x: f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f4(x));
|
||||
::core::hint::black_box(());
|
||||
}
|
||||
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
pub fn f5(x: *const f32, y: &f32) {
|
||||
::core::panicking::panic("not implemented")
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
|
||||
#[inline(never)]
|
||||
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(f5(x, y));
|
||||
::core::hint::black_box((dx_0, dy_0));
|
||||
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
|
||||
}
|
||||
fn main() {}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
enum Foo { Reverse }
|
||||
enum Foo {
|
||||
Reverse,
|
||||
}
|
||||
use Foo::Reverse;
|
||||
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
|
||||
// constructor) namespace? > It's expected to work normally.
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@
|
|||
//@ needs-enzyme
|
||||
|
||||
#![feature(autodiff)]
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
#[macro_use]
|
||||
extern crate std;
|
||||
#[prelude_import]
|
||||
use ::std::prelude::rust_2015::*;
|
||||
//@ pretty-mode:expanded
|
||||
//@ pretty-compare-only
|
||||
//@ pp-exact:inherent_impl.pp
|
||||
|
|
@ -26,16 +26,12 @@ trait MyTrait {
|
|||
|
||||
impl MyTrait for Foo {
|
||||
#[rustc_autodiff]
|
||||
#[inline(never)]
|
||||
fn f(&self, x: f64) -> f64 {
|
||||
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
|
||||
}
|
||||
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
|
||||
#[inline(never)]
|
||||
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
|
||||
unsafe { asm!("NOP", options(pure, nomem)); };
|
||||
::core::hint::black_box(self.f(x));
|
||||
::core::hint::black_box((dret,));
|
||||
::core::hint::black_box((self.f(x), f64::default()))
|
||||
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
|
||||
(self, x, dret))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
22
tests/ui/autodiff/macro_hygiene.rs
Normal file
22
tests/ui/autodiff/macro_hygiene.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
|
||||
//@ no-prefer-dynamic
|
||||
//@ needs-enzyme
|
||||
//@ check-pass
|
||||
|
||||
// In the past, we just checked for correct macro hygiene information.
|
||||
|
||||
#![feature(autodiff)]
|
||||
|
||||
macro_rules! demo {
|
||||
() => {
|
||||
#[std::autodiff::autodiff_reverse(fd, Active, Active)]
|
||||
fn f(x: f64) -> f64 {
|
||||
x * x
|
||||
}
|
||||
};
|
||||
}
|
||||
demo!();
|
||||
|
||||
fn main() {
|
||||
dbg!(f(2.0f64));
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue