Auto merge of #145423 - Zalathar:rollup-9jtefpl, r=Zalathar

Rollup of 21 pull requests

Successful merges:

 - rust-lang/rust#118087 (Add Ref/RefMut try_map method)
 - rust-lang/rust#122661 (Change the desugaring of `assert!` for better error output)
 - rust-lang/rust#142640 (Implement autodiff using intrinsics)
 - rust-lang/rust#143075 (compiler: Allow `extern "interrupt" fn() -> !`)
 - rust-lang/rust#144865 (Fix tail calls to `#[track_caller]` functions)
 - rust-lang/rust#144944 (E0793: Clarify that it applies to unions as well)
 - rust-lang/rust#144947 (Fix description of unsigned `checked_exact_div`)
 - rust-lang/rust#145004 (Couple of minor cleanups)
 - rust-lang/rust#145005 (strip prefix of temporary file names when it exceeds filesystem name length limit)
 - rust-lang/rust#145012 (Tail call diagnostics to include lifetime info)
 - rust-lang/rust#145065 (resolve: Introduce `RibKind::Block`)
 - rust-lang/rust#145120 (llvm: Accept new LLVM lifetime format)
 - rust-lang/rust#145189 (Weekly `cargo update`)
 - rust-lang/rust#145235 (Minor `[const]` tweaks)
 - rust-lang/rust#145275 (fix(compiler/rustc_codegen_llvm): apply `target-cpu` attribute)
 - rust-lang/rust#145322 (Resolve the prelude import in `build_reduced_graph`)
 - rust-lang/rust#145331 (Make std use the edition 2024 prelude)
 - rust-lang/rust#145369 (Do not ICE on private type in field of unresolved struct)
 - rust-lang/rust#145378 (Add `FnContext` in parser for diagnostic)
 - rust-lang/rust#145389 ([rustdoc] Revert "rustdoc search: prefer stable items in search results")
 - rust-lang/rust#145392 (coverage: Remove intermediate data structures from mapping creation)

r? `@ghost`
`@rustbot` modify labels: rollup
This commit is contained in:
bors 2025-08-15 09:13:10 +00:00
commit ba412a6e70
153 changed files with 2286 additions and 1847 deletions

View file

@ -80,9 +80,9 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.6.19"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192"
dependencies = [
"anstyle",
"anstyle-parse",
@ -119,18 +119,18 @@ dependencies = [
[[package]]
name = "anstyle-query"
version = "1.1.3"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
name = "anstyle-svg"
version = "0.1.9"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a43964079ef399480603125d5afae2b219aceffb77478956e25f17b9bc3435c"
checksum = "dc03a770ef506fe1396c0e476120ac0e6523cf14b74218dd5f18cd6833326fa9"
dependencies = [
"anstyle",
"anstyle-lossy",
@ -141,13 +141,13 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
version = "3.0.9"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@ -353,9 +353,9 @@ checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
[[package]]
name = "camino"
version = "1.1.10"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0da45bc31171d8d6960122e222a67740df867c1dd53b4d51caa297084c185cab"
checksum = "5d07aa9a93b00c76f71bc35d598bed923f6d4f3a9ca5c24b7737ae1a292841c0"
dependencies = [
"serde",
]
@ -518,9 +518,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed87a9d530bb41a67537289bafcac159cb3ee28460e0a4571123d2a778a6a882"
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f"
dependencies = [
"clap_builder",
"clap_derive",
@ -538,9 +538,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64f4f3f3c77c94aff3c7e9aac9a2ca1974a5adf392a8bb751e827d6d127ab966"
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65"
dependencies = [
"anstream",
"anstyle",
@ -937,9 +937,9 @@ dependencies = [
[[package]]
name = "cxx"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3523cc02ad831111491dd64b27ad999f1ae189986728e477604e61b81f828df"
checksum = "b5287274dfdf7e7eaa3d97d460eb2a94922539e6af214bda423f292105011ee2"
dependencies = [
"cc",
"cxxbridge-cmd",
@ -951,9 +951,9 @@ dependencies = [
[[package]]
name = "cxx-build"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212b754247a6f07b10fa626628c157593f0abf640a3dd04cce2760eca970f909"
checksum = "65f3ce027a744135db10a1ebffa0863dab685aeef48f40a02c201f5e70c667d3"
dependencies = [
"cc",
"codespan-reporting",
@ -966,9 +966,9 @@ dependencies = [
[[package]]
name = "cxxbridge-cmd"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f426a20413ec2e742520ba6837c9324b55ffac24ead47491a6e29f933c5b135a"
checksum = "a07dc23f2eea4774297f4c9a17ae4065fecb63127da556e6c9fadb0216d93595"
dependencies = [
"clap",
"codespan-reporting",
@ -980,15 +980,15 @@ dependencies = [
[[package]]
name = "cxxbridge-flags"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258b6069020b4e5da6415df94a50ee4f586a6c38b037a180e940a43d06a070d"
checksum = "f7a4dbad6171f763c4066c83dcd27546b6e93c5c5ae2229f9813bda7233f571d"
[[package]]
name = "cxxbridge-macro"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dec184b52be5008d6eaf7e62fc1802caf1ad1227d11b3b7df2c409c7ffc3f4"
checksum = "a9be4b527950fc42db06163705e78e73eedc8fd723708e942afe3572a9a2c366"
dependencies = [
"indexmap",
"proc-macro2",
@ -1055,9 +1055,9 @@ version = "0.1.91"
[[package]]
name = "derive-where"
version = "1.5.0"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b"
checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f"
dependencies = [
"proc-macro2",
"quote",
@ -1567,9 +1567,9 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.15.4"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
@ -1688,7 +1688,7 @@ dependencies = [
"potential_utf",
"yoke 0.8.0",
"zerofrom",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1721,7 +1721,7 @@ dependencies = [
"litemap 0.8.0",
"tinystr 0.8.1",
"writeable 0.6.1",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1769,7 +1769,7 @@ dependencies = [
"icu_properties",
"icu_provider 2.0.0",
"smallvec",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1791,7 +1791,7 @@ dependencies = [
"icu_provider 2.0.0",
"potential_utf",
"zerotrie",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1831,7 +1831,7 @@ dependencies = [
"yoke 0.8.0",
"zerofrom",
"zerotrie",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1909,9 +1909,9 @@ dependencies = [
[[package]]
name = "indenter"
version = "0.3.3"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5"
[[package]]
name = "indexmap"
@ -2971,7 +2971,7 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585"
dependencies = [
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -3283,7 +3283,7 @@ dependencies = [
"regex",
"serde_json",
"similar",
"wasmparser 0.219.2",
"wasmparser 0.236.0",
]
[[package]]
@ -3623,6 +3623,7 @@ dependencies = [
"rustc_hir",
"rustc_incremental",
"rustc_index",
"rustc_lint_defs",
"rustc_macros",
"rustc_metadata",
"rustc_middle",
@ -4309,7 +4310,6 @@ name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_abi",
"rustc_ast",
"rustc_data_structures",
"rustc_errors",
"rustc_fluent_macro",
@ -4318,7 +4318,6 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",
@ -4920,9 +4919,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.21"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ruzstd"
@ -4980,9 +4979,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
version = "1.0.8"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f6280af86e5f559536da57a45ebc84948833b3bee313a7dd25232e09c878a52"
checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2"
[[package]]
name = "self_cell"
@ -5508,7 +5507,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b"
dependencies = [
"displaydoc",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -6854,9 +6853,9 @@ dependencies = [
[[package]]
name = "zerovec"
version = "0.11.2"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428"
checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b"
dependencies = [
"yoke 0.8.0",
"zerofrom",

View file

@ -17,7 +17,7 @@ ast_passes_abi_must_not_have_parameters_or_return_type=
ast_passes_abi_must_not_have_return_type=
invalid signature for `extern {$abi}` function
.note = functions with the "custom" ABI cannot have a return type
.note = functions with the {$abi} ABI cannot have a return type
.help = remove the return type
ast_passes_assoc_const_without_body =

View file

@ -390,7 +390,13 @@ impl<'a> AstValidator<'a> {
if let InterruptKind::X86 = interrupt_kind {
// "x86-interrupt" is special because it does have arguments.
// FIXME(workingjubilee): properly lint on acceptable input types.
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output {
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output
&& match &ret_ty.kind {
TyKind::Never => false,
TyKind::Tup(tup) if tup.is_empty() => false,
_ => true,
}
{
self.dcx().emit_err(errors::AbiMustNotHaveReturnType {
span: ret_ty.span,
abi,
@ -449,7 +455,13 @@ impl<'a> AstValidator<'a> {
fn reject_params_or_return(&self, abi: ExternAbi, ident: &Ident, sig: &FnSig) {
let mut spans: Vec<_> = sig.decl.inputs.iter().map(|p| p.span).collect();
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output {
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output
&& match &ret_ty.kind {
TyKind::Never => false,
TyKind::Tup(tup) if tup.is_empty() => false,
_ => true,
}
{
spans.push(ret_ty.span);
}

View file

@ -1,8 +1,8 @@
mod context;
use rustc_ast::token::Delimiter;
use rustc_ast::token::{self, Delimiter};
use rustc_ast::tokenstream::{DelimSpan, TokenStream};
use rustc_ast::{DelimArgs, Expr, ExprKind, MacCall, Path, PathSegment, UnOp, token};
use rustc_ast::{DelimArgs, Expr, ExprKind, MacCall, Path, PathSegment};
use rustc_ast_pretty::pprust;
use rustc_errors::PResult;
use rustc_expand::base::{DummyResult, ExpandResult, ExtCtxt, MacEager, MacroExpanderResult};
@ -29,7 +29,7 @@ pub(crate) fn expand_assert<'cx>(
// `core::panic` and `std::panic` are different macros, so we use call-site
// context to pick up whichever is currently in scope.
let call_site_span = cx.with_call_site_ctxt(span);
let call_site_span = cx.with_call_site_ctxt(cond_expr.span);
let panic_path = || {
if use_panic_2021(span) {
@ -63,7 +63,7 @@ pub(crate) fn expand_assert<'cx>(
}),
})),
);
expr_if_not(cx, call_site_span, cond_expr, then, None)
assert_cond_check(cx, call_site_span, cond_expr, then)
}
// If `generic_assert` is enabled, generates rich captured outputs
//
@ -88,26 +88,33 @@ pub(crate) fn expand_assert<'cx>(
)),
)],
);
expr_if_not(cx, call_site_span, cond_expr, then, None)
assert_cond_check(cx, call_site_span, cond_expr, then)
};
ExpandResult::Ready(MacEager::expr(expr))
}
/// `assert!($cond_expr, $custom_message)`
struct Assert {
cond_expr: Box<Expr>,
custom_message: Option<TokenStream>,
}
// if !{ ... } { ... } else { ... }
fn expr_if_not(
cx: &ExtCtxt<'_>,
span: Span,
cond: Box<Expr>,
then: Box<Expr>,
els: Option<Box<Expr>>,
) -> Box<Expr> {
cx.expr_if(span, cx.expr(span, ExprKind::Unary(UnOp::Not, cond)), then, els)
/// `match <cond> { true => {} _ => <then> }`
fn assert_cond_check(cx: &ExtCtxt<'_>, span: Span, cond: Box<Expr>, then: Box<Expr>) -> Box<Expr> {
// Instead of expanding to `if !<cond> { <then> }`, we expand to
// `match <cond> { true => {} _ => <then> }`.
// This allows us to always complain about mismatched types instead of "cannot apply unary
// operator `!` to type `X`" when passing an invalid `<cond>`, while also allowing `<cond>` to
// be `&true`.
let els = cx.expr_block(cx.block(span, thin_vec![]));
let mut arms = thin_vec![];
arms.push(cx.arm(span, cx.pat_lit(span, cx.expr_bool(span, true)), els));
arms.push(cx.arm(span, cx.pat_wild(span), then));
// We wrap the `match` in a statement to limit the length of any borrows introduced in the
// condition.
cx.expr_block(cx.block(span, [cx.stmt_expr(cx.expr_match(span, cond, arms))].into()))
}
fn parse_assert<'a>(cx: &ExtCtxt<'a>, sp: Span, stream: TokenStream) -> PResult<'a, Assert> {

View file

@ -15,11 +15,12 @@ mod llvm_enzyme {
use rustc_ast::tokenstream::*;
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, kw, sym};
use rustc_span::{Ident, Span, Symbol, sym};
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, trace};
@ -179,11 +180,8 @@ mod llvm_enzyme {
}
/// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
/// should just prevent early inlining and optimizations which alter the function signature.
/// The exact signature of the generated function depends on the configuration provided by the
/// user, but here is an example:
/// type-checking and can be called by users. The exact signature of the generated function
/// depends on the configuration provided by the user, but here is an example:
///
/// ```
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@ -194,19 +192,12 @@ mod llvm_enzyme {
/// which becomes expanded to:
/// ```
/// #[rustc_autodiff]
/// #[inline(never)]
/// fn sin(x: &Box<f32>) -> f32 {
/// f32::sin(**x)
/// }
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
/// #[inline(never)]
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
/// unsafe {
/// asm!("NOP");
/// };
/// ::core::hint::black_box(sin(x));
/// ::core::hint::black_box((dx, dret));
/// ::core::hint::black_box(sin(x))
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
/// }
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@ -227,16 +218,24 @@ mod llvm_enzyme {
// first get information about the annotable item: visibility, signature, name and generic
// parameters.
// these will be used to generate the differentiated version of the function
let Some((vis, sig, primal, generics)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
Annotatable::Item(iitem) => {
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
}
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
ast::StmtKind::Item(iitem) => {
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
}
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
assoc_item.vis.clone(),
sig.clone(),
ident.clone(),
generics.clone(),
*of_trait,
)),
_ => None,
},
_ => None,
@ -254,7 +253,6 @@ mod llvm_enzyme {
};
let has_ret = has_ret(&sig.decl.output);
let sig_span = ecx.with_call_site_ctxt(sig.span);
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
@ -323,19 +321,23 @@ mod llvm_enzyme {
}
let span = ecx.with_def_site_ctxt(expand_span);
let n_active: u32 = x
.input_activity
.iter()
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = ecx.block(
span,
thin_vec![call_autodiff(
ecx,
primal,
first_ident(&meta_item_vec[0]),
span,
&d_sig,
&generics,
impl_of_trait,
)],
);
// The first element of it is the name of the function to be generated
let asdf = Box::new(ast::Fn {
let d_fn = Box::new(ast::Fn {
defaultness: ast::Defaultness::Final,
sig: d_sig,
ident: first_ident(&meta_item_vec[0]),
@ -368,7 +370,7 @@ mod llvm_enzyme {
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
match (attr, item) {
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@ -381,14 +383,16 @@ mod llvm_enzyme {
}
}
let mut has_inline_never = false;
// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::Item(iitem.clone())
}
@ -396,8 +400,8 @@ mod llvm_enzyme {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
assoc_item.attrs.push(inline_never.clone());
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::AssocItem(assoc_item.clone(), i)
}
@ -407,9 +411,8 @@ mod llvm_enzyme {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
{
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
}
_ => unreachable!("stmt kind checked previously"),
@ -428,12 +431,21 @@ mod llvm_enzyme {
tokens: ts,
});
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
// If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
let mut d_attrs = thin_vec![d_attr];
if has_inline_never {
d_attrs.push(inline_never);
}
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
let d_fn = Box::new(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
attrs: d_attrs,
id: ast::DUMMY_NODE_ID,
span,
vis,
@ -443,13 +455,13 @@ mod llvm_enzyme {
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Stmt(Box::new(ast::Stmt {
@ -484,282 +496,95 @@ mod llvm_enzyme {
ty
}
// Will generate a body of the type:
// Generate `autodiff` intrinsic call
// ```
// {
// unsafe {
// asm!("NOP");
// }
// ::core::hint::black_box(primal(args));
// ::core::hint::black_box((args, ret));
// <This part remains to be done by following function>
// }
// std::intrinsics::autodiff(source, diff, (args))
// ```
fn init_body_helper(
fn call_autodiff(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
new_names: &[String],
sig_span: Span,
new_decl_span: Span,
idents: &[Ident],
errored: bool,
diff: Ident,
span: Span,
d_sig: &FnSig,
generics: &Generics,
) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
template_strs: Box::new([]),
operands: vec![],
clobber_abis: vec![],
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
line_spans: vec![],
};
let noop_expr = ecx.expr_asm(span, Box::new(noop));
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
let unsf_block = ast::Block {
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
id: ast::DUMMY_NODE_ID,
tokens: None,
rules: unsf,
is_impl: bool,
) -> rustc_ast::Stmt {
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
let tuple_expr = ecx.expr_tuple(
span,
};
let unsf_expr = ecx.expr_block(Box::new(unsf_block));
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
let black_box_primal_call = ecx.expr_call(
new_decl_span,
blackbox_call_expr.clone(),
thin_vec![primal_call.clone()],
);
let tup_args = new_names
.iter()
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
.collect();
let black_box_remaining_args = ecx.expr_call(
sig_span,
blackbox_call_expr.clone(),
thin_vec![ecx.expr_tuple(sig_span, tup_args)],
d_sig
.decl
.inputs
.iter()
.map(|arg| match arg.pat.kind {
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
_ => todo!(),
})
.collect::<ThinVec<_>>()
.into(),
);
let mut body = ecx.block(span, ThinVec::new());
body.stmts.push(ecx.stmt_semi(unsf_expr));
let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
let enzyme_path = ecx.path(span, enzyme_path_idents);
let call_expr = ecx.expr_call(
span,
ecx.expr_path(enzyme_path),
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
);
// This uses primal args which won't be available if we errored before
if !errored {
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
(body, primal_call, black_box_primal_call, blackbox_call_expr)
ecx.stmt_expr(call_expr)
}
/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we manually build something that should pass the type checker.
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
/// bug would ever try to accidentally differentiate this placeholder function body.
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
// Generate turbofish expression from fn name and generics
// Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`
// We use this expression when passing primal and diff function to the autodiff intrinsic
fn gen_turbofish_expr(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
ident: Ident,
generics: &Generics,
) -> Box<ast::Block> {
let new_decl_span = d_sig.span;
// Just adding some default inline-asm and black_box usages to prevent early inlining
// and optimizations which alter the function signature.
//
// The bb_primal_call is the black_box call of the primal function. We keep it around,
// since it has the convenient property of returning the type of the primal function,
// Remember, we only care to match types here.
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
// to prevent rustc from propagating it into the caller.
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
ecx,
span,
primal,
new_names,
sig_span,
new_decl_span,
&idents,
errored,
generics,
);
if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
}
// Everything from here onwards just tries to fulfil the return type. Fun!
// having an active-only return means we'll drop the original return type.
// So that can be treated identical to not having one in the first place.
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(bb_primal_call));
return body;
}
if !primal_ret && n_active == 1 {
// Again no tuple return, so return default float val.
let ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
let arg = ty.kind.is_simple_path().unwrap();
let tmp = ecx.def_site_path(&[arg, kw::Default]);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
body.stmts.push(ecx.stmt_expr(default_call_expr));
return body;
}
let mut exprs: Box<ast::Expr> = primal_call;
let d_ret_ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
if x.mode.is_fwd() {
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
if x.ret_activity == DiffActivity::Const {
// Here we call the primal function, since our dummy function has the same return
// type due to the Const return activity.
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
} else {
let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
let y = ExprKind::Path(
Some(Box::new(q)),
ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
);
let default_call_expr = ecx.expr(span, y);
let default_call_expr =
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
}
} else if x.mode.is_rev() {
if x.width == 1 {
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
match d_ret_ty.kind {
TyKind::Tup(ref args) => {
// We have a tuple return type. We need to create a tuple of the same size
// and fill it with default values.
let mut exprs2 = thin_vec![exprs];
for arg in args.iter().skip(1) {
let arg = arg.kind.is_simple_path().unwrap();
let tmp = ecx.def_site_path(&[arg, kw::Default]);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr =
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs2.push(default_call_expr);
}
exprs = ecx.expr_tuple(new_decl_span, exprs2);
}
_ => {
// Interestingly, even the `-> ArbitraryType` case
// ends up getting matched and handled correctly above,
// so we don't have to handle any other case for now.
panic!("Unsupported return type: {:?}", d_ret_ty);
}
}
}
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
} else {
unreachable!("Unsupported mode: {:?}", x.mode);
}
body.stmts.push(ecx.stmt_expr(exprs));
body
}
fn gen_primal_call(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: &[Ident],
generics: &Generics,
is_impl: bool,
) -> Box<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
let args: ThinVec<_> =
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let self_expr = ecx.expr_self(span);
ecx.expr_method_call(span, self_expr, primal, args)
} else {
let args: ThinVec<_> =
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let mut primal_path = ecx.path_ident(span, primal);
let is_generic = !generics.params.is_empty();
match (is_generic, primal_path.segments.last_mut()) {
(true, Some(function_path)) => {
let primal_generic_types = generics
.params
.iter()
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
let generated_generic_types = primal_generic_types
.map(|type_param| {
let generic_param = TyKind::Path(
None,
ast::Path {
span,
segments: thin_vec![ast::PathSegment {
ident: type_param.ident,
args: None,
id: ast::DUMMY_NODE_ID,
}],
tokens: None,
},
);
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
id: type_param.id,
span,
kind: generic_param,
tokens: None,
})))
})
.collect();
function_path.args =
Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
span,
args: generated_generic_types,
})));
let generic_args = generics
.params
.iter()
.filter_map(|p| match &p.kind {
GenericParamKind::Type { .. } => {
let path = ast::Path::from_ident(p.ident);
let ty = ecx.ty_path(path);
Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
}
_ => {}
}
GenericParamKind::Const { .. } => {
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
}
GenericParamKind::Lifetime { .. } => None,
})
.collect::<ThinVec<_>>();
let primal_call_expr = ecx.expr_path(primal_path);
ecx.expr_call(span, primal_call_expr, args)
}
let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
let segment = PathSegment {
ident,
id: ast::DUMMY_NODE_ID,
args: Some(Box::new(GenericArgs::AngleBracketed(args))),
};
let segments = if is_impl {
thin_vec![
PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
segment,
]
} else {
thin_vec![segment]
};
let path = Path { span, segments, tokens: None };
ecx.expr_path(path)
}
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
@ -778,7 +603,7 @@ mod llvm_enzyme {
sig: &ast::FnSig,
x: &AutoDiffAttrs,
span: Span,
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
) -> ast::FnSig {
let dcx = ecx.sess.dcx();
let has_ret = has_ret(&sig.decl.output);
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
@ -790,7 +615,7 @@ mod llvm_enzyme {
found: num_activities,
});
// This is not the right signature, but we can continue parsing.
return (sig.clone(), vec![], vec![], true);
return sig.clone();
}
assert!(sig.decl.inputs.len() == x.input_activity.len());
assert!(has_ret == x.has_ret_activity());
@ -833,7 +658,7 @@ mod llvm_enzyme {
if errors {
// This is not the right signature, but we can continue parsing.
return (sig.clone(), new_inputs, idents, true);
return sig.clone();
}
let unsafe_activities = x
@ -1047,7 +872,7 @@ mod llvm_enzyme {
}
let d_sig = FnSig { header: d_header, decl: d_decl, span };
trace!("Generated signature: {:?}", d_sig);
(d_sig, new_inputs, idents, false)
d_sig
}
}

View file

@ -93,7 +93,6 @@ use gccjit::{CType, Context, OptimizationLevel};
#[cfg(feature = "master")]
use gccjit::{TargetInfo, Version};
use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@ -363,12 +362,7 @@ impl WriteBackendMethods for GccCodegenBackend {
_exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_functions: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
if !diff_functions.is_empty() {
unimplemented!();
}
back::lto::run_fat(cgcx, each_linked_rlib_for_lto, modules)
}

View file

@ -8,11 +8,12 @@ use rustc_middle::bug;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::{DebugInfo, OomStrategy};
use rustc_symbol_mangling::mangle_internal_symbol;
use smallvec::SmallVec;
use crate::builder::SBuilder;
use crate::declare::declare_simple_fn;
use crate::llvm::{self, False, True, Type, Value};
use crate::{SimpleCx, attributes, debuginfo};
use crate::{SimpleCx, attributes, debuginfo, llvm_util};
pub(crate) unsafe fn codegen(
tcx: TyCtxt<'_>,
@ -147,6 +148,20 @@ fn create_wrapper_function(
llvm::Visibility::from_generic(tcx.sess.default_visibility()),
ty,
);
let mut attrs = SmallVec::<[_; 2]>::new();
let target_cpu = llvm_util::target_cpu(tcx.sess);
let target_cpu_attr = llvm::CreateAttrStringValue(cx.llcx, "target-cpu", target_cpu);
let tune_cpu_attr = llvm_util::tune_cpu(tcx.sess)
.map(|tune_cpu| llvm::CreateAttrStringValue(cx.llcx, "tune-cpu", tune_cpu));
attrs.push(target_cpu_attr);
attrs.extend(tune_cpu_attr);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &attrs);
let no_return = if no_return {
// -> ! DIFlagNoReturn
let no_return = llvm::AttributeKind::NoReturn.create_attr(cx.llcx);

View file

@ -28,22 +28,6 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[
}
}
pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool {
llvm::HasAttributeAtIndex(llfn, idx, attr)
}
pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
llvm::HasStringAttribute(llfn, name)
}
pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
}
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
llvm::RemoveStringAttrFromFn(llfn, name);
}
/// Get LLVM attribute for the provided inline heuristic.
#[inline]
fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> {

View file

@ -24,9 +24,8 @@ use crate::back::write::{
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
};
use crate::errors::{LlvmError, LtoBitcodeFromRlib};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx};
/// We keep track of the computed LTO cache keys from the previous
/// session to determine which CGUs we can reuse.
@ -593,31 +592,6 @@ pub(crate) fn run_pass_manager(
}
if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
for function in cx.get_functions() {
let enzyme_marker = "enzyme_marker";
if attributes::has_string_attr(function, enzyme_marker) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, enzyme_marker);
assert!(
!attributes::has_string_attr(function, enzyme_marker),
"Expected function to not have 'enzyme_marker'"
);
let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
attributes::apply_to_llfn(function, Function, &[always_inline]);
}
}
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {

View file

@ -862,7 +862,7 @@ pub(crate) fn codegen(
.generic_activity_with_arg("LLVM_module_codegen_embed_bitcode", &*module.name);
let thin_bc =
module.thin_lto_buffer.as_deref().expect("cannot find embedded bitcode");
embed_bitcode(cgcx, llcx, llmod, &config.bc_cmdline, &thin_bc);
embed_bitcode(cgcx, llcx, llmod, &thin_bc);
}
}
@ -1058,7 +1058,6 @@ fn embed_bitcode(
cgcx: &CodegenContext<LlvmCodegenBackend>,
llcx: &llvm::Context,
llmod: &llvm::Module,
cmdline: &str,
bitcode: &[u8],
) {
// We're adding custom sections to the output object file, but we definitely
@ -1074,7 +1073,9 @@ fn embed_bitcode(
// * Mach-O - this is for macOS. Inspecting the source code for the native
// linker here shows that the `.llvmbc` and `.llvmcmd` sections are
// automatically skipped by the linker. In that case there's nothing extra
// that we need to do here.
// that we need to do here. We do need to make sure that the
// `__LLVM,__cmdline` section exists even though it is empty as otherwise
// ld64 rejects the object file.
//
// * Wasm - the native LLD linker is hard-coded to skip `.llvmbc` and
// `.llvmcmd` sections, so there's nothing extra we need to do.
@ -1111,7 +1112,7 @@ fn embed_bitcode(
llvm::set_linkage(llglobal, llvm::Linkage::PrivateLinkage);
llvm::LLVMSetGlobalConstant(llglobal, llvm::True);
let llconst = common::bytes_in_context(llcx, cmdline.as_bytes());
let llconst = common::bytes_in_context(llcx, &[]);
let llglobal = llvm::add_global(llmod, common::val_ty(llconst), c"rustc.embedded.cmdline");
llvm::set_initializer(llglobal, llconst);
let section = if cgcx.target_is_like_darwin {
@ -1128,7 +1129,7 @@ fn embed_bitcode(
let section_flags = if cgcx.is_pe_coff { "n" } else { "e" };
let asm = create_section_with_flags_asm(".llvmbc", section_flags, bitcode);
llvm::append_module_inline_asm(llmod, &asm);
let asm = create_section_with_flags_asm(".llvmcmd", section_flags, cmdline.as_bytes());
let asm = create_section_with_flags_asm(".llvmcmd", section_flags, &[]);
llvm::append_module_inline_asm(llmod, &asm);
}
}

View file

@ -1696,7 +1696,11 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
return;
}
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]);
if crate::llvm_util::get_version() >= (22, 0, 0) {
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[ptr]);
} else {
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]);
}
}
}
impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {

View file

@ -1,40 +1,92 @@
use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_errors::FatalError;
use rustc_middle::bug;
use tracing::{debug, trace};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use tracing::debug;
use crate::back::write::llvm_err;
use crate::builder::{SBuilder, UNNAMED};
use crate::builder::{Builder, PlaceRef, UNNAMED};
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::llvm;
use crate::llvm::{Metadata, True, Type};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
fn get_params(fnc: &Value) -> Vec<&Value> {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
unsafe {
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
fnc_args.set_len(param_num);
pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
fn_ty: Ty<'tcx>,
da: &mut Vec<DiffActivity>,
) {
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
}
fnc_args
}
fn has_sret(fnc: &Value) -> bool {
let num_args = llvm::LLVMCountParams(fnc) as usize;
if num_args == 0 {
false
} else {
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let mut new_activities = vec![];
let mut new_positions = vec![];
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};
let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}
}
@ -48,14 +100,13 @@ fn has_sret(fnc: &Value) -> bool {
// need to match those.
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
cx: &SimpleCx<'ll>,
builder: &SBuilder<'ll, 'll>,
builder: &mut Builder<'_, 'll, 'tcx>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
outer_args: &[&'ll llvm::Value],
has_sret: bool,
) {
debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@ -67,14 +118,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
let mut outer_pos: usize = 0;
let mut activity_pos = 0;
if has_sret {
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
// inner function will still return something. We increase our outer_pos by one,
// and once we're done with all other args we will take the return of the inner call and
// update the sret pointer with it
outer_pos = 1;
}
let enzyme_const = cx.create_metadata(b"enzyme_const");
let enzyme_out = cx.create_metadata(b"enzyme_out");
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
@ -193,92 +236,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
}
}
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
// Beyond sret, this article describes our challenges nicely:
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
fn compute_enzyme_fn_ty<'ll>(
cx: &SimpleCx<'ll>,
attrs: &AutoDiffAttrs,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
) -> &'ll llvm::Type {
let fn_ty = cx.get_type_of_global(outer_fn);
let mut ret_ty = cx.get_return_type(fn_ty);
let has_sret = has_sret(outer_fn);
if has_sret {
// Now we don't just forward the return type, so we have to figure it out based on the
// primal return type, in combination with the autodiff settings.
let fn_ty = cx.get_type_of_global(fn_to_diff);
let inner_ret_ty = cx.get_return_type(fn_ty);
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
if inner_ret_ty == void_ty {
// This indicates that even the inner function has an sret.
// Right now I only look for an sret in the outer function.
// This *probably* needs some extra handling, but I never ran
// into such a case. So I'll wait for user reports to have a test case.
bug!("sret in inner function");
}
if attrs.width == 1 {
// Enzyme returns a struct of style:
// `{ original_ret(if requested), float, float, ... }`
let mut struct_elements = vec![];
if attrs.has_primal_ret() {
struct_elements.push(inner_ret_ty);
}
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
// and therefore part of the return struct.
let param_tys = cx.func_params_types(fn_ty);
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
if matches!(act, DiffActivity::Active) {
// Now find the float type at position i based on the fn_ty,
// to know what (f16/f32/f64/...) to add to the struct.
struct_elements.push(param_ty);
}
}
ret_ty = cx.type_struct(&struct_elements, false);
} else {
// First we check if we also have to deal with the primal return.
match attrs.mode {
DiffMode::Forward => match attrs.ret_activity {
DiffActivity::Dual => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
ret_ty = arr_ty;
}
DiffActivity::DualOnly => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
ret_ty = arr_ty;
}
DiffActivity::Const => {
todo!("Not sure, do we need to do something here?");
}
_ => {
bug!("unreachable");
}
},
DiffMode::Reverse => {
todo!("Handle sret for reverse mode");
}
_ => {
bug!("unreachable");
}
}
}
}
// LLVM can figure out the input types on it's own, so we take a shortcut here.
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
}
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
/// function with expected naming and calling conventions[^1] which will be
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@ -288,11 +245,15 @@ fn compute_enzyme_fn_ty<'ll>(
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
fn generate_enzyme_call<'ll>(
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
outer_name: &str,
ret_ty: &'ll Type,
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
dest: PlaceRef<'tcx, &'ll Value>,
) {
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
let mut ad_name: String = match attrs.mode {
@ -302,11 +263,9 @@ fn generate_enzyme_call<'ll>(
}
.to_string();
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
// add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
let name = llvm::get_value_name(outer_fn);
let outer_fn_name = std::str::from_utf8(&name).unwrap();
ad_name.push_str(outer_fn_name);
ad_name.push_str(outer_name);
// Let us assume the user wrote the following function square:
//
@ -316,14 +275,8 @@ fn generate_enzyme_call<'ll>(
// %0 = fmul double %x, %x
// ret double %0
// }
// ```
//
// The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
// Our macro generates the following placeholder code (slightly simplified):
//
// ```llvm
// define double @dsquare(double %x) {
// ; placeholder code
// return 0.0;
// }
// ```
@ -340,175 +293,44 @@ fn generate_enzyme_call<'ll>(
// ret double %0
// }
// ```
unsafe {
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) };
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
llvm::UnnamedAddr::No,
llvm::Visibility::Default,
enzyme_ty,
);
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
llvm::UnnamedAddr::No,
llvm::Visibility::Default,
enzyme_ty,
);
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
// do it's work.
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);
// We add a made-up attribute just such that we can recognize it after AD to update
// (no)-inline attributes. We'll then also remove this attribute.
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
// first, remove all calls from fnc
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
let br = llvm::LLVMRustGetTerminator(entry);
llvm::LLVMRustEraseInstFromParent(br);
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
let mut builder = SBuilder::build(cx, entry);
let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
if attrs.width > 1 {
let enzyme_width = cx.create_metadata(b"enzyme_width");
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
}
let has_sret = has_sret(outer_fn);
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
match_args_from_caller_to_enzyme(
&cx,
&builder,
attrs.width,
&mut args,
&attrs.input_activity,
&outer_args,
has_sret,
);
let call = builder.call(enzyme_ty, ad_fn, &args, None);
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attached to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
// dummy code which we inserted at a higher level.
// FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
// and how to best improve it for enzyme core and rust-enzyme.
let md_ty = cx.get_md_kind_id("dbg");
if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
.expect("failed to get instruction metadata");
let md_todiff = cx.get_metadata_value(md);
llvm::LLVMSetMetadata(call, md_ty, md_todiff);
} else {
// We don't panic, since depending on whether we are in debug or release mode, we might
// have no debug info to copy, which would then be ok.
trace!("no dbg info");
}
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
if cx.val_ty(call) == cx.type_void() || has_sret {
if has_sret {
// This is what we already have in our outer_fn (shortened):
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
// <Here we are, we want to add the following two lines>
// store [4 x double] %7, ptr %0, align 8
// ret void
// }
// now store the result of the enzyme call into the sret pointer.
let sret_ptr = outer_args[0];
let call_ty = cx.val_ty(call);
if attrs.width == 1 {
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
} else {
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
}
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
}
builder.ret_void();
} else {
builder.ret(call);
}
// Let's crash in case that we messed something up above and generated invalid IR.
llvm::LLVMRustVerifyFunction(
outer_fn,
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
);
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
}
pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
) -> Result<(), FatalError> {
for item in &diff_items {
trace!("{}", item);
}
let diag_handler = cgcx.create_dcx();
let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
if !diff_items.is_empty()
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
{
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}
// Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
for item in diff_items.iter() {
let name = item.source.clone();
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
let Some(fn_def) = fn_def else {
return Err(llvm_err(
diag_handler.handle(),
LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find source function".to_owned(),
},
));
};
debug!(?item.target);
let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
let Some(fn_target) = fn_target else {
return Err(llvm_err(
diag_handler.handle(),
LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find target function".to_owned(),
},
));
};
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
}
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
trace!("done with differentiate()");
Ok(())
if attrs.width > 1 {
let enzyme_width = cx.create_metadata(b"enzyme_width");
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
}
match_args_from_caller_to_enzyme(
&cx,
builder,
attrs.width,
&mut args,
&attrs.input_activity,
fn_args,
);
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
builder.store_to_place(call, dest.val);
}

View file

@ -8,7 +8,6 @@ use std::str;
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
use rustc_codegen_ssa::back::versioned_llvm_target;
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::errors as ssa_errors;
use rustc_codegen_ssa::traits::*;
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@ -660,10 +659,6 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
assert_eq!(self.type_kind(ty), TypeKind::Function);
unsafe { llvm::LLVMGetReturnType(ty) }
}
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
unsafe { llvm::LLVMGlobalGetValueType(val) }
}
@ -727,16 +722,6 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
}
}
pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
let mut functions = vec![];
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
while let Some(f) = func {
functions.push(f);
func = unsafe { llvm::LLVMGetNextFunction(f) }
}
functions
}
}
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

View file

@ -3,24 +3,29 @@ use std::cmp::Ordering;
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::{self as hir};
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, span_bug};
use rustc_span::{Span, Symbol, sym};
use rustc_symbol_mangling::mangle_internal_symbol;
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
use rustc_target::callconv::PassMode;
use rustc_target::spec::PanicStrategy;
use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::context::CodegenCx;
use crate::errors::AutoDiffWithoutEnable;
use crate::llvm::{self, Metadata};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
@ -189,6 +194,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
&[ptr, args[1].immediate()],
)
}
sym::autodiff => {
codegen_autodiff(self, tcx, instance, args, result);
return Ok(());
}
sym::is_val_statically_known => {
if let OperandValue::Immediate(imm) = args[0].val {
self.call_intrinsic(
@ -1113,6 +1122,143 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
rust_try
}
fn codegen_autodiff<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
result: PlaceRef<'tcx, &'ll Value>,
) {
if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) {
let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable);
}
let fn_args = instance.args;
let callee_ty = instance.ty(tcx, bx.typing_env());
let sig = callee_ty.fn_sig(tcx).skip_binder();
let ret_ty = sig.output();
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
// Get source, diff, and attrs
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, source_params) => (def_id, source_params),
_ => bug!("invalid autodiff intrinsic args"),
};
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
source_id,
source_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
bug!("could not find source function")
};
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
_ => bug!("invalid args"),
};
let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
diff_id,
diff_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
bug!("could not find autodiff attrs")
};
adjust_activity_to_abi(
tcx,
fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
&mut diff_attrs.input_activity,
);
// Build body
generate_enzyme_call(
bx,
bx.cx,
fn_to_diff,
&diff_symbol,
llret_ty,
&val_arr,
diff_attrs.clone(),
result,
);
}
fn get_args_from_tuple<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tuple_op: OperandRef<'tcx, &'ll Value>,
fn_instance: Instance<'tcx>,
) -> Vec<&'ll Value> {
let cx = bx.cx;
let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty());
match tuple_op.val {
OperandValue::Immediate(val) => vec![val],
OperandValue::Pair(v1, v2) => vec![v1, v2],
OperandValue::Ref(ptr) => {
let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout };
let mut result = Vec::with_capacity(fn_abi.args.len());
let mut tuple_index = 0;
for arg in &fn_abi.args {
match arg.mode {
PassMode::Ignore => {}
PassMode::Direct(_) | PassMode::Cast { .. } => {
let field = tuple_place.project_field(bx, tuple_index);
let llvm_ty = field.layout.llvm_type(bx.cx);
let val = bx.load(llvm_ty, field.val.llval, field.val.align);
result.push(val);
tuple_index += 1;
}
PassMode::Pair(_, _) => {
let field = tuple_place.project_field(bx, tuple_index);
let llvm_ty = field.layout.llvm_type(bx.cx);
let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
result.push(bx.extract_value(pair_val, 0));
result.push(bx.extract_value(pair_val, 1));
tuple_index += 1;
}
PassMode::Indirect { .. } => {
let field = tuple_place.project_field(bx, tuple_index);
result.push(field.val.llval);
tuple_index += 1;
}
}
}
result
}
OperandValue::ZeroSized => vec![],
}
}
fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,

View file

@ -30,7 +30,6 @@ use context::SimpleCx;
use errors::ParseTargetMachineConfig;
use llvm_util::target_config;
use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn,
@ -173,15 +172,10 @@ impl WriteBackendMethods for LlvmCodegenBackend {
exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_fncs: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
let mut module =
back::lto::run_fat(cgcx, exported_symbols_for_lto, each_linked_rlib_for_lto, modules)?;
if !diff_fncs.is_empty() {
builder::autodiff::differentiate(&module, cgcx, diff_fncs)?;
}
let dcx = cgcx.create_dcx();
let dcx = dcx.handle();
back::lto::run_pass_manager(cgcx, dcx, &mut module, false)?;

View file

@ -42,32 +42,6 @@ pub(crate) fn AddFunctionAttributes<'ll>(
}
}
pub(crate) fn HasAttributeAtIndex<'ll>(
llfn: &'ll Value,
idx: AttributePlace,
kind: AttributeKind,
) -> bool {
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
}
pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveRustEnumAttributeAtIndex(
llfn: &Value,
place: AttributePlace,
kind: AttributeKind,
) {
unsafe {
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
}
}
pub(crate) fn AddCallSiteAttributes<'ll>(
callsite: &'ll Value,
idx: AttributePlace,

View file

@ -26,6 +26,7 @@ rustc_hashes = { path = "../rustc_hashes" }
rustc_hir = { path = "../rustc_hir" }
rustc_incremental = { path = "../rustc_incremental" }
rustc_index = { path = "../rustc_index" }
rustc_lint_defs = { path = "../rustc_lint_defs" }
rustc_macros = { path = "../rustc_macros" }
rustc_metadata = { path = "../rustc_metadata" }
rustc_middle = { path = "../rustc_middle" }

View file

@ -7,7 +7,6 @@ use std::{fs, io, mem, str, thread};
use rustc_abi::Size;
use rustc_ast::attr;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::jobserver::{self, Acquired};
use rustc_data_structures::memmap::Mmap;
@ -38,7 +37,7 @@ use tracing::debug;
use super::link::{self, ensure_removed};
use super::lto::{self, SerializedModule};
use crate::back::lto::check_lto_allowed;
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
use crate::errors::ErrorCreatingRemarkDir;
use crate::traits::*;
use crate::{
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@ -76,12 +75,9 @@ pub struct ModuleConfig {
/// Names of additional optimization passes to run.
pub passes: Vec<String>,
/// Some(level) to optimize at a certain level, or None to run
/// absolutely no optimizations (used for the metadata module).
/// absolutely no optimizations (used for the allocator module).
pub opt_level: Option<config::OptLevel>,
/// Some(level) to optimize binary size, or None to not affect program size.
pub opt_size: Option<config::OptLevel>,
pub pgo_gen: SwitchWithOptPath,
pub pgo_use: Option<PathBuf>,
pub pgo_sample_use: Option<PathBuf>,
@ -102,7 +98,6 @@ pub struct ModuleConfig {
pub emit_obj: EmitObj,
pub emit_thin_lto: bool,
pub emit_thin_lto_summary: bool,
pub bc_cmdline: String,
// Miscellaneous flags. These are mostly copied from command-line
// options.
@ -110,7 +105,6 @@ pub struct ModuleConfig {
pub lint_llvm_ir: bool,
pub no_prepopulate_passes: bool,
pub no_builtins: bool,
pub time_module: bool,
pub vectorize_loop: bool,
pub vectorize_slp: bool,
pub merge_functions: bool,
@ -171,7 +165,6 @@ impl ModuleConfig {
passes: if_regular!(sess.opts.cg.passes.clone(), vec![]),
opt_level: opt_level_and_size,
opt_size: opt_level_and_size,
pgo_gen: if_regular!(
sess.opts.cg.profile_generate.clone(),
@ -221,17 +214,12 @@ impl ModuleConfig {
sess.opts.output_types.contains_key(&OutputType::ThinLinkBitcode),
false
),
bc_cmdline: sess.target.bitcode_llvm_cmdline.to_string(),
verify_llvm_ir: sess.verify_llvm_ir(),
lint_llvm_ir: sess.opts.unstable_opts.lint_llvm_ir,
no_prepopulate_passes: sess.opts.cg.no_prepopulate_passes,
no_builtins: no_builtins || sess.target.no_builtins,
// Exclude metadata and allocator modules from time_passes output,
// since they throw off the "LLVM passes" measurement.
time_module: if_regular!(true, false),
// Copy what clang does by turning on loop vectorization at O2 and
// slp vectorization at O3.
vectorize_loop: !sess.opts.cg.no_vectorize_loops
@ -454,7 +442,6 @@ pub(crate) fn start_async_codegen<B: ExtraBackendMethods>(
backend: B,
tcx: TyCtxt<'_>,
target_cpu: String,
autodiff_items: &[AutoDiffItem],
) -> OngoingCodegen<B> {
let (coordinator_send, coordinator_receive) = channel();
@ -473,7 +460,6 @@ pub(crate) fn start_async_codegen<B: ExtraBackendMethods>(
backend.clone(),
tcx,
&crate_info,
autodiff_items,
shared_emitter,
codegen_worker_send,
coordinator_receive,
@ -728,7 +714,6 @@ pub(crate) enum WorkItem<B: WriteBackendMethods> {
each_linked_rlib_for_lto: Vec<PathBuf>,
needs_fat_lto: Vec<FatLtoInput<B>>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
autodiff: Vec<AutoDiffItem>,
},
/// Performs thin-LTO on the given module.
ThinLto(lto::ThinModule<B>),
@ -1001,7 +986,6 @@ fn execute_fat_lto_work_item<B: ExtraBackendMethods>(
each_linked_rlib_for_lto: &[PathBuf],
mut needs_fat_lto: Vec<FatLtoInput<B>>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
autodiff: Vec<AutoDiffItem>,
module_config: &ModuleConfig,
) -> Result<WorkItemResult<B>, FatalError> {
for (module, wp) in import_only_modules {
@ -1013,7 +997,6 @@ fn execute_fat_lto_work_item<B: ExtraBackendMethods>(
exported_symbols_for_lto,
each_linked_rlib_for_lto,
needs_fat_lto,
autodiff,
)?;
let module = B::codegen(cgcx, module, module_config)?;
Ok(WorkItemResult::Finished(module))
@ -1105,7 +1088,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
backend: B,
tcx: TyCtxt<'_>,
crate_info: &CrateInfo,
autodiff_items: &[AutoDiffItem],
shared_emitter: SharedEmitter,
codegen_worker_send: Sender<CguMessage>,
coordinator_receive: Receiver<Message<B>>,
@ -1115,7 +1097,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
) -> thread::JoinHandle<Result<CompiledModules, ()>> {
let coordinator_send = tx_to_llvm_workers;
let sess = tcx.sess;
let autodiff_items = autodiff_items.to_vec();
let mut each_linked_rlib_for_lto = Vec::new();
let mut each_linked_rlib_file_for_lto = Vec::new();
@ -1448,7 +1429,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
each_linked_rlib_for_lto: each_linked_rlib_file_for_lto,
needs_fat_lto,
import_only_modules,
autodiff: autodiff_items.clone(),
},
0,
));
@ -1456,11 +1436,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
helper.request_token();
}
} else {
if !autodiff_items.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
for (work, cost) in generate_thin_lto_work(
&cgcx,
&exported_symbols_for_lto,
@ -1740,7 +1715,7 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
llvm_start_time: &mut Option<VerboseTimingGuard<'a>>,
work: WorkItem<B>,
) {
if cgcx.config(work.module_kind()).time_module && llvm_start_time.is_none() {
if llvm_start_time.is_none() {
*llvm_start_time = Some(cgcx.prof.verbose_generic_activity("LLVM_passes"));
}
@ -1795,7 +1770,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
each_linked_rlib_for_lto,
needs_fat_lto,
import_only_modules,
autodiff,
} => {
let _timer = cgcx
.prof
@ -1806,7 +1780,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
&each_linked_rlib_for_lto,
needs_fat_lto,
import_only_modules,
autodiff,
module_config,
)
}

View file

@ -647,7 +647,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
) -> OngoingCodegen<B> {
// Skip crate items and just output metadata in -Z no-codegen mode.
if tcx.sess.opts.unstable_opts.no_codegen || !tcx.sess.opts.output_types.should_codegen() {
let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu, &[]);
let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu);
ongoing_codegen.codegen_finished(tcx);
@ -665,8 +665,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
// Run the monomorphization collector and partition the collected items into
// codegen units.
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
tcx.collect_and_partition_mono_items(());
let MonoItemPartitions { codegen_units, .. } = tcx.collect_and_partition_mono_items(());
// Force all codegen_unit queries so they are already either red or green
// when compile_codegen_unit accesses them. We are not able to re-execute
@ -679,34 +678,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
}
}
let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu, autodiff_items);
// Codegen an allocator shim, if necessary.
if let Some(kind) = allocator_kind_for_codegen(tcx) {
let llmod_id =
cgu_name_builder.build_cgu_name(LOCAL_CRATE, &["crate"], Some("allocator")).to_string();
let module_llvm = tcx.sess.time("write_allocator_module", || {
backend.codegen_allocator(
tcx,
&llmod_id,
kind,
// If allocator_kind is Some then alloc_error_handler_kind must
// also be Some.
tcx.alloc_error_handler_kind(()).unwrap(),
)
});
ongoing_codegen.wait_for_signal_to_codegen_item();
ongoing_codegen.check_for_errors(tcx.sess);
// These modules are generally cheap and won't throw off scheduling.
let cost = 0;
submit_codegened_module_to_llvm(
&ongoing_codegen.coordinator,
ModuleCodegen::new_allocator(llmod_id, module_llvm),
cost,
);
}
let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu);
// For better throughput during parallel processing by LLVM, we used to sort
// CGUs largest to smallest. This would lead to better thread utilization
@ -823,6 +795,35 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
}
}
// Codegen an allocator shim, if necessary.
// Do this last to ensure the LLVM_passes timer doesn't start while no CGUs have been codegened
// yet for the backend to optimize.
if let Some(kind) = allocator_kind_for_codegen(tcx) {
let llmod_id =
cgu_name_builder.build_cgu_name(LOCAL_CRATE, &["crate"], Some("allocator")).to_string();
let module_llvm = tcx.sess.time("write_allocator_module", || {
backend.codegen_allocator(
tcx,
&llmod_id,
kind,
// If allocator_kind is Some then alloc_error_handler_kind must
// also be Some.
tcx.alloc_error_handler_kind(()).unwrap(),
)
});
ongoing_codegen.wait_for_signal_to_codegen_item();
ongoing_codegen.check_for_errors(tcx.sess);
// These modules are generally cheap and won't throw off scheduling.
let cost = 0;
submit_codegened_module_to_llvm(
&ongoing_codegen.coordinator,
ModuleCodegen::new_allocator(llmod_id, module_llvm),
cost,
);
}
ongoing_codegen.codegen_finished(tcx);
// Since the main thread is sometimes blocked during codegen, we keep track

View file

@ -177,14 +177,6 @@ fn process_builtin_attrs(
let mut interesting_spans = InterestingAttributeDiagnosticSpans::default();
let rust_target_features = tcx.rust_target_features(LOCAL_CRATE);
// If our rustc version supports autodiff/enzyme, then we call our handler
// to check for any `#[rustc_autodiff(...)]` attributes.
// FIXME(jdonszelmann): merge with loop below
if cfg!(llvm_enzyme) {
let ad = autodiff_attrs(tcx, did.into());
codegen_fn_attrs.autodiff_item = ad;
}
for attr in attrs.iter() {
if let hir::Attribute::Parsed(p) = attr {
match p {
@ -612,7 +604,7 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Align> {
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();

View file

@ -5,6 +5,7 @@ use rustc_ast as ast;
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_data_structures::packed::Pu128;
use rustc_hir::lang_items::LangItem;
use rustc_lint_defs::builtin::TAIL_CALL_TRACK_CALLER;
use rustc_middle::mir::{self, AssertKind, InlineAsmMacro, SwitchTargets, UnwindTerminateReason};
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf, ValidityRequirement};
use rustc_middle::ty::print::{with_no_trimmed_paths, with_no_visible_paths};
@ -35,7 +36,7 @@ enum MergingSucc {
True,
}
/// Indicates to the call terminator codegen whether a cal
/// Indicates to the call terminator codegen whether a call
/// is a normal call or an explicit tail call.
#[derive(Debug, PartialEq)]
enum CallKind {
@ -906,7 +907,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn_span,
);
let instance = match instance.def {
match instance.def {
// We don't need AsyncDropGlueCtorShim here because it is not `noop func`,
// it is `func returning noop future`
ty::InstanceKind::DropGlue(_, None) => {
@ -995,14 +996,35 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
intrinsic.name,
);
}
instance
(Some(instance), None)
}
}
}
_ => instance,
};
(Some(instance), None)
_ if kind == CallKind::Tail
&& instance.def.requires_caller_location(bx.tcx()) =>
{
if let Some(hir_id) =
terminator.source_info.scope.lint_root(&self.mir.source_scopes)
{
let msg = "tail calling a function marked with `#[track_caller]` has no special effect";
bx.tcx().node_lint(TAIL_CALL_TRACK_CALLER, hir_id, |d| {
_ = d.primary_message(msg).span(fn_span)
});
}
let instance = ty::Instance::resolve_for_fn_ptr(
bx.tcx(),
bx.typing_env(),
def_id,
generic_args,
)
.unwrap();
(None, Some(bx.get_fn_addr(instance)))
}
_ => (Some(instance), None),
}
}
ty::FnPtr(..) => (None, Some(callee.immediate())),
_ => bug!("{} is not callable", callee.layout.ty),

View file

@ -1,6 +1,5 @@
use std::path::PathBuf;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_errors::{DiagCtxtHandle, FatalError};
use rustc_middle::dep_graph::WorkProduct;
@ -23,7 +22,6 @@ pub trait WriteBackendMethods: Clone + 'static {
exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_fncs: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError>;
/// Performs thin LTO by performing necessary global analysis and returning two
/// lists, one of the modules that need optimization and another for modules that

View file

@ -1,4 +1,9 @@
An unaligned reference to a field of a [packed] struct got created.
An unaligned reference to a field of a [packed] `struct` or `union` was created.
The `#[repr(packed)]` attribute removes padding between fields, which can
cause fields to be stored at unaligned memory addresses. Creating references
to such fields violates Rust's memory safety guarantees and can lead to
undefined behavior in optimized code.
Erroneous code example:
@ -45,9 +50,36 @@ unsafe {
// For formatting, we can create a copy to avoid the direct reference.
let copy = foo.field1;
println!("{}", copy);
// Creating a copy can be written in a single line with curly braces.
// (This is equivalent to the two lines above.)
println!("{}", { foo.field1 });
// A reference to a field that will always be sufficiently aligned is safe:
println!("{}", foo.field2);
}
```
### Unions
Although creating a reference to a `union` field is `unsafe`, this error
will still be triggered if the referenced field is not sufficiently
aligned. Use `addr_of!` and raw pointers in the same way as for struct fields.
```compile_fail,E0793
#[repr(packed)]
pub union Foo {
field1: u64,
field2: u8,
}
unsafe {
let foo = Foo { field1: 0 };
// Accessing the field directly is fine.
let val = foo.field1;
// A reference to a packed union field causes an error.
let val = &foo.field1; // ERROR
}
```

View file

@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::round_ties_even_f32
| sym::round_ties_even_f64
| sym::round_ties_even_f128
| sym::autodiff
| sym::const_eval_select => hir::Safety::Safe,
_ => hir::Safety::Unsafe,
};
@ -198,6 +199,7 @@ pub(crate) fn check_intrinsic_type(
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
let n_lts = 0;
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
sym::abort => (0, 0, vec![], tcx.types.never),
sym::unreachable => (0, 0, vec![], tcx.types.never),
sym::breakpoint => (0, 0, vec![], tcx.types.unit),

View file

@ -390,45 +390,13 @@ fn check_predicates<'tcx>(
let mut res = Ok(());
for (clause, span) in impl1_predicates {
if !impl2_predicates.iter().any(|pred2| trait_predicates_eq(clause.as_predicate(), *pred2))
{
if !impl2_predicates.iter().any(|&pred2| clause.as_predicate() == pred2) {
res = res.and(check_specialization_on(tcx, clause, span))
}
}
res
}
/// Checks if some predicate on the specializing impl (`predicate1`) is the same
/// as some predicate on the base impl (`predicate2`).
///
/// This basically just checks syntactic equivalence, but is a little more
/// forgiving since we want to equate `T: Tr` with `T: [const] Tr` so this can work:
///
/// ```ignore (illustrative)
/// #[rustc_specialization_trait]
/// trait Specialize { }
///
/// impl<T: Bound> Tr for T { }
/// impl<T: [const] Bound + Specialize> const Tr for T { }
/// ```
///
/// However, we *don't* want to allow the reverse, i.e., when the bound on the
/// specializing impl is not as const as the bound on the base impl:
///
/// ```ignore (illustrative)
/// impl<T: [const] Bound> const Tr for T { }
/// impl<T: Bound + Specialize> const Tr for T { } // should be T: [const] Bound
/// ```
///
/// So we make that check in this function and try to raise a helpful error message.
fn trait_predicates_eq<'tcx>(
predicate1: ty::Predicate<'tcx>,
predicate2: ty::Predicate<'tcx>,
) -> bool {
// FIXME(const_trait_impl)
predicate1 == predicate2
}
#[instrument(level = "debug", skip(tcx))]
fn check_specialization_on<'tcx>(
tcx: TyCtxt<'tcx>,

View file

@ -910,7 +910,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// const stability checking here too, I guess.
if self.tcx.is_conditionally_const(callee_did) {
let q = self.tcx.const_conditions(callee_did);
// FIXME(const_trait_impl): Use this span with a better cause code.
for (idx, (cond, pred_span)) in
q.instantiate(self.tcx, callee_args).into_iter().enumerate()
{

View file

@ -5105,3 +5105,36 @@ declare_lint! {
report_in_deps: true,
};
}
declare_lint! {
/// The `tail_call_track_caller` lint detects usage of `become` attempting to tail call
/// a function marked with `#[track_caller]`.
///
/// ### Example
///
/// ```rust
/// #![feature(explicit_tail_calls)]
/// #![expect(incomplete_features)]
///
/// #[track_caller]
/// fn f() {}
///
/// fn g() {
/// become f();
/// }
///
/// g();
/// ```
///
/// {{produces}}
///
/// ### Explanation
///
/// Due to implementation details of tail calls and `#[track_caller]` attribute, calls to
/// functions marked with `#[track_caller]` cannot become tail calls. As such using `become`
/// is no different than a normal call (except for changes in drop order).
pub TAIL_CALL_TRACK_CALLER,
Warn,
"detects tail calls of functions marked with `#[track_caller]`",
@feature_gate = explicit_tail_calls;
}

View file

@ -1,7 +1,6 @@
use std::borrow::Cow;
use rustc_abi::Align;
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
use rustc_hir::attrs::{InlineAttr, InstructionSetAttr, Linkage, OptimizeAttr};
use rustc_hir::def_id::DefId;
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
@ -75,8 +74,6 @@ pub struct CodegenFnAttrs {
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
/// the function entry.
pub patchable_function_entry: Option<PatchableFunctionEntry>,
/// For the `#[autodiff]` macros.
pub autodiff_item: Option<AutoDiffAttrs>,
}
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
@ -182,7 +179,6 @@ impl CodegenFnAttrs {
instruction_set: None,
alignment: None,
patchable_function_entry: None,
autodiff_item: None,
}
}

View file

@ -2,7 +2,6 @@ use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN};
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::FxIndexMap;
@ -336,7 +335,6 @@ impl ToStableHashKey<StableHashingContext<'_>> for MonoItem<'_> {
pub struct MonoItemPartitions<'tcx> {
pub codegen_units: &'tcx [CodegenUnit<'tcx>],
pub all_mono_items: &'tcx DefIdSet,
pub autodiff_items: &'tcx [AutoDiffItem],
}
#[derive(Debug, HashStable)]

View file

@ -3,6 +3,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_errors::Applicability;
use rustc_hir::LangItem;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::CRATE_DEF_ID;
use rustc_middle::span_bug;
use rustc_middle::thir::visit::{self, Visitor};
use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
@ -136,7 +137,15 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
if caller_sig.inputs() != callee_sig.inputs() {
self.report_arguments_mismatch(expr.span, caller_sig, callee_sig);
self.report_arguments_mismatch(
expr.span,
self.tcx.liberate_late_bound_regions(
CRATE_DEF_ID.to_def_id(),
self.caller_ty.fn_sig(self.tcx),
),
self.tcx
.liberate_late_bound_regions(CRATE_DEF_ID.to_def_id(), ty.fn_sig(self.tcx)),
);
}
// FIXME(explicit_tail_calls): this currently fails for cases where opaques are used.

View file

@ -1,51 +1,36 @@
use rustc_index::IndexVec;
use rustc_middle::mir::coverage::{BlockMarkerId, BranchSpan, CoverageInfoHi, CoverageKind};
use rustc_middle::mir::coverage::{
BlockMarkerId, BranchSpan, CoverageInfoHi, CoverageKind, Mapping, MappingKind,
};
use rustc_middle::mir::{self, BasicBlock, StatementKind};
use rustc_middle::ty::TyCtxt;
use rustc_span::Span;
use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph};
use crate::coverage::graph::CoverageGraph;
use crate::coverage::hir_info::ExtractedHirInfo;
use crate::coverage::spans::extract_refined_covspans;
use crate::coverage::unexpand::unexpand_into_body_span;
/// Associates an ordinary executable code span with its corresponding BCB.
#[derive(Debug)]
pub(super) struct CodeMapping {
pub(super) span: Span,
pub(super) bcb: BasicCoverageBlock,
}
#[derive(Debug)]
pub(super) struct BranchPair {
pub(super) span: Span,
pub(super) true_bcb: BasicCoverageBlock,
pub(super) false_bcb: BasicCoverageBlock,
}
#[derive(Default)]
pub(super) struct ExtractedMappings {
pub(super) code_mappings: Vec<CodeMapping>,
pub(super) branch_pairs: Vec<BranchPair>,
pub(crate) struct ExtractedMappings {
pub(crate) mappings: Vec<Mapping>,
}
/// Extracts coverage-relevant spans from MIR, and associates them with
/// their corresponding BCBs.
pub(super) fn extract_all_mapping_info_from_mir<'tcx>(
/// Extracts coverage-relevant spans from MIR, and uses them to create
/// coverage mapping data for inclusion in MIR.
pub(crate) fn extract_mappings_from_mir<'tcx>(
tcx: TyCtxt<'tcx>,
mir_body: &mir::Body<'tcx>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
) -> ExtractedMappings {
let mut code_mappings = vec![];
let mut branch_pairs = vec![];
let mut mappings = vec![];
// Extract ordinary code mappings from MIR statement/terminator spans.
extract_refined_covspans(tcx, mir_body, hir_info, graph, &mut code_mappings);
extract_refined_covspans(tcx, mir_body, hir_info, graph, &mut mappings);
branch_pairs.extend(extract_branch_pairs(mir_body, hir_info, graph));
extract_branch_mappings(mir_body, hir_info, graph, &mut mappings);
ExtractedMappings { code_mappings, branch_pairs }
ExtractedMappings { mappings }
}
fn resolve_block_markers(
@ -69,19 +54,18 @@ fn resolve_block_markers(
block_markers
}
pub(super) fn extract_branch_pairs(
pub(super) fn extract_branch_mappings(
mir_body: &mir::Body<'_>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
) -> Vec<BranchPair> {
let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return vec![] };
mappings: &mut Vec<Mapping>,
) {
let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return };
let block_markers = resolve_block_markers(coverage_info_hi, mir_body);
coverage_info_hi
.branch_spans
.iter()
.filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| {
mappings.extend(coverage_info_hi.branch_spans.iter().filter_map(
|&BranchSpan { span: raw_span, true_marker, false_marker }| try {
// For now, ignore any branch span that was introduced by
// expansion. This makes things like assert macros less noisy.
if !raw_span.ctxt().outer_expn_data().is_root() {
@ -94,7 +78,7 @@ pub(super) fn extract_branch_pairs(
let true_bcb = bcb_from_marker(true_marker)?;
let false_bcb = bcb_from_marker(false_marker)?;
Some(BranchPair { span, true_bcb, false_bcb })
})
.collect::<Vec<_>>()
Mapping { span, kind: MappingKind::Branch { true_bcb, false_bcb } }
},
));
}

View file

@ -1,4 +1,4 @@
use rustc_middle::mir::coverage::{CoverageKind, FunctionCoverageInfo, Mapping, MappingKind};
use rustc_middle::mir::coverage::{CoverageKind, FunctionCoverageInfo};
use rustc_middle::mir::{self, BasicBlock, Statement, StatementKind, TerminatorKind};
use rustc_middle::ty::TyCtxt;
use tracing::{debug, debug_span, trace};
@ -71,10 +71,8 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir:
////////////////////////////////////////////////////
// Extract coverage spans and other mapping info from MIR.
let extracted_mappings =
mappings::extract_all_mapping_info_from_mir(tcx, mir_body, &hir_info, &graph);
let mappings = create_mappings(&extracted_mappings);
let ExtractedMappings { mappings } =
mappings::extract_mappings_from_mir(tcx, mir_body, &hir_info, &graph);
if mappings.is_empty() {
// No spans could be converted into valid mappings, so skip this function.
debug!("no spans could be converted into valid mappings; skipping");
@ -100,34 +98,6 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir:
}));
}
/// For each coverage span extracted from MIR, create a corresponding mapping.
///
/// FIXME(Zalathar): This used to be where BCBs in the extracted mappings were
/// resolved to a `CovTerm`. But that is now handled elsewhere, so this
/// function can potentially be simplified even further.
fn create_mappings(extracted_mappings: &ExtractedMappings) -> Vec<Mapping> {
// Fully destructure the mappings struct to make sure we don't miss any kinds.
let ExtractedMappings { code_mappings, branch_pairs } = extracted_mappings;
let mut mappings = Vec::new();
mappings.extend(code_mappings.iter().map(
// Ordinary code mappings are the simplest kind.
|&mappings::CodeMapping { span, bcb }| {
let kind = MappingKind::Code { bcb };
Mapping { kind, span }
},
));
mappings.extend(branch_pairs.iter().map(
|&mappings::BranchPair { span, true_bcb, false_bcb }| {
let kind = MappingKind::Branch { true_bcb, false_bcb };
Mapping { kind, span }
},
));
mappings
}
/// Inject any necessary coverage statements into MIR, so that they influence codegen.
fn inject_coverage_statements<'tcx>(mir_body: &mut mir::Body<'tcx>, graph: &CoverageGraph) {
for (bcb, data) in graph.iter_enumerated() {

View file

@ -1,6 +1,6 @@
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir;
use rustc_middle::mir::coverage::START_BCB;
use rustc_middle::mir::coverage::{Mapping, MappingKind, START_BCB};
use rustc_middle::ty::TyCtxt;
use rustc_span::source_map::SourceMap;
use rustc_span::{BytePos, DesugaringKind, ExpnKind, MacroKind, Span};
@ -9,7 +9,7 @@ use tracing::instrument;
use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph};
use crate::coverage::hir_info::ExtractedHirInfo;
use crate::coverage::spans::from_mir::{Hole, RawSpanFromMir, SpanFromMir};
use crate::coverage::{mappings, unexpand};
use crate::coverage::unexpand;
mod from_mir;
@ -18,7 +18,7 @@ pub(super) fn extract_refined_covspans<'tcx>(
mir_body: &mir::Body<'tcx>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
code_mappings: &mut Vec<mappings::CodeMapping>,
mappings: &mut Vec<Mapping>,
) {
if hir_info.is_async_fn {
// An async function desugars into a function that returns a future,
@ -26,7 +26,7 @@ pub(super) fn extract_refined_covspans<'tcx>(
// outer function will be unhelpful, so just keep the signature span
// and ignore all of the spans in the MIR body.
if let Some(span) = hir_info.fn_sig_span {
code_mappings.push(mappings::CodeMapping { span, bcb: START_BCB });
mappings.push(Mapping { span, kind: MappingKind::Code { bcb: START_BCB } })
}
return;
}
@ -111,9 +111,9 @@ pub(super) fn extract_refined_covspans<'tcx>(
// Merge covspans that can be merged.
covspans.dedup_by(|b, a| a.merge_if_eligible(b));
code_mappings.extend(covspans.into_iter().map(|Covspan { span, bcb }| {
mappings.extend(covspans.into_iter().map(|Covspan { span, bcb }| {
// Each span produced by the refiner represents an ordinary code region.
mappings::CodeMapping { span, bcb }
Mapping { span, kind: MappingKind::Code { bcb } }
}));
}

View file

@ -6,7 +6,6 @@ edition = "2024"
[dependencies]
# tidy-alphabetical-start
rustc_abi = { path = "../rustc_abi" }
rustc_ast = { path = "../rustc_ast" }
rustc_data_structures = { path = "../rustc_data_structures" }
rustc_errors = { path = "../rustc_errors" }
rustc_fluent_macro = { path = "../rustc_fluent_macro" }
@ -15,7 +14,6 @@ rustc_macros = { path = "../rustc_macros" }
rustc_middle = { path = "../rustc_middle" }
rustc_session = { path = "../rustc_session" }
rustc_span = { path = "../rustc_span" }
rustc_symbol_mangling = { path = "../rustc_symbol_mangling" }
rustc_target = { path = "../rustc_target" }
serde = "1"
serde_json = "1"

View file

@ -205,6 +205,8 @@
//! this is not implemented however: a mono item will be produced
//! regardless of whether it is actually needed or not.
mod autodiff;
use std::cell::OnceCell;
use rustc_data_structures::fx::FxIndexMap;
@ -235,6 +237,7 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan};
use rustc_span::{DUMMY_SP, Span};
use tracing::{debug, instrument, trace};
use crate::collector::autodiff::collect_autodiff_fn;
use crate::errors::{
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
NoOptimizedMir, RecursionLimit,
@ -786,7 +789,35 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirUsedCollector<'a, 'tcx> {
// *Before* monomorphizing, record that we already handled this mention.
self.used_mentioned_items.insert(MentionedItem::Fn(callee_ty));
let callee_ty = self.monomorphize(callee_ty);
visit_fn_use(self.tcx, callee_ty, true, source, &mut self.used_items)
// HACK(explicit_tail_calls): collect tail calls to `#[track_caller]` functions as indirect,
// because we later call them as such, to prevent issues with ABI incompatibility.
// Ideally we'd replace such tail calls with normal call + return, but this requires
// post-mono MIR optimizations, which we don't yet have.
let force_indirect_call =
if matches!(terminator.kind, mir::TerminatorKind::TailCall { .. })
&& let &ty::FnDef(def_id, args) = callee_ty.kind()
&& let instance = ty::Instance::expect_resolve(
self.tcx,
ty::TypingEnv::fully_monomorphized(),
def_id,
args,
source,
)
&& instance.def.requires_caller_location(self.tcx)
{
true
} else {
false
};
visit_fn_use(
self.tcx,
callee_ty,
!force_indirect_call,
source,
&mut self.used_items,
)
}
mir::TerminatorKind::Drop { ref place, .. } => {
let ty = place.ty(self.body, self.tcx).ty;
@ -911,6 +942,8 @@ fn visit_instance_use<'tcx>(
return;
}
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
collect_autodiff_fn(tcx, instance, intrinsic, output);
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any

View file

@ -0,0 +1,48 @@
use rustc_middle::bug;
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
use crate::collector::{MonoItems, create_fn_mono_item};
// Here, we force both primal and diff function to be collected in
// mono so this does not interfere in `autodiff` intrinsics
// codegen process. If they are unused, LLVM will remove them when
// compiling with O3.
pub(crate) fn collect_autodiff_fn<'tcx>(
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
intrinsic: IntrinsicDef,
output: &mut MonoItems<'tcx>,
) {
if intrinsic.name != rustc_span::sym::autodiff {
return;
};
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
}
fn collect_autodiff_fn_from_arg<'tcx>(
arg: GenericArg<'tcx>,
tcx: TyCtxt<'tcx>,
output: &mut MonoItems<'tcx>,
) {
let (instance, span) = match arg.kind() {
ty::GenericArgKind::Type(ty) => match ty.kind() {
ty::FnDef(def_id, substs) => {
let span = tcx.def_span(def_id);
let instance = ty::Instance::expect_resolve(
tcx,
ty::TypingEnv::non_body_analysis(tcx, def_id),
*def_id,
substs,
span,
);
(instance, span)
}
_ => bug!("expected autodiff function"),
},
_ => bug!("expected type when matching autodiff arg"),
};
output.push(create_fn_mono_item(tcx, instance, span));
}

View file

@ -92,8 +92,6 @@
//! source-level module, functions from the same module will be available for
//! inlining, even when they are not marked `#[inline]`.
mod autodiff;
use std::cmp;
use std::collections::hash_map::Entry;
use std::fs::{self, File};
@ -251,17 +249,7 @@ where
always_export_generics,
);
// We can't differentiate a function that got inlined.
let autodiff_active = cfg!(llvm_enzyme)
&& matches!(mono_item, MonoItem::Fn(_))
&& cx
.tcx
.codegen_fn_attrs(mono_item.def_id())
.autodiff_item
.as_ref()
.is_some_and(|ad| ad.is_active());
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
if visibility == Visibility::Hidden && can_be_internalized {
internalization_candidates.insert(mono_item);
}
let size_estimate = mono_item.size_estimate(cx.tcx);
@ -1157,27 +1145,15 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
}
}
#[cfg(not(llvm_enzyme))]
let autodiff_mono_items: Vec<_> = vec![];
#[cfg(llvm_enzyme)]
let mut autodiff_mono_items: Vec<_> = vec![];
let mono_items: DefIdSet = items
.iter()
.filter_map(|mono_item| match *mono_item {
MonoItem::Fn(ref instance) => {
#[cfg(llvm_enzyme)]
autodiff_mono_items.push((mono_item, instance));
Some(instance.def_id())
}
MonoItem::Fn(ref instance) => Some(instance.def_id()),
MonoItem::Static(def_id) => Some(def_id),
_ => None,
})
.collect();
let autodiff_items =
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
// Output monomorphization stats per def_id
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats
&& let Err(err) =
@ -1235,11 +1211,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
}
}
MonoItemPartitions {
all_mono_items: tcx.arena.alloc(mono_items),
codegen_units,
autodiff_items,
}
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
}
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s

View file

@ -1,143 +0,0 @@
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};
use crate::partitioning::UsageMap;
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
}
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let mut new_activities = vec![];
let mut new_positions = vec![];
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};
let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}
}
pub(crate) fn find_autodiff_source_functions<'tcx>(
tcx: TyCtxt<'tcx>,
usage_map: &UsageMap<'tcx>,
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
) -> Vec<AutoDiffItem> {
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
for (item, instance) in autodiff_mono_items {
let target_id = instance.def_id();
let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
let Some(target_attrs) = cg_fn_attr else {
continue;
};
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
if target_attrs.is_source() {
trace!("source found: {:?}", target_id);
}
if !target_attrs.apply_autodiff() {
continue;
}
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
let source =
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
MonoItem::Fn(ref instance_s) => {
let source_id = instance_s.def_id();
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
&& ad.is_active()
{
return Some(instance_s);
}
None
}
_ => None,
});
let inst = match source {
Some(source) => source,
None => continue,
};
debug!("source_id: {:?}", inst.def_id());
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
assert!(fn_ty.is_fn());
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
let mut new_target_attrs = target_attrs.clone();
new_target_attrs.input_activity = input_activities;
let itm = new_target_attrs.into_item(symb, target_symbol);
autodiff_items.push(itm);
}
if !autodiff_items.is_empty() {
trace!("AUTODIFF ITEMS EXIST");
for item in &mut *autodiff_items {
trace!("{}", &item);
}
}
autodiff_items
}

View file

@ -11,6 +11,7 @@ use tracing::debug;
use super::{
AttrWrapper, Capturing, FnParseMode, ForceCollect, Parser, PathStyle, Trailing, UsePreAttrPos,
};
use crate::parser::FnContext;
use crate::{errors, exp, fluent_generated as fluent};
// Public for rustfmt usage
@ -200,7 +201,7 @@ impl<'a> Parser<'a> {
AttrWrapper::empty(),
true,
false,
FnParseMode { req_name: |_| true, req_body: true },
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true },
ForceCollect::No,
) {
Ok(Some(item)) => {

View file

@ -44,6 +44,7 @@ use crate::errors::{
UnexpectedConstParamDeclaration, UnexpectedConstParamDeclarationSugg, UnmatchedAngleBrackets,
UseEqInstead, WrapType,
};
use crate::parser::FnContext;
use crate::parser::attr::InnerAttrPolicy;
use crate::{exp, fluent_generated as fluent};
@ -2246,6 +2247,7 @@ impl<'a> Parser<'a> {
pat: Box<ast::Pat>,
require_name: bool,
first_param: bool,
fn_parse_mode: &crate::parser::item::FnParseMode,
) -> Option<Ident> {
// If we find a pattern followed by an identifier, it could be an (incorrect)
// C-style parameter declaration.
@ -2268,7 +2270,14 @@ impl<'a> Parser<'a> {
|| self.token == token::Lt
|| self.token == token::CloseParen)
{
let rfc_note = "anonymous parameters are removed in the 2018 edition (see RFC 1685)";
let maybe_emit_anon_params_note = |this: &mut Self, err: &mut Diag<'_>| {
let ed = this.token.span.with_neighbor(this.prev_token.span).edition();
if matches!(fn_parse_mode.context, crate::parser::item::FnContext::Trait)
&& (fn_parse_mode.req_name)(ed)
{
err.note("anonymous parameters are removed in the 2018 edition (see RFC 1685)");
}
};
let (ident, self_sugg, param_sugg, type_sugg, self_span, param_span, type_span) =
match pat.kind {
@ -2305,7 +2314,7 @@ impl<'a> Parser<'a> {
"_: ".to_string(),
Applicability::MachineApplicable,
);
err.note(rfc_note);
maybe_emit_anon_params_note(self, err);
}
return None;
@ -2313,7 +2322,13 @@ impl<'a> Parser<'a> {
};
// `fn foo(a, b) {}`, `fn foo(a<x>, b<y>) {}` or `fn foo(usize, usize) {}`
if first_param {
if first_param
// Only when the fn is a method, we emit this suggestion.
&& matches!(
fn_parse_mode.context,
FnContext::Trait | FnContext::Impl
)
{
err.span_suggestion_verbose(
self_span,
"if this is a `self` type, give it a parameter name",
@ -2337,7 +2352,7 @@ impl<'a> Parser<'a> {
type_sugg,
Applicability::MachineApplicable,
);
err.note(rfc_note);
maybe_emit_anon_params_note(self, err);
// Don't attempt to recover by using the `X` in `X<Y>` as the parameter name.
return if self.token == token::Lt { None } else { Some(ident) };

View file

@ -116,7 +116,8 @@ impl<'a> Parser<'a> {
impl<'a> Parser<'a> {
pub fn parse_item(&mut self, force_collect: ForceCollect) -> PResult<'a, Option<Box<Item>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true };
self.parse_item_(fn_parse_mode, force_collect).map(|i| i.map(Box::new))
}
@ -975,7 +976,8 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<AssocItem>>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Impl, req_body: true };
self.parse_assoc_item(fn_parse_mode, force_collect)
}
@ -983,8 +985,11 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<AssocItem>>>> {
let fn_parse_mode =
FnParseMode { req_name: |edition| edition >= Edition::Edition2018, req_body: false };
let fn_parse_mode = FnParseMode {
req_name: |edition| edition >= Edition::Edition2018,
context: FnContext::Trait,
req_body: false,
};
self.parse_assoc_item(fn_parse_mode, force_collect)
}
@ -1261,7 +1266,8 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<ForeignItem>>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: false };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: false };
Ok(self.parse_item_(fn_parse_mode, force_collect)?.map(
|Item { attrs, id, span, vis, kind, tokens }| {
let kind = match ForeignItemKind::try_from(kind) {
@ -2135,7 +2141,8 @@ impl<'a> Parser<'a> {
let inherited_vis =
Visibility { span: DUMMY_SP, kind: VisibilityKind::Inherited, tokens: None };
// We use `parse_fn` to get a span for the function
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true };
match self.parse_fn(
&mut AttrVec::new(),
fn_parse_mode,
@ -2403,6 +2410,9 @@ pub(crate) struct FnParseMode {
/// * The span is from Edition 2015. In particular, you can get a
/// 2015 span inside a 2021 crate using macros.
pub(super) req_name: ReqName,
/// The context in which this function is parsed, used for diagnostics.
/// This indicates the fn is a free function or method and so on.
pub(super) context: FnContext,
/// If this flag is set to `true`, then plain, semicolon-terminated function
/// prototypes are not allowed here.
///
@ -2424,6 +2434,18 @@ pub(crate) struct FnParseMode {
pub(super) req_body: bool,
}
/// The context in which a function is parsed.
/// FIXME(estebank, xizheyin): Use more variants.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum FnContext {
/// Free context.
Free,
/// A Trait context.
Trait,
/// An Impl block.
Impl,
}
/// Parsing of functions and methods.
impl<'a> Parser<'a> {
/// Parse a function starting from the front matter (`const ...`) to the body `{ ... }` or `;`.
@ -2439,11 +2461,8 @@ impl<'a> Parser<'a> {
let header = self.parse_fn_front_matter(vis, case, FrontMatterParsingMode::Function)?; // `const ... fn`
let ident = self.parse_ident()?; // `foo`
let mut generics = self.parse_generics()?; // `<'a, T, ...>`
let decl = match self.parse_fn_decl(
fn_parse_mode.req_name,
AllowPlus::Yes,
RecoverReturnSign::Yes,
) {
let decl = match self.parse_fn_decl(&fn_parse_mode, AllowPlus::Yes, RecoverReturnSign::Yes)
{
Ok(decl) => decl,
Err(old_err) => {
// If we see `for Ty ...` then user probably meant `impl` item.
@ -2961,18 +2980,21 @@ impl<'a> Parser<'a> {
/// Parses the parameter list and result type of a function declaration.
pub(super) fn parse_fn_decl(
&mut self,
req_name: ReqName,
fn_parse_mode: &FnParseMode,
ret_allow_plus: AllowPlus,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, Box<FnDecl>> {
Ok(Box::new(FnDecl {
inputs: self.parse_fn_params(req_name)?,
inputs: self.parse_fn_params(fn_parse_mode)?,
output: self.parse_ret_ty(ret_allow_plus, RecoverQPath::Yes, recover_return_sign)?,
}))
}
/// Parses the parameter list of a function, including the `(` and `)` delimiters.
pub(super) fn parse_fn_params(&mut self, req_name: ReqName) -> PResult<'a, ThinVec<Param>> {
pub(super) fn parse_fn_params(
&mut self,
fn_parse_mode: &FnParseMode,
) -> PResult<'a, ThinVec<Param>> {
let mut first_param = true;
// Parse the arguments, starting out with `self` being allowed...
if self.token != TokenKind::OpenParen
@ -2988,7 +3010,7 @@ impl<'a> Parser<'a> {
let (mut params, _) = self.parse_paren_comma_seq(|p| {
p.recover_vcs_conflict_marker();
let snapshot = p.create_snapshot_for_diagnostic();
let param = p.parse_param_general(req_name, first_param, true).or_else(|e| {
let param = p.parse_param_general(fn_parse_mode, first_param, true).or_else(|e| {
let guar = e.emit();
// When parsing a param failed, we should check to make the span of the param
// not contain '(' before it.
@ -3019,7 +3041,7 @@ impl<'a> Parser<'a> {
/// - `recover_arg_parse` is used to recover from a failed argument parse.
pub(super) fn parse_param_general(
&mut self,
req_name: ReqName,
fn_parse_mode: &FnParseMode,
first_param: bool,
recover_arg_parse: bool,
) -> PResult<'a, Param> {
@ -3035,16 +3057,22 @@ impl<'a> Parser<'a> {
let is_name_required = match this.token.kind {
token::DotDotDot => false,
_ => req_name(this.token.span.with_neighbor(this.prev_token.span).edition()),
_ => (fn_parse_mode.req_name)(
this.token.span.with_neighbor(this.prev_token.span).edition(),
),
};
let (pat, ty) = if is_name_required || this.is_named_param() {
debug!("parse_param_general parse_pat (is_name_required:{})", is_name_required);
let (pat, colon) = this.parse_fn_param_pat_colon()?;
if !colon {
let mut err = this.unexpected().unwrap_err();
return if let Some(ident) =
this.parameter_without_type(&mut err, pat, is_name_required, first_param)
{
return if let Some(ident) = this.parameter_without_type(
&mut err,
pat,
is_name_required,
first_param,
fn_parse_mode,
) {
let guar = err.emit();
Ok((dummy_arg(ident, guar), Trailing::No, UsePreAttrPos::No))
} else {

View file

@ -22,7 +22,7 @@ use std::{fmt, mem, slice};
use attr_wrapper::{AttrWrapper, UsePreAttrPos};
pub use diagnostics::AttemptLocalParseRecovery;
pub(crate) use expr::ForbiddenLetReason;
pub(crate) use item::FnParseMode;
pub(crate) use item::{FnContext, FnParseMode};
pub use pat::{CommaRecoveryMode, RecoverColon, RecoverComma};
use path::PathStyle;
use rustc_ast::token::{

View file

@ -20,7 +20,9 @@ use crate::errors::{
PathFoundAttributeInParams, PathFoundCVariadicParams, PathSingleColon, PathTripleColon,
};
use crate::exp;
use crate::parser::{CommaRecoveryMode, ExprKind, RecoverColon, RecoverComma};
use crate::parser::{
CommaRecoveryMode, ExprKind, FnContext, FnParseMode, RecoverColon, RecoverComma,
};
/// Specifies how to parse a path.
#[derive(Copy, Clone, PartialEq)]
@ -399,7 +401,13 @@ impl<'a> Parser<'a> {
let dcx = self.dcx();
let parse_params_result = self.parse_paren_comma_seq(|p| {
let param = p.parse_param_general(|_| false, false, false);
// Inside parenthesized type arguments, we want types only, not names.
let mode = FnParseMode {
context: FnContext::Free,
req_name: |_| false,
req_body: false,
};
let param = p.parse_param_general(&mode, false, false);
param.map(move |param| {
if !matches!(param.pat.kind, PatKind::Missing) {
dcx.emit_err(FnPathFoundNamedParams {

View file

@ -19,8 +19,8 @@ use super::diagnostics::AttemptLocalParseRecovery;
use super::pat::{PatternLocation, RecoverComma};
use super::path::PathStyle;
use super::{
AttrWrapper, BlockMode, FnParseMode, ForceCollect, Parser, Restrictions, SemiColonMode,
Trailing, UsePreAttrPos,
AttrWrapper, BlockMode, FnContext, FnParseMode, ForceCollect, Parser, Restrictions,
SemiColonMode, Trailing, UsePreAttrPos,
};
use crate::errors::{self, MalformedLoopLabel};
use crate::exp;
@ -153,7 +153,7 @@ impl<'a> Parser<'a> {
attrs.clone(), // FIXME: unwanted clone of attrs
false,
true,
FnParseMode { req_name: |_| true, req_body: true },
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true },
force_collect,
)? {
self.mk_stmt(lo.to(item.span), StmtKind::Item(Box::new(item)))

View file

@ -19,6 +19,7 @@ use crate::errors::{
NestedCVariadicType, ReturnTypesUseThinArrow,
};
use crate::parser::item::FrontMatterParsingMode;
use crate::parser::{FnContext, FnParseMode};
use crate::{exp, maybe_recover_from_interpolated_ty_qpath};
/// Signals whether parsing a type should allow `+`.
@ -769,7 +770,12 @@ impl<'a> Parser<'a> {
if self.may_recover() && self.token == TokenKind::Lt {
self.recover_fn_ptr_with_generics(lo, &mut params, param_insertion_point)?;
}
let decl = self.parse_fn_decl(|_| false, AllowPlus::No, recover_return_sign)?;
let mode = crate::parser::item::FnParseMode {
req_name: |_| false,
context: FnContext::Free,
req_body: false,
};
let decl = self.parse_fn_decl(&mode, AllowPlus::No, recover_return_sign)?;
let decl_span = span_start.to(self.prev_token.span);
Ok(TyKind::FnPtr(Box::new(FnPtrTy {
@ -1314,7 +1320,8 @@ impl<'a> Parser<'a> {
self.bump();
let args_lo = self.token.span;
let snapshot = self.create_snapshot_for_diagnostic();
match self.parse_fn_decl(|_| false, AllowPlus::No, RecoverReturnSign::OnlyFatArrow) {
let mode = FnParseMode { req_name: |_| false, context: FnContext::Free, req_body: false };
match self.parse_fn_decl(&mode, AllowPlus::No, RecoverReturnSign::OnlyFatArrow) {
Ok(decl) => {
self.dcx().emit_err(ExpectedFnPathFoundFnKeyword { fn_token_span });
Some(ast::Path {
@ -1400,8 +1407,9 @@ impl<'a> Parser<'a> {
// Parse `(T, U) -> R`.
let inputs_lo = self.token.span;
let mode = FnParseMode { req_name: |_| false, context: FnContext::Free, req_body: false };
let inputs: ThinVec<_> =
self.parse_fn_params(|_| false)?.into_iter().map(|input| input.ty).collect();
self.parse_fn_params(&mode)?.into_iter().map(|input| input.ty).collect();
let inputs_span = inputs_lo.to(self.prev_token.span);
let output = self.parse_ret_ty(AllowPlus::No, RecoverQPath::No, RecoverReturnSign::No)?;
let args = ast::ParenthesizedArgs {

View file

@ -493,9 +493,6 @@ impl<'a, 'ra, 'tcx> BuildReducedGraphVisitor<'a, 'ra, 'tcx> {
});
}
}
// We don't add prelude imports to the globs since they only affect lexical scopes,
// which are not relevant to import resolution.
ImportKind::Glob { is_prelude: true, .. } => {}
ImportKind::Glob { .. } => current_module.globs.borrow_mut().push(import),
_ => unreachable!(),
}
@ -658,13 +655,19 @@ impl<'a, 'ra, 'tcx> BuildReducedGraphVisitor<'a, 'ra, 'tcx> {
self.add_import(module_path, kind, use_tree.span, item, root_span, item.id, vis);
}
ast::UseTreeKind::Glob => {
let kind = ImportKind::Glob {
is_prelude: ast::attr::contains_name(&item.attrs, sym::prelude_import),
max_vis: Cell::new(None),
id,
};
self.add_import(prefix, kind, use_tree.span, item, root_span, item.id, vis);
if !ast::attr::contains_name(&item.attrs, sym::prelude_import) {
let kind = ImportKind::Glob { max_vis: Cell::new(None), id };
self.add_import(prefix, kind, use_tree.span, item, root_span, item.id, vis);
} else {
// Resolve the prelude import early.
let path_res =
self.r.cm().maybe_resolve_path(&prefix, None, &self.parent_scope, None);
if let PathResult::Module(ModuleOrUniformRoot::Module(module)) = path_res {
self.r.prelude = Some(module);
} else {
self.r.dcx().span_err(use_tree.span, "cannot resolve a prelude import");
}
}
}
ast::UseTreeKind::Nested { ref items, .. } => {
// Ensure there is at most one `self` in the list

View file

@ -2189,9 +2189,9 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
let ast::ExprKind::Struct(struct_expr) = &expr.kind else { return };
// We don't have to handle type-relative paths because they're forbidden in ADT
// expressions, but that would change with `#[feature(more_qualified_paths)]`.
let Some(Res::Def(_, def_id)) =
self.partial_res_map[&struct_expr.path.segments.iter().last().unwrap().id].full_res()
else {
let Some(segment) = struct_expr.path.segments.last() else { return };
let Some(partial_res) = self.partial_res_map.get(&segment.id) else { return };
let Some(Res::Def(_, def_id)) = partial_res.full_res() else {
return;
};
let Some(default_fields) = self.field_defaults(def_id) else { return };

View file

@ -318,7 +318,6 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
let normalized_ident = Ident { span: normalized_span, ..ident };
// Walk backwards up the ribs in scope.
let mut module = self.graph_root;
for (i, rib) in ribs.iter().enumerate().rev() {
debug!("walk rib\n{:?}", rib.bindings);
// Use the rib kind to determine whether we are resolving parameters
@ -334,51 +333,47 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
*original_rib_ident_def,
ribs,
)));
}
module = match rib.kind {
RibKind::Module(module) => module,
RibKind::MacroDefinition(def) if def == self.macro_def(ident.span.ctxt()) => {
// If an invocation of this macro created `ident`, give up on `ident`
// and switch to `ident`'s source from the macro definition.
ident.span.remove_mark();
continue;
}
_ => continue,
};
match module.kind {
ModuleKind::Block => {} // We can see through blocks
_ => break,
}
let item = self.cm().resolve_ident_in_module_unadjusted(
ModuleOrUniformRoot::Module(module),
ident,
ns,
parent_scope,
Shadowing::Unrestricted,
finalize.map(|finalize| Finalize { used: Used::Scope, ..finalize }),
ignore_binding,
None,
);
if let Ok(binding) = item {
// The ident resolves to an item.
} else if let RibKind::Block(Some(module)) = rib.kind
&& let Ok(binding) = self.cm().resolve_ident_in_module_unadjusted(
ModuleOrUniformRoot::Module(module),
ident,
ns,
parent_scope,
Shadowing::Unrestricted,
finalize.map(|finalize| Finalize { used: Used::Scope, ..finalize }),
ignore_binding,
None,
)
{
// The ident resolves to an item in a block.
return Some(LexicalScopeBinding::Item(binding));
} else if let RibKind::Module(module) = rib.kind {
// Encountered a module item, abandon ribs and look into that module and preludes.
return self
.cm()
.early_resolve_ident_in_lexical_scope(
orig_ident,
ScopeSet::Late(ns, module, finalize.map(|finalize| finalize.node_id)),
parent_scope,
finalize,
finalize.is_some(),
ignore_binding,
None,
)
.ok()
.map(LexicalScopeBinding::Item);
}
if let RibKind::MacroDefinition(def) = rib.kind
&& def == self.macro_def(ident.span.ctxt())
{
// If an invocation of this macro created `ident`, give up on `ident`
// and switch to `ident`'s source from the macro definition.
ident.span.remove_mark();
}
}
self.cm()
.early_resolve_ident_in_lexical_scope(
orig_ident,
ScopeSet::Late(ns, module, finalize.map(|finalize| finalize.node_id)),
parent_scope,
finalize,
finalize.is_some(),
ignore_binding,
None,
)
.ok()
.map(LexicalScopeBinding::Item)
unreachable!()
}
/// Resolve an identifier in lexical scope.
@ -1171,6 +1166,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)
@ -1263,6 +1259,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
let (has_generic_params, def_kind) = match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)
@ -1356,6 +1353,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
let (has_generic_params, def_kind) = match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)

View file

@ -87,7 +87,6 @@ pub(crate) enum ImportKind<'ra> {
id: NodeId,
},
Glob {
is_prelude: bool,
// The visibility of the greatest re-export.
// n.b. `max_vis` is only used in `finalize_import` to check for re-export errors.
max_vis: Cell<Option<Visibility>>,
@ -125,12 +124,9 @@ impl<'ra> std::fmt::Debug for ImportKind<'ra> {
.field("nested", nested)
.field("id", id)
.finish(),
Glob { is_prelude, max_vis, id } => f
.debug_struct("Glob")
.field("is_prelude", is_prelude)
.field("max_vis", max_vis)
.field("id", id)
.finish(),
Glob { max_vis, id } => {
f.debug_struct("Glob").field("max_vis", max_vis).field("id", id).finish()
}
ExternCrate { source, target, id } => f
.debug_struct("ExternCrate")
.field("source", source)
@ -1073,7 +1069,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
ImportKind::Single { source, target, ref bindings, type_ns_only, id, .. } => {
(source, target, bindings, type_ns_only, id)
}
ImportKind::Glob { is_prelude, ref max_vis, id } => {
ImportKind::Glob { ref max_vis, id } => {
if import.module_path.len() <= 1 {
// HACK(eddyb) `lint_if_path_starts_with_module` needs at least
// 2 segments, so the `resolve_path` above won't trigger it.
@ -1096,8 +1092,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
module: None,
});
}
if !is_prelude
&& let Some(max_vis) = max_vis.get()
if let Some(max_vis) = max_vis.get()
&& !max_vis.is_at_least(import.vis, self.tcx)
{
let def_id = self.local_def_id(id);
@ -1485,7 +1480,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
fn resolve_glob_import(&mut self, import: Import<'ra>) {
// This function is only called for glob imports.
let ImportKind::Glob { id, is_prelude, .. } = import.kind else { unreachable!() };
let ImportKind::Glob { id, .. } = import.kind else { unreachable!() };
let ModuleOrUniformRoot::Module(module) = import.imported_module.get().unwrap() else {
self.dcx().emit_err(CannotGlobImportAllCrates { span: import.span });
@ -1504,9 +1499,6 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
if module == import.parent_scope.module {
return;
} else if is_prelude {
self.prelude = Some(module);
return;
}
// Add to module's glob_importers

View file

@ -192,6 +192,13 @@ pub(crate) enum RibKind<'ra> {
/// No restriction needs to be applied.
Normal,
/// We passed through an `ast::Block`.
/// Behaves like `Normal`, but also partially like `Module` if the block contains items.
/// `Block(None)` must be always processed in the same way as `Block(Some(module))`
/// with empty `module`. The module can be `None` only because creation of some definitely
/// empty modules is skipped as an optimization.
Block(Option<Module<'ra>>),
/// We passed through an impl or trait and are now in one of its
/// methods or associated types. Allow references to ty params that impl or trait
/// binds. Disallow any other upvars (including other ty params that are
@ -210,7 +217,7 @@ pub(crate) enum RibKind<'ra> {
/// All other constants aren't allowed to use generic params at all.
ConstantItem(ConstantHasGenerics, Option<(Ident, ConstantItemKind)>),
/// We passed through a module.
/// We passed through a module item.
Module(Module<'ra>),
/// We passed through a `macro_rules!` statement
@ -242,6 +249,7 @@ impl RibKind<'_> {
pub(crate) fn contains_params(&self) -> bool {
match self {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::ConstantItem(..)
| RibKind::Module(_)
@ -258,15 +266,8 @@ impl RibKind<'_> {
fn is_label_barrier(self) -> bool {
match self {
RibKind::Normal | RibKind::MacroDefinition(..) => false,
RibKind::AssocItem
| RibKind::FnOrCoroutine
| RibKind::Item(..)
| RibKind::ConstantItem(..)
| RibKind::Module(..)
| RibKind::ForwardGenericParamBan(_)
| RibKind::ConstParamTy
| RibKind::InlineAsmSym => true,
RibKind::FnOrCoroutine | RibKind::ConstantItem(..) => true,
kind => bug!("unexpected rib kind: {kind:?}"),
}
}
}
@ -1527,19 +1528,6 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
ret
}
fn with_mod_rib<T>(&mut self, id: NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let module = self.r.expect_module(self.r.local_def_id(id).to_def_id());
// Move down in the graph.
let orig_module = replace(&mut self.parent_scope.module, module);
self.with_rib(ValueNS, RibKind::Module(module), |this| {
this.with_rib(TypeNS, RibKind::Module(module), |this| {
let ret = f(this);
this.parent_scope.module = orig_module;
ret
})
})
}
fn visit_generic_params(&mut self, params: &'ast [GenericParam], add_self_upper: bool) {
// For type parameter defaults, we have to ban access
// to following type parameters, as the GenericArgs can only
@ -2677,20 +2665,25 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
}
ItemKind::Mod(..) => {
self.with_mod_rib(item.id, |this| {
if mod_inner_docs {
this.resolve_doc_links(&item.attrs, MaybeExported::Ok(item.id));
}
let old_macro_rules = this.parent_scope.macro_rules;
visit::walk_item(this, item);
// Maintain macro_rules scopes in the same way as during early resolution
// for diagnostics and doc links.
if item.attrs.iter().all(|attr| {
!attr.has_name(sym::macro_use) && !attr.has_name(sym::macro_escape)
}) {
this.parent_scope.macro_rules = old_macro_rules;
}
let module = self.r.expect_module(self.r.local_def_id(item.id).to_def_id());
let orig_module = replace(&mut self.parent_scope.module, module);
self.with_rib(ValueNS, RibKind::Module(module), |this| {
this.with_rib(TypeNS, RibKind::Module(module), |this| {
if mod_inner_docs {
this.resolve_doc_links(&item.attrs, MaybeExported::Ok(item.id));
}
let old_macro_rules = this.parent_scope.macro_rules;
visit::walk_item(this, item);
// Maintain macro_rules scopes in the same way as during early resolution
// for diagnostics and doc links.
if item.attrs.iter().all(|attr| {
!attr.has_name(sym::macro_use) && !attr.has_name(sym::macro_escape)
}) {
this.parent_scope.macro_rules = old_macro_rules;
}
})
});
self.parent_scope.module = orig_module;
}
ItemKind::Static(box ast::StaticItem {
@ -2821,9 +2814,9 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
// We also can't shadow bindings from associated parent items.
for ns in [ValueNS, TypeNS] {
for parent_rib in self.ribs[ns].iter().rev() {
// Break at mod level, to account for nested items which are
// Break at module or block level, to account for nested items which are
// allowed to shadow generic param names.
if matches!(parent_rib.kind, RibKind::Module(..)) {
if matches!(parent_rib.kind, RibKind::Module(..) | RibKind::Block(..)) {
break;
}
@ -4652,16 +4645,16 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
debug!("(resolving block) entering block");
// Move down in the graph, if there's an anonymous module rooted here.
let orig_module = self.parent_scope.module;
let anonymous_module = self.r.block_map.get(&block.id).cloned(); // clones a reference
let anonymous_module = self.r.block_map.get(&block.id).copied();
let mut num_macro_definition_ribs = 0;
if let Some(anonymous_module) = anonymous_module {
debug!("(resolving block) found anonymous module, moving down");
self.ribs[ValueNS].push(Rib::new(RibKind::Module(anonymous_module)));
self.ribs[TypeNS].push(Rib::new(RibKind::Module(anonymous_module)));
self.ribs[ValueNS].push(Rib::new(RibKind::Block(Some(anonymous_module))));
self.ribs[TypeNS].push(Rib::new(RibKind::Block(Some(anonymous_module))));
self.parent_scope.module = anonymous_module;
} else {
self.ribs[ValueNS].push(Rib::new(RibKind::Normal));
self.ribs[ValueNS].push(Rib::new(RibKind::Block(None)));
}
// Descend into the block.

View file

@ -849,9 +849,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
}
// Try to find in last block rib
if let Some(rib) = &self.last_block_rib
&& let RibKind::Normal = rib.kind
{
if let Some(rib) = &self.last_block_rib {
for (ident, &res) in &rib.bindings {
if let Res::Local(_) = res
&& path.len() == 1
@ -900,7 +898,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
if path.len() == 1 {
for rib in self.ribs[ns].iter().rev() {
let item = path[0].ident;
if let RibKind::Module(module) = rib.kind
if let RibKind::Module(module) | RibKind::Block(Some(module)) = rib.kind
&& let Some(did) = find_doc_alias_name(self.r, module, item.name)
{
return Some((did, item));
@ -2458,9 +2456,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
}
}
if let RibKind::Module(module) = rib.kind
&& let ModuleKind::Block = module.kind
{
if let RibKind::Block(Some(module)) = rib.kind {
self.r.add_module_candidates(module, &mut names, &filter_fn, Some(ctxt));
} else if let RibKind::Module(module) = rib.kind {
// Encountered a module item, abandon ribs and look into that module and preludes.

View file

@ -16,10 +16,11 @@ use std::{cmp, fmt, fs, iter};
use externs::{ExternOpt, split_extern_opt};
use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
use rustc_data_structures::stable_hasher::{StableOrd, ToStableHashKey};
use rustc_data_structures::stable_hasher::{StableHasher, StableOrd, ToStableHashKey};
use rustc_errors::emitter::HumanReadableErrorType;
use rustc_errors::{ColorConfig, DiagArgValue, DiagCtxtFlags, IntoDiagArg};
use rustc_feature::UnstableFeatures;
use rustc_hashes::Hash64;
use rustc_macros::{Decodable, Encodable, HashStable_Generic};
use rustc_span::edition::{DEFAULT_EDITION, EDITION_NAME_LIST, Edition, LATEST_STABLE_EDITION};
use rustc_span::source_map::FilePathMapping;
@ -1195,7 +1196,25 @@ pub struct OutputFilenames {
pub const RLINK_EXT: &str = "rlink";
pub const RUST_CGU_EXT: &str = "rcgu";
pub const DWARF_OBJECT_EXT: &str = "dwo";
pub const MAX_FILENAME_LENGTH: usize = 143; // ecryptfs limits filenames to 143 bytes see #49914
/// Ensure the filename is not too long, as some filesystems have a limit.
/// If the filename is too long, hash part of it and append the hash to the filename.
/// This is a workaround for long crate names generating overly long filenames.
fn maybe_strip_file_name(mut path: PathBuf) -> PathBuf {
if path.file_name().map_or(0, |name| name.len()) > MAX_FILENAME_LENGTH {
let filename = path.file_name().unwrap().to_string_lossy();
let hash_len = 64 / 4; // Hash64 is 64 bits encoded in hex
let stripped_len = filename.len() - MAX_FILENAME_LENGTH + hash_len;
let mut hasher = StableHasher::new();
filename[..stripped_len].hash(&mut hasher);
let hash = hasher.finish::<Hash64>();
path.set_file_name(format!("{:x}-{}", hash, &filename[stripped_len..]));
}
path
}
impl OutputFilenames {
pub fn new(
out_directory: PathBuf,
@ -1288,7 +1307,7 @@ impl OutputFilenames {
}
let temps_directory = self.temps_directory.as_ref().unwrap_or(&self.out_directory);
self.with_directory_and_extension(temps_directory, &extension)
maybe_strip_file_name(self.with_directory_and_extension(temps_directory, &extension))
}
pub fn temp_path_for_diagnostic(&self, ext: &str) -> PathBuf {

View file

@ -542,6 +542,7 @@ symbols! {
audit_that,
augmented_assignments,
auto_traits,
autodiff,
autodiff_forward,
autodiff_reverse,
automatically_derived,

View file

@ -168,7 +168,6 @@ impl Target {
forward!(main_needs_argc_argv);
forward!(has_thread_local);
forward!(obj_is_bitcode);
forward!(bitcode_llvm_cmdline);
forward_opt!(max_atomic_width);
forward_opt!(min_atomic_width);
forward!(atomic_cas);
@ -361,7 +360,6 @@ impl ToJson for Target {
target_option_val!(main_needs_argc_argv);
target_option_val!(has_thread_local);
target_option_val!(obj_is_bitcode);
target_option_val!(bitcode_llvm_cmdline);
target_option_val!(min_atomic_width);
target_option_val!(max_atomic_width);
target_option_val!(atomic_cas);
@ -555,7 +553,6 @@ struct TargetSpecJson {
main_needs_argc_argv: Option<bool>,
has_thread_local: Option<bool>,
obj_is_bitcode: Option<bool>,
bitcode_llvm_cmdline: Option<StaticCow<str>>,
max_atomic_width: Option<u64>,
min_atomic_width: Option<u64>,
atomic_cas: Option<bool>,

View file

@ -2624,8 +2624,6 @@ pub struct TargetOptions {
/// If we give emcc .o files that are actually .bc files it
/// will 'just work'.
pub obj_is_bitcode: bool,
/// Content of the LLVM cmdline section associated with embedded bitcode.
pub bitcode_llvm_cmdline: StaticCow<str>,
/// Don't use this field; instead use the `.min_atomic_width()` method.
pub min_atomic_width: Option<u64>,
@ -2989,7 +2987,6 @@ impl Default for TargetOptions {
allow_asm: true,
has_thread_local: false,
obj_is_bitcode: false,
bitcode_llvm_cmdline: "".into(),
min_atomic_width: None,
max_atomic_width: None,
atomic_cas: true,

View file

@ -1619,8 +1619,18 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
{
let e = self.tcx.erase_regions(e);
let f = self.tcx.erase_regions(f);
let expected = with_forced_trimmed_paths!(e.sort_string(self.tcx));
let found = with_forced_trimmed_paths!(f.sort_string(self.tcx));
let mut expected = with_forced_trimmed_paths!(e.sort_string(self.tcx));
let mut found = with_forced_trimmed_paths!(f.sort_string(self.tcx));
if let ObligationCauseCode::Pattern { span, .. } = cause.code()
&& let Some(span) = span
&& !span.from_expansion()
&& cause.span.from_expansion()
{
// When the type error comes from a macro like `assert!()`, and we are pointing at
// code the user wrote the cause and effect are reversed as the expected value is
// what the macro expanded to.
(found, expected) = (expected, found);
}
if expected == found {
label_or_note(span, terr.to_string(self.tcx));
} else {
@ -2143,7 +2153,9 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
) -> Option<(DiagStyledString, DiagStyledString)> {
match values {
ValuePairs::Regions(exp_found) => self.expected_found_str(exp_found),
ValuePairs::Terms(exp_found) => self.expected_found_str_term(exp_found, long_ty_path),
ValuePairs::Terms(exp_found) => {
self.expected_found_str_term(cause, exp_found, long_ty_path)
}
ValuePairs::Aliases(exp_found) => self.expected_found_str(exp_found),
ValuePairs::ExistentialTraitRef(exp_found) => self.expected_found_str(exp_found),
ValuePairs::ExistentialProjection(exp_found) => self.expected_found_str(exp_found),
@ -2182,6 +2194,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
fn expected_found_str_term(
&self,
cause: &ObligationCause<'tcx>,
exp_found: ty::error::ExpectedFound<ty::Term<'tcx>>,
long_ty_path: &mut Option<PathBuf>,
) -> Option<(DiagStyledString, DiagStyledString)> {
@ -2189,8 +2202,27 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
if exp_found.references_error() {
return None;
}
let (mut expected, mut found) = (exp_found.expected, exp_found.found);
Some(match (exp_found.expected.kind(), exp_found.found.kind()) {
if let ObligationCauseCode::Pattern { span, .. } = cause.code()
&& let Some(span) = span
&& !span.from_expansion()
&& cause.span.from_expansion()
{
// When the type error comes from a macro like `assert!()`, and we are pointing at
// code the user wrote, the cause and effect are reversed as the expected value is
// what the macro expanded to. So if the user provided a `Type` when the macro is
// written in such a way that a `bool` was expected, we want to print:
// = note: expected `bool`
// found `Type`"
// but as far as the compiler is concerned, after expansion what was expected was `Type`
// = note: expected `Type`
// found `bool`"
// so we reverse them here to match user expectation.
(expected, found) = (found, expected);
}
Some(match (expected.kind(), found.kind()) {
(ty::TermKind::Ty(expected), ty::TermKind::Ty(found)) => {
let (mut exp, mut fnd) = self.cmp(expected, found);
// Use the terminal width as the basis to determine when to compress the printed

View file

@ -37,7 +37,6 @@ use super::on_unimplemented::{AppendConstMessage, OnUnimplementedNote};
use super::suggestions::get_explanation_based_on_obligation;
use super::{
ArgKind, CandidateSimilarity, FindExprBySpan, GetSafeTransmuteErrorAndReason, ImplCandidate,
UnsatisfiedConst,
};
use crate::error_reporting::TypeErrCtxt;
use crate::error_reporting::infer::TyCategory;
@ -374,13 +373,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
}
}
let UnsatisfiedConst(unsatisfied_const) = self
.maybe_add_note_for_unsatisfied_const(
leaf_trait_predicate,
&mut err,
span,
);
if let Some((msg, span)) = type_def {
err.span_label(span, msg);
}
@ -506,7 +498,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
span,
is_fn_trait,
suggested,
unsatisfied_const,
);
// Changing mutability doesn't make a difference to whether we have
@ -2716,7 +2707,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
span: Span,
is_fn_trait: bool,
suggested: bool,
unsatisfied_const: bool,
) {
let body_def_id = obligation.cause.body_id;
let span = if let ObligationCauseCode::BinOp { rhs_span: Some(rhs_span), .. } =
@ -2763,10 +2753,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
self.tcx.def_span(trait_def_id),
crate::fluent_generated::trait_selection_trait_has_no_impls,
);
} else if !suggested
&& !unsatisfied_const
&& trait_predicate.polarity() == ty::PredicatePolarity::Positive
{
} else if !suggested && trait_predicate.polarity() == ty::PredicatePolarity::Positive {
// Can't show anything else useful, try to find similar impls.
let impl_candidates = self.find_similar_impl_candidates(trait_predicate);
if !self.report_similar_impl_candidates(
@ -2878,17 +2865,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
}
}
fn maybe_add_note_for_unsatisfied_const(
&self,
_trait_predicate: ty::PolyTraitPredicate<'tcx>,
_err: &mut Diag<'_>,
_span: Span,
) -> UnsatisfiedConst {
let unsatisfied_const = UnsatisfiedConst(false);
// FIXME(const_trait_impl)
unsatisfied_const
}
fn report_closure_error(
&self,
obligation: &PredicateObligation<'tcx>,

View file

@ -51,8 +51,6 @@ enum GetSafeTransmuteErrorAndReason {
Error { err_msg: String, safe_transmute_explanation: Option<String> },
}
struct UnsatisfiedConst(pub bool);
/// Crude way of getting back an `Expr` from a `Span`.
pub struct FindExprBySpan<'hir> {
pub span: Span,

View file

@ -120,9 +120,9 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.15.4"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"rustc-std-workspace-alloc",
"rustc-std-workspace-core",

View file

@ -1573,6 +1573,47 @@ impl<'b, T: ?Sized> Ref<'b, T> {
}
}
/// Tries to makes a new `Ref` for a component of the borrowed data.
/// On failure, the original guard is returned alongside with the error
/// returned by the closure.
///
/// The `RefCell` is already immutably borrowed, so this cannot fail.
///
/// This is an associated function that needs to be used as
/// `Ref::try_map(...)`. A method would interfere with methods of the same
/// name on the contents of a `RefCell` used through `Deref`.
///
/// # Examples
///
/// ```
/// #![feature(refcell_try_map)]
/// use std::cell::{RefCell, Ref};
/// use std::str::{from_utf8, Utf8Error};
///
/// let c = RefCell::new(vec![0xF0, 0x9F, 0xA6 ,0x80]);
/// let b1: Ref<'_, Vec<u8>> = c.borrow();
/// let b2: Result<Ref<'_, str>, _> = Ref::try_map(b1, |v| from_utf8(v));
/// assert_eq!(&*b2.unwrap(), "🦀");
///
/// let c = RefCell::new(vec![0xF0, 0x9F, 0xA6]);
/// let b1: Ref<'_, Vec<u8>> = c.borrow();
/// let b2: Result<_, (Ref<'_, Vec<u8>>, Utf8Error)> = Ref::try_map(b1, |v| from_utf8(v));
/// let (b3, e) = b2.unwrap_err();
/// assert_eq!(*b3, vec![0xF0, 0x9F, 0xA6]);
/// assert_eq!(e.valid_up_to(), 0);
/// ```
#[unstable(feature = "refcell_try_map", issue = "143801")]
#[inline]
pub fn try_map<U: ?Sized, E>(
orig: Ref<'b, T>,
f: impl FnOnce(&T) -> Result<&U, E>,
) -> Result<Ref<'b, U>, (Self, E)> {
match f(&*orig) {
Ok(value) => Ok(Ref { value: NonNull::from(value), borrow: orig.borrow }),
Err(e) => Err((orig, e)),
}
}
/// Splits a `Ref` into multiple `Ref`s for different components of the
/// borrowed data.
///
@ -1734,6 +1775,58 @@ impl<'b, T: ?Sized> RefMut<'b, T> {
}
}
/// Tries to makes a new `RefMut` for a component of the borrowed data.
/// On failure, the original guard is returned alongside with the error
/// returned by the closure.
///
/// The `RefCell` is already mutably borrowed, so this cannot fail.
///
/// This is an associated function that needs to be used as
/// `RefMut::try_map(...)`. A method would interfere with methods of the same
/// name on the contents of a `RefCell` used through `Deref`.
///
/// # Examples
///
/// ```
/// #![feature(refcell_try_map)]
/// use std::cell::{RefCell, RefMut};
/// use std::str::{from_utf8_mut, Utf8Error};
///
/// let c = RefCell::new(vec![0x68, 0x65, 0x6C, 0x6C, 0x6F]);
/// {
/// let b1: RefMut<'_, Vec<u8>> = c.borrow_mut();
/// let b2: Result<RefMut<'_, str>, _> = RefMut::try_map(b1, |v| from_utf8_mut(v));
/// let mut b2 = b2.unwrap();
/// assert_eq!(&*b2, "hello");
/// b2.make_ascii_uppercase();
/// }
/// assert_eq!(*c.borrow(), "HELLO".as_bytes());
///
/// let c = RefCell::new(vec![0xFF]);
/// let b1: RefMut<'_, Vec<u8>> = c.borrow_mut();
/// let b2: Result<_, (RefMut<'_, Vec<u8>>, Utf8Error)> = RefMut::try_map(b1, |v| from_utf8_mut(v));
/// let (b3, e) = b2.unwrap_err();
/// assert_eq!(*b3, vec![0xFF]);
/// assert_eq!(e.valid_up_to(), 0);
/// ```
#[unstable(feature = "refcell_try_map", issue = "143801")]
#[inline]
pub fn try_map<U: ?Sized, E>(
mut orig: RefMut<'b, T>,
f: impl FnOnce(&mut T) -> Result<&mut U, E>,
) -> Result<RefMut<'b, U>, (Self, E)> {
// SAFETY: function holds onto an exclusive reference for the duration
// of its call through `orig`, and the pointer is only de-referenced
// inside of the function call never allowing the exclusive reference to
// escape.
match f(&mut *orig) {
Ok(value) => {
Ok(RefMut { value: NonNull::from(value), borrow: orig.borrow, marker: PhantomData })
}
Err(e) => Err((orig, e)),
}
}
/// Splits a `RefMut` into multiple `RefMut`s for different components of the
/// borrowed data.
///

View file

@ -3157,6 +3157,44 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
#[rustc_intrinsic]
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
/// Generates the LLVM body for the automatic differentiation of `f` using Enzyme,
/// with `df` as the derivative function and `args` as its arguments.
///
/// Used internally as the body of `df` when expanding the `#[autodiff_forward]`
/// and `#[autodiff_reverse]` attribute macros.
///
/// Type Parameters:
/// - `F`: The original function to differentiate. Must be a function item.
/// - `G`: The derivative function. Must be a function item.
/// - `T`: A tuple of arguments passed to `df`.
/// - `R`: The return type of the derivative function.
///
/// This shows where the `autodiff` intrinsic is used during macro expansion:
///
/// ```rust,ignore (macro example)
/// #[autodiff_forward(df1, Dual, Const, Dual)]
/// pub fn f1(x: &[f64], y: f64) -> f64 {
/// unimplemented!()
/// }
/// ```
///
/// expands to:
///
/// ```rust,ignore (macro example)
/// #[rustc_autodiff]
/// #[inline(never)]
/// pub fn f1(x: &[f64], y: f64) -> f64 {
/// ::core::panicking::panic("not implemented")
/// }
/// #[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
/// pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
/// ::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
/// }
/// ```
#[rustc_nounwind]
#[rustc_intrinsic]
pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
/// Inform Miri that a given pointer definitely has a certain alignment.
#[cfg(miri)]
#[rustc_allow_const_fn_unstable(const_eval_select)]

View file

@ -208,6 +208,10 @@
#[allow(unused_extern_crates)]
extern crate self as core;
/* The core prelude, not as all-encompassing as the std prelude */
// The compiler expects the prelude definition to be defined before it's use statement.
pub mod prelude;
#[prelude_import]
#[allow(unused)]
use prelude::rust_2024::*;
@ -293,10 +297,6 @@ pub mod f64;
#[macro_use]
pub mod num;
/* The core prelude, not as all-encompassing as the std prelude */
pub mod prelude;
/* Core modules for ownership management */
pub mod hint;

View file

@ -1495,6 +1495,7 @@ pub(crate) mod builtin {
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[allow_internal_unstable(core_intrinsics)]
#[rustc_builtin_macro]
pub macro autodiff_forward($item:item) {
/* compiler built-in */
@ -1513,6 +1514,7 @@ pub(crate) mod builtin {
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[allow_internal_unstable(core_intrinsics)]
#[rustc_builtin_macro]
pub macro autodiff_reverse($item:item) {
/* compiler built-in */

View file

@ -1094,23 +1094,17 @@ macro_rules! uint_impl {
self / rhs
}
/// Checked integer division without remainder. Computes `self / rhs`.
///
/// # Panics
///
/// This function will panic if `rhs == 0` or `self % rhs != 0`.
/// Checked integer division without remainder. Computes `self / rhs`,
/// returning `None` if `rhs == 0` or if `self % rhs != 0`.
///
/// # Examples
///
/// ```
/// #![feature(exact_div)]
#[doc = concat!("assert_eq!(64", stringify!($SelfT), ".exact_div(2), 32);")]
#[doc = concat!("assert_eq!(64", stringify!($SelfT), ".exact_div(32), 2);")]
/// ```
///
/// ```should_panic
/// #![feature(exact_div)]
#[doc = concat!("let _ = 65", stringify!($SelfT), ".exact_div(2);")]
#[doc = concat!("assert_eq!(64", stringify!($SelfT), ".checked_exact_div(2), Some(32));")]
#[doc = concat!("assert_eq!(64", stringify!($SelfT), ".checked_exact_div(32), Some(2));")]
#[doc = concat!("assert_eq!(64", stringify!($SelfT), ".checked_exact_div(0), None);")]
#[doc = concat!("assert_eq!(65", stringify!($SelfT), ".checked_exact_div(2), None);")]
/// ```
#[unstable(
feature = "exact_div",

View file

@ -428,11 +428,15 @@
//
#![default_lib_allocator]
// The Rust prelude
// The compiler expects the prelude definition to be defined before it's use statement.
pub mod prelude;
// Explicitly import the prelude. The compiler uses this same unstable attribute
// to import the prelude implicitly when building crates that depend on std.
#[prelude_import]
#[allow(unused)]
use prelude::rust_2021::*;
use prelude::rust_2024::*;
// Access to Bencher, etc.
#[cfg(test)]
@ -483,9 +487,6 @@ mod macros;
#[macro_use]
pub mod rt;
// The Rust prelude
pub mod prelude;
#[stable(feature = "rust1", since = "1.0.0")]
pub use core::any;
#[stable(feature = "core_array", since = "1.35.0")]

View file

@ -107,7 +107,6 @@
- [Installation](./autodiff/installation.md)
- [How to debug](./autodiff/debugging.md)
- [Autodiff flags](./autodiff/flags.md)
- [Current limitations](./autodiff/limitations.md)
# Source Code Representation

View file

@ -1,27 +0,0 @@
# Current limitations
## Safety and Soundness
Enzyme currently assumes that the user passes shadow arguments (`dx`, `dy`, ...) of appropriate size. Under Reverse Mode, we additionally assume that shadow arguments are mutable. In Reverse Mode we adjust the outermost pointer or reference to be mutable. Therefore `&f32` will receive the shadow type `&mut f32`. However, we do not check length for other types than slices (e.g. enums, Vec). We also do not enforce mutability of inner references, but will warn if we recognize them. We do intend to add additional checks over time.
## ABI adjustments
In some cases, a function parameter might get lowered in a way that we currently don't handle correctly, leading to a compile time type mismatch in the `rustc_codegen_llvm` backend. Here are some [examples](https://github.com/EnzymeAD/rust/issues/105).
## Compile Times
Enzyme will often achieve excellent runtime performance, but might increase your compile time by a large factor. For Rust, we already have made significant improvements and have a list of further improvements planed - please reach out if you have time to help here.
### Type Analysis
Most of the times, Type Analysis (TA) is the reason of large (>5x) compile time increases when using Enzyme. This poster explains why we need to run Type Analysis in the bottom left part: [Poster Link](https://c.wsmoses.com/posters/Enzyme-llvmdev.pdf).
We intend to increase the number of locations where we pass down Type information based on Rust types, which in turn will reduce the number of locations where Enzyme has to run Type Analysis, which will help compile times.
### Duplicated Optimizations
The key reason for Enzyme offering often excellent performance is that Enzyme differentiates already optimized LLVM-IR. However, we also (have to) run LLVM's optimization pipeline after differentiating, to make sure that the code which Enzyme generates is optimized properly. As a result you should have excellent runtime performance (please fill an issue if not), but at a compile time cost for running optimizations twice.
### Fat-LTO
The usage of `#[autodiff(...)]` currently requires compiling your project with Fat-LTO. We technically only need LTO if the function being differentiated calls functions in other compilation units. Therefore, other solutions are possible, but this is the most simple one to get started.

View file

@ -602,7 +602,6 @@ fn add_item_to_search_index(tcx: TyCtxt<'_>, cache: &mut Cache, item: &clean::It
search_type,
aliases,
deprecation,
is_unstable: item.stability(tcx).map(|x| x.is_unstable()).unwrap_or(false),
};
cache.search_index.push(index_item);
}

View file

@ -139,7 +139,6 @@ pub(crate) struct IndexItem {
pub(crate) search_type: Option<IndexItemFunctionType>,
pub(crate) aliases: Box<[Symbol]>,
pub(crate) deprecation: Option<Deprecation>,
pub(crate) is_unstable: bool,
}
/// A type used for the search index.

View file

@ -93,7 +93,6 @@ pub(crate) fn build_index(
),
aliases: item.attrs.get_doc_aliases(),
deprecation: item.deprecation(tcx),
is_unstable: item.stability(tcx).is_some_and(|x| x.is_unstable()),
});
}
}
@ -656,7 +655,6 @@ pub(crate) fn build_index(
let mut parents_backref_queue = VecDeque::new();
let mut functions = String::with_capacity(self.items.len());
let mut deprecated = Vec::with_capacity(self.items.len());
let mut unstable = Vec::with_capacity(self.items.len());
let mut type_backref_queue = VecDeque::new();
@ -713,9 +711,6 @@ pub(crate) fn build_index(
// bitmasks always use 1-indexing for items, with 0 as the crate itself
deprecated.push(u32::try_from(index + 1).unwrap());
}
if item.is_unstable {
unstable.push(u32::try_from(index + 1).unwrap());
}
}
for (index, path) in &revert_extra_paths {
@ -754,7 +749,6 @@ pub(crate) fn build_index(
crate_data.serialize_field("r", &re_exports)?;
crate_data.serialize_field("b", &self.associated_item_disambiguators)?;
crate_data.serialize_field("c", &bitmap_to_string(&deprecated))?;
crate_data.serialize_field("u", &bitmap_to_string(&unstable))?;
crate_data.serialize_field("e", &bitmap_to_string(&self.empty_desc))?;
crate_data.serialize_field("P", &param_names)?;
if has_aliases {

View file

@ -129,7 +129,7 @@ declare namespace rustdoc {
/**
* A single parsed "atom" in a search query. For example,
*
*
* std::fmt::Formatter, Write -> Result<()>
*
* QueryElement {
@ -449,8 +449,6 @@ declare namespace rustdoc {
* of `p`) but is used for modules items like free functions.
*
* `c` is an array of item indices that are deprecated.
*
* `u` is an array of item indices that are unstable.
*/
type RawSearchIndexCrate = {
doc: string,
@ -465,7 +463,6 @@ declare namespace rustdoc {
p: Array<[number, string] | [number, string, number] | [number, string, number, number] | [number, string, number, number, string]>,
b: Array<[number, String]>,
c: string,
u: string,
r: Array<[number, number]>,
P: Array<[number, string]>,
};

View file

@ -1464,11 +1464,6 @@ class DocSearch {
* @type {Map<String, RoaringBitmap>}
*/
this.searchIndexEmptyDesc = new Map();
/**
* @type {Map<String, RoaringBitmap>}
*/
this.searchIndexUnstable = new Map();
/**
* @type {Uint32Array}
*/
@ -2057,12 +2052,9 @@ class DocSearch {
};
const descShardList = [descShard];
// Deprecated and unstable items and items with no description
// Deprecated items and items with no description
this.searchIndexDeprecated.set(crate, new RoaringBitmap(crateCorpus.c));
this.searchIndexEmptyDesc.set(crate, new RoaringBitmap(crateCorpus.e));
if (crateCorpus.u !== undefined && crateCorpus.u !== null) {
this.searchIndexUnstable.set(crate, new RoaringBitmap(crateCorpus.u));
}
let descIndex = 0;
/**
@ -3334,25 +3326,6 @@ class DocSearch {
return a - b;
}
// sort unstable items later
// FIXME: there is some doubt if this is the most effecient way to implement this.
// alternative options include:
// * put is_unstable on each item when the index is built.
// increases memory usage but avoids a hashmap lookup.
// * put is_unstable on each item before sorting.
// better worst case performance but worse average case performance.
a = Number(
// @ts-expect-error
this.searchIndexUnstable.get(aaa.item.crate).contains(aaa.item.bitIndex),
);
b = Number(
// @ts-expect-error
this.searchIndexUnstable.get(bbb.item.crate).contains(bbb.item.bitIndex),
);
if (a !== b) {
return a - b;
}
// sort by crate (current crate comes first)
a = Number(aaa.item.crate !== preferredCrate);
b = Number(bbb.item.crate !== preferredCrate);

View file

@ -11,7 +11,7 @@ use rustc_ast::{BinOpKind, LitKind, RangeLimits};
use rustc_data_structures::packed::Pu128;
use rustc_data_structures::unhash::UnindexMap;
use rustc_errors::{Applicability, Diag};
use rustc_hir::{Block, Body, Expr, ExprKind, UnOp};
use rustc_hir::{Body, Expr, ExprKind};
use rustc_lint::{LateContext, LateLintPass};
use rustc_session::declare_lint_pass;
use rustc_span::source_map::Spanned;
@ -135,12 +135,12 @@ fn assert_len_expr<'hir>(
cx: &LateContext<'_>,
expr: &'hir Expr<'hir>,
) -> Option<(LengthComparison, usize, &'hir Expr<'hir>)> {
let (cmp, asserted_len, slice_len) = if let Some(higher::If { cond, then, .. }) = higher::If::hir(expr)
&& let ExprKind::Unary(UnOp::Not, condition) = &cond.kind
&& let ExprKind::Binary(bin_op, left, right) = &condition.kind
let (cmp, asserted_len, slice_len) = if let Some(
higher::IfLetOrMatch::Match(cond, [_, then], _)
) = higher::IfLetOrMatch::parse(cx, expr)
&& let ExprKind::Binary(bin_op, left, right) = &cond.kind
// check if `then` block has a never type expression
&& let ExprKind::Block(Block { expr: Some(then_expr), .. }, _) = then.kind
&& cx.typeck_results().expr_ty(then_expr).is_never()
&& cx.typeck_results().expr_ty(then.body).is_never()
{
len_comparison(bin_op.node, left, right)?
} else if let Some((macro_call, bin_op)) = first_node_macro_backtrace(cx, expr).find_map(|macro_call| {

View file

@ -196,6 +196,7 @@ fn issue_13106() {
const {
assert!(EMPTY_STR.is_empty());
//~^ const_is_empty
}
const {

View file

@ -158,10 +158,16 @@ LL | let _ = val.is_empty();
| ^^^^^^^^^^^^^^
error: this expression always evaluates to true
--> tests/ui/const_is_empty.rs:202:9
--> tests/ui/const_is_empty.rs:198:17
|
LL | assert!(EMPTY_STR.is_empty());
| ^^^^^^^^^^^^^^^^^^^^
error: this expression always evaluates to true
--> tests/ui/const_is_empty.rs:203:9
|
LL | EMPTY_STR.is_empty();
| ^^^^^^^^^^^^^^^^^^^^
error: aborting due to 27 previous errors
error: aborting due to 28 previous errors

View file

@ -1,6 +1,6 @@
#![warn(clippy::incompatible_msrv)]
#![feature(custom_inner_attributes)]
#![allow(stable_features)]
#![allow(stable_features, clippy::diverging_sub_expression)]
#![feature(strict_provenance)] // For use in test
#![clippy::msrv = "1.3.0"]

View file

@ -17,7 +17,7 @@ object = "0.37"
regex = "1.11"
serde_json = "1.0"
similar = "2.7"
wasmparser = { version = "0.219", default-features = false, features = ["std"] }
wasmparser = { version = "0.236", default-features = false, features = ["std", "features", "validate"] }
# tidy-alphabetical-end
# Shared with bootstrap and compiletest

View file

@ -47,9 +47,9 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.6.19"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192"
dependencies = [
"anstyle",
"anstyle-parse",
@ -77,22 +77,22 @@ dependencies = [
[[package]]
name = "anstyle-query"
version = "1.1.3"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.9"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@ -156,9 +156,9 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]]
name = "cc"
version = "1.2.31"
version = "1.2.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3a42d84bb6b69d3a8b3eaacf0d88f179e1929695e1ad012b6cf64d9caaa5fd2"
checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e"
dependencies = [
"shlex",
]
@ -185,9 +185,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed87a9d530bb41a67537289bafcac159cb3ee28460e0a4571123d2a778a6a882"
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f"
dependencies = [
"clap_builder",
"clap_derive",
@ -195,9 +195,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64f4f3f3c77c94aff3c7e9aac9a2ca1974a5adf392a8bb751e827d6d127ab966"
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65"
dependencies = [
"anstream",
"anstyle",
@ -208,9 +208,9 @@ dependencies = [
[[package]]
name = "clap_complete"
version = "4.5.55"
version = "4.5.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5abde44486daf70c5be8b8f8f1b66c49f86236edf6fa2abadb4d961c4c6229a"
checksum = "67e4efcbb5da11a92e8a609233aa1e8a7d91e38de0be865f016d14700d45a7fd"
dependencies = [
"clap",
]
@ -564,9 +564,9 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.15.4"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
[[package]]
name = "heck"
@ -1406,9 +1406,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.21"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ryu"
@ -2183,9 +2183,9 @@ dependencies = [
[[package]]
name = "zerovec"
version = "0.11.2"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428"
checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b"
dependencies = [
"yoke",
"zerofrom",

View file

@ -55,10 +55,10 @@ extern "C" {
pub unsafe fn rust_to_c_increases_alignment(x: Align1) {
// i686-linux: start:
// i686-linux-NEXT: [[ALLOCA:%[0-9a-z]+]] = alloca [48 x i8], align 4
// i686-linux-NEXT: call void @llvm.lifetime.start.p0(i64 48, ptr {{.*}}[[ALLOCA]])
// i686-linux-NEXT: call void @llvm.lifetime.start.p0({{(i64 48, )?}}ptr {{.*}}[[ALLOCA]])
// i686-linux-NEXT: call void @llvm.memcpy.{{.+}}(ptr {{.*}}align 4 {{.*}}[[ALLOCA]], ptr {{.*}}align 1 {{.*}}%x
// i686-linux-NEXT: call void @extern_c_align1({{.+}} [[ALLOCA]])
// i686-linux-NEXT: call void @llvm.lifetime.end.p0(i64 48, ptr {{.*}}[[ALLOCA]])
// i686-linux-NEXT: call void @llvm.lifetime.end.p0({{(i64 48, )?}}ptr {{.*}}[[ALLOCA]])
// x86_64-linux: start:
// x86_64-linux-NEXT: call void @extern_c_align1

View file

@ -26,12 +26,13 @@
#![feature(autodiff)]
use std::autodiff::autodiff;
use std::autodiff::autodiff_forward;
// CHECK: ;
#[no_mangle]
//#[autodiff(d_square1, Forward, Dual, Dual)]
#[autodiff(d_square2, Forward, 4, Dualv, Dualv)]
#[autodiff(d_square3, Forward, 4, Dual, Dual)]
#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
#[autodiff_forward(d_square3, 4, Dual, Dual)]
fn square(x: &[f32], y: &mut [f32]) {
assert!(x.len() >= 4);
assert!(y.len() >= 5);

View file

@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward;
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
#[inline(never)]
fn square(x: &f32) -> f32 {
x * x
}
// d_sqaure2
// d_square2
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 {
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: ret [4 x float] %19
// CHECK-NEXT: }
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: ret [4 x float] %15
// CHECK-NEXT: }
// d_square3, the extra float is the original return value (x * x)
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 {
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
// CHECK-NEXT: ret { float, [4 x float] } %21
// CHECK-NEXT: }
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
// CHECK-NEXT: ret { float, [4 x float] } %17
// CHECK-NEXT: }
fn main() {
let x = std::hint::black_box(3.0);

View file

@ -6,19 +6,11 @@
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double
// Ensure that `d_square::<f32>` code is generated
//
// CHECK: ; generic::square
@ -28,6 +20,15 @@ fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
// CHECK-NOT: ret
// CHECK: fmul float
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double
fn main() {
let xf32: f32 = std::hint::black_box(3.0);
let xf64: f64 = std::hint::black_box(3.0);

View file

@ -14,25 +14,27 @@
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[inline(never)]
fn square(x: &f64) -> f64 {
x * x
}
#[autodiff_reverse(d_square2, Duplicated, Active)]
#[inline(never)]
fn square2(x: &f64) -> f64 {
x * x
}
// CHECK:; identical_fnc::main
// CHECK-NEXT:; Function Attrs:
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
// CHECK-NEXT:start:
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:; call identical_fnc::d_square
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
fn main() {
let x = std::hint::black_box(3.0);

View file

@ -1,23 +0,0 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
// CHECK: ; inline::d_square
// CHECK-NEXT: ; Function Attrs: alwaysinline
// CHECK-NOT: noinline
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
fn main() {
let x = std::hint::black_box(3.0);
let mut dx1 = std::hint::black_box(1.0);
let _ = d_square(&x, &mut dx1, 1.0);
assert_eq!(dx1, 6.0);
}

View file

@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse;
#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
#[inline(never)]
fn square(x: &f64) -> f64 {
x * x
}
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'")
// CHECK-NEXT:invertstart:
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val

View file

@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse;
#[no_mangle]
#[autodiff_reverse(df, Active, Active, Active)]
#[inline(never)]
fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
// CHECK-NEXT:start:
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
// CHECK-NEXT: ret void
// CHECK-NEXT:}
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
// CHECK-NEXT: invertstart:
// CHECK-NEXT: %_4 = fmul float %x, %x
// CHECK-NEXT: %_3 = fmul float %_4, %y
// CHECK-NEXT: %_0 = fpext float %_3 to double
// CHECK-NEXT: %0 = fadd fast float %y, %y
// CHECK-NEXT: %1 = fmul fast float %0, %x
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
// CHECK-NEXT: ret { double, float, float } %4
// CHECK-NEXT: }
fn main() {
let x = std::hint::black_box(3.0);
let y = std::hint::black_box(2.5);
let scalar = std::hint::black_box(1.0);
let (r1, r2, r3) = df(x, y, scalar);
// 3*3*1.5 = 22.5
// 3*3*2.5 = 22.5
assert_eq!(r1, 22.5);
// 2*x*y = 2*3*2.5 = 15.0
assert_eq!(r2, 15.0);

View file

@ -0,0 +1,30 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
// Just check it does not crash for now
// CHECK: ;
#![feature(autodiff)]
use std::autodiff::autodiff_reverse;
struct Foo {
a: f64,
}
trait MyTrait {
fn f(&self, x: f64) -> f64;
fn df(&self, x: f64, seed: f64) -> (f64, f64);
}
impl MyTrait for Foo {
#[autodiff_reverse(df, Const, Active, Active)]
fn f(&self, x: f64) -> f64 {
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
}
}
fn main() {
let foo = Foo { a: 3.0f64 };
dbg!(foo.df(1.0, 1.0));
}

View file

@ -16,14 +16,14 @@ use minicore::*;
// CHECK-NEXT: start:
// CHECK-NEXT: [[B:%.*]] = alloca
// CHECK-NEXT: [[A:%.*]] = alloca
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4096, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 4096, )?}}ptr [[A]])
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 4096, i1 false)
// CHECK-NEXT: call void %h(ptr {{.*}} [[A]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4096, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4096, ptr [[B]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 4096, )?}}ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 4096, )?}}ptr [[B]])
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 4096, i1 false)
// CHECK-NEXT: call void %h(ptr {{.*}} [[B]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4096, ptr [[B]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 4096, )?}}ptr [[B]])
#[no_mangle]
pub fn const_indirect(h: extern "C" fn([u32; 1024])) {
const C: [u32; 1024] = [0; 1024];
@ -42,12 +42,12 @@ pub struct Str {
// CHECK-LABEL: define void @immediate_indirect(ptr {{.*}}%s.0, i32 {{.*}}%s.1, ptr {{.*}}%g)
// CHECK-NEXT: start:
// CHECK-NEXT: [[A:%.*]] = alloca
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr [[A]])
// CHECK-NEXT: store ptr %s.0, ptr [[A]]
// CHECK-NEXT: [[B:%.]] = getelementptr inbounds i8, ptr [[A]], i32 4
// CHECK-NEXT: store i32 %s.1, ptr [[B]]
// CHECK-NEXT: call void %g(ptr {{.*}} [[A]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr [[A]])
#[no_mangle]
pub fn immediate_indirect(s: Str, g: extern "C" fn(Str)) {
g(s);
@ -58,10 +58,10 @@ pub fn immediate_indirect(s: Str, g: extern "C" fn(Str)) {
// CHECK-LABEL: define void @align_indirect(ptr{{.*}} align 1{{.*}} %a, ptr{{.*}} %fun)
// CHECK-NEXT: start:
// CHECK-NEXT: [[A:%.*]] = alloca [1024 x i8], align 4
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 1024, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 1024, )?}}ptr [[A]])
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 1 %a, i32 1024, i1 false)
// CHECK-NEXT: call void %fun(ptr {{.*}} [[A]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 1024, ptr [[A]])
// CHECK-NEXT: call void @llvm.lifetime.end.p0({{(i64 1024, )?}}ptr [[A]])
#[no_mangle]
pub fn align_indirect(a: [u8; 1024], fun: extern "C" fn([u8; 1024])) {
fun(a);

View file

@ -97,10 +97,10 @@ pub extern "C" fn float_ptr_same_lanes(v: f64x2) -> PtrX2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <2 x double> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x ptr> %[[RET]]
unsafe { transmute(v) }
}
@ -111,10 +111,10 @@ pub extern "C" fn ptr_float_same_lanes(v: PtrX2) -> f64x2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <2 x ptr> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x double>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x double> %[[RET]]
unsafe { transmute(v) }
}
@ -125,10 +125,10 @@ pub extern "C" fn int_ptr_same_lanes(v: i64x2) -> PtrX2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <2 x i64> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x ptr> %[[RET]]
unsafe { transmute(v) }
}
@ -139,10 +139,10 @@ pub extern "C" fn ptr_int_same_lanes(v: PtrX2) -> i64x2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <2 x ptr> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x i64>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x i64> %[[RET]]
unsafe { transmute(v) }
}
@ -153,10 +153,10 @@ pub extern "C" fn float_ptr_widen(v: f32x4) -> PtrX2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <4 x float> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x ptr> %[[RET]]
unsafe { transmute(v) }
}
@ -167,10 +167,10 @@ pub extern "C" fn int_ptr_widen(v: i32x4) -> PtrX2 {
// CHECK-NOT: alloca
// CHECK: %[[TEMP:.+]] = alloca [16 x i8]
// CHECK-NOT: alloca
// CHECK: call void @llvm.lifetime.start.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: store <4 x i32> %v, ptr %[[TEMP]]
// CHECK: %[[RET:.+]] = load <2 x ptr>, ptr %[[TEMP]]
// CHECK: call void @llvm.lifetime.end.p0(i64 16, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 16, )?}}ptr %[[TEMP]])
// CHECK: ret <2 x ptr> %[[RET]]
unsafe { transmute(v) }
}

View file

@ -192,12 +192,12 @@ pub unsafe fn check_byte_from_bool(x: bool) -> u8 {
#[no_mangle]
pub unsafe fn check_to_pair(x: u64) -> Option<i32> {
// CHECK: %[[TEMP:.+]] = alloca [8 x i8], align 8
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr %[[TEMP]])
// CHECK: store i64 %x, ptr %[[TEMP]], align 8
// CHECK: %[[PAIR0:.+]] = load i32, ptr %[[TEMP]], align 8
// CHECK: %[[PAIR1P:.+]] = getelementptr inbounds i8, ptr %[[TEMP]], i64 4
// CHECK: %[[PAIR1:.+]] = load i32, ptr %[[PAIR1P]], align 4
// CHECK: call void @llvm.lifetime.end.p0(i64 8, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr %[[TEMP]])
// CHECK: insertvalue {{.+}}, i32 %[[PAIR0]], 0
// CHECK: insertvalue {{.+}}, i32 %[[PAIR1]], 1
transmute(x)
@ -207,12 +207,12 @@ pub unsafe fn check_to_pair(x: u64) -> Option<i32> {
#[no_mangle]
pub unsafe fn check_from_pair(x: Option<i32>) -> u64 {
// CHECK: %[[TEMP:.+]] = alloca [8 x i8], align 8
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.start.p0({{(i64 8, )?}}ptr %[[TEMP]])
// CHECK: store i32 %x.0, ptr %[[TEMP]], align 8
// CHECK: %[[PAIR1P:.+]] = getelementptr inbounds i8, ptr %[[TEMP]], i64 4
// CHECK: store i32 %x.1, ptr %[[PAIR1P]], align 4
// CHECK: %[[R:.+]] = load i64, ptr %[[TEMP]], align 8
// CHECK: call void @llvm.lifetime.end.p0(i64 8, ptr %[[TEMP]])
// CHECK: call void @llvm.lifetime.end.p0({{(i64 8, )?}}ptr %[[TEMP]])
// CHECK: ret i64 %[[R]]
transmute(x)
}

View file

@ -19,6 +19,6 @@ pub fn outer_function(x: S, y: S) -> usize {
// CHECK: [[spill:%.*]] = alloca
// CHECK-NOT: [[ptr_tmp:%.*]] = getelementptr inbounds i8, ptr [[spill]]
// CHECK-NOT: [[load:%.*]] = load ptr, ptr
// CHECK: call void @llvm.lifetime.start{{.*}}({{.*}}, ptr [[spill]])
// CHECK: call void @llvm.lifetime.start{{.*}}({{(.*, )?}}ptr [[spill]])
// CHECK: [[inner:%.*]] = getelementptr inbounds i8, ptr [[spill]]
// CHECK: call void @llvm.memcpy{{.*}}(ptr {{align .*}} [[inner]], ptr {{align .*}} %x

View file

@ -8,27 +8,27 @@ pub fn test() {
let a = 0u8;
&a; // keep variable in an alloca
// CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, ptr %a)
// CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}ptr %a)
{
let b = &Some(a);
&b; // keep variable in an alloca
// CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, {{.*}})
// CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}{{.*}})
// CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, {{.*}})
// CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}{{.*}})
// CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, {{.*}})
// CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}{{.*}})
// CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, {{.*}})
// CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}{{.*}})
}
let c = 1u8;
&c; // keep variable in an alloca
// CHECK: call void @llvm.lifetime.start{{.*}}(i{{[0-9 ]+}}, ptr %c)
// CHECK: call void @llvm.lifetime.start{{.*}}({{(i[0-9 ]+, )?}}ptr %c)
// CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, ptr %c)
// CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}ptr %c)
// CHECK: call void @llvm.lifetime.end{{.*}}(i{{[0-9 ]+}}, ptr %a)
// CHECK: call void @llvm.lifetime.end{{.*}}({{(i[0-9 ]+, )?}}ptr %a)
}

View file

@ -23,7 +23,7 @@ extern "Rust" {
#[no_mangle]
pub fn test_uninhabited_ret_by_ref() {
// CHECK: %_1 = alloca [24 x i8], align {{8|4}}
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %_1)
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 24, )?}}ptr nonnull %_1)
// CHECK-NEXT: call void @opaque({{.*}} sret([24 x i8]) {{.*}} %_1) #2
// CHECK-NEXT: unreachable
unsafe {
@ -35,7 +35,7 @@ pub fn test_uninhabited_ret_by_ref() {
#[no_mangle]
pub fn test_uninhabited_ret_by_ref_with_arg(rsi: u32) {
// CHECK: %_2 = alloca [24 x i8], align {{8|4}}
// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %_2)
// CHECK-NEXT: call void @llvm.lifetime.start.p0({{(i64 24, )?}}ptr nonnull %_2)
// CHECK-NEXT: call void @opaque_with_arg({{.*}} sret([24 x i8]) {{.*}} %_2, i32 noundef %rsi) #2
// CHECK-NEXT: unreachable
unsafe {

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:autodiff_forward.pp
@ -16,7 +16,6 @@ extern crate std;
use std::autodiff::{autodiff_forward, autodiff_reverse};
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
@ -36,163 +35,96 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
#[inline(never)]
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f64, f64)>::default())
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f2(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f2(x, y))
::core::intrinsics::autodiff(f2::<>, df2::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f3(x, y))
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, bx_0, y))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f4() {}
#[rustc_autodiff(Forward, 1, None)]
#[inline(never)]
pub fn df4() -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4());
::core::hint::black_box(());
}
pub fn df4() -> () { ::core::intrinsics::autodiff(f4::<>, df4::<>, ()) }
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
#[inline(never)]
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((by_0,));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_y::<>, (x, y, by_0))
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f5(x, y))
::core::intrinsics::autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
}
struct DoesNotImplDefault;
#[rustc_autodiff]
#[inline(never)]
pub fn f6() -> DoesNotImplDefault {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const)]
#[inline(never)]
pub fn df6() -> DoesNotImplDefault {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f6());
::core::hint::black_box(());
::core::hint::black_box(f6())
::core::intrinsics::autodiff(f6::<>, df6::<>, ())
}
#[rustc_autodiff]
#[inline(never)]
pub fn f7(x: f32) -> () {}
#[rustc_autodiff(Forward, 1, Const, None)]
#[inline(never)]
pub fn df7(x: f32) -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f7(x));
::core::hint::black_box(());
::core::intrinsics::autodiff(f7::<>, df7::<>, (x,))
}
#[no_mangle]
#[rustc_autodiff]
#[inline(never)]
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Forward, 4, Dual, Dual)]
#[inline(never)]
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 5usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 5usize]>::default())
::core::intrinsics::autodiff(f8::<>, f8_3::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
#[inline(never)]
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 4usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 4usize]>::default())
::core::intrinsics::autodiff(f8::<>, f8_2::<>,
(x, bx_0, bx_1, bx_2, bx_3))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
::core::intrinsics::autodiff(f8::<>, f8_1::<>, (x, bx_0))
}
pub fn f9() {
#[rustc_autodiff]
#[inline(never)]
fn inner(x: f32) -> f32 { x * x }
#[rustc_autodiff(Forward, 1, Dual, Dual)]
#[inline(never)]
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f32, f32)>::default())
::core::intrinsics::autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
}
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
::core::intrinsics::autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
}
}
#[rustc_autodiff]
#[inline(never)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10::<T>(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10::<T>(x))
::core::intrinsics::autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
}
fn main() {}

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:autodiff_reverse.pp
@ -16,7 +16,6 @@ extern crate std;
use std::autodiff::autodiff_reverse;
#[rustc_autodiff]
#[inline(never)]
pub fn f1(x: &[f64], y: f64) -> f64 {
// Not the most interesting derivative, but who are we to judge
@ -29,58 +28,33 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f1(x, y))
::core::intrinsics::autodiff(f1::<>, df1::<>, (x, dx_0, y, dret))
}
#[rustc_autodiff]
#[inline(never)]
pub fn f2() {}
#[rustc_autodiff(Reverse, 1, None)]
#[inline(never)]
pub fn df2() {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2());
::core::hint::black_box(());
}
pub fn df2() { ::core::intrinsics::autodiff(f2::<>, df2::<>, ()) }
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f3(x, y))
::core::intrinsics::autodiff(f3::<>, df3::<>, (x, dx_0, y, dret))
}
enum Foo { Reverse, }
use Foo::Reverse;
#[rustc_autodiff]
#[inline(never)]
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Reverse, 1, Const, None)]
#[inline(never)]
pub fn df4(x: f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4(x));
::core::hint::black_box(());
}
pub fn df4(x: f32) { ::core::intrinsics::autodiff(f4::<>, df4::<>, (x,)) }
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: *const f32, y: &f32) {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
#[inline(never)]
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dy_0));
::core::intrinsics::autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0))
}
fn main() {}

View file

@ -23,7 +23,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
unimplemented!()
}
enum Foo { Reverse }
enum Foo {
Reverse,
}
use Foo::Reverse;
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
// constructor) namespace? > It's expected to work normally.

View file

@ -3,10 +3,10 @@
//@ needs-enzyme
#![feature(autodiff)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
extern crate std;
#[prelude_import]
use ::std::prelude::rust_2015::*;
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:inherent_impl.pp
@ -26,16 +26,12 @@ trait MyTrait {
impl MyTrait for Foo {
#[rustc_autodiff]
#[inline(never)]
fn f(&self, x: f64) -> f64 {
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
}
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
#[inline(never)]
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(self.f(x));
::core::hint::black_box((dret,));
::core::hint::black_box((self.f(x), f64::default()))
::core::intrinsics::autodiff(Self::f::<>, Self::df::<>,
(self, x, dret))
}
}

View file

@ -0,0 +1,7 @@
// This file has very long lines, but there is no way to avoid it as we are testing
// long crate names. so:
// ignore-tidy-linelength
extern crate generated_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_crate_name;
fn main() {}

View file

@ -0,0 +1,32 @@
// This file has very long lines, but there is no way to avoid it as we are testing
// long crate names. so:
// ignore-tidy-linelength
// A variant of the smoke test to check that link time optimization
// (LTO) is accepted by the compiler, and that
// passing its various flags still results in successful compilation, even for very long crate names.
// See https://github.com/rust-lang/rust/issues/49914
//@ ignore-cross-compile
use std::fs;
use run_make_support::{rfs, rustc};
// This test make sure we don't get such following error:
// error: could not write output to generated_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_crate_name.generated_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_crate_name.9384edb61bfd127c-cgu.0.rcgu.o: File name too long
// as reported in issue #49914
fn main() {
let lto_flags = ["-Clto", "-Clto=yes", "-Clto=off", "-Clto=thin", "-Clto=fat"];
let aux_file = "generated_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_large_crate_name.rs";
// The auxiliary file is used to test long crate names.
// The file name is intentionally long to test the handling of long filenames.
// We don't commit it to avoid issues with Windows paths which have known limitations for the full path length.
// Posix usually only have a limit for the length of the file name.
rfs::write(aux_file, "#![crate_type = \"rlib\"]\n");
for flag in lto_flags {
rustc().input(aux_file).arg(flag).run();
rustc().input("main.rs").arg(flag).run();
}
}

View file

@ -0,0 +1,43 @@
#![no_core]
#![crate_type = "cdylib"]
#![feature(no_core, lang_items, allocator_internals, rustc_attrs)]
#![needs_allocator]
#![allow(internal_features)]
#[rustc_std_internal_symbol]
unsafe fn __rust_alloc(_size: usize, _align: usize) -> *mut u8 {
0 as *mut u8
}
unsafe extern "Rust" {
#[rustc_std_internal_symbol]
fn __rust_alloc_error_handler(size: usize, align: usize) -> !;
}
#[used]
static mut BUF: [u8; 1024] = [0; 1024];
#[unsafe(no_mangle)]
extern "C" fn init() {
unsafe {
__rust_alloc_error_handler(0, 0);
}
}
mod minicore {
#[lang = "pointee_sized"]
pub trait PointeeSized {}
#[lang = "meta_sized"]
pub trait MetaSized: PointeeSized {}
#[lang = "sized"]
pub trait Sized: MetaSized {}
#[lang = "copy"]
pub trait Copy {}
impl Copy for u8 {}
#[lang = "drop_in_place"]
fn drop_in_place<T>(_: *mut T) {}
}

Some files were not shown because too many files have changed in this diff Show more