Merge pull request #4528 from rust-lang/rustup-2025-08-18

Automatic Rustup
This commit is contained in:
Ralf Jung 2025-08-18 07:19:57 +00:00 committed by GitHub
commit 909d3297af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
281 changed files with 4769 additions and 3037 deletions

View file

@ -223,6 +223,11 @@ jobs:
cd src/ci/citool
CARGO_INCREMENTAL=0 CARGO_TARGET_DIR=../../../build/citool cargo build
- name: wait for Windows disk cleanup to finish
if: ${{ matrix.free_disk && startsWith(matrix.os, 'windows-') }}
run: |
python3 src/ci/scripts/free-disk-space-windows-wait.py
- name: run the build
run: |
set +e

View file

@ -80,9 +80,9 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.6.19"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192"
dependencies = [
"anstyle",
"anstyle-parse",
@ -119,18 +119,18 @@ dependencies = [
[[package]]
name = "anstyle-query"
version = "1.1.3"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
name = "anstyle-svg"
version = "0.1.9"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a43964079ef399480603125d5afae2b219aceffb77478956e25f17b9bc3435c"
checksum = "dc03a770ef506fe1396c0e476120ac0e6523cf14b74218dd5f18cd6833326fa9"
dependencies = [
"anstyle",
"anstyle-lossy",
@ -141,13 +141,13 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
version = "3.0.9"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@ -353,9 +353,9 @@ checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e"
[[package]]
name = "camino"
version = "1.1.10"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0da45bc31171d8d6960122e222a67740df867c1dd53b4d51caa297084c185cab"
checksum = "5d07aa9a93b00c76f71bc35d598bed923f6d4f3a9ca5c24b7737ae1a292841c0"
dependencies = [
"serde",
]
@ -518,9 +518,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed87a9d530bb41a67537289bafcac159cb3ee28460e0a4571123d2a778a6a882"
checksum = "50fd97c9dc2399518aa331917ac6f274280ec5eb34e555dd291899745c48ec6f"
dependencies = [
"clap_builder",
"clap_derive",
@ -538,9 +538,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.42"
version = "4.5.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64f4f3f3c77c94aff3c7e9aac9a2ca1974a5adf392a8bb751e827d6d127ab966"
checksum = "c35b5830294e1fa0462034af85cc95225a4cb07092c088c55bda3147cfcd8f65"
dependencies = [
"anstream",
"anstyle",
@ -937,9 +937,9 @@ dependencies = [
[[package]]
name = "cxx"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3523cc02ad831111491dd64b27ad999f1ae189986728e477604e61b81f828df"
checksum = "b5287274dfdf7e7eaa3d97d460eb2a94922539e6af214bda423f292105011ee2"
dependencies = [
"cc",
"cxxbridge-cmd",
@ -951,9 +951,9 @@ dependencies = [
[[package]]
name = "cxx-build"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212b754247a6f07b10fa626628c157593f0abf640a3dd04cce2760eca970f909"
checksum = "65f3ce027a744135db10a1ebffa0863dab685aeef48f40a02c201f5e70c667d3"
dependencies = [
"cc",
"codespan-reporting",
@ -966,9 +966,9 @@ dependencies = [
[[package]]
name = "cxxbridge-cmd"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f426a20413ec2e742520ba6837c9324b55ffac24ead47491a6e29f933c5b135a"
checksum = "a07dc23f2eea4774297f4c9a17ae4065fecb63127da556e6c9fadb0216d93595"
dependencies = [
"clap",
"codespan-reporting",
@ -980,15 +980,15 @@ dependencies = [
[[package]]
name = "cxxbridge-flags"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258b6069020b4e5da6415df94a50ee4f586a6c38b037a180e940a43d06a070d"
checksum = "f7a4dbad6171f763c4066c83dcd27546b6e93c5c5ae2229f9813bda7233f571d"
[[package]]
name = "cxxbridge-macro"
version = "1.0.161"
version = "1.0.166"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dec184b52be5008d6eaf7e62fc1802caf1ad1227d11b3b7df2c409c7ffc3f4"
checksum = "a9be4b527950fc42db06163705e78e73eedc8fd723708e942afe3572a9a2c366"
dependencies = [
"indexmap",
"proc-macro2",
@ -1055,9 +1055,9 @@ version = "0.1.91"
[[package]]
name = "derive-where"
version = "1.5.0"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b"
checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f"
dependencies = [
"proc-macro2",
"quote",
@ -1567,9 +1567,9 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.15.4"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
@ -1688,7 +1688,7 @@ dependencies = [
"potential_utf",
"yoke 0.8.0",
"zerofrom",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1721,7 +1721,7 @@ dependencies = [
"litemap 0.8.0",
"tinystr 0.8.1",
"writeable 0.6.1",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1769,7 +1769,7 @@ dependencies = [
"icu_properties",
"icu_provider 2.0.0",
"smallvec",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1791,7 +1791,7 @@ dependencies = [
"icu_provider 2.0.0",
"potential_utf",
"zerotrie",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1831,7 +1831,7 @@ dependencies = [
"yoke 0.8.0",
"zerofrom",
"zerotrie",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -1909,9 +1909,9 @@ dependencies = [
[[package]]
name = "indenter"
version = "0.3.3"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5"
[[package]]
name = "indexmap"
@ -2971,7 +2971,7 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585"
dependencies = [
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -3283,7 +3283,7 @@ dependencies = [
"regex",
"serde_json",
"similar",
"wasmparser 0.219.2",
"wasmparser 0.236.0",
]
[[package]]
@ -3623,6 +3623,7 @@ dependencies = [
"rustc_hir",
"rustc_incremental",
"rustc_index",
"rustc_lint_defs",
"rustc_macros",
"rustc_metadata",
"rustc_middle",
@ -4309,7 +4310,6 @@ name = "rustc_monomorphize"
version = "0.0.0"
dependencies = [
"rustc_abi",
"rustc_ast",
"rustc_data_structures",
"rustc_errors",
"rustc_fluent_macro",
@ -4318,7 +4318,6 @@ dependencies = [
"rustc_middle",
"rustc_session",
"rustc_span",
"rustc_symbol_mangling",
"rustc_target",
"serde",
"serde_json",
@ -4920,9 +4919,9 @@ dependencies = [
[[package]]
name = "rustversion"
version = "1.0.21"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "ruzstd"
@ -4980,9 +4979,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
version = "1.0.8"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f6280af86e5f559536da57a45ebc84948833b3bee313a7dd25232e09c878a52"
checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2"
[[package]]
name = "self_cell"
@ -5508,7 +5507,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b"
dependencies = [
"displaydoc",
"zerovec 0.11.2",
"zerovec 0.11.4",
]
[[package]]
@ -6854,9 +6853,9 @@ dependencies = [
[[package]]
name = "zerovec"
version = "0.11.2"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428"
checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b"
dependencies = [
"yoke 0.8.0",
"zerofrom",

View file

@ -9,7 +9,7 @@
# a custom configuration file can also be specified with `--config` to the build
# system.
#
# Note that the following are equivelent, for more details see <https://toml.io/en/v1.0.0>.
# Note that the following are equivalent, for more details see <https://toml.io/en/v1.0.0>.
#
# build.verbose = 1
#
@ -345,9 +345,9 @@
# want to use vendoring. See https://forge.rust-lang.org/infra/other-installation-methods.html#source-code.
#build.vendor = if "is a tarball source" && "vendor" dir exists && ".cargo/config.toml" file exists { true } else { false }
# Typically the build system will build the Rust compiler twice. The second
# compiler, however, will simply use its own libraries to link against. If you
# would rather to perform a full bootstrap, compiling the compiler three times,
# If you build the compiler more than twice (stage3+) or the standard library more than once
# (stage 2+), the third compiler and second library will get uplifted from stage2 and stage1,
# respectively. If you would like to disable this uplifting, and rather perform a full bootstrap,
# then you can set this option to true.
#
# This is only useful for verifying that rustc generates reproducible builds.
@ -482,7 +482,7 @@
# Use `--extra-checks=''` to temporarily disable all extra checks.
#
# Automatically enabled in the "tools" profile.
# Set to the empty string to force disable (recommeded for hdd systems).
# Set to the empty string to force disable (recommended for hdd systems).
#build.tidy-extra-checks = ""
# Indicates whether ccache is used when building certain artifacts (e.g. LLVM).

View file

@ -17,7 +17,7 @@ ast_passes_abi_must_not_have_parameters_or_return_type=
ast_passes_abi_must_not_have_return_type=
invalid signature for `extern {$abi}` function
.note = functions with the "custom" ABI cannot have a return type
.note = functions with the {$abi} ABI cannot have a return type
.help = remove the return type
ast_passes_assoc_const_without_body =
@ -32,6 +32,13 @@ ast_passes_assoc_type_without_body =
associated type in `impl` without body
.suggestion = provide a definition for the type
ast_passes_async_fn_in_const_trait_or_trait_impl =
async functions are not allowed in `const` {$in_impl ->
[true] trait impls
*[false] traits
}
.label = associated functions of `const` cannot be declared `async`
ast_passes_at_least_one_trait = at least one trait must be specified
ast_passes_auto_generic = auto traits cannot have generic parameters

View file

@ -293,6 +293,21 @@ impl<'a> AstValidator<'a> {
});
}
fn check_async_fn_in_const_trait_or_impl(&self, sig: &FnSig, parent: &TraitOrTraitImpl) {
let Some(const_keyword) = parent.constness() else { return };
let Some(CoroutineKind::Async { span: async_keyword, .. }) = sig.header.coroutine_kind
else {
return;
};
self.dcx().emit_err(errors::AsyncFnInConstTraitOrTraitImpl {
async_keyword,
in_impl: matches!(parent, TraitOrTraitImpl::TraitImpl { .. }),
const_keyword,
});
}
fn check_fn_decl(&self, fn_decl: &FnDecl, self_semantic: SelfSemantic) {
self.check_decl_num_args(fn_decl);
self.check_decl_cvariadic_pos(fn_decl);
@ -390,7 +405,13 @@ impl<'a> AstValidator<'a> {
if let InterruptKind::X86 = interrupt_kind {
// "x86-interrupt" is special because it does have arguments.
// FIXME(workingjubilee): properly lint on acceptable input types.
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output {
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output
&& match &ret_ty.kind {
TyKind::Never => false,
TyKind::Tup(tup) if tup.is_empty() => false,
_ => true,
}
{
self.dcx().emit_err(errors::AbiMustNotHaveReturnType {
span: ret_ty.span,
abi,
@ -449,7 +470,13 @@ impl<'a> AstValidator<'a> {
fn reject_params_or_return(&self, abi: ExternAbi, ident: &Ident, sig: &FnSig) {
let mut spans: Vec<_> = sig.decl.inputs.iter().map(|p| p.span).collect();
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output {
if let FnRetTy::Ty(ref ret_ty) = sig.decl.output
&& match &ret_ty.kind {
TyKind::Never => false,
TyKind::Tup(tup) if tup.is_empty() => false,
_ => true,
}
{
spans.push(ret_ty.span);
}
@ -1566,6 +1593,7 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
self.visibility_not_permitted(&item.vis, errors::VisibilityNotPermittedNote::TraitImpl);
if let AssocItemKind::Fn(box Fn { sig, .. }) = &item.kind {
self.check_trait_fn_not_const(sig.header.constness, parent);
self.check_async_fn_in_const_trait_or_impl(sig, parent);
}
}

View file

@ -62,6 +62,16 @@ pub(crate) struct TraitFnConst {
pub make_trait_const_sugg: Option<Span>,
}
#[derive(Diagnostic)]
#[diag(ast_passes_async_fn_in_const_trait_or_trait_impl)]
pub(crate) struct AsyncFnInConstTraitOrTraitImpl {
#[primary_span]
pub async_keyword: Span,
pub in_impl: bool,
#[label]
pub const_keyword: Span,
}
#[derive(Diagnostic)]
#[diag(ast_passes_forbidden_bound)]
pub(crate) struct ForbiddenBound {

View file

@ -103,6 +103,10 @@ impl RegionTracker {
self.max_nameable_universe
}
pub(crate) fn max_placeholder_universe_reached(self) -> UniverseIndex {
self.max_placeholder_universe_reached
}
fn merge_min_max_seen(&mut self, other: &Self) {
self.max_placeholder_universe_reached = std::cmp::max(
self.max_placeholder_universe_reached,

View file

@ -713,7 +713,7 @@ impl<'tcx> RegionInferenceContext<'tcx> {
// If the member region lives in a higher universe, we currently choose
// the most conservative option by leaving it unchanged.
if !self.max_nameable_universe(scc).is_root() {
if !self.max_placeholder_universe_reached(scc).is_root() {
return;
}
@ -1376,6 +1376,13 @@ impl<'tcx> RegionInferenceContext<'tcx> {
self.scc_annotations[scc].max_nameable_universe()
}
pub(crate) fn max_placeholder_universe_reached(
&self,
scc: ConstraintSccIndex,
) -> UniverseIndex {
self.scc_annotations[scc].max_placeholder_universe_reached()
}
/// Checks the final value for the free region `fr` to see if it
/// grew too large. In particular, examine what `end(X)` points
/// wound up in `fr`'s final value; for each `end(X)` where `X !=

View file

@ -222,6 +222,15 @@ builtin_macros_format_unused_args = multiple unused formatting arguments
builtin_macros_format_use_positional = consider using a positional formatting argument instead
builtin_macros_derive_from_wrong_target = `#[derive(From)]` used on {$kind}
builtin_macros_derive_from_wrong_field_count = `#[derive(From)]` used on a struct with {$multiple_fields ->
[true] multiple fields
*[false] no fields
}
builtin_macros_derive_from_usage_note = `#[derive(From)]` can only be used on structs with exactly one field
builtin_macros_multiple_default_attrs = multiple `#[default]` attributes
.note = only one `#[default]` attribute is needed
.label = `#[default]` used here

View file

@ -1,8 +1,8 @@
mod context;
use rustc_ast::token::Delimiter;
use rustc_ast::token::{self, Delimiter};
use rustc_ast::tokenstream::{DelimSpan, TokenStream};
use rustc_ast::{DelimArgs, Expr, ExprKind, MacCall, Path, PathSegment, UnOp, token};
use rustc_ast::{DelimArgs, Expr, ExprKind, MacCall, Path, PathSegment};
use rustc_ast_pretty::pprust;
use rustc_errors::PResult;
use rustc_expand::base::{DummyResult, ExpandResult, ExtCtxt, MacEager, MacroExpanderResult};
@ -29,7 +29,7 @@ pub(crate) fn expand_assert<'cx>(
// `core::panic` and `std::panic` are different macros, so we use call-site
// context to pick up whichever is currently in scope.
let call_site_span = cx.with_call_site_ctxt(span);
let call_site_span = cx.with_call_site_ctxt(cond_expr.span);
let panic_path = || {
if use_panic_2021(span) {
@ -63,7 +63,7 @@ pub(crate) fn expand_assert<'cx>(
}),
})),
);
expr_if_not(cx, call_site_span, cond_expr, then, None)
assert_cond_check(cx, call_site_span, cond_expr, then)
}
// If `generic_assert` is enabled, generates rich captured outputs
//
@ -88,26 +88,33 @@ pub(crate) fn expand_assert<'cx>(
)),
)],
);
expr_if_not(cx, call_site_span, cond_expr, then, None)
assert_cond_check(cx, call_site_span, cond_expr, then)
};
ExpandResult::Ready(MacEager::expr(expr))
}
/// `assert!($cond_expr, $custom_message)`
struct Assert {
cond_expr: Box<Expr>,
custom_message: Option<TokenStream>,
}
// if !{ ... } { ... } else { ... }
fn expr_if_not(
cx: &ExtCtxt<'_>,
span: Span,
cond: Box<Expr>,
then: Box<Expr>,
els: Option<Box<Expr>>,
) -> Box<Expr> {
cx.expr_if(span, cx.expr(span, ExprKind::Unary(UnOp::Not, cond)), then, els)
/// `match <cond> { true => {} _ => <then> }`
fn assert_cond_check(cx: &ExtCtxt<'_>, span: Span, cond: Box<Expr>, then: Box<Expr>) -> Box<Expr> {
// Instead of expanding to `if !<cond> { <then> }`, we expand to
// `match <cond> { true => {} _ => <then> }`.
// This allows us to always complain about mismatched types instead of "cannot apply unary
// operator `!` to type `X`" when passing an invalid `<cond>`, while also allowing `<cond>` to
// be `&true`.
let els = cx.expr_block(cx.block(span, thin_vec![]));
let mut arms = thin_vec![];
arms.push(cx.arm(span, cx.pat_lit(span, cx.expr_bool(span, true)), els));
arms.push(cx.arm(span, cx.pat_wild(span), then));
// We wrap the `match` in a statement to limit the length of any borrows introduced in the
// condition.
cx.expr_block(cx.block(span, [cx.stmt_expr(cx.expr_match(span, cond, arms))].into()))
}
fn parse_assert<'a>(cx: &ExtCtxt<'a>, sp: Span, stream: TokenStream) -> PResult<'a, Assert> {

View file

@ -15,11 +15,12 @@ mod llvm_enzyme {
use rustc_ast::tokenstream::*;
use rustc_ast::visit::AssocCtxt::*;
use rustc_ast::{
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, kw, sym};
use rustc_span::{Ident, Span, Symbol, sym};
use thin_vec::{ThinVec, thin_vec};
use tracing::{debug, trace};
@ -179,11 +180,8 @@ mod llvm_enzyme {
}
/// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
/// should just prevent early inlining and optimizations which alter the function signature.
/// The exact signature of the generated function depends on the configuration provided by the
/// user, but here is an example:
/// type-checking and can be called by users. The exact signature of the generated function
/// depends on the configuration provided by the user, but here is an example:
///
/// ```
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@ -194,19 +192,12 @@ mod llvm_enzyme {
/// which becomes expanded to:
/// ```
/// #[rustc_autodiff]
/// #[inline(never)]
/// fn sin(x: &Box<f32>) -> f32 {
/// f32::sin(**x)
/// }
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
/// #[inline(never)]
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
/// unsafe {
/// asm!("NOP");
/// };
/// ::core::hint::black_box(sin(x));
/// ::core::hint::black_box((dx, dret));
/// ::core::hint::black_box(sin(x))
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
/// }
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@ -227,16 +218,24 @@ mod llvm_enzyme {
// first get information about the annotable item: visibility, signature, name and generic
// parameters.
// these will be used to generate the differentiated version of the function
let Some((vis, sig, primal, generics)) = (match &item {
Annotatable::Item(iitem) => extract_item_info(iitem),
let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
Annotatable::Item(iitem) => {
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
}
Annotatable::Stmt(stmt) => match &stmt.kind {
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
ast::StmtKind::Item(iitem) => {
extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
}
_ => None,
},
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
}
Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
assoc_item.vis.clone(),
sig.clone(),
ident.clone(),
generics.clone(),
*of_trait,
)),
_ => None,
},
_ => None,
@ -254,7 +253,6 @@ mod llvm_enzyme {
};
let has_ret = has_ret(&sig.decl.output);
let sig_span = ecx.with_call_site_ctxt(sig.span);
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
@ -323,19 +321,23 @@ mod llvm_enzyme {
}
let span = ecx.with_def_site_ctxt(expand_span);
let n_active: u32 = x
.input_activity
.iter()
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = ecx.block(
span,
thin_vec![call_autodiff(
ecx,
primal,
first_ident(&meta_item_vec[0]),
span,
&d_sig,
&generics,
impl_of_trait,
)],
);
// The first element of it is the name of the function to be generated
let asdf = Box::new(ast::Fn {
let d_fn = Box::new(ast::Fn {
defaultness: ast::Defaultness::Final,
sig: d_sig,
ident: first_ident(&meta_item_vec[0]),
@ -368,7 +370,7 @@ mod llvm_enzyme {
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
match (attr, item) {
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@ -381,14 +383,16 @@ mod llvm_enzyme {
}
}
let mut has_inline_never = false;
// Don't add it multiple times:
let orig_annotatable: Annotatable = match item {
Annotatable::Item(ref mut iitem) => {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::Item(iitem.clone())
}
@ -396,8 +400,8 @@ mod llvm_enzyme {
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
assoc_item.attrs.push(attr);
}
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
assoc_item.attrs.push(inline_never.clone());
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
Annotatable::AssocItem(assoc_item.clone(), i)
}
@ -407,9 +411,8 @@ mod llvm_enzyme {
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
iitem.attrs.push(attr);
}
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
{
iitem.attrs.push(inline_never.clone());
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
has_inline_never = true;
}
}
_ => unreachable!("stmt kind checked previously"),
@ -428,12 +431,21 @@ mod llvm_enzyme {
tokens: ts,
});
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
// If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
let mut d_attrs = thin_vec![d_attr];
if has_inline_never {
d_attrs.push(inline_never);
}
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
let d_fn = Box::new(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
attrs: d_attrs,
id: ast::DUMMY_NODE_ID,
span,
vis,
@ -443,13 +455,13 @@ mod llvm_enzyme {
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
d_fn.vis = vis;
Annotatable::Stmt(Box::new(ast::Stmt {
@ -484,282 +496,95 @@ mod llvm_enzyme {
ty
}
// Will generate a body of the type:
// Generate `autodiff` intrinsic call
// ```
// {
// unsafe {
// asm!("NOP");
// }
// ::core::hint::black_box(primal(args));
// ::core::hint::black_box((args, ret));
// <This part remains to be done by following function>
// }
// std::intrinsics::autodiff(source, diff, (args))
// ```
fn init_body_helper(
fn call_autodiff(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
new_names: &[String],
sig_span: Span,
new_decl_span: Span,
idents: &[Ident],
errored: bool,
diff: Ident,
span: Span,
d_sig: &FnSig,
generics: &Generics,
) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) {
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
let noop = ast::InlineAsm {
asm_macro: ast::AsmMacro::Asm,
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
template_strs: Box::new([]),
operands: vec![],
clobber_abis: vec![],
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
line_spans: vec![],
};
let noop_expr = ecx.expr_asm(span, Box::new(noop));
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
let unsf_block = ast::Block {
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
id: ast::DUMMY_NODE_ID,
tokens: None,
rules: unsf,
is_impl: bool,
) -> rustc_ast::Stmt {
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
let tuple_expr = ecx.expr_tuple(
span,
};
let unsf_expr = ecx.expr_block(Box::new(unsf_block));
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
let black_box_primal_call = ecx.expr_call(
new_decl_span,
blackbox_call_expr.clone(),
thin_vec![primal_call.clone()],
);
let tup_args = new_names
.iter()
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
.collect();
let black_box_remaining_args = ecx.expr_call(
sig_span,
blackbox_call_expr.clone(),
thin_vec![ecx.expr_tuple(sig_span, tup_args)],
d_sig
.decl
.inputs
.iter()
.map(|arg| match arg.pat.kind {
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
_ => todo!(),
})
.collect::<ThinVec<_>>()
.into(),
);
let mut body = ecx.block(span, ThinVec::new());
body.stmts.push(ecx.stmt_semi(unsf_expr));
let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
let enzyme_path = ecx.path(span, enzyme_path_idents);
let call_expr = ecx.expr_call(
span,
ecx.expr_path(enzyme_path),
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
);
// This uses primal args which won't be available if we errored before
if !errored {
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
}
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
(body, primal_call, black_box_primal_call, blackbox_call_expr)
ecx.stmt_expr(call_expr)
}
/// We only want this function to type-check, since we will replace the body
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
/// so instead we manually build something that should pass the type checker.
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
/// bug would ever try to accidentally differentiate this placeholder function body.
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
/// from optimizing any arguments away.
fn gen_enzyme_body(
// Generate turbofish expression from fn name and generics
// Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`
// We use this expression when passing primal and diff function to the autodiff intrinsic
fn gen_turbofish_expr(
ecx: &ExtCtxt<'_>,
x: &AutoDiffAttrs,
n_active: u32,
sig: &ast::FnSig,
d_sig: &ast::FnSig,
primal: Ident,
new_names: &[String],
span: Span,
sig_span: Span,
idents: Vec<Ident>,
errored: bool,
ident: Ident,
generics: &Generics,
) -> Box<ast::Block> {
let new_decl_span = d_sig.span;
// Just adding some default inline-asm and black_box usages to prevent early inlining
// and optimizations which alter the function signature.
//
// The bb_primal_call is the black_box call of the primal function. We keep it around,
// since it has the convenient property of returning the type of the primal function,
// Remember, we only care to match types here.
// No matter which return we pick, we always wrap it into a std::hint::black_box call,
// to prevent rustc from propagating it into the caller.
let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
ecx,
span,
primal,
new_names,
sig_span,
new_decl_span,
&idents,
errored,
generics,
);
if !has_ret(&d_sig.decl.output) {
// there is no return type that we have to match, () works fine.
return body;
}
// Everything from here onwards just tries to fulfil the return type. Fun!
// having an active-only return means we'll drop the original return type.
// So that can be treated identical to not having one in the first place.
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
if primal_ret && n_active == 0 && x.mode.is_rev() {
// We only have the primal ret.
body.stmts.push(ecx.stmt_expr(bb_primal_call));
return body;
}
if !primal_ret && n_active == 1 {
// Again no tuple return, so return default float val.
let ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
let arg = ty.kind.is_simple_path().unwrap();
let tmp = ecx.def_site_path(&[arg, kw::Default]);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
body.stmts.push(ecx.stmt_expr(default_call_expr));
return body;
}
let mut exprs: Box<ast::Expr> = primal_call;
let d_ret_ty = match d_sig.decl.output {
FnRetTy::Ty(ref ty) => ty.clone(),
FnRetTy::Default(span) => {
panic!("Did not expect Default ret ty: {:?}", span);
}
};
if x.mode.is_fwd() {
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
// We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
// In all three cases, we can return `std::hint::black_box(<T>::default())`.
if x.ret_activity == DiffActivity::Const {
// Here we call the primal function, since our dummy function has the same return
// type due to the Const return activity.
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
} else {
let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
let y = ExprKind::Path(
Some(Box::new(q)),
ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
);
let default_call_expr = ecx.expr(span, y);
let default_call_expr =
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
}
} else if x.mode.is_rev() {
if x.width == 1 {
// We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
match d_ret_ty.kind {
TyKind::Tup(ref args) => {
// We have a tuple return type. We need to create a tuple of the same size
// and fill it with default values.
let mut exprs2 = thin_vec![exprs];
for arg in args.iter().skip(1) {
let arg = arg.kind.is_simple_path().unwrap();
let tmp = ecx.def_site_path(&[arg, kw::Default]);
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
let default_call_expr =
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
exprs2.push(default_call_expr);
}
exprs = ecx.expr_tuple(new_decl_span, exprs2);
}
_ => {
// Interestingly, even the `-> ArbitraryType` case
// ends up getting matched and handled correctly above,
// so we don't have to handle any other case for now.
panic!("Unsupported return type: {:?}", d_ret_ty);
}
}
}
exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
} else {
unreachable!("Unsupported mode: {:?}", x.mode);
}
body.stmts.push(ecx.stmt_expr(exprs));
body
}
fn gen_primal_call(
ecx: &ExtCtxt<'_>,
span: Span,
primal: Ident,
idents: &[Ident],
generics: &Generics,
is_impl: bool,
) -> Box<ast::Expr> {
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
if has_self {
let args: ThinVec<_> =
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let self_expr = ecx.expr_self(span);
ecx.expr_method_call(span, self_expr, primal, args)
} else {
let args: ThinVec<_> =
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
let mut primal_path = ecx.path_ident(span, primal);
let is_generic = !generics.params.is_empty();
match (is_generic, primal_path.segments.last_mut()) {
(true, Some(function_path)) => {
let primal_generic_types = generics
.params
.iter()
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
let generated_generic_types = primal_generic_types
.map(|type_param| {
let generic_param = TyKind::Path(
None,
ast::Path {
span,
segments: thin_vec![ast::PathSegment {
ident: type_param.ident,
args: None,
id: ast::DUMMY_NODE_ID,
}],
tokens: None,
},
);
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
id: type_param.id,
span,
kind: generic_param,
tokens: None,
})))
})
.collect();
function_path.args =
Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
span,
args: generated_generic_types,
})));
let generic_args = generics
.params
.iter()
.filter_map(|p| match &p.kind {
GenericParamKind::Type { .. } => {
let path = ast::Path::from_ident(p.ident);
let ty = ecx.ty_path(path);
Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
}
_ => {}
}
GenericParamKind::Const { .. } => {
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
}
GenericParamKind::Lifetime { .. } => None,
})
.collect::<ThinVec<_>>();
let primal_call_expr = ecx.expr_path(primal_path);
ecx.expr_call(span, primal_call_expr, args)
}
let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
let segment = PathSegment {
ident,
id: ast::DUMMY_NODE_ID,
args: Some(Box::new(GenericArgs::AngleBracketed(args))),
};
let segments = if is_impl {
thin_vec![
PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
segment,
]
} else {
thin_vec![segment]
};
let path = Path { span, segments, tokens: None };
ecx.expr_path(path)
}
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
@ -778,7 +603,7 @@ mod llvm_enzyme {
sig: &ast::FnSig,
x: &AutoDiffAttrs,
span: Span,
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
) -> ast::FnSig {
let dcx = ecx.sess.dcx();
let has_ret = has_ret(&sig.decl.output);
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
@ -790,7 +615,7 @@ mod llvm_enzyme {
found: num_activities,
});
// This is not the right signature, but we can continue parsing.
return (sig.clone(), vec![], vec![], true);
return sig.clone();
}
assert!(sig.decl.inputs.len() == x.input_activity.len());
assert!(has_ret == x.has_ret_activity());
@ -833,7 +658,7 @@ mod llvm_enzyme {
if errors {
// This is not the right signature, but we can continue parsing.
return (sig.clone(), new_inputs, idents, true);
return sig.clone();
}
let unsafe_activities = x
@ -1047,7 +872,7 @@ mod llvm_enzyme {
}
let d_sig = FnSig { header: d_header, decl: d_decl, span };
trace!("Generated signature: {:?}", d_sig);
(d_sig, new_inputs, idents, false)
d_sig
}
}

View file

@ -0,0 +1,132 @@
use rustc_ast as ast;
use rustc_ast::{ItemKind, VariantData};
use rustc_errors::MultiSpan;
use rustc_expand::base::{Annotatable, DummyResult, ExtCtxt};
use rustc_span::{Ident, Span, kw, sym};
use thin_vec::thin_vec;
use crate::deriving::generic::ty::{Bounds, Path, PathKind, Ty};
use crate::deriving::generic::{
BlockOrExpr, FieldlessVariantsStrategy, MethodDef, SubstructureFields, TraitDef,
combine_substructure,
};
use crate::deriving::pathvec_std;
use crate::errors;
/// Generate an implementation of the `From` trait, provided that `item`
/// is a struct or a tuple struct with exactly one field.
pub(crate) fn expand_deriving_from(
cx: &ExtCtxt<'_>,
span: Span,
mitem: &ast::MetaItem,
annotatable: &Annotatable,
push: &mut dyn FnMut(Annotatable),
is_const: bool,
) {
let Annotatable::Item(item) = &annotatable else {
cx.dcx().bug("derive(From) used on something else than an item");
};
// #[derive(From)] is currently usable only on structs with exactly one field.
let field = if let ItemKind::Struct(_, _, data) = &item.kind
&& let [field] = data.fields()
{
Some(field.clone())
} else {
None
};
let from_type = match &field {
Some(field) => Ty::AstTy(field.ty.clone()),
// We don't have a type to put into From<...> if we don't have a single field, so just put
// unit there.
None => Ty::Unit,
};
let path =
Path::new_(pathvec_std!(convert::From), vec![Box::new(from_type.clone())], PathKind::Std);
// Generate code like this:
//
// struct S(u32);
// #[automatically_derived]
// impl ::core::convert::From<u32> for S {
// #[inline]
// fn from(value: u32) -> S {
// Self(value)
// }
// }
let from_trait_def = TraitDef {
span,
path,
skip_path_as_bound: true,
needs_copy_as_bound_if_packed: false,
additional_bounds: Vec::new(),
supports_unions: false,
methods: vec![MethodDef {
name: sym::from,
generics: Bounds { bounds: vec![] },
explicit_self: false,
nonself_args: vec![(from_type, sym::value)],
ret_ty: Ty::Self_,
attributes: thin_vec![cx.attr_word(sym::inline, span)],
fieldless_variants_strategy: FieldlessVariantsStrategy::Default,
combine_substructure: combine_substructure(Box::new(|cx, span, substructure| {
let Some(field) = &field else {
let item_span = item.kind.ident().map(|ident| ident.span).unwrap_or(item.span);
let err_span = MultiSpan::from_spans(vec![span, item_span]);
let error = match &item.kind {
ItemKind::Struct(_, _, data) => {
cx.dcx().emit_err(errors::DeriveFromWrongFieldCount {
span: err_span,
multiple_fields: data.fields().len() > 1,
})
}
ItemKind::Enum(_, _, _) | ItemKind::Union(_, _, _) => {
cx.dcx().emit_err(errors::DeriveFromWrongTarget {
span: err_span,
kind: &format!("{} {}", item.kind.article(), item.kind.descr()),
})
}
_ => cx.dcx().bug("Invalid derive(From) ADT input"),
};
return BlockOrExpr::new_expr(DummyResult::raw_expr(span, Some(error)));
};
let self_kw = Ident::new(kw::SelfUpper, span);
let expr: Box<ast::Expr> = match substructure.fields {
SubstructureFields::StaticStruct(variant, _) => match variant {
// Self {
// field: value
// }
VariantData::Struct { .. } => cx.expr_struct_ident(
span,
self_kw,
thin_vec![cx.field_imm(
span,
field.ident.unwrap(),
cx.expr_ident(span, Ident::new(sym::value, span))
)],
),
// Self(value)
VariantData::Tuple(_, _) => cx.expr_call_ident(
span,
self_kw,
thin_vec![cx.expr_ident(span, Ident::new(sym::value, span))],
),
variant => {
cx.dcx().bug(format!("Invalid derive(From) ADT variant: {variant:?}"));
}
},
_ => cx.dcx().bug("Invalid derive(From) ADT input"),
};
BlockOrExpr::new_expr(expr)
})),
}],
associated_types: Vec::new(),
is_const,
is_staged_api_crate: cx.ecfg.features.staged_api(),
};
from_trait_def.expand(cx, mitem, annotatable, push);
}

View file

@ -2,7 +2,7 @@
//! when specifying impls to be derived.
pub(crate) use Ty::*;
use rustc_ast::{self as ast, Expr, GenericArg, GenericParamKind, Generics, SelfKind};
use rustc_ast::{self as ast, Expr, GenericArg, GenericParamKind, Generics, SelfKind, TyKind};
use rustc_expand::base::ExtCtxt;
use rustc_span::source_map::respan;
use rustc_span::{DUMMY_SP, Ident, Span, Symbol, kw};
@ -65,7 +65,7 @@ impl Path {
}
}
/// A type. Supports pointers, Self, and literals.
/// A type. Supports pointers, Self, literals, unit or an arbitrary AST path.
#[derive(Clone)]
pub(crate) enum Ty {
Self_,
@ -76,6 +76,8 @@ pub(crate) enum Ty {
Path(Path),
/// For () return types.
Unit,
/// An arbitrary type.
AstTy(Box<ast::Ty>),
}
pub(crate) fn self_ref() -> Ty {
@ -101,6 +103,7 @@ impl Ty {
let ty = ast::TyKind::Tup(ThinVec::new());
cx.ty(span, ty)
}
AstTy(ty) => ty.clone(),
}
}
@ -132,6 +135,10 @@ impl Ty {
cx.path_all(span, false, vec![self_ty], params)
}
Path(p) => p.to_path(cx, span, self_ty, generics),
AstTy(ty) => match &ty.kind {
TyKind::Path(_, path) => path.clone(),
_ => cx.dcx().span_bug(span, "non-path in a path in generic `derive`"),
},
Ref(..) => cx.dcx().span_bug(span, "ref in a path in generic `derive`"),
Unit => cx.dcx().span_bug(span, "unit in a path in generic `derive`"),
}

View file

@ -23,6 +23,7 @@ pub(crate) mod clone;
pub(crate) mod coerce_pointee;
pub(crate) mod debug;
pub(crate) mod default;
pub(crate) mod from;
pub(crate) mod hash;
#[path = "cmp/eq.rs"]

View file

@ -446,6 +446,24 @@ pub(crate) struct DefaultHasArg {
pub(crate) span: Span,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_derive_from_wrong_target)]
#[note(builtin_macros_derive_from_usage_note)]
pub(crate) struct DeriveFromWrongTarget<'a> {
#[primary_span]
pub(crate) span: MultiSpan,
pub(crate) kind: &'a str,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_derive_from_wrong_field_count)]
#[note(builtin_macros_derive_from_usage_note)]
pub(crate) struct DeriveFromWrongFieldCount {
#[primary_span]
pub(crate) span: MultiSpan,
pub(crate) multiple_fields: bool,
}
#[derive(Diagnostic)]
#[diag(builtin_macros_derive_macro_call)]
pub(crate) struct DeriveMacroCall {

View file

@ -139,6 +139,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
PartialEq: partial_eq::expand_deriving_partial_eq,
PartialOrd: partial_ord::expand_deriving_partial_ord,
CoercePointee: coerce_pointee::expand_deriving_coerce_pointee,
From: from::expand_deriving_from,
}
let client = rustc_proc_macro::bridge::client::Client::expand1(rustc_proc_macro::quote);

View file

@ -93,7 +93,6 @@ use gccjit::{CType, Context, OptimizationLevel};
#[cfg(feature = "master")]
use gccjit::{TargetInfo, Version};
use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@ -363,12 +362,7 @@ impl WriteBackendMethods for GccCodegenBackend {
_exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_functions: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
if !diff_functions.is_empty() {
unimplemented!();
}
back::lto::run_fat(cgcx, each_linked_rlib_for_lto, modules)
}

View file

@ -8,11 +8,12 @@ use rustc_middle::bug;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::{DebugInfo, OomStrategy};
use rustc_symbol_mangling::mangle_internal_symbol;
use smallvec::SmallVec;
use crate::builder::SBuilder;
use crate::declare::declare_simple_fn;
use crate::llvm::{self, False, True, Type, Value};
use crate::{SimpleCx, attributes, debuginfo};
use crate::{SimpleCx, attributes, debuginfo, llvm_util};
pub(crate) unsafe fn codegen(
tcx: TyCtxt<'_>,
@ -147,6 +148,20 @@ fn create_wrapper_function(
llvm::Visibility::from_generic(tcx.sess.default_visibility()),
ty,
);
let mut attrs = SmallVec::<[_; 2]>::new();
let target_cpu = llvm_util::target_cpu(tcx.sess);
let target_cpu_attr = llvm::CreateAttrStringValue(cx.llcx, "target-cpu", target_cpu);
let tune_cpu_attr = llvm_util::tune_cpu(tcx.sess)
.map(|tune_cpu| llvm::CreateAttrStringValue(cx.llcx, "tune-cpu", tune_cpu));
attrs.push(target_cpu_attr);
attrs.extend(tune_cpu_attr);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Function, &attrs);
let no_return = if no_return {
// -> ! DIFlagNoReturn
let no_return = llvm::AttributeKind::NoReturn.create_attr(cx.llcx);

View file

@ -28,22 +28,6 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[
}
}
pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool {
llvm::HasAttributeAtIndex(llfn, idx, attr)
}
pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
llvm::HasStringAttribute(llfn, name)
}
pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
}
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
llvm::RemoveStringAttrFromFn(llfn, name);
}
/// Get LLVM attribute for the provided inline heuristic.
#[inline]
fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> {

View file

@ -24,9 +24,8 @@ use crate::back::write::{
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
};
use crate::errors::{LlvmError, LtoBitcodeFromRlib};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx};
/// We keep track of the computed LTO cache keys from the previous
/// session to determine which CGUs we can reuse.
@ -593,31 +592,6 @@ pub(crate) fn run_pass_manager(
}
if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
for function in cx.get_functions() {
let enzyme_marker = "enzyme_marker";
if attributes::has_string_attr(function, enzyme_marker) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, enzyme_marker);
assert!(
!attributes::has_string_attr(function, enzyme_marker),
"Expected function to not have 'enzyme_marker'"
);
let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
attributes::apply_to_llfn(function, Function, &[always_inline]);
}
}
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {

View file

@ -862,7 +862,7 @@ pub(crate) fn codegen(
.generic_activity_with_arg("LLVM_module_codegen_embed_bitcode", &*module.name);
let thin_bc =
module.thin_lto_buffer.as_deref().expect("cannot find embedded bitcode");
embed_bitcode(cgcx, llcx, llmod, &config.bc_cmdline, &thin_bc);
embed_bitcode(cgcx, llcx, llmod, &thin_bc);
}
}
@ -1058,7 +1058,6 @@ fn embed_bitcode(
cgcx: &CodegenContext<LlvmCodegenBackend>,
llcx: &llvm::Context,
llmod: &llvm::Module,
cmdline: &str,
bitcode: &[u8],
) {
// We're adding custom sections to the output object file, but we definitely
@ -1074,7 +1073,9 @@ fn embed_bitcode(
// * Mach-O - this is for macOS. Inspecting the source code for the native
// linker here shows that the `.llvmbc` and `.llvmcmd` sections are
// automatically skipped by the linker. In that case there's nothing extra
// that we need to do here.
// that we need to do here. We do need to make sure that the
// `__LLVM,__cmdline` section exists even though it is empty as otherwise
// ld64 rejects the object file.
//
// * Wasm - the native LLD linker is hard-coded to skip `.llvmbc` and
// `.llvmcmd` sections, so there's nothing extra we need to do.
@ -1111,7 +1112,7 @@ fn embed_bitcode(
llvm::set_linkage(llglobal, llvm::Linkage::PrivateLinkage);
llvm::LLVMSetGlobalConstant(llglobal, llvm::True);
let llconst = common::bytes_in_context(llcx, cmdline.as_bytes());
let llconst = common::bytes_in_context(llcx, &[]);
let llglobal = llvm::add_global(llmod, common::val_ty(llconst), c"rustc.embedded.cmdline");
llvm::set_initializer(llglobal, llconst);
let section = if cgcx.target_is_like_darwin {
@ -1128,7 +1129,7 @@ fn embed_bitcode(
let section_flags = if cgcx.is_pe_coff { "n" } else { "e" };
let asm = create_section_with_flags_asm(".llvmbc", section_flags, bitcode);
llvm::append_module_inline_asm(llmod, &asm);
let asm = create_section_with_flags_asm(".llvmcmd", section_flags, cmdline.as_bytes());
let asm = create_section_with_flags_asm(".llvmcmd", section_flags, &[]);
llvm::append_module_inline_asm(llmod, &asm);
}
}

View file

@ -1696,7 +1696,11 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
return;
}
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]);
if crate::llvm_util::get_version() >= (22, 0, 0) {
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[ptr]);
} else {
self.call_intrinsic(intrinsic, &[self.val_ty(ptr)], &[self.cx.const_u64(size), ptr]);
}
}
}
impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {

View file

@ -1,40 +1,92 @@
use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_errors::FatalError;
use rustc_middle::bug;
use tracing::{debug, trace};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, ty};
use tracing::debug;
use crate::back::write::llvm_err;
use crate::builder::{SBuilder, UNNAMED};
use crate::builder::{Builder, PlaceRef, UNNAMED};
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::llvm;
use crate::llvm::{Metadata, True, Type};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
fn get_params(fnc: &Value) -> Vec<&Value> {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
unsafe {
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
fnc_args.set_len(param_num);
pub(crate) fn adjust_activity_to_abi<'tcx>(
tcx: TyCtxt<'tcx>,
fn_ty: Ty<'tcx>,
da: &mut Vec<DiffActivity>,
) {
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
}
fnc_args
}
fn has_sret(fnc: &Value) -> bool {
let num_args = llvm::LLVMCountParams(fnc) as usize;
if num_args == 0 {
false
} else {
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let mut new_activities = vec![];
let mut new_positions = vec![];
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};
let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}
}
@ -48,14 +100,13 @@ fn has_sret(fnc: &Value) -> bool {
// need to match those.
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
cx: &SimpleCx<'ll>,
builder: &SBuilder<'ll, 'll>,
builder: &mut Builder<'_, 'll, 'tcx>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
outer_args: &[&'ll llvm::Value],
has_sret: bool,
) {
debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@ -67,14 +118,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
let mut outer_pos: usize = 0;
let mut activity_pos = 0;
if has_sret {
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
// inner function will still return something. We increase our outer_pos by one,
// and once we're done with all other args we will take the return of the inner call and
// update the sret pointer with it
outer_pos = 1;
}
let enzyme_const = cx.create_metadata(b"enzyme_const");
let enzyme_out = cx.create_metadata(b"enzyme_out");
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
@ -193,92 +236,6 @@ fn match_args_from_caller_to_enzyme<'ll>(
}
}
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
// Beyond sret, this article describes our challenges nicely:
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
fn compute_enzyme_fn_ty<'ll>(
cx: &SimpleCx<'ll>,
attrs: &AutoDiffAttrs,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
) -> &'ll llvm::Type {
let fn_ty = cx.get_type_of_global(outer_fn);
let mut ret_ty = cx.get_return_type(fn_ty);
let has_sret = has_sret(outer_fn);
if has_sret {
// Now we don't just forward the return type, so we have to figure it out based on the
// primal return type, in combination with the autodiff settings.
let fn_ty = cx.get_type_of_global(fn_to_diff);
let inner_ret_ty = cx.get_return_type(fn_ty);
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
if inner_ret_ty == void_ty {
// This indicates that even the inner function has an sret.
// Right now I only look for an sret in the outer function.
// This *probably* needs some extra handling, but I never ran
// into such a case. So I'll wait for user reports to have a test case.
bug!("sret in inner function");
}
if attrs.width == 1 {
// Enzyme returns a struct of style:
// `{ original_ret(if requested), float, float, ... }`
let mut struct_elements = vec![];
if attrs.has_primal_ret() {
struct_elements.push(inner_ret_ty);
}
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
// and therefore part of the return struct.
let param_tys = cx.func_params_types(fn_ty);
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
if matches!(act, DiffActivity::Active) {
// Now find the float type at position i based on the fn_ty,
// to know what (f16/f32/f64/...) to add to the struct.
struct_elements.push(param_ty);
}
}
ret_ty = cx.type_struct(&struct_elements, false);
} else {
// First we check if we also have to deal with the primal return.
match attrs.mode {
DiffMode::Forward => match attrs.ret_activity {
DiffActivity::Dual => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
ret_ty = arr_ty;
}
DiffActivity::DualOnly => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
ret_ty = arr_ty;
}
DiffActivity::Const => {
todo!("Not sure, do we need to do something here?");
}
_ => {
bug!("unreachable");
}
},
DiffMode::Reverse => {
todo!("Handle sret for reverse mode");
}
_ => {
bug!("unreachable");
}
}
}
}
// LLVM can figure out the input types on it's own, so we take a shortcut here.
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
}
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
/// function with expected naming and calling conventions[^1] which will be
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@ -288,11 +245,15 @@ fn compute_enzyme_fn_ty<'ll>(
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
fn generate_enzyme_call<'ll>(
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
outer_name: &str,
ret_ty: &'ll Type,
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
dest: PlaceRef<'tcx, &'ll Value>,
) {
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
let mut ad_name: String = match attrs.mode {
@ -302,11 +263,9 @@ fn generate_enzyme_call<'ll>(
}
.to_string();
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
// add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
let name = llvm::get_value_name(outer_fn);
let outer_fn_name = std::str::from_utf8(&name).unwrap();
ad_name.push_str(outer_fn_name);
ad_name.push_str(outer_name);
// Let us assume the user wrote the following function square:
//
@ -316,14 +275,8 @@ fn generate_enzyme_call<'ll>(
// %0 = fmul double %x, %x
// ret double %0
// }
// ```
//
// The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
// Our macro generates the following placeholder code (slightly simplified):
//
// ```llvm
// define double @dsquare(double %x) {
// ; placeholder code
// return 0.0;
// }
// ```
@ -340,175 +293,44 @@ fn generate_enzyme_call<'ll>(
// ret double %0
// }
// ```
unsafe {
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) };
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
llvm::UnnamedAddr::No,
llvm::Visibility::Default,
enzyme_ty,
);
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
llvm::UnnamedAddr::No,
llvm::Visibility::Default,
enzyme_ty,
);
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
// do it's work.
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);
// We add a made-up attribute just such that we can recognize it after AD to update
// (no)-inline attributes. We'll then also remove this attribute.
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
// first, remove all calls from fnc
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
let br = llvm::LLVMRustGetTerminator(entry);
llvm::LLVMRustEraseInstFromParent(br);
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
let mut builder = SBuilder::build(cx, entry);
let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
args.push(fn_to_diff);
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
if attrs.width > 1 {
let enzyme_width = cx.create_metadata(b"enzyme_width");
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
}
let has_sret = has_sret(outer_fn);
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
match_args_from_caller_to_enzyme(
&cx,
&builder,
attrs.width,
&mut args,
&attrs.input_activity,
&outer_args,
has_sret,
);
let call = builder.call(enzyme_ty, ad_fn, &args, None);
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attached to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
// dummy code which we inserted at a higher level.
// FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
// and how to best improve it for enzyme core and rust-enzyme.
let md_ty = cx.get_md_kind_id("dbg");
if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
.expect("failed to get instruction metadata");
let md_todiff = cx.get_metadata_value(md);
llvm::LLVMSetMetadata(call, md_ty, md_todiff);
} else {
// We don't panic, since depending on whether we are in debug or release mode, we might
// have no debug info to copy, which would then be ok.
trace!("no dbg info");
}
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
if cx.val_ty(call) == cx.type_void() || has_sret {
if has_sret {
// This is what we already have in our outer_fn (shortened):
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
// <Here we are, we want to add the following two lines>
// store [4 x double] %7, ptr %0, align 8
// ret void
// }
// now store the result of the enzyme call into the sret pointer.
let sret_ptr = outer_args[0];
let call_ty = cx.val_ty(call);
if attrs.width == 1 {
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
} else {
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
}
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
}
builder.ret_void();
} else {
builder.ret(call);
}
// Let's crash in case that we messed something up above and generated invalid IR.
llvm::LLVMRustVerifyFunction(
outer_fn,
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
);
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
}
pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
) -> Result<(), FatalError> {
for item in &diff_items {
trace!("{}", item);
}
let diag_handler = cgcx.create_dcx();
let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
if !diff_items.is_empty()
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
{
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}
// Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
for item in diff_items.iter() {
let name = item.source.clone();
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
let Some(fn_def) = fn_def else {
return Err(llvm_err(
diag_handler.handle(),
LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find source function".to_owned(),
},
));
};
debug!(?item.target);
let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
let Some(fn_target) = fn_target else {
return Err(llvm_err(
diag_handler.handle(),
LlvmError::PrepareAutoDiff {
src: item.source.clone(),
target: item.target.clone(),
error: "could not find target function".to_owned(),
},
));
};
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
}
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
trace!("done with differentiate()");
Ok(())
if attrs.width > 1 {
let enzyme_width = cx.create_metadata(b"enzyme_width");
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
}
match_args_from_caller_to_enzyme(
&cx,
builder,
attrs.width,
&mut args,
&attrs.input_activity,
fn_args,
);
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
builder.store_to_place(call, dest.val);
}

View file

@ -8,7 +8,6 @@ use std::str;
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
use rustc_codegen_ssa::back::versioned_llvm_target;
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::errors as ssa_errors;
use rustc_codegen_ssa::traits::*;
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@ -660,10 +659,6 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
}
}
impl<'ll> SimpleCx<'ll> {
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
assert_eq!(self.type_kind(ty), TypeKind::Function);
unsafe { llvm::LLVMGetReturnType(ty) }
}
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
unsafe { llvm::LLVMGlobalGetValueType(val) }
}
@ -727,16 +722,6 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
}
}
pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
let mut functions = vec![];
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
while let Some(f) = func {
functions.push(f);
func = unsafe { llvm::LLVMGetNextFunction(f) }
}
functions
}
}
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

View file

@ -3,24 +3,29 @@ use std::cmp::Ordering;
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::{self as hir};
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, span_bug};
use rustc_span::{Span, Symbol, sym};
use rustc_symbol_mangling::mangle_internal_symbol;
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
use rustc_target::callconv::PassMode;
use rustc_target::spec::PanicStrategy;
use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::context::CodegenCx;
use crate::errors::AutoDiffWithoutEnable;
use crate::llvm::{self, Metadata};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
@ -189,6 +194,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
&[ptr, args[1].immediate()],
)
}
sym::autodiff => {
codegen_autodiff(self, tcx, instance, args, result);
return Ok(());
}
sym::is_val_statically_known => {
if let OperandValue::Immediate(imm) = args[0].val {
self.call_intrinsic(
@ -1113,6 +1122,143 @@ fn get_rust_try_fn<'a, 'll, 'tcx>(
rust_try
}
fn codegen_autodiff<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
result: PlaceRef<'tcx, &'ll Value>,
) {
if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) {
let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable);
}
let fn_args = instance.args;
let callee_ty = instance.ty(tcx, bx.typing_env());
let sig = callee_ty.fn_sig(tcx).skip_binder();
let ret_ty = sig.output();
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
// Get source, diff, and attrs
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, source_params) => (def_id, source_params),
_ => bug!("invalid autodiff intrinsic args"),
};
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
source_id,
source_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
bug!("could not find source function")
};
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
_ => bug!("invalid args"),
};
let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
diff_id,
diff_args
),
Err(_) => {
// An error has already been emitted
return;
}
};
let val_arr = get_args_from_tuple(bx, args[2], fn_diff);
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else {
bug!("could not find autodiff attrs")
};
adjust_activity_to_abi(
tcx,
fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
&mut diff_attrs.input_activity,
);
// Build body
generate_enzyme_call(
bx,
bx.cx,
fn_to_diff,
&diff_symbol,
llret_ty,
&val_arr,
diff_attrs.clone(),
result,
);
}
fn get_args_from_tuple<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tuple_op: OperandRef<'tcx, &'ll Value>,
fn_instance: Instance<'tcx>,
) -> Vec<&'ll Value> {
let cx = bx.cx;
let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty());
match tuple_op.val {
OperandValue::Immediate(val) => vec![val],
OperandValue::Pair(v1, v2) => vec![v1, v2],
OperandValue::Ref(ptr) => {
let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout };
let mut result = Vec::with_capacity(fn_abi.args.len());
let mut tuple_index = 0;
for arg in &fn_abi.args {
match arg.mode {
PassMode::Ignore => {}
PassMode::Direct(_) | PassMode::Cast { .. } => {
let field = tuple_place.project_field(bx, tuple_index);
let llvm_ty = field.layout.llvm_type(bx.cx);
let val = bx.load(llvm_ty, field.val.llval, field.val.align);
result.push(val);
tuple_index += 1;
}
PassMode::Pair(_, _) => {
let field = tuple_place.project_field(bx, tuple_index);
let llvm_ty = field.layout.llvm_type(bx.cx);
let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
result.push(bx.extract_value(pair_val, 0));
result.push(bx.extract_value(pair_val, 1));
tuple_index += 1;
}
PassMode::Indirect { .. } => {
let field = tuple_place.project_field(bx, tuple_index);
result.push(field.val.llval);
tuple_index += 1;
}
}
}
result
}
OperandValue::ZeroSized => vec![],
}
}
fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,

View file

@ -30,7 +30,6 @@ use context::SimpleCx;
use errors::ParseTargetMachineConfig;
use llvm_util::target_config;
use rustc_ast::expand::allocator::AllocatorKind;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule};
use rustc_codegen_ssa::back::write::{
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn,
@ -173,15 +172,10 @@ impl WriteBackendMethods for LlvmCodegenBackend {
exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_fncs: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
let mut module =
back::lto::run_fat(cgcx, exported_symbols_for_lto, each_linked_rlib_for_lto, modules)?;
if !diff_fncs.is_empty() {
builder::autodiff::differentiate(&module, cgcx, diff_fncs)?;
}
let dcx = cgcx.create_dcx();
let dcx = dcx.handle();
back::lto::run_pass_manager(cgcx, dcx, &mut module, false)?;

View file

@ -42,32 +42,6 @@ pub(crate) fn AddFunctionAttributes<'ll>(
}
}
pub(crate) fn HasAttributeAtIndex<'ll>(
llfn: &'ll Value,
idx: AttributePlace,
kind: AttributeKind,
) -> bool {
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
}
pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}
pub(crate) fn RemoveRustEnumAttributeAtIndex(
llfn: &Value,
place: AttributePlace,
kind: AttributeKind,
) {
unsafe {
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
}
}
pub(crate) fn AddCallSiteAttributes<'ll>(
callsite: &'ll Value,
idx: AttributePlace,

View file

@ -26,6 +26,7 @@ rustc_hashes = { path = "../rustc_hashes" }
rustc_hir = { path = "../rustc_hir" }
rustc_incremental = { path = "../rustc_incremental" }
rustc_index = { path = "../rustc_index" }
rustc_lint_defs = { path = "../rustc_lint_defs" }
rustc_macros = { path = "../rustc_macros" }
rustc_metadata = { path = "../rustc_metadata" }
rustc_middle = { path = "../rustc_middle" }

View file

@ -7,7 +7,6 @@ use std::{fs, io, mem, str, thread};
use rustc_abi::Size;
use rustc_ast::attr;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::jobserver::{self, Acquired};
use rustc_data_structures::memmap::Mmap;
@ -38,7 +37,7 @@ use tracing::debug;
use super::link::{self, ensure_removed};
use super::lto::{self, SerializedModule};
use crate::back::lto::check_lto_allowed;
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
use crate::errors::ErrorCreatingRemarkDir;
use crate::traits::*;
use crate::{
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@ -76,12 +75,9 @@ pub struct ModuleConfig {
/// Names of additional optimization passes to run.
pub passes: Vec<String>,
/// Some(level) to optimize at a certain level, or None to run
/// absolutely no optimizations (used for the metadata module).
/// absolutely no optimizations (used for the allocator module).
pub opt_level: Option<config::OptLevel>,
/// Some(level) to optimize binary size, or None to not affect program size.
pub opt_size: Option<config::OptLevel>,
pub pgo_gen: SwitchWithOptPath,
pub pgo_use: Option<PathBuf>,
pub pgo_sample_use: Option<PathBuf>,
@ -102,7 +98,6 @@ pub struct ModuleConfig {
pub emit_obj: EmitObj,
pub emit_thin_lto: bool,
pub emit_thin_lto_summary: bool,
pub bc_cmdline: String,
// Miscellaneous flags. These are mostly copied from command-line
// options.
@ -110,7 +105,6 @@ pub struct ModuleConfig {
pub lint_llvm_ir: bool,
pub no_prepopulate_passes: bool,
pub no_builtins: bool,
pub time_module: bool,
pub vectorize_loop: bool,
pub vectorize_slp: bool,
pub merge_functions: bool,
@ -171,7 +165,6 @@ impl ModuleConfig {
passes: if_regular!(sess.opts.cg.passes.clone(), vec![]),
opt_level: opt_level_and_size,
opt_size: opt_level_and_size,
pgo_gen: if_regular!(
sess.opts.cg.profile_generate.clone(),
@ -221,17 +214,12 @@ impl ModuleConfig {
sess.opts.output_types.contains_key(&OutputType::ThinLinkBitcode),
false
),
bc_cmdline: sess.target.bitcode_llvm_cmdline.to_string(),
verify_llvm_ir: sess.verify_llvm_ir(),
lint_llvm_ir: sess.opts.unstable_opts.lint_llvm_ir,
no_prepopulate_passes: sess.opts.cg.no_prepopulate_passes,
no_builtins: no_builtins || sess.target.no_builtins,
// Exclude metadata and allocator modules from time_passes output,
// since they throw off the "LLVM passes" measurement.
time_module: if_regular!(true, false),
// Copy what clang does by turning on loop vectorization at O2 and
// slp vectorization at O3.
vectorize_loop: !sess.opts.cg.no_vectorize_loops
@ -454,7 +442,6 @@ pub(crate) fn start_async_codegen<B: ExtraBackendMethods>(
backend: B,
tcx: TyCtxt<'_>,
target_cpu: String,
autodiff_items: &[AutoDiffItem],
) -> OngoingCodegen<B> {
let (coordinator_send, coordinator_receive) = channel();
@ -473,7 +460,6 @@ pub(crate) fn start_async_codegen<B: ExtraBackendMethods>(
backend.clone(),
tcx,
&crate_info,
autodiff_items,
shared_emitter,
codegen_worker_send,
coordinator_receive,
@ -728,7 +714,6 @@ pub(crate) enum WorkItem<B: WriteBackendMethods> {
each_linked_rlib_for_lto: Vec<PathBuf>,
needs_fat_lto: Vec<FatLtoInput<B>>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
autodiff: Vec<AutoDiffItem>,
},
/// Performs thin-LTO on the given module.
ThinLto(lto::ThinModule<B>),
@ -1001,7 +986,6 @@ fn execute_fat_lto_work_item<B: ExtraBackendMethods>(
each_linked_rlib_for_lto: &[PathBuf],
mut needs_fat_lto: Vec<FatLtoInput<B>>,
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
autodiff: Vec<AutoDiffItem>,
module_config: &ModuleConfig,
) -> Result<WorkItemResult<B>, FatalError> {
for (module, wp) in import_only_modules {
@ -1013,7 +997,6 @@ fn execute_fat_lto_work_item<B: ExtraBackendMethods>(
exported_symbols_for_lto,
each_linked_rlib_for_lto,
needs_fat_lto,
autodiff,
)?;
let module = B::codegen(cgcx, module, module_config)?;
Ok(WorkItemResult::Finished(module))
@ -1105,7 +1088,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
backend: B,
tcx: TyCtxt<'_>,
crate_info: &CrateInfo,
autodiff_items: &[AutoDiffItem],
shared_emitter: SharedEmitter,
codegen_worker_send: Sender<CguMessage>,
coordinator_receive: Receiver<Message<B>>,
@ -1115,7 +1097,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
) -> thread::JoinHandle<Result<CompiledModules, ()>> {
let coordinator_send = tx_to_llvm_workers;
let sess = tcx.sess;
let autodiff_items = autodiff_items.to_vec();
let mut each_linked_rlib_for_lto = Vec::new();
let mut each_linked_rlib_file_for_lto = Vec::new();
@ -1448,7 +1429,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
each_linked_rlib_for_lto: each_linked_rlib_file_for_lto,
needs_fat_lto,
import_only_modules,
autodiff: autodiff_items.clone(),
},
0,
));
@ -1456,11 +1436,6 @@ fn start_executing_work<B: ExtraBackendMethods>(
helper.request_token();
}
} else {
if !autodiff_items.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
for (work, cost) in generate_thin_lto_work(
&cgcx,
&exported_symbols_for_lto,
@ -1740,7 +1715,7 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
llvm_start_time: &mut Option<VerboseTimingGuard<'a>>,
work: WorkItem<B>,
) {
if cgcx.config(work.module_kind()).time_module && llvm_start_time.is_none() {
if llvm_start_time.is_none() {
*llvm_start_time = Some(cgcx.prof.verbose_generic_activity("LLVM_passes"));
}
@ -1795,7 +1770,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
each_linked_rlib_for_lto,
needs_fat_lto,
import_only_modules,
autodiff,
} => {
let _timer = cgcx
.prof
@ -1806,7 +1780,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>(
&each_linked_rlib_for_lto,
needs_fat_lto,
import_only_modules,
autodiff,
module_config,
)
}

View file

@ -647,7 +647,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
) -> OngoingCodegen<B> {
// Skip crate items and just output metadata in -Z no-codegen mode.
if tcx.sess.opts.unstable_opts.no_codegen || !tcx.sess.opts.output_types.should_codegen() {
let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu, &[]);
let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu);
ongoing_codegen.codegen_finished(tcx);
@ -665,8 +665,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
// Run the monomorphization collector and partition the collected items into
// codegen units.
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
tcx.collect_and_partition_mono_items(());
let MonoItemPartitions { codegen_units, .. } = tcx.collect_and_partition_mono_items(());
// Force all codegen_unit queries so they are already either red or green
// when compile_codegen_unit accesses them. We are not able to re-execute
@ -679,34 +678,7 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
}
}
let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu, autodiff_items);
// Codegen an allocator shim, if necessary.
if let Some(kind) = allocator_kind_for_codegen(tcx) {
let llmod_id =
cgu_name_builder.build_cgu_name(LOCAL_CRATE, &["crate"], Some("allocator")).to_string();
let module_llvm = tcx.sess.time("write_allocator_module", || {
backend.codegen_allocator(
tcx,
&llmod_id,
kind,
// If allocator_kind is Some then alloc_error_handler_kind must
// also be Some.
tcx.alloc_error_handler_kind(()).unwrap(),
)
});
ongoing_codegen.wait_for_signal_to_codegen_item();
ongoing_codegen.check_for_errors(tcx.sess);
// These modules are generally cheap and won't throw off scheduling.
let cost = 0;
submit_codegened_module_to_llvm(
&ongoing_codegen.coordinator,
ModuleCodegen::new_allocator(llmod_id, module_llvm),
cost,
);
}
let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu);
// For better throughput during parallel processing by LLVM, we used to sort
// CGUs largest to smallest. This would lead to better thread utilization
@ -823,6 +795,35 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
}
}
// Codegen an allocator shim, if necessary.
// Do this last to ensure the LLVM_passes timer doesn't start while no CGUs have been codegened
// yet for the backend to optimize.
if let Some(kind) = allocator_kind_for_codegen(tcx) {
let llmod_id =
cgu_name_builder.build_cgu_name(LOCAL_CRATE, &["crate"], Some("allocator")).to_string();
let module_llvm = tcx.sess.time("write_allocator_module", || {
backend.codegen_allocator(
tcx,
&llmod_id,
kind,
// If allocator_kind is Some then alloc_error_handler_kind must
// also be Some.
tcx.alloc_error_handler_kind(()).unwrap(),
)
});
ongoing_codegen.wait_for_signal_to_codegen_item();
ongoing_codegen.check_for_errors(tcx.sess);
// These modules are generally cheap and won't throw off scheduling.
let cost = 0;
submit_codegened_module_to_llvm(
&ongoing_codegen.coordinator,
ModuleCodegen::new_allocator(llmod_id, module_llvm),
cost,
);
}
ongoing_codegen.codegen_finished(tcx);
// Since the main thread is sometimes blocked during codegen, we keep track

View file

@ -176,14 +176,6 @@ fn process_builtin_attrs(
let mut interesting_spans = InterestingAttributeDiagnosticSpans::default();
let rust_target_features = tcx.rust_target_features(LOCAL_CRATE);
// If our rustc version supports autodiff/enzyme, then we call our handler
// to check for any `#[rustc_autodiff(...)]` attributes.
// FIXME(jdonszelmann): merge with loop below
if cfg!(llvm_enzyme) {
let ad = autodiff_attrs(tcx, did.into());
codegen_fn_attrs.autodiff_item = ad;
}
for attr in attrs.iter() {
if let hir::Attribute::Parsed(p) = attr {
match p {
@ -609,7 +601,7 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Align> {
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
/// panic, unless we introduced a bug when parsing the autodiff macro.
//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed.
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();

View file

@ -5,6 +5,7 @@ use rustc_ast as ast;
use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_data_structures::packed::Pu128;
use rustc_hir::lang_items::LangItem;
use rustc_lint_defs::builtin::TAIL_CALL_TRACK_CALLER;
use rustc_middle::mir::{self, AssertKind, InlineAsmMacro, SwitchTargets, UnwindTerminateReason};
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf, ValidityRequirement};
use rustc_middle::ty::print::{with_no_trimmed_paths, with_no_visible_paths};
@ -35,7 +36,7 @@ enum MergingSucc {
True,
}
/// Indicates to the call terminator codegen whether a cal
/// Indicates to the call terminator codegen whether a call
/// is a normal call or an explicit tail call.
#[derive(Debug, PartialEq)]
enum CallKind {
@ -906,7 +907,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn_span,
);
let instance = match instance.def {
match instance.def {
// We don't need AsyncDropGlueCtorShim here because it is not `noop func`,
// it is `func returning noop future`
ty::InstanceKind::DropGlue(_, None) => {
@ -995,14 +996,35 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
intrinsic.name,
);
}
instance
(Some(instance), None)
}
}
}
_ => instance,
};
(Some(instance), None)
_ if kind == CallKind::Tail
&& instance.def.requires_caller_location(bx.tcx()) =>
{
if let Some(hir_id) =
terminator.source_info.scope.lint_root(&self.mir.source_scopes)
{
let msg = "tail calling a function marked with `#[track_caller]` has no special effect";
bx.tcx().node_lint(TAIL_CALL_TRACK_CALLER, hir_id, |d| {
_ = d.primary_message(msg).span(fn_span)
});
}
let instance = ty::Instance::resolve_for_fn_ptr(
bx.tcx(),
bx.typing_env(),
def_id,
generic_args,
)
.unwrap();
(None, Some(bx.get_fn_addr(instance)))
}
_ => (Some(instance), None),
}
}
ty::FnPtr(..) => (None, Some(callee.immediate())),
_ => bug!("{} is not callable", callee.layout.ty),

View file

@ -1,6 +1,5 @@
use std::path::PathBuf;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_errors::{DiagCtxtHandle, FatalError};
use rustc_middle::dep_graph::WorkProduct;
@ -23,7 +22,6 @@ pub trait WriteBackendMethods: Clone + 'static {
exported_symbols_for_lto: &[String],
each_linked_rlib_for_lto: &[PathBuf],
modules: Vec<FatLtoInput<Self>>,
diff_fncs: Vec<AutoDiffItem>,
) -> Result<ModuleCodegen<Self::Module>, FatalError>;
/// Performs thin LTO by performing necessary global analysis and returning two
/// lists, one of the modules that need optimization and another for modules that

View file

@ -57,7 +57,7 @@ const_eval_const_context = {$kind ->
}
const_eval_const_heap_ptr_in_final = encountered `const_allocate` pointer in final value that was not made global
.note = use `const_make_global` to make allocated pointers immutable before returning
.note = use `const_make_global` to turn allocated pointers into immutable globals before returning
const_eval_const_make_global_ptr_already_made_global = attempting to call `const_make_global` twice on the same allocation {$alloc}
@ -231,6 +231,9 @@ const_eval_mutable_borrow_escaping =
const_eval_mutable_ptr_in_final = encountered mutable pointer in final value of {const_eval_intern_kind}
const_eval_partial_pointer_in_final = encountered partial pointer in final value of {const_eval_intern_kind}
.note = while pointers can be broken apart into individual bytes during const-evaluation, only complete pointers (with all their bytes in the right order) are supported in the final value
const_eval_nested_static_in_thread_local = #[thread_local] does not support implicit nested statics, please create explicit static items and refer to them instead
const_eval_non_const_await =
@ -299,10 +302,8 @@ const_eval_panic = evaluation panicked: {$msg}
const_eval_panic_non_str = argument to `panic!()` in a const context must have type `&str`
const_eval_partial_pointer_copy =
unable to copy parts of a pointer from memory at {$ptr}
const_eval_partial_pointer_overwrite =
unable to overwrite parts of a pointer in memory at {$ptr}
const_eval_partial_pointer_read =
unable to read parts of a pointer from memory at {$ptr}
const_eval_pointer_arithmetic_overflow =
overflowing pointer arithmetic: the total offset in bytes does not fit in an `isize`

View file

@ -827,7 +827,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
// At this point, we are calling a function, `callee`, whose `DefId` is known...
// `begin_panic` and `#[rustc_const_panic_str]` functions accept generic
// `begin_panic` and `panic_display` functions accept generic
// types other than str. Check to enforce that only str can be used in
// const-eval.
@ -841,8 +841,8 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
return;
}
// const-eval of `#[rustc_const_panic_str]` functions assumes the argument is `&&str`
if tcx.has_attr(callee, sym::rustc_const_panic_str) {
// const-eval of `panic_display` assumes the argument is `&&str`
if tcx.is_lang_item(callee, LangItem::PanicDisplay) {
match args[0].node.ty(&self.ccx.body.local_decls, tcx).kind() {
ty::Ref(_, ty, _) if matches!(ty.kind(), ty::Ref(_, ty, _) if ty.is_str()) =>
{}

View file

@ -117,6 +117,13 @@ fn eval_body_using_ecx<'tcx, R: InterpretationResult<'tcx>>(
ecx.tcx.dcx().emit_err(errors::ConstHeapPtrInFinal { span: ecx.tcx.span }),
)));
}
Err(InternError::PartialPointer) => {
throw_inval!(AlreadyReported(ReportedErrorInfo::non_const_eval_error(
ecx.tcx
.dcx()
.emit_err(errors::PartialPtrInFinal { span: ecx.tcx.span, kind: intern_kind }),
)));
}
}
interp_ok(R::make_result(ret, ecx))

View file

@ -237,7 +237,7 @@ impl<'tcx> CompileTimeInterpCx<'tcx> {
) -> InterpResult<'tcx, Option<ty::Instance<'tcx>>> {
let def_id = instance.def_id();
if self.tcx.has_attr(def_id, sym::rustc_const_panic_str)
if self.tcx.is_lang_item(def_id, LangItem::PanicDisplay)
|| self.tcx.is_lang_item(def_id, LangItem::BeginPanic)
{
let args = self.copy_fn_args(args);

View file

@ -51,6 +51,15 @@ pub(crate) struct ConstHeapPtrInFinal {
pub span: Span,
}
#[derive(Diagnostic)]
#[diag(const_eval_partial_pointer_in_final)]
#[note]
pub(crate) struct PartialPtrInFinal {
#[primary_span]
pub span: Span,
pub kind: InternKind,
}
#[derive(Diagnostic)]
#[diag(const_eval_unstable_in_stable_exposed)]
pub(crate) struct UnstableInStableExposed {
@ -836,8 +845,7 @@ impl ReportErrorExt for UnsupportedOpInfo {
UnsupportedOpInfo::Unsupported(s) => s.clone().into(),
UnsupportedOpInfo::ExternTypeField => const_eval_extern_type_field,
UnsupportedOpInfo::UnsizedLocal => const_eval_unsized_local,
UnsupportedOpInfo::OverwritePartialPointer(_) => const_eval_partial_pointer_overwrite,
UnsupportedOpInfo::ReadPartialPointer(_) => const_eval_partial_pointer_copy,
UnsupportedOpInfo::ReadPartialPointer(_) => const_eval_partial_pointer_read,
UnsupportedOpInfo::ReadPointerAsInt(_) => const_eval_read_pointer_as_int,
UnsupportedOpInfo::ThreadLocalStatic(_) => const_eval_thread_local_static,
UnsupportedOpInfo::ExternStatic(_) => const_eval_extern_static,
@ -848,7 +856,7 @@ impl ReportErrorExt for UnsupportedOpInfo {
use UnsupportedOpInfo::*;
use crate::fluent_generated::*;
if let ReadPointerAsInt(_) | OverwritePartialPointer(_) | ReadPartialPointer(_) = self {
if let ReadPointerAsInt(_) | ReadPartialPointer(_) = self {
diag.help(const_eval_ptr_as_bytes_1);
diag.help(const_eval_ptr_as_bytes_2);
}
@ -860,7 +868,7 @@ impl ReportErrorExt for UnsupportedOpInfo {
| UnsupportedOpInfo::ExternTypeField
| Unsupported(_)
| ReadPointerAsInt(_) => {}
OverwritePartialPointer(ptr) | ReadPartialPointer(ptr) => {
ReadPartialPointer(ptr) => {
diag.arg("ptr", ptr);
}
ThreadLocalStatic(did) | ExternStatic(did) => rustc_middle::ty::tls::with(|tcx| {

View file

@ -19,9 +19,12 @@ use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
use rustc_hir as hir;
use rustc_hir::definitions::{DefPathData, DisambiguatorState};
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::mir::interpret::{ConstAllocation, CtfeProvenance, InterpResult};
use rustc_middle::mir::interpret::{
AllocBytes, ConstAllocation, CtfeProvenance, InterpResult, Provenance,
};
use rustc_middle::query::TyCtxtAt;
use rustc_middle::span_bug;
use rustc_middle::ty::TyCtxt;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_span::def_id::LocalDefId;
use tracing::{instrument, trace};
@ -52,6 +55,45 @@ impl HasStaticRootDefId for const_eval::CompileTimeMachine<'_> {
}
}
fn prepare_alloc<'tcx, Prov: Provenance, Extra, Bytes: AllocBytes>(
tcx: TyCtxt<'tcx>,
kind: MemoryKind<const_eval::MemoryKind>,
alloc: &mut Allocation<Prov, Extra, Bytes>,
mutability: Mutability,
) -> Result<(), InternError> {
match kind {
MemoryKind::Machine(const_eval::MemoryKind::Heap { was_made_global }) => {
if !was_made_global {
// Attempting to intern a `const_allocate`d pointer that was not made global via
// `const_make_global`.
tcx.dcx().delayed_bug("non-global heap allocation in const value");
return Err(InternError::ConstAllocNotGlobal);
}
}
MemoryKind::Stack | MemoryKind::CallerLocation => {}
}
if !alloc.provenance_merge_bytes(&tcx) {
// Per-byte provenance is not supported by backends, so we cannot accept it here.
tcx.dcx().delayed_bug("partial pointer in const value");
return Err(InternError::PartialPointer);
}
// Set allocation mutability as appropriate. This is used by LLVM to put things into
// read-only memory, and also by Miri when evaluating other globals that
// access this one.
match mutability {
Mutability::Not => {
alloc.mutability = Mutability::Not;
}
Mutability::Mut => {
// This must be already mutable, we won't "un-freeze" allocations ever.
assert_eq!(alloc.mutability, Mutability::Mut);
}
}
Ok(())
}
/// Intern an allocation. Returns `Err` if the allocation does not exist in the local memory.
///
/// `mutability` can be used to force immutable interning: if it is `Mutability::Not`, the
@ -72,31 +114,13 @@ fn intern_shallow<'tcx, M: CompileTimeMachine<'tcx>>(
return Err(InternError::DanglingPointer);
};
match kind {
MemoryKind::Machine(const_eval::MemoryKind::Heap { was_made_global }) => {
if !was_made_global {
// Attempting to intern a `const_allocate`d pointer that was not made global via
// `const_make_global`. We want to error here, but we have to first put the
// allocation back into the `alloc_map` to keep things in a consistent state.
ecx.memory.alloc_map.insert(alloc_id, (kind, alloc));
return Err(InternError::ConstAllocNotGlobal);
}
}
MemoryKind::Stack | MemoryKind::CallerLocation => {}
if let Err(err) = prepare_alloc(*ecx.tcx, kind, &mut alloc, mutability) {
// We want to error here, but we have to first put the
// allocation back into the `alloc_map` to keep things in a consistent state.
ecx.memory.alloc_map.insert(alloc_id, (kind, alloc));
return Err(err);
}
// Set allocation mutability as appropriate. This is used by LLVM to put things into
// read-only memory, and also by Miri when evaluating other globals that
// access this one.
match mutability {
Mutability::Not => {
alloc.mutability = Mutability::Not;
}
Mutability::Mut => {
// This must be already mutable, we won't "un-freeze" allocations ever.
assert_eq!(alloc.mutability, Mutability::Mut);
}
}
// link the alloc id to the actual allocation
let alloc = ecx.tcx.mk_const_alloc(alloc);
if let Some(static_id) = ecx.machine.static_def_id() {
@ -166,6 +190,7 @@ pub enum InternError {
BadMutablePointer,
DanglingPointer,
ConstAllocNotGlobal,
PartialPointer,
}
/// Intern `ret` and everything it references.
@ -221,13 +246,11 @@ pub fn intern_const_alloc_recursive<'tcx, M: CompileTimeMachine<'tcx>>(
let mut todo: Vec<_> = if is_static {
// Do not steal the root allocation, we need it later to create the return value of `eval_static_initializer`.
// But still change its mutability to match the requested one.
let alloc = ecx.memory.alloc_map.get_mut(&base_alloc_id).unwrap();
alloc.1.mutability = base_mutability;
alloc.1.provenance().ptrs().iter().map(|&(_, prov)| prov).collect()
let (kind, alloc) = ecx.memory.alloc_map.get_mut(&base_alloc_id).unwrap();
prepare_alloc(*ecx.tcx, *kind, alloc, base_mutability)?;
alloc.provenance().ptrs().iter().map(|&(_, prov)| prov).collect()
} else {
intern_shallow(ecx, base_alloc_id, base_mutability, Some(&mut disambiguator))
.unwrap()
.collect()
intern_shallow(ecx, base_alloc_id, base_mutability, Some(&mut disambiguator))?.collect()
};
// We need to distinguish "has just been interned" from "was already in `tcx`",
// so we track this in a separate set.
@ -235,7 +258,6 @@ pub fn intern_const_alloc_recursive<'tcx, M: CompileTimeMachine<'tcx>>(
// Whether we encountered a bad mutable pointer.
// We want to first report "dangling" and then "mutable", so we need to delay reporting these
// errors.
let mut result = Ok(());
let mut found_bad_mutable_ptr = false;
// Keep interning as long as there are things to intern.
@ -310,20 +332,15 @@ pub fn intern_const_alloc_recursive<'tcx, M: CompileTimeMachine<'tcx>>(
// okay with losing some potential for immutability here. This can anyway only affect
// `static mut`.
just_interned.insert(alloc_id);
match intern_shallow(ecx, alloc_id, inner_mutability, Some(&mut disambiguator)) {
Ok(nested) => todo.extend(nested),
Err(err) => {
ecx.tcx.dcx().delayed_bug("error during const interning");
result = Err(err);
}
}
let next = intern_shallow(ecx, alloc_id, inner_mutability, Some(&mut disambiguator))?;
todo.extend(next);
}
if found_bad_mutable_ptr && result.is_ok() {
if found_bad_mutable_ptr {
// We found a mutable pointer inside a const where inner allocations should be immutable,
// and there was no other error. This should usually never happen! However, this can happen
// in unleash-miri mode, so report it as a normal error then.
if ecx.tcx.sess.opts.unstable_opts.unleash_the_miri_inside_of_you {
result = Err(InternError::BadMutablePointer);
return Err(InternError::BadMutablePointer);
} else {
span_bug!(
ecx.tcx.span,
@ -331,7 +348,7 @@ pub fn intern_const_alloc_recursive<'tcx, M: CompileTimeMachine<'tcx>>(
);
}
}
result
Ok(())
}
/// Intern `ret`. This function assumes that `ret` references no other allocation.

View file

@ -1314,29 +1314,20 @@ impl<'a, 'tcx, Prov: Provenance, Extra, Bytes: AllocBytes>
}
/// Mark the given sub-range (relative to this allocation reference) as uninitialized.
pub fn write_uninit(&mut self, range: AllocRange) -> InterpResult<'tcx> {
pub fn write_uninit(&mut self, range: AllocRange) {
let range = self.range.subrange(range);
self.alloc
.write_uninit(&self.tcx, range)
.map_err(|e| e.to_interp_error(self.alloc_id))
.into()
self.alloc.write_uninit(&self.tcx, range);
}
/// Mark the entire referenced range as uninitialized
pub fn write_uninit_full(&mut self) -> InterpResult<'tcx> {
self.alloc
.write_uninit(&self.tcx, self.range)
.map_err(|e| e.to_interp_error(self.alloc_id))
.into()
pub fn write_uninit_full(&mut self) {
self.alloc.write_uninit(&self.tcx, self.range);
}
/// Remove all provenance in the reference range.
pub fn clear_provenance(&mut self) -> InterpResult<'tcx> {
self.alloc
.clear_provenance(&self.tcx, self.range)
.map_err(|e| e.to_interp_error(self.alloc_id))
.into()
pub fn clear_provenance(&mut self) {
self.alloc.clear_provenance(&self.tcx, self.range);
}
}
@ -1427,11 +1418,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
// Side-step AllocRef and directly access the underlying bytes more efficiently.
// (We are staying inside the bounds here and all bytes do get overwritten so all is good.)
let alloc_id = alloc_ref.alloc_id;
let bytes = alloc_ref
.alloc
.get_bytes_unchecked_for_overwrite(&alloc_ref.tcx, alloc_ref.range)
.map_err(move |e| e.to_interp_error(alloc_id))?;
let bytes =
alloc_ref.alloc.get_bytes_unchecked_for_overwrite(&alloc_ref.tcx, alloc_ref.range);
// `zip` would stop when the first iterator ends; we want to definitely
// cover all of `bytes`.
for dest in bytes {
@ -1513,10 +1501,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
// `get_bytes_mut` will clear the provenance, which is correct,
// since we don't want to keep any provenance at the target.
// This will also error if copying partial provenance is not supported.
let provenance = src_alloc
.provenance()
.prepare_copy(src_range, dest_offset, num_copies, self)
.map_err(|e| e.to_interp_error(src_alloc_id))?;
let provenance =
src_alloc.provenance().prepare_copy(src_range, dest_offset, num_copies, self);
// Prepare a copy of the initialization mask.
let init = src_alloc.init_mask().prepare_copy(src_range);
@ -1534,10 +1520,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
dest_range,
)?;
// Yes we do overwrite all bytes in `dest_bytes`.
let dest_bytes = dest_alloc
.get_bytes_unchecked_for_overwrite_ptr(&tcx, dest_range)
.map_err(|e| e.to_interp_error(dest_alloc_id))?
.as_mut_ptr();
let dest_bytes =
dest_alloc.get_bytes_unchecked_for_overwrite_ptr(&tcx, dest_range).as_mut_ptr();
if init.no_bytes_init() {
// Fast path: If all bytes are `uninit` then there is nothing to copy. The target range
@ -1546,9 +1530,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
// This also avoids writing to the target bytes so that the backing allocation is never
// touched if the bytes stay uninitialized for the whole interpreter execution. On contemporary
// operating system this can avoid physically allocating the page.
dest_alloc
.write_uninit(&tcx, dest_range)
.map_err(|e| e.to_interp_error(dest_alloc_id))?;
dest_alloc.write_uninit(&tcx, dest_range);
// `write_uninit` also resets the provenance, so we are done.
return interp_ok(());
}

View file

@ -705,7 +705,7 @@ where
match value {
Immediate::Scalar(scalar) => {
alloc.write_scalar(alloc_range(Size::ZERO, scalar.size()), scalar)
alloc.write_scalar(alloc_range(Size::ZERO, scalar.size()), scalar)?;
}
Immediate::ScalarPair(a_val, b_val) => {
let BackendRepr::ScalarPair(a, b) = layout.backend_repr else {
@ -725,10 +725,10 @@ where
alloc.write_scalar(alloc_range(Size::ZERO, a_val.size()), a_val)?;
alloc.write_scalar(alloc_range(b_offset, b_val.size()), b_val)?;
// We don't have to reset padding here, `write_immediate` will anyway do a validation run.
interp_ok(())
}
Immediate::Uninit => alloc.write_uninit_full(),
}
interp_ok(())
}
pub fn write_uninit(
@ -748,7 +748,7 @@ where
// Zero-sized access
return interp_ok(());
};
alloc.write_uninit_full()?;
alloc.write_uninit_full();
}
}
interp_ok(())
@ -772,7 +772,7 @@ where
// Zero-sized access
return interp_ok(());
};
alloc.clear_provenance()?;
alloc.clear_provenance();
}
}
interp_ok(())

View file

@ -949,7 +949,7 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
let padding_size = offset - padding_cleared_until;
let range = alloc_range(padding_start, padding_size);
trace!("reset_padding on {}: resetting padding range {range:?}", mplace.layout.ty);
alloc.write_uninit(range)?;
alloc.write_uninit(range);
}
padding_cleared_until = offset + size;
}
@ -1239,7 +1239,7 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
if self.reset_provenance_and_padding {
// We can't share this with above as above, we might be looking at read-only memory.
let mut alloc = self.ecx.get_ptr_alloc_mut(mplace.ptr(), size)?.expect("we already excluded size 0");
alloc.clear_provenance()?;
alloc.clear_provenance();
// Also, mark this as containing data, not padding.
self.add_data_range(mplace.ptr(), size);
}

View file

@ -5,7 +5,7 @@ use rustc_hir::def_id::CrateNum;
use rustc_hir::definitions::DisambiguatedDefPathData;
use rustc_middle::bug;
use rustc_middle::ty::print::{PrettyPrinter, PrintError, Printer};
use rustc_middle::ty::{self, GenericArg, GenericArgKind, Ty, TyCtxt};
use rustc_middle::ty::{self, GenericArg, Ty, TyCtxt};
struct TypeNamePrinter<'tcx> {
tcx: TyCtxt<'tcx>,
@ -18,9 +18,10 @@ impl<'tcx> Printer<'tcx> for TypeNamePrinter<'tcx> {
}
fn print_region(&mut self, _region: ty::Region<'_>) -> Result<(), PrintError> {
// This is reachable (via `pretty_print_dyn_existential`) even though
// `<Self As PrettyPrinter>::should_print_region` returns false. See #144994.
Ok(())
// FIXME: most regions have been erased by the time this code runs.
// Just printing `'_` is a bit hacky but gives mostly good results, and
// doing better is difficult. See `should_print_optional_region`.
write!(self, "'_")
}
fn print_type(&mut self, ty: Ty<'tcx>) -> Result<(), PrintError> {
@ -125,10 +126,8 @@ impl<'tcx> Printer<'tcx> for TypeNamePrinter<'tcx> {
args: &[GenericArg<'tcx>],
) -> Result<(), PrintError> {
print_prefix(self)?;
let args =
args.iter().cloned().filter(|arg| !matches!(arg.kind(), GenericArgKind::Lifetime(_)));
if args.clone().next().is_some() {
self.generic_delimiters(|cx| cx.comma_sep(args))
if !args.is_empty() {
self.generic_delimiters(|cx| cx.comma_sep(args.iter().copied()))
} else {
Ok(())
}
@ -136,8 +135,15 @@ impl<'tcx> Printer<'tcx> for TypeNamePrinter<'tcx> {
}
impl<'tcx> PrettyPrinter<'tcx> for TypeNamePrinter<'tcx> {
fn should_print_region(&self, _region: ty::Region<'_>) -> bool {
false
fn should_print_optional_region(&self, _region: ty::Region<'_>) -> bool {
// Bound regions are always printed (as `'_`), which gives some idea that they are special,
// even though the `for` is omitted by the pretty printer.
// E.g. `for<'a, 'b> fn(&'a u32, &'b u32)` is printed as "fn(&'_ u32, &'_ u32)".
match _region.kind() {
ty::ReErased => false,
ty::ReBound(..) => true,
_ => unreachable!(),
}
}
fn generic_delimiters(

View file

@ -1,4 +1,9 @@
An unaligned reference to a field of a [packed] struct got created.
An unaligned reference to a field of a [packed] `struct` or `union` was created.
The `#[repr(packed)]` attribute removes padding between fields, which can
cause fields to be stored at unaligned memory addresses. Creating references
to such fields violates Rust's memory safety guarantees and can lead to
undefined behavior in optimized code.
Erroneous code example:
@ -45,9 +50,36 @@ unsafe {
// For formatting, we can create a copy to avoid the direct reference.
let copy = foo.field1;
println!("{}", copy);
// Creating a copy can be written in a single line with curly braces.
// (This is equivalent to the two lines above.)
println!("{}", { foo.field1 });
// A reference to a field that will always be sufficiently aligned is safe:
println!("{}", foo.field2);
}
```
### Unions
Although creating a reference to a `union` field is `unsafe`, this error
will still be triggered if the referenced field is not sufficiently
aligned. Use `addr_of!` and raw pointers in the same way as for struct fields.
```compile_fail,E0793
#[repr(packed)]
pub union Foo {
field1: u64,
field2: u8,
}
unsafe {
let foo = Foo { field1: 0 };
// Accessing the field directly is fine.
let val = foo.field1;
// A reference to a packed union field causes an error.
let val = &foo.field1; // ERROR
}
```

View file

@ -1158,10 +1158,6 @@ pub static BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
rustc_do_not_const_check, Normal, template!(Word), WarnFollowing,
EncodeCrossCrate::Yes, "`#[rustc_do_not_const_check]` skips const-check for this function's body",
),
rustc_attr!(
rustc_const_panic_str, Normal, template!(Word), WarnFollowing,
EncodeCrossCrate::Yes, "`#[rustc_const_panic_str]` ensures the argument to this function is &&str during const-check",
),
rustc_attr!(
rustc_const_stable_indirect, Normal,
template!(Word),

View file

@ -470,6 +470,8 @@ declare_features! (
(unstable, deprecated_suggestion, "1.61.0", Some(94785)),
/// Allows deref patterns.
(incomplete, deref_patterns, "1.79.0", Some(87121)),
/// Allows deriving the From trait on single-field structs.
(unstable, derive_from, "CURRENT_RUSTC_VERSION", Some(144889)),
/// Tells rustdoc to automatically generate `#[doc(cfg(...))]`.
(unstable, doc_auto_cfg, "1.58.0", Some(43781)),
/// Allows `#[doc(cfg(...))]`.

View file

@ -286,6 +286,7 @@ language_item_table! {
Panic, sym::panic, panic_fn, Target::Fn, GenericRequirement::Exact(0);
PanicNounwind, sym::panic_nounwind, panic_nounwind, Target::Fn, GenericRequirement::Exact(0);
PanicFmt, sym::panic_fmt, panic_fmt, Target::Fn, GenericRequirement::None;
PanicDisplay, sym::panic_display, panic_display, Target::Fn, GenericRequirement::None;
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);

View file

@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::round_ties_even_f32
| sym::round_ties_even_f64
| sym::round_ties_even_f128
| sym::autodiff
| sym::const_eval_select => hir::Safety::Safe,
_ => hir::Safety::Unsafe,
};
@ -198,6 +199,7 @@ pub(crate) fn check_intrinsic_type(
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
let n_lts = 0;
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
sym::autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
sym::abort => (0, 0, vec![], tcx.types.never),
sym::unreachable => (0, 0, vec![], tcx.types.never),
sym::breakpoint => (0, 0, vec![], tcx.types.unit),

View file

@ -390,45 +390,13 @@ fn check_predicates<'tcx>(
let mut res = Ok(());
for (clause, span) in impl1_predicates {
if !impl2_predicates.iter().any(|pred2| trait_predicates_eq(clause.as_predicate(), *pred2))
{
if !impl2_predicates.iter().any(|&pred2| clause.as_predicate() == pred2) {
res = res.and(check_specialization_on(tcx, clause, span))
}
}
res
}
/// Checks if some predicate on the specializing impl (`predicate1`) is the same
/// as some predicate on the base impl (`predicate2`).
///
/// This basically just checks syntactic equivalence, but is a little more
/// forgiving since we want to equate `T: Tr` with `T: [const] Tr` so this can work:
///
/// ```ignore (illustrative)
/// #[rustc_specialization_trait]
/// trait Specialize { }
///
/// impl<T: Bound> Tr for T { }
/// impl<T: [const] Bound + Specialize> const Tr for T { }
/// ```
///
/// However, we *don't* want to allow the reverse, i.e., when the bound on the
/// specializing impl is not as const as the bound on the base impl:
///
/// ```ignore (illustrative)
/// impl<T: [const] Bound> const Tr for T { }
/// impl<T: Bound + Specialize> const Tr for T { } // should be T: [const] Bound
/// ```
///
/// So we make that check in this function and try to raise a helpful error message.
fn trait_predicates_eq<'tcx>(
predicate1: ty::Predicate<'tcx>,
predicate2: ty::Predicate<'tcx>,
) -> bool {
// FIXME(const_trait_impl)
predicate1 == predicate2
}
#[instrument(level = "debug", skip(tcx))]
fn check_specialization_on<'tcx>(
tcx: TyCtxt<'tcx>,

View file

@ -910,7 +910,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// const stability checking here too, I guess.
if self.tcx.is_conditionally_const(callee_did) {
let q = self.tcx.const_conditions(callee_did);
// FIXME(const_trait_impl): Use this span with a better cause code.
for (idx, (cond, pred_span)) in
q.instantiate(self.tcx, callee_args).into_iter().enumerate()
{

View file

@ -756,22 +756,22 @@ impl<'tcx> LateContext<'tcx> {
}
fn print_region(&mut self, _region: ty::Region<'_>) -> Result<(), PrintError> {
unreachable!(); // because `path_generic_args` ignores the `GenericArgs`
unreachable!(); // because `print_path_with_generic_args` ignores the `GenericArgs`
}
fn print_type(&mut self, _ty: Ty<'tcx>) -> Result<(), PrintError> {
unreachable!(); // because `path_generic_args` ignores the `GenericArgs`
unreachable!(); // because `print_path_with_generic_args` ignores the `GenericArgs`
}
fn print_dyn_existential(
&mut self,
_predicates: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> Result<(), PrintError> {
unreachable!(); // because `path_generic_args` ignores the `GenericArgs`
unreachable!(); // because `print_path_with_generic_args` ignores the `GenericArgs`
}
fn print_const(&mut self, _ct: ty::Const<'tcx>) -> Result<(), PrintError> {
unreachable!(); // because `path_generic_args` ignores the `GenericArgs`
unreachable!(); // because `print_path_with_generic_args` ignores the `GenericArgs`
}
fn print_crate_name(&mut self, cnum: CrateNum) -> Result<(), PrintError> {

View file

@ -151,7 +151,7 @@ impl<'tcx> LateLintPass<'tcx> for DropForgetUseless {
&& let Node::Stmt(stmt) = node
&& let StmtKind::Semi(e) = stmt.kind
&& e.hir_id == expr.hir_id
&& let Some(arg_span) = arg.span.find_ancestor_inside(expr.span)
&& let Some(arg_span) = arg.span.find_ancestor_inside_same_ctxt(expr.span)
{
UseLetUnderscoreIgnoreSuggestion::Suggestion {
start_span: expr.span.shrink_to_lo().until(arg_span),

View file

@ -5105,3 +5105,36 @@ declare_lint! {
report_in_deps: true,
};
}
declare_lint! {
/// The `tail_call_track_caller` lint detects usage of `become` attempting to tail call
/// a function marked with `#[track_caller]`.
///
/// ### Example
///
/// ```rust
/// #![feature(explicit_tail_calls)]
/// #![expect(incomplete_features)]
///
/// #[track_caller]
/// fn f() {}
///
/// fn g() {
/// become f();
/// }
///
/// g();
/// ```
///
/// {{produces}}
///
/// ### Explanation
///
/// Due to implementation details of tail calls and `#[track_caller]` attribute, calls to
/// functions marked with `#[track_caller]` cannot become tail calls. As such using `become`
/// is no different than a normal call (except for changes in drop order).
pub TAIL_CALL_TRACK_CALLER,
Warn,
"detects tail calls of functions marked with `#[track_caller]`",
@feature_gate = explicit_tail_calls;
}

View file

@ -1,7 +1,6 @@
use std::borrow::Cow;
use rustc_abi::Align;
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
use rustc_hir::attrs::{InlineAttr, InstructionSetAttr, Linkage, OptimizeAttr};
use rustc_hir::def_id::DefId;
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
@ -75,8 +74,6 @@ pub struct CodegenFnAttrs {
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
/// the function entry.
pub patchable_function_entry: Option<PatchableFunctionEntry>,
/// For the `#[autodiff]` macros.
pub autodiff_item: Option<AutoDiffAttrs>,
}
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
@ -182,7 +179,6 @@ impl CodegenFnAttrs {
instruction_set: None,
alignment: None,
patchable_function_entry: None,
autodiff_item: None,
}
}

View file

@ -448,6 +448,11 @@ impl<'tcx> Const<'tcx> {
Self::Val(val, ty)
}
#[inline]
pub fn from_ty_value(tcx: TyCtxt<'tcx>, val: ty::Value<'tcx>) -> Self {
Self::Ty(val.ty, ty::Const::new_value(tcx, val.valtree, val.ty))
}
pub fn from_bits(
tcx: TyCtxt<'tcx>,
bits: u128,

View file

@ -306,8 +306,6 @@ pub enum AllocError {
ScalarSizeMismatch(ScalarSizeMismatch),
/// Encountered a pointer where we needed raw bytes.
ReadPointerAsInt(Option<BadBytesAccess>),
/// Partially overwriting a pointer.
OverwritePartialPointer(Size),
/// Partially copying a pointer.
ReadPartialPointer(Size),
/// Using uninitialized data where it is not allowed.
@ -331,9 +329,6 @@ impl AllocError {
ReadPointerAsInt(info) => InterpErrorKind::Unsupported(
UnsupportedOpInfo::ReadPointerAsInt(info.map(|b| (alloc_id, b))),
),
OverwritePartialPointer(offset) => InterpErrorKind::Unsupported(
UnsupportedOpInfo::OverwritePartialPointer(Pointer::new(alloc_id, offset)),
),
ReadPartialPointer(offset) => InterpErrorKind::Unsupported(
UnsupportedOpInfo::ReadPartialPointer(Pointer::new(alloc_id, offset)),
),
@ -633,11 +628,11 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
) -> AllocResult<&mut [u8]> {
) -> &mut [u8] {
self.mark_init(range, true);
self.provenance.clear(range, cx)?;
self.provenance.clear(range, cx);
Ok(&mut self.bytes[range.start.bytes_usize()..range.end().bytes_usize()])
&mut self.bytes[range.start.bytes_usize()..range.end().bytes_usize()]
}
/// A raw pointer variant of `get_bytes_unchecked_for_overwrite` that avoids invalidating existing immutable aliases
@ -646,15 +641,15 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
&mut self,
cx: &impl HasDataLayout,
range: AllocRange,
) -> AllocResult<*mut [u8]> {
) -> *mut [u8] {
self.mark_init(range, true);
self.provenance.clear(range, cx)?;
self.provenance.clear(range, cx);
assert!(range.end().bytes_usize() <= self.bytes.len()); // need to do our own bounds-check
// Crucially, we go via `AllocBytes::as_mut_ptr`, not `AllocBytes::deref_mut`.
let begin_ptr = self.bytes.as_mut_ptr().wrapping_add(range.start.bytes_usize());
let len = range.end().bytes_usize() - range.start.bytes_usize();
Ok(ptr::slice_from_raw_parts_mut(begin_ptr, len))
ptr::slice_from_raw_parts_mut(begin_ptr, len)
}
/// This gives direct mutable access to the entire buffer, just exposing their internal state
@ -723,26 +718,45 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
let ptr = Pointer::new(prov, Size::from_bytes(bits));
return Ok(Scalar::from_pointer(ptr, cx));
}
// If we can work on pointers byte-wise, join the byte-wise provenances.
if Prov::OFFSET_IS_ADDR {
let mut prov = self.provenance.get(range.start, cx);
// The other easy case is total absence of provenance.
if self.provenance.range_empty(range, cx) {
return Ok(Scalar::from_uint(bits, range.size));
}
// If we get here, we have to check per-byte provenance, and join them together.
let prov = 'prov: {
// Initialize with first fragment. Must have index 0.
let Some((mut joint_prov, 0)) = self.provenance.get_byte(range.start, cx) else {
break 'prov None;
};
// Update with the remaining fragments.
for offset in Size::from_bytes(1)..range.size {
let this_prov = self.provenance.get(range.start + offset, cx);
prov = Prov::join(prov, this_prov);
// Ensure there is provenance here and it has the right index.
let Some((frag_prov, frag_idx)) =
self.provenance.get_byte(range.start + offset, cx)
else {
break 'prov None;
};
// Wildcard provenance is allowed to come with any index (this is needed
// for Miri's native-lib mode to work).
if u64::from(frag_idx) != offset.bytes() && Some(frag_prov) != Prov::WILDCARD {
break 'prov None;
}
// Merge this byte's provenance with the previous ones.
joint_prov = match Prov::join(joint_prov, frag_prov) {
Some(prov) => prov,
None => break 'prov None,
};
}
// Now use this provenance.
let ptr = Pointer::new(prov, Size::from_bytes(bits));
return Ok(Scalar::from_maybe_pointer(ptr, cx));
} else {
// Without OFFSET_IS_ADDR, the only remaining case we can handle is total absence of
// provenance.
if self.provenance.range_empty(range, cx) {
return Ok(Scalar::from_uint(bits, range.size));
}
// Else we have mixed provenance, that doesn't work.
break 'prov Some(joint_prov);
};
if prov.is_none() && !Prov::OFFSET_IS_ADDR {
// There are some bytes with provenance here but overall the provenance does not add up.
// We need `OFFSET_IS_ADDR` to fall back to no-provenance here; without that option, we must error.
return Err(AllocError::ReadPartialPointer(range.start));
}
// We can use this provenance.
let ptr = Pointer::new(prov, Size::from_bytes(bits));
return Ok(Scalar::from_maybe_pointer(ptr, cx));
} else {
// We are *not* reading a pointer.
// If we can just ignore provenance or there is none, that's easy.
@ -782,7 +796,7 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
let endian = cx.data_layout().endian;
// Yes we do overwrite all the bytes in `dst`.
let dst = self.get_bytes_unchecked_for_overwrite(cx, range)?;
let dst = self.get_bytes_unchecked_for_overwrite(cx, range);
write_target_uint(endian, dst, bytes).unwrap();
// See if we have to also store some provenance.
@ -795,10 +809,9 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
}
/// Write "uninit" to the given memory range.
pub fn write_uninit(&mut self, cx: &impl HasDataLayout, range: AllocRange) -> AllocResult {
pub fn write_uninit(&mut self, cx: &impl HasDataLayout, range: AllocRange) {
self.mark_init(range, false);
self.provenance.clear(range, cx)?;
Ok(())
self.provenance.clear(range, cx);
}
/// Mark all bytes in the given range as initialised and reset the provenance
@ -817,9 +830,12 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
}
/// Remove all provenance in the given memory range.
pub fn clear_provenance(&mut self, cx: &impl HasDataLayout, range: AllocRange) -> AllocResult {
self.provenance.clear(range, cx)?;
return Ok(());
pub fn clear_provenance(&mut self, cx: &impl HasDataLayout, range: AllocRange) {
self.provenance.clear(range, cx);
}
pub fn provenance_merge_bytes(&mut self, cx: &impl HasDataLayout) -> bool {
self.provenance.merge_bytes(cx)
}
/// Applies a previously prepared provenance copy.

View file

@ -10,7 +10,7 @@ use rustc_macros::HashStable;
use rustc_serialize::{Decodable, Decoder, Encodable, Encoder};
use tracing::trace;
use super::{AllocError, AllocRange, AllocResult, CtfeProvenance, Provenance, alloc_range};
use super::{AllocRange, CtfeProvenance, Provenance, alloc_range};
/// Stores the provenance information of pointers stored in memory.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
@ -19,25 +19,25 @@ pub struct ProvenanceMap<Prov = CtfeProvenance> {
/// `Provenance` in this map applies from the given offset for an entire pointer-size worth of
/// bytes. Two entries in this map are always at least a pointer size apart.
ptrs: SortedMap<Size, Prov>,
/// Provenance in this map only applies to the given single byte.
/// This map is disjoint from the previous. It will always be empty when
/// `Prov::OFFSET_IS_ADDR` is false.
bytes: Option<Box<SortedMap<Size, Prov>>>,
/// This stores byte-sized provenance fragments.
/// The `u8` indicates the position of this byte inside its original pointer.
/// If the bytes are re-assembled in their original order, the pointer can be used again.
/// Wildcard provenance is allowed to have index 0 everywhere.
bytes: Option<Box<SortedMap<Size, (Prov, u8)>>>,
}
// These impls are generic over `Prov` since `CtfeProvenance` is only decodable/encodable
// for some particular `D`/`S`.
impl<D: Decoder, Prov: Provenance + Decodable<D>> Decodable<D> for ProvenanceMap<Prov> {
fn decode(d: &mut D) -> Self {
assert!(!Prov::OFFSET_IS_ADDR); // only `CtfeProvenance` is ever serialized
// `bytes` is not in the serialized format
Self { ptrs: Decodable::decode(d), bytes: None }
}
}
impl<S: Encoder, Prov: Provenance + Encodable<S>> Encodable<S> for ProvenanceMap<Prov> {
fn encode(&self, s: &mut S) {
let Self { ptrs, bytes } = self;
assert!(!Prov::OFFSET_IS_ADDR); // only `CtfeProvenance` is ever serialized
debug_assert!(bytes.is_none()); // without `OFFSET_IS_ADDR`, this is always empty
assert!(bytes.is_none()); // interning refuses allocations with pointer fragments
ptrs.encode(s)
}
}
@ -58,10 +58,10 @@ impl ProvenanceMap {
/// Give access to the ptr-sized provenances (which can also be thought of as relocations, and
/// indeed that is how codegen treats them).
///
/// Only exposed with `CtfeProvenance` provenance, since it panics if there is bytewise provenance.
/// Only use on interned allocations, as other allocations may have per-byte provenance!
#[inline]
pub fn ptrs(&self) -> &SortedMap<Size, CtfeProvenance> {
debug_assert!(self.bytes.is_none()); // `CtfeProvenance::OFFSET_IS_ADDR` is false so this cannot fail
assert!(self.bytes.is_none(), "`ptrs()` called on non-interned allocation");
&self.ptrs
}
}
@ -88,12 +88,12 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
}
/// `pm.range_ptrs_is_empty(r, cx)` == `pm.range_ptrs_get(r, cx).is_empty()`, but is faster.
pub(super) fn range_ptrs_is_empty(&self, range: AllocRange, cx: &impl HasDataLayout) -> bool {
fn range_ptrs_is_empty(&self, range: AllocRange, cx: &impl HasDataLayout) -> bool {
self.ptrs.range_is_empty(Self::adjusted_range_ptrs(range, cx))
}
/// Returns all byte-wise provenance in the given range.
fn range_bytes_get(&self, range: AllocRange) -> &[(Size, Prov)] {
fn range_bytes_get(&self, range: AllocRange) -> &[(Size, (Prov, u8))] {
if let Some(bytes) = self.bytes.as_ref() {
bytes.range(range.start..range.end())
} else {
@ -107,19 +107,47 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
}
/// Get the provenance of a single byte.
pub fn get(&self, offset: Size, cx: &impl HasDataLayout) -> Option<Prov> {
pub fn get_byte(&self, offset: Size, cx: &impl HasDataLayout) -> Option<(Prov, u8)> {
let prov = self.range_ptrs_get(alloc_range(offset, Size::from_bytes(1)), cx);
debug_assert!(prov.len() <= 1);
if let Some(entry) = prov.first() {
// If it overlaps with this byte, it is on this byte.
debug_assert!(self.bytes.as_ref().is_none_or(|b| !b.contains_key(&offset)));
Some(entry.1)
Some((entry.1, (offset - entry.0).bytes() as u8))
} else {
// Look up per-byte provenance.
self.bytes.as_ref().and_then(|b| b.get(&offset).copied())
}
}
/// Attempt to merge per-byte provenance back into ptr chunks, if the right fragments
/// sit next to each other. Return `false` is that is not possible due to partial pointers.
pub fn merge_bytes(&mut self, cx: &impl HasDataLayout) -> bool {
let Some(bytes) = self.bytes.as_deref_mut() else {
return true;
};
let ptr_size = cx.data_layout().pointer_size();
while let Some((offset, (prov, _))) = bytes.iter().next().copied() {
// Check if this fragment starts a pointer.
let range = offset..offset + ptr_size;
let frags = bytes.range(range.clone());
if frags.len() != ptr_size.bytes_usize() {
return false;
}
for (idx, (_offset, (frag_prov, frag_idx))) in frags.iter().copied().enumerate() {
if frag_prov != prov || frag_idx != idx as u8 {
return false;
}
}
// Looks like a pointer! Move it over to the ptr provenance map.
bytes.remove_range(range);
self.ptrs.insert(offset, prov);
}
// We managed to convert everything into whole pointers.
self.bytes = None;
true
}
/// Check if there is ptr-sized provenance at the given index.
/// Does not mean anything for bytewise provenance! But can be useful as an optimization.
pub fn get_ptr(&self, offset: Size) -> Option<Prov> {
@ -137,7 +165,7 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
/// Yields all the provenances stored in this map.
pub fn provenances(&self) -> impl Iterator<Item = Prov> {
let bytes = self.bytes.iter().flat_map(|b| b.values());
let bytes = self.bytes.iter().flat_map(|b| b.values().map(|(p, _i)| p));
self.ptrs.values().chain(bytes).copied()
}
@ -148,16 +176,12 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
/// Removes all provenance inside the given range.
/// If there is provenance overlapping with the edges, might result in an error.
pub fn clear(&mut self, range: AllocRange, cx: &impl HasDataLayout) -> AllocResult {
pub fn clear(&mut self, range: AllocRange, cx: &impl HasDataLayout) {
let start = range.start;
let end = range.end();
// Clear the bytewise part -- this is easy.
if Prov::OFFSET_IS_ADDR {
if let Some(bytes) = self.bytes.as_mut() {
bytes.remove_range(start..end);
}
} else {
debug_assert!(self.bytes.is_none());
if let Some(bytes) = self.bytes.as_mut() {
bytes.remove_range(start..end);
}
let pointer_size = cx.data_layout().pointer_size();
@ -168,7 +192,7 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
// Find all provenance overlapping the given range.
if self.range_ptrs_is_empty(range, cx) {
// No provenance in this range, we are done. This is the common case.
return Ok(());
return;
}
// This redoes some of the work of `range_get_ptrs_is_empty`, but this path is much
@ -179,28 +203,20 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
// We need to handle clearing the provenance from parts of a pointer.
if first < start {
if !Prov::OFFSET_IS_ADDR {
// We can't split up the provenance into less than a pointer.
return Err(AllocError::OverwritePartialPointer(first));
}
// Insert the remaining part in the bytewise provenance.
let prov = self.ptrs[&first];
let bytes = self.bytes.get_or_insert_with(Box::default);
for offset in first..start {
bytes.insert(offset, prov);
bytes.insert(offset, (prov, (offset - first).bytes() as u8));
}
}
if last > end {
let begin_of_last = last - pointer_size;
if !Prov::OFFSET_IS_ADDR {
// We can't split up the provenance into less than a pointer.
return Err(AllocError::OverwritePartialPointer(begin_of_last));
}
// Insert the remaining part in the bytewise provenance.
let prov = self.ptrs[&begin_of_last];
let bytes = self.bytes.get_or_insert_with(Box::default);
for offset in end..last {
bytes.insert(offset, prov);
bytes.insert(offset, (prov, (offset - begin_of_last).bytes() as u8));
}
}
@ -208,8 +224,6 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
// Since provenance do not overlap, we know that removing until `last` (exclusive) is fine,
// i.e., this will not remove any other provenance just after the ones we care about.
self.ptrs.remove_range(first..last);
Ok(())
}
/// Overwrites all provenance in the given range with wildcard provenance.
@ -218,10 +232,6 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
///
/// Provided for usage in Miri and panics otherwise.
pub fn write_wildcards(&mut self, cx: &impl HasDataLayout, range: AllocRange) {
assert!(
Prov::OFFSET_IS_ADDR,
"writing wildcard provenance is not supported when `OFFSET_IS_ADDR` is false"
);
let wildcard = Prov::WILDCARD.unwrap();
let bytes = self.bytes.get_or_insert_with(Box::default);
@ -229,21 +239,22 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
// Remove pointer provenances that overlap with the range, then readd the edge ones bytewise.
let ptr_range = Self::adjusted_range_ptrs(range, cx);
let ptrs = self.ptrs.range(ptr_range.clone());
if let Some((offset, prov)) = ptrs.first() {
for byte_ofs in *offset..range.start {
bytes.insert(byte_ofs, *prov);
if let Some((offset, prov)) = ptrs.first().copied() {
for byte_ofs in offset..range.start {
bytes.insert(byte_ofs, (prov, (byte_ofs - offset).bytes() as u8));
}
}
if let Some((offset, prov)) = ptrs.last() {
for byte_ofs in range.end()..*offset + cx.data_layout().pointer_size() {
bytes.insert(byte_ofs, *prov);
if let Some((offset, prov)) = ptrs.last().copied() {
for byte_ofs in range.end()..offset + cx.data_layout().pointer_size() {
bytes.insert(byte_ofs, (prov, (byte_ofs - offset).bytes() as u8));
}
}
self.ptrs.remove_range(ptr_range);
// Overwrite bytewise provenance.
for offset in range.start..range.end() {
bytes.insert(offset, wildcard);
// The fragment index does not matter for wildcard provenance.
bytes.insert(offset, (wildcard, 0));
}
}
}
@ -253,7 +264,7 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
/// Offsets are already adjusted to the destination allocation.
pub struct ProvenanceCopy<Prov> {
dest_ptrs: Option<Box<[(Size, Prov)]>>,
dest_bytes: Option<Box<[(Size, Prov)]>>,
dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
}
impl<Prov: Provenance> ProvenanceMap<Prov> {
@ -263,7 +274,7 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
dest: Size,
count: u64,
cx: &impl HasDataLayout,
) -> AllocResult<ProvenanceCopy<Prov>> {
) -> ProvenanceCopy<Prov> {
let shift_offset = move |idx, offset| {
// compute offset for current repetition
let dest_offset = dest + src.size * idx; // `Size` operations
@ -301,24 +312,16 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
let mut dest_bytes_box = None;
let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
if !Prov::OFFSET_IS_ADDR {
// There can't be any bytewise provenance, and we cannot split up the begin/end overlap.
if let Some(entry) = begin_overlap {
return Err(AllocError::ReadPartialPointer(entry.0));
}
if let Some(entry) = end_overlap {
return Err(AllocError::ReadPartialPointer(entry.0));
}
debug_assert!(self.bytes.is_none());
} else {
let mut bytes = Vec::new();
// We only need to go here if there is some overlap or some bytewise provenance.
if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
// First, if there is a part of a pointer at the start, add that.
if let Some(entry) = begin_overlap {
trace!("start overlapping entry: {entry:?}");
// For really small copies, make sure we don't run off the end of the `src` range.
let entry_end = cmp::min(entry.0 + ptr_size, src.end());
for offset in src.start..entry_end {
bytes.push((offset, entry.1));
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
}
} else {
trace!("no start overlapping entry");
@ -334,8 +337,9 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
let entry_start = cmp::max(entry.0, src.start);
for offset in entry_start..src.end() {
if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
// The last entry, if it exists, has a lower offset than us.
bytes.push((offset, entry.1));
// The last entry, if it exists, has a lower offset than us, so we
// can add it at the end and remain sorted.
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
} else {
// There already is an entry for this offset in there! This can happen when the
// start and end range checks actually end up hitting the same pointer, so we
@ -358,7 +362,7 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
dest_bytes_box = Some(dest_bytes.into_boxed_slice());
}
Ok(ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box })
ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box }
}
/// Applies a provenance copy.
@ -368,14 +372,10 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
if let Some(dest_ptrs) = copy.dest_ptrs {
self.ptrs.insert_presorted(dest_ptrs.into());
}
if Prov::OFFSET_IS_ADDR {
if let Some(dest_bytes) = copy.dest_bytes
&& !dest_bytes.is_empty()
{
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
}
} else {
debug_assert!(copy.dest_bytes.is_none());
if let Some(dest_bytes) = copy.dest_bytes
&& !dest_bytes.is_empty()
{
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
}
}
}

View file

@ -582,9 +582,6 @@ pub enum UnsupportedOpInfo {
//
// The variants below are only reachable from CTFE/const prop, miri will never emit them.
//
/// Overwriting parts of a pointer; without knowing absolute addresses, the resulting state
/// cannot be represented by the CTFE interpreter.
OverwritePartialPointer(Pointer<AllocId>),
/// Attempting to read or copy parts of a pointer to somewhere else; without knowing absolute
/// addresses, the resulting state cannot be represented by the CTFE interpreter.
ReadPartialPointer(Pointer<AllocId>),

View file

@ -56,7 +56,7 @@ impl<T: HasDataLayout> PointerArithmetic for T {}
/// mostly opaque; the `Machine` trait extends it with some more operations that also have access to
/// some global state.
/// The `Debug` rendering is used to display bare provenance, and for the default impl of `fmt`.
pub trait Provenance: Copy + fmt::Debug + 'static {
pub trait Provenance: Copy + PartialEq + fmt::Debug + 'static {
/// Says whether the `offset` field of `Pointer`s with this provenance is the actual physical address.
/// - If `false`, the offset *must* be relative. This means the bytes representing a pointer are
/// different from what the Abstract Machine prescribes, so the interpreter must prevent any
@ -79,7 +79,7 @@ pub trait Provenance: Copy + fmt::Debug + 'static {
fn get_alloc_id(self) -> Option<AllocId>;
/// Defines the 'join' of provenance: what happens when doing a pointer load and different bytes have different provenance.
fn join(left: Option<Self>, right: Option<Self>) -> Option<Self>;
fn join(left: Self, right: Self) -> Option<Self>;
}
/// The type of provenance in the compile-time interpreter.
@ -192,8 +192,8 @@ impl Provenance for CtfeProvenance {
Some(self.alloc_id())
}
fn join(_left: Option<Self>, _right: Option<Self>) -> Option<Self> {
panic!("merging provenance is not supported when `OFFSET_IS_ADDR` is false")
fn join(left: Self, right: Self) -> Option<Self> {
if left == right { Some(left) } else { None }
}
}
@ -224,8 +224,8 @@ impl Provenance for AllocId {
Some(self)
}
fn join(_left: Option<Self>, _right: Option<Self>) -> Option<Self> {
panic!("merging provenance is not supported when `OFFSET_IS_ADDR` is false")
fn join(_left: Self, _right: Self) -> Option<Self> {
unreachable!()
}
}

View file

@ -2,7 +2,6 @@ use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN};
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::FxIndexMap;
@ -336,7 +335,6 @@ impl ToStableHashKey<StableHashingContext<'_>> for MonoItem<'_> {
pub struct MonoItemPartitions<'tcx> {
pub codegen_units: &'tcx [CodegenUnit<'tcx>],
pub all_mono_items: &'tcx DefIdSet,
pub autodiff_items: &'tcx [AutoDiffItem],
}
#[derive(Debug, HashStable)]

View file

@ -1789,7 +1789,7 @@ pub fn write_allocation_bytes<'tcx, Prov: Provenance, Extra, Bytes: AllocBytes>(
ascii.push('╼');
i += ptr_size;
}
} else if let Some(prov) = alloc.provenance().get(i, &tcx) {
} else if let Some((prov, idx)) = alloc.provenance().get_byte(i, &tcx) {
// Memory with provenance must be defined
assert!(
alloc.init_mask().is_range_initialized(alloc_range(i, Size::from_bytes(1))).is_ok()
@ -1799,7 +1799,7 @@ pub fn write_allocation_bytes<'tcx, Prov: Provenance, Extra, Bytes: AllocBytes>(
// Format is similar to "oversized" above.
let j = i.bytes_usize();
let c = alloc.inspect_with_uninit_and_ptr_outside_interpreter(j..j + 1)[0];
write!(w, "╾{c:02x}{prov:#?} (1 ptr byte)╼")?;
write!(w, "╾{c:02x}{prov:#?} (ptr fragment {idx})╼")?;
i += Size::from_bytes(1);
} else if alloc
.init_mask()

View file

@ -832,17 +832,17 @@ pub enum PatKind<'tcx> {
},
/// One of the following:
/// * `&str` (represented as a valtree), which will be handled as a string pattern and thus
/// * `&str`, which will be handled as a string pattern and thus
/// exhaustiveness checking will detect if you use the same string twice in different
/// patterns.
/// * integer, bool, char or float (represented as a valtree), which will be handled by
/// * integer, bool, char or float, which will be handled by
/// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are
/// much simpler.
/// * raw pointers derived from integers, other raw pointers will have already resulted in an
// error.
/// * `String`, if `string_deref_patterns` is enabled.
Constant {
value: mir::Const<'tcx>,
value: ty::Value<'tcx>,
},
/// Pattern obtained by converting a constant (inline or named) to its pattern
@ -935,7 +935,7 @@ impl<'tcx> PatRange<'tcx> {
let lo_is_min = match self.lo {
PatRangeBoundary::NegInfinity => true,
PatRangeBoundary::Finite(value) => {
let lo = value.try_to_bits(size).unwrap() ^ bias;
let lo = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias;
lo <= min
}
PatRangeBoundary::PosInfinity => false,
@ -944,7 +944,7 @@ impl<'tcx> PatRange<'tcx> {
let hi_is_max = match self.hi {
PatRangeBoundary::NegInfinity => false,
PatRangeBoundary::Finite(value) => {
let hi = value.try_to_bits(size).unwrap() ^ bias;
let hi = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias;
hi > max || hi == max && self.end == RangeEnd::Included
}
PatRangeBoundary::PosInfinity => true,
@ -957,22 +957,17 @@ impl<'tcx> PatRange<'tcx> {
}
#[inline]
pub fn contains(
&self,
value: mir::Const<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<bool> {
pub fn contains(&self, value: ty::Value<'tcx>, tcx: TyCtxt<'tcx>) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, value.ty());
debug_assert_eq!(value.ty, self.ty);
let ty = self.ty;
let value = PatRangeBoundary::Finite(value);
let value = PatRangeBoundary::Finite(value.valtree);
// For performance, it's important to only do the second comparison if necessary.
Some(
match self.lo.compare_with(value, ty, tcx, typing_env)? {
match self.lo.compare_with(value, ty, tcx)? {
Less | Equal => true,
Greater => false,
} && match value.compare_with(self.hi, ty, tcx, typing_env)? {
} && match value.compare_with(self.hi, ty, tcx)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
@ -981,21 +976,16 @@ impl<'tcx> PatRange<'tcx> {
}
#[inline]
pub fn overlaps(
&self,
other: &Self,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<bool> {
pub fn overlaps(&self, other: &Self, tcx: TyCtxt<'tcx>) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, other.ty);
// For performance, it's important to only do the second comparison if necessary.
Some(
match other.lo.compare_with(self.hi, self.ty, tcx, typing_env)? {
match other.lo.compare_with(self.hi, self.ty, tcx)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
} && match self.lo.compare_with(other.hi, self.ty, tcx, typing_env)? {
} && match self.lo.compare_with(other.hi, self.ty, tcx)? {
Less => true,
Equal => other.end == RangeEnd::Included,
Greater => false,
@ -1006,11 +996,13 @@ impl<'tcx> PatRange<'tcx> {
impl<'tcx> fmt::Display for PatRange<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let PatRangeBoundary::Finite(value) = &self.lo {
if let &PatRangeBoundary::Finite(valtree) = &self.lo {
let value = ty::Value { ty: self.ty, valtree };
write!(f, "{value}")?;
}
if let PatRangeBoundary::Finite(value) = &self.hi {
if let &PatRangeBoundary::Finite(valtree) = &self.hi {
write!(f, "{}", self.end)?;
let value = ty::Value { ty: self.ty, valtree };
write!(f, "{value}")?;
} else {
// `0..` is parsed as an inclusive range, we must display it correctly.
@ -1024,7 +1016,8 @@ impl<'tcx> fmt::Display for PatRange<'tcx> {
/// If present, the const must be of a numeric type.
#[derive(Copy, Clone, Debug, PartialEq, HashStable, TypeVisitable)]
pub enum PatRangeBoundary<'tcx> {
Finite(mir::Const<'tcx>),
/// The type of this valtree is stored in the surrounding `PatRange`.
Finite(ty::ValTree<'tcx>),
NegInfinity,
PosInfinity,
}
@ -1035,20 +1028,15 @@ impl<'tcx> PatRangeBoundary<'tcx> {
matches!(self, Self::Finite(..))
}
#[inline]
pub fn as_finite(self) -> Option<mir::Const<'tcx>> {
pub fn as_finite(self) -> Option<ty::ValTree<'tcx>> {
match self {
Self::Finite(value) => Some(value),
Self::NegInfinity | Self::PosInfinity => None,
}
}
pub fn eval_bits(
self,
ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> u128 {
pub fn to_bits(self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> u128 {
match self {
Self::Finite(value) => value.eval_bits(tcx, typing_env),
Self::Finite(value) => value.try_to_scalar_int().unwrap().to_bits_unchecked(),
Self::NegInfinity => {
// Unwrap is ok because the type is known to be numeric.
ty.numeric_min_and_max_as_bits(tcx).unwrap().0
@ -1060,14 +1048,8 @@ impl<'tcx> PatRangeBoundary<'tcx> {
}
}
#[instrument(skip(tcx, typing_env), level = "debug", ret)]
pub fn compare_with(
self,
other: Self,
ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<Ordering> {
#[instrument(skip(tcx), level = "debug", ret)]
pub fn compare_with(self, other: Self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Ordering> {
use PatRangeBoundary::*;
match (self, other) {
// When comparing with infinities, we must remember that `0u8..` and `0u8..=255`
@ -1095,8 +1077,8 @@ impl<'tcx> PatRangeBoundary<'tcx> {
_ => {}
}
let a = self.eval_bits(ty, tcx, typing_env);
let b = other.eval_bits(ty, tcx, typing_env);
let a = self.to_bits(ty, tcx);
let b = other.to_bits(ty, tcx);
match ty.kind() {
ty::Float(ty::FloatTy::F16) => {

View file

@ -2,10 +2,12 @@ use std::fmt;
use std::ops::Deref;
use rustc_data_structures::intern::Interned;
use rustc_hir::def::Namespace;
use rustc_macros::{HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
use super::ScalarInt;
use crate::mir::interpret::{ErrorHandled, Scalar};
use crate::ty::print::{FmtPrinter, PrettyPrinter};
use crate::ty::{self, Ty, TyCtxt};
/// This datastructure is used to represent the value of constants used in the type system.
@ -133,6 +135,8 @@ pub type ConstToValTreeResult<'tcx> = Result<Result<ValTree<'tcx>, Ty<'tcx>>, Er
/// A type-level constant value.
///
/// Represents a typed, fully evaluated constant.
/// Note that this is also used by pattern elaboration to represent values which cannot occur in types,
/// such as raw pointers and floats.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable, Lift)]
pub struct Value<'tcx> {
@ -203,3 +207,14 @@ impl<'tcx> rustc_type_ir::inherent::ValueConst<TyCtxt<'tcx>> for Value<'tcx> {
self.valtree
}
}
impl<'tcx> fmt::Display for Value<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ty::tls::with(move |tcx| {
let cv = tcx.lift(*self).unwrap();
let mut p = FmtPrinter::new(tcx, Namespace::ValueNS);
p.pretty_print_const_valtree(cv, /*print_ty*/ true)?;
f.write_str(&p.into_buffer())
})
}
}

View file

@ -337,10 +337,10 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
false
}
/// Returns `true` if the region should be printed in
/// optional positions, e.g., `&'a T` or `dyn Tr + 'b`.
/// This is typically the case for all non-`'_` regions.
fn should_print_region(&self, region: ty::Region<'tcx>) -> bool;
/// Returns `true` if the region should be printed in optional positions,
/// e.g., `&'a T` or `dyn Tr + 'b`. (Regions like the one in `Cow<'static, T>`
/// will always be printed.)
fn should_print_optional_region(&self, region: ty::Region<'tcx>) -> bool;
fn reset_type_limit(&mut self) {}
@ -717,7 +717,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
}
ty::Ref(r, ty, mutbl) => {
write!(self, "&")?;
if self.should_print_region(r) {
if self.should_print_optional_region(r) {
r.print(self)?;
write!(self, " ")?;
}
@ -785,7 +785,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
},
ty::Adt(def, args) => self.print_def_path(def.did(), args)?,
ty::Dynamic(data, r, repr) => {
let print_r = self.should_print_region(r);
let print_r = self.should_print_optional_region(r);
if print_r {
write!(self, "(")?;
}
@ -2494,7 +2494,7 @@ impl<'tcx> PrettyPrinter<'tcx> for FmtPrinter<'_, 'tcx> {
!self.type_length_limit.value_within_limit(self.printed_type_count)
}
fn should_print_region(&self, region: ty::Region<'tcx>) -> bool {
fn should_print_optional_region(&self, region: ty::Region<'tcx>) -> bool {
let highlight = self.region_highlight_mode;
if highlight.region_highlighted(region).is_some() {
return true;

View file

@ -12,7 +12,6 @@ use rustc_hir::def_id::LocalDefId;
use rustc_span::source_map::Spanned;
use rustc_type_ir::{ConstKind, TypeFolder, VisitorResult, try_visit};
use super::print::PrettyPrinter;
use super::{GenericArg, GenericArgKind, Pattern, Region};
use crate::mir::PlaceElem;
use crate::ty::print::{FmtPrinter, Printer, with_no_trimmed_paths};
@ -168,15 +167,11 @@ impl<'tcx> fmt::Debug for ty::Const<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// If this is a value, we spend some effort to make it look nice.
if let ConstKind::Value(cv) = self.kind() {
return ty::tls::with(move |tcx| {
let cv = tcx.lift(cv).unwrap();
let mut p = FmtPrinter::new(tcx, Namespace::ValueNS);
p.pretty_print_const_valtree(cv, /*print_ty*/ true)?;
f.write_str(&p.into_buffer())
});
write!(f, "{}", cv)
} else {
// Fall back to something verbose.
write!(f, "{:?}", self.kind())
}
// Fall back to something verbose.
write!(f, "{:?}", self.kind())
}
}

View file

@ -160,7 +160,7 @@ impl<'a, 'tcx> ParseCtxt<'a, 'tcx> {
});
}
};
values.push(value.eval_bits(self.tcx, self.typing_env));
values.push(value.valtree.unwrap_leaf().to_bits_unchecked());
targets.push(self.parse_block(arm.body)?);
}

View file

@ -16,7 +16,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir::{BindingMode, ByRef, LetStmt, LocalSource, Node};
use rustc_middle::bug;
use rustc_middle::middle::region;
use rustc_middle::mir::{self, *};
use rustc_middle::mir::*;
use rustc_middle::thir::{self, *};
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, ValTree, ValTreeKind};
use rustc_pattern_analysis::constructor::RangeEnd;
@ -1245,7 +1245,7 @@ struct Ascription<'tcx> {
#[derive(Debug, Clone)]
enum TestCase<'tcx> {
Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx },
Constant { value: mir::Const<'tcx> },
Constant { value: ty::Value<'tcx> },
Range(Arc<PatRange<'tcx>>),
Slice { len: usize, variable_length: bool },
Deref { temp: Place<'tcx>, mutability: Mutability },
@ -1316,13 +1316,13 @@ enum TestKind<'tcx> {
If,
/// Test for equality with value, possibly after an unsizing coercion to
/// `ty`,
/// `cast_ty`,
Eq {
value: Const<'tcx>,
value: ty::Value<'tcx>,
// Integer types are handled by `SwitchInt`, and constants with ADT
// types and `&[T]` types are converted back into patterns, so this can
// only be `&str`, `f32` or `f64`.
ty: Ty<'tcx>,
// only be `&str` or floats.
cast_ty: Ty<'tcx>,
},
/// Test whether the value falls within an inclusive or exclusive range.
@ -1357,8 +1357,8 @@ pub(crate) struct Test<'tcx> {
enum TestBranch<'tcx> {
/// Success branch, used for tests with two possible outcomes.
Success,
/// Branch corresponding to this constant.
Constant(Const<'tcx>, u128),
/// Branch corresponding to this constant. Must be a scalar.
Constant(ty::Value<'tcx>),
/// Branch corresponding to this variant.
Variant(VariantIdx),
/// Failure branch for tests with two possible outcomes, and "otherwise" branch for other tests.
@ -1366,8 +1366,8 @@ enum TestBranch<'tcx> {
}
impl<'tcx> TestBranch<'tcx> {
fn as_constant(&self) -> Option<&Const<'tcx>> {
if let Self::Constant(v, _) = self { Some(v) } else { None }
fn as_constant(&self) -> Option<ty::Value<'tcx>> {
if let Self::Constant(v) = self { Some(*v) } else { None }
}
}

View file

@ -35,7 +35,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
TestCase::Constant { .. } if match_pair.pattern_ty.is_bool() => TestKind::If,
TestCase::Constant { .. } if is_switch_ty(match_pair.pattern_ty) => TestKind::SwitchInt,
TestCase::Constant { value } => TestKind::Eq { value, ty: match_pair.pattern_ty },
TestCase::Constant { value } => TestKind::Eq { value, cast_ty: match_pair.pattern_ty },
TestCase::Range(ref range) => {
assert_eq!(range.ty, match_pair.pattern_ty);
@ -112,7 +112,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let otherwise_block = target_block(TestBranch::Failure);
let switch_targets = SwitchTargets::new(
target_blocks.iter().filter_map(|(&branch, &block)| {
if let TestBranch::Constant(_, bits) = branch {
if let TestBranch::Constant(value) = branch {
let bits = value.valtree.unwrap_leaf().to_bits_unchecked();
Some((bits, block))
} else {
None
@ -135,17 +136,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.cfg.terminate(block, self.source_info(match_start_span), terminator);
}
TestKind::Eq { value, mut ty } => {
TestKind::Eq { value, mut cast_ty } => {
let tcx = self.tcx;
let success_block = target_block(TestBranch::Success);
let fail_block = target_block(TestBranch::Failure);
let mut expect_ty = value.ty();
let mut expect = self.literal_operand(test.span, value);
let mut expect_ty = value.ty;
let mut expect = self.literal_operand(test.span, Const::from_ty_value(tcx, value));
let mut place = place;
let mut block = block;
match ty.kind() {
match cast_ty.kind() {
ty::Str => {
// String literal patterns may have type `str` if `deref_patterns` is
// enabled, in order to allow `deref!("..."): String`. In this case, `value`
@ -167,7 +168,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Rvalue::Ref(re_erased, BorrowKind::Shared, place),
);
place = ref_place;
ty = ref_str_ty;
cast_ty = ref_str_ty;
}
ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::String) => {
if !tcx.features().string_deref_patterns() {
@ -186,7 +187,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
eq_block,
place,
Mutability::Not,
ty,
cast_ty,
ref_str,
test.span,
);
@ -195,10 +196,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// Similarly, the normal test code should be generated for the `&str`, instead of the `String`.
block = eq_block;
place = ref_str;
ty = ref_str_ty;
cast_ty = ref_str_ty;
}
&ty::Pat(base, _) => {
assert_eq!(ty, value.ty());
assert_eq!(cast_ty, value.ty);
assert!(base.is_trivially_pure_clone_copy());
let transmuted_place = self.temp(base, test.span);
@ -219,14 +220,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
place = transmuted_place;
expect = Operand::Copy(transmuted_expect);
ty = base;
cast_ty = base;
expect_ty = base;
}
_ => {}
}
assert_eq!(expect_ty, ty);
if !ty.is_scalar() {
assert_eq!(expect_ty, cast_ty);
if !cast_ty.is_scalar() {
// Use `PartialEq::eq` instead of `BinOp::Eq`
// (the binop can only handle primitives)
// Make sure that we do *not* call any user-defined code here.
@ -234,10 +235,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// comparison defined in `core`.
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
// on the `PartialEq` impl for `str` to b correct!)
match *ty.kind() {
match *cast_ty.kind() {
ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {}
_ => {
span_bug!(source_info.span, "invalid type for non-scalar compare: {ty}")
span_bug!(
source_info.span,
"invalid type for non-scalar compare: {cast_ty}"
)
}
};
self.string_compare(
@ -276,7 +280,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
if let Some(lo) = range.lo.as_finite() {
let lo = self.literal_operand(test.span, lo);
let lo = ty::Value { ty: range.ty, valtree: lo };
let lo = self.literal_operand(test.span, Const::from_ty_value(self.tcx, lo));
self.compare(
block,
intermediate_block,
@ -289,7 +294,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
if let Some(hi) = range.hi.as_finite() {
let hi = self.literal_operand(test.span, hi);
let hi = ty::Value { ty: range.ty, valtree: hi };
let hi = self.literal_operand(test.span, Const::from_ty_value(self.tcx, hi));
let op = match range.end {
RangeEnd::Included => BinOp::Le,
RangeEnd::Excluded => BinOp::Lt,
@ -555,10 +561,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// not to add such values here.
let is_covering_range = |test_case: &TestCase<'tcx>| {
test_case.as_range().is_some_and(|range| {
matches!(
range.contains(value, self.tcx, self.typing_env()),
None | Some(true)
)
matches!(range.contains(value, self.tcx), None | Some(true))
})
};
let is_conflicting_candidate = |candidate: &&mut Candidate<'tcx>| {
@ -575,8 +578,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
None
} else {
fully_matched = true;
let bits = value.eval_bits(self.tcx, self.typing_env());
Some(TestBranch::Constant(value, bits))
Some(TestBranch::Constant(value))
}
}
(TestKind::SwitchInt, TestCase::Range(range)) => {
@ -585,12 +587,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// the values being tested. (This restricts what values can be
// added to the test by subsequent candidates.)
fully_matched = false;
let not_contained =
sorted_candidates.keys().filter_map(|br| br.as_constant()).copied().all(
|val| {
matches!(range.contains(val, self.tcx, self.typing_env()), Some(false))
},
);
let not_contained = sorted_candidates
.keys()
.filter_map(|br| br.as_constant())
.all(|val| matches!(range.contains(val, self.tcx), Some(false)));
not_contained.then(|| {
// No switch values are contained in the pattern range,
@ -601,7 +601,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
(TestKind::If, TestCase::Constant { value }) => {
fully_matched = true;
let value = value.try_eval_bool(self.tcx, self.typing_env()).unwrap_or_else(|| {
let value = value.try_to_bool().unwrap_or_else(|| {
span_bug!(test.span, "expected boolean value but got {value:?}")
});
Some(if value { TestBranch::Success } else { TestBranch::Failure })
@ -681,16 +681,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
fully_matched = false;
// If the testing range does not overlap with pattern range,
// the pattern can be matched only if this test fails.
if !test.overlaps(pat, self.tcx, self.typing_env())? {
Some(TestBranch::Failure)
} else {
None
}
if !test.overlaps(pat, self.tcx)? { Some(TestBranch::Failure) } else { None }
}
}
(TestKind::Range(range), &TestCase::Constant { value }) => {
fully_matched = false;
if !range.contains(value, self.tcx, self.typing_env())? {
if !range.contains(value, self.tcx)? {
// `value` is not contained in the testing range,
// so `value` can be matched only if this test fails.
Some(TestBranch::Failure)

View file

@ -3,6 +3,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_errors::Applicability;
use rustc_hir::LangItem;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::CRATE_DEF_ID;
use rustc_middle::span_bug;
use rustc_middle::thir::visit::{self, Visitor};
use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
@ -136,7 +137,15 @@ impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
if caller_sig.inputs() != callee_sig.inputs() {
self.report_arguments_mismatch(expr.span, caller_sig, callee_sig);
self.report_arguments_mismatch(
expr.span,
self.tcx.liberate_late_bound_regions(
CRATE_DEF_ID.to_def_id(),
self.caller_ty.fn_sig(self.tcx),
),
self.tcx
.liberate_late_bound_regions(CRATE_DEF_ID.to_def_id(), ty.fn_sig(self.tcx)),
);
}
// FIXME(explicit_tail_calls): this currently fails for cases where opaques are used.

View file

@ -11,11 +11,11 @@ use rustc_index::Idx;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::Obligation;
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::span_bug;
use rustc_middle::thir::{FieldPat, Pat, PatKind};
use rustc_middle::ty::{
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitableExt, TypeVisitor, ValTree,
};
use rustc_middle::{mir, span_bug};
use rustc_span::def_id::DefId;
use rustc_span::{DUMMY_SP, Span};
use rustc_trait_selection::traits::ObligationCause;
@ -288,16 +288,12 @@ impl<'tcx> ConstToPat<'tcx> {
// when lowering to MIR in `Builder::perform_test`, treat the constant as a `&str`.
// This works because `str` and `&str` have the same valtree representation.
let ref_str_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, ty);
PatKind::Constant {
value: mir::Const::Ty(ref_str_ty, ty::Const::new_value(tcx, cv, ref_str_ty)),
}
PatKind::Constant { value: ty::Value { ty: ref_str_ty, valtree: cv } }
}
ty::Ref(_, pointee_ty, ..) => match *pointee_ty.kind() {
// `&str` is represented as a valtree, let's keep using this
// optimization for now.
ty::Str => PatKind::Constant {
value: mir::Const::Ty(ty, ty::Const::new_value(tcx, cv, ty)),
},
ty::Str => PatKind::Constant { value: ty::Value { ty, valtree: cv } },
// All other references are converted into deref patterns and then recursively
// convert the dereferenced constant to a pattern that is the sub-pattern of the
// deref pattern.
@ -326,15 +322,13 @@ impl<'tcx> ConstToPat<'tcx> {
// Also see <https://github.com/rust-lang/rfcs/pull/3535>.
return self.mk_err(tcx.dcx().create_err(NaNPattern { span }), ty);
} else {
PatKind::Constant {
value: mir::Const::Ty(ty, ty::Const::new_value(tcx, cv, ty)),
}
PatKind::Constant { value: ty::Value { ty, valtree: cv } }
}
}
ty::Pat(..) | ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::RawPtr(..) => {
// The raw pointers we see here have been "vetted" by valtree construction to be
// just integers, so we simply allow them.
PatKind::Constant { value: mir::Const::Ty(ty, ty::Const::new_value(tcx, cv, ty)) }
PatKind::Constant { value: ty::Value { ty, valtree: cv } }
}
ty::FnPtr(..) => {
unreachable!(

View file

@ -161,8 +161,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
format!("found bad range pattern endpoint `{expr:?}` outside of error recovery");
return Err(self.tcx.dcx().span_delayed_bug(expr.span, msg));
};
Ok(Some(PatRangeBoundary::Finite(value)))
Ok(Some(PatRangeBoundary::Finite(value.valtree)))
}
/// Overflowing literals are linted against in a late pass. This is mostly fine, except when we
@ -235,7 +234,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
let lo = lower_endpoint(lo_expr)?.unwrap_or(PatRangeBoundary::NegInfinity);
let hi = lower_endpoint(hi_expr)?.unwrap_or(PatRangeBoundary::PosInfinity);
let cmp = lo.compare_with(hi, ty, self.tcx, self.typing_env);
let cmp = lo.compare_with(hi, ty, self.tcx);
let mut kind = PatKind::Range(Arc::new(PatRange { lo, hi, end, ty }));
match (end, cmp) {
// `x..y` where `x < y`.
@ -244,7 +243,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
(RangeEnd::Included, Some(Ordering::Less)) => {}
// `x..=y` where `x == y` and `x` and `y` are finite.
(RangeEnd::Included, Some(Ordering::Equal)) if lo.is_finite() && hi.is_finite() => {
kind = PatKind::Constant { value: lo.as_finite().unwrap() };
let value = ty::Value { ty, valtree: lo.as_finite().unwrap() };
kind = PatKind::Constant { value };
}
// `..=x` where `x == ty::MIN`.
(RangeEnd::Included, Some(Ordering::Equal)) if !lo.is_finite() => {}

View file

@ -763,7 +763,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
}
PatKind::Constant { value } => {
print_indented!(self, "Constant {", depth_lvl + 1);
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, format!("value: {}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::ExpandedConstant { def_id, subpattern } => {

View file

@ -1,51 +1,36 @@
use rustc_index::IndexVec;
use rustc_middle::mir::coverage::{BlockMarkerId, BranchSpan, CoverageInfoHi, CoverageKind};
use rustc_middle::mir::coverage::{
BlockMarkerId, BranchSpan, CoverageInfoHi, CoverageKind, Mapping, MappingKind,
};
use rustc_middle::mir::{self, BasicBlock, StatementKind};
use rustc_middle::ty::TyCtxt;
use rustc_span::Span;
use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph};
use crate::coverage::graph::CoverageGraph;
use crate::coverage::hir_info::ExtractedHirInfo;
use crate::coverage::spans::extract_refined_covspans;
use crate::coverage::unexpand::unexpand_into_body_span;
/// Associates an ordinary executable code span with its corresponding BCB.
#[derive(Debug)]
pub(super) struct CodeMapping {
pub(super) span: Span,
pub(super) bcb: BasicCoverageBlock,
}
#[derive(Debug)]
pub(super) struct BranchPair {
pub(super) span: Span,
pub(super) true_bcb: BasicCoverageBlock,
pub(super) false_bcb: BasicCoverageBlock,
}
#[derive(Default)]
pub(super) struct ExtractedMappings {
pub(super) code_mappings: Vec<CodeMapping>,
pub(super) branch_pairs: Vec<BranchPair>,
pub(crate) struct ExtractedMappings {
pub(crate) mappings: Vec<Mapping>,
}
/// Extracts coverage-relevant spans from MIR, and associates them with
/// their corresponding BCBs.
pub(super) fn extract_all_mapping_info_from_mir<'tcx>(
/// Extracts coverage-relevant spans from MIR, and uses them to create
/// coverage mapping data for inclusion in MIR.
pub(crate) fn extract_mappings_from_mir<'tcx>(
tcx: TyCtxt<'tcx>,
mir_body: &mir::Body<'tcx>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
) -> ExtractedMappings {
let mut code_mappings = vec![];
let mut branch_pairs = vec![];
let mut mappings = vec![];
// Extract ordinary code mappings from MIR statement/terminator spans.
extract_refined_covspans(tcx, mir_body, hir_info, graph, &mut code_mappings);
extract_refined_covspans(tcx, mir_body, hir_info, graph, &mut mappings);
branch_pairs.extend(extract_branch_pairs(mir_body, hir_info, graph));
extract_branch_mappings(mir_body, hir_info, graph, &mut mappings);
ExtractedMappings { code_mappings, branch_pairs }
ExtractedMappings { mappings }
}
fn resolve_block_markers(
@ -69,19 +54,18 @@ fn resolve_block_markers(
block_markers
}
pub(super) fn extract_branch_pairs(
pub(super) fn extract_branch_mappings(
mir_body: &mir::Body<'_>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
) -> Vec<BranchPair> {
let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return vec![] };
mappings: &mut Vec<Mapping>,
) {
let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return };
let block_markers = resolve_block_markers(coverage_info_hi, mir_body);
coverage_info_hi
.branch_spans
.iter()
.filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| {
mappings.extend(coverage_info_hi.branch_spans.iter().filter_map(
|&BranchSpan { span: raw_span, true_marker, false_marker }| try {
// For now, ignore any branch span that was introduced by
// expansion. This makes things like assert macros less noisy.
if !raw_span.ctxt().outer_expn_data().is_root() {
@ -94,7 +78,7 @@ pub(super) fn extract_branch_pairs(
let true_bcb = bcb_from_marker(true_marker)?;
let false_bcb = bcb_from_marker(false_marker)?;
Some(BranchPair { span, true_bcb, false_bcb })
})
.collect::<Vec<_>>()
Mapping { span, kind: MappingKind::Branch { true_bcb, false_bcb } }
},
));
}

View file

@ -1,4 +1,4 @@
use rustc_middle::mir::coverage::{CoverageKind, FunctionCoverageInfo, Mapping, MappingKind};
use rustc_middle::mir::coverage::{CoverageKind, FunctionCoverageInfo};
use rustc_middle::mir::{self, BasicBlock, Statement, StatementKind, TerminatorKind};
use rustc_middle::ty::TyCtxt;
use tracing::{debug, debug_span, trace};
@ -71,10 +71,8 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir:
////////////////////////////////////////////////////
// Extract coverage spans and other mapping info from MIR.
let extracted_mappings =
mappings::extract_all_mapping_info_from_mir(tcx, mir_body, &hir_info, &graph);
let mappings = create_mappings(&extracted_mappings);
let ExtractedMappings { mappings } =
mappings::extract_mappings_from_mir(tcx, mir_body, &hir_info, &graph);
if mappings.is_empty() {
// No spans could be converted into valid mappings, so skip this function.
debug!("no spans could be converted into valid mappings; skipping");
@ -100,34 +98,6 @@ fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir:
}));
}
/// For each coverage span extracted from MIR, create a corresponding mapping.
///
/// FIXME(Zalathar): This used to be where BCBs in the extracted mappings were
/// resolved to a `CovTerm`. But that is now handled elsewhere, so this
/// function can potentially be simplified even further.
fn create_mappings(extracted_mappings: &ExtractedMappings) -> Vec<Mapping> {
// Fully destructure the mappings struct to make sure we don't miss any kinds.
let ExtractedMappings { code_mappings, branch_pairs } = extracted_mappings;
let mut mappings = Vec::new();
mappings.extend(code_mappings.iter().map(
// Ordinary code mappings are the simplest kind.
|&mappings::CodeMapping { span, bcb }| {
let kind = MappingKind::Code { bcb };
Mapping { kind, span }
},
));
mappings.extend(branch_pairs.iter().map(
|&mappings::BranchPair { span, true_bcb, false_bcb }| {
let kind = MappingKind::Branch { true_bcb, false_bcb };
Mapping { kind, span }
},
));
mappings
}
/// Inject any necessary coverage statements into MIR, so that they influence codegen.
fn inject_coverage_statements<'tcx>(mir_body: &mut mir::Body<'tcx>, graph: &CoverageGraph) {
for (bcb, data) in graph.iter_enumerated() {

View file

@ -1,6 +1,6 @@
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir;
use rustc_middle::mir::coverage::START_BCB;
use rustc_middle::mir::coverage::{Mapping, MappingKind, START_BCB};
use rustc_middle::ty::TyCtxt;
use rustc_span::source_map::SourceMap;
use rustc_span::{BytePos, DesugaringKind, ExpnKind, MacroKind, Span};
@ -9,7 +9,7 @@ use tracing::instrument;
use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph};
use crate::coverage::hir_info::ExtractedHirInfo;
use crate::coverage::spans::from_mir::{Hole, RawSpanFromMir, SpanFromMir};
use crate::coverage::{mappings, unexpand};
use crate::coverage::unexpand;
mod from_mir;
@ -18,7 +18,7 @@ pub(super) fn extract_refined_covspans<'tcx>(
mir_body: &mir::Body<'tcx>,
hir_info: &ExtractedHirInfo,
graph: &CoverageGraph,
code_mappings: &mut Vec<mappings::CodeMapping>,
mappings: &mut Vec<Mapping>,
) {
if hir_info.is_async_fn {
// An async function desugars into a function that returns a future,
@ -26,7 +26,7 @@ pub(super) fn extract_refined_covspans<'tcx>(
// outer function will be unhelpful, so just keep the signature span
// and ignore all of the spans in the MIR body.
if let Some(span) = hir_info.fn_sig_span {
code_mappings.push(mappings::CodeMapping { span, bcb: START_BCB });
mappings.push(Mapping { span, kind: MappingKind::Code { bcb: START_BCB } })
}
return;
}
@ -111,9 +111,9 @@ pub(super) fn extract_refined_covspans<'tcx>(
// Merge covspans that can be merged.
covspans.dedup_by(|b, a| a.merge_if_eligible(b));
code_mappings.extend(covspans.into_iter().map(|Covspan { span, bcb }| {
mappings.extend(covspans.into_iter().map(|Covspan { span, bcb }| {
// Each span produced by the refiner represents an ordinary code region.
mappings::CodeMapping { span, bcb }
Mapping { span, kind: MappingKind::Code { bcb } }
}));
}

View file

@ -41,19 +41,40 @@ fn to_profiler_name(type_name: &'static str) -> &'static str {
})
}
// const wrapper for `if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name }`
const fn c_name(name: &'static str) -> &'static str {
// A function that simplifies a pass's type_name. E.g. `Baz`, `Baz<'_>`,
// `foo::bar::Baz`, and `foo::bar::Baz<'a, 'b>` all become `Baz`.
//
// It's `const` for perf reasons: it's called a lot, and doing the string
// operations at runtime causes a non-trivial slowdown. If
// `split_once`/`rsplit_once` become `const` its body could be simplified to
// this:
// ```ignore (fragment)
// let name = if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name };
// let name = if let Some((head, _)) = name.split_once('<') { head } else { name };
// name
// ```
const fn simplify_pass_type_name(name: &'static str) -> &'static str {
// FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable.
// and inline into call site
// Work backwards from the end. If a ':' is hit, strip it and everything before it.
let bytes = name.as_bytes();
let mut i = bytes.len();
while i > 0 && bytes[i - 1] != b':' {
i = i - 1;
i -= 1;
}
let (_, bytes) = bytes.split_at(i);
// Work forwards from the start of what's left. If a '<' is hit, strip it and everything after
// it.
let mut i = 0;
while i < bytes.len() && bytes[i] != b'<' {
i += 1;
}
let (bytes, _) = bytes.split_at(i);
match std::str::from_utf8(bytes) {
Ok(name) => name,
Err(_) => name,
Err(_) => panic!(),
}
}
@ -62,12 +83,7 @@ const fn c_name(name: &'static str) -> &'static str {
/// loop that goes over each available MIR and applies `run_pass`.
pub(super) trait MirPass<'tcx> {
fn name(&self) -> &'static str {
// FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable.
// See copypaste in `MirLint`
const {
let name = std::any::type_name::<Self>();
c_name(name)
}
const { simplify_pass_type_name(std::any::type_name::<Self>()) }
}
fn profiler_name(&self) -> &'static str {
@ -101,12 +117,7 @@ pub(super) trait MirPass<'tcx> {
/// disabled (via the `Lint` adapter).
pub(super) trait MirLint<'tcx> {
fn name(&self) -> &'static str {
// FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable.
// See copypaste in `MirPass`
const {
let name = std::any::type_name::<Self>();
c_name(name)
}
const { simplify_pass_type_name(std::any::type_name::<Self>()) }
}
fn is_enabled(&self, _sess: &Session) -> bool {

View file

@ -6,7 +6,6 @@ edition = "2024"
[dependencies]
# tidy-alphabetical-start
rustc_abi = { path = "../rustc_abi" }
rustc_ast = { path = "../rustc_ast" }
rustc_data_structures = { path = "../rustc_data_structures" }
rustc_errors = { path = "../rustc_errors" }
rustc_fluent_macro = { path = "../rustc_fluent_macro" }
@ -15,7 +14,6 @@ rustc_macros = { path = "../rustc_macros" }
rustc_middle = { path = "../rustc_middle" }
rustc_session = { path = "../rustc_session" }
rustc_span = { path = "../rustc_span" }
rustc_symbol_mangling = { path = "../rustc_symbol_mangling" }
rustc_target = { path = "../rustc_target" }
serde = "1"
serde_json = "1"

View file

@ -205,6 +205,8 @@
//! this is not implemented however: a mono item will be produced
//! regardless of whether it is actually needed or not.
mod autodiff;
use std::cell::OnceCell;
use rustc_data_structures::fx::FxIndexMap;
@ -235,6 +237,7 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan};
use rustc_span::{DUMMY_SP, Span};
use tracing::{debug, instrument, trace};
use crate::collector::autodiff::collect_autodiff_fn;
use crate::errors::{
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
NoOptimizedMir, RecursionLimit,
@ -786,7 +789,35 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirUsedCollector<'a, 'tcx> {
// *Before* monomorphizing, record that we already handled this mention.
self.used_mentioned_items.insert(MentionedItem::Fn(callee_ty));
let callee_ty = self.monomorphize(callee_ty);
visit_fn_use(self.tcx, callee_ty, true, source, &mut self.used_items)
// HACK(explicit_tail_calls): collect tail calls to `#[track_caller]` functions as indirect,
// because we later call them as such, to prevent issues with ABI incompatibility.
// Ideally we'd replace such tail calls with normal call + return, but this requires
// post-mono MIR optimizations, which we don't yet have.
let force_indirect_call =
if matches!(terminator.kind, mir::TerminatorKind::TailCall { .. })
&& let &ty::FnDef(def_id, args) = callee_ty.kind()
&& let instance = ty::Instance::expect_resolve(
self.tcx,
ty::TypingEnv::fully_monomorphized(),
def_id,
args,
source,
)
&& instance.def.requires_caller_location(self.tcx)
{
true
} else {
false
};
visit_fn_use(
self.tcx,
callee_ty,
!force_indirect_call,
source,
&mut self.used_items,
)
}
mir::TerminatorKind::Drop { ref place, .. } => {
let ty = place.ty(self.body, self.tcx).ty;
@ -911,6 +942,8 @@ fn visit_instance_use<'tcx>(
return;
}
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
collect_autodiff_fn(tcx, instance, intrinsic, output);
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any

View file

@ -0,0 +1,48 @@
use rustc_middle::bug;
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
use crate::collector::{MonoItems, create_fn_mono_item};
// Here, we force both primal and diff function to be collected in
// mono so this does not interfere in `autodiff` intrinsics
// codegen process. If they are unused, LLVM will remove them when
// compiling with O3.
pub(crate) fn collect_autodiff_fn<'tcx>(
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
intrinsic: IntrinsicDef,
output: &mut MonoItems<'tcx>,
) {
if intrinsic.name != rustc_span::sym::autodiff {
return;
};
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
}
fn collect_autodiff_fn_from_arg<'tcx>(
arg: GenericArg<'tcx>,
tcx: TyCtxt<'tcx>,
output: &mut MonoItems<'tcx>,
) {
let (instance, span) = match arg.kind() {
ty::GenericArgKind::Type(ty) => match ty.kind() {
ty::FnDef(def_id, substs) => {
let span = tcx.def_span(def_id);
let instance = ty::Instance::expect_resolve(
tcx,
ty::TypingEnv::non_body_analysis(tcx, def_id),
*def_id,
substs,
span,
);
(instance, span)
}
_ => bug!("expected autodiff function"),
},
_ => bug!("expected type when matching autodiff arg"),
};
output.push(create_fn_mono_item(tcx, instance, span));
}

View file

@ -92,8 +92,6 @@
//! source-level module, functions from the same module will be available for
//! inlining, even when they are not marked `#[inline]`.
mod autodiff;
use std::cmp;
use std::collections::hash_map::Entry;
use std::fs::{self, File};
@ -251,17 +249,7 @@ where
always_export_generics,
);
// We can't differentiate a function that got inlined.
let autodiff_active = cfg!(llvm_enzyme)
&& matches!(mono_item, MonoItem::Fn(_))
&& cx
.tcx
.codegen_fn_attrs(mono_item.def_id())
.autodiff_item
.as_ref()
.is_some_and(|ad| ad.is_active());
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
if visibility == Visibility::Hidden && can_be_internalized {
internalization_candidates.insert(mono_item);
}
let size_estimate = mono_item.size_estimate(cx.tcx);
@ -1157,27 +1145,15 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
}
}
#[cfg(not(llvm_enzyme))]
let autodiff_mono_items: Vec<_> = vec![];
#[cfg(llvm_enzyme)]
let mut autodiff_mono_items: Vec<_> = vec![];
let mono_items: DefIdSet = items
.iter()
.filter_map(|mono_item| match *mono_item {
MonoItem::Fn(ref instance) => {
#[cfg(llvm_enzyme)]
autodiff_mono_items.push((mono_item, instance));
Some(instance.def_id())
}
MonoItem::Fn(ref instance) => Some(instance.def_id()),
MonoItem::Static(def_id) => Some(def_id),
_ => None,
})
.collect();
let autodiff_items =
autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
// Output monomorphization stats per def_id
if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats
&& let Err(err) =
@ -1235,11 +1211,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
}
}
MonoItemPartitions {
all_mono_items: tcx.arena.alloc(mono_items),
codegen_units,
autodiff_items,
}
MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
}
/// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s

View file

@ -1,143 +0,0 @@
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};
use crate::partitioning::UsageMap;
fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
bug!("expected fn def for autodiff, got {:?}", fn_ty);
}
// We don't actually pass the types back into the type system.
// All we do is decide how to handle the arguments.
let sig = fn_ty.fn_sig(tcx).skip_binder();
let mut new_activities = vec![];
let mut new_positions = vec![];
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};
let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;
// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}
continue;
}
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
for _ in 0..new_activities.len() {
let pos = new_positions.pop().unwrap();
let activity = new_activities.pop().unwrap();
da.insert(pos, activity);
}
}
pub(crate) fn find_autodiff_source_functions<'tcx>(
tcx: TyCtxt<'tcx>,
usage_map: &UsageMap<'tcx>,
autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
) -> Vec<AutoDiffItem> {
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
for (item, instance) in autodiff_mono_items {
let target_id = instance.def_id();
let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
let Some(target_attrs) = cg_fn_attr else {
continue;
};
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
if target_attrs.is_source() {
trace!("source found: {:?}", target_id);
}
if !target_attrs.apply_autodiff() {
continue;
}
let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
let source =
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
MonoItem::Fn(ref instance_s) => {
let source_id = instance_s.def_id();
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
&& ad.is_active()
{
return Some(instance_s);
}
None
}
_ => None,
});
let inst = match source {
Some(source) => source,
None => continue,
};
debug!("source_id: {:?}", inst.def_id());
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
assert!(fn_ty.is_fn());
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
let mut new_target_attrs = target_attrs.clone();
new_target_attrs.input_activity = input_activities;
let itm = new_target_attrs.into_item(symb, target_symbol);
autodiff_items.push(itm);
}
if !autodiff_items.is_empty() {
trace!("AUTODIFF ITEMS EXIST");
for item in &mut *autodiff_items {
trace!("{}", &item);
}
}
autodiff_items
}

View file

@ -8,6 +8,7 @@ use std::ops::ControlFlow;
use derive_where::derive_where;
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::search_graph::CandidateHeadUsages;
use rustc_type_ir::solve::SizedTraitKind;
use rustc_type_ir::{
self as ty, Interner, TypeFlags, TypeFoldable, TypeSuperVisitable, TypeVisitable,
@ -33,10 +34,11 @@ enum AliasBoundKind {
///
/// It consists of both the `source`, which describes how that goal would be proven,
/// and the `result` when using the given `source`.
#[derive_where(Clone, Debug; I: Interner)]
#[derive_where(Debug; I: Interner)]
pub(super) struct Candidate<I: Interner> {
pub(super) source: CandidateSource<I>,
pub(super) result: CanonicalResponse<I>,
pub(super) head_usages: CandidateHeadUsages,
}
/// Methods used to assemble candidates for either trait or projection goals.
@ -116,8 +118,11 @@ where
ecx: &mut EvalCtxt<'_, D>,
goal: Goal<I, Self>,
assumption: I::Clause,
) -> Result<Candidate<I>, NoSolution> {
Self::fast_reject_assumption(ecx, goal, assumption)?;
) -> Result<Candidate<I>, CandidateHeadUsages> {
match Self::fast_reject_assumption(ecx, goal, assumption) {
Ok(()) => {}
Err(NoSolution) => return Err(CandidateHeadUsages::default()),
}
// Dealing with `ParamEnv` candidates is a bit of a mess as we need to lazily
// check whether the candidate is global while considering normalization.
@ -126,18 +131,23 @@ where
// in `probe` even if the candidate does not apply before we get there. We handle this
// by using a `Cell` here. We only ever write into it inside of `match_assumption`.
let source = Cell::new(CandidateSource::ParamEnv(ParamEnvSource::Global));
ecx.probe(|result: &QueryResult<I>| inspect::ProbeKind::TraitCandidate {
source: source.get(),
result: *result,
})
.enter(|ecx| {
Self::match_assumption(ecx, goal, assumption, |ecx| {
ecx.try_evaluate_added_goals()?;
source.set(ecx.characterize_param_env_assumption(goal.param_env, assumption)?);
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
let (result, head_usages) = ecx
.probe(|result: &QueryResult<I>| inspect::ProbeKind::TraitCandidate {
source: source.get(),
result: *result,
})
})
.map(|result| Candidate { source: source.get(), result })
.enter_single_candidate(|ecx| {
Self::match_assumption(ecx, goal, assumption, |ecx| {
ecx.try_evaluate_added_goals()?;
source.set(ecx.characterize_param_env_assumption(goal.param_env, assumption)?);
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
})
});
match result {
Ok(result) => Ok(Candidate { source: source.get(), result, head_usages }),
Err(NoSolution) => Err(head_usages),
}
}
/// Try equating an assumption predicate against a goal's predicate. If it
@ -355,6 +365,19 @@ pub(super) enum AssembleCandidatesFrom {
EnvAndBounds,
}
/// This is currently used to track the [CandidateHeadUsages] of all failed `ParamEnv`
/// candidates. This is then used to ignore their head usages in case there's another
/// always applicable `ParamEnv` candidate. Look at how `param_env_head_usages` is
/// used in the code for more details.
///
/// We could easily extend this to also ignore head usages of other ignored candidates.
/// However, we currently don't have any tests where this matters and the complexity of
/// doing so does not feel worth it for now.
#[derive(Debug)]
pub(super) struct FailedCandidateInfo {
pub param_env_head_usages: CandidateHeadUsages,
}
impl<D, I> EvalCtxt<'_, D>
where
D: SolverDelegate<Interner = I>,
@ -364,16 +387,20 @@ where
&mut self,
goal: Goal<I, G>,
assemble_from: AssembleCandidatesFrom,
) -> Vec<Candidate<I>> {
) -> (Vec<Candidate<I>>, FailedCandidateInfo) {
let mut candidates = vec![];
let mut failed_candidate_info =
FailedCandidateInfo { param_env_head_usages: CandidateHeadUsages::default() };
let Ok(normalized_self_ty) =
self.structurally_normalize_ty(goal.param_env, goal.predicate.self_ty())
else {
return vec![];
return (candidates, failed_candidate_info);
};
if normalized_self_ty.is_ty_var() {
debug!("self type has been normalized to infer");
return self.forced_ambiguity(MaybeCause::Ambiguity).into_iter().collect();
candidates.extend(self.forced_ambiguity(MaybeCause::Ambiguity));
return (candidates, failed_candidate_info);
}
let goal: Goal<I, G> = goal
@ -382,16 +409,15 @@ where
// normalizing the self type as well, since type variables are not uniquified.
let goal = self.resolve_vars_if_possible(goal);
let mut candidates = vec![];
if let TypingMode::Coherence = self.typing_mode()
&& let Ok(candidate) = self.consider_coherence_unknowable_candidate(goal)
{
return vec![candidate];
candidates.push(candidate);
return (candidates, failed_candidate_info);
}
self.assemble_alias_bound_candidates(goal, &mut candidates);
self.assemble_param_env_candidates(goal, &mut candidates);
self.assemble_param_env_candidates(goal, &mut candidates, &mut failed_candidate_info);
match assemble_from {
AssembleCandidatesFrom::All => {
@ -423,7 +449,7 @@ where
AssembleCandidatesFrom::EnvAndBounds => {}
}
candidates
(candidates, failed_candidate_info)
}
pub(super) fn forced_ambiguity(
@ -584,9 +610,15 @@ where
&mut self,
goal: Goal<I, G>,
candidates: &mut Vec<Candidate<I>>,
failed_candidate_info: &mut FailedCandidateInfo,
) {
for assumption in goal.param_env.caller_bounds().iter() {
candidates.extend(G::probe_and_consider_param_env_candidate(self, goal, assumption));
match G::probe_and_consider_param_env_candidate(self, goal, assumption) {
Ok(candidate) => candidates.push(candidate),
Err(head_usages) => {
failed_candidate_info.param_env_head_usages.merge_usages(head_usages)
}
}
}
}
@ -661,7 +693,11 @@ where
if let Ok(result) =
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
{
candidates.push(Candidate { source: CandidateSource::AliasBound, result });
candidates.push(Candidate {
source: CandidateSource::AliasBound,
result,
head_usages: CandidateHeadUsages::default(),
});
}
return;
}
@ -959,7 +995,7 @@ where
// Even when a trait bound has been proven using a where-bound, we
// still need to consider alias-bounds for normalization, see
// `tests/ui/next-solver/alias-bound-shadowed-by-env.rs`.
let mut candidates: Vec<_> = self
let (mut candidates, _) = self
.assemble_and_evaluate_candidates(goal, AssembleCandidatesFrom::EnvAndBounds);
// We still need to prefer where-bounds over alias-bounds however.
@ -972,23 +1008,20 @@ where
return inject_normalize_to_rigid_candidate(self);
}
if let Some(response) = self.try_merge_candidates(&candidates) {
if let Some((response, _)) = self.try_merge_candidates(&candidates) {
Ok(response)
} else {
self.flounder(&candidates)
}
}
TraitGoalProvenVia::Misc => {
let mut candidates =
let (mut candidates, _) =
self.assemble_and_evaluate_candidates(goal, AssembleCandidatesFrom::All);
// Prefer "orphaned" param-env normalization predicates, which are used
// (for example, and ideally only) when proving item bounds for an impl.
let candidates_from_env: Vec<_> = candidates
.extract_if(.., |c| matches!(c.source, CandidateSource::ParamEnv(_)))
.collect();
if let Some(response) = self.try_merge_candidates(&candidates_from_env) {
return Ok(response);
if candidates.iter().any(|c| matches!(c.source, CandidateSource::ParamEnv(_))) {
candidates.retain(|c| matches!(c.source, CandidateSource::ParamEnv(_)));
}
// We drop specialized impls to allow normalization via a final impl here. In case
@ -997,7 +1030,7 @@ where
// means we can just ignore inference constraints and don't have to special-case
// constraining the normalized-to `term`.
self.filter_specialized_impls(AllowInferenceConstraints::Yes, &mut candidates);
if let Some(response) = self.try_merge_candidates(&candidates) {
if let Some((response, _)) = self.try_merge_candidates(&candidates) {
Ok(response)
} else {
self.flounder(&candidates)

View file

@ -8,7 +8,7 @@ use rustc_type_ir::fast_reject::DeepRejectCtxt;
use rustc_type_ir::inherent::*;
use rustc_type_ir::relate::Relate;
use rustc_type_ir::relate::solver_relating::RelateExt;
use rustc_type_ir::search_graph::PathKind;
use rustc_type_ir::search_graph::{CandidateHeadUsages, PathKind};
use rustc_type_ir::{
self as ty, CanonicalVarValues, InferCtxtLike, Interner, TypeFoldable, TypeFolder,
TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
@ -399,6 +399,10 @@ where
result
}
pub(super) fn ignore_candidate_head_usages(&mut self, usages: CandidateHeadUsages) {
self.search_graph.ignore_candidate_head_usages(usages);
}
/// Recursively evaluates `goal`, returning whether any inference vars have
/// been constrained and the certainty of the result.
fn evaluate_goal(

View file

@ -1,5 +1,6 @@
use std::marker::PhantomData;
use rustc_type_ir::search_graph::CandidateHeadUsages;
use rustc_type_ir::{InferCtxtLike, Interner};
use tracing::instrument;
@ -25,6 +26,20 @@ where
D: SolverDelegate<Interner = I>,
I: Interner,
{
pub(in crate::solve) fn enter_single_candidate(
self,
f: impl FnOnce(&mut EvalCtxt<'_, D>) -> T,
) -> (T, CandidateHeadUsages) {
self.ecx.search_graph.enter_single_candidate();
let mut candidate_usages = CandidateHeadUsages::default();
let result = self.enter(|ecx| {
let result = f(ecx);
candidate_usages = ecx.search_graph.finish_single_candidate();
result
});
(result, candidate_usages)
}
pub(in crate::solve) fn enter(self, f: impl FnOnce(&mut EvalCtxt<'_, D>) -> T) -> T {
let ProbeCtxt { ecx: outer, probe_kind, _result } = self;
@ -78,7 +93,8 @@ where
self,
f: impl FnOnce(&mut EvalCtxt<'_, D>) -> QueryResult<I>,
) -> Result<Candidate<I>, NoSolution> {
self.cx.enter(|ecx| f(ecx)).map(|result| Candidate { source: self.source, result })
let (result, head_usages) = self.cx.enter_single_candidate(f);
result.map(|result| Candidate { source: self.source, result, head_usages })
}
}

View file

@ -236,6 +236,12 @@ where
}
}
#[derive(Debug)]
enum MergeCandidateInfo {
AlwaysApplicable(usize),
EqualResponse,
}
impl<D, I> EvalCtxt<'_, D>
where
D: SolverDelegate<Interner = I>,
@ -248,23 +254,25 @@ where
fn try_merge_candidates(
&mut self,
candidates: &[Candidate<I>],
) -> Option<CanonicalResponse<I>> {
) -> Option<(CanonicalResponse<I>, MergeCandidateInfo)> {
if candidates.is_empty() {
return None;
}
let one: CanonicalResponse<I> = candidates[0].result;
if candidates[1..].iter().all(|candidate| candidate.result == one) {
return Some(one);
let always_applicable = candidates.iter().enumerate().find(|(_, candidate)| {
candidate.result.value.certainty == Certainty::Yes
&& has_no_inference_or_external_constraints(candidate.result)
});
if let Some((i, c)) = always_applicable {
return Some((c.result, MergeCandidateInfo::AlwaysApplicable(i)));
}
candidates
.iter()
.find(|candidate| {
candidate.result.value.certainty == Certainty::Yes
&& has_no_inference_or_external_constraints(candidate.result)
})
.map(|candidate| candidate.result)
let one: CanonicalResponse<I> = candidates[0].result;
if candidates[1..].iter().all(|candidate| candidate.result == one) {
return Some((one, MergeCandidateInfo::EqualResponse));
}
None
}
fn bail_with_ambiguity(&mut self, candidates: &[Candidate<I>]) -> CanonicalResponse<I> {

View file

@ -13,11 +13,13 @@ use tracing::{debug, instrument, trace};
use crate::delegate::SolverDelegate;
use crate::solve::assembly::structural_traits::{self, AsyncCallableRelevantTypes};
use crate::solve::assembly::{self, AllowInferenceConstraints, AssembleCandidatesFrom, Candidate};
use crate::solve::assembly::{
self, AllowInferenceConstraints, AssembleCandidatesFrom, Candidate, FailedCandidateInfo,
};
use crate::solve::inspect::ProbeKind;
use crate::solve::{
BuiltinImplSource, CandidateSource, Certainty, EvalCtxt, Goal, GoalSource, MaybeCause,
NoSolution, ParamEnvSource, QueryResult, has_only_region_constraints,
MergeCandidateInfo, NoSolution, ParamEnvSource, QueryResult, has_only_region_constraints,
};
impl<D, I> assembly::GoalKind<D> for TraitPredicate<I>
@ -1344,9 +1346,10 @@ where
pub(super) fn merge_trait_candidates(
&mut self,
mut candidates: Vec<Candidate<I>>,
failed_candidate_info: FailedCandidateInfo,
) -> Result<(CanonicalResponse<I>, Option<TraitGoalProvenVia>), NoSolution> {
if let TypingMode::Coherence = self.typing_mode() {
return if let Some(response) = self.try_merge_candidates(&candidates) {
return if let Some((response, _)) = self.try_merge_candidates(&candidates) {
Ok((response, Some(TraitGoalProvenVia::Misc)))
} else {
self.flounder(&candidates).map(|r| (r, None))
@ -1376,10 +1379,41 @@ where
let where_bounds: Vec<_> = candidates
.extract_if(.., |c| matches!(c.source, CandidateSource::ParamEnv(_)))
.collect();
return if let Some(response) = self.try_merge_candidates(&where_bounds) {
Ok((response, Some(TraitGoalProvenVia::ParamEnv)))
if let Some((response, info)) = self.try_merge_candidates(&where_bounds) {
match info {
// If there's an always applicable candidate, the result of all
// other candidates does not matter. This means we can ignore
// them when checking whether we've reached a fixpoint.
//
// We always prefer the first always applicable candidate, even if a
// later candidate is also always applicable and would result in fewer
// reruns. We could slightly improve this by e.g. searching for another
// always applicable candidate which doesn't depend on any cycle heads.
//
// NOTE: This is optimization is observable in case there is an always
// applicable global candidate and another non-global candidate which only
// applies because of a provisional result. I can't even think of a test
// case where this would occur and even then, this would not be unsound.
// Supporting this makes the code more involved, so I am just going to
// ignore this for now.
MergeCandidateInfo::AlwaysApplicable(i) => {
for (j, c) in where_bounds.into_iter().enumerate() {
if i != j {
self.ignore_candidate_head_usages(c.head_usages)
}
}
// If a where-bound does not apply, we don't actually get a
// candidate for it. We manually track the head usages
// of all failed `ParamEnv` candidates instead.
self.ignore_candidate_head_usages(
failed_candidate_info.param_env_head_usages,
);
}
MergeCandidateInfo::EqualResponse => {}
}
return Ok((response, Some(TraitGoalProvenVia::ParamEnv)));
} else {
Ok((self.bail_with_ambiguity(&where_bounds), None))
return Ok((self.bail_with_ambiguity(&where_bounds), None));
};
}
@ -1387,7 +1421,7 @@ where
let alias_bounds: Vec<_> = candidates
.extract_if(.., |c| matches!(c.source, CandidateSource::AliasBound))
.collect();
return if let Some(response) = self.try_merge_candidates(&alias_bounds) {
return if let Some((response, _)) = self.try_merge_candidates(&alias_bounds) {
Ok((response, Some(TraitGoalProvenVia::AliasBound)))
} else {
Ok((self.bail_with_ambiguity(&alias_bounds), None))
@ -1412,7 +1446,7 @@ where
TraitGoalProvenVia::Misc
};
if let Some(response) = self.try_merge_candidates(&candidates) {
if let Some((response, _)) = self.try_merge_candidates(&candidates) {
Ok((response, Some(proven_via)))
} else {
self.flounder(&candidates).map(|r| (r, None))
@ -1424,8 +1458,9 @@ where
&mut self,
goal: Goal<I, TraitPredicate<I>>,
) -> Result<(CanonicalResponse<I>, Option<TraitGoalProvenVia>), NoSolution> {
let candidates = self.assemble_and_evaluate_candidates(goal, AssembleCandidatesFrom::All);
self.merge_trait_candidates(candidates)
let (candidates, failed_candidate_info) =
self.assemble_and_evaluate_candidates(goal, AssembleCandidatesFrom::All);
self.merge_trait_candidates(candidates, failed_candidate_info)
}
fn try_stall_coroutine(&mut self, self_ty: I::Ty) -> Option<Result<Candidate<I>, NoSolution>> {

View file

@ -11,6 +11,7 @@ use tracing::debug;
use super::{
AttrWrapper, Capturing, FnParseMode, ForceCollect, Parser, PathStyle, Trailing, UsePreAttrPos,
};
use crate::parser::FnContext;
use crate::{errors, exp, fluent_generated as fluent};
// Public for rustfmt usage
@ -200,7 +201,7 @@ impl<'a> Parser<'a> {
AttrWrapper::empty(),
true,
false,
FnParseMode { req_name: |_| true, req_body: true },
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true },
ForceCollect::No,
) {
Ok(Some(item)) => {

View file

@ -44,6 +44,7 @@ use crate::errors::{
UnexpectedConstParamDeclaration, UnexpectedConstParamDeclarationSugg, UnmatchedAngleBrackets,
UseEqInstead, WrapType,
};
use crate::parser::FnContext;
use crate::parser::attr::InnerAttrPolicy;
use crate::{exp, fluent_generated as fluent};
@ -2246,6 +2247,7 @@ impl<'a> Parser<'a> {
pat: Box<ast::Pat>,
require_name: bool,
first_param: bool,
fn_parse_mode: &crate::parser::item::FnParseMode,
) -> Option<Ident> {
// If we find a pattern followed by an identifier, it could be an (incorrect)
// C-style parameter declaration.
@ -2268,7 +2270,14 @@ impl<'a> Parser<'a> {
|| self.token == token::Lt
|| self.token == token::CloseParen)
{
let rfc_note = "anonymous parameters are removed in the 2018 edition (see RFC 1685)";
let maybe_emit_anon_params_note = |this: &mut Self, err: &mut Diag<'_>| {
let ed = this.token.span.with_neighbor(this.prev_token.span).edition();
if matches!(fn_parse_mode.context, crate::parser::item::FnContext::Trait)
&& (fn_parse_mode.req_name)(ed)
{
err.note("anonymous parameters are removed in the 2018 edition (see RFC 1685)");
}
};
let (ident, self_sugg, param_sugg, type_sugg, self_span, param_span, type_span) =
match pat.kind {
@ -2305,7 +2314,7 @@ impl<'a> Parser<'a> {
"_: ".to_string(),
Applicability::MachineApplicable,
);
err.note(rfc_note);
maybe_emit_anon_params_note(self, err);
}
return None;
@ -2313,7 +2322,13 @@ impl<'a> Parser<'a> {
};
// `fn foo(a, b) {}`, `fn foo(a<x>, b<y>) {}` or `fn foo(usize, usize) {}`
if first_param {
if first_param
// Only when the fn is a method, we emit this suggestion.
&& matches!(
fn_parse_mode.context,
FnContext::Trait | FnContext::Impl
)
{
err.span_suggestion_verbose(
self_span,
"if this is a `self` type, give it a parameter name",
@ -2337,7 +2352,7 @@ impl<'a> Parser<'a> {
type_sugg,
Applicability::MachineApplicable,
);
err.note(rfc_note);
maybe_emit_anon_params_note(self, err);
// Don't attempt to recover by using the `X` in `X<Y>` as the parameter name.
return if self.token == token::Lt { None } else { Some(ident) };

View file

@ -116,7 +116,8 @@ impl<'a> Parser<'a> {
impl<'a> Parser<'a> {
pub fn parse_item(&mut self, force_collect: ForceCollect) -> PResult<'a, Option<Box<Item>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true };
self.parse_item_(fn_parse_mode, force_collect).map(|i| i.map(Box::new))
}
@ -975,7 +976,8 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<AssocItem>>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Impl, req_body: true };
self.parse_assoc_item(fn_parse_mode, force_collect)
}
@ -983,8 +985,11 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<AssocItem>>>> {
let fn_parse_mode =
FnParseMode { req_name: |edition| edition >= Edition::Edition2018, req_body: false };
let fn_parse_mode = FnParseMode {
req_name: |edition| edition >= Edition::Edition2018,
context: FnContext::Trait,
req_body: false,
};
self.parse_assoc_item(fn_parse_mode, force_collect)
}
@ -1261,7 +1266,8 @@ impl<'a> Parser<'a> {
&mut self,
force_collect: ForceCollect,
) -> PResult<'a, Option<Option<Box<ForeignItem>>>> {
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: false };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: false };
Ok(self.parse_item_(fn_parse_mode, force_collect)?.map(
|Item { attrs, id, span, vis, kind, tokens }| {
let kind = match ForeignItemKind::try_from(kind) {
@ -2135,7 +2141,8 @@ impl<'a> Parser<'a> {
let inherited_vis =
Visibility { span: DUMMY_SP, kind: VisibilityKind::Inherited, tokens: None };
// We use `parse_fn` to get a span for the function
let fn_parse_mode = FnParseMode { req_name: |_| true, req_body: true };
let fn_parse_mode =
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true };
match self.parse_fn(
&mut AttrVec::new(),
fn_parse_mode,
@ -2403,6 +2410,9 @@ pub(crate) struct FnParseMode {
/// * The span is from Edition 2015. In particular, you can get a
/// 2015 span inside a 2021 crate using macros.
pub(super) req_name: ReqName,
/// The context in which this function is parsed, used for diagnostics.
/// This indicates the fn is a free function or method and so on.
pub(super) context: FnContext,
/// If this flag is set to `true`, then plain, semicolon-terminated function
/// prototypes are not allowed here.
///
@ -2424,6 +2434,18 @@ pub(crate) struct FnParseMode {
pub(super) req_body: bool,
}
/// The context in which a function is parsed.
/// FIXME(estebank, xizheyin): Use more variants.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum FnContext {
/// Free context.
Free,
/// A Trait context.
Trait,
/// An Impl block.
Impl,
}
/// Parsing of functions and methods.
impl<'a> Parser<'a> {
/// Parse a function starting from the front matter (`const ...`) to the body `{ ... }` or `;`.
@ -2439,11 +2461,8 @@ impl<'a> Parser<'a> {
let header = self.parse_fn_front_matter(vis, case, FrontMatterParsingMode::Function)?; // `const ... fn`
let ident = self.parse_ident()?; // `foo`
let mut generics = self.parse_generics()?; // `<'a, T, ...>`
let decl = match self.parse_fn_decl(
fn_parse_mode.req_name,
AllowPlus::Yes,
RecoverReturnSign::Yes,
) {
let decl = match self.parse_fn_decl(&fn_parse_mode, AllowPlus::Yes, RecoverReturnSign::Yes)
{
Ok(decl) => decl,
Err(old_err) => {
// If we see `for Ty ...` then user probably meant `impl` item.
@ -2961,18 +2980,21 @@ impl<'a> Parser<'a> {
/// Parses the parameter list and result type of a function declaration.
pub(super) fn parse_fn_decl(
&mut self,
req_name: ReqName,
fn_parse_mode: &FnParseMode,
ret_allow_plus: AllowPlus,
recover_return_sign: RecoverReturnSign,
) -> PResult<'a, Box<FnDecl>> {
Ok(Box::new(FnDecl {
inputs: self.parse_fn_params(req_name)?,
inputs: self.parse_fn_params(fn_parse_mode)?,
output: self.parse_ret_ty(ret_allow_plus, RecoverQPath::Yes, recover_return_sign)?,
}))
}
/// Parses the parameter list of a function, including the `(` and `)` delimiters.
pub(super) fn parse_fn_params(&mut self, req_name: ReqName) -> PResult<'a, ThinVec<Param>> {
pub(super) fn parse_fn_params(
&mut self,
fn_parse_mode: &FnParseMode,
) -> PResult<'a, ThinVec<Param>> {
let mut first_param = true;
// Parse the arguments, starting out with `self` being allowed...
if self.token != TokenKind::OpenParen
@ -2988,7 +3010,7 @@ impl<'a> Parser<'a> {
let (mut params, _) = self.parse_paren_comma_seq(|p| {
p.recover_vcs_conflict_marker();
let snapshot = p.create_snapshot_for_diagnostic();
let param = p.parse_param_general(req_name, first_param, true).or_else(|e| {
let param = p.parse_param_general(fn_parse_mode, first_param, true).or_else(|e| {
let guar = e.emit();
// When parsing a param failed, we should check to make the span of the param
// not contain '(' before it.
@ -3019,7 +3041,7 @@ impl<'a> Parser<'a> {
/// - `recover_arg_parse` is used to recover from a failed argument parse.
pub(super) fn parse_param_general(
&mut self,
req_name: ReqName,
fn_parse_mode: &FnParseMode,
first_param: bool,
recover_arg_parse: bool,
) -> PResult<'a, Param> {
@ -3035,16 +3057,22 @@ impl<'a> Parser<'a> {
let is_name_required = match this.token.kind {
token::DotDotDot => false,
_ => req_name(this.token.span.with_neighbor(this.prev_token.span).edition()),
_ => (fn_parse_mode.req_name)(
this.token.span.with_neighbor(this.prev_token.span).edition(),
),
};
let (pat, ty) = if is_name_required || this.is_named_param() {
debug!("parse_param_general parse_pat (is_name_required:{})", is_name_required);
let (pat, colon) = this.parse_fn_param_pat_colon()?;
if !colon {
let mut err = this.unexpected().unwrap_err();
return if let Some(ident) =
this.parameter_without_type(&mut err, pat, is_name_required, first_param)
{
return if let Some(ident) = this.parameter_without_type(
&mut err,
pat,
is_name_required,
first_param,
fn_parse_mode,
) {
let guar = err.emit();
Ok((dummy_arg(ident, guar), Trailing::No, UsePreAttrPos::No))
} else {

View file

@ -22,7 +22,7 @@ use std::{fmt, mem, slice};
use attr_wrapper::{AttrWrapper, UsePreAttrPos};
pub use diagnostics::AttemptLocalParseRecovery;
pub(crate) use expr::ForbiddenLetReason;
pub(crate) use item::FnParseMode;
pub(crate) use item::{FnContext, FnParseMode};
pub use pat::{CommaRecoveryMode, RecoverColon, RecoverComma};
use path::PathStyle;
use rustc_ast::token::{

View file

@ -20,7 +20,9 @@ use crate::errors::{
PathFoundAttributeInParams, PathFoundCVariadicParams, PathSingleColon, PathTripleColon,
};
use crate::exp;
use crate::parser::{CommaRecoveryMode, ExprKind, RecoverColon, RecoverComma};
use crate::parser::{
CommaRecoveryMode, ExprKind, FnContext, FnParseMode, RecoverColon, RecoverComma,
};
/// Specifies how to parse a path.
#[derive(Copy, Clone, PartialEq)]
@ -399,7 +401,13 @@ impl<'a> Parser<'a> {
let dcx = self.dcx();
let parse_params_result = self.parse_paren_comma_seq(|p| {
let param = p.parse_param_general(|_| false, false, false);
// Inside parenthesized type arguments, we want types only, not names.
let mode = FnParseMode {
context: FnContext::Free,
req_name: |_| false,
req_body: false,
};
let param = p.parse_param_general(&mode, false, false);
param.map(move |param| {
if !matches!(param.pat.kind, PatKind::Missing) {
dcx.emit_err(FnPathFoundNamedParams {

View file

@ -19,8 +19,8 @@ use super::diagnostics::AttemptLocalParseRecovery;
use super::pat::{PatternLocation, RecoverComma};
use super::path::PathStyle;
use super::{
AttrWrapper, BlockMode, FnParseMode, ForceCollect, Parser, Restrictions, SemiColonMode,
Trailing, UsePreAttrPos,
AttrWrapper, BlockMode, FnContext, FnParseMode, ForceCollect, Parser, Restrictions,
SemiColonMode, Trailing, UsePreAttrPos,
};
use crate::errors::{self, MalformedLoopLabel};
use crate::exp;
@ -153,7 +153,7 @@ impl<'a> Parser<'a> {
attrs.clone(), // FIXME: unwanted clone of attrs
false,
true,
FnParseMode { req_name: |_| true, req_body: true },
FnParseMode { req_name: |_| true, context: FnContext::Free, req_body: true },
force_collect,
)? {
self.mk_stmt(lo.to(item.span), StmtKind::Item(Box::new(item)))

View file

@ -19,6 +19,7 @@ use crate::errors::{
NestedCVariadicType, ReturnTypesUseThinArrow,
};
use crate::parser::item::FrontMatterParsingMode;
use crate::parser::{FnContext, FnParseMode};
use crate::{exp, maybe_recover_from_interpolated_ty_qpath};
/// Signals whether parsing a type should allow `+`.
@ -769,7 +770,12 @@ impl<'a> Parser<'a> {
if self.may_recover() && self.token == TokenKind::Lt {
self.recover_fn_ptr_with_generics(lo, &mut params, param_insertion_point)?;
}
let decl = self.parse_fn_decl(|_| false, AllowPlus::No, recover_return_sign)?;
let mode = crate::parser::item::FnParseMode {
req_name: |_| false,
context: FnContext::Free,
req_body: false,
};
let decl = self.parse_fn_decl(&mode, AllowPlus::No, recover_return_sign)?;
let decl_span = span_start.to(self.prev_token.span);
Ok(TyKind::FnPtr(Box::new(FnPtrTy {
@ -1314,7 +1320,8 @@ impl<'a> Parser<'a> {
self.bump();
let args_lo = self.token.span;
let snapshot = self.create_snapshot_for_diagnostic();
match self.parse_fn_decl(|_| false, AllowPlus::No, RecoverReturnSign::OnlyFatArrow) {
let mode = FnParseMode { req_name: |_| false, context: FnContext::Free, req_body: false };
match self.parse_fn_decl(&mode, AllowPlus::No, RecoverReturnSign::OnlyFatArrow) {
Ok(decl) => {
self.dcx().emit_err(ExpectedFnPathFoundFnKeyword { fn_token_span });
Some(ast::Path {
@ -1400,8 +1407,9 @@ impl<'a> Parser<'a> {
// Parse `(T, U) -> R`.
let inputs_lo = self.token.span;
let mode = FnParseMode { req_name: |_| false, context: FnContext::Free, req_body: false };
let inputs: ThinVec<_> =
self.parse_fn_params(|_| false)?.into_iter().map(|input| input.ty).collect();
self.parse_fn_params(&mode)?.into_iter().map(|input| input.ty).collect();
let inputs_span = inputs_lo.to(self.prev_token.span);
let output = self.parse_ret_ty(AllowPlus::No, RecoverQPath::No, RecoverReturnSign::No)?;
let args = ast::ParenthesizedArgs {

View file

@ -8,7 +8,6 @@ use rustc_hir::HirId;
use rustc_hir::def_id::DefId;
use rustc_index::{Idx, IndexVec};
use rustc_middle::middle::stability::EvalResult;
use rustc_middle::mir::{self, Const};
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary};
use rustc_middle::ty::layout::IntegerExt;
use rustc_middle::ty::{
@ -430,7 +429,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
match bdy {
PatRangeBoundary::NegInfinity => MaybeInfiniteInt::NegInfinity,
PatRangeBoundary::Finite(value) => {
let bits = value.eval_bits(self.tcx, self.typing_env);
let bits = value.try_to_scalar_int().unwrap().to_bits_unchecked();
match *ty.kind() {
ty::Int(ity) => {
let size = Integer::from_int_ty(&self.tcx, ity).size().bits();
@ -520,75 +519,54 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
PatKind::Constant { value } => {
match ty.kind() {
ty::Bool => {
ctor = match value.try_eval_bool(cx.tcx, cx.typing_env) {
Some(b) => Bool(b),
None => Opaque(OpaqueId::new()),
};
ctor = Bool(value.try_to_bool().unwrap());
fields = vec![];
arity = 0;
}
ty::Char | ty::Int(_) | ty::Uint(_) => {
ctor = match value.try_eval_bits(cx.tcx, cx.typing_env) {
Some(bits) => {
let x = match *ty.kind() {
ty::Int(ity) => {
let size = Integer::from_int_ty(&cx.tcx, ity).size().bits();
MaybeInfiniteInt::new_finite_int(bits, size)
}
_ => MaybeInfiniteInt::new_finite_uint(bits),
};
IntRange(IntRange::from_singleton(x))
}
None => Opaque(OpaqueId::new()),
ctor = {
let bits = value.valtree.unwrap_leaf().to_bits_unchecked();
let x = match *ty.kind() {
ty::Int(ity) => {
let size = Integer::from_int_ty(&cx.tcx, ity).size().bits();
MaybeInfiniteInt::new_finite_int(bits, size)
}
_ => MaybeInfiniteInt::new_finite_uint(bits),
};
IntRange(IntRange::from_singleton(x))
};
fields = vec![];
arity = 0;
}
ty::Float(ty::FloatTy::F16) => {
ctor = match value.try_eval_bits(cx.tcx, cx.typing_env) {
Some(bits) => {
use rustc_apfloat::Float;
let value = rustc_apfloat::ieee::Half::from_bits(bits);
F16Range(value, value, RangeEnd::Included)
}
None => Opaque(OpaqueId::new()),
};
use rustc_apfloat::Float;
let bits = value.valtree.unwrap_leaf().to_u16();
let value = rustc_apfloat::ieee::Half::from_bits(bits.into());
ctor = F16Range(value, value, RangeEnd::Included);
fields = vec![];
arity = 0;
}
ty::Float(ty::FloatTy::F32) => {
ctor = match value.try_eval_bits(cx.tcx, cx.typing_env) {
Some(bits) => {
use rustc_apfloat::Float;
let value = rustc_apfloat::ieee::Single::from_bits(bits);
F32Range(value, value, RangeEnd::Included)
}
None => Opaque(OpaqueId::new()),
};
use rustc_apfloat::Float;
let bits = value.valtree.unwrap_leaf().to_u32();
let value = rustc_apfloat::ieee::Single::from_bits(bits.into());
ctor = F32Range(value, value, RangeEnd::Included);
fields = vec![];
arity = 0;
}
ty::Float(ty::FloatTy::F64) => {
ctor = match value.try_eval_bits(cx.tcx, cx.typing_env) {
Some(bits) => {
use rustc_apfloat::Float;
let value = rustc_apfloat::ieee::Double::from_bits(bits);
F64Range(value, value, RangeEnd::Included)
}
None => Opaque(OpaqueId::new()),
};
use rustc_apfloat::Float;
let bits = value.valtree.unwrap_leaf().to_u64();
let value = rustc_apfloat::ieee::Double::from_bits(bits.into());
ctor = F64Range(value, value, RangeEnd::Included);
fields = vec![];
arity = 0;
}
ty::Float(ty::FloatTy::F128) => {
ctor = match value.try_eval_bits(cx.tcx, cx.typing_env) {
Some(bits) => {
use rustc_apfloat::Float;
let value = rustc_apfloat::ieee::Quad::from_bits(bits);
F128Range(value, value, RangeEnd::Included)
}
None => Opaque(OpaqueId::new()),
};
use rustc_apfloat::Float;
let bits = value.valtree.unwrap_leaf().to_u128();
let value = rustc_apfloat::ieee::Quad::from_bits(bits);
ctor = F128Range(value, value, RangeEnd::Included);
fields = vec![];
arity = 0;
}
@ -630,8 +608,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
}
ty::Float(fty) => {
use rustc_apfloat::Float;
let lo = lo.as_finite().map(|c| c.eval_bits(cx.tcx, cx.typing_env));
let hi = hi.as_finite().map(|c| c.eval_bits(cx.tcx, cx.typing_env));
let lo = lo
.as_finite()
.map(|c| c.try_to_scalar_int().unwrap().to_bits_unchecked());
let hi = hi
.as_finite()
.map(|c| c.try_to_scalar_int().unwrap().to_bits_unchecked());
match fty {
ty::FloatTy::F16 => {
use rustc_apfloat::ieee::Half;
@ -739,8 +721,8 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
};
match ScalarInt::try_from_uint(bits, size) {
Some(scalar) => {
let value = mir::Const::from_scalar(tcx, scalar.into(), ty.inner());
PatRangeBoundary::Finite(value)
let valtree = ty::ValTree::from_scalar_int(tcx, scalar);
PatRangeBoundary::Finite(valtree)
}
// The value doesn't fit. Since `x >= 0` and 0 always encodes the minimum value
// for a type, the problem isn't that the value is too small. So it must be too
@ -760,7 +742,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
"_".to_string()
} else if range.is_singleton() {
let lo = cx.hoist_pat_range_bdy(range.lo, ty);
let value = lo.as_finite().unwrap();
let value = ty::Value { ty: ty.inner(), valtree: lo.as_finite().unwrap() };
value.to_string()
} else {
// We convert to an inclusive range for diagnostics.
@ -772,7 +754,9 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
// fictitious values after `{u,i}size::MAX` (see [`IntRange::split`] for why we do
// this). We show this to the user as `usize::MAX..` which is slightly incorrect but
// probably clear enough.
lo = PatRangeBoundary::Finite(ty.numeric_max_val(cx.tcx).unwrap());
let max = ty.numeric_max_val(cx.tcx).unwrap();
let max = ty::ValTree::from_scalar_int(cx.tcx, max.try_to_scalar_int().unwrap());
lo = PatRangeBoundary::Finite(max);
}
let hi = if let Some(hi) = range.hi.minus_one() {
hi
@ -907,7 +891,7 @@ impl<'p, 'tcx: 'p> PatCx for RustcPatCtxt<'p, 'tcx> {
type Ty = RevealedTy<'tcx>;
type Error = ErrorGuaranteed;
type VariantIdx = VariantIdx;
type StrLit = Const<'tcx>;
type StrLit = ty::Value<'tcx>;
type ArmData = HirId;
type PatData = &'p Pat<'tcx>;

View file

@ -493,9 +493,6 @@ impl<'a, 'ra, 'tcx> BuildReducedGraphVisitor<'a, 'ra, 'tcx> {
});
}
}
// We don't add prelude imports to the globs since they only affect lexical scopes,
// which are not relevant to import resolution.
ImportKind::Glob { is_prelude: true, .. } => {}
ImportKind::Glob { .. } => current_module.globs.borrow_mut().push(import),
_ => unreachable!(),
}
@ -658,13 +655,19 @@ impl<'a, 'ra, 'tcx> BuildReducedGraphVisitor<'a, 'ra, 'tcx> {
self.add_import(module_path, kind, use_tree.span, item, root_span, item.id, vis);
}
ast::UseTreeKind::Glob => {
let kind = ImportKind::Glob {
is_prelude: ast::attr::contains_name(&item.attrs, sym::prelude_import),
max_vis: Cell::new(None),
id,
};
self.add_import(prefix, kind, use_tree.span, item, root_span, item.id, vis);
if !ast::attr::contains_name(&item.attrs, sym::prelude_import) {
let kind = ImportKind::Glob { max_vis: Cell::new(None), id };
self.add_import(prefix, kind, use_tree.span, item, root_span, item.id, vis);
} else {
// Resolve the prelude import early.
let path_res =
self.r.cm().maybe_resolve_path(&prefix, None, &self.parent_scope, None);
if let PathResult::Module(ModuleOrUniformRoot::Module(module)) = path_res {
self.r.prelude = Some(module);
} else {
self.r.dcx().span_err(use_tree.span, "cannot resolve a prelude import");
}
}
}
ast::UseTreeKind::Nested { ref items, .. } => {
// Ensure there is at most one `self` in the list

View file

@ -2189,9 +2189,9 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
let ast::ExprKind::Struct(struct_expr) = &expr.kind else { return };
// We don't have to handle type-relative paths because they're forbidden in ADT
// expressions, but that would change with `#[feature(more_qualified_paths)]`.
let Some(Res::Def(_, def_id)) =
self.partial_res_map[&struct_expr.path.segments.iter().last().unwrap().id].full_res()
else {
let Some(segment) = struct_expr.path.segments.last() else { return };
let Some(partial_res) = self.partial_res_map.get(&segment.id) else { return };
let Some(Res::Def(_, def_id)) = partial_res.full_res() else {
return;
};
let Some(default_fields) = self.field_defaults(def_id) else { return };

View file

@ -318,7 +318,6 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
let normalized_ident = Ident { span: normalized_span, ..ident };
// Walk backwards up the ribs in scope.
let mut module = self.graph_root;
for (i, rib) in ribs.iter().enumerate().rev() {
debug!("walk rib\n{:?}", rib.bindings);
// Use the rib kind to determine whether we are resolving parameters
@ -334,51 +333,47 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
*original_rib_ident_def,
ribs,
)));
}
module = match rib.kind {
RibKind::Module(module) => module,
RibKind::MacroDefinition(def) if def == self.macro_def(ident.span.ctxt()) => {
// If an invocation of this macro created `ident`, give up on `ident`
// and switch to `ident`'s source from the macro definition.
ident.span.remove_mark();
continue;
}
_ => continue,
};
match module.kind {
ModuleKind::Block => {} // We can see through blocks
_ => break,
}
let item = self.cm().resolve_ident_in_module_unadjusted(
ModuleOrUniformRoot::Module(module),
ident,
ns,
parent_scope,
Shadowing::Unrestricted,
finalize.map(|finalize| Finalize { used: Used::Scope, ..finalize }),
ignore_binding,
None,
);
if let Ok(binding) = item {
// The ident resolves to an item.
} else if let RibKind::Block(Some(module)) = rib.kind
&& let Ok(binding) = self.cm().resolve_ident_in_module_unadjusted(
ModuleOrUniformRoot::Module(module),
ident,
ns,
parent_scope,
Shadowing::Unrestricted,
finalize.map(|finalize| Finalize { used: Used::Scope, ..finalize }),
ignore_binding,
None,
)
{
// The ident resolves to an item in a block.
return Some(LexicalScopeBinding::Item(binding));
} else if let RibKind::Module(module) = rib.kind {
// Encountered a module item, abandon ribs and look into that module and preludes.
return self
.cm()
.early_resolve_ident_in_lexical_scope(
orig_ident,
ScopeSet::Late(ns, module, finalize.map(|finalize| finalize.node_id)),
parent_scope,
finalize,
finalize.is_some(),
ignore_binding,
None,
)
.ok()
.map(LexicalScopeBinding::Item);
}
if let RibKind::MacroDefinition(def) = rib.kind
&& def == self.macro_def(ident.span.ctxt())
{
// If an invocation of this macro created `ident`, give up on `ident`
// and switch to `ident`'s source from the macro definition.
ident.span.remove_mark();
}
}
self.cm()
.early_resolve_ident_in_lexical_scope(
orig_ident,
ScopeSet::Late(ns, module, finalize.map(|finalize| finalize.node_id)),
parent_scope,
finalize,
finalize.is_some(),
ignore_binding,
None,
)
.ok()
.map(LexicalScopeBinding::Item)
unreachable!()
}
/// Resolve an identifier in lexical scope.
@ -1171,6 +1166,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)
@ -1263,6 +1259,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
let (has_generic_params, def_kind) = match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)
@ -1356,6 +1353,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
for rib in ribs {
let (has_generic_params, def_kind) = match rib.kind {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::Module(..)
| RibKind::MacroDefinition(..)

View file

@ -87,7 +87,6 @@ pub(crate) enum ImportKind<'ra> {
id: NodeId,
},
Glob {
is_prelude: bool,
// The visibility of the greatest re-export.
// n.b. `max_vis` is only used in `finalize_import` to check for re-export errors.
max_vis: Cell<Option<Visibility>>,
@ -125,12 +124,9 @@ impl<'ra> std::fmt::Debug for ImportKind<'ra> {
.field("nested", nested)
.field("id", id)
.finish(),
Glob { is_prelude, max_vis, id } => f
.debug_struct("Glob")
.field("is_prelude", is_prelude)
.field("max_vis", max_vis)
.field("id", id)
.finish(),
Glob { max_vis, id } => {
f.debug_struct("Glob").field("max_vis", max_vis).field("id", id).finish()
}
ExternCrate { source, target, id } => f
.debug_struct("ExternCrate")
.field("source", source)
@ -1073,7 +1069,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
ImportKind::Single { source, target, ref bindings, type_ns_only, id, .. } => {
(source, target, bindings, type_ns_only, id)
}
ImportKind::Glob { is_prelude, ref max_vis, id } => {
ImportKind::Glob { ref max_vis, id } => {
if import.module_path.len() <= 1 {
// HACK(eddyb) `lint_if_path_starts_with_module` needs at least
// 2 segments, so the `resolve_path` above won't trigger it.
@ -1096,8 +1092,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
module: None,
});
}
if !is_prelude
&& let Some(max_vis) = max_vis.get()
if let Some(max_vis) = max_vis.get()
&& !max_vis.is_at_least(import.vis, self.tcx)
{
let def_id = self.local_def_id(id);
@ -1485,7 +1480,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
fn resolve_glob_import(&mut self, import: Import<'ra>) {
// This function is only called for glob imports.
let ImportKind::Glob { id, is_prelude, .. } = import.kind else { unreachable!() };
let ImportKind::Glob { id, .. } = import.kind else { unreachable!() };
let ModuleOrUniformRoot::Module(module) = import.imported_module.get().unwrap() else {
self.dcx().emit_err(CannotGlobImportAllCrates { span: import.span });
@ -1504,9 +1499,6 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> {
if module == import.parent_scope.module {
return;
} else if is_prelude {
self.prelude = Some(module);
return;
}
// Add to module's glob_importers

View file

@ -192,6 +192,13 @@ pub(crate) enum RibKind<'ra> {
/// No restriction needs to be applied.
Normal,
/// We passed through an `ast::Block`.
/// Behaves like `Normal`, but also partially like `Module` if the block contains items.
/// `Block(None)` must be always processed in the same way as `Block(Some(module))`
/// with empty `module`. The module can be `None` only because creation of some definitely
/// empty modules is skipped as an optimization.
Block(Option<Module<'ra>>),
/// We passed through an impl or trait and are now in one of its
/// methods or associated types. Allow references to ty params that impl or trait
/// binds. Disallow any other upvars (including other ty params that are
@ -210,7 +217,7 @@ pub(crate) enum RibKind<'ra> {
/// All other constants aren't allowed to use generic params at all.
ConstantItem(ConstantHasGenerics, Option<(Ident, ConstantItemKind)>),
/// We passed through a module.
/// We passed through a module item.
Module(Module<'ra>),
/// We passed through a `macro_rules!` statement
@ -242,6 +249,7 @@ impl RibKind<'_> {
pub(crate) fn contains_params(&self) -> bool {
match self {
RibKind::Normal
| RibKind::Block(..)
| RibKind::FnOrCoroutine
| RibKind::ConstantItem(..)
| RibKind::Module(_)
@ -258,15 +266,8 @@ impl RibKind<'_> {
fn is_label_barrier(self) -> bool {
match self {
RibKind::Normal | RibKind::MacroDefinition(..) => false,
RibKind::AssocItem
| RibKind::FnOrCoroutine
| RibKind::Item(..)
| RibKind::ConstantItem(..)
| RibKind::Module(..)
| RibKind::ForwardGenericParamBan(_)
| RibKind::ConstParamTy
| RibKind::InlineAsmSym => true,
RibKind::FnOrCoroutine | RibKind::ConstantItem(..) => true,
kind => bug!("unexpected rib kind: {kind:?}"),
}
}
}
@ -1527,19 +1528,6 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
ret
}
fn with_mod_rib<T>(&mut self, id: NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let module = self.r.expect_module(self.r.local_def_id(id).to_def_id());
// Move down in the graph.
let orig_module = replace(&mut self.parent_scope.module, module);
self.with_rib(ValueNS, RibKind::Module(module), |this| {
this.with_rib(TypeNS, RibKind::Module(module), |this| {
let ret = f(this);
this.parent_scope.module = orig_module;
ret
})
})
}
fn visit_generic_params(&mut self, params: &'ast [GenericParam], add_self_upper: bool) {
// For type parameter defaults, we have to ban access
// to following type parameters, as the GenericArgs can only
@ -2677,20 +2665,25 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
}
ItemKind::Mod(..) => {
self.with_mod_rib(item.id, |this| {
if mod_inner_docs {
this.resolve_doc_links(&item.attrs, MaybeExported::Ok(item.id));
}
let old_macro_rules = this.parent_scope.macro_rules;
visit::walk_item(this, item);
// Maintain macro_rules scopes in the same way as during early resolution
// for diagnostics and doc links.
if item.attrs.iter().all(|attr| {
!attr.has_name(sym::macro_use) && !attr.has_name(sym::macro_escape)
}) {
this.parent_scope.macro_rules = old_macro_rules;
}
let module = self.r.expect_module(self.r.local_def_id(item.id).to_def_id());
let orig_module = replace(&mut self.parent_scope.module, module);
self.with_rib(ValueNS, RibKind::Module(module), |this| {
this.with_rib(TypeNS, RibKind::Module(module), |this| {
if mod_inner_docs {
this.resolve_doc_links(&item.attrs, MaybeExported::Ok(item.id));
}
let old_macro_rules = this.parent_scope.macro_rules;
visit::walk_item(this, item);
// Maintain macro_rules scopes in the same way as during early resolution
// for diagnostics and doc links.
if item.attrs.iter().all(|attr| {
!attr.has_name(sym::macro_use) && !attr.has_name(sym::macro_escape)
}) {
this.parent_scope.macro_rules = old_macro_rules;
}
})
});
self.parent_scope.module = orig_module;
}
ItemKind::Static(box ast::StaticItem {
@ -2821,9 +2814,9 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
// We also can't shadow bindings from associated parent items.
for ns in [ValueNS, TypeNS] {
for parent_rib in self.ribs[ns].iter().rev() {
// Break at mod level, to account for nested items which are
// Break at module or block level, to account for nested items which are
// allowed to shadow generic param names.
if matches!(parent_rib.kind, RibKind::Module(..)) {
if matches!(parent_rib.kind, RibKind::Module(..) | RibKind::Block(..)) {
break;
}
@ -4652,16 +4645,16 @@ impl<'a, 'ast, 'ra, 'tcx> LateResolutionVisitor<'a, 'ast, 'ra, 'tcx> {
debug!("(resolving block) entering block");
// Move down in the graph, if there's an anonymous module rooted here.
let orig_module = self.parent_scope.module;
let anonymous_module = self.r.block_map.get(&block.id).cloned(); // clones a reference
let anonymous_module = self.r.block_map.get(&block.id).copied();
let mut num_macro_definition_ribs = 0;
if let Some(anonymous_module) = anonymous_module {
debug!("(resolving block) found anonymous module, moving down");
self.ribs[ValueNS].push(Rib::new(RibKind::Module(anonymous_module)));
self.ribs[TypeNS].push(Rib::new(RibKind::Module(anonymous_module)));
self.ribs[ValueNS].push(Rib::new(RibKind::Block(Some(anonymous_module))));
self.ribs[TypeNS].push(Rib::new(RibKind::Block(Some(anonymous_module))));
self.parent_scope.module = anonymous_module;
} else {
self.ribs[ValueNS].push(Rib::new(RibKind::Normal));
self.ribs[ValueNS].push(Rib::new(RibKind::Block(None)));
}
// Descend into the block.

View file

@ -849,9 +849,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
}
// Try to find in last block rib
if let Some(rib) = &self.last_block_rib
&& let RibKind::Normal = rib.kind
{
if let Some(rib) = &self.last_block_rib {
for (ident, &res) in &rib.bindings {
if let Res::Local(_) = res
&& path.len() == 1
@ -900,7 +898,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
if path.len() == 1 {
for rib in self.ribs[ns].iter().rev() {
let item = path[0].ident;
if let RibKind::Module(module) = rib.kind
if let RibKind::Module(module) | RibKind::Block(Some(module)) = rib.kind
&& let Some(did) = find_doc_alias_name(self.r, module, item.name)
{
return Some((did, item));
@ -2458,9 +2456,7 @@ impl<'ast, 'ra, 'tcx> LateResolutionVisitor<'_, 'ast, 'ra, 'tcx> {
}
}
if let RibKind::Module(module) = rib.kind
&& let ModuleKind::Block = module.kind
{
if let RibKind::Block(Some(module)) = rib.kind {
self.r.add_module_candidates(module, &mut names, &filter_fn, Some(ctxt));
} else if let RibKind::Module(module) = rib.kind {
// Encountered a module item, abandon ribs and look into that module and preludes.

Some files were not shown because too many files have changed in this diff Show more