diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs index 0e5e6185d054..ad983df8a57a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs @@ -5,15 +5,16 @@ use ide_db::{ assists::AssistId, defs::Definition, helpers::mod_path_to_ast, - imports::insert_use::{ImportScope, insert_use}, + imports::insert_use::{ImportScope, insert_use_with_editor}, search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, }; use syntax::{ AstNode, SyntaxNode, - ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, make}, - match_ast, ted, + ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, syntax_factory::SyntaxFactory}, + match_ast, + syntax_editor::SyntaxEditor, }; use crate::assist_context::{AssistContext, Assists}; @@ -67,14 +68,15 @@ pub(crate) fn convert_tuple_return_type_to_struct( "Convert tuple return type to tuple struct", target, move |edit| { - let ret_type = edit.make_mut(ret_type); - let fn_ = edit.make_mut(fn_); + let mut syntax_editor = edit.make_editor(ret_type.syntax()); + let syntax_factory = SyntaxFactory::with_mappings(); let usages = Definition::Function(fn_def).usages(&ctx.sema).all(); let struct_name = format!("{}Result", stdx::to_camel_case(&fn_name.to_string())); let parent = fn_.syntax().ancestors().find_map(>::cast); add_tuple_struct_def( edit, + &syntax_factory, ctx, &usages, parent.as_ref().map(|it| it.syntax()).unwrap_or(fn_.syntax()), @@ -83,15 +85,23 @@ pub(crate) fn convert_tuple_return_type_to_struct( &target_module, ); - ted::replace( + syntax_editor.replace( ret_type.syntax(), - make::ret_type(make::ty(&struct_name)).syntax().clone_for_update(), + syntax_factory.ret_type(syntax_factory.ty(&struct_name)).syntax(), ); if let Some(fn_body) = fn_.body() { - replace_body_return_values(ast::Expr::BlockExpr(fn_body), &struct_name); + replace_body_return_values( + &mut syntax_editor, + &syntax_factory, + ast::Expr::BlockExpr(fn_body), + &struct_name, + ); } + syntax_editor.add_mappings(syntax_factory.finish_with_mappings()); + edit.add_file_edits(ctx.vfs_file_id(), syntax_editor); + replace_usages(edit, ctx, &usages, &struct_name, &target_module); }, ) @@ -106,24 +116,37 @@ fn replace_usages( target_module: &hir::Module, ) { for (file_id, references) in usages.iter() { - edit.edit_file(file_id.file_id(ctx.db())); + let Some(first_ref) = references.first() else { continue }; - let refs_with_imports = - augment_references_with_imports(edit, ctx, references, struct_name, target_module); + let mut editor = edit.make_editor(first_ref.name.syntax().as_node().unwrap()); + let syntax_factory = SyntaxFactory::with_mappings(); + + let refs_with_imports = augment_references_with_imports( + &syntax_factory, + ctx, + references, + struct_name, + target_module, + ); refs_with_imports.into_iter().rev().for_each(|(name, import_data)| { if let Some(fn_) = name.syntax().parent().and_then(ast::Fn::cast) { cov_mark::hit!(replace_trait_impl_fns); if let Some(ret_type) = fn_.ret_type() { - ted::replace( + editor.replace( ret_type.syntax(), - make::ret_type(make::ty(struct_name)).syntax().clone_for_update(), + syntax_factory.ret_type(syntax_factory.ty(struct_name)).syntax(), ); } if let Some(fn_body) = fn_.body() { - replace_body_return_values(ast::Expr::BlockExpr(fn_body), struct_name); + replace_body_return_values( + &mut editor, + &syntax_factory, + ast::Expr::BlockExpr(fn_body), + struct_name, + ); } } else { // replace tuple patterns @@ -143,22 +166,30 @@ fn replace_usages( _ => None, }); for tuple_pat in tuple_pats { - ted::replace( + editor.replace( tuple_pat.syntax(), - make::tuple_struct_pat( - make::path_from_text(struct_name), - tuple_pat.fields(), - ) - .clone_for_update() - .syntax(), + syntax_factory + .tuple_struct_pat( + syntax_factory.path_from_text(struct_name), + tuple_pat.fields(), + ) + .syntax(), ); } } - // add imports across modules where needed if let Some((import_scope, path)) = import_data { - insert_use(&import_scope, path, &ctx.config.insert_use); + insert_use_with_editor( + &import_scope, + path, + &ctx.config.insert_use, + &mut editor, + &syntax_factory, + ); } - }) + }); + + editor.add_mappings(syntax_factory.finish_with_mappings()); + edit.add_file_edits(file_id.file_id(ctx.db()), editor); } } @@ -176,7 +207,7 @@ fn node_to_pats(node: SyntaxNode) -> Option> { } fn augment_references_with_imports( - edit: &mut SourceChangeBuilder, + syntax_factory: &SyntaxFactory, ctx: &AssistContext<'_>, references: &[FileReference], struct_name: &str, @@ -191,8 +222,6 @@ fn augment_references_with_imports( ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) }) .map(|(name, ref_module)| { - let new_name = edit.make_mut(name); - // if the referenced module is not the same as the target one and has not been seen before, add an import let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module && !visited_modules.contains(&ref_module) @@ -201,8 +230,7 @@ fn augment_references_with_imports( let cfg = ctx.config.find_path_config(ctx.sema.is_nightly(ref_module.krate(ctx.sema.db))); - let import_scope = - ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema); + let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); let path = ref_module .find_use_path( ctx.sema.db, @@ -211,12 +239,12 @@ fn augment_references_with_imports( cfg, ) .map(|mod_path| { - make::path_concat( + syntax_factory.path_concat( mod_path_to_ast( &mod_path, target_module.krate(ctx.db()).edition(ctx.db()), ), - make::path_from_text(struct_name), + syntax_factory.path_from_text(struct_name), ) }); @@ -225,7 +253,7 @@ fn augment_references_with_imports( None }; - (new_name, import_data) + (name, import_data) }) .collect() } @@ -233,6 +261,7 @@ fn augment_references_with_imports( // Adds the definition of the tuple struct before the parent function. fn add_tuple_struct_def( edit: &mut SourceChangeBuilder, + syntax_factory: &SyntaxFactory, ctx: &AssistContext<'_>, usages: &UsageSearchResult, parent: &SyntaxNode, @@ -248,13 +277,13 @@ fn add_tuple_struct_def( ctx.sema.scope(name.syntax()).map(|scope| scope.module()) }) .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); - let visibility = if make_struct_pub { Some(make::visibility_pub()) } else { None }; + let visibility = if make_struct_pub { Some(syntax_factory.visibility_pub()) } else { None }; - let field_list = ast::FieldList::TupleFieldList(make::tuple_field_list( - tuple_ty.fields().map(|ty| make::tuple_field(visibility.clone(), ty)), + let field_list = ast::FieldList::TupleFieldList(syntax_factory.tuple_field_list( + tuple_ty.fields().map(|ty| syntax_factory.tuple_field(visibility.clone(), ty)), )); - let struct_name = make::name(struct_name); - let struct_def = make::struct_(visibility, struct_name, None, field_list).clone_for_update(); + let struct_name = syntax_factory.name(struct_name); + let struct_def = syntax_factory.struct_(visibility, struct_name, None, field_list); let indent = IndentLevel::from_node(parent); struct_def.reindent_to(indent); @@ -263,7 +292,12 @@ fn add_tuple_struct_def( } /// Replaces each returned tuple in `body` with the constructor of the tuple struct named `struct_name`. -fn replace_body_return_values(body: ast::Expr, struct_name: &str) { +fn replace_body_return_values( + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, + body: ast::Expr, + struct_name: &str, +) { let mut exprs_to_wrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); @@ -278,12 +312,11 @@ fn replace_body_return_values(body: ast::Expr, struct_name: &str) { for ret_expr in exprs_to_wrap { if let ast::Expr::TupleExpr(tuple_expr) = &ret_expr { - let struct_constructor = make::expr_call( - make::expr_path(make::ext::ident_path(struct_name)), - make::arg_list(tuple_expr.fields()), - ) - .clone_for_update(); - ted::replace(ret_expr.syntax(), struct_constructor.syntax()); + let struct_constructor = syntax_factory.expr_call( + syntax_factory.expr_path(syntax_factory.ident_path(struct_name)), + syntax_factory.arg_list(tuple_expr.fields()), + ); + syntax_editor.replace(ret_expr.syntax(), struct_constructor.syntax()); } } } diff --git a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs index db1d599d550d..f26952fa1535 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs @@ -9,8 +9,9 @@ use syntax::{ Direction, NodeOrToken, SyntaxKind, SyntaxNode, algo, ast::{ self, AstNode, HasAttrs, HasModuleItem, HasVisibility, PathSegmentKind, - edit_in_place::Removable, make, + edit_in_place::Removable, make, syntax_factory::SyntaxFactory, }, + syntax_editor::{Position, SyntaxEditor}, ted, }; @@ -146,6 +147,17 @@ pub fn insert_use(scope: &ImportScope, path: ast::Path, cfg: &InsertUseConfig) { insert_use_with_alias_option(scope, path, cfg, None); } +/// Insert an import path into the given file/node. A `merge` value of none indicates that no import merging is allowed to occur. +pub fn insert_use_with_editor( + scope: &ImportScope, + path: ast::Path, + cfg: &InsertUseConfig, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + insert_use_with_alias_option_with_editor(scope, path, cfg, None, syntax_editor, syntax_factory); +} + pub fn insert_use_as_alias( scope: &ImportScope, path: ast::Path, @@ -229,6 +241,71 @@ fn insert_use_with_alias_option( insert_use_(scope, use_item, cfg.group); } +fn insert_use_with_alias_option_with_editor( + scope: &ImportScope, + path: ast::Path, + cfg: &InsertUseConfig, + alias: Option, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + let _p = tracing::info_span!("insert_use_with_alias_option").entered(); + let mut mb = match cfg.granularity { + ImportGranularity::Crate => Some(MergeBehavior::Crate), + ImportGranularity::Module => Some(MergeBehavior::Module), + ImportGranularity::One => Some(MergeBehavior::One), + ImportGranularity::Item => None, + }; + if !cfg.enforce_granularity { + let file_granularity = guess_granularity_from_scope(scope); + mb = match file_granularity { + ImportGranularityGuess::Unknown => mb, + ImportGranularityGuess::Item => None, + ImportGranularityGuess::Module => Some(MergeBehavior::Module), + // We use the user's setting to infer if this is module or item. + ImportGranularityGuess::ModuleOrItem => match mb { + Some(MergeBehavior::Module) | None => mb, + // There isn't really a way to decide between module or item here, so we just pick one. + // FIXME: Maybe it is possible to infer based on semantic analysis? + Some(MergeBehavior::One | MergeBehavior::Crate) => Some(MergeBehavior::Module), + }, + ImportGranularityGuess::Crate => Some(MergeBehavior::Crate), + ImportGranularityGuess::CrateOrModule => match mb { + Some(MergeBehavior::Crate | MergeBehavior::Module) => mb, + Some(MergeBehavior::One) | None => Some(MergeBehavior::Crate), + }, + ImportGranularityGuess::One => Some(MergeBehavior::One), + }; + } + + let use_tree = syntax_factory.use_tree(path, None, alias, false); + if mb == Some(MergeBehavior::One) && use_tree.path().is_some() { + use_tree.wrap_in_tree_list(); + } + let use_item = make::use_(None, None, use_tree).clone_for_update(); + for attr in + scope.required_cfgs.iter().map(|attr| attr.syntax().clone_subtree().clone_for_update()) + { + syntax_editor.insert(Position::first_child_of(use_item.syntax()), attr); + } + + // merge into existing imports if possible + if let Some(mb) = mb { + let filter = |it: &_| !(cfg.skip_glob_imports && ast::Use::is_simple_glob(it)); + for existing_use in + scope.as_syntax_node().children().filter_map(ast::Use::cast).filter(filter) + { + if let Some(merged) = try_merge_imports(&existing_use, &use_item, mb) { + syntax_editor.replace(existing_use.syntax(), merged.syntax()); + return; + } + } + } + // either we weren't allowed to merge or there is no import that fits the merge conditions + // so look for the place we have to insert to + insert_use_with_editor_(scope, use_item, cfg.group, syntax_editor, syntax_factory); +} + pub fn ast_to_remove_for_path_in_use_stmt(path: &ast::Path) -> Option> { // FIXME: improve this if path.parent_path().is_some() { @@ -500,6 +577,127 @@ fn insert_use_(scope: &ImportScope, use_item: ast::Use, group_imports: bool) { } } +fn insert_use_with_editor_( + scope: &ImportScope, + use_item: ast::Use, + group_imports: bool, + syntax_editor: &mut SyntaxEditor, + syntax_factory: &SyntaxFactory, +) { + let scope_syntax = scope.as_syntax_node(); + let insert_use_tree = + use_item.use_tree().expect("`use_item` should have a use tree for `insert_path`"); + let group = ImportGroup::new(&insert_use_tree); + let path_node_iter = scope_syntax + .children() + .filter_map(|node| ast::Use::cast(node.clone()).zip(Some(node))) + .flat_map(|(use_, node)| { + let tree = use_.use_tree()?; + Some((tree, node)) + }); + + if group_imports { + // Iterator that discards anything that's not in the required grouping + // This implementation allows the user to rearrange their import groups as this only takes the first group that fits + let group_iter = path_node_iter + .clone() + .skip_while(|(use_tree, ..)| ImportGroup::new(use_tree) != group) + .take_while(|(use_tree, ..)| ImportGroup::new(use_tree) == group); + + // track the last element we iterated over, if this is still None after the iteration then that means we never iterated in the first place + let mut last = None; + // find the element that would come directly after our new import + let post_insert: Option<(_, SyntaxNode)> = group_iter + .inspect(|(.., node)| last = Some(node.clone())) + .find(|(use_tree, _)| use_tree_cmp(&insert_use_tree, use_tree) != Ordering::Greater); + + if let Some((.., node)) = post_insert { + cov_mark::hit!(insert_group); + // insert our import before that element + return syntax_editor.insert(Position::before(node), use_item.syntax()); + } + if let Some(node) = last { + cov_mark::hit!(insert_group_last); + // there is no element after our new import, so append it to the end of the group + return syntax_editor.insert(Position::after(node), use_item.syntax()); + } + + // the group we were looking for actually doesn't exist, so insert + + let mut last = None; + // find the group that comes after where we want to insert + let post_group = path_node_iter + .inspect(|(.., node)| last = Some(node.clone())) + .find(|(use_tree, ..)| ImportGroup::new(use_tree) > group); + if let Some((.., node)) = post_group { + cov_mark::hit!(insert_group_new_group); + syntax_editor.insert(Position::before(&node), use_item.syntax()); + if let Some(node) = algo::non_trivia_sibling(node.into(), Direction::Prev) { + syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + } + return; + } + // there is no such group, so append after the last one + if let Some(node) = last { + cov_mark::hit!(insert_group_no_group); + syntax_editor.insert(Position::after(&node), use_item.syntax()); + syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + return; + } + } else { + // There exists a group, so append to the end of it + if let Some((_, node)) = path_node_iter.last() { + cov_mark::hit!(insert_no_grouping_last); + syntax_editor.insert(Position::after(node), use_item.syntax()); + return; + } + } + + let l_curly = match &scope.kind { + ImportScopeKind::File(_) => None, + // don't insert the imports before the item list/block expr's opening curly brace + ImportScopeKind::Module(item_list) => item_list.l_curly_token(), + // don't insert the imports before the item list's opening curly brace + ImportScopeKind::Block(block) => block.l_curly_token(), + }; + // there are no imports in this file at all + // so put the import after all inner module attributes and possible license header comments + if let Some(last_inner_element) = scope_syntax + .children_with_tokens() + // skip the curly brace + .skip(l_curly.is_some() as usize) + .take_while(|child| match child { + NodeOrToken::Node(node) => is_inner_attribute(node.clone()), + NodeOrToken::Token(token) => { + [SyntaxKind::WHITESPACE, SyntaxKind::COMMENT, SyntaxKind::SHEBANG] + .contains(&token.kind()) + } + }) + .filter(|child| child.as_token().is_none_or(|t| t.kind() != SyntaxKind::WHITESPACE)) + .last() + { + cov_mark::hit!(insert_empty_inner_attr); + syntax_editor.insert(Position::after(&last_inner_element), use_item.syntax()); + syntax_editor.insert(Position::after(last_inner_element), syntax_factory.whitespace("\n")); + } else { + match l_curly { + Some(b) => { + cov_mark::hit!(insert_empty_module); + syntax_editor.insert(Position::after(&b), syntax_factory.whitespace("\n")); + syntax_editor.insert(Position::after(&b), use_item.syntax()); + } + None => { + cov_mark::hit!(insert_empty_file); + syntax_editor.insert( + Position::first_child_of(scope_syntax), + syntax_factory.whitespace("\n\n"), + ); + syntax_editor.insert(Position::first_child_of(scope_syntax), use_item.syntax()); + } + } + } +} + fn is_inner_attribute(node: SyntaxNode) -> bool { ast::Attr::cast(node).map(|attr| attr.kind()) == Some(ast::AttrKind::Inner) } diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs index ad9b9054a8f5..6e17d262a79d 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -75,6 +75,24 @@ impl SyntaxFactory { make::path_from_text(text).clone_for_update() } + pub fn path_concat(&self, first: ast::Path, second: ast::Path) -> ast::Path { + make::path_concat(first, second).clone_for_update() + } + + pub fn visibility_pub(&self) -> ast::Visibility { + make::visibility_pub() + } + + pub fn struct_( + &self, + visibility: Option, + strukt_name: ast::Name, + generic_param_list: Option, + field_list: ast::FieldList, + ) -> ast::Struct { + make::struct_(visibility, strukt_name, generic_param_list, field_list).clone_for_update() + } + pub fn expr_field(&self, receiver: ast::Expr, field: &str) -> ast::FieldExpr { let ast::Expr::FieldExpr(ast) = make::expr_field(receiver.clone(), field).clone_for_update()