diff --git a/crates/ra_assists/src/assists/merge_match_arms.rs b/crates/ra_assists/src/assists/merge_match_arms.rs index aca391155f37..64c9379da18c 100644 --- a/crates/ra_assists/src/assists/merge_match_arms.rs +++ b/crates/ra_assists/src/assists/merge_match_arms.rs @@ -1,6 +1,12 @@ -use crate::{Assist, AssistCtx, AssistId, TextRange, TextUnit}; +use std::iter::successors; + use hir::db::HirDatabase; -use ra_syntax::ast::{AstNode, MatchArm}; +use ra_syntax::{ + ast::{self, AstNode}, + Direction, TextUnit, +}; + +use crate::{Assist, AssistCtx, AssistId, TextRange}; // Assist: merge_match_arms // @@ -27,62 +33,80 @@ use ra_syntax::ast::{AstNode, MatchArm}; // } // ``` pub(crate) fn merge_match_arms(ctx: AssistCtx) -> Option { - let current_arm = ctx.find_node_at_offset::()?; - - // We check if the following match arm matches this one. We could, but don't, - // compare to the previous match arm as well. - let next = current_arm.syntax().next_sibling(); - let next_arm = MatchArm::cast(next?)?; - + let current_arm = ctx.find_node_at_offset::()?; // Don't try to handle arms with guards for now - can add support for this later - if current_arm.guard().is_some() || next_arm.guard().is_some() { + if current_arm.guard().is_some() { return None; } - let current_expr = current_arm.expr()?; - let next_expr = next_arm.expr()?; + let current_text_range = current_arm.syntax().text_range(); - // Check for match arm equality by comparing lengths and then string contents - if current_expr.syntax().text_range().len() != next_expr.syntax().text_range().len() { + enum CursorPos { + InExpr(TextUnit), + InPat(TextUnit), + } + let cursor_pos = ctx.frange.range.start(); + let cursor_pos = if current_expr.syntax().text_range().contains(cursor_pos) { + CursorPos::InExpr(current_text_range.end() - cursor_pos) + } else { + CursorPos::InPat(cursor_pos) + }; + + // We check if the following match arms match this one. We could, but don't, + // compare to the previous match arm as well. + let arms_to_merge = successors(Some(current_arm), next_arm) + .take_while(|arm| { + if arm.guard().is_some() { + return false; + } + match arm.expr() { + Some(expr) => expr.syntax().text() == current_expr.syntax().text(), + None => false, + } + }) + .collect::>(); + + if arms_to_merge.len() <= 1 { return None; } - if current_expr.syntax().text() != next_expr.syntax().text() { - return None; - } - - let cursor_to_end = current_arm.syntax().text_range().end() - ctx.frange.range.start(); ctx.add_assist(AssistId("merge_match_arms"), "Merge match arms", |edit| { - fn contains_placeholder(a: &MatchArm) -> bool { - a.pats().any(|x| match x { - ra_syntax::ast::Pat::PlaceholderPat(..) => true, - _ => false, - }) - } - - let pats = if contains_placeholder(¤t_arm) || contains_placeholder(&next_arm) { + let pats = if arms_to_merge.iter().any(contains_placeholder) { "_".into() } else { - let ps: Vec = current_arm - .pats() + arms_to_merge + .iter() + .flat_map(ast::MatchArm::pats) .map(|x| x.syntax().to_string()) - .chain(next_arm.pats().map(|x| x.syntax().to_string())) - .collect(); - ps.join(" | ") + .collect::>() + .join(" | ") }; let arm = format!("{} => {}", pats, current_expr.syntax().text()); - let offset = TextUnit::from_usize(arm.len()) - cursor_to_end; - let start = current_arm.syntax().text_range().start(); - let end = next_arm.syntax().text_range().end(); + let start = arms_to_merge.first().unwrap().syntax().text_range().start(); + let end = arms_to_merge.last().unwrap().syntax().text_range().end(); - edit.target(current_arm.syntax().text_range()); + edit.target(current_text_range); + edit.set_cursor(match cursor_pos { + CursorPos::InExpr(back_offset) => start + TextUnit::from_usize(arm.len()) - back_offset, + CursorPos::InPat(offset) => offset, + }); edit.replace(TextRange::from_to(start, end), arm); - edit.set_cursor(start + offset); }) } +fn contains_placeholder(a: &ast::MatchArm) -> bool { + a.pats().any(|x| match x { + ra_syntax::ast::Pat::PlaceholderPat(..) => true, + _ => false, + }) +} + +fn next_arm(arm: &ast::MatchArm) -> Option { + arm.syntax().siblings(Direction::Next).skip(1).find_map(ast::MatchArm::cast) +} + #[cfg(test)] mod tests { use super::merge_match_arms; @@ -184,6 +208,37 @@ mod tests { ); } + #[test] + fn merges_all_subsequent_arms() { + check_assist( + merge_match_arms, + r#" + enum X { A, B, C, D, E } + + fn main() { + match X::A { + X::A<|> => 92, + X::B => 92, + X::C => 92, + X::D => 62, + _ => panic!(), + } + } + "#, + r#" + enum X { A, B, C, D, E } + + fn main() { + match X::A { + X::A<|> | X::B | X::C => 92, + X::D => 62, + _ => panic!(), + } + } + "#, + ) + } + #[test] fn merge_match_arms_rejects_guards() { check_assist_not_applicable(