From 6bf6f4ff1dd8d342c061708041810c64fe983ab8 Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Sun, 23 Jan 2022 05:39:26 +0200 Subject: [PATCH] Lower `let` expressions --- crates/hir_def/src/body/lower.rs | 73 ++----------- crates/hir_def/src/body/scope.rs | 102 +++++++++++------- crates/hir_def/src/expr.rs | 16 +-- .../mbe/tt_conversion.rs | 6 +- 4 files changed, 87 insertions(+), 110 deletions(-) diff --git a/crates/hir_def/src/body/lower.rs b/crates/hir_def/src/body/lower.rs index 7cbeef1488a0..06ad7ce4cd08 100644 --- a/crates/hir_def/src/body/lower.rs +++ b/crates/hir_def/src/body/lower.rs @@ -28,7 +28,7 @@ use crate::{ db::DefDatabase, expr::{ dummy_expr_id, Array, BindingAnnotation, Expr, ExprId, Label, LabelId, Literal, MatchArm, - MatchGuard, Pat, PatId, RecordFieldPat, RecordLitField, Statement, + Pat, PatId, RecordFieldPat, RecordLitField, Statement, }, intern::Interned, item_scope::BuiltinShadowMode, @@ -155,9 +155,6 @@ impl ExprCollector<'_> { fn alloc_expr_desugared(&mut self, expr: Expr) -> ExprId { self.make_expr(expr, Err(SyntheticSyntax)) } - fn unit(&mut self) -> ExprId { - self.alloc_expr_desugared(Expr::Tuple { exprs: Box::default() }) - } fn missing_expr(&mut self) -> ExprId { self.alloc_expr_desugared(Expr::Missing) } @@ -215,33 +212,15 @@ impl ExprCollector<'_> { } }); - let condition = match e.condition() { - None => self.missing_expr(), - Some(condition) => match condition.pat() { - None => self.collect_expr_opt(condition.expr()), - // if let -- desugar to match - Some(pat) => { - let pat = self.collect_pat(pat); - let match_expr = self.collect_expr_opt(condition.expr()); - let placeholder_pat = self.missing_pat(); - let arms = vec![ - MatchArm { pat, expr: then_branch, guard: None }, - MatchArm { - pat: placeholder_pat, - expr: else_branch.unwrap_or_else(|| self.unit()), - guard: None, - }, - ] - .into(); - return Some( - self.alloc_expr(Expr::Match { expr: match_expr, arms }, syntax_ptr), - ); - } - }, - }; + let condition = self.collect_expr_opt(e.condition()); self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr) } + ast::Expr::LetExpr(e) => { + let pat = self.collect_pat_opt(e.pat()); + let expr = self.collect_expr_opt(e.expr()); + self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr) + } ast::Expr::BlockExpr(e) => match e.modifier() { Some(ast::BlockModifier::Try(_)) => { let body = self.collect_block(e); @@ -282,31 +261,7 @@ impl ExprCollector<'_> { let label = e.label().map(|label| self.collect_label(label)); let body = self.collect_block_opt(e.loop_body()); - let condition = match e.condition() { - None => self.missing_expr(), - Some(condition) => match condition.pat() { - None => self.collect_expr_opt(condition.expr()), - // if let -- desugar to match - Some(pat) => { - cov_mark::hit!(infer_resolve_while_let); - let pat = self.collect_pat(pat); - let match_expr = self.collect_expr_opt(condition.expr()); - let placeholder_pat = self.missing_pat(); - let break_ = - self.alloc_expr_desugared(Expr::Break { expr: None, label: None }); - let arms = vec![ - MatchArm { pat, expr: body, guard: None }, - MatchArm { pat: placeholder_pat, expr: break_, guard: None }, - ] - .into(); - let match_expr = - self.alloc_expr_desugared(Expr::Match { expr: match_expr, arms }); - return Some( - self.alloc_expr(Expr::Loop { body: match_expr, label }, syntax_ptr), - ); - } - }, - }; + let condition = self.collect_expr_opt(e.condition()); self.alloc_expr(Expr::While { condition, body, label }, syntax_ptr) } @@ -352,15 +307,9 @@ impl ExprCollector<'_> { self.check_cfg(&arm).map(|()| MatchArm { pat: self.collect_pat_opt(arm.pat()), expr: self.collect_expr_opt(arm.expr()), - guard: arm.guard().map(|guard| match guard.pat() { - Some(pat) => MatchGuard::IfLet { - pat: self.collect_pat(pat), - expr: self.collect_expr_opt(guard.expr()), - }, - None => { - MatchGuard::If { expr: self.collect_expr_opt(guard.expr()) } - } - }), + guard: arm + .guard() + .map(|guard| self.collect_expr_opt(guard.condition())), }) }) .collect() diff --git a/crates/hir_def/src/body/scope.rs b/crates/hir_def/src/body/scope.rs index 2658eece8e85..505d33fa482a 100644 --- a/crates/hir_def/src/body/scope.rs +++ b/crates/hir_def/src/body/scope.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use crate::{ body::Body, db::DefDatabase, - expr::{Expr, ExprId, LabelId, MatchGuard, Pat, PatId, Statement}, + expr::{Expr, ExprId, LabelId, Pat, PatId, Statement}, BlockId, DefWithBodyId, }; @@ -53,9 +53,9 @@ impl ExprScopes { fn new(body: &Body) -> ExprScopes { let mut scopes = ExprScopes { scopes: Arena::default(), scope_by_expr: FxHashMap::default() }; - let root = scopes.root_scope(); + let mut root = scopes.root_scope(); scopes.add_params_bindings(body, root, &body.params); - compute_expr_scopes(body.body_expr, body, &mut scopes, root); + compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root); scopes } @@ -151,32 +151,32 @@ fn compute_block_scopes( match stmt { Statement::Let { pat, initializer, else_branch, .. } => { if let Some(expr) = initializer { - compute_expr_scopes(*expr, body, scopes, scope); + compute_expr_scopes(*expr, body, scopes, &mut scope); } if let Some(expr) = else_branch { - compute_expr_scopes(*expr, body, scopes, scope); + compute_expr_scopes(*expr, body, scopes, &mut scope); } scope = scopes.new_scope(scope); scopes.add_bindings(body, scope, *pat); } Statement::Expr { expr, .. } => { - compute_expr_scopes(*expr, body, scopes, scope); + compute_expr_scopes(*expr, body, scopes, &mut scope); } } } if let Some(expr) = tail { - compute_expr_scopes(expr, body, scopes, scope); + compute_expr_scopes(expr, body, scopes, &mut scope); } } -fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) { +fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId) { let make_label = |label: &Option| label.map(|label| (label, body.labels[label].name.clone())); - scopes.set_scope(expr, scope); + scopes.set_scope(expr, *scope); match &body[expr] { Expr::Block { statements, tail, id, label } => { - let scope = scopes.new_block_scope(scope, *id, make_label(label)); + let scope = scopes.new_block_scope(*scope, *id, make_label(label)); // Overwrite the old scope for the block expr, so that every block scope can be found // via the block itself (important for blocks that only contain items, no expressions). scopes.set_scope(expr, scope); @@ -184,46 +184,49 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope } Expr::For { iterable, pat, body: body_expr, label } => { compute_expr_scopes(*iterable, body, scopes, scope); - let scope = scopes.new_labeled_scope(scope, make_label(label)); + let mut scope = scopes.new_labeled_scope(*scope, make_label(label)); scopes.add_bindings(body, scope, *pat); - compute_expr_scopes(*body_expr, body, scopes, scope); + compute_expr_scopes(*body_expr, body, scopes, &mut scope); } Expr::While { condition, body: body_expr, label } => { - let scope = scopes.new_labeled_scope(scope, make_label(label)); - compute_expr_scopes(*condition, body, scopes, scope); - compute_expr_scopes(*body_expr, body, scopes, scope); + let mut scope = scopes.new_labeled_scope(*scope, make_label(label)); + compute_expr_scopes(*condition, body, scopes, &mut scope); + compute_expr_scopes(*body_expr, body, scopes, &mut scope); } Expr::Loop { body: body_expr, label } => { - let scope = scopes.new_labeled_scope(scope, make_label(label)); - compute_expr_scopes(*body_expr, body, scopes, scope); + let mut scope = scopes.new_labeled_scope(*scope, make_label(label)); + compute_expr_scopes(*body_expr, body, scopes, &mut scope); } Expr::Lambda { args, body: body_expr, .. } => { - let scope = scopes.new_scope(scope); + let mut scope = scopes.new_scope(*scope); scopes.add_params_bindings(body, scope, args); - compute_expr_scopes(*body_expr, body, scopes, scope); + compute_expr_scopes(*body_expr, body, scopes, &mut scope); } Expr::Match { expr, arms } => { compute_expr_scopes(*expr, body, scopes, scope); for arm in arms.iter() { - let mut scope = scopes.new_scope(scope); + let mut scope = scopes.new_scope(*scope); scopes.add_bindings(body, scope, arm.pat); - match arm.guard { - Some(MatchGuard::If { expr: guard }) => { - scopes.set_scope(guard, scope); - compute_expr_scopes(guard, body, scopes, scope); - } - Some(MatchGuard::IfLet { pat, expr: guard }) => { - scopes.set_scope(guard, scope); - compute_expr_scopes(guard, body, scopes, scope); - scope = scopes.new_scope(scope); - scopes.add_bindings(body, scope, pat); - } - _ => {} - }; - scopes.set_scope(arm.expr, scope); - compute_expr_scopes(arm.expr, body, scopes, scope); + if let Some(guard) = arm.guard { + scope = scopes.new_scope(scope); + compute_expr_scopes(guard, body, scopes, &mut scope); + } + compute_expr_scopes(arm.expr, body, scopes, &mut scope); } } + &Expr::If { condition, then_branch, else_branch } => { + let mut then_branch_scope = scopes.new_scope(*scope); + compute_expr_scopes(condition, body, scopes, &mut then_branch_scope); + compute_expr_scopes(then_branch, body, scopes, &mut then_branch_scope); + if let Some(else_branch) = else_branch { + compute_expr_scopes(else_branch, body, scopes, scope); + } + } + &Expr::Let { pat, expr } => { + compute_expr_scopes(expr, body, scopes, scope); + *scope = scopes.new_scope(*scope); + scopes.add_bindings(body, *scope, pat); + } e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)), }; } @@ -500,8 +503,7 @@ fn foo() { } #[test] - fn while_let_desugaring() { - cov_mark::check!(infer_resolve_while_let); + fn while_let_adds_binding() { do_check_local_name( r#" fn test() { @@ -513,5 +515,31 @@ fn test() { "#, 75, ); + do_check_local_name( + r#" +fn test() { + let foo: Option = None; + while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo { + spam$0 + } +} +"#, + 107, + ); + } + + #[test] + fn match_guard_if_let() { + do_check_local_name( + r#" +fn test() { + let foo: Option = None; + match foo { + _ if let Option::Some(spam) = foo => spam$0, + } +} +"#, + 93, + ); } } diff --git a/crates/hir_def/src/expr.rs b/crates/hir_def/src/expr.rs index 6534f970ee6b..4dca8238880d 100644 --- a/crates/hir_def/src/expr.rs +++ b/crates/hir_def/src/expr.rs @@ -59,6 +59,10 @@ pub enum Expr { then_branch: ExprId, else_branch: Option, }, + Let { + pat: PatId, + expr: ExprId, + }, Block { id: BlockId, statements: Box<[Statement]>, @@ -189,17 +193,10 @@ pub enum Array { #[derive(Debug, Clone, Eq, PartialEq)] pub struct MatchArm { pub pat: PatId, - pub guard: Option, + pub guard: Option, pub expr: ExprId, } -#[derive(Debug, Clone, Eq, PartialEq)] -pub enum MatchGuard { - If { expr: ExprId }, - - IfLet { pat: PatId, expr: ExprId }, -} - #[derive(Debug, Clone, Eq, PartialEq)] pub struct RecordLitField { pub name: Name, @@ -232,6 +229,9 @@ impl Expr { f(else_branch); } } + Expr::Let { expr, .. } => { + f(*expr); + } Expr::Block { statements, tail, .. } => { for stmt in statements.iter() { match stmt { diff --git a/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs b/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs index 5f4b7d6d0bca..84cc3f3872f2 100644 --- a/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs +++ b/crates/hir_def/src/macro_expansion_tests/mbe/tt_conversion.rs @@ -108,18 +108,18 @@ fn expansion_does_not_parse_as_expression() { check( r#" macro_rules! stmts { - () => { let _ = 0; } + () => { fn foo() {} } } fn f() { let _ = stmts!/*+errors*/(); } "#, expect![[r#" macro_rules! stmts { - () => { let _ = 0; } + () => { fn foo() {} } } fn f() { let _ = /* parse error: expected expression */ -let _ = 0;; } +fn foo() {}; } "#]], ) }