diff --git a/src/librustc_typeck/check/mod.rs b/src/librustc_typeck/check/mod.rs index a80550486d62..4a0b3879cb99 100644 --- a/src/librustc_typeck/check/mod.rs +++ b/src/librustc_typeck/check/mod.rs @@ -3687,6 +3687,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } } + /// If `expr` is a `match` expression that has only one non-`!` arm, use that arm's tail + /// expression's `Span`, otherwise return `expr.span`. This is done to give bettern errors + /// when given code like the following: + /// ```text + /// if false { return 0i32; } else { 1u32 } + /// // ^^^^ point at this instead of the whole `if` expression + /// ``` + fn get_expr_coercion_span(&self, expr: &hir::Expr) -> syntax_pos::Span { + if let hir::ExprKind::Match(_, arms, _) = &expr.node { + let arm_spans: Vec = arms.iter().filter_map(|arm| { + self.in_progress_tables + .and_then(|tables| tables.borrow().node_type_opt(arm.body.hir_id)) + .and_then(|arm_ty| { + if arm_ty.is_never() { + None + } else { + Some(match &arm.body.node { + // Point at the tail expression when possible. + hir::ExprKind::Block(block, _) => block.expr + .as_ref() + .map(|e| e.span) + .unwrap_or(block.span), + _ => arm.body.span, + }) + } + }) + }).collect(); + if arm_spans.len() == 1 { + return arm_spans[0]; + } + } + expr.span + } + fn check_block_with_expected( &self, blk: &'tcx hir::Block, @@ -3746,12 +3780,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let coerce = ctxt.coerce.as_mut().unwrap(); if let Some(tail_expr_ty) = tail_expr_ty { let tail_expr = tail_expr.unwrap(); - let cause = self.cause(tail_expr.span, - ObligationCauseCode::BlockTailExpression(blk.hir_id)); - coerce.coerce(self, - &cause, - tail_expr, - tail_expr_ty); + let span = self.get_expr_coercion_span(tail_expr); + let cause = self.cause(span, ObligationCauseCode::BlockTailExpression(blk.hir_id)); + coerce.coerce(self, &cause, tail_expr, tail_expr_ty); } else { // Subtle: if there is no explicit tail expression, // that is typically equivalent to a tail expression diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs index 95b40368143e..d416db628c03 100644 --- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs +++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.rs @@ -17,10 +17,10 @@ fn bar() -> impl std::fmt::Display { fn baz() -> impl std::fmt::Display { if false { - //~^ ERROR mismatched types return 0i32; } else { 1u32 + //~^ ERROR mismatched types } } diff --git a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr index ee1e36081e77..47644d66d1a2 100644 --- a/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr +++ b/src/test/ui/point-to-type-err-cause-on-impl-trait-return.stderr @@ -29,18 +29,16 @@ LL | return 1u32; found type `u32` error[E0308]: mismatched types - --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:19:5 + --> $DIR/point-to-type-err-cause-on-impl-trait-return.rs:22:9 | -LL | fn baz() -> impl std::fmt::Display { - | ---------------------- expected because this return type... -LL | / if false { -LL | | -LL | | return 0i32; - | | ---- ...is found to be `i32` here -LL | | } else { -LL | | 1u32 -LL | | } - | |_____^ expected i32, found u32 +LL | fn baz() -> impl std::fmt::Display { + | ---------------------- expected because this return type... +LL | if false { +LL | return 0i32; + | ---- ...is found to be `i32` here +LL | } else { +LL | 1u32 + | ^^^^ expected i32, found u32 | = note: expected type `i32` found type `u32`