diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs index af2c2c759ec7..490a9ee3c04d 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs @@ -1,7 +1,8 @@ -use ide_db::base_db::SourceDatabase; -use syntax::TextSize; use syntax::{ - algo::non_trivia_sibling, ast, AstNode, Direction, SyntaxKind, SyntaxToken, TextRange, T, + algo::non_trivia_sibling, + ast::{self, syntax_factory::SyntaxFactory}, + syntax_editor::{Element, SyntaxMapping}, + AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxToken, T, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -25,8 +26,6 @@ pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( let comma = ctx.find_token_syntax_at_offset(T![,])?; let prev = non_trivia_sibling(comma.clone().into(), Direction::Prev)?; let next = non_trivia_sibling(comma.clone().into(), Direction::Next)?; - let (mut prev_text, mut next_text) = (prev.to_string(), next.to_string()); - let (mut prev_range, mut next_range) = (prev.text_range(), next.text_range()); // Don't apply a "flip" in case of a last comma // that typically comes before punctuation @@ -40,53 +39,85 @@ pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( return None; } - if let Some(parent) = comma.parent().and_then(ast::TokenTree::cast) { - // An attribute. It often contains a path followed by a token tree (e.g. `align(2)`), so we have - // to be smarter. - let prev_start = - match comma.siblings_with_tokens(Direction::Prev).skip(1).find(|it| it.kind() == T![,]) - { - Some(it) => position_after_token(it.as_token().unwrap()), - None => position_after_token(&parent.left_delimiter_token()?), - }; - let prev_end = prev.text_range().end(); - let next_start = next.text_range().start(); - let next_end = - match comma.siblings_with_tokens(Direction::Next).skip(1).find(|it| it.kind() == T![,]) - { - Some(it) => position_before_token(it.as_token().unwrap()), - None => position_before_token(&parent.right_delimiter_token()?), - }; - prev_range = TextRange::new(prev_start, prev_end); - next_range = TextRange::new(next_start, next_end); - let file_text = ctx.db().file_text(ctx.file_id().file_id()); - prev_text = file_text[prev_range].to_owned(); - next_text = file_text[next_range].to_owned(); - } + // FIXME: remove `clone_for_update` when `SyntaxEditor` handles it for us + let prev = match prev { + SyntaxElement::Node(node) => node.clone_for_update().syntax_element(), + _ => prev, + }; + let next = match next { + SyntaxElement::Node(node) => node.clone_for_update().syntax_element(), + _ => next, + }; acc.add( AssistId("flip_comma", AssistKind::RefactorRewrite), "Flip comma", comma.text_range(), - |edit| { - edit.replace(prev_range, next_text); - edit.replace(next_range, prev_text); + |builder| { + let parent = comma.parent().unwrap(); + let mut editor = builder.make_editor(&parent); + + if let Some(parent) = ast::TokenTree::cast(parent) { + // An attribute. It often contains a path followed by a + // token tree (e.g. `align(2)`), so we have to be smarter. + let (new_tree, mapping) = flip_tree(parent.clone(), comma); + editor.replace(parent.syntax(), new_tree.syntax()); + editor.add_mappings(mapping); + } else { + editor.replace(prev.clone(), next.clone()); + editor.replace(next.clone(), prev.clone()); + } + + builder.add_file_edits(ctx.file_id(), editor); }, ) } -fn position_before_token(token: &SyntaxToken) -> TextSize { - match non_trivia_sibling(token.clone().into(), Direction::Prev) { - Some(prev_token) => prev_token.text_range().end(), - None => token.text_range().start(), - } -} +fn flip_tree(tree: ast::TokenTree, comma: SyntaxToken) -> (ast::TokenTree, SyntaxMapping) { + let mut tree_iter = tree.token_trees_and_tokens(); + let before: Vec<_> = + tree_iter.by_ref().take_while(|it| it.as_token() != Some(&comma)).collect(); + let after: Vec<_> = tree_iter.collect(); -fn position_after_token(token: &SyntaxToken) -> TextSize { - match non_trivia_sibling(token.clone().into(), Direction::Next) { - Some(prev_token) => prev_token.text_range().start(), - None => token.text_range().end(), - } + let not_ws = |element: &NodeOrToken<_, SyntaxToken>| match element { + NodeOrToken::Token(token) => token.kind() != SyntaxKind::WHITESPACE, + NodeOrToken::Node(_) => true, + }; + + let is_comma = |element: &NodeOrToken<_, SyntaxToken>| match element { + NodeOrToken::Token(token) => token.kind() == T![,], + NodeOrToken::Node(_) => false, + }; + + let prev_start_untrimmed = match before.iter().rposition(is_comma) { + Some(pos) => pos + 1, + None => 1, + }; + let prev_end = 1 + before.iter().rposition(not_ws).unwrap(); + let prev_start = prev_start_untrimmed + + before[prev_start_untrimmed..prev_end].iter().position(not_ws).unwrap(); + + let next_start = after.iter().position(not_ws).unwrap(); + let next_end_untrimmed = match after.iter().position(is_comma) { + Some(pos) => pos, + None => after.len() - 1, + }; + let next_end = 1 + after[..next_end_untrimmed].iter().rposition(not_ws).unwrap(); + + let result = [ + &before[1..prev_start], + &after[next_start..next_end], + &before[prev_end..], + &[NodeOrToken::Token(comma)], + &after[..next_start], + &before[prev_start..prev_end], + &after[next_end..after.len() - 1], + ] + .concat(); + + let make = SyntaxFactory::new(); + let new_token_tree = make.token_tree(tree.left_delimiter_token().unwrap().kind(), result); + (new_token_tree, make.finish_with_mappings()) } #[cfg(test)] @@ -147,4 +178,9 @@ mod tests { r#"#[foo(bar, qux, baz(1 + 1), other)] struct Foo;"#, ); } + + #[test] + fn flip_comma_attribute_incomplete() { + check_assist_not_applicable(flip_comma, r#"#[repr(align(2),$0)] struct Foo;"#); + } }