From 8e73ea52537fe5189fac5cd02380592563fe7f0c Mon Sep 17 00:00:00 2001 From: hkalbasi Date: Sun, 19 Mar 2023 13:02:51 +0330 Subject: [PATCH] Desugar try blocks --- crates/hir-def/src/body/lower.rs | 87 +++++++++++++++++++++++---- crates/hir-def/src/body/pretty.rs | 3 - crates/hir-def/src/body/scope.rs | 3 +- crates/hir-def/src/expr.rs | 6 -- crates/hir-expand/src/name.rs | 14 ++++- crates/hir-ty/src/consteval/tests.rs | 54 +++++++++++++++++ crates/hir-ty/src/infer.rs | 4 -- crates/hir-ty/src/infer/expr.rs | 20 ------ crates/hir-ty/src/infer/mutability.rs | 1 - crates/hir-ty/src/mir/lower.rs | 87 ++++++++++++++++++--------- crates/hir-ty/src/tests/traits.rs | 20 ++++-- 11 files changed, 215 insertions(+), 84 deletions(-) diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 448821d3844a..a93fcb3b1dc2 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -19,7 +19,7 @@ use rustc_hash::FxHashMap; use smallvec::SmallVec; use syntax::{ ast::{ - self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind, + self, ArrayExprKind, AstChildren, BlockExpr, HasArgList, HasLoopBody, HasName, LiteralKind, SlicePatComponents, }, AstNode, AstPtr, SyntaxNodePtr, @@ -100,6 +100,7 @@ pub(super) fn lower( _c: Count::new(), }, expander, + current_try_block: None, is_lowering_assignee_expr: false, is_lowering_generator: false, } @@ -113,6 +114,7 @@ struct ExprCollector<'a> { body: Body, krate: CrateId, source_map: BodySourceMap, + current_try_block: Option, is_lowering_assignee_expr: bool, is_lowering_generator: bool, } @@ -222,6 +224,10 @@ impl ExprCollector<'_> { self.source_map.label_map.insert(src, id); id } + // FIXME: desugared labels don't have ptr, that's wrong and should be fixed somehow. + fn alloc_label_desugared(&mut self, label: Label) -> LabelId { + self.body.labels.alloc(label) + } fn make_label(&mut self, label: Label, src: LabelSource) -> LabelId { let id = self.body.labels.alloc(label); self.source_map.label_map_back.insert(id, src); @@ -259,13 +265,7 @@ impl ExprCollector<'_> { self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr) } ast::Expr::BlockExpr(e) => match e.modifier() { - Some(ast::BlockModifier::Try(_)) => { - self.collect_block_(e, |id, statements, tail| Expr::TryBlock { - id, - statements, - tail, - }) - } + Some(ast::BlockModifier::Try(_)) => self.collect_try_block(e), Some(ast::BlockModifier::Unsafe(_)) => { self.collect_block_(e, |id, statements, tail| Expr::Unsafe { id, @@ -606,6 +606,59 @@ impl ExprCollector<'_> { }) } + /// Desugar `try { ; }` into `': { ; ::std::ops::Try::from_output() }`, + /// `try { ; }` into `': { ; ::std::ops::Try::from_output(()) }` + /// and save the `` to use it as a break target for desugaring of the `?` operator. + fn collect_try_block(&mut self, e: BlockExpr) -> ExprId { + let Some(try_from_output) = LangItem::TryTraitFromOutput.path(self.db, self.krate) else { + return self.alloc_expr_desugared(Expr::Missing); + }; + let prev_try_block = self.current_try_block.take(); + self.current_try_block = + Some(self.alloc_label_desugared(Label { name: Name::generate_new_name() })); + let expr_id = self.collect_block(e); + let callee = self.alloc_expr_desugared(Expr::Path(try_from_output)); + let Expr::Block { label, tail, .. } = &mut self.body.exprs[expr_id] else { + unreachable!("It is the output of collect block"); + }; + *label = self.current_try_block; + let next_tail = match *tail { + Some(tail) => self.alloc_expr_desugared(Expr::Call { + callee, + args: Box::new([tail]), + is_assignee_expr: false, + }), + None => { + let unit = self.alloc_expr_desugared(Expr::Tuple { + exprs: Box::new([]), + is_assignee_expr: false, + }); + self.alloc_expr_desugared(Expr::Call { + callee, + args: Box::new([unit]), + is_assignee_expr: false, + }) + } + }; + let Expr::Block { tail, .. } = &mut self.body.exprs[expr_id] else { + unreachable!("It is the output of collect block"); + }; + *tail = Some(next_tail); + self.current_try_block = prev_try_block; + expr_id + } + + /// Desugar `ast::TryExpr` from: `?` into: + /// ```ignore (pseudo-rust) + /// match Try::branch() { + /// ControlFlow::Continue(val) => val, + /// ControlFlow::Break(residual) => + /// // If there is an enclosing `try {...}`: + /// break 'catch_target Try::from_residual(residual), + /// // Otherwise: + /// return Try::from_residual(residual), + /// } + /// ``` fn collect_try_operator(&mut self, syntax_ptr: AstPtr, e: ast::TryExpr) -> ExprId { let (try_branch, cf_continue, cf_break, try_from_residual) = 'if_chain: { if let Some(try_branch) = LangItem::TryTraitBranch.path(self.db, self.krate) { @@ -628,7 +681,9 @@ impl ExprCollector<'_> { Expr::Call { callee: try_branch, args: Box::new([operand]), is_assignee_expr: false }, syntax_ptr.clone(), ); - let continue_binding = self.alloc_binding(name![v1], BindingAnnotation::Unannotated); + let continue_name = Name::generate_new_name(); + let continue_binding = + self.alloc_binding(continue_name.clone(), BindingAnnotation::Unannotated); let continue_bpat = self.alloc_pat_desugared(Pat::Bind { id: continue_binding, subpat: None }); self.add_definition_to_binding(continue_binding, continue_bpat); @@ -639,9 +694,10 @@ impl ExprCollector<'_> { ellipsis: None, }), guard: None, - expr: self.alloc_expr(Expr::Path(Path::from(name![v1])), syntax_ptr.clone()), + expr: self.alloc_expr(Expr::Path(Path::from(continue_name)), syntax_ptr.clone()), }; - let break_binding = self.alloc_binding(name![v1], BindingAnnotation::Unannotated); + let break_name = Name::generate_new_name(); + let break_binding = self.alloc_binding(break_name.clone(), BindingAnnotation::Unannotated); let break_bpat = self.alloc_pat_desugared(Pat::Bind { id: break_binding, subpat: None }); self.add_definition_to_binding(break_binding, break_bpat); let break_arm = MatchArm { @@ -652,13 +708,18 @@ impl ExprCollector<'_> { }), guard: None, expr: { - let x = self.alloc_expr(Expr::Path(Path::from(name![v1])), syntax_ptr.clone()); + let x = self.alloc_expr(Expr::Path(Path::from(break_name)), syntax_ptr.clone()); let callee = self.alloc_expr(Expr::Path(try_from_residual), syntax_ptr.clone()); let result = self.alloc_expr( Expr::Call { callee, args: Box::new([x]), is_assignee_expr: false }, syntax_ptr.clone(), ); - self.alloc_expr(Expr::Return { expr: Some(result) }, syntax_ptr.clone()) + if let Some(label) = self.current_try_block { + let label = Some(self.body.labels[label].name.clone()); + self.alloc_expr(Expr::Break { expr: Some(result), label }, syntax_ptr.clone()) + } else { + self.alloc_expr(Expr::Return { expr: Some(result) }, syntax_ptr.clone()) + } }, }; let arms = Box::new([continue_arm, break_arm]); diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs index c091ad0d150f..8c9d77620e18 100644 --- a/crates/hir-def/src/body/pretty.rs +++ b/crates/hir-def/src/body/pretty.rs @@ -420,9 +420,6 @@ impl<'a> Printer<'a> { Expr::Unsafe { id: _, statements, tail } => { self.print_block(Some("unsafe "), statements, tail); } - Expr::TryBlock { id: _, statements, tail } => { - self.print_block(Some("try "), statements, tail); - } Expr::Async { id: _, statements, tail } => { self.print_block(Some("async "), statements, tail); } diff --git a/crates/hir-def/src/body/scope.rs b/crates/hir-def/src/body/scope.rs index 12fc1f116d7d..8ddb89a4725d 100644 --- a/crates/hir-def/src/body/scope.rs +++ b/crates/hir-def/src/body/scope.rs @@ -202,8 +202,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope } Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } - | Expr::Const { id, statements, tail } - | Expr::TryBlock { id, statements, tail } => { + | Expr::Const { id, statements, tail } => { let mut scope = scopes.new_block_scope(*scope, *id, None); // Overwrite the old scope for the block expr, so that every block scope can be found // via the block itself (important for blocks that only contain items, no expressions). diff --git a/crates/hir-def/src/expr.rs b/crates/hir-def/src/expr.rs index 8b1528f81e61..7ede19cc3c2b 100644 --- a/crates/hir-def/src/expr.rs +++ b/crates/hir-def/src/expr.rs @@ -122,11 +122,6 @@ pub enum Expr { tail: Option, label: Option, }, - TryBlock { - id: BlockId, - statements: Box<[Statement]>, - tail: Option, - }, Async { id: BlockId, statements: Box<[Statement]>, @@ -310,7 +305,6 @@ impl Expr { f(*expr); } Expr::Block { statements, tail, .. } - | Expr::TryBlock { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } | Expr::Async { statements, tail, .. } | Expr::Const { statements, tail, .. } => { diff --git a/crates/hir-expand/src/name.rs b/crates/hir-expand/src/name.rs index 71eb35d9df8e..8099c20b027c 100644 --- a/crates/hir-expand/src/name.rs +++ b/crates/hir-expand/src/name.rs @@ -78,7 +78,7 @@ impl Name { Self::new_text(lt.text().into()) } - /// Shortcut to create inline plain text name + /// Shortcut to create inline plain text name. Panics if `text.len() > 22` const fn new_inline(text: &str) -> Name { Name::new_text(SmolStr::new_inline(text)) } @@ -112,6 +112,18 @@ impl Name { Name::new_inline("[missing name]") } + /// Generates a new name which is only equal to itself, by incrementing a counter. Due + /// its implementation, it should not be used in things that salsa considers, like + /// type names or field names, and it should be only used in names of local variables + /// and labels and similar things. + pub fn generate_new_name() -> Name { + use std::sync::atomic::{AtomicUsize, Ordering}; + static CNT: AtomicUsize = AtomicUsize::new(0); + let c = CNT.fetch_add(1, Ordering::Relaxed); + // FIXME: Currently a `__RA_generated_name` in user code will break our analysis + Name::new_text(format!("__RA_geneated_name_{c}").into()) + } + /// Returns the tuple index this name represents if it is a tuple field. pub fn as_tuple_index(&self) -> Option { match self.0 { diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 97c8d62860c6..2ba0cbd5db4f 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -522,6 +522,42 @@ fn loops() { "#, 4, ); + check_number( + r#" + const GOAL: u8 = { + let mut x = 0; + loop { + x = x + 1; + if x == 5 { + break x + 2; + } + } + }; + "#, + 7, + ); + check_number( + r#" + const GOAL: u8 = { + 'a: loop { + let x = 'b: loop { + let x = 'c: loop { + let x = 'd: loop { + let x = 'e: loop { + break 'd 1; + }; + break 2 + x; + }; + break 3 + x; + }; + break 'a 4 + x; + }; + break 5 + x; + } + }; + "#, + 8, + ); } #[test] @@ -1019,6 +1055,24 @@ fn try_operator() { ); } +#[test] +fn try_block() { + check_number( + r#" + //- minicore: option, try + const fn g(x: Option, y: Option) -> i32 { + let r = try { x? * y? }; + match r { + Some(k) => k, + None => 5, + } + } + const GOAL: i32 = g(Some(10), Some(20)) + g(Some(30), None) + g(None, Some(40)) + g(None, None); + "#, + 215, + ); +} + #[test] fn or_pattern() { check_number( diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index d1b9aff36d24..38b7dee75fd5 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -1025,10 +1025,6 @@ impl<'a> InferenceContext<'a> { self.resolve_lang_item(lang)?.as_trait() } - fn resolve_ops_try_output(&self) -> Option { - self.resolve_output_on(self.resolve_lang_trait(LangItem::Try)?) - } - fn resolve_ops_neg_output(&self) -> Option { self.resolve_output_on(self.resolve_lang_trait(LangItem::Neg)?) } diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 82119c97ec21..6d2aa59ea359 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -159,26 +159,6 @@ impl<'a> InferenceContext<'a> { }) .1 } - Expr::TryBlock { id: _, statements, tail } => { - // The type that is returned from the try block - let try_ty = self.table.new_type_var(); - if let Some(ty) = expected.only_has_type(&mut self.table) { - self.unify(&try_ty, &ty); - } - - // The ok-ish type that is expected from the last expression - let ok_ty = - self.resolve_associated_type(try_ty.clone(), self.resolve_ops_try_output()); - - self.infer_block( - tgt_expr, - statements, - *tail, - None, - &Expectation::has_type(ok_ty.clone()), - ); - try_ty - } Expr::Async { id: _, statements, tail } => { let ret_ty = self.table.new_type_var(); let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs index 784725da9350..7ed21d230c7f 100644 --- a/crates/hir-ty/src/infer/mutability.rs +++ b/crates/hir-ty/src/infer/mutability.rs @@ -44,7 +44,6 @@ impl<'a> InferenceContext<'a> { } Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)), Expr::Block { id: _, statements, tail, label: _ } - | Expr::TryBlock { id: _, statements, tail } | Expr::Async { id: _, statements, tail } | Expr::Const { id: _, statements, tail } | Expr::Unsafe { id: _, statements, tail } => { diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 4b43e44a8ec3..5d9ae320726f 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -18,6 +18,7 @@ use hir_def::{ }; use hir_expand::name::Name; use la_arena::ArenaMap; +use rustc_hash::FxHashMap; use crate::{ consteval::ConstEvalError, db::HirDatabase, display::HirDisplay, infer::TypeMismatch, @@ -32,17 +33,21 @@ mod pattern_matching; use pattern_matching::AdtPatternShape; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] struct LoopBlocks { begin: BasicBlockId, /// `None` for loops that are not terminating end: Option, + place: Place, } struct MirLowerCtx<'a> { result: MirBody, owner: DefWithBodyId, current_loop_blocks: Option, + // FIXME: we should resolve labels in HIR lowering and always work with label id here, not + // with raw names. + labeled_loop_blocks: FxHashMap, discr_temp: Option, db: &'a dyn HirDatabase, body: &'a Body, @@ -72,6 +77,7 @@ pub enum MirLowerError { ImplementationError(&'static str), LangItemNotFound(LangItem), MutatingRvalue, + UnresolvedLabel, } macro_rules! not_supported { @@ -375,19 +381,29 @@ impl MirLowerCtx<'_> { Ok(self.merge_blocks(Some(then_target), else_target)) } Expr::Unsafe { id: _, statements, tail } => { - self.lower_block_to_place(None, statements, current, *tail, place) + self.lower_block_to_place(statements, current, *tail, place) } Expr::Block { id: _, statements, tail, label } => { - self.lower_block_to_place(*label, statements, current, *tail, place) + if let Some(label) = label { + self.lower_loop(current, place.clone(), Some(*label), |this, begin| { + if let Some(block) = this.lower_block_to_place(statements, begin, *tail, place)? { + let end = this.current_loop_end()?; + this.set_goto(block, end); + } + Ok(()) + }) + } else { + self.lower_block_to_place(statements, current, *tail, place) + } } - Expr::Loop { body, label } => self.lower_loop(current, *label, |this, begin| { + Expr::Loop { body, label } => self.lower_loop(current, place, *label, |this, begin| { if let Some((_, block)) = this.lower_expr_as_place(begin, *body, true)? { this.set_goto(block, begin); } Ok(()) }), Expr::While { condition, body, label } => { - self.lower_loop(current, *label, |this, begin| { + self.lower_loop(current, place, *label, |this, begin| { let Some((discr, to_switch)) = this.lower_expr_to_some_operand(*condition, begin)? else { return Ok(()); }; @@ -438,7 +454,7 @@ impl MirLowerCtx<'_> { return Ok(None); }; self.push_assignment(current, ref_mut_iterator_place.clone(), Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, iterator_place), expr_id.into()); - self.lower_loop(current, label, |this, begin| { + self.lower_loop(current, place, label, |this, begin| { let Some(current) = this.lower_call(iter_next_fn_op, vec![Operand::Copy(ref_mut_iterator_place)], option_item_place.clone(), begin, false)? else { return Ok(()); @@ -558,24 +574,28 @@ impl MirLowerCtx<'_> { Some(_) => not_supported!("continue with label"), None => { let loop_data = - self.current_loop_blocks.ok_or(MirLowerError::ContinueWithoutLoop)?; + self.current_loop_blocks.as_ref().ok_or(MirLowerError::ContinueWithoutLoop)?; self.set_goto(current, loop_data.begin); Ok(None) } }, Expr::Break { expr, label } => { - if expr.is_some() { - not_supported!("break with value"); - } - match label { - Some(_) => not_supported!("break with label"), - None => { - let end = - self.current_loop_end()?; - self.set_goto(current, end); - Ok(None) - } + if let Some(expr) = expr { + let loop_data = match label { + Some(l) => self.labeled_loop_blocks.get(l).ok_or(MirLowerError::UnresolvedLabel)?, + None => self.current_loop_blocks.as_ref().ok_or(MirLowerError::BreakWithoutLoop)?, + }; + let Some(c) = self.lower_expr_to_place(*expr, loop_data.place.clone(), current)? else { + return Ok(None); + }; + current = c; } + let end = match label { + Some(l) => self.labeled_loop_blocks.get(l).ok_or(MirLowerError::UnresolvedLabel)?.end.expect("We always generate end for labeled loops"), + None => self.current_loop_end()?, + }; + self.set_goto(current, end); + Ok(None) } Expr::Return { expr } => { if let Some(expr) = expr { @@ -668,7 +688,6 @@ impl MirLowerCtx<'_> { } Expr::Await { .. } => not_supported!("await"), Expr::Yeet { .. } => not_supported!("yeet"), - Expr::TryBlock { .. } => not_supported!("try block"), Expr::Async { .. } => not_supported!("async block"), Expr::Const { .. } => not_supported!("anonymous const block"), Expr::Cast { expr, type_ref: _ } => { @@ -1085,19 +1104,34 @@ impl MirLowerCtx<'_> { fn lower_loop( &mut self, prev_block: BasicBlockId, + place: Place, label: Option, f: impl FnOnce(&mut MirLowerCtx<'_>, BasicBlockId) -> Result<()>, ) -> Result> { - if label.is_some() { - not_supported!("loop with label"); - } let begin = self.new_basic_block(); - let prev = - mem::replace(&mut self.current_loop_blocks, Some(LoopBlocks { begin, end: None })); + let prev = mem::replace( + &mut self.current_loop_blocks, + Some(LoopBlocks { begin, end: None, place }), + ); + let prev_label = if let Some(label) = label { + // We should generate the end now, to make sure that it wouldn't change later. It is + // bad as we may emit end (unneccessary unreachable block) for unterminating loop, but + // it should not affect correctness. + self.current_loop_end()?; + self.labeled_loop_blocks.insert( + self.body.labels[label].name.clone(), + self.current_loop_blocks.as_ref().unwrap().clone(), + ) + } else { + None + }; self.set_goto(prev_block, begin); f(self, begin)?; let my = mem::replace(&mut self.current_loop_blocks, prev) .ok_or(MirLowerError::ImplementationError("current_loop_blocks is corrupt"))?; + if let Some(prev) = prev_label { + self.labeled_loop_blocks.insert(self.body.labels[label.unwrap()].name.clone(), prev); + } Ok(my.end) } @@ -1185,15 +1219,11 @@ impl MirLowerCtx<'_> { fn lower_block_to_place( &mut self, - label: Option, statements: &[hir_def::expr::Statement], mut current: BasicBlockId, tail: Option, place: Place, ) -> Result>> { - if label.is_some() { - not_supported!("block with label"); - } for statement in statements.iter() { match statement { hir_def::expr::Statement::Let { pat, initializer, else_branch, type_ref: _ } => { @@ -1355,6 +1385,7 @@ pub fn lower_to_mir( body, owner, current_loop_blocks: None, + labeled_loop_blocks: Default::default(), discr_temp: None, }; let mut current = start_block; diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index da76d7fd83f7..97ec1bb871d4 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -206,19 +206,27 @@ fn test() { fn infer_try_trait() { check_types( r#" -//- minicore: try, result +//- minicore: try, result, from fn test() { let r: Result = Result::Ok(1); let v = r?; v; } //^ i32 - -impl core::ops::Try for Result { - type Output = O; - type Error = Result; +"#, + ); } -impl> core::ops::FromResidual> for Result {} +#[test] +fn infer_try_block() { + // FIXME: We should test more cases, but it currently doesn't work, since + // our labeled block type inference is broken. + check_types( + r#" +//- minicore: try, option +fn test() { + let x: Option<_> = try { Some(2)?; }; + //^ Option<()> +} "#, ); }