fix: branches-sharing-code suggests wrongly on const and static

This commit is contained in:
yanglsh 2025-08-21 10:56:14 +08:00 committed by Linshu Yang
parent 45168a79cd
commit f83c72f00c
4 changed files with 388 additions and 28 deletions

View file

@ -9,7 +9,7 @@ use clippy_utils::{
use core::iter;
use core::ops::ControlFlow;
use rustc_errors::Applicability;
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, LetStmt, Node, Stmt, StmtKind, intravisit};
use rustc_hir::{Block, Expr, ExprKind, HirId, HirIdSet, ItemKind, LetStmt, Node, Stmt, StmtKind, UseKind, intravisit};
use rustc_lint::LateContext;
use rustc_span::hygiene::walk_chain;
use rustc_span::source_map::SourceMap;
@ -108,6 +108,7 @@ struct BlockEq {
/// The name and id of every local which can be moved at the beginning and the end.
moved_locals: Vec<(HirId, Symbol)>,
}
impl BlockEq {
fn start_span(&self, b: &Block<'_>, sm: &SourceMap) -> Option<Span> {
match &b.stmts[..self.start_end_eq] {
@ -129,20 +130,33 @@ impl BlockEq {
}
/// If the statement is a local, checks if the bound names match the expected list of names.
fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
if let StmtKind::Let(l) = s.kind {
let mut i = 0usize;
let mut res = true;
l.pat.each_binding_or_first(&mut |_, _, _, name| {
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
i += 1;
} else {
res = false;
}
});
res && i == names.len()
} else {
false
fn eq_binding_names(cx: &LateContext<'_>, s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
match s.kind {
StmtKind::Let(l) => {
let mut i = 0usize;
let mut res = true;
l.pat.each_binding_or_first(&mut |_, _, _, name| {
if names.get(i).is_some_and(|&(_, n)| n == name.name) {
i += 1;
} else {
res = false;
}
});
res && i == names.len()
},
StmtKind::Item(item_id)
if let [(_, name)] = names
&& let item = cx.tcx.hir_item(item_id)
&& let ItemKind::Static(_, ident, ..)
| ItemKind::Const(ident, ..)
| ItemKind::Fn { ident, .. }
| ItemKind::TyAlias(ident, ..)
| ItemKind::Use(_, UseKind::Single(ident))
| ItemKind::Mod(ident, _) = item.kind =>
{
*name == ident.name
},
_ => false,
}
}
@ -164,6 +178,7 @@ fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &
/// Checks if the given statement should be considered equal to the statement in the same
/// position for each block.
fn eq_stmts(
cx: &LateContext<'_>,
stmt: &Stmt<'_>,
blocks: &[&Block<'_>],
get_stmt: impl for<'a> Fn(&'a Block<'a>) -> Option<&'a Stmt<'a>>,
@ -178,7 +193,7 @@ fn eq_stmts(
let new_bindings = &moved_bindings[old_count..];
blocks
.iter()
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(s, new_bindings)))
.all(|b| get_stmt(b).is_some_and(|s| eq_binding_names(cx, s, new_bindings)))
} else {
true
}) && blocks.iter().all(|b| get_stmt(b).is_some_and(|s| eq.eq_stmt(s, stmt)))
@ -218,7 +233,7 @@ fn scan_block_for_eq<'tcx>(
return true;
}
modifies_any_local(cx, stmt, &cond_locals)
|| !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
|| !eq_stmts(cx, stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
})
.map_or(block.stmts.len(), |(i, stmt)| {
adjust_by_closest_callsite(i, stmt, block.stmts[..i].iter().enumerate().rev())
@ -279,6 +294,7 @@ fn scan_block_for_eq<'tcx>(
}))
.fold(end_search_start, |init, (stmt, offset)| {
if eq_stmts(
cx,
stmt,
blocks,
|b| b.stmts.get(b.stmts.len() - offset),
@ -290,11 +306,26 @@ fn scan_block_for_eq<'tcx>(
// Clear out all locals seen at the end so far. None of them can be moved.
let stmts = &blocks[0].stmts;
for stmt in &stmts[stmts.len() - init..=stmts.len() - offset] {
if let StmtKind::Let(l) = stmt.kind {
l.pat.each_binding_or_first(&mut |_, id, _, _| {
// FIXME(rust/#120456) - is `swap_remove` correct?
eq.locals.swap_remove(&id);
});
match stmt.kind {
StmtKind::Let(l) => {
l.pat.each_binding_or_first(&mut |_, id, _, _| {
// FIXME(rust/#120456) - is `swap_remove` correct?
eq.locals.swap_remove(&id);
});
},
StmtKind::Item(item_id) => {
let item = cx.tcx.hir_item(item_id);
if let ItemKind::Static(..)
| ItemKind::Const(..)
| ItemKind::Fn { .. }
| ItemKind::TyAlias(..)
| ItemKind::Use(..)
| ItemKind::Mod(..) = item.kind
{
eq.local_items.swap_remove(&item.owner_id.to_def_id());
}
},
_ => {},
}
}
moved_locals.truncate(moved_locals_at_start);

View file

@ -4,14 +4,17 @@ use crate::source::{SpanRange, SpanRangeExt, walk_span_to_context};
use crate::tokenize_with_text;
use rustc_ast::ast;
use rustc_ast::ast::InlineAsmTemplatePiece;
use rustc_data_structures::fx::FxHasher;
use rustc_data_structures::fx::{FxHasher, FxIndexMap};
use rustc_hir::MatchSource::TryDesugar;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::{
AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, ByRef, Closure, ConstArg, ConstArgKind, Expr,
ExprField, ExprKind, FnRetTy, GenericArg, GenericArgs, HirId, HirIdMap, InlineAsmOperand, LetExpr, Lifetime,
LifetimeKind, Node, Pat, PatExpr, PatExprKind, PatField, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind,
StructTailExpr, TraitBoundModifiers, Ty, TyKind, TyPat, TyPatKind,
AssocItemConstraint, BinOpKind, BindingMode, Block, BodyId, ByRef, Closure, ConstArg, ConstArgKind, ConstItemRhs,
Expr, ExprField, ExprKind, FnDecl, FnRetTy, FnSig, GenericArg, GenericArgs, GenericBound, GenericBounds,
GenericParam, GenericParamKind, GenericParamSource, Generics, HirId, HirIdMap, InlineAsmOperand, ItemId, ItemKind,
LetExpr, Lifetime, LifetimeKind, LifetimeParamKind, Node, ParamName, Pat, PatExpr, PatExprKind, PatField, PatKind,
Path, PathSegment, PreciseCapturingArgKind, PrimTy, QPath, Stmt, StmtKind, StructTailExpr, TraitBoundModifiers, Ty,
TyKind, TyPat, TyPatKind, UseKind, WherePredicate, WherePredicateKind,
};
use rustc_lexer::{FrontmatterAllowed, TokenKind, tokenize};
use rustc_lint::LateContext;
@ -106,6 +109,7 @@ impl<'a, 'tcx> SpanlessEq<'a, 'tcx> {
left_ctxt: SyntaxContext::root(),
right_ctxt: SyntaxContext::root(),
locals: HirIdMap::default(),
local_items: FxIndexMap::default(),
}
}
@ -144,6 +148,7 @@ pub struct HirEqInterExpr<'a, 'b, 'tcx> {
// right. For example, when comparing `{ let x = 1; x + 2 }` and `{ let y = 1; y + 2 }`,
// these blocks are considered equal since `x` is mapped to `y`.
pub locals: HirIdMap<HirId>,
pub local_items: FxIndexMap<DefId, DefId>,
}
impl HirEqInterExpr<'_, '_, '_> {
@ -168,6 +173,189 @@ impl HirEqInterExpr<'_, '_, '_> {
&& self.eq_pat(l.pat, r.pat)
},
(StmtKind::Expr(l), StmtKind::Expr(r)) | (StmtKind::Semi(l), StmtKind::Semi(r)) => self.eq_expr(l, r),
(StmtKind::Item(l), StmtKind::Item(r)) => self.eq_item(*l, *r),
_ => false,
}
}
pub fn eq_item(&mut self, l: ItemId, r: ItemId) -> bool {
let left = self.inner.cx.tcx.hir_item(l);
let right = self.inner.cx.tcx.hir_item(r);
let eq = match (left.kind, right.kind) {
(
ItemKind::Const(l_ident, l_generics, l_ty, ConstItemRhs::Body(l_body)),
ItemKind::Const(r_ident, r_generics, r_ty, ConstItemRhs::Body(r_body)),
) => {
l_ident.name == r_ident.name
&& self.eq_generics(l_generics, r_generics)
&& self.eq_ty(l_ty, r_ty)
&& self.eq_body(l_body, r_body)
},
(ItemKind::Static(l_mut, l_ident, l_ty, l_body), ItemKind::Static(r_mut, r_ident, r_ty, r_body)) => {
l_mut == r_mut && l_ident.name == r_ident.name && self.eq_ty(l_ty, r_ty) && self.eq_body(l_body, r_body)
},
(
ItemKind::Fn {
sig: l_sig,
ident: l_ident,
generics: l_generics,
body: l_body,
has_body: l_has_body,
},
ItemKind::Fn {
sig: r_sig,
ident: r_ident,
generics: r_generics,
body: r_body,
has_body: r_has_body,
},
) => {
l_ident.name == r_ident.name
&& (l_has_body == r_has_body)
&& self.eq_fn_sig(&l_sig, &r_sig)
&& self.eq_generics(l_generics, r_generics)
&& self.eq_body(l_body, r_body)
},
(ItemKind::TyAlias(l_ident, l_generics, l_ty), ItemKind::TyAlias(r_ident, r_generics, r_ty)) => {
l_ident.name == r_ident.name && self.eq_generics(l_generics, r_generics) && self.eq_ty(l_ty, r_ty)
},
(ItemKind::Use(l_path, l_kind), ItemKind::Use(r_path, r_kind)) => {
self.eq_path_segments(l_path.segments, r_path.segments)
&& match (l_kind, r_kind) {
(UseKind::Single(l_ident), UseKind::Single(r_ident)) => l_ident.name == r_ident.name,
(UseKind::Glob, UseKind::Glob) | (UseKind::ListStem, UseKind::ListStem) => true,
_ => false,
}
},
(ItemKind::Mod(l_ident, l_mod), ItemKind::Mod(r_ident, r_mod)) => {
l_ident.name == r_ident.name && over(l_mod.item_ids, r_mod.item_ids, |l, r| self.eq_item(*l, *r))
},
_ => false,
};
if eq {
self.local_items.insert(l.owner_id.to_def_id(), r.owner_id.to_def_id());
}
eq
}
fn eq_fn_sig(&mut self, left: &FnSig<'_>, right: &FnSig<'_>) -> bool {
left.header.safety == right.header.safety
&& left.header.constness == right.header.constness
&& left.header.asyncness == right.header.asyncness
&& left.header.abi == right.header.abi
&& self.eq_fn_decl(left.decl, right.decl)
}
fn eq_fn_decl(&mut self, left: &FnDecl<'_>, right: &FnDecl<'_>) -> bool {
over(left.inputs, right.inputs, |l, r| self.eq_ty(l, r))
&& (match (left.output, right.output) {
(FnRetTy::DefaultReturn(_), FnRetTy::DefaultReturn(_)) => true,
(FnRetTy::Return(l_ty), FnRetTy::Return(r_ty)) => self.eq_ty(l_ty, r_ty),
_ => false,
})
&& left.c_variadic == right.c_variadic
&& left.implicit_self == right.implicit_self
&& left.lifetime_elision_allowed == right.lifetime_elision_allowed
}
fn eq_generics(&mut self, left: &Generics<'_>, right: &Generics<'_>) -> bool {
self.eq_generics_param(left.params, right.params)
&& self.eq_generics_predicate(left.predicates, right.predicates)
}
fn eq_generics_predicate(&mut self, left: &[WherePredicate<'_>], right: &[WherePredicate<'_>]) -> bool {
over(left, right, |l, r| match (l.kind, r.kind) {
(WherePredicateKind::BoundPredicate(l_bound), WherePredicateKind::BoundPredicate(r_bound)) => {
l_bound.origin == r_bound.origin
&& self.eq_ty(l_bound.bounded_ty, r_bound.bounded_ty)
&& self.eq_generics_param(l_bound.bound_generic_params, r_bound.bound_generic_params)
&& self.eq_generics_bound(l_bound.bounds, r_bound.bounds)
},
(WherePredicateKind::RegionPredicate(l_region), WherePredicateKind::RegionPredicate(r_region)) => {
Self::eq_lifetime(l_region.lifetime, r_region.lifetime)
&& self.eq_generics_bound(l_region.bounds, r_region.bounds)
},
(WherePredicateKind::EqPredicate(l_eq), WherePredicateKind::EqPredicate(r_eq)) => {
self.eq_ty(l_eq.lhs_ty, r_eq.lhs_ty)
},
_ => false,
})
}
fn eq_generics_bound(&mut self, left: GenericBounds<'_>, right: GenericBounds<'_>) -> bool {
over(left, right, |l, r| match (l, r) {
(GenericBound::Trait(l_trait), GenericBound::Trait(r_trait)) => {
l_trait.modifiers == r_trait.modifiers
&& self.eq_path(l_trait.trait_ref.path, r_trait.trait_ref.path)
&& self.eq_generics_param(l_trait.bound_generic_params, r_trait.bound_generic_params)
},
(GenericBound::Outlives(l_lifetime), GenericBound::Outlives(r_lifetime)) => {
Self::eq_lifetime(l_lifetime, r_lifetime)
},
(GenericBound::Use(l_capture, _), GenericBound::Use(r_capture, _)) => {
over(l_capture, r_capture, |l, r| match (l, r) {
(PreciseCapturingArgKind::Lifetime(l_lifetime), PreciseCapturingArgKind::Lifetime(r_lifetime)) => {
Self::eq_lifetime(l_lifetime, r_lifetime)
},
(PreciseCapturingArgKind::Param(l_param), PreciseCapturingArgKind::Param(r_param)) => {
l_param.ident == r_param.ident && l_param.res == r_param.res
},
_ => false,
})
},
_ => false,
})
}
fn eq_generics_param(&mut self, left: &[GenericParam<'_>], right: &[GenericParam<'_>]) -> bool {
over(left, right, |l, r| {
(match (l.name, r.name) {
(ParamName::Plain(l_ident), ParamName::Plain(r_ident))
| (ParamName::Error(l_ident), ParamName::Error(r_ident)) => l_ident.name == r_ident.name,
(ParamName::Fresh, ParamName::Fresh) => true,
_ => false,
}) && l.pure_wrt_drop == r.pure_wrt_drop
&& self.eq_generics_param_kind(&l.kind, &r.kind)
&& (matches!(
(l.source, r.source),
(GenericParamSource::Generics, GenericParamSource::Generics)
| (GenericParamSource::Binder, GenericParamSource::Binder)
))
})
}
fn eq_generics_param_kind(&mut self, left: &GenericParamKind<'_>, right: &GenericParamKind<'_>) -> bool {
match (left, right) {
(GenericParamKind::Lifetime { kind: l_kind }, GenericParamKind::Lifetime { kind: r_kind }) => {
match (l_kind, r_kind) {
(LifetimeParamKind::Explicit, LifetimeParamKind::Explicit)
| (LifetimeParamKind::Error, LifetimeParamKind::Error) => true,
(LifetimeParamKind::Elided(l_lifetime_kind), LifetimeParamKind::Elided(r_lifetime_kind)) => {
l_lifetime_kind == r_lifetime_kind
},
_ => false,
}
},
(
GenericParamKind::Type {
default: l_default,
synthetic: l_synthetic,
},
GenericParamKind::Type {
default: r_default,
synthetic: r_synthetic,
},
) => both(*l_default, *r_default, |l, r| self.eq_ty(l, r)) && l_synthetic == r_synthetic,
(
GenericParamKind::Const {
ty: l_ty,
default: l_default,
},
GenericParamKind::Const {
ty: r_ty,
default: r_default,
},
) => self.eq_ty(l_ty, r_ty) && both(*l_default, *r_default, |l, r| self.eq_const_arg(l, r)),
_ => false,
}
}
@ -563,6 +751,17 @@ impl HirEqInterExpr<'_, '_, '_> {
match (left.res, right.res) {
(Res::Local(l), Res::Local(r)) => l == r || self.locals.get(&l) == Some(&r),
(Res::Local(_), _) | (_, Res::Local(_)) => false,
(Res::Def(l_kind, l), Res::Def(r_kind, r))
if l_kind == r_kind
&& let DefKind::Const
| DefKind::Static { .. }
| DefKind::Fn
| DefKind::TyAlias
| DefKind::Use
| DefKind::Mod = l_kind =>
{
(l == r || self.local_items.get(&l) == Some(&r)) && self.eq_path_segments(left.segments, right.segments)
},
_ => self.eq_path_segments(left.segments, right.segments),
}
}

View file

@ -300,3 +300,65 @@ fn issue15004() {
//~^ branches_sharing_code
};
}
pub fn issue15347<T>() -> isize {
if false {
static A: isize = 4;
return A;
} else {
static A: isize = 5;
return A;
}
if false {
//~^ branches_sharing_code
type ISize = isize;
return ISize::MAX;
} else {
type ISize = isize;
return ISize::MAX;
}
if false {
//~^ branches_sharing_code
fn foo() -> isize {
4
}
return foo();
} else {
fn foo() -> isize {
4
}
return foo();
}
if false {
//~^ branches_sharing_code
use std::num::NonZeroIsize;
return NonZeroIsize::new(4).unwrap().get();
} else {
use std::num::NonZeroIsize;
return NonZeroIsize::new(4).unwrap().get();
}
if false {
//~^ branches_sharing_code
const B: isize = 5;
return B;
} else {
const B: isize = 5;
return B;
}
// Should not lint!
const A: isize = 1;
if false {
const B: isize = A;
return B;
} else {
const C: isize = A;
return C;
}
todo!()
}

View file

@ -202,5 +202,73 @@ LL ~ }
LL ~ 1;
|
error: aborting due to 12 previous errors
error: all if blocks contain the same code at the start
--> tests/ui/branches_sharing_code/shared_at_bottom.rs:313:5
|
LL | / if false {
LL | |
LL | | type ISize = isize;
LL | | return ISize::MAX;
| |__________________________^
|
help: consider moving these statements before the if
|
LL ~ type ISize = isize;
LL + return ISize::MAX;
LL + if false {
|
error: all if blocks contain the same code at the start
--> tests/ui/branches_sharing_code/shared_at_bottom.rs:322:5
|
LL | / if false {
LL | |
LL | | fn foo() -> isize {
LL | | 4
LL | | }
LL | | return foo();
| |_____________________^
|
help: consider moving these statements before the if
|
LL ~ fn foo() -> isize {
LL + 4
LL + }
LL + return foo();
LL + if false {
|
error: all if blocks contain the same code at the start
--> tests/ui/branches_sharing_code/shared_at_bottom.rs:335:5
|
LL | / if false {
LL | |
LL | | use std::num::NonZeroIsize;
LL | | return NonZeroIsize::new(4).unwrap().get();
| |___________________________________________________^
|
help: consider moving these statements before the if
|
LL ~ use std::num::NonZeroIsize;
LL + return NonZeroIsize::new(4).unwrap().get();
LL + if false {
|
error: all if blocks contain the same code at the start
--> tests/ui/branches_sharing_code/shared_at_bottom.rs:344:5
|
LL | / if false {
LL | |
LL | | const B: isize = 5;
LL | | return B;
| |_________________^
|
help: consider moving these statements before the if
|
LL ~ const B: isize = 5;
LL + return B;
LL + if false {
|
error: aborting due to 16 previous errors