add autodiff examples
This commit is contained in:
parent
a2dce774bc
commit
f5892da3f2
5 changed files with 71 additions and 6 deletions
|
|
@ -23,6 +23,7 @@ optimize_for_size = []
|
|||
# Make `RefCell` store additional debugging information, which is printed out when
|
||||
# a borrow error occurs
|
||||
debug_refcell = []
|
||||
llvm_enzyme = []
|
||||
|
||||
[lints.rust.unexpected_cfgs]
|
||||
level = "warn"
|
||||
|
|
@ -38,4 +39,6 @@ check-cfg = [
|
|||
'cfg(target_has_reliable_f16_math)',
|
||||
'cfg(target_has_reliable_f128)',
|
||||
'cfg(target_has_reliable_f128_math)',
|
||||
'cfg(llvm_enzyme)',
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1511,13 +1511,43 @@ pub(crate) mod builtin {
|
|||
/// If used on an input argument, a new shadow argument of the same type will be created,
|
||||
/// directly following the original argument.
|
||||
///
|
||||
/// ### Usage examples:
|
||||
///
|
||||
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
|
||||
/// #![feature(autodiff)]
|
||||
/// use std::autodiff::*;
|
||||
/// #[autodiff_forward(rb_fwd1, Dual, Const, Dual)]
|
||||
/// #[autodiff_forward(rb_fwd2, Const, Dual, Dual)]
|
||||
/// #[autodiff_forward(rb_fwd3, Dual, Dual, Dual)]
|
||||
/// fn rosenbrock(x: f64, y: f64) -> f64 {
|
||||
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
|
||||
/// }
|
||||
/// #[autodiff_forward(rb_inp_fwd, Dual, Dual, Dual)]
|
||||
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
|
||||
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
|
||||
/// }
|
||||
///
|
||||
/// fn main() {
|
||||
/// let x0 = rosenbrock(1.0, 3.0); // 400.0
|
||||
/// let (x1, dx1) = rb_fwd1(1.0, 1.0, 3.0); // (400.0, -800.0)
|
||||
/// let (x2, dy1) = rb_fwd2(1.0, 3.0, 1.0); // (400.0, 400.0)
|
||||
/// // When seeding both arguments at once the tangent return is the sum of both.
|
||||
/// let (x3, dxy) = rb_fwd3(1.0, 1.0, 3.0, 1.0); // (400.0, -400.0)
|
||||
///
|
||||
/// let mut out = 0.0;
|
||||
/// let mut dout = 0.0;
|
||||
/// rb_inp_fwd(1.0, 1.0, 3.0, 1.0, &mut out, &mut dout);
|
||||
/// // (out, dout) == (400.0, -400.0)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// We might want to track how one input float affects one or more output floats. In this case,
|
||||
/// the shadow of one input should be initialized to `1.0`, while the shadows of the other
|
||||
/// inputs should be initialized to `0.0`. The shadow of the output(s) should be initialized to
|
||||
/// `0.0`. After calling the generated function, the shadow of the input will be zeroed,
|
||||
/// while the shadow(s) of the output(s) will contain the derivatives. Forward mode is generally
|
||||
/// more efficient if we have more output floats marked as `Dual` than input floats.
|
||||
/// Related information can also be found unter the term "Vector-Jacobian product" (VJP).
|
||||
/// Related information can also be found under the term "Vector-Jacobian product" (VJP).
|
||||
#[unstable(feature = "autodiff", issue = "124509")]
|
||||
#[allow_internal_unstable(rustc_attrs)]
|
||||
#[allow_internal_unstable(core_intrinsics)]
|
||||
|
|
@ -1552,19 +1582,45 @@ pub(crate) mod builtin {
|
|||
/// `Const` should be used on non-float arguments, or float-based arguments as an optimization
|
||||
/// if we are not interested in computing the derivatives with respect to this argument.
|
||||
///
|
||||
/// ### Usage examples:
|
||||
///
|
||||
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
|
||||
/// #![feature(autodiff)]
|
||||
/// use std::autodiff::*;
|
||||
/// #[autodiff_reverse(rb_rev, Active, Active, Active)]
|
||||
/// fn rosenbrock(x: f64, y: f64) -> f64 {
|
||||
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
|
||||
/// }
|
||||
/// #[autodiff_reverse(rb_inp_rev, Active, Active, Duplicated)]
|
||||
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
|
||||
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
|
||||
/// }
|
||||
///
|
||||
/// fn main() {
|
||||
/// let (output1, dx1, dy1) = rb_rev(1.0, 3.0, 1.0);
|
||||
/// dbg!(output1, dx1, dy1); // (400.0, -800.0, 400.0)
|
||||
/// let mut output2 = 0.0;
|
||||
/// let mut seed = 1.0;
|
||||
/// let (dx2, dy2) = rb_inp_rev(1.0, 3.0, &mut output2, &mut seed);
|
||||
/// // (dx2, dy2, output2, seed) == (-800.0, 400.0, 400.0, 0.0)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// We often want to track how one or more input floats affect one output float. This output can
|
||||
/// be a scalar return value, or a mutable reference or pointer argument. In this case, the
|
||||
/// shadow of the input should be marked as duplicated and initialized to `0.0`. The shadow of
|
||||
/// be a scalar return value, or a mutable reference or pointer argument. In the latter case, the
|
||||
/// mutable input should be marked as duplicated and its shadow initialized to `0.0`. The shadow of
|
||||
/// the output should be marked as active or duplicated and initialized to `1.0`. After calling
|
||||
/// the generated function, the shadow(s) of the input(s) will contain the derivatives. If the
|
||||
/// function has more than one output float marked as active or duplicated, users might want to
|
||||
/// the generated function, the shadow(s) of the input(s) will contain the derivatives. The
|
||||
/// shadow of the outputs ("seed") will be reset to zero.
|
||||
/// If the function has more than one output float marked as active or duplicated, users might want to
|
||||
/// set one of them to `1.0` and the others to `0.0` to compute partial derivatives.
|
||||
/// Unlike forward-mode, a call to the generated function does not reset the shadow of the
|
||||
/// inputs.
|
||||
/// Reverse mode is generally more efficient if we have more active/duplicated input than
|
||||
/// output floats.
|
||||
///
|
||||
/// Related information can also be found unter the term "Jacobian-Vector Product" (JVP).
|
||||
/// Related information can also be found under the term "Jacobian-Vector Product" (JVP).
|
||||
#[unstable(feature = "autodiff", issue = "124509")]
|
||||
#[allow_internal_unstable(rustc_attrs)]
|
||||
#[allow_internal_unstable(core_intrinsics)]
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ optimize_for_size = ["core/optimize_for_size", "alloc/optimize_for_size"]
|
|||
# a borrow error occurs
|
||||
debug_refcell = ["core/debug_refcell"]
|
||||
|
||||
llvm_enzyme = ["core/llvm_enzyme"]
|
||||
|
||||
# Enable std_detect features:
|
||||
std_detect_file_io = ["std_detect/std_detect_file_io"]
|
||||
|
|
|
|||
|
|
@ -35,3 +35,4 @@ profiler = ["dep:profiler_builtins"]
|
|||
std_detect_file_io = ["std/std_detect_file_io"]
|
||||
std_detect_dlsym_getauxval = ["std/std_detect_dlsym_getauxval"]
|
||||
windows_raw_dylib = ["std/windows_raw_dylib"]
|
||||
llvm_enzyme = ["std/llvm_enzyme"]
|
||||
|
|
|
|||
|
|
@ -846,6 +846,10 @@ impl Build {
|
|||
features.insert("compiler-builtins-mem");
|
||||
}
|
||||
|
||||
if self.config.llvm_enzyme {
|
||||
features.insert("llvm_enzyme");
|
||||
}
|
||||
|
||||
features.into_iter().collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue