diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/display.rs b/src/tools/rust-analyzer/crates/hir-ty/src/display.rs index c749a3d24a25..dd1b212d4c29 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/display.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/display.rs @@ -38,7 +38,8 @@ use rustc_apfloat::{ use rustc_ast_ir::FloatTy; use rustc_hash::FxHashSet; use rustc_type_ir::{ - AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, RegionKind, Upcast, + AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, CoroutineClosureArgsParts, RegionKind, + Upcast, inherent::{AdtDef, GenericArgs as _, IntoKind, SliceLike, Term as _, Ty as _, Tys as _}, }; use smallvec::SmallVec; @@ -1444,14 +1445,83 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } if f.closure_style == ClosureStyle::RANotation || !sig.output().is_unit() { write!(f, " -> ")?; - // FIXME: We display `AsyncFn` as `-> impl Future`, but this is hard to fix because - // we don't have a trait environment here, required to normalize `::Output`. sig.output().hir_fmt(f)?; } } else { write!(f, "{{closure}}")?; } } + TyKind::CoroutineClosure(id, args) => { + let id = id.0; + if f.display_kind.is_source_code() { + if !f.display_kind.allows_opaque() { + return Err(HirDisplayError::DisplaySourceCodeError( + DisplaySourceCodeError::OpaqueType, + )); + } else if f.closure_style != ClosureStyle::ImplFn { + never!("Only `impl Fn` is valid for displaying closures in source code"); + } + } + match f.closure_style { + ClosureStyle::Hide => return write!(f, "{TYPE_HINT_TRUNCATION}"), + ClosureStyle::ClosureWithId => { + return write!( + f, + "{{async closure#{:?}}}", + salsa::plumbing::AsId::as_id(&id).index() + ); + } + ClosureStyle::ClosureWithSubst => { + write!( + f, + "{{async closure#{:?}}}", + salsa::plumbing::AsId::as_id(&id).index() + )?; + return hir_fmt_generics(f, args.as_slice(), None, None); + } + _ => (), + } + let CoroutineClosureArgsParts { closure_kind_ty, signature_parts_ty, .. } = + args.split_coroutine_closure_args(); + let kind = closure_kind_ty.to_opt_closure_kind().unwrap(); + let kind = match kind { + rustc_type_ir::ClosureKind::Fn => "AsyncFn", + rustc_type_ir::ClosureKind::FnMut => "AsyncFnMut", + rustc_type_ir::ClosureKind::FnOnce => "AsyncFnOnce", + }; + let TyKind::FnPtr(coroutine_sig, _) = signature_parts_ty.kind() else { + unreachable!("invalid coroutine closure signature"); + }; + let coroutine_sig = coroutine_sig.skip_binder(); + let coroutine_inputs = coroutine_sig.inputs(); + let TyKind::Tuple(coroutine_inputs) = coroutine_inputs.as_slice()[1].kind() else { + unreachable!("invalid coroutine closure signature"); + }; + let TyKind::Tuple(coroutine_output) = coroutine_sig.output().kind() else { + unreachable!("invalid coroutine closure signature"); + }; + let coroutine_output = coroutine_output.as_slice()[1]; + match f.closure_style { + ClosureStyle::ImplFn => write!(f, "impl {kind}(")?, + ClosureStyle::RANotation => write!(f, "async |")?, + _ => unreachable!(), + } + if coroutine_inputs.is_empty() { + } else if f.should_truncate() { + write!(f, "{TYPE_HINT_TRUNCATION}")?; + } else { + f.write_joined(coroutine_inputs, ", ")?; + }; + match f.closure_style { + ClosureStyle::ImplFn => write!(f, ")")?, + ClosureStyle::RANotation => write!(f, "|")?, + _ => unreachable!(), + } + if f.closure_style == ClosureStyle::RANotation || !coroutine_output.is_unit() { + write!(f, " -> ")?; + coroutine_output.hir_fmt(f)?; + } + } TyKind::Placeholder(_) => write!(f, "{{placeholder}}")?, TyKind::Param(param) => { // FIXME: We should not access `param.id`, it should be removed, and we should know the @@ -1545,8 +1615,13 @@ impl<'db> HirDisplay<'db> for Ty<'db> { let CoroutineArgsParts { resume_ty, yield_ty, return_ty, .. } = subst.split_coroutine_args(); let body = db.body(owner); - match &body[expr_id] { - hir_def::hir::Expr::Async { .. } => { + let expr = &body[expr_id]; + match expr { + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Async, + .. + } + | hir_def::hir::Expr::Async { .. } => { let future_trait = LangItem::Future.resolve_trait(db, owner.module(db).krate()); let output = future_trait.and_then(|t| { @@ -1573,7 +1648,10 @@ impl<'db> HirDisplay<'db> for Ty<'db> { return_ty.hir_fmt(f)?; write!(f, ">")?; } - _ => { + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Coroutine(..), + .. + } => { if f.display_kind.is_source_code() { return Err(HirDisplayError::DisplaySourceCodeError( DisplaySourceCodeError::Coroutine, @@ -1589,12 +1667,12 @@ impl<'db> HirDisplay<'db> for Ty<'db> { write!(f, " -> ")?; return_ty.hir_fmt(f)?; } + _ => panic!("invalid expr for coroutine: {expr:?}"), } } TyKind::CoroutineWitness(..) => write!(f, "{{coroutine witness}}")?, TyKind::Pat(_, _) => write!(f, "{{pat}}")?, TyKind::UnsafeBinder(_) => write!(f, "{{unsafe binder}}")?, - TyKind::CoroutineClosure(_, _) => write!(f, "{{coroutine closure}}")?, TyKind::Alias(_, _) => write!(f, "{{alias}}")?, } Ok(()) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs index 3dc277023a32..06f8307eb0ab 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs @@ -11,8 +11,9 @@ use hir_def::{ type_ref::TypeRefId, }; use rustc_type_ir::{ - ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, Interner, TypeSuperVisitable, - TypeVisitable, TypeVisitableExt, TypeVisitor, + ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, + CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, + TypeVisitor, inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Ty as _}, }; use tracing::debug; @@ -22,8 +23,9 @@ use crate::{ db::{InternedClosure, InternedCoroutine}, infer::{BreakableKind, Diverges, coerce::CoerceMany}, next_solver::{ - AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, - PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind, + AliasTy, Binder, BoundRegionKind, BoundVarKind, BoundVarKinds, ClauseKind, DbInterner, + ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, PolyProjectionPredicate, Predicate, + PredicateKind, SolverDefId, Ty, TyKind, abi::Safety, infer::{ BoundRegionConversionTime, InferOk, InferResult, @@ -72,6 +74,8 @@ impl<'db> InferenceContext<'_, 'db> { let sig_ty = Ty::new_fn_ptr(interner, bound_sig); let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); + // FIXME: Make this an infer var and infer it later. + let tupled_upvars_ty = self.types.unit; let (id, ty, resume_yield_tys) = match closure_kind { ClosureKind::Coroutine(_) => { let yield_ty = self.table.next_ty_var(); @@ -80,11 +84,11 @@ impl<'db> InferenceContext<'_, 'db> { // FIXME: Infer the upvars later. let parts = CoroutineArgsParts { parent_args, - kind_ty: Ty::new_unit(interner), + kind_ty: self.types.unit, resume_ty, yield_ty, return_ty: body_ret_ty, - tupled_upvars_ty: Ty::new_unit(interner), + tupled_upvars_ty, }; let coroutine_id = @@ -97,9 +101,7 @@ impl<'db> InferenceContext<'_, 'db> { (None, coroutine_ty, Some((resume_ty, yield_ty))) } - // FIXME(next-solver): `ClosureKind::Async` should really be a separate arm that creates a `CoroutineClosure`. - // But for now we treat it as a closure. - ClosureKind::Closure | ClosureKind::Async => { + ClosureKind::Closure => { let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr)); match expected_kind { Some(kind) => { @@ -117,7 +119,7 @@ impl<'db> InferenceContext<'_, 'db> { } None => {} }; - // FIXME: Infer the kind and the upvars later when needed. + // FIXME: Infer the kind later if needed. let parts = ClosureArgsParts { parent_args, closure_kind_ty: Ty::from_closure_kind( @@ -125,7 +127,7 @@ impl<'db> InferenceContext<'_, 'db> { expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), ), closure_sig_as_fn_ptr_ty: sig_ty, - tupled_upvars_ty: Ty::new_unit(interner), + tupled_upvars_ty, }; let closure_ty = Ty::new_closure( interner, @@ -136,6 +138,61 @@ impl<'db> InferenceContext<'_, 'db> { self.add_current_closure_dependency(closure_id); (Some(closure_id), closure_ty, None) } + ClosureKind::Async => { + // async closures always return the type ascribed after the `->` (if present), + // and yield `()`. + let bound_return_ty = bound_sig.skip_binder().output(); + let bound_yield_ty = self.types.unit; + // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. + let resume_ty = self.types.unit; + + // FIXME: Infer the kind later if needed. + let closure_kind_ty = Ty::from_closure_kind( + interner, + expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), + ); + + // FIXME: Infer captures later. + // `for<'env> fn() -> ()`, for no captures. + let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + interner, + Binder::bind_with_vars( + interner.mk_fn_sig([], self.types.unit, false, Safety::Safe, FnAbi::Rust), + BoundVarKinds::new_from_iter( + interner, + [BoundVarKind::Region(BoundRegionKind::ClosureEnv)], + ), + ), + ); + let closure_args = CoroutineClosureArgs::new( + interner, + CoroutineClosureArgsParts { + parent_args, + closure_kind_ty, + signature_parts_ty: Ty::new_fn_ptr( + interner, + bound_sig.map_bound(|sig| { + interner.mk_fn_sig( + [ + resume_ty, + Ty::new_tup_from_iter(interner, sig.inputs().iter()), + ], + Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), + sig.c_variadic, + sig.safety, + sig.abi, + ) + }), + ), + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + }, + ); + + let coroutine_id = + self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); + (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None) + } }; // Now go through the argument patterns diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs index e3c65689d3fd..081865a99e5c 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs @@ -17,8 +17,8 @@ use rustc_abi::{ReprFlags, ReprOptions}; use rustc_hash::FxHashSet; use rustc_index::bit_set::DenseBitSet; use rustc_type_ir::{ - AliasTermKind, AliasTyKind, BoundVar, CollectAndApply, DebruijnIndex, EarlyBinder, - FlagComputation, Flags, GenericArgKind, ImplPolarity, InferTy, Interner, TraitRef, + AliasTermKind, AliasTyKind, BoundVar, CollectAndApply, CoroutineWitnessTypes, DebruijnIndex, + EarlyBinder, FlagComputation, Flags, GenericArgKind, ImplPolarity, InferTy, Interner, TraitRef, TypeVisitableExt, UniverseIndex, Upcast, Variance, elaborate::elaborate, error::TypeError, @@ -29,7 +29,7 @@ use rustc_type_ir::{ use crate::{ FnAbi, - db::HirDatabase, + db::{HirDatabase, InternedCoroutine}, method_resolution::{ALL_FLOAT_FPS, ALL_INT_FPS, TyFingerprint}, next_solver::{ AdtIdWrapper, BoundConst, CallableIdWrapper, CanonicalVarKind, ClosureIdWrapper, @@ -1205,12 +1205,28 @@ impl<'db> Interner for DbInterner<'db> { self.db().callable_item_signature(def_id.0) } - fn coroutine_movability(self, _def_id: Self::CoroutineId) -> rustc_ast_ir::Movability { - unimplemented!() + fn coroutine_movability(self, def_id: Self::CoroutineId) -> rustc_ast_ir::Movability { + // FIXME: Make this a query? I don't believe this can be accessed from bodies other than + // the current infer query, except with revealed opaques - is it rare enough to not matter? + let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let body = self.db.body(owner); + let expr = &body[expr_id]; + match *expr { + hir_def::hir::Expr::Closure { closure_kind, .. } => match closure_kind { + hir_def::hir::ClosureKind::Coroutine(movability) => match movability { + hir_def::hir::Movability::Static => rustc_ast_ir::Movability::Static, + hir_def::hir::Movability::Movable => rustc_ast_ir::Movability::Movable, + }, + hir_def::hir::ClosureKind::Async => rustc_ast_ir::Movability::Static, + _ => panic!("unexpected expression for a coroutine: {expr:?}"), + }, + hir_def::hir::Expr::Async { .. } => rustc_ast_ir::Movability::Static, + _ => panic!("unexpected expression for a coroutine: {expr:?}"), + } } - fn coroutine_for_closure(self, _def_id: Self::CoroutineId) -> Self::CoroutineId { - unimplemented!() + fn coroutine_for_closure(self, def_id: Self::CoroutineClosureId) -> Self::CoroutineId { + def_id } fn generics_require_sized_self(self, def_id: Self::DefId) -> bool { @@ -1725,23 +1741,39 @@ impl<'db> Interner for DbInterner<'db> { panic!("Bug encountered in next-trait-solver: {}", msg.to_string()) } - fn is_general_coroutine(self, _coroutine_def_id: Self::CoroutineId) -> bool { - // FIXME(next-solver) - true + fn is_general_coroutine(self, def_id: Self::CoroutineId) -> bool { + // FIXME: Make this a query? I don't believe this can be accessed from bodies other than + // the current infer query, except with revealed opaques - is it rare enough to not matter? + let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let body = self.db.body(owner); + matches!( + body[expr_id], + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Coroutine(_), + .. + } + ) } - fn coroutine_is_async(self, _coroutine_def_id: Self::CoroutineId) -> bool { - // FIXME(next-solver) - true + fn coroutine_is_async(self, def_id: Self::CoroutineId) -> bool { + // FIXME: Make this a query? I don't believe this can be accessed from bodies other than + // the current infer query, except with revealed opaques - is it rare enough to not matter? + let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let body = self.db.body(owner); + matches!( + body[expr_id], + hir_def::hir::Expr::Closure { closure_kind: hir_def::hir::ClosureKind::Async, .. } + | hir_def::hir::Expr::Async { .. } + ) } fn coroutine_is_gen(self, _coroutine_def_id: Self::CoroutineId) -> bool { - // FIXME(next-solver) + // We don't handle gen coroutines yet. false } fn coroutine_is_async_gen(self, _coroutine_def_id: Self::CoroutineId) -> bool { - // FIXME(next-solver) + // We don't handle gen coroutines yet. false } @@ -1897,10 +1929,12 @@ impl<'db> Interner for DbInterner<'db> { fn coroutine_hidden_types( self, _def_id: Self::CoroutineId, - ) -> EarlyBinder>> - { - // FIXME(next-solver) - unimplemented!() + ) -> EarlyBinder>> { + // FIXME: Actually implement this. + EarlyBinder::bind(Binder::dummy(CoroutineWitnessTypes { + types: Tys::default(), + assumptions: RegionAssumptions::default(), + })) } fn is_default_trait(self, def_id: Self::TraitId) -> bool { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs index 1443e2f0b312..b8406fecda31 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs @@ -11,10 +11,10 @@ use hir_def::{TraitId, type_ref::Rawness}; use rustc_abi::{Float, Integer, Size}; use rustc_ast_ir::{Mutability, try_visit, visit::VisitorResult}; use rustc_type_ir::{ - AliasTyKind, BoundVar, BoundVarIndexKind, ClosureKind, DebruijnIndex, FlagComputation, Flags, - FloatTy, FloatVid, InferTy, IntTy, IntVid, Interner, TyVid, TypeFoldable, TypeSuperFoldable, - TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, UintTy, Upcast, - WithCachedTypeInfo, + AliasTyKind, BoundVar, BoundVarIndexKind, ClosureKind, CoroutineArgs, CoroutineArgsParts, + DebruijnIndex, FlagComputation, Flags, FloatTy, FloatVid, InferTy, IntTy, IntVid, Interner, + TyVid, TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, + TypeVisitor, UintTy, Upcast, WithCachedTypeInfo, inherent::{ AdtDef as _, BoundExistentialPredicates, BoundVarLike, Const as _, GenericArgs as _, IntoKind, ParamLike, PlaceholderLike, Safety as _, SliceLike, Ty as _, @@ -404,6 +404,40 @@ impl<'db> Ty<'db> { .split_closure_args_untupled() .closure_sig_as_fn_ptr_ty .callable_sig(interner), + TyKind::CoroutineClosure(coroutine_id, args) => { + Some(args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + let unit_ty = Ty::new_unit(interner); + let return_ty = Ty::new_coroutine( + interner, + coroutine_id, + CoroutineArgs::new( + interner, + CoroutineArgsParts { + parent_args: args.as_coroutine_closure().parent_args(), + kind_ty: unit_ty, + resume_ty: unit_ty, + yield_ty: unit_ty, + return_ty: sig.return_ty, + // FIXME: Deduce this from the coroutine closure's upvars. + tupled_upvars_ty: unit_ty, + }, + ) + .args, + ); + FnSig { + inputs_and_output: Tys::new_from_iter( + interner, + sig.tupled_inputs_ty + .tuple_fields() + .iter() + .chain(std::iter::once(return_ty)), + ), + c_variadic: sig.c_variadic, + safety: sig.safety, + abi: sig.abi, + } + })) + } _ => None, } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs index 38af7cb7248f..c2392b36baba 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/simple.rs @@ -3856,9 +3856,9 @@ fn main() { 74..75 'f': F 80..82 '{}': () 94..191 '{ ... }); }': () - 100..113 'async_closure': fn async_closure(impl FnOnce(i32)) + 100..113 'async_closure': fn async_closure(impl AsyncFnOnce(i32)) 100..147 'async_... })': () - 114..146 'async ... }': impl FnOnce(i32) + 114..146 'async ... }': impl AsyncFnOnce(i32) 121..124 'arg': i32 126..146 '{ ... }': () 136..139 'arg': i32 diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs index 0cf723e8514d..f72ca22fd229 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs @@ -85,7 +85,6 @@ async fn test() { } #[test] -#[ignore = "FIXME(next-solver): fix async closures"] fn infer_async_closure() { check_types( r#" @@ -93,7 +92,7 @@ fn infer_async_closure() { async fn test() { let f = async move |x: i32| x + 42; f; -// ^ impl Fn(i32) -> impl Future +// ^ impl AsyncFn(i32) -> i32 let a = f(4); a; // ^ impl Future @@ -102,7 +101,7 @@ async fn test() { // ^ i32 let f = async move || 42; f; -// ^ impl Fn() -> impl Future +// ^ impl AsyncFn() -> i32 let a = f(); a; // ^ impl Future @@ -119,7 +118,7 @@ async fn test() { }; let _: Option = c().await; c; -// ^ impl Fn() -> impl Future> +// ^ impl AsyncFn() -> Option } "#, ); @@ -4930,7 +4929,6 @@ fn main() { #[test] fn async_fn_return_type() { - // FIXME(next-solver): Async closures are lowered as closures currently. We should fix that. check_infer( r#" //- minicore: async_fn @@ -4948,9 +4946,9 @@ fn main() { 46..53 'loop {}': ! 51..53 '{}': () 67..97 '{ ...()); }': () - 73..76 'foo': fn foo(impl Fn()) + 73..76 'foo': fn foo(impl AsyncFn()) 73..94 'foo(as...|| ())': () - 77..93 'async ... || ()': impl Fn() + 77..93 'async ... || ()': impl AsyncFn() 91..93 '()': () "#]], ); diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs index 941890312317..2bb2f80ecc05 100644 --- a/src/tools/rust-analyzer/crates/hir/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs @@ -75,7 +75,7 @@ use hir_ty::{ TraitEnvironment, TyDefId, TyLoweringDiagnostic, ValueTyDefId, all_super_traits, autoderef, check_orphan_rules, consteval::try_const_usize, - db::InternedClosureId, + db::{InternedClosureId, InternedCoroutineId}, diagnostics::BodyValidationDiagnostic, direct_super_traits, known_const_to_ast, layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding}, @@ -92,7 +92,7 @@ use itertools::Itertools; use rustc_hash::FxHashSet; use rustc_type_ir::{ AliasTyKind, TypeSuperVisitable, TypeVisitable, TypeVisitor, - inherent::{AdtDef, IntoKind, SliceLike, Term as _, Ty as _}, + inherent::{AdtDef, GenericArgs as _, IntoKind, SliceLike, Term as _, Ty as _}, }; use smallvec::SmallVec; use span::{AstIdNode, Edition, FileId}; @@ -4558,16 +4558,27 @@ impl<'db> TraitRef<'db> { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +enum AnyClosureId { + ClosureId(InternedClosureId), + CoroutineClosureId(InternedCoroutineId), +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Closure<'db> { - id: InternedClosureId, + id: AnyClosureId, subst: GenericArgs<'db>, } impl<'db> Closure<'db> { fn as_ty(&self, db: &'db dyn HirDatabase) -> Ty<'db> { let interner = DbInterner::new_with(db, None, None); - Ty::new_closure(interner, self.id.into(), self.subst) + match self.id { + AnyClosureId::ClosureId(id) => Ty::new_closure(interner, id.into(), self.subst), + AnyClosureId::CoroutineClosureId(id) => { + Ty::new_coroutine_closure(interner, id.into(), self.subst) + } + } } pub fn display_with_id(&self, db: &dyn HirDatabase, display_target: DisplayTarget) -> String { @@ -4585,20 +4596,28 @@ impl<'db> Closure<'db> { } pub fn captured_items(&self, db: &'db dyn HirDatabase) -> Vec> { - let owner = db.lookup_intern_closure(self.id).0; + let AnyClosureId::ClosureId(id) = self.id else { + // FIXME: Infer coroutine closures' captures. + return Vec::new(); + }; + let owner = db.lookup_intern_closure(id).0; let infer = db.infer(owner); - let info = infer.closure_info(self.id); + let info = infer.closure_info(id); info.0 .iter() .cloned() - .map(|capture| ClosureCapture { owner, closure: self.id, capture }) + .map(|capture| ClosureCapture { owner, closure: id, capture }) .collect() } pub fn capture_types(&self, db: &'db dyn HirDatabase) -> Vec> { - let owner = db.lookup_intern_closure(self.id).0; + let AnyClosureId::ClosureId(id) = self.id else { + // FIXME: Infer coroutine closures' captures. + return Vec::new(); + }; + let owner = db.lookup_intern_closure(id).0; let infer = db.infer(owner); - let (captures, _) = infer.closure_info(self.id); + let (captures, _) = infer.closure_info(id); let env = db.trait_environment_for_body(owner); captures .iter() @@ -4607,10 +4626,22 @@ impl<'db> Closure<'db> { } pub fn fn_trait(&self, db: &dyn HirDatabase) -> FnTrait { - let owner = db.lookup_intern_closure(self.id).0; - let infer = db.infer(owner); - let info = infer.closure_info(self.id); - info.1 + match self.id { + AnyClosureId::ClosureId(id) => { + let owner = db.lookup_intern_closure(id).0; + let infer = db.infer(owner); + let info = infer.closure_info(id); + info.1 + } + AnyClosureId::CoroutineClosureId(_id) => { + // FIXME: Infer kind for coroutine closures. + match self.subst.as_coroutine_closure().kind() { + rustc_type_ir::ClosureKind::Fn => FnTrait::AsyncFn, + rustc_type_ir::ClosureKind::FnMut => FnTrait::AsyncFnMut, + rustc_type_ir::ClosureKind::FnOnce => FnTrait::AsyncFnOnce, + } + } + } } } @@ -5124,28 +5155,14 @@ impl<'db> Type<'db> { let interner = DbInterner::new_with(db, None, None); let callee = match self.ty.kind() { TyKind::Closure(id, subst) => Callee::Closure(id.0, subst), + TyKind::CoroutineClosure(id, subst) => Callee::CoroutineClosure(id.0, subst), TyKind::FnPtr(..) => Callee::FnPtr, TyKind::FnDef(id, _) => Callee::Def(id.0), - kind => { - // This will happen when it implements fn or fn mut, since we add an autoborrow adjustment - let (ty, kind) = if let TyKind::Ref(_, ty, _) = kind { - (ty, ty.kind()) - } else { - (self.ty, kind) - }; - if let TyKind::Closure(closure, subst) = kind { - let sig = subst - .split_closure_args_untupled() - .closure_sig_as_fn_ptr_ty - .callable_sig(interner)?; - return Some(Callable { - ty: self.clone(), - sig, - callee: Callee::Closure(closure.0, subst), - is_bound_method: false, - }); - } - let (fn_trait, sig) = hir_ty::callable_sig_from_fn_trait(ty, self.env.clone(), db)?; + // This will happen when it implements fn or fn mut, since we add an autoborrow adjustment + TyKind::Ref(_, inner_ty, _) => return self.derived(inner_ty).as_callable(db), + _ => { + let (fn_trait, sig) = + hir_ty::callable_sig_from_fn_trait(self.ty, self.env.clone(), db)?; return Some(Callable { ty: self.clone(), sig, @@ -5165,7 +5182,12 @@ impl<'db> Type<'db> { pub fn as_closure(&self) -> Option> { match self.ty.kind() { - TyKind::Closure(id, subst) => Some(Closure { id: id.0, subst }), + TyKind::Closure(id, subst) => { + Some(Closure { id: AnyClosureId::ClosureId(id.0), subst }) + } + TyKind::CoroutineClosure(id, subst) => { + Some(Closure { id: AnyClosureId::CoroutineClosureId(id.0), subst }) + } _ => None, } } @@ -5824,6 +5846,7 @@ pub struct Callable<'db> { enum Callee<'db> { Def(CallableDefId), Closure(InternedClosureId, GenericArgs<'db>), + CoroutineClosure(InternedCoroutineId, GenericArgs<'db>), FnPtr, FnImpl(FnTrait), } @@ -5845,7 +5868,12 @@ impl<'db> Callable<'db> { Callee::Def(CallableDefId::EnumVariantId(it)) => { CallableKind::TupleEnumVariant(it.into()) } - Callee::Closure(id, ref subst) => CallableKind::Closure(Closure { id, subst: *subst }), + Callee::Closure(id, subst) => { + CallableKind::Closure(Closure { id: AnyClosureId::ClosureId(id), subst }) + } + Callee::CoroutineClosure(id, subst) => { + CallableKind::Closure(Closure { id: AnyClosureId::CoroutineClosureId(id), subst }) + } Callee::FnPtr => CallableKind::FnPtr, Callee::FnImpl(fn_) => CallableKind::FnImpl(fn_), } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs index 2cda6d6f1c0a..ca142332d97e 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs @@ -805,7 +805,6 @@ impl A { ); } - #[ignore = "FIXME(next-solver): Fix async closures"] #[test] fn replaces_async_closure_with_async_fn() { check_assist(