Add inherit attributes for extract_function assist

Example
---
```rust
#[cfg(test)]
fn foo() {
    foo($01 + 1$0);
}
```

**Before this PR**

```rust
#[cfg(test)]
fn foo() {
    foo(fun_name());
}

fn $0fun_name() -> i32 {
    1 + 1
}
```

**After this PR**

```rust
#[cfg(test)]
fn foo() {
    foo(fun_name());
}

#[cfg(test)]
fn $0fun_name() -> i32 {
    1 + 1
}
```
This commit is contained in:
A4-Tacks 2026-01-11 15:06:38 +08:00
parent 491c3202a5
commit e40bd1cf6e
No known key found for this signature in database
GPG key ID: 9E63F956E66DD9C7

View file

@ -25,7 +25,7 @@ use syntax::{
SyntaxKind::{self, COMMENT},
SyntaxNode, SyntaxToken, T, TextRange, TextSize, TokenAtOffset, WalkEvent,
ast::{
self, AstNode, AstToken, HasGenericParams, HasName, edit::IndentLevel,
self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, edit::IndentLevel,
edit_in_place::Indent,
},
match_ast, ted,
@ -375,6 +375,7 @@ struct ContainerInfo<'db> {
ret_type: Option<hir::Type<'db>>,
generic_param_lists: Vec<ast::GenericParamList>,
where_clauses: Vec<ast::WhereClause>,
attrs: Vec<ast::Attr>,
edition: Edition,
}
@ -911,6 +912,7 @@ impl FunctionBody {
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
let attrs = parents.iter().flat_map(|it| it.attrs()).filter(is_inherit_attr).collect();
Some((
ContainerInfo {
@ -919,6 +921,7 @@ impl FunctionBody {
ret_type: ty,
generic_param_lists,
where_clauses,
attrs,
edition,
},
contains_tail_expr,
@ -1103,6 +1106,14 @@ impl GenericParent {
GenericParent::Trait(trait_) => trait_.where_clause(),
}
}
fn attrs(&self) -> impl Iterator<Item = ast::Attr> {
match self {
GenericParent::Fn(fn_) => fn_.attrs(),
GenericParent::Impl(impl_) => impl_.attrs(),
GenericParent::Trait(trait_) => trait_.attrs(),
}
}
}
/// Search `parent`'s ancestors for items with potentially applicable generic parameters
@ -1578,7 +1589,7 @@ fn format_function(
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
make::fn_(
None,
fun.mods.attrs.clone(),
None,
fun_name,
generic_params,
@ -1958,6 +1969,11 @@ fn format_type(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module)
ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_owned())
}
fn is_inherit_attr(attr: &ast::Attr) -> bool {
let Some(name) = attr.simple_name() else { return false };
matches!(name.as_str(), "track_caller" | "cfg")
}
fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type {
let ty_str = format_type(ty, ctx, module);
make::ty(&ty_str)
@ -6372,6 +6388,55 @@ fn foo() {
fn $0fun_name(mut a: i32, mut b: i32) {
(a, b) = (b, a);
}
"#,
);
}
#[test]
fn with_cfg_attr() {
check_assist(
extract_function,
r#"
//- /main.rs crate:main cfg:test
#[cfg(test)]
fn foo() {
foo($01 + 1$0);
}
"#,
r#"
#[cfg(test)]
fn foo() {
foo(fun_name());
}
#[cfg(test)]
fn $0fun_name() -> i32 {
1 + 1
}
"#,
);
}
#[test]
fn with_track_caller() {
check_assist(
extract_function,
r#"
#[track_caller]
fn foo() {
foo($01 + 1$0);
}
"#,
r#"
#[track_caller]
fn foo() {
foo(fun_name());
}
#[track_caller]
fn $0fun_name() -> i32 {
1 + 1
}
"#,
);
}