add autodiff examples

This commit is contained in:
Manuel Drehwald 2025-11-11 21:47:41 -05:00
parent a2dce774bc
commit f5892da3f2
5 changed files with 71 additions and 6 deletions

View file

@ -23,6 +23,7 @@ optimize_for_size = []
# Make `RefCell` store additional debugging information, which is printed out when
# a borrow error occurs
debug_refcell = []
llvm_enzyme = []
[lints.rust.unexpected_cfgs]
level = "warn"
@ -38,4 +39,6 @@ check-cfg = [
'cfg(target_has_reliable_f16_math)',
'cfg(target_has_reliable_f128)',
'cfg(target_has_reliable_f128_math)',
'cfg(llvm_enzyme)',
]

View file

@ -1511,13 +1511,43 @@ pub(crate) mod builtin {
/// If used on an input argument, a new shadow argument of the same type will be created,
/// directly following the original argument.
///
/// ### Usage examples:
///
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
/// #![feature(autodiff)]
/// use std::autodiff::*;
/// #[autodiff_forward(rb_fwd1, Dual, Const, Dual)]
/// #[autodiff_forward(rb_fwd2, Const, Dual, Dual)]
/// #[autodiff_forward(rb_fwd3, Dual, Dual, Dual)]
/// fn rosenbrock(x: f64, y: f64) -> f64 {
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
/// }
/// #[autodiff_forward(rb_inp_fwd, Dual, Dual, Dual)]
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
/// }
///
/// fn main() {
/// let x0 = rosenbrock(1.0, 3.0); // 400.0
/// let (x1, dx1) = rb_fwd1(1.0, 1.0, 3.0); // (400.0, -800.0)
/// let (x2, dy1) = rb_fwd2(1.0, 3.0, 1.0); // (400.0, 400.0)
/// // When seeding both arguments at once the tangent return is the sum of both.
/// let (x3, dxy) = rb_fwd3(1.0, 1.0, 3.0, 1.0); // (400.0, -400.0)
///
/// let mut out = 0.0;
/// let mut dout = 0.0;
/// rb_inp_fwd(1.0, 1.0, 3.0, 1.0, &mut out, &mut dout);
/// // (out, dout) == (400.0, -400.0)
/// }
/// ```
///
/// We might want to track how one input float affects one or more output floats. In this case,
/// the shadow of one input should be initialized to `1.0`, while the shadows of the other
/// inputs should be initialized to `0.0`. The shadow of the output(s) should be initialized to
/// `0.0`. After calling the generated function, the shadow of the input will be zeroed,
/// while the shadow(s) of the output(s) will contain the derivatives. Forward mode is generally
/// more efficient if we have more output floats marked as `Dual` than input floats.
/// Related information can also be found unter the term "Vector-Jacobian product" (VJP).
/// Related information can also be found under the term "Vector-Jacobian product" (VJP).
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[allow_internal_unstable(core_intrinsics)]
@ -1552,19 +1582,45 @@ pub(crate) mod builtin {
/// `Const` should be used on non-float arguments, or float-based arguments as an optimization
/// if we are not interested in computing the derivatives with respect to this argument.
///
/// ### Usage examples:
///
/// ```rust,ignore (autodiff requires a -Z flag as well as fat-lto for testing)
/// #![feature(autodiff)]
/// use std::autodiff::*;
/// #[autodiff_reverse(rb_rev, Active, Active, Active)]
/// fn rosenbrock(x: f64, y: f64) -> f64 {
/// (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
/// }
/// #[autodiff_reverse(rb_inp_rev, Active, Active, Duplicated)]
/// fn rosenbrock_inp(x: f64, y: f64, out: &mut f64) {
/// *out = (1.0 - x).powi(2) + 100.0 * (y - x.powi(2)).powi(2);
/// }
///
/// fn main() {
/// let (output1, dx1, dy1) = rb_rev(1.0, 3.0, 1.0);
/// dbg!(output1, dx1, dy1); // (400.0, -800.0, 400.0)
/// let mut output2 = 0.0;
/// let mut seed = 1.0;
/// let (dx2, dy2) = rb_inp_rev(1.0, 3.0, &mut output2, &mut seed);
/// // (dx2, dy2, output2, seed) == (-800.0, 400.0, 400.0, 0.0)
/// }
/// ```
///
///
/// We often want to track how one or more input floats affect one output float. This output can
/// be a scalar return value, or a mutable reference or pointer argument. In this case, the
/// shadow of the input should be marked as duplicated and initialized to `0.0`. The shadow of
/// be a scalar return value, or a mutable reference or pointer argument. In the latter case, the
/// mutable input should be marked as duplicated and its shadow initialized to `0.0`. The shadow of
/// the output should be marked as active or duplicated and initialized to `1.0`. After calling
/// the generated function, the shadow(s) of the input(s) will contain the derivatives. If the
/// function has more than one output float marked as active or duplicated, users might want to
/// the generated function, the shadow(s) of the input(s) will contain the derivatives. The
/// shadow of the outputs ("seed") will be reset to zero.
/// If the function has more than one output float marked as active or duplicated, users might want to
/// set one of them to `1.0` and the others to `0.0` to compute partial derivatives.
/// Unlike forward-mode, a call to the generated function does not reset the shadow of the
/// inputs.
/// Reverse mode is generally more efficient if we have more active/duplicated input than
/// output floats.
///
/// Related information can also be found unter the term "Jacobian-Vector Product" (JVP).
/// Related information can also be found under the term "Jacobian-Vector Product" (JVP).
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[allow_internal_unstable(core_intrinsics)]

View file

@ -126,6 +126,7 @@ optimize_for_size = ["core/optimize_for_size", "alloc/optimize_for_size"]
# a borrow error occurs
debug_refcell = ["core/debug_refcell"]
llvm_enzyme = ["core/llvm_enzyme"]
# Enable std_detect features:
std_detect_file_io = ["std_detect/std_detect_file_io"]

View file

@ -35,3 +35,4 @@ profiler = ["dep:profiler_builtins"]
std_detect_file_io = ["std/std_detect_file_io"]
std_detect_dlsym_getauxval = ["std/std_detect_dlsym_getauxval"]
windows_raw_dylib = ["std/windows_raw_dylib"]
llvm_enzyme = ["std/llvm_enzyme"]

View file

@ -846,6 +846,10 @@ impl Build {
features.insert("compiler-builtins-mem");
}
if self.config.llvm_enzyme {
features.insert("llvm_enzyme");
}
features.into_iter().collect::<Vec<_>>().join(" ")
}