diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index d11f80cffb84..45a01679ce2f 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -522,6 +522,8 @@ struct InferenceContext<'a, D: HirDatabase> { impl_block: Option, var_unification_table: InPlaceUnificationTable, type_of: FxHashMap, + /// The return type of the function being inferred. + return_ty: Ty, } impl<'a, D: HirDatabase> InferenceContext<'a, D> { @@ -534,7 +536,8 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { InferenceContext { type_of: FxHashMap::default(), var_unification_table: InPlaceUnificationTable::new(), - self_param: None, // set during parameter typing + self_param: None, // set during parameter typing + return_ty: Ty::Unknown, // set in collect_fn_signature db, scopes, module, @@ -555,6 +558,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { self.type_of.insert(LocalSyntaxPtr::new(node), ty); } + fn make_ty(&self, type_ref: &TypeRef) -> Cancelable { + Ty::from_hir(self.db, &self.module, self.impl_block.as_ref(), type_ref) + } + + fn make_ty_opt(&self, type_ref: Option<&TypeRef>) -> Cancelable { + Ty::from_hir_opt(self.db, &self.module, self.impl_block.as_ref(), type_ref) + } + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { match (ty1, ty2) { (Ty::Unknown, ..) => true, @@ -952,6 +963,55 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { self.write_ty(node.syntax(), ty.clone()); Ok(ty) } + + fn collect_fn_signature(&mut self, node: ast::FnDef) -> Cancelable<()> { + if let Some(param_list) = node.param_list() { + if let Some(self_param) = param_list.self_param() { + let self_type = if let Some(type_ref) = self_param.type_ref() { + let ty = self.make_ty(&TypeRef::from_ast(type_ref))?; + self.insert_type_vars(ty) + } else { + // TODO this should be handled by desugaring during HIR conversion + let ty = self.make_ty_opt(self.impl_block.as_ref().map(|i| i.target()))?; + let ty = match self_param.flavor() { + ast::SelfParamFlavor::Owned => ty, + ast::SelfParamFlavor::Ref => Ty::Ref(Arc::new(ty), Mutability::Shared), + ast::SelfParamFlavor::MutRef => Ty::Ref(Arc::new(ty), Mutability::Mut), + }; + self.insert_type_vars(ty) + }; + if let Some(self_kw) = self_param.self_kw() { + let self_param = LocalSyntaxPtr::new(self_kw.syntax()); + self.self_param = Some(self_param); + self.type_of.insert(self_param, self_type); + } + } + for param in param_list.params() { + let pat = if let Some(pat) = param.pat() { + pat + } else { + continue; + }; + let ty = if let Some(type_ref) = param.type_ref() { + let ty = self.make_ty(&TypeRef::from_ast(type_ref))?; + self.insert_type_vars(ty) + } else { + // missing type annotation + self.new_type_var() + }; + self.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty); + } + } + + self.return_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) { + let ty = self.make_ty(&TypeRef::from_ast(type_ref))?; + self.insert_type_vars(ty) + } else { + Ty::unit() + }; + + Ok(()) + } } pub fn infer(db: &impl HirDatabase, def_id: DefId) -> Cancelable> { @@ -964,66 +1024,10 @@ pub fn infer(db: &impl HirDatabase, def_id: DefId) -> Cancelable ty, - ast::SelfParamFlavor::Ref => Ty::Ref(Arc::new(ty), Mutability::Shared), - ast::SelfParamFlavor::MutRef => Ty::Ref(Arc::new(ty), Mutability::Mut), - }; - ctx.insert_type_vars(ty) - } - } else { - log::debug!( - "No impl block found, but self param for function {:?}", - def_id - ); - ctx.new_type_var() - }; - if let Some(self_kw) = self_param.self_kw() { - let self_param = LocalSyntaxPtr::new(self_kw.syntax()); - ctx.self_param = Some(self_param); - ctx.type_of.insert(self_param, self_type); - } - } - for param in param_list.params() { - let pat = if let Some(pat) = param.pat() { - pat - } else { - continue; - }; - let ty = if let Some(type_ref) = param.type_ref() { - let ty = Ty::from_ast(db, &ctx.module, ctx.impl_block.as_ref(), type_ref)?; - ctx.insert_type_vars(ty) - } else { - // missing type annotation - ctx.new_type_var() - }; - ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty); - } - } - - let ret_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) { - let ty = Ty::from_ast(db, &ctx.module, ctx.impl_block.as_ref(), type_ref)?; - ctx.insert_type_vars(ty) - } else { - Ty::unit() - }; + ctx.collect_fn_signature(node)?; if let Some(block) = node.body() { - ctx.infer_block(block, &Expectation::has_type(ret_ty))?; + ctx.infer_block(block, &Expectation::has_type(ctx.return_ty.clone()))?; } Ok(Arc::new(ctx.resolve_all()))