From 2666f39d3e5d4b354065056102ef4856aa23fdae Mon Sep 17 00:00:00 2001 From: Guillaume Gomez Date: Wed, 3 Jan 2024 22:14:36 +0100 Subject: [PATCH] Detect unconditional recursion between Default trait impl and static methods --- clippy_lints/src/lib.rs | 2 +- clippy_lints/src/unconditional_recursion.rs | 271 +++++++++++++++++--- 2 files changed, 231 insertions(+), 42 deletions(-) diff --git a/clippy_lints/src/lib.rs b/clippy_lints/src/lib.rs index 854c111f9f53..efdd39259497 100644 --- a/clippy_lints/src/lib.rs +++ b/clippy_lints/src/lib.rs @@ -1083,7 +1083,7 @@ pub fn register_lints(store: &mut rustc_lint::LintStore, conf: &'static Conf) { store.register_late_pass(|_| Box::new(repeat_vec_with_capacity::RepeatVecWithCapacity)); store.register_late_pass(|_| Box::new(uninhabited_references::UninhabitedReferences)); store.register_late_pass(|_| Box::new(ineffective_open_options::IneffectiveOpenOptions)); - store.register_late_pass(|_| Box::new(unconditional_recursion::UnconditionalRecursion)); + store.register_late_pass(|_| Box::::default()); store.register_late_pass(move |_| { Box::new(pub_underscore_fields::PubUnderscoreFields { behavior: pub_underscore_fields_behavior, diff --git a/clippy_lints/src/unconditional_recursion.rs b/clippy_lints/src/unconditional_recursion.rs index 5366b5513d36..e90306ded61c 100644 --- a/clippy_lints/src/unconditional_recursion.rs +++ b/clippy_lints/src/unconditional_recursion.rs @@ -1,13 +1,19 @@ use clippy_utils::diagnostics::span_lint_and_then; -use clippy_utils::{expr_or_init, get_trait_def_id}; +use clippy_utils::{expr_or_init, get_trait_def_id, path_def_id}; use rustc_ast::BinOpKind; +use rustc_data_structures::fx::FxHashMap; +use rustc_hir as hir; +use rustc_hir::def::{DefKind, Res}; use rustc_hir::def_id::{DefId, LocalDefId}; -use rustc_hir::intravisit::{walk_body, FnKind}; -use rustc_hir::{Body, Expr, ExprKind, FnDecl, Item, ItemKind, Node}; +use rustc_hir::intravisit::{walk_body, walk_expr, FnKind, Visitor}; +use rustc_hir::{Body, Expr, ExprKind, FnDecl, HirId, Item, ItemKind, Node, QPath, TyKind}; +use rustc_hir_analysis::hir_ty_to_ty; use rustc_lint::{LateContext, LateLintPass}; -use rustc_middle::ty::{self, Ty}; -use rustc_session::declare_lint_pass; -use rustc_span::symbol::Ident; +use rustc_middle::hir::map::Map; +use rustc_middle::hir::nested_filter; +use rustc_middle::ty::{self, AssocKind, Ty, TyCtxt}; +use rustc_session::impl_lint_pass; +use rustc_span::symbol::{kw, Ident}; use rustc_span::{sym, Span}; use rustc_trait_selection::traits::error_reporting::suggestions::ReturnsVisitor; @@ -42,7 +48,26 @@ declare_clippy_lint! { "detect unconditional recursion in some traits implementation" } -declare_lint_pass!(UnconditionalRecursion => [UNCONDITIONAL_RECURSION]); +#[derive(Default)] +pub struct UnconditionalRecursion { + /// The key is the `DefId` of the type implementing the `Default` trait and the value is the + /// `DefId` of the return call. + default_impl_for_type: FxHashMap, +} + +impl_lint_pass!(UnconditionalRecursion => [UNCONDITIONAL_RECURSION]); + +fn span_error(cx: &LateContext<'_>, method_span: Span, expr: &Expr<'_>) { + span_lint_and_then( + cx, + UNCONDITIONAL_RECURSION, + method_span, + "function cannot return without recursing", + |diag| { + diag.span_note(expr.span, "recursive call site"); + }, + ); +} fn get_ty_def_id(ty: Ty<'_>) -> Option { match ty.peel_refs().kind() { @@ -52,17 +77,60 @@ fn get_ty_def_id(ty: Ty<'_>) -> Option { } } -fn has_conditional_return(body: &Body<'_>, expr: &Expr<'_>) -> bool { +fn get_hir_ty_def_id(tcx: TyCtxt<'_>, hir_ty: rustc_hir::Ty<'_>) -> Option { + let TyKind::Path(qpath) = hir_ty.kind else { return None }; + match qpath { + QPath::Resolved(_, path) => path.res.opt_def_id(), + QPath::TypeRelative(_, _) => { + let ty = hir_ty_to_ty(tcx, &hir_ty); + + match ty.kind() { + ty::Alias(ty::Projection, proj) => { + Res::::Def(DefKind::Trait, proj.trait_ref(tcx).def_id).opt_def_id() + }, + _ => None, + } + }, + QPath::LangItem(..) => None, + } +} + +fn get_return_calls_in_body<'tcx>(body: &'tcx Body<'tcx>) -> Vec<&'tcx Expr<'tcx>> { let mut visitor = ReturnsVisitor::default(); - walk_body(&mut visitor, body); - match visitor.returns.as_slice() { + visitor.visit_body(body); + visitor.returns +} + +fn has_conditional_return(body: &Body<'_>, expr: &Expr<'_>) -> bool { + match get_return_calls_in_body(body).as_slice() { [] => false, [return_expr] => return_expr.hir_id != expr.hir_id, _ => true, } } +fn get_impl_trait_def_id(cx: &LateContext<'_>, method_def_id: LocalDefId) -> Option { + let hir_id = cx.tcx.local_def_id_to_hir_id(method_def_id); + if let Some(( + _, + Node::Item(Item { + kind: ItemKind::Impl(impl_), + owner_id, + .. + }), + )) = cx.tcx.hir().parent_iter(hir_id).next() + // We exclude `impl` blocks generated from rustc's proc macros. + && !cx.tcx.has_attr(*owner_id, sym::automatically_derived) + // It is a implementation of a trait. + && let Some(trait_) = impl_.of_trait + { + trait_.trait_def_id() + } else { + None + } +} + #[allow(clippy::unnecessary_def_path)] fn check_partial_eq(cx: &LateContext<'_>, method_span: Span, method_def_id: LocalDefId, name: Ident, expr: &Expr<'_>) { let args = cx @@ -75,20 +143,7 @@ fn check_partial_eq(cx: &LateContext<'_>, method_span: Span, method_def_id: Loca && let Some(other_arg) = get_ty_def_id(*other_arg) // The two arguments are of the same type. && self_arg == other_arg - && let hir_id = cx.tcx.local_def_id_to_hir_id(method_def_id) - && let Some(( - _, - Node::Item(Item { - kind: ItemKind::Impl(impl_), - owner_id, - .. - }), - )) = cx.tcx.hir().parent_iter(hir_id).next() - // We exclude `impl` blocks generated from rustc's proc macros. - && !cx.tcx.has_attr(*owner_id, sym::automatically_derived) - // It is a implementation of a trait. - && let Some(trait_) = impl_.of_trait - && let Some(trait_def_id) = trait_.trait_def_id() + && let Some(trait_def_id) = get_impl_trait_def_id(cx, method_def_id) // The trait is `PartialEq`. && Some(trait_def_id) == get_trait_def_id(cx, &["core", "cmp", "PartialEq"]) { @@ -125,15 +180,7 @@ fn check_partial_eq(cx: &LateContext<'_>, method_span: Span, method_def_id: Loca _ => false, }; if is_bad { - span_lint_and_then( - cx, - UNCONDITIONAL_RECURSION, - method_span, - "function cannot return without recursing", - |diag| { - diag.span_note(expr.span, "recursive call site"); - }, - ); + span_error(cx, method_span, expr); } } } @@ -177,15 +224,156 @@ fn check_to_string(cx: &LateContext<'_>, method_span: Span, method_def_id: Local _ => false, }; if is_bad { - span_lint_and_then( + span_error(cx, method_span, expr); + } + } +} + +fn is_default_method_on_current_ty(tcx: TyCtxt<'_>, qpath: QPath<'_>, implemented_ty_id: DefId) -> bool { + match qpath { + QPath::Resolved(_, path) => match path.segments { + [first, .., last] => last.ident.name == kw::Default && first.res.opt_def_id() == Some(implemented_ty_id), + _ => false, + }, + QPath::TypeRelative(ty, segment) => { + if segment.ident.name != kw::Default { + return false; + } + if matches!( + ty.kind, + TyKind::Path(QPath::Resolved( + _, + hir::Path { + res: Res::SelfTyAlias { .. }, + .. + }, + )) + ) { + return true; + } + get_hir_ty_def_id(tcx, *ty) == Some(implemented_ty_id) + }, + QPath::LangItem(..) => false, + } +} + +struct CheckCalls<'a, 'tcx> { + cx: &'a LateContext<'tcx>, + map: Map<'tcx>, + implemented_ty_id: DefId, + found_default_call: bool, + method_span: Span, +} + +impl<'a, 'tcx> Visitor<'tcx> for CheckCalls<'a, 'tcx> +where + 'tcx: 'a, +{ + type NestedFilter = nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.map + } + + #[allow(clippy::unnecessary_def_path)] + fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + if self.found_default_call { + return; + } + walk_expr(self, expr); + + if let ExprKind::Call(f, _) = expr.kind + && let ExprKind::Path(qpath) = f.kind + && is_default_method_on_current_ty(self.cx.tcx, qpath, self.implemented_ty_id) + && let Some(method_def_id) = path_def_id(self.cx, f) + && let Some(trait_def_id) = self.cx.tcx.trait_of_item(method_def_id) + && Some(trait_def_id) == get_trait_def_id(self.cx, &["core", "default", "Default"]) + { + self.found_default_call = true; + span_error(self.cx, self.method_span, expr); + } + } +} + +impl UnconditionalRecursion { + #[allow(clippy::unnecessary_def_path)] + fn init_default_impl_for_type_if_needed(&mut self, cx: &LateContext<'_>) { + if self.default_impl_for_type.is_empty() + && let Some(default_trait_id) = get_trait_def_id(cx, &["core", "default", "Default"]) + { + let impls = cx.tcx.trait_impls_of(default_trait_id); + for (ty, impl_def_ids) in impls.non_blanket_impls() { + let Some(self_def_id) = ty.def() else { continue }; + for impl_def_id in impl_def_ids { + if !cx.tcx.has_attr(*impl_def_id, sym::automatically_derived) && + let Some(assoc_item) = cx + .tcx + .associated_items(impl_def_id) + .in_definition_order() + // We're not interested in foreign implementations of the `Default` trait. + .find(|item| { + item.kind == AssocKind::Fn && item.def_id.is_local() && item.name == kw::Default + }) + && let Some(body_node) = cx.tcx.hir().get_if_local(assoc_item.def_id) + && let Some(body_id) = body_node.body_id() + && let body = cx.tcx.hir().body(body_id) + // We don't want to keep it if it has conditional return. + && let [return_expr] = get_return_calls_in_body(body).as_slice() + && let ExprKind::Call(call_expr, _) = return_expr.kind + // We need to use typeck here to infer the actual function being called. + && let body_def_id = cx.tcx.hir().enclosing_body_owner(call_expr.hir_id) + && let Some(body_owner) = cx.tcx.hir().maybe_body_owned_by(body_def_id) + && let typeck = cx.tcx.typeck_body(body_owner) + && let Some(call_def_id) = typeck.type_dependent_def_id(call_expr.hir_id) + { + self.default_impl_for_type.insert(self_def_id, call_def_id); + } + } + } + } + } + + fn check_default_new<'tcx>( + &mut self, + cx: &LateContext<'tcx>, + decl: &FnDecl<'tcx>, + body: &'tcx Body<'tcx>, + method_span: Span, + method_def_id: LocalDefId, + ) { + // We're only interested into static methods. + if decl.implicit_self.has_implicit_self() { + return; + } + // We don't check trait implementations. + if get_impl_trait_def_id(cx, method_def_id).is_some() { + return; + } + + let hir_id = cx.tcx.local_def_id_to_hir_id(method_def_id); + if let Some(( + _, + Node::Item(Item { + kind: ItemKind::Impl(impl_), + .. + }), + )) = cx.tcx.hir().parent_iter(hir_id).next() + && let Some(implemented_ty_id) = get_hir_ty_def_id(cx.tcx, *impl_.self_ty) + && { + self.init_default_impl_for_type_if_needed(cx); + true + } + && let Some(return_def_id) = self.default_impl_for_type.get(&implemented_ty_id) + && method_def_id.to_def_id() == *return_def_id + { + let mut c = CheckCalls { cx, - UNCONDITIONAL_RECURSION, + map: cx.tcx.hir(), + implemented_ty_id, + found_default_call: false, method_span, - "function cannot return without recursing", - |diag| { - diag.span_note(expr.span, "recursive call site"); - }, - ); + }; + walk_body(&mut c, body); } } } @@ -195,7 +383,7 @@ impl<'tcx> LateLintPass<'tcx> for UnconditionalRecursion { &mut self, cx: &LateContext<'tcx>, kind: FnKind<'tcx>, - _decl: &'tcx FnDecl<'tcx>, + decl: &'tcx FnDecl<'tcx>, body: &'tcx Body<'tcx>, method_span: Span, method_def_id: LocalDefId, @@ -211,6 +399,7 @@ impl<'tcx> LateLintPass<'tcx> for UnconditionalRecursion { } else if name.name == sym::to_string { check_to_string(cx, method_span, method_def_id, name, expr); } + self.check_default_new(cx, decl, body, method_span, method_def_id); } } }