diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs index f0f3e2c3c746..852555048fe7 100644 --- a/compiler/rustc_ast_lowering/src/item.rs +++ b/compiler/rustc_ast_lowering/src/item.rs @@ -1,3 +1,5 @@ +use crate::FnReturnTransformation; + use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound}; use super::ResolverAstLoweringExt; use super::{AstOwner, ImplTraitContext, ImplTraitPosition}; @@ -207,13 +209,33 @@ impl<'hir> LoweringContext<'_, 'hir> { // only cares about the input argument patterns in the function // declaration (decl), not the return types. let asyncness = header.asyncness; - let body_id = - this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref()); + let genness = header.genness; + let body_id = this.lower_maybe_coroutine_body( + span, + hir_id, + decl, + asyncness, + genness, + body.as_deref(), + ); let itctx = ImplTraitContext::Universal; let (generics, decl) = this.lower_generics(generics, header.constness, id, &itctx, |this| { - let ret_id = asyncness.opt_return_id(); + let ret_id = asyncness + .opt_return_id() + .map(|(node_id, span)| { + crate::FnReturnTransformation::Async(node_id, span) + }) + .or_else(|| match genness { + Gen::Yes { span, closure_id: _, return_impl_trait_id } => { + Some(crate::FnReturnTransformation::Iterator( + return_impl_trait_id, + span, + )) + } + _ => None, + }); this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id) }); let sig = hir::FnSig { @@ -732,20 +754,31 @@ impl<'hir> LoweringContext<'_, 'hir> { sig, i.id, FnDeclKind::Trait, - asyncness.opt_return_id(), + asyncness + .opt_return_id() + .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)), ); (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false) } AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => { let asyncness = sig.header.asyncness; - let body_id = - self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body)); + let genness = sig.header.genness; + let body_id = self.lower_maybe_coroutine_body( + i.span, + hir_id, + &sig.decl, + asyncness, + genness, + Some(body), + ); let (generics, sig) = self.lower_method_sig( generics, sig, i.id, FnDeclKind::Trait, - asyncness.opt_return_id(), + asyncness + .opt_return_id() + .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)), ); (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true) } @@ -835,11 +868,13 @@ impl<'hir> LoweringContext<'_, 'hir> { ), AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => { let asyncness = sig.header.asyncness; - let body_id = self.lower_maybe_async_body( + let genness = sig.header.genness; + let body_id = self.lower_maybe_coroutine_body( i.span, hir_id, &sig.decl, asyncness, + genness, body.as_deref(), ); let (generics, sig) = self.lower_method_sig( @@ -847,7 +882,9 @@ impl<'hir> LoweringContext<'_, 'hir> { sig, i.id, if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent }, - asyncness.opt_return_id(), + asyncness + .opt_return_id() + .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)), ); (generics, hir::ImplItemKind::Fn(sig, body_id)) @@ -1011,16 +1048,22 @@ impl<'hir> LoweringContext<'_, 'hir> { }) } - fn lower_maybe_async_body( + /// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or + /// `gen {}` block as appropriate. + fn lower_maybe_coroutine_body( &mut self, span: Span, fn_id: hir::HirId, decl: &FnDecl, asyncness: Async, + genness: Gen, body: Option<&Block>, ) -> hir::BodyId { - let (closure_id, body) = match (asyncness, body) { - (Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body), + let (closure_id, body) = match (asyncness, genness, body) { + // FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error + // for now since it's not supported. + (Async::Yes { closure_id, .. }, _, Some(body)) + | (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body), _ => return self.lower_fn_body_block(span, decl, body), }; @@ -1163,44 +1206,55 @@ impl<'hir> LoweringContext<'_, 'hir> { parameters.push(new_parameter); } - let async_expr = this.make_async_expr( - CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, - closure_id, - None, - body.span, - hir::CoroutineSource::Fn, - |this| { - // Create a block from the user's function body: - let user_body = this.lower_block_expr(body); + let mkbody = |this: &mut LoweringContext<'_, 'hir>| { + // Create a block from the user's function body: + let user_body = this.lower_block_expr(body); - // Transform into `drop-temps { }`, an expression: - let desugared_span = - this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None); - let user_body = - this.expr_drop_temps(desugared_span, this.arena.alloc(user_body)); + // Transform into `drop-temps { }`, an expression: + let desugared_span = + this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None); + let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body)); - // As noted above, create the final block like - // - // ``` - // { - // let $param_pattern = $raw_param; - // ... - // drop-temps { } - // } - // ``` - let body = this.block_all( - desugared_span, - this.arena.alloc_from_iter(statements), - Some(user_body), - ); + // As noted above, create the final block like + // + // ``` + // { + // let $param_pattern = $raw_param; + // ... + // drop-temps { } + // } + // ``` + let body = this.block_all( + desugared_span, + this.arena.alloc_from_iter(statements), + Some(user_body), + ); - this.expr_block(body) - }, - ); + this.expr_block(body) + }; + let coroutine_expr = match (asyncness, genness) { + (Async::Yes { .. }, _) => this.make_async_expr( + CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, + closure_id, + None, + body.span, + hir::CoroutineSource::Fn, + mkbody, + ), + (_, Gen::Yes { .. }) => this.make_gen_expr( + CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, + closure_id, + None, + body.span, + hir::CoroutineSource::Fn, + mkbody, + ), + _ => unreachable!("we must have either an async fn or a gen fn"), + }; let hir_id = this.lower_node_id(closure_id); this.maybe_forward_track_caller(body.span, fn_id, hir_id); - let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) }; + let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) }; (this.arena.alloc_from_iter(parameters), expr) }) @@ -1212,13 +1266,13 @@ impl<'hir> LoweringContext<'_, 'hir> { sig: &FnSig, id: NodeId, kind: FnDeclKind, - is_async: Option<(NodeId, Span)>, + transform_return_type: Option, ) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) { let header = self.lower_fn_header(sig.header); let itctx = ImplTraitContext::Universal; let (generics, decl) = self.lower_generics(generics, sig.header.constness, id, &itctx, |this| { - this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async) + this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type) }); (generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) }) } diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs index aa8ad9784513..96a413e9f73b 100644 --- a/compiler/rustc_ast_lowering/src/lib.rs +++ b/compiler/rustc_ast_lowering/src/lib.rs @@ -493,6 +493,21 @@ enum ParenthesizedGenericArgs { Err, } +/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl` +#[derive(Debug)] +enum FnReturnTransformation { + /// Replaces a return type `T` with `impl Future`. + /// + /// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the + /// `async` keyword. + Async(NodeId, Span), + /// Replaces a return type `T` with `impl Iterator`. + /// + /// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the + /// `gen` keyword. + Iterator(NodeId, Span), +} + impl<'a, 'hir> LoweringContext<'a, 'hir> { fn create_def( &mut self, @@ -1778,13 +1793,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { })) } - // Lowers a function declaration. - // - // `decl`: the unlowered (AST) function declaration. - // `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`. - // `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future` in the - // return type. This is used for `async fn` declarations. The `NodeId` is the ID of the - // return type `impl Trait` item, and the `Span` points to the `async` keyword. + /// Lowers a function declaration. + /// + /// `decl`: the unlowered (AST) function declaration. + /// + /// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given + /// `NodeId`. + /// + /// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is + /// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details. #[instrument(level = "debug", skip(self))] fn lower_fn_decl( &mut self, @@ -1792,7 +1809,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { fn_node_id: NodeId, fn_span: Span, kind: FnDeclKind, - make_ret_async: Option<(NodeId, Span)>, + transform_return_type: Option, ) -> &'hir hir::FnDecl<'hir> { let c_variadic = decl.c_variadic(); @@ -1821,11 +1838,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { self.lower_ty_direct(¶m.ty, &itctx) })); - let output = if let Some((ret_id, _span)) = make_ret_async { - let fn_def_id = self.local_def_id(fn_node_id); - self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span) - } else { - match &decl.output { + let output = match transform_return_type { + Some(transform) => { + let fn_def_id = self.local_def_id(fn_node_id); + self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span) + } + None => match &decl.output { FnRetTy::Ty(ty) => { let context = if kind.return_impl_trait_allowed() { let fn_def_id = self.local_def_id(fn_node_id); @@ -1849,7 +1867,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { hir::FnRetTy::Return(self.lower_ty(ty, &context)) } FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)), - } + }, }; self.arena.alloc(hir::FnDecl { @@ -1888,17 +1906,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { // `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition) // `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created #[instrument(level = "debug", skip(self))] - fn lower_async_fn_ret_ty( + fn lower_coroutine_fn_ret_ty( &mut self, output: &FnRetTy, fn_def_id: LocalDefId, - opaque_ty_node_id: NodeId, + transform: FnReturnTransformation, fn_kind: FnDeclKind, fn_span: Span, ) -> hir::FnRetTy<'hir> { let span = self.lower_span(fn_span); let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None); + let opaque_ty_node_id = match transform { + FnReturnTransformation::Async(opaque_ty_node_id, _) + | FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id, + }; + let captured_lifetimes: Vec<_> = self .resolver .take_extra_lifetime_params(opaque_ty_node_id) @@ -1914,8 +1937,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { span, opaque_ty_span, |this| { - let future_bound = this.lower_async_fn_output_type_to_future_bound( + let future_bound = this.lower_coroutine_fn_output_type_to_future_bound( output, + transform, span, ImplTraitContext::ReturnPositionOpaqueTy { origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id), @@ -1931,9 +1955,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { } /// Transforms `-> T` into `Future`. - fn lower_async_fn_output_type_to_future_bound( + fn lower_coroutine_fn_output_type_to_future_bound( &mut self, output: &FnRetTy, + transform: FnReturnTransformation, span: Span, nested_impl_trait_context: ImplTraitContext, ) -> hir::GenericBound<'hir> { @@ -1948,17 +1973,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])), }; - // "" + // "" + let (symbol, lang_item) = match transform { + FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future), + FnReturnTransformation::Iterator(..) => { + (hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator) + } + }; + let future_args = self.arena.alloc(hir::GenericArgs { args: &[], - bindings: arena_vec![self; self.output_ty_binding(span, output_ty)], + bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)], parenthesized: hir::GenericArgsParentheses::No, span_ext: DUMMY_SP, }); hir::GenericBound::LangItemTrait( - // ::std::future::Future - hir::LangItem::Future, + lang_item, self.lower_span(span), self.next_id(), future_args, diff --git a/compiler/rustc_ast_lowering/src/path.rs b/compiler/rustc_ast_lowering/src/path.rs index db8ca7c3643f..accb74d7a522 100644 --- a/compiler/rustc_ast_lowering/src/path.rs +++ b/compiler/rustc_ast_lowering/src/path.rs @@ -389,7 +389,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])), }; let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))]; - let binding = self.output_ty_binding(output_ty.span, output_ty); + let binding = self.assoc_ty_binding(hir::FN_OUTPUT_NAME, output_ty.span, output_ty); ( GenericArgsCtor { args, @@ -401,13 +401,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { ) } - /// An associated type binding `Output = $ty`. - pub(crate) fn output_ty_binding( + /// An associated type binding `$symbol = $ty`. + pub(crate) fn assoc_ty_binding( &mut self, + symbol: rustc_span::Symbol, span: Span, ty: &'hir hir::Ty<'hir>, ) -> hir::TypeBinding<'hir> { - let ident = Ident::with_dummy_span(hir::FN_OUTPUT_NAME); + let ident = Ident::with_dummy_span(symbol); let kind = hir::TypeBindingKind::Equality { term: ty.into() }; let args = arena_vec![self;]; let bindings = arena_vec![self;]; diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 81733d8f64e2..0d9e174e37aa 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -2255,6 +2255,8 @@ pub enum ImplItemKind<'hir> { /// The name of the associated type for `Fn` return types. pub const FN_OUTPUT_NAME: Symbol = sym::Output; +/// The name of the associated type for `Iterator` item types. +pub const ITERATOR_ITEM_NAME: Symbol = sym::Item; /// Bind a type to an associated type (i.e., `A = Foo`). /// diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index cb36d510149f..f0bb18df48c9 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -651,9 +651,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { }, ) } - Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => { - todo!("gen closures do not exist yet") - } + // For a `gen {}` block created as a `gen fn` body, we need the return type to be + // (). + Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => self.tcx.types.unit, _ => astconv.ty_infer(None, decl.output.span()), }, diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs index 55f7310681fa..4f01ab02c044 100644 --- a/compiler/rustc_parse/src/parser/item.rs +++ b/compiler/rustc_parse/src/parser/item.rs @@ -2410,10 +2410,6 @@ impl<'a> Parser<'a> { } } - if let Gen::Yes { span, .. } = genness { - self.sess.emit_err(errors::GenFn { span }); - } - if !self.eat_keyword_case(kw::Fn, case) { // It is possible for `expect_one_of` to recover given the contents of // `self.expected_tokens`, therefore, do not use `self.unexpected()` which doesn't diff --git a/compiler/rustc_resolve/src/def_collector.rs b/compiler/rustc_resolve/src/def_collector.rs index 647c92785e10..306492eaa96c 100644 --- a/compiler/rustc_resolve/src/def_collector.rs +++ b/compiler/rustc_resolve/src/def_collector.rs @@ -156,7 +156,10 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> { fn visit_fn(&mut self, fn_kind: FnKind<'a>, span: Span, _: NodeId) { if let FnKind::Fn(_, _, sig, _, generics, body) = fn_kind { - if let Async::Yes { closure_id, .. } = sig.header.asyncness { + // FIXME(eholk): handle `async gen fn` + if let (Async::Yes { closure_id, .. }, _) | (_, Gen::Yes { closure_id, .. }) = + (sig.header.asyncness, sig.header.genness) + { self.visit_generics(generics); // For async functions, we need to create their inner defs inside of a