diff --git a/crates/ra_hir/src/either.rs b/crates/ra_hir/src/either.rs
index 71c53ebc0662..439e6ec87295 100644
--- a/crates/ra_hir/src/either.rs
+++ b/crates/ra_hir/src/either.rs
@@ -25,6 +25,12 @@ impl Either {
Either::B(b) => Either::B(f2(b)),
}
}
+ pub fn map_a(self, f: F) -> Either
+ where
+ F: FnOnce(A) -> U,
+ {
+ self.map(f, |it| it)
+ }
pub fn a(self) -> Option {
match self {
Either::A(it) => Some(it),
diff --git a/crates/ra_hir/src/source_binder.rs b/crates/ra_hir/src/source_binder.rs
index 55eb7da35654..df67d2c39389 100644
--- a/crates/ra_hir/src/source_binder.rs
+++ b/crates/ra_hir/src/source_binder.rs
@@ -309,15 +309,11 @@ impl SourceAnalyzer {
crate::Resolution::LocalBinding(it) => {
// We get a `PatId` from resolver, but it actually can only
// point at `BindPat`, and not at the arbitrary pattern.
- let pat_ptr = self.body_source_map.as_ref()?.pat_syntax(it)?;
- let pat_ptr = match pat_ptr {
- Either::A(pat) => {
- let pat: AstPtr =
- pat.cast_checking_kind(|kind| kind == BIND_PAT).unwrap();
- Either::A(pat)
- }
- Either::B(self_param) => Either::B(self_param),
- };
+ let pat_ptr = self
+ .body_source_map
+ .as_ref()?
+ .pat_syntax(it)?
+ .map_a(|ptr| ptr.cast::().unwrap());
PathResolution::LocalBinding(pat_ptr)
}
crate::Resolution::GenericParam(it) => PathResolution::GenericParam(it),
diff --git a/crates/ra_syntax/src/ast.rs b/crates/ra_syntax/src/ast.rs
index ceb603c5052f..4a38197f6ff4 100644
--- a/crates/ra_syntax/src/ast.rs
+++ b/crates/ra_syntax/src/ast.rs
@@ -10,7 +10,7 @@ use std::marker::PhantomData;
use crate::{
syntax_node::{SyntaxNode, SyntaxNodeChildren, SyntaxToken},
- SmolStr,
+ SmolStr, SyntaxKind,
};
pub use self::{
@@ -26,6 +26,8 @@ pub use self::{
/// the same representation: a pointer to the tree root and a pointer to the
/// node itself.
pub trait AstNode: Clone {
+ fn can_cast(kind: SyntaxKind) -> bool;
+
fn cast(syntax: SyntaxNode) -> Option
where
Self: Sized;
diff --git a/crates/ra_syntax/src/ast/generated.rs b/crates/ra_syntax/src/ast/generated.rs
index a1f320257892..99fcdbd9ac8d 100644
--- a/crates/ra_syntax/src/ast/generated.rs
+++ b/crates/ra_syntax/src/ast/generated.rs
@@ -10,7 +10,7 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use crate::{
- SyntaxNode, SyntaxKind::*,
+ SyntaxNode, SyntaxKind::{self, *},
ast::{self, AstNode},
};
@@ -21,12 +21,15 @@ pub struct Alias {
}
impl AstNode for Alias {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ALIAS => Some(Alias { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ALIAS => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(Alias { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -41,12 +44,15 @@ pub struct ArgList {
}
impl AstNode for ArgList {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ARG_LIST => Some(ArgList { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ARG_LIST => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(ArgList { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -64,12 +70,15 @@ pub struct ArrayExpr {
}
impl AstNode for ArrayExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ARRAY_EXPR => Some(ArrayExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ARRAY_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(ArrayExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -87,12 +96,15 @@ pub struct ArrayType {
}
impl AstNode for ArrayType {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ARRAY_TYPE => Some(ArrayType { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ARRAY_TYPE => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(ArrayType { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -114,12 +126,15 @@ pub struct AssocTypeArg {
}
impl AstNode for AssocTypeArg {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ASSOC_TYPE_ARG => Some(AssocTypeArg { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ASSOC_TYPE_ARG => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(AssocTypeArg { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -141,12 +156,15 @@ pub struct Attr {
}
impl AstNode for Attr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ATTR => Some(Attr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ATTR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(Attr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -164,12 +182,15 @@ pub struct BinExpr {
}
impl AstNode for BinExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- BIN_EXPR => Some(BinExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ BIN_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(BinExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -183,12 +204,15 @@ pub struct BindPat {
}
impl AstNode for BindPat {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- BIND_PAT => Some(BindPat { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ BIND_PAT => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(BindPat { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -207,12 +231,15 @@ pub struct Block {
}
impl AstNode for Block {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- BLOCK => Some(Block { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ BLOCK => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(Block { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -235,12 +262,15 @@ pub struct BlockExpr {
}
impl AstNode for BlockExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- BLOCK_EXPR => Some(BlockExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ BLOCK_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(BlockExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -258,12 +288,15 @@ pub struct BreakExpr {
}
impl AstNode for BreakExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- BREAK_EXPR => Some(BreakExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ BREAK_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(BreakExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -281,12 +314,15 @@ pub struct CallExpr {
}
impl AstNode for CallExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- CALL_EXPR => Some(CallExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ CALL_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(CallExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -305,12 +341,15 @@ pub struct CastExpr {
}
impl AstNode for CastExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- CAST_EXPR => Some(CastExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ CAST_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(CastExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -332,12 +371,15 @@ pub struct Condition {
}
impl AstNode for Condition {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- CONDITION => Some(Condition { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ CONDITION => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(Condition { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -359,12 +401,15 @@ pub struct ConstDef {
}
impl AstNode for ConstDef {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- CONST_DEF => Some(ConstDef { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ CONST_DEF => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(ConstDef { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -388,12 +433,15 @@ pub struct ContinueExpr {
}
impl AstNode for ContinueExpr {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- CONTINUE_EXPR => Some(ContinueExpr { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ CONTINUE_EXPR => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(ContinueExpr { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -407,12 +455,15 @@ pub struct DynTraitType {
}
impl AstNode for DynTraitType {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- DYN_TRAIT_TYPE => Some(DynTraitType { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ DYN_TRAIT_TYPE => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(DynTraitType { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -427,12 +478,15 @@ pub struct EnumDef {
}
impl AstNode for EnumDef {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ENUM_DEF => Some(EnumDef { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ENUM_DEF => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(EnumDef { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -455,12 +509,15 @@ pub struct EnumVariant {
}
impl AstNode for EnumVariant {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ENUM_VARIANT => Some(EnumVariant { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ENUM_VARIANT => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(EnumVariant { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -481,12 +538,15 @@ pub struct EnumVariantList {
}
impl AstNode for EnumVariantList {
- fn cast(syntax: SyntaxNode) -> Option {
- match syntax.kind() {
- ENUM_VARIANT_LIST => Some(EnumVariantList { syntax }),
- _ => None,
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ ENUM_VARIANT_LIST => true,
+ _ => false,
}
}
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(EnumVariantList { syntax }) } else { None }
+ }
fn syntax(&self) -> &SyntaxNode { &self.syntax }
}
@@ -503,6 +563,20 @@ pub struct Expr {
pub(crate) syntax: SyntaxNode,
}
+impl AstNode for Expr {
+ fn can_cast(kind: SyntaxKind) -> bool {
+ match kind {
+ | TUPLE_EXPR | ARRAY_EXPR | PAREN_EXPR | PATH_EXPR | LAMBDA_EXPR | IF_EXPR | LOOP_EXPR | FOR_EXPR | WHILE_EXPR | CONTINUE_EXPR | BREAK_EXPR | LABEL | BLOCK_EXPR | RETURN_EXPR | MATCH_EXPR | STRUCT_LIT | CALL_EXPR | INDEX_EXPR | METHOD_CALL_EXPR | FIELD_EXPR | TRY_EXPR | TRY_BLOCK_EXPR | CAST_EXPR | REF_EXPR | PREFIX_EXPR | RANGE_EXPR | BIN_EXPR | LITERAL | MACRO_CALL => true,
+ _ => false,
+ }
+ }
+ fn cast(syntax: SyntaxNode) -> Option {
+ if Self::can_cast(syntax.kind()) { Some(Expr { syntax }) } else { None }
+ }
+ fn syntax(&self) -> &SyntaxNode { &self.syntax }
+}
+
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExprKind {
TupleExpr(TupleExpr),
@@ -536,190 +610,92 @@ pub enum ExprKind {
MacroCall(MacroCall),
}
impl From for Expr {
- fn from(n: TupleExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: TupleExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: ArrayExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: ArrayExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: ParenExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: ParenExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: PathExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: PathExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: LambdaExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: LambdaExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: IfExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: IfExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: LoopExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: LoopExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: ForExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: ForExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: WhileExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: WhileExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: ContinueExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: ContinueExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From for Expr {
- fn from(n: BreakExpr) -> Expr {
- Expr::cast(n.syntax).unwrap()
- }
+ fn from(n: BreakExpr) -> Expr { Expr { syntax: n.syntax } }
}
impl From