diff --git a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs index ce9eeb820074..d3192ae40914 100644 --- a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs +++ b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs @@ -1,8 +1,8 @@ use crate::assist_context::{AssistContext, Assists}; -use ide_db::{assists::AssistId, SnippetCap}; +use ide_db::assists::AssistId; use syntax::{ - ast::{self, HasGenericParams, HasVisibility}, - AstNode, + ast::{self, edit::IndentLevel, make, HasGenericParams, HasVisibility}, + ted, AstNode, SyntaxKind, }; // NOTES : @@ -68,6 +68,16 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ // Get AST Node let impl_ast = ctx.find_node_at_offset::()?; + // Check if cursor is to the left of assoc item list's L_CURLY. + // if no L_CURLY then return. + let l_curly = impl_ast.assoc_item_list()?.l_curly_token()?; + + let cursor_offset = ctx.offset(); + let l_curly_offset = l_curly.text_range(); + if cursor_offset >= l_curly_offset.start() { + return None; + } + // If impl is not inherent then we don't really need to go any further. if impl_ast.for_token().is_some() { return None; @@ -80,9 +90,11 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ return None; } + let impl_name = impl_ast.self_ty()?; + acc.add( AssistId("generate_trait_from_impl", ide_db::assists::AssistKind::Generate), - "Generate trait from impl".to_owned(), + "Generate trait from impl", impl_ast.syntax().text_range(), |builder| { let trait_items = assoc_items.clone_for_update(); @@ -93,45 +105,43 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ remove_items_visibility(&item); }); - syntax::ted::replace(assoc_items.clone_for_update().syntax(), impl_items.syntax()); + ted::replace(assoc_items.clone_for_update().syntax(), impl_items.syntax()); impl_items.assoc_items().for_each(|item| { remove_items_visibility(&item); }); - let trait_ast = ast::make::trait_( + let trait_ast = make::trait_( false, - "NewTrait".to_string(), - HasGenericParams::generic_param_list(&impl_ast), - HasGenericParams::where_clause(&impl_ast), + "NewTrait", + impl_ast.generic_param_list(), + impl_ast.where_clause(), trait_items, ); // Change `impl Foo` to `impl NewTrait for Foo` - // First find the PATH_TYPE which is what Foo is. - let impl_name = impl_ast.self_ty().unwrap(); - let trait_name = if let Some(genpars) = impl_ast.generic_param_list() { - format!("NewTrait{}", genpars.to_generic_args()) + let arg_list = if let Some(genpars) = impl_ast.generic_param_list() { + genpars.to_generic_args().to_string() } else { - format!("NewTrait") + "".to_string() }; // // Then replace builder.replace( - impl_name.clone().syntax().text_range(), - format!("{} for {}", trait_name, impl_name.to_string()), + impl_name.syntax().text_range(), + format!("NewTrait{} for {}", arg_list, impl_name.to_string()), ); - builder.replace( - impl_ast.assoc_item_list().unwrap().syntax().text_range(), - impl_items.to_string(), - ); + builder.replace(assoc_items.syntax().text_range(), impl_items.to_string()); // Insert trait before TraitImpl - builder.insert_snippet( - SnippetCap::new(true).unwrap(), + builder.insert( impl_ast.syntax().text_range().start(), - format!("{}\n\n", trait_ast.to_string()), + format!( + "{}\n\n{}", + trait_ast.to_string(), + IndentLevel::from_node(impl_ast.syntax()) + ), ); }, ); @@ -144,17 +154,17 @@ fn remove_items_visibility(item: &ast::AssocItem) { match item { ast::AssocItem::Const(c) => { if let Some(vis) = c.visibility() { - syntax::ted::remove(vis.syntax()); + ted::remove(vis.syntax()); } } ast::AssocItem::Fn(f) => { if let Some(vis) = f.visibility() { - syntax::ted::remove(vis.syntax()); + ted::remove(vis.syntax()); } } ast::AssocItem::TypeAlias(t) => { if let Some(vis) = t.visibility() { - syntax::ted::remove(vis.syntax()); + ted::remove(vis.syntax()); } } _ => (), @@ -168,12 +178,12 @@ fn strip_body(item: &ast::AssocItem) { // In constrast to function bodies, we want to see no ws before a semicolon. // So let's remove them if we see any. if let Some(prev) = body.syntax().prev_sibling_or_token() { - if prev.kind() == syntax::SyntaxKind::WHITESPACE { - syntax::ted::remove(prev); + if prev.kind() == SyntaxKind::WHITESPACE { + ted::remove(prev); } } - syntax::ted::replace(body.syntax(), ast::make::tokens::semicolon()); + ted::replace(body.syntax(), ast::make::tokens::semicolon()); } } _ => (), @@ -185,6 +195,21 @@ mod tests { use super::*; use crate::tests::{check_assist, check_assist_not_applicable}; + #[test] + fn test_trigger_when_cursor_on_header() { + check_assist_not_applicable( + generate_trait_from_impl, + r#" +struct Foo(f64); + +impl Foo { $0 + fn add(&mut self, x: f64) { + self.0 += x; + } +}"#, + ); + } + #[test] fn test_assoc_item_fn() { check_assist( @@ -299,7 +324,7 @@ impl NewTrait for Foo { } #[test] - fn test_e0449_avoided() { + fn test_trait_items_should_not_have_vis() { check_assist( generate_trait_from_impl, r#" @@ -334,4 +359,27 @@ impl Emp$0tyImpl{} "#, ) } + + #[test] + fn test_not_top_level_impl() { + check_assist( + generate_trait_from_impl, + r#" +mod a { + impl S$0 { + fn foo() {} + } +}"#, + r#" +mod a { + trait NewTrait { + fn foo(); + } + + impl NewTrait for S { + fn foo() {} + } +}"#, + ) + } } diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 1675d1af1dd2..3facd90a11d9 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -865,7 +865,7 @@ pub fn param_list( pub fn trait_( is_unsafe: bool, - ident: String, + ident: &str, gen_params: Option, where_clause: Option, assoc_items: ast::AssocItemList,