diff --git a/crates/ide_assists/src/handlers/generate_function.rs b/crates/ide_assists/src/handlers/generate_function.rs index 3870b7e75af6..6f95b1a07398 100644 --- a/crates/ide_assists/src/handlers/generate_function.rs +++ b/crates/ide_assists/src/handlers/generate_function.rs @@ -83,17 +83,18 @@ struct FunctionTemplate { leading_ws: String, fn_def: ast::Fn, ret_type: ast::RetType, + should_render_snippet: bool, trailing_ws: String, file: FileId, } impl FunctionTemplate { fn to_string(&self, cap: Option) -> String { - let f = match cap { - Some(cap) => { + let f = match (cap, self.should_render_snippet) { + (Some(cap), true) => { render_snippet(cap, self.fn_def.syntax(), Cursor::Replace(self.ret_type.syntax())) } - None => self.fn_def.to_string(), + _ => self.fn_def.to_string(), }; format!("{}{}{}", self.leading_ws, f, self.trailing_ws) } @@ -104,6 +105,8 @@ struct FunctionBuilder { fn_name: ast::Name, type_params: Option, params: ast::ParamList, + ret_type: ast::RetType, + should_render_snippet: bool, file: FileId, needs_pub: bool, } @@ -132,7 +135,43 @@ impl FunctionBuilder { let fn_name = fn_name(&path)?; let (type_params, params) = fn_args(ctx, target_module, &call)?; - Some(Self { target, fn_name, type_params, params, file, needs_pub }) + // should_render_snippet intends to express a rough level of confidence about + // the correctness of the return type. + // + // If we are able to infer some return type, and that return type is not unit, we + // don't want to render the snippet. The assumption here is in this situation the + // return type is just as likely to be correct as any other part of the generated + // function. + // + // In the case where the return type is inferred as unit it is likely that the + // user does in fact intend for this generated function to return some non unit + // type, but that the current state of their code doesn't allow that return type + // to be accurately inferred. + let (ret_ty, should_render_snippet) = { + match ctx.sema.type_of_expr(&ast::Expr::CallExpr(call.clone())) { + Some(ty) if ty.is_unknown() || ty.is_unit() => (make::ty_unit(), true), + Some(ty) => { + let rendered = ty.display_source_code(ctx.db(), target_module.into()); + match rendered { + Ok(rendered) => (make::ty(&rendered), false), + Err(_) => (make::ty_unit(), true), + } + } + None => (make::ty_unit(), true), + } + }; + let ret_type = make::ret_type(ret_ty); + + Some(Self { + target, + fn_name, + type_params, + params, + ret_type, + should_render_snippet, + file, + needs_pub, + }) } fn render(self) -> FunctionTemplate { @@ -145,7 +184,7 @@ impl FunctionBuilder { self.type_params, self.params, fn_body, - Some(make::ret_type(make::ty_unit())), + Some(self.ret_type), ); let leading_ws; let trailing_ws; @@ -171,6 +210,7 @@ impl FunctionBuilder { insert_offset, leading_ws, ret_type: fn_def.ret_type().unwrap(), + should_render_snippet: self.should_render_snippet, fn_def, trailing_ws, file: self.file, @@ -546,7 +586,7 @@ impl Baz { } } -fn bar(baz: Baz) ${0:-> ()} { +fn bar(baz: Baz) -> Baz { todo!() } ", @@ -1059,6 +1099,27 @@ pub(crate) fn bar() ${0:-> ()} { ) } + #[test] + fn add_function_with_return_type() { + check_assist( + generate_function, + r" +fn main() { + let x: u32 = foo$0(); +} +", + r" +fn main() { + let x: u32 = foo(); +} + +fn foo() -> u32 { + todo!() +} +", + ) + } + #[test] fn add_function_not_applicable_if_function_already_exists() { check_assist_not_applicable(