diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs index 62700ab1809f..04d63f5bc8fe 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs @@ -189,7 +189,7 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti /// This will create a turbofish generic arg list corresponding to the number of arguments fn get_fish_head(make: &SyntaxFactory, number_of_arguments: usize) -> ast::GenericArgList { let args = (0..number_of_arguments).map(|_| make::type_arg(make::ty_placeholder()).into()); - make.turbofish_generic_arg_list(args) + make.generic_arg_list(args, true) } #[cfg(test)] diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs index 658600cd2d0e..0b145dcb06ba 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs @@ -6,10 +6,9 @@ use ide_db::{ famous_defs::FamousDefs, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, }; -use itertools::Itertools; use syntax::{ - ast::{self, make, Expr, HasGenericParams}, - match_ast, ted, AstNode, ToSmolStr, + ast::{self, syntax_factory::SyntaxFactory, Expr, HasGenericArgs, HasGenericParams}, + match_ast, AstNode, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -43,11 +42,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let ret_type = ctx.find_node_at_offset::()?; let parent = ret_type.syntax().parent()?; - let body = match_ast! { + let body_expr = match_ast! { match parent { - ast::Fn(func) => func.body()?, + ast::Fn(func) => func.body()?.into(), ast::ClosureExpr(closure) => match closure.body()? { - Expr::BlockExpr(block) => block, + Expr::BlockExpr(block) => block.into(), // closures require a block when a return type is specified _ => return None, }, @@ -75,56 +74,65 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op kind.assist_id(), kind.label(), type_ref.syntax().text_range(), - |edit| { - let alias = wrapper_alias(ctx, &core_wrapper, type_ref, kind.symbol()); - let new_return_ty = - alias.unwrap_or_else(|| kind.wrap_type(type_ref)).clone_for_update(); - - let body = edit.make_mut(ast::Expr::BlockExpr(body.clone())); + |builder| { + let mut editor = builder.make_editor(&parent); + let make = SyntaxFactory::new(); + let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol()); + let new_return_ty = alias.unwrap_or_else(|| match kind { + WrapperKind::Option => make.ty_option(type_ref.clone()), + WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()), + }); let mut exprs_to_wrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); - walk_expr(&body, &mut |expr| { + walk_expr(&body_expr, &mut |expr| { if let Expr::ReturnExpr(ret_expr) = expr { if let Some(ret_expr_arg) = &ret_expr.expr() { for_each_tail_expr(ret_expr_arg, tail_cb); } } }); - for_each_tail_expr(&body, tail_cb); + for_each_tail_expr(&body_expr, tail_cb); for ret_expr_arg in exprs_to_wrap { - let happy_wrapped = make::expr_call( - make::expr_path(make::ext::ident_path(kind.happy_ident())), - make::arg_list(iter::once(ret_expr_arg.clone())), - ) - .clone_for_update(); - ted::replace(ret_expr_arg.syntax(), happy_wrapped.syntax()); + let happy_wrapped = make.expr_call( + make.expr_path(make.ident_path(kind.happy_ident())), + make.arg_list(iter::once(ret_expr_arg.clone())), + ); + editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax()); } - let old_return_ty = edit.make_mut(type_ref.clone()); - ted::replace(old_return_ty.syntax(), new_return_ty.syntax()); + editor.replace(type_ref.syntax(), new_return_ty.syntax()); if let WrapperKind::Result = kind { // Add a placeholder snippet at the first generic argument that doesn't equal the return type. // This is normally the error type, but that may not be the case when we inserted a type alias. - let args = - new_return_ty.syntax().descendants().find_map(ast::GenericArgList::cast); - let error_type_arg = args.and_then(|list| { - list.generic_args().find(|arg| match arg { - ast::GenericArg::TypeArg(_) => { - arg.syntax().text() != type_ref.syntax().text() - } - ast::GenericArg::LifetimeArg(_) => false, - _ => true, - }) + let args = new_return_ty + .path() + .unwrap() + .segment() + .unwrap() + .generic_arg_list() + .unwrap(); + let error_type_arg = args.generic_args().find(|arg| match arg { + ast::GenericArg::TypeArg(_) => { + arg.syntax().text() != type_ref.syntax().text() + } + ast::GenericArg::LifetimeArg(_) => false, + _ => true, }); if let Some(error_type_arg) = error_type_arg { if let Some(cap) = ctx.config.snippet_cap { - edit.add_placeholder_snippet(cap, error_type_arg); + editor.add_annotation( + error_type_arg.syntax(), + builder.make_placeholder_snippet(cap), + ); } } } + + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.file_id(), editor); }, ); } @@ -176,22 +184,16 @@ impl WrapperKind { WrapperKind::Result => hir::sym::Result.clone(), } } - - fn wrap_type(&self, type_ref: &ast::Type) -> ast::Type { - match self { - WrapperKind::Option => make::ext::ty_option(type_ref.clone()), - WrapperKind::Result => make::ext::ty_result(type_ref.clone(), make::ty_placeholder()), - } - } } // Try to find an wrapper type alias in the current scope (shadowing the default). fn wrapper_alias( ctx: &AssistContext<'_>, + make: &SyntaxFactory, core_wrapper: &hir::Enum, ret_type: &ast::Type, wrapper: hir::Symbol, -) -> Option { +) -> Option { let wrapper_path = hir::ModPath::from_segments( hir::PathKind::Plain, iter::once(hir::Name::new_symbol_root(wrapper)), @@ -207,25 +209,28 @@ fn wrapper_alias( }) .find_map(|alias| { let mut inserted_ret_type = false; - let generic_params = alias - .source(ctx.db())? - .value - .generic_param_list()? - .generic_params() - .map(|param| match param { - // Replace the very first type parameter with the functions return type. - ast::GenericParam::TypeParam(_) if !inserted_ret_type => { - inserted_ret_type = true; - ret_type.to_smolstr() + let generic_args = + alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| { + match param { + // Replace the very first type parameter with the function's return type. + ast::GenericParam::TypeParam(_) if !inserted_ret_type => { + inserted_ret_type = true; + make.type_arg(ret_type.clone()).into() + } + ast::GenericParam::LifetimeParam(_) => { + make.lifetime_arg(make.lifetime("'_")).into() + } + _ => make.type_arg(make.ty_infer().into()).into(), } - ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(), - _ => make::ty_placeholder().to_smolstr(), - }) - .join(", "); + }); let name = alias.name(ctx.db()); - let name = name.as_str(); - Some(make::ty(&format!("{name}<{generic_params}>"))) + let generic_arg_list = make.generic_arg_list(generic_args, false); + let path = make.path_unqualified( + make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list), + ); + + Some(make.ty_path(path)) }) }) } diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs index d62c01ba761a..af7b3c815812 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -1,6 +1,9 @@ //! Wrappers over [`make`] constructors use crate::{ - ast::{self, make, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds, HasVisibility}, + ast::{ + self, make, HasArgList, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds, + HasVisibility, + }, syntax_editor::SyntaxMappingBuilder, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken, }; @@ -16,6 +19,10 @@ impl SyntaxFactory { make::name_ref(name).clone_for_update() } + pub fn lifetime(&self, text: &str) -> ast::Lifetime { + make::lifetime(text).clone_for_update() + } + pub fn ty(&self, text: &str) -> ast::Type { make::ty(text).clone_for_update() } @@ -28,6 +35,20 @@ impl SyntaxFactory { ast } + pub fn ty_path(&self, path: ast::Path) -> ast::PathType { + let ast::Type::PathType(ast) = make::ty_path(path.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn type_param( &self, name: ast::Name, @@ -253,6 +274,37 @@ impl SyntaxFactory { ast } + pub fn expr_call(&self, expr: ast::Expr, arg_list: ast::ArgList) -> ast::CallExpr { + // FIXME: `make::expr_call`` should return a `CallExpr`, not just an `Expr` + let ast::Expr::CallExpr(ast) = + make::expr_call(expr.clone(), arg_list.clone()).clone_for_update() + else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone()); + builder.map_node(arg_list.syntax().clone(), ast.arg_list().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn arg_list(&self, args: impl IntoIterator) -> ast::ArgList { + let (args, input) = iterator_input(args); + let ast = make::arg_list(args).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone()); + builder.map_children(input.into_iter(), ast.args().map(|it| it.syntax().clone())); + builder.finish(&mut mapping); + } + + ast + } + pub fn expr_ref(&self, expr: ast::Expr, exclusive: bool) -> ast::Expr { let ast::Expr::RefExpr(ast) = make::expr_ref(expr.clone(), exclusive).clone_for_update() else { @@ -428,6 +480,30 @@ impl SyntaxFactory { ast } + pub fn type_arg(&self, ty: ast::Type) -> ast::TypeArg { + let ast = make::type_arg(ty.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + + pub fn lifetime_arg(&self, lifetime: ast::Lifetime) -> ast::LifetimeArg { + let ast = make::lifetime_arg(lifetime.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(lifetime.syntax().clone(), ast.lifetime().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn item_const( &self, visibility: Option, @@ -495,12 +571,17 @@ impl SyntaxFactory { ast } - pub fn turbofish_generic_arg_list( + pub fn generic_arg_list( &self, generic_args: impl IntoIterator, + is_turbo: bool, ) -> ast::GenericArgList { let (generic_args, input) = iterator_input(generic_args); - let ast = make::turbofish_generic_arg_list(generic_args.clone()).clone_for_update(); + let ast = if is_turbo { + make::turbofish_generic_arg_list(generic_args).clone_for_update() + } else { + make::generic_arg_list(generic_args).clone_for_update() + }; if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); @@ -753,12 +834,31 @@ impl SyntaxFactory { // `ext` constructors impl SyntaxFactory { + pub fn ident_path(&self, ident: &str) -> ast::Path { + self.path_unqualified(self.path_segment(self.name_ref(ident))) + } + pub fn expr_unit(&self) -> ast::Expr { self.expr_tuple([]).into() } - pub fn ident_path(&self, ident: &str) -> ast::Path { - self.path_unqualified(self.path_segment(self.name_ref(ident))) + pub fn ty_option(&self, t: ast::Type) -> ast::PathType { + let generic_arg_list = self.generic_arg_list([self.type_arg(t).into()], false); + let path = self.path_unqualified( + self.path_segment_generics(self.name_ref("Option"), generic_arg_list), + ); + + self.ty_path(path) + } + + pub fn ty_result(&self, t: ast::Type, e: ast::Type) -> ast::PathType { + let generic_arg_list = + self.generic_arg_list([self.type_arg(t).into(), self.type_arg(e).into()], false); + let path = self.path_unqualified( + self.path_segment_generics(self.name_ref("Result"), generic_arg_list), + ); + + self.ty_path(path) } }