diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs new file mode 100644 index 000000000000..68cb76403095 --- /dev/null +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs @@ -0,0 +1,259 @@ +use ide_db::assists::AssistId; +use itertools::Itertools; +use syntax::{ + AstNode, T, + algo::previous_non_trivia_token, + ast::{ + self, HasArgList, HasLoopBody, HasName, RangeItem, edit::AstNodeEdit, make, + syntax_factory::SyntaxFactory, + }, + syntax_editor::{Element, Position}, +}; + +use crate::assist_context::{AssistContext, Assists}; + +// Assist: convert_range_for_to_while +// +// Convert for each range into while loop. +// +// ``` +// fn foo() { +// $0for i in 3..7 { +// foo(i); +// } +// } +// ``` +// -> +// ``` +// fn foo() { +// let mut i = 3; +// while i < 7 { +// foo(i); +// i += 1; +// } +// } +// ``` +pub(crate) fn convert_range_for_to_while(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let for_kw = ctx.find_token_syntax_at_offset(T![for])?; + let for_ = ast::ForExpr::cast(for_kw.parent()?)?; + let ast::Pat::IdentPat(pat) = for_.pat()? else { return None }; + let iterable = for_.iterable()?; + let (start, end, step, inclusive) = extract_range(&iterable)?; + let name = pat.name()?; + let body = for_.loop_body()?; + let last = previous_non_trivia_token(body.stmt_list()?.r_curly_token()?)?; + + let description = if end.is_some() { + "Replace with while expression" + } else { + "Replace with loop expression" + }; + acc.add( + AssistId::refactor("convert_range_for_to_while"), + description, + for_.syntax().text_range(), + |builder| { + let mut edit = builder.make_editor(for_.syntax()); + let make = SyntaxFactory::with_mappings(); + + let indent = for_.indent_level(); + let pat = make.ident_pat(pat.ref_token().is_some(), true, name.clone()); + let let_stmt = make.let_stmt(pat.into(), None, Some(start)); + edit.insert_all( + Position::before(for_.syntax()), + vec![ + let_stmt.syntax().syntax_element(), + make.whitespace(&format!("\n{}", indent)).syntax_element(), + ], + ); + + let mut elements = vec![]; + + let var_expr = make.expr_path(make.ident_path(&name.text())); + let op = ast::BinaryOp::CmpOp(ast::CmpOp::Ord { + ordering: ast::Ordering::Less, + strict: !inclusive, + }); + if let Some(end) = end { + elements.extend([ + make.token(T![while]).syntax_element(), + make.whitespace(" ").syntax_element(), + make.expr_bin(var_expr.clone(), op, end).syntax().syntax_element(), + ]); + } else { + elements.push(make.token(T![loop]).syntax_element()); + } + + edit.replace_all( + for_kw.syntax_element()..=iterable.syntax().syntax_element(), + elements, + ); + + let op = ast::BinaryOp::Assignment { op: Some(ast::ArithOp::Add) }; + edit.insert_all( + Position::after(last), + vec![ + make.whitespace(&format!("\n{}", indent + 1)).syntax_element(), + make.expr_bin(var_expr, op, step).syntax().syntax_element(), + make.token(T![;]).syntax_element(), + ], + ); + + edit.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.vfs_file_id(), edit); + }, + ) +} + +fn extract_range(iterable: &ast::Expr) -> Option<(ast::Expr, Option, ast::Expr, bool)> { + Some(match iterable { + ast::Expr::ParenExpr(expr) => extract_range(&expr.expr()?)?, + ast::Expr::RangeExpr(range) => { + let inclusive = range.op_kind()? == ast::RangeOp::Inclusive; + (range.start()?, range.end(), make::expr_literal("1").into(), inclusive) + } + ast::Expr::MethodCallExpr(call) if call.name_ref()?.text() == "step_by" => { + let [step] = call.arg_list()?.args().collect_array()?; + let (start, end, _, inclusive) = extract_range(&call.receiver()?)?; + (start, end, step, inclusive) + } + _ => return None, + }) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn test_convert_range_for_to_while() { + check_assist( + convert_range_for_to_while, + " +fn foo() { + $0for i in 3..7 { + foo(i); + } +} + ", + " +fn foo() { + let mut i = 3; + while i < 7 { + foo(i); + i += 1; + } +} + ", + ); + } + + #[test] + fn test_convert_range_for_to_while_no_end_bound() { + check_assist( + convert_range_for_to_while, + " +fn foo() { + $0for i in 3.. { + foo(i); + } +} + ", + " +fn foo() { + let mut i = 3; + loop { + foo(i); + i += 1; + } +} + ", + ); + } + + #[test] + fn test_convert_range_for_to_while_with_mut_binding() { + check_assist( + convert_range_for_to_while, + " +fn foo() { + $0for mut i in 3..7 { + foo(i); + } +} + ", + " +fn foo() { + let mut i = 3; + while i < 7 { + foo(i); + i += 1; + } +} + ", + ); + } + + #[test] + fn test_convert_range_for_to_while_with_label() { + check_assist( + convert_range_for_to_while, + " +fn foo() { + 'a: $0for mut i in 3..7 { + foo(i); + } +} + ", + " +fn foo() { + let mut i = 3; + 'a: while i < 7 { + foo(i); + i += 1; + } +} + ", + ); + } + + #[test] + fn test_convert_range_for_to_while_step_by() { + check_assist( + convert_range_for_to_while, + " +fn foo() { + $0for mut i in (3..7).step_by(2) { + foo(i); + } +} + ", + " +fn foo() { + let mut i = 3; + while i < 7 { + foo(i); + i += 2; + } +} + ", + ); + } + + #[test] + fn test_convert_range_for_to_while_not_applicable_non_range() { + check_assist_not_applicable( + convert_range_for_to_while, + " +fn foo() { + let ident = 3..7; + $0for mut i in ident { + foo(i); + } +} + ", + ); + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs index e9f2d686465e..4b4aa9427955 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs @@ -131,6 +131,7 @@ mod handlers { mod convert_match_to_let_else; mod convert_named_struct_to_tuple_struct; mod convert_nested_function_to_closure; + mod convert_range_for_to_while; mod convert_to_guarded_return; mod convert_tuple_return_type_to_struct; mod convert_tuple_struct_to_named_struct; @@ -268,6 +269,7 @@ mod handlers { convert_match_to_let_else::convert_match_to_let_else, convert_named_struct_to_tuple_struct::convert_named_struct_to_tuple_struct, convert_nested_function_to_closure::convert_nested_function_to_closure, + convert_range_for_to_while::convert_range_for_to_while, convert_to_guarded_return::convert_to_guarded_return, convert_tuple_return_type_to_struct::convert_tuple_return_type_to_struct, convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct, diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs index a99fe8de333d..7f0836abdf3c 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs @@ -731,6 +731,29 @@ fn main() { ) } +#[test] +fn doctest_convert_range_for_to_while() { + check_doc_test( + "convert_range_for_to_while", + r#####" +fn foo() { + $0for i in 3..7 { + foo(i); + } +} +"#####, + r#####" +fn foo() { + let mut i = 3; + while i < 7 { + foo(i); + i += 1; + } +} +"#####, + ) +} + #[test] fn doctest_convert_to_guarded_return() { check_doc_test( diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs index 051c5835571b..dba39204e32e 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs @@ -1355,7 +1355,7 @@ pub mod tokens { pub(super) static SOURCE_FILE: LazyLock> = LazyLock::new(|| { SourceFile::parse( - "use crate::foo; const C: <()>::Item = ( true && true , true || true , 1 != 1, 2 == 2, 3 < 3, 4 <= 4, 5 > 5, 6 >= 6, !true, *p, &p , &mut p, async { let _ @ [] })\n;\n\nunsafe impl A for B where: {}", + "use crate::foo; const C: <()>::Item = ( true && true , true || true , 1 != 1, 2 == 2, 3 < 3, 4 <= 4, 5 > 5, 6 >= 6, !true, *p, &p , &mut p, async { let _ @ [] }, while loop {} {})\n;\n\nunsafe impl A for B where: {}", Edition::CURRENT, ) }); diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs index 8bf27f967482..969552392180 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -644,6 +644,20 @@ impl SyntaxFactory { ast } + pub fn expr_loop(&self, body: ast::BlockExpr) -> ast::LoopExpr { + let ast::Expr::LoopExpr(ast) = make::expr_loop(body.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(body.syntax().clone(), ast.loop_body().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn expr_while_loop(&self, condition: ast::Expr, body: ast::BlockExpr) -> ast::WhileExpr { let ast = make::expr_while_loop(condition.clone(), body.clone()).clone_for_update();