diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs index a83f6835ca61..e4f5e3523bd2 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs @@ -1,10 +1,12 @@ use syntax::{ - AstNode, SyntaxKind, T, TextRange, + AstNode, SyntaxElement, SyntaxKind, SyntaxNode, T, ast::{ self, edit::{AstNodeEdit, IndentLevel}, make, }, + match_ast, + syntax_editor::{Element, Position, SyntaxEditor}, }; use crate::{AssistContext, AssistId, Assists}; @@ -27,123 +29,108 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - let assist_id = AssistId::refactor_rewrite("unwrap_block"); - let assist_label = "Unwrap block"; let l_curly_token = ctx.find_token_syntax_at_offset(T!['{'])?; - let mut block = ast::BlockExpr::cast(l_curly_token.parent_ancestors().nth(1)?)?; + let block = l_curly_token.parent_ancestors().nth(1).and_then(ast::BlockExpr::cast)?; let target = block.syntax().text_range(); - let mut parent = block.syntax().parent()?; - if ast::MatchArm::can_cast(parent.kind()) { - parent = parent.ancestors().find(|it| ast::MatchExpr::can_cast(it.kind()))? - } + let mut container = block.syntax().clone(); + let mut replacement = block.clone(); + let mut prefer_container = None; - let kind = parent.kind(); - if matches!(kind, SyntaxKind::STMT_LIST | SyntaxKind::EXPR_STMT) { - acc.add(assist_id, assist_label, target, |builder| { - builder.replace(block.syntax().text_range(), update_expr_string(block.to_string())); - }) - } else if matches!(kind, SyntaxKind::LET_STMT) { - let parent = ast::LetStmt::cast(parent)?; - let pattern = ast::Pat::cast(parent.syntax().first_child()?)?; - let ty = parent.ty(); - let list = block.stmt_list()?; - let replaced = match list.syntax().last_child() { - Some(last) => { - let stmts: Vec = list.statements().collect(); - let initializer = ast::Expr::cast(last)?; - let let_stmt = make::let_stmt(pattern, ty, Some(initializer)); - if !stmts.is_empty() { - let block = make::block_expr(stmts, None); - format!("{}\n {}", update_expr_string(block.to_string()), let_stmt) - } else { - let_stmt.to_string() - } - } - None => { - let empty_tuple = make::ext::expr_unit(); - make::let_stmt(pattern, ty, Some(empty_tuple)).to_string() - } - }; - acc.add(assist_id, assist_label, target, |builder| { - builder.replace(parent.syntax().text_range(), replaced); - }) - } else { - let parent = ast::Expr::cast(parent)?; - match parent.clone() { - ast::Expr::ForExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::LoopExpr(_) => (), - ast::Expr::MatchExpr(_) => block = block.dedent(IndentLevel(1)), - ast::Expr::IfExpr(if_expr) => { - let then_branch = if_expr.then_branch()?; - if then_branch == block { - if let Some(ancestor) = if_expr.syntax().parent().and_then(ast::IfExpr::cast) { - // For `else if` blocks - let ancestor_then_branch = ancestor.then_branch()?; - - return acc.add(assist_id, assist_label, target, |edit| { - let range_to_del_else_if = TextRange::new( - ancestor_then_branch.syntax().text_range().end(), - l_curly_token.text_range().start(), - ); - let range_to_del_rest = TextRange::new( - then_branch.syntax().text_range().end(), - if_expr.syntax().text_range().end(), - ); - - edit.delete(range_to_del_rest); - edit.delete(range_to_del_else_if); - edit.replace( - target, - update_expr_string_without_newline(then_branch.to_string()), - ); - }); - } - } else { - return acc.add(assist_id, assist_label, target, |edit| { - let range_to_del = TextRange::new( - then_branch.syntax().text_range().end(), - l_curly_token.text_range().start(), - ); - - edit.delete(range_to_del); - edit.replace(target, update_expr_string_without_newline(block.to_string())); + let from_indent = block.indent_level(); + let into_indent = loop { + let parent = container.parent()?; + container = match_ast! { + match parent { + ast::ForExpr(it) => it.syntax().clone(), + ast::LoopExpr(it) => it.syntax().clone(), + ast::WhileExpr(it) => it.syntax().clone(), + ast::MatchArm(it) => it.parent_match().syntax().clone(), + ast::LetStmt(it) => { + replacement = wrap_let(&it, replacement); + prefer_container = Some(it.syntax().clone()); + it.syntax().clone() + }, + ast::IfExpr(it) => { + prefer_container.get_or_insert_with(|| { + if let Some(else_branch) = it.else_branch() + && *else_branch.syntax() == container + { + else_branch.syntax().clone() + } else { + it.syntax().clone() + } }); - } + it.syntax().clone() + }, + ast::ExprStmt(it) => it.syntax().clone(), + ast::StmtList(it) => break it.indent_level(), + _ => return None, } - _ => return None, }; + }; + let replacement = replacement.stmt_list()?; - acc.add(assist_id, assist_label, target, |builder| { - builder.replace(parent.syntax().text_range(), update_expr_string(block.to_string())); - }) - } + acc.add(AssistId::refactor_rewrite("unwrap_block"), "Unwrap block", target, |builder| { + let mut edit = builder.make_editor(block.syntax()); + let replacement = replacement.dedent(from_indent).indent(into_indent); + let container = prefer_container.unwrap_or(container); + + edit.replace_with_many(&container, extract_statements(replacement)); + delete_else_before(container, &mut edit); + + builder.add_file_edits(ctx.vfs_file_id(), edit); + }) } -fn update_expr_string(expr_string: String) -> String { - update_expr_string_with_pat(expr_string, &[' ', '\n']) +fn delete_else_before(container: SyntaxNode, edit: &mut SyntaxEditor) { + let Some(else_token) = container + .siblings_with_tokens(syntax::Direction::Prev) + .skip(1) + .map_while(|it| it.into_token()) + .find(|it| it.kind() == T![else]) + else { + return; + }; + itertools::chain(else_token.prev_token(), else_token.next_token()) + .filter(|it| it.kind() == SyntaxKind::WHITESPACE) + .for_each(|it| edit.delete(it)); + let indent = IndentLevel::from_node(&container); + let newline = make::tokens::whitespace(&format!("\n{indent}")); + edit.replace(else_token, newline); } -fn update_expr_string_without_newline(expr_string: String) -> String { - update_expr_string_with_pat(expr_string, &[' ']) +fn wrap_let(assign: &ast::LetStmt, replacement: ast::BlockExpr) -> ast::BlockExpr { + let try_wrap_assign = || { + let initializer = assign.initializer()?.syntax().syntax_element(); + let replacement = replacement.clone_subtree(); + let assign = assign.clone_for_update(); + let tail_expr = replacement.tail_expr()?; + let before = + assign.syntax().children_with_tokens().take_while(|it| *it != initializer).collect(); + let after = assign + .syntax() + .children_with_tokens() + .skip_while(|it| *it != initializer) + .skip(1) + .collect(); + + let mut edit = SyntaxEditor::new(replacement.syntax().clone()); + edit.insert_all(Position::before(tail_expr.syntax()), before); + edit.insert_all(Position::after(tail_expr.syntax()), after); + ast::BlockExpr::cast(edit.finish().new_root().clone()) + }; + try_wrap_assign().unwrap_or(replacement) } -fn update_expr_string_with_pat(expr_str: String, whitespace_pat: &[char]) -> String { - // Remove leading whitespace, index to remove the leading '{', - // then continue to remove leading whitespace. - // We cannot assume the `{` is the first character because there are block modifiers - // (`unsafe`, `async` etc.). - let after_open_brace_index = expr_str.find('{').map_or(0, |it| it + 1); - let expr_str = expr_str[after_open_brace_index..].trim_start_matches(whitespace_pat); - - // Remove trailing whitespace, index [..expr_str.len() - 1] to remove the trailing '}', - // then continue to remove trailing whitespace. - let expr_str = expr_str.trim_end_matches(whitespace_pat); - let expr_str = expr_str[..expr_str.len() - 1].trim_end_matches(whitespace_pat); - - expr_str - .lines() - .map(|line| line.replacen(" ", "", 1)) // Delete indentation - .collect::>() - .join("\n") +fn extract_statements(stmt_list: ast::StmtList) -> Vec { + let mut elements = stmt_list + .syntax() + .children_with_tokens() + .filter(|it| !matches!(it.kind(), T!['{'] | T!['}'])) + .skip_while(|it| it.kind() == SyntaxKind::WHITESPACE) + .collect::>(); + while elements.pop_if(|it| it.kind() == SyntaxKind::WHITESPACE).is_some() {} + elements } #[cfg(test)] @@ -593,6 +580,30 @@ fn main() { ); } + #[test] + fn unwrap_match_arm_in_let() { + check_assist( + unwrap_block, + r#" +fn main() { + let value = match rel_path { + Ok(rel_path) => {$0 + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + Some((*id, rel_path)) + } + Err(_) => None, + }; +} +"#, + r#" +fn main() { + let rel_path = RelativePathBuf::from_path(rel_path).ok()?; + let value = Some((*id, rel_path)); +} +"#, + ); + } + #[test] fn simple_if_in_while_bad_cursor_position() { check_assist_not_applicable( @@ -750,19 +761,6 @@ fn main() -> i32 { check_assist( unwrap_block, r#" -fn main() { - let x = {$0}; -} -"#, - r#" -fn main() { - let x = (); -} -"#, - ); - check_assist( - unwrap_block, - r#" fn main() { let x = {$0 bar @@ -784,8 +782,7 @@ fn main() -> i32 { "#, r#" fn main() -> i32 { - 1; - let _ = 2; + 1; let _ = 2; } "#, ); @@ -795,11 +792,29 @@ fn main() -> i32 { fn main() -> i32 { let mut a = {$01; 2}; } +"#, + r#" +fn main() -> i32 { + 1; let mut a = 2; +} +"#, + ); + check_assist( + unwrap_block, + r#" +fn main() -> i32 { + let mut a = {$0 + 1; + 2; + 3 + }; +} "#, r#" fn main() -> i32 { 1; - let mut a = 2; + 2; + let mut a = 3; } "#, );