From 075fd364d06d89f5cd79308938e3ed4a5893ec84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?John=20K=C3=A5re=20Alsaker?= Date: Mon, 10 Jul 2017 20:04:15 +0200 Subject: [PATCH] Ensure upvars are dropped when generators have never been resumed --- src/librustc_mir/transform/generator.rs | 158 +++++++++++++----- .../generator/no-arguments-on-generators.rs | 2 +- src/test/run-pass/generator/drop-env.rs | 6 +- 3 files changed, 118 insertions(+), 48 deletions(-) diff --git a/src/librustc_mir/transform/generator.rs b/src/librustc_mir/transform/generator.rs index 5e33eb95d403..a2c1cf415974 100644 --- a/src/librustc_mir/transform/generator.rs +++ b/src/librustc_mir/transform/generator.rs @@ -83,7 +83,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor { lvalue: &mut Lvalue<'tcx>, context: LvalueContext<'tcx>, location: Location) { - if *lvalue == Lvalue::Local(Local::new(2)) { + if *lvalue == Lvalue::Local(Local::new(1)) { *lvalue = Lvalue::Projection(Box::new(Projection { base: lvalue.clone(), elem: ProjectionElem::Deref, @@ -114,10 +114,6 @@ impl<'a, 'tcx> TransformVisitor<'a, 'tcx> { fn make_field(&self, idx: usize, ty: Ty<'tcx>) -> Lvalue<'tcx> { let base = Lvalue::Local(Local::new(1)); - let base = Lvalue::Projection(Box::new(Projection { - base: base, - elem: ProjectionElem::Deref, - })); let field = Projection { base: base, elem: ProjectionElem::Field(Field::new(idx), ty), @@ -258,6 +254,23 @@ fn ensure_generator_state_argument<'a, 'tcx>( let gen_ty = mir.local_decls.raw[2].ty; + // Swap generator and implicit argument + SwapLocalVisitor { + a: Local::new(1), + b: Local::new(2), + }.visit_mir(mir); + + mir.local_decls.raw[..].swap(1, 2); + + (gen_ty, interior) +} + +fn make_generator_state_argument_indirect<'a, 'tcx>( + tcx: TyCtxt<'a, 'tcx, 'tcx>, + def_id: DefId, + mir: &mut Mir<'tcx>) { + let gen_ty = mir.local_decls.raw[1].ty; + let region = ty::ReFree(ty::FreeRegion { scope: def_id, bound_region: ty::BoundRegion::BrEnv, @@ -271,20 +284,10 @@ fn ensure_generator_state_argument<'a, 'tcx>( }); // Replace the by value generator argument - mir.local_decls.raw[2].ty = ref_gen_ty; + mir.local_decls.raw[1].ty = ref_gen_ty; - // Add a deref to accesses of the generator state for upvars + // Add a deref to accesses of the generator state DerefArgVisitor.visit_mir(mir); - - // Swap generator and implicit argument - SwapLocalVisitor { - a: Local::new(1), - b: Local::new(2), - }.visit_mir(mir); - - mir.local_decls.raw[..].swap(1, 2); - - (gen_ty, interior) } fn replace_result_variable<'tcx>(ret_ty: Ty<'tcx>, @@ -412,7 +415,7 @@ fn elaborate_generator_drops<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, use shim::DropShimElaborator; let param_env = tcx.param_env(def_id); - let gen = Local::new(2); + let gen = Local::new(1); for block in mir.basic_blocks().indices() { let (target, unwind, source_info) = match mir.basic_blocks()[block].terminator() { @@ -460,12 +463,14 @@ fn generate_drop<'a, 'tcx>( def_id: DefId, source: MirSource, gen_ty: Ty<'tcx>, - mir: &mut Mir<'tcx>) { + mir: &mut Mir<'tcx>, + drop_clean: BasicBlock) { let source_info = SourceInfo { span: mir.span, scope: ARGUMENT_VISIBILITY_SCOPE, }; + let return_block = BasicBlock::new(mir.basic_blocks().len()); mir.basic_blocks_mut().push(BasicBlockData { statements: Vec::new(), terminator: Some(Terminator { @@ -475,10 +480,12 @@ fn generate_drop<'a, 'tcx>( is_cleanup: false, }); - let cases: Vec<_> = transform.bb_targets.iter().filter_map(|(&(r, u), &s)| { + let mut cases: Vec<_> = transform.bb_targets.iter().filter_map(|(&(r, u), &s)| { u.map(|d| (s, d)) }).collect(); + cases.insert(0, (0, drop_clean)); + // The poisoned state 1 falls through to the default case which is just to return let switch = TerminatorKind::SwitchInt { @@ -487,7 +494,7 @@ fn generate_drop<'a, 'tcx>( values: Cow::from(cases.iter().map(|&(i, _)| { ConstInt::U32(i) }).collect::>()), - targets: cases.iter().map(|&(_, d)| d).chain(once(transform.return_block)).collect(), + targets: cases.iter().map(|&(_, d)| d).chain(once(return_block)).collect(), }; insert_entry_point(mir, BasicBlockData { @@ -525,6 +532,8 @@ fn generate_drop<'a, 'tcx>( is_user_variable: false, }; + make_generator_state_argument_indirect(tcx, def_id, mir); + // Change the generator argument from &mut to *mut mir.local_decls[Local::new(1)] = LocalDecl { mutability: Mutability::Mut, @@ -544,21 +553,9 @@ fn generate_drop<'a, 'tcx>( dump_mir(tcx, None, "generator_drop", &0, source, mir); } -fn generate_resume<'a, 'tcx>( - tcx: TyCtxt<'a, 'tcx, 'tcx>, - mut transform: TransformVisitor<'a, 'tcx>, - node_id: NodeId, - def_id: DefId, - source: MirSource, - mir: &mut Mir<'tcx>) { - // Poison the generator when it unwinds - for block in mir.basic_blocks_mut() { - let source_info = block.terminator().source_info; - if let &TerminatorKind::Resume = &block.terminator().kind { - block.statements.push(transform.set_state(1, source_info)); - } - } - +fn insert_resume_after_return<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, + def_id: DefId, + mir: &mut Mir<'tcx>) -> Option { let drop_arg = mir.local_decls.raw[2].ty.needs_drop(tcx, tcx.param_env(def_id)); let cleanup = if drop_arg { @@ -567,6 +564,7 @@ fn generate_resume<'a, 'tcx>( None }; + let assert_block = BasicBlock::new(mir.basic_blocks().len()); let term = TerminatorKind::Assert { cond: Operand::Constant(box Constant { span: mir.span, @@ -577,7 +575,7 @@ fn generate_resume<'a, 'tcx>( }), expected: true, msg: AssertMessage::GeneratorResumedAfterReturn, - target: transform.return_block, + target: assert_block, cleanup: cleanup, }; @@ -623,6 +621,30 @@ fn generate_resume<'a, 'tcx>( }); } + cleanup +} + +fn generate_resume<'a, 'tcx>( + tcx: TyCtxt<'a, 'tcx, 'tcx>, + mut transform: TransformVisitor<'a, 'tcx>, + node_id: NodeId, + def_id: DefId, + source: MirSource, + cleanup: Option, + mir: &mut Mir<'tcx>) { + // Poison the generator when it unwinds + for block in mir.basic_blocks_mut() { + let source_info = block.terminator().source_info; + if let &TerminatorKind::Resume = &block.terminator().kind { + block.statements.push(transform.set_state(1, source_info)); + } + } + + let source_info = SourceInfo { + span: mir.span, + scope: ARGUMENT_VISIBILITY_SCOPE, + }; + let poisoned_block = BasicBlock::new(mir.basic_blocks().len()); let term = TerminatorKind::Assert { @@ -671,6 +693,8 @@ fn generate_resume<'a, 'tcx>( is_cleanup: false, }); + make_generator_state_argument_indirect(tcx, def_id, mir); + // Make sure we remove dead blocks to remove // unrelated code from the drop part of the function simplify::remove_dead_blocks(mir); @@ -678,6 +702,41 @@ fn generate_resume<'a, 'tcx>( dump_mir(tcx, None, "generator_resume", &0, source, mir); } +fn insert_clean_drop<'a, 'tcx>(mir: &mut Mir<'tcx>) -> (BasicBlock, BasicBlock) { + let source_info = SourceInfo { + span: mir.span, + scope: ARGUMENT_VISIBILITY_SCOPE, + }; + + let return_block = BasicBlock::new(mir.basic_blocks().len()); + mir.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Return, + }), + is_cleanup: false, + }); + + // Create a block to destroy an unresumed generators. This can only destroy upvars. + let drop_clean = BasicBlock::new(mir.basic_blocks().len()); + let term = TerminatorKind::Drop { + location: Lvalue::Local(Local::new(1)), + target: return_block, + unwind: None, + }; + mir.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { + source_info, + kind: term, + }), + is_cleanup: false, + }); + + (return_block, drop_clean) +} + impl MirPass for StateTransform { fn run_pass<'a, 'tcx>(&self, tcx: TyCtxt<'a, 'tcx, 'tcx>, @@ -695,8 +754,6 @@ impl MirPass for StateTransform { let node_id = source.item_id(); let def_id = tcx.hir.local_def_id(source.item_id()); - elaborate_generator_drops(tcx, def_id, mir); - let (gen_ty, interior) = ensure_generator_state_argument(tcx, node_id, def_id, mir); let state_did = tcx.lang_items.gen_state().unwrap(); @@ -709,7 +766,7 @@ impl MirPass for StateTransform { let (remap, layout) = compute_layout(tcx, def_id, source, interior, mir); - let return_block = BasicBlock::new(mir.basic_blocks().len()); + let tail_block = BasicBlock::new(mir.basic_blocks().len()); let state_field = mir.upvar_decls.len(); @@ -724,7 +781,7 @@ impl MirPass for StateTransform { bb_target_count: 2, bb_targets, new_ret_local, - return_block, + return_block: tail_block, state_field, }; transform.visit_mir(mir); @@ -735,14 +792,29 @@ impl MirPass for StateTransform { mir.spread_arg = None; mir.generator_layout = Some(layout); + let arg_cleanup = insert_resume_after_return(tcx, def_id, mir); + + let (_return_block, drop_clean) = insert_clean_drop(mir); + + dump_mir(tcx, None, "generator_pre-elab", &0, source, mir); + + elaborate_generator_drops(tcx, def_id, mir); + dump_mir(tcx, None, "generator_post-transform", &0, source, mir); let mut drop_impl = mir.clone(); - generate_drop(tcx, &transform, node_id, def_id, source, gen_ty, &mut drop_impl); + generate_drop(tcx, + &transform, + node_id, + def_id, + source, + gen_ty, + &mut drop_impl, + drop_clean); mir.generator_drop = Some(box drop_impl); - generate_resume(tcx, transform, node_id, def_id, source, mir); + generate_resume(tcx, transform, node_id, def_id, source, arg_cleanup, mir); } } diff --git a/src/test/compile-fail/generator/no-arguments-on-generators.rs b/src/test/compile-fail/generator/no-arguments-on-generators.rs index 9c68fe4ceda8..0f8d29e5cabb 100644 --- a/src/test/compile-fail/generator/no-arguments-on-generators.rs +++ b/src/test/compile-fail/generator/no-arguments-on-generators.rs @@ -13,5 +13,5 @@ fn main() { let gen = |start| { //~ ERROR generators cannot have explicit arguments yield; - }; + }; } \ No newline at end of file diff --git a/src/test/run-pass/generator/drop-env.rs b/src/test/run-pass/generator/drop-env.rs index 43d0af7a0734..1b0df6f87236 100644 --- a/src/test/run-pass/generator/drop-env.rs +++ b/src/test/run-pass/generator/drop-env.rs @@ -58,7 +58,7 @@ fn t2() { fn t3() { let b = B; - let mut foo = || { + let foo = || { let _: () = gen arg; // TODO: this line should not be necessary yield; drop(b); @@ -67,7 +67,5 @@ fn t3() { let n = A.load(Ordering::SeqCst); assert_eq!(A.load(Ordering::SeqCst), n); drop(foo); - // TODO: we should assert n+1 here, not n - // assert_eq!(A.load(Ordering::SeqCst), n + 1); - assert_eq!(A.load(Ordering::SeqCst), n); + assert_eq!(A.load(Ordering::SeqCst), n + 1); }