Merge pull request #20598 from A4-Tacks/let-chain-sup-conv-to-guarded-ret

Add let-chain support for convert_to_guarded_return
This commit is contained in:
Shoyu Vanilla (Flint) 2025-09-26 06:42:10 +00:00 committed by GitHub
commit acd320f7b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,13 +1,12 @@
use std::iter::once;
use ide_db::{
syntax_helpers::node_ext::{is_pattern_cond, single_let},
ty_filter::TryEnum,
};
use either::Either;
use hir::{Semantics, TypeInfo};
use ide_db::{RootDatabase, ty_filter::TryEnum};
use syntax::{
AstNode,
SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
T,
SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
SyntaxNode, T,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
@ -44,12 +43,9 @@ use crate::{
// }
// ```
pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
if let Some(let_stmt) = ctx.find_node_at_offset() {
let_stmt_to_guarded_return(let_stmt, acc, ctx)
} else if let Some(if_expr) = ctx.find_node_at_offset() {
if_expr_to_guarded_return(if_expr, acc, ctx)
} else {
None
match ctx.find_node_at_offset::<Either<ast::LetStmt, ast::IfExpr>>()? {
Either::Left(let_stmt) => let_stmt_to_guarded_return(let_stmt, acc, ctx),
Either::Right(if_expr) => if_expr_to_guarded_return(if_expr, acc, ctx),
}
}
@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
return None;
}
// Check if there is an IfLet that we can handle.
let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
let let_ = single_let(cond)?;
(Some(let_.pat()?), let_.expr()?)
} else {
(None, cond)
};
let let_chains = flat_let_chain(cond);
let then_block = if_expr.then_branch()?;
let then_block = then_block.stmt_list()?;
@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
let parent_container = parent_block.syntax().parent()?;
let early_expression: ast::Expr = match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
};
let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;
then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
target,
|edit| {
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
let replacement = match if_let_pat {
None => {
// If.
let new_expr = {
let then_branch =
make::block_expr(once(make::expr_stmt(early_expression).into()), None);
let cond = invert_boolean_expression_legacy(cond_expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
Some(pat) => {
let replacement = let_chains.into_iter().map(|expr| {
if let ast::Expr::LetExpr(let_expr) = &expr
&& let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
{
// If-let.
let let_else_stmt = make::let_else_stmt(
pat,
None,
cond_expr,
ast::make::tail_only_block_expr(early_expression),
expr,
ast::make::tail_only_block_expr(early_expression.clone()),
);
let let_else_stmt = let_else_stmt.indent(if_indent_level);
let_else_stmt.syntax().clone()
} else {
// If.
let new_expr = {
let then_branch = make::block_expr(
once(make::expr_stmt(early_expression.clone()).into()),
None,
);
let cond = invert_boolean_expression_legacy(expr);
make::expr_if(cond, then_branch, None).indent(if_indent_level)
};
new_expr.syntax().clone()
}
};
});
let newline = &format!("\n{if_indent_level}");
let then_statements = replacement
.children_with_tokens()
.enumerate()
.flat_map(|(i, node)| {
(i != 0)
.then(|| make::tokens::whitespace(newline).into())
.into_iter()
.chain(node.children_with_tokens())
})
.chain(
then_block_items
.syntax()
@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
let parent_container = parent_block.syntax().parent()?;
match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN => make::expr_return(None),
_ => return None,
}
early_expression(parent_container, &ctx.sema)?
};
acc.add(
@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
)
}
fn early_expression(
parent_container: SyntaxNode,
sema: &Semantics<'_, RootDatabase>,
) -> Option<ast::Expr> {
let return_none_expr = || {
let none_expr = make::expr_path(make::ext::ident_path("None"));
make::expr_return(Some(none_expr))
};
if let Some(fn_) = ast::Fn::cast(parent_container.clone())
&& let Some(fn_def) = sema.to_def(&fn_)
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
{
return Some(return_none_expr());
}
if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
&& let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
&& let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
{
return Some(return_none_expr());
}
Some(match parent_container.kind() {
WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
FN | CLOSURE_EXPR => make::expr_return(None),
_ => return None,
})
}
fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
let mut chains = vec![];
while let ast::Expr::BinExpr(bin_expr) = &expr
&& bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
&& let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
{
if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
} else {
chains.push(rhs);
}
expr = lhs;
}
chains.push(expr);
chains.reverse();
chains
}
#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
@ -268,6 +308,71 @@ fn main() {
);
}
#[test]
fn convert_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn ret_option() -> Option<()> {
bar();
if$0 true {
foo();
// comment
bar();
}
}
"#,
r#"
fn ret_option() -> Option<()> {
bar();
if false {
return None;
}
foo();
// comment
bar();
}
"#,
);
}
#[test]
fn convert_inside_closure() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
let _f = || {
bar();
if$0 true {
foo();
// comment
bar();
}
}
}
"#,
r#"
fn main() {
let _f = || {
bar();
if false {
return;
}
foo();
// comment
bar();
}
}
"#,
);
}
#[test]
fn convert_let_inside_fn() {
check_assist(
@ -316,6 +421,82 @@ fn main() {
);
}
#[test]
fn convert_if_let_result_inside_let() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
let _x = loop {
if$0 let Ok(x) = Err(92) {
foo(x);
}
};
}
"#,
r#"
fn main() {
let _x = loop {
let Ok(x) = Err(92) else { continue };
foo(x);
};
}
"#,
);
}
#[test]
fn convert_if_let_chain_result() {
check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if x >= 30 {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);
check_assist(
convert_to_guarded_return,
r#"
fn main() {
if$0 let Ok(x) = Err(92)
&& x < 30
&& y < 20
&& let Some(y) = Some(8)
{
foo(x, y);
}
}
"#,
r#"
fn main() {
let Ok(x) = Err(92) else { return };
if !(x < 30 && y < 20) {
return;
}
let Some(y) = Some(8) else { return };
foo(x, y);
}
"#,
);
}
#[test]
fn convert_let_ok_inside_fn() {
check_assist(
@ -560,6 +741,32 @@ fn main() {
);
}
#[test]
fn convert_let_stmt_inside_fn_return_option() {
check_assist(
convert_to_guarded_return,
r#"
//- minicore: option
fn foo() -> Option<i32> {
None
}
fn ret_option() -> Option<i32> {
let x$0 = foo();
}
"#,
r#"
fn foo() -> Option<i32> {
None
}
fn ret_option() -> Option<i32> {
let Some(x) = foo() else { return None };
}
"#,
);
}
#[test]
fn convert_let_stmt_inside_loop() {
check_assist(