[clang] a711a3a - [Syntax] Build mapping from AST to syntax tree nodes

Dmitri Gribenko via cfe-commits cfe-commits at lists.llvm.org
Mon Mar 23 08:26:34 PDT 2020


Author: Marcel Hlopko
Date: 2020-03-23T16:22:14+01:00
New Revision: a711a3a46039154c38eade8bef1138b77fdb05ee

URL: https://github.com/llvm/llvm-project/commit/a711a3a46039154c38eade8bef1138b77fdb05ee
DIFF: https://github.com/llvm/llvm-project/commit/a711a3a46039154c38eade8bef1138b77fdb05ee.diff

LOG: [Syntax] Build mapping from AST to syntax tree nodes

Summary:
Copy of https://reviews.llvm.org/D72446, submitting with Ilya's permission.

Only used to assign roles to child nodes for now. This is more efficient
than doing range-based queries.

In the future, will be exposed in the public API of syntax trees.

Reviewers: gribozavr2

Reviewed By: gribozavr2

Subscribers: cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D76355

Added: 
    

Modified: 
    clang/include/clang/Tooling/Syntax/Tree.h
    clang/lib/Tooling/Syntax/BuildTree.cpp
    clang/lib/Tooling/Syntax/Mutations.cpp
    clang/lib/Tooling/Syntax/Tree.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Tooling/Syntax/Tree.h b/clang/include/clang/Tooling/Syntax/Tree.h
index 8702fe60ce1b..bc581004c46e 100644
--- a/clang/include/clang/Tooling/Syntax/Tree.h
+++ b/clang/include/clang/Tooling/Syntax/Tree.h
@@ -126,6 +126,8 @@ class Node {
   // FactoryImpl sets CanModify flag.
   friend class FactoryImpl;
 
+  void setRole(NodeRole NR);
+
   Tree *Parent;
   Node *NextSibling;
   unsigned Kind : 16;
@@ -171,8 +173,11 @@ class Tree : public Node {
   /// Prepend \p Child to the list of children and and sets the parent pointer.
   /// A very low-level operation that does not check any invariants, only used
   /// by TreeBuilder and FactoryImpl.
-  /// EXPECTS: Role != NodeRoleDetached.
+  /// EXPECTS: Role != Detached.
   void prependChildLowLevel(Node *Child, NodeRole Role);
+  /// Like the previous overload, but does not set role for \p Child.
+  /// EXPECTS: Child->Role != Detached
+  void prependChildLowLevel(Node *Child);
   friend class TreeBuilder;
   friend class FactoryImpl;
 

diff  --git a/clang/lib/Tooling/Syntax/BuildTree.cpp b/clang/lib/Tooling/Syntax/BuildTree.cpp
index a09ac1c53e34..4103a4d92c7d 100644
--- a/clang/lib/Tooling/Syntax/BuildTree.cpp
+++ b/clang/lib/Tooling/Syntax/BuildTree.cpp
@@ -25,6 +25,8 @@
 #include "clang/Tooling/Syntax/Tokens.h"
 #include "clang/Tooling/Syntax/Tree.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallVector.h"
@@ -34,6 +36,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cstddef>
 #include <map>
 
 using namespace clang;
@@ -145,6 +148,30 @@ static SourceRange getDeclaratorRange(const SourceManager &SM, TypeLoc T,
   return SourceRange(Start, End);
 }
 
+namespace {
+/// All AST hierarchy roots that can be represented as pointers.
+using ASTPtr = llvm::PointerUnion<Stmt *, Decl *>;
+/// Maintains a mapping from AST to syntax tree nodes. This class will get more
+/// complicated as we support more kinds of AST nodes, e.g. TypeLocs.
+/// FIXME: expose this as public API.
+class ASTToSyntaxMapping {
+public:
+  void add(ASTPtr From, syntax::Tree *To) {
+    assert(To != nullptr);
+    assert(!From.isNull());
+
+    bool Added = Nodes.insert({From, To}).second;
+    (void)Added;
+    assert(Added && "mapping added twice");
+  }
+
+  syntax::Tree *find(ASTPtr P) const { return Nodes.lookup(P); }
+
+private:
+  llvm::DenseMap<ASTPtr, syntax::Tree *> Nodes;
+};
+} // namespace
+
 /// A helper class for constructing the syntax tree while traversing a clang
 /// AST.
 ///
@@ -172,7 +199,18 @@ class syntax::TreeBuilder {
 
   /// Populate children for \p New node, assuming it covers tokens from \p
   /// Range.
-  void foldNode(llvm::ArrayRef<syntax::Token> Range, syntax::Tree *New);
+  void foldNode(llvm::ArrayRef<syntax::Token> Range, syntax::Tree *New,
+                ASTPtr From) {
+    assert(New);
+    Pending.foldChildren(Arena, Range, New);
+    if (From)
+      Mapping.add(From, New);
+  }
+  void foldNode(llvm::ArrayRef<syntax::Token> Range, syntax::Tree *New,
+                TypeLoc L) {
+    // FIXME: add mapping for TypeLocs
+    foldNode(Range, New, nullptr);
+  }
 
   /// Must be called with the range of each `DeclaratorDecl`. Ensures the
   /// corresponding declarator nodes are covered by `SimpleDeclaration`.
@@ -195,8 +233,10 @@ class syntax::TreeBuilder {
   /// Set role for \p T.
   void markChildToken(const syntax::Token *T, NodeRole R);
 
-  /// Set role for the node that spans exactly \p Range.
-  void markChild(llvm::ArrayRef<syntax::Token> Range, NodeRole R);
+  /// Set role for \p N.
+  void markChild(syntax::Node *N, NodeRole R);
+  /// Set role for the syntax node matching \p N.
+  void markChild(ASTPtr N, NodeRole R);
   /// Set role for the delayed node that spans exactly \p Range.
   void markDelayedChild(llvm::ArrayRef<syntax::Token> Range, NodeRole R);
   /// Set role for the node that may or may not be delayed. Node must span
@@ -221,8 +261,13 @@ class syntax::TreeBuilder {
   /// Finds a token starting at \p L. The token must exist if \p L is valid.
   const syntax::Token *findToken(SourceLocation L) const;
 
-  /// getRange() finds the syntax tokens corresponding to the passed source
-  /// locations.
+  /// Finds the syntax tokens corresponding to the \p SourceRange.
+  llvm::ArrayRef<syntax::Token> getRange(SourceRange Range) const {
+    assert(Range.isValid());
+    return getRange(Range.getBegin(), Range.getEnd());
+  }
+
+  /// Finds the syntax tokens corresponding to the passed source locations.
   /// \p First is the start position of the first token and \p Last is the start
   /// position of the last token.
   llvm::ArrayRef<syntax::Token> getRange(SourceLocation First,
@@ -236,8 +281,7 @@ class syntax::TreeBuilder {
 
   llvm::ArrayRef<syntax::Token>
   getTemplateRange(const ClassTemplateSpecializationDecl *D) const {
-    auto R = D->getSourceRange();
-    auto Tokens = getRange(R.getBegin(), R.getEnd());
+    auto Tokens = getRange(D->getSourceRange());
     return maybeAppendSemicolon(Tokens, D);
   }
 
@@ -247,16 +291,16 @@ class syntax::TreeBuilder {
     if (const auto *S = llvm::dyn_cast<TagDecl>(D))
       Tokens = getRange(S->TypeDecl::getBeginLoc(), S->getEndLoc());
     else
-      Tokens = getRange(D->getBeginLoc(), D->getEndLoc());
+      Tokens = getRange(D->getSourceRange());
     return maybeAppendSemicolon(Tokens, D);
   }
   llvm::ArrayRef<syntax::Token> getExprRange(const Expr *E) const {
-    return getRange(E->getBeginLoc(), E->getEndLoc());
+    return getRange(E->getSourceRange());
   }
   /// Find the adjusted range for the statement, consuming the trailing
   /// semicolon when needed.
   llvm::ArrayRef<syntax::Token> getStmtRange(const Stmt *S) const {
-    auto Tokens = getRange(S->getBeginLoc(), S->getEndLoc());
+    auto Tokens = getRange(S->getSourceRange());
     if (isa<CompoundStmt>(S))
       return Tokens;
 
@@ -290,6 +334,11 @@ class syntax::TreeBuilder {
     return Tokens;
   }
 
+  void setRole(syntax::Node *N, NodeRole R) {
+    assert(N->role() == NodeRole::Detached);
+    N->setRole(R);
+  }
+
   /// A collection of trees covering the input tokens.
   /// When created, each tree corresponds to a single token in the file.
   /// Clients call 'foldChildren' to attach one or more subtrees to a parent
@@ -306,7 +355,7 @@ class syntax::TreeBuilder {
         auto *L = new (A.allocator()) syntax::Leaf(&T);
         L->Original = true;
         L->CanModify = A.tokenBuffer().spelledForExpanded(T).hasValue();
-        Trees.insert(Trees.end(), {&T, NodeAndRole{L}});
+        Trees.insert(Trees.end(), {&T, L});
       }
     }
 
@@ -338,7 +387,9 @@ class syntax::TreeBuilder {
       assert((std::next(It) == Trees.end() ||
               std::next(It)->first == Range.end()) &&
              "no child with the specified range");
-      It->second.Role = Role;
+      assert(It->second->role() == NodeRole::Detached &&
+             "re-assigning role for a child");
+      It->second->setRole(Role);
     }
 
     /// Add \p Node to the forest and attach child nodes based on \p Tokens.
@@ -394,7 +445,7 @@ class syntax::TreeBuilder {
     // EXPECTS: all tokens were consumed and are owned by a single root node.
     syntax::Node *finalize() && {
       assert(Trees.size() == 1);
-      auto *Root = Trees.begin()->second.Node;
+      auto *Root = Trees.begin()->second;
       Trees = {};
       return Root;
     }
@@ -408,9 +459,9 @@ class syntax::TreeBuilder {
                 : A.tokenBuffer().expandedTokens().end() - It->first;
 
         R += std::string(llvm::formatv(
-            "- '{0}' covers '{1}'+{2} tokens\n", It->second.Node->kind(),
+            "- '{0}' covers '{1}'+{2} tokens\n", It->second->kind(),
             It->first->text(A.sourceManager()), CoveredTokens));
-        R += It->second.Node->dump(A);
+        R += It->second->dump(A);
       }
       return R;
     }
@@ -434,32 +485,25 @@ class syntax::TreeBuilder {
           "fold crosses boundaries of existing subtrees");
 
       // We need to go in reverse order, because we can only prepend.
-      for (auto It = EndChildren; It != BeginChildren; --It)
-        Node->prependChildLowLevel(std::prev(It)->second.Node,
-                                   std::prev(It)->second.Role);
+      for (auto It = EndChildren; It != BeginChildren; --It) {
+        auto *C = std::prev(It)->second;
+        if (C->role() == NodeRole::Detached)
+          C->setRole(NodeRole::Unknown);
+        Node->prependChildLowLevel(C);
+      }
 
       // Mark that this node came from the AST and is backed by the source code.
       Node->Original = true;
       Node->CanModify = A.tokenBuffer().spelledForExpanded(Tokens).hasValue();
 
       Trees.erase(BeginChildren, EndChildren);
-      Trees.insert({FirstToken, NodeAndRole(Node)});
+      Trees.insert({FirstToken, Node});
     }
-    /// A with a role that should be assigned to it when adding to a parent.
-    struct NodeAndRole {
-      explicit NodeAndRole(syntax::Node *Node)
-          : Node(Node), Role(NodeRole::Unknown) {}
-
-      syntax::Node *Node;
-      NodeRole Role;
-    };
 
     /// Maps from the start token to a subtree starting at that token.
     /// Keys in the map are pointers into the array of expanded tokens, so
     /// pointer order corresponds to the order of preprocessor tokens.
-    /// FIXME: storing the end tokens is redundant.
-    /// FIXME: the key of a map is redundant, it is also stored in NodeForRange.
-    std::map<const syntax::Token *, NodeAndRole> Trees;
+    std::map<const syntax::Token *, syntax::Node *> Trees;
 
     /// See documentation of `foldChildrenDelayed` for details.
     struct DelayedFold {
@@ -479,6 +523,7 @@ class syntax::TreeBuilder {
       LocationToToken;
   Forest Pending;
   llvm::DenseSet<Decl *> DeclsWithoutSemicolons;
+  ASTToSyntaxMapping Mapping;
 };
 
 namespace {
@@ -505,10 +550,9 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
         Builder.sourceManager(), DD->getTypeSourceInfo()->getTypeLoc(),
         getQualifiedNameStart(DD), Initializer);
     if (Declarator.isValid()) {
-      auto Tokens =
-          Builder.getRange(Declarator.getBegin(), Declarator.getEnd());
-      Builder.foldNode(Tokens, new (allocator()) syntax::SimpleDeclarator);
-      Builder.markChild(Tokens, syntax::NodeRole::SimpleDeclaration_declarator);
+      auto *N = new (allocator()) syntax::SimpleDeclarator;
+      Builder.foldNode(Builder.getRange(Declarator), N, DD);
+      Builder.markChild(N, syntax::NodeRole::SimpleDeclaration_declarator);
     }
 
     return true;
@@ -522,9 +566,9 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
         Builder.sourceManager(), D->getTypeSourceInfo()->getTypeLoc(),
         /*Name=*/D->getLocation(), /*Initializer=*/SourceRange());
     if (R.isValid()) {
-      auto Tokens = Builder.getRange(R.getBegin(), R.getEnd());
-      Builder.foldNode(Tokens, new (allocator()) syntax::SimpleDeclarator);
-      Builder.markChild(Tokens, syntax::NodeRole::SimpleDeclaration_declarator);
+      auto *N = new (allocator()) syntax::SimpleDeclarator;
+      Builder.foldNode(Builder.getRange(R), N, D);
+      Builder.markChild(N, syntax::NodeRole::SimpleDeclaration_declarator);
     }
     return true;
   }
@@ -532,7 +576,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
   bool VisitDecl(Decl *D) {
     assert(!D->isImplicit());
     Builder.foldNode(Builder.getDeclRange(D),
-                     new (allocator()) syntax::UnknownDeclaration());
+                     new (allocator()) syntax::UnknownDeclaration(), D);
     return true;
   }
 
@@ -545,11 +589,11 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
       return false;
     if (C->isExplicitSpecialization())
       return true; // we are only interested in explicit instantiations.
-    if (!WalkUpFromClassTemplateSpecializationDecl(C))
-      return false;
+    auto *Declaration =
+        cast<syntax::SimpleDeclaration>(handleFreeStandingTagDecl(C));
     foldExplicitTemplateInstantiation(
         Builder.getTemplateRange(C), Builder.findToken(C->getExternLoc()),
-        Builder.findToken(C->getTemplateKeywordLoc()), Builder.getDeclRange(C));
+        Builder.findToken(C->getTemplateKeywordLoc()), Declaration, C);
     return true;
   }
 
@@ -557,7 +601,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     foldTemplateDeclaration(
         Builder.getDeclRange(S),
         Builder.findToken(S->getTemplateParameters()->getTemplateLoc()),
-        Builder.getDeclRange(S->getTemplatedDecl()));
+        Builder.getDeclRange(S->getTemplatedDecl()), S);
     return true;
   }
 
@@ -567,24 +611,30 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
       assert(C->getNumTemplateParameterLists() == 0);
       return true;
     }
+    handleFreeStandingTagDecl(C);
+    return true;
+  }
+
+  syntax::Declaration *handleFreeStandingTagDecl(TagDecl *C) {
+    assert(C->isFreeStanding());
     // Class is a declaration specifier and needs a spanning declaration node.
     auto DeclarationRange = Builder.getDeclRange(C);
-    Builder.foldNode(DeclarationRange,
-                     new (allocator()) syntax::SimpleDeclaration);
+    syntax::Declaration *Result = new (allocator()) syntax::SimpleDeclaration;
+    Builder.foldNode(DeclarationRange, Result, nullptr);
 
     // Build TemplateDeclaration nodes if we had template parameters.
     auto ConsumeTemplateParameters = [&](const TemplateParameterList &L) {
       const auto *TemplateKW = Builder.findToken(L.getTemplateLoc());
       auto R = llvm::makeArrayRef(TemplateKW, DeclarationRange.end());
-      foldTemplateDeclaration(R, TemplateKW, DeclarationRange);
-
+      Result =
+          foldTemplateDeclaration(R, TemplateKW, DeclarationRange, nullptr);
       DeclarationRange = R;
     };
     if (auto *S = llvm::dyn_cast<ClassTemplatePartialSpecializationDecl>(C))
       ConsumeTemplateParameters(*S->getTemplateParameters());
     for (unsigned I = C->getNumTemplateParameterLists(); 0 < I; --I)
       ConsumeTemplateParameters(*C->getTemplateParameterList(I - 1));
-    return true;
+    return Result;
   }
 
   bool WalkUpFromTranslationUnitDecl(TranslationUnitDecl *TU) {
@@ -602,14 +652,14 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(S->getRBracLoc(), NodeRole::CloseParen);
 
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::CompoundStatement);
+                     new (allocator()) syntax::CompoundStatement, S);
     return true;
   }
 
   // Some statements are not yet handled by syntax trees.
   bool WalkUpFromStmt(Stmt *S) {
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::UnknownStatement);
+                     new (allocator()) syntax::UnknownStatement, S);
     return true;
   }
 
@@ -647,7 +697,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
   bool WalkUpFromExpr(Expr *E) {
     assert(!isImplicitExpr(E) && "should be handled by TraverseStmt");
     Builder.foldNode(Builder.getExprRange(E),
-                     new (allocator()) syntax::UnknownExpression);
+                     new (allocator()) syntax::UnknownExpression, E);
     return true;
   }
 
@@ -659,7 +709,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
       // FIXME: build corresponding nodes for the name of this namespace.
       return true;
     }
-    Builder.foldNode(Tokens, new (allocator()) syntax::NamespaceDefinition);
+    Builder.foldNode(Tokens, new (allocator()) syntax::NamespaceDefinition, S);
     return true;
   }
 
@@ -674,7 +724,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(L.getLParenLoc(), syntax::NodeRole::OpenParen);
     Builder.markChildToken(L.getRParenLoc(), syntax::NodeRole::CloseParen);
     Builder.foldNode(Builder.getRange(L.getLParenLoc(), L.getRParenLoc()),
-                     new (allocator()) syntax::ParenDeclarator);
+                     new (allocator()) syntax::ParenDeclarator, L);
     return true;
   }
 
@@ -685,7 +735,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
                           syntax::NodeRole::ArraySubscript_sizeExpression);
     Builder.markChildToken(L.getRBracketLoc(), syntax::NodeRole::CloseParen);
     Builder.foldNode(Builder.getRange(L.getLBracketLoc(), L.getRBracketLoc()),
-                     new (allocator()) syntax::ArraySubscript);
+                     new (allocator()) syntax::ArraySubscript, L);
     return true;
   }
 
@@ -697,7 +747,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
           syntax::NodeRole::ParametersAndQualifiers_parameter);
     Builder.markChildToken(L.getRParenLoc(), syntax::NodeRole::CloseParen);
     Builder.foldNode(Builder.getRange(L.getLParenLoc(), L.getEndLoc()),
-                     new (allocator()) syntax::ParametersAndQualifiers);
+                     new (allocator()) syntax::ParametersAndQualifiers, L);
     return true;
   }
 
@@ -714,8 +764,8 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
 
   bool WalkUpFromMemberPointerTypeLoc(MemberPointerTypeLoc L) {
     auto SR = L.getLocalSourceRange();
-    Builder.foldNode(Builder.getRange(SR.getBegin(), SR.getEnd()),
-                     new (allocator()) syntax::MemberPointer);
+    Builder.foldNode(Builder.getRange(SR),
+                     new (allocator()) syntax::MemberPointer, L);
     return true;
   }
 
@@ -724,13 +774,13 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
   // and fold resulting nodes.
   bool WalkUpFromDeclStmt(DeclStmt *S) {
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::DeclarationStatement);
+                     new (allocator()) syntax::DeclarationStatement, S);
     return true;
   }
 
   bool WalkUpFromNullStmt(NullStmt *S) {
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::EmptyStatement);
+                     new (allocator()) syntax::EmptyStatement, S);
     return true;
   }
 
@@ -739,7 +789,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
                            syntax::NodeRole::IntroducerKeyword);
     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::SwitchStatement);
+                     new (allocator()) syntax::SwitchStatement, S);
     return true;
   }
 
@@ -749,7 +799,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markExprChild(S->getLHS(), syntax::NodeRole::CaseStatement_value);
     Builder.markStmtChild(S->getSubStmt(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::CaseStatement);
+                     new (allocator()) syntax::CaseStatement, S);
     return true;
   }
 
@@ -758,7 +808,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
                            syntax::NodeRole::IntroducerKeyword);
     Builder.markStmtChild(S->getSubStmt(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::DefaultStatement);
+                     new (allocator()) syntax::DefaultStatement, S);
     return true;
   }
 
@@ -771,7 +821,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markStmtChild(S->getElse(),
                           syntax::NodeRole::IfStatement_elseStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::IfStatement);
+                     new (allocator()) syntax::IfStatement, S);
     return true;
   }
 
@@ -779,7 +829,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(S->getForLoc(), syntax::NodeRole::IntroducerKeyword);
     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::ForStatement);
+                     new (allocator()) syntax::ForStatement, S);
     return true;
   }
 
@@ -788,7 +838,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
                            syntax::NodeRole::IntroducerKeyword);
     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::WhileStatement);
+                     new (allocator()) syntax::WhileStatement, S);
     return true;
   }
 
@@ -796,7 +846,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(S->getContinueLoc(),
                            syntax::NodeRole::IntroducerKeyword);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::ContinueStatement);
+                     new (allocator()) syntax::ContinueStatement, S);
     return true;
   }
 
@@ -804,7 +854,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(S->getBreakLoc(),
                            syntax::NodeRole::IntroducerKeyword);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::BreakStatement);
+                     new (allocator()) syntax::BreakStatement, S);
     return true;
   }
 
@@ -814,7 +864,7 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markExprChild(S->getRetValue(),
                           syntax::NodeRole::ReturnStatement_value);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::ReturnStatement);
+                     new (allocator()) syntax::ReturnStatement, S);
     return true;
   }
 
@@ -822,13 +872,13 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChildToken(S->getForLoc(), syntax::NodeRole::IntroducerKeyword);
     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
     Builder.foldNode(Builder.getStmtRange(S),
-                     new (allocator()) syntax::RangeBasedForStatement);
+                     new (allocator()) syntax::RangeBasedForStatement, S);
     return true;
   }
 
   bool WalkUpFromEmptyDecl(EmptyDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::EmptyDeclaration);
+                     new (allocator()) syntax::EmptyDeclaration, S);
     return true;
   }
 
@@ -838,55 +888,56 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markExprChild(S->getMessage(),
                           syntax::NodeRole::StaticAssertDeclaration_message);
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::StaticAssertDeclaration);
+                     new (allocator()) syntax::StaticAssertDeclaration, S);
     return true;
   }
 
   bool WalkUpFromLinkageSpecDecl(LinkageSpecDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::LinkageSpecificationDeclaration);
+                     new (allocator()) syntax::LinkageSpecificationDeclaration,
+                     S);
     return true;
   }
 
   bool WalkUpFromNamespaceAliasDecl(NamespaceAliasDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::NamespaceAliasDefinition);
+                     new (allocator()) syntax::NamespaceAliasDefinition, S);
     return true;
   }
 
   bool WalkUpFromUsingDirectiveDecl(UsingDirectiveDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::UsingNamespaceDirective);
+                     new (allocator()) syntax::UsingNamespaceDirective, S);
     return true;
   }
 
   bool WalkUpFromUsingDecl(UsingDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::UsingDeclaration);
+                     new (allocator()) syntax::UsingDeclaration, S);
     return true;
   }
 
   bool WalkUpFromUnresolvedUsingValueDecl(UnresolvedUsingValueDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::UsingDeclaration);
+                     new (allocator()) syntax::UsingDeclaration, S);
     return true;
   }
 
   bool WalkUpFromUnresolvedUsingTypenameDecl(UnresolvedUsingTypenameDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::UsingDeclaration);
+                     new (allocator()) syntax::UsingDeclaration, S);
     return true;
   }
 
   bool WalkUpFromTypeAliasDecl(TypeAliasDecl *S) {
     Builder.foldNode(Builder.getDeclRange(S),
-                     new (allocator()) syntax::TypeAliasDeclaration);
+                     new (allocator()) syntax::TypeAliasDeclaration, S);
     return true;
   }
 
 private:
   /// Returns the range of the built node.
-  llvm::ArrayRef<syntax::Token> BuildTrailingReturn(FunctionProtoTypeLoc L) {
+  syntax::TrailingReturnType *BuildTrailingReturn(FunctionProtoTypeLoc L) {
     assert(L.getTypePtr()->hasTrailingReturn());
 
     auto ReturnedType = L.getReturnLoc();
@@ -895,33 +946,31 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
         getDeclaratorRange(this->Builder.sourceManager(), ReturnedType,
                            /*Name=*/SourceLocation(),
                            /*Initializer=*/SourceLocation());
-    llvm::ArrayRef<syntax::Token> ReturnDeclaratorTokens;
+    syntax::SimpleDeclarator *ReturnDeclarator = nullptr;
     if (ReturnDeclaratorRange.isValid()) {
-      ReturnDeclaratorTokens = Builder.getRange(
-          ReturnDeclaratorRange.getBegin(), ReturnDeclaratorRange.getEnd());
-      Builder.foldNode(ReturnDeclaratorTokens,
-                       new (allocator()) syntax::SimpleDeclarator);
+      ReturnDeclarator = new (allocator()) syntax::SimpleDeclarator;
+      Builder.foldNode(Builder.getRange(ReturnDeclaratorRange),
+                       ReturnDeclarator, nullptr);
     }
 
     // Build node for trailing return type.
-    auto Return =
-        Builder.getRange(ReturnedType.getBeginLoc(), ReturnedType.getEndLoc());
+    auto Return = Builder.getRange(ReturnedType.getSourceRange());
     const auto *Arrow = Return.begin() - 1;
     assert(Arrow->kind() == tok::arrow);
     auto Tokens = llvm::makeArrayRef(Arrow, Return.end());
     Builder.markChildToken(Arrow, syntax::NodeRole::TrailingReturnType_arrow);
-    if (!ReturnDeclaratorTokens.empty())
-      Builder.markChild(ReturnDeclaratorTokens,
+    if (ReturnDeclarator)
+      Builder.markChild(ReturnDeclarator,
                         syntax::NodeRole::TrailingReturnType_declarator);
-    Builder.foldNode(Tokens, new (allocator()) syntax::TrailingReturnType);
-    return Tokens;
+    auto *R = new (allocator()) syntax::TrailingReturnType;
+    Builder.foldNode(Tokens, R, nullptr);
+    return R;
   }
 
-  void
-  foldExplicitTemplateInstantiation(ArrayRef<syntax::Token> Range,
-                                    const syntax::Token *ExternKW,
-                                    const syntax::Token *TemplateKW,
-                                    ArrayRef<syntax::Token> InnerDeclaration) {
+  void foldExplicitTemplateInstantiation(
+      ArrayRef<syntax::Token> Range, const syntax::Token *ExternKW,
+      const syntax::Token *TemplateKW,
+      syntax::SimpleDeclaration *InnerDeclaration, Decl *From) {
     assert(!ExternKW || ExternKW->kind() == tok::kw_extern);
     assert(TemplateKW && TemplateKW->kind() == tok::kw_template);
     Builder.markChildToken(
@@ -931,19 +980,22 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
     Builder.markChild(
         InnerDeclaration,
         syntax::NodeRole::ExplicitTemplateInstantiation_declaration);
-    Builder.foldNode(Range,
-                     new (allocator()) syntax::ExplicitTemplateInstantiation);
+    Builder.foldNode(
+        Range, new (allocator()) syntax::ExplicitTemplateInstantiation, From);
   }
 
-  void foldTemplateDeclaration(ArrayRef<syntax::Token> Range,
-                               const syntax::Token *TemplateKW,
-                               ArrayRef<syntax::Token> TemplatedDeclaration) {
+  syntax::TemplateDeclaration *foldTemplateDeclaration(
+      ArrayRef<syntax::Token> Range, const syntax::Token *TemplateKW,
+      ArrayRef<syntax::Token> TemplatedDeclaration, Decl *From) {
     assert(TemplateKW && TemplateKW->kind() == tok::kw_template);
     Builder.markChildToken(TemplateKW, syntax::NodeRole::IntroducerKeyword);
     Builder.markMaybeDelayedChild(
         TemplatedDeclaration,
         syntax::NodeRole::TemplateDeclaration_declaration);
-    Builder.foldNode(Range, new (allocator()) syntax::TemplateDeclaration);
+
+    auto *N = new (allocator()) syntax::TemplateDeclaration;
+    Builder.foldNode(Range, N, From);
+    return N;
   }
 
   /// A small helper to save some typing.
@@ -954,11 +1006,6 @@ class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
 };
 } // namespace
 
-void syntax::TreeBuilder::foldNode(llvm::ArrayRef<syntax::Token> Range,
-                                   syntax::Tree *New) {
-  Pending.foldChildren(Arena, Range, New);
-}
-
 void syntax::TreeBuilder::noticeDeclRange(llvm::ArrayRef<syntax::Token> Range) {
   if (Pending.extendDelayedFold(Range))
     return;
@@ -982,9 +1029,15 @@ void syntax::TreeBuilder::markChildToken(const syntax::Token *T, NodeRole R) {
   Pending.assignRole(*T, R);
 }
 
-void syntax::TreeBuilder::markChild(llvm::ArrayRef<syntax::Token> Range,
-                                    NodeRole R) {
-  Pending.assignRole(Range, R);
+void syntax::TreeBuilder::markChild(syntax::Node *N, NodeRole R) {
+  assert(N);
+  setRole(N, R);
+}
+
+void syntax::TreeBuilder::markChild(ASTPtr N, NodeRole R) {
+  auto *SN = Mapping.find(N);
+  assert(SN != nullptr);
+  setRole(SN, R);
 }
 
 void syntax::TreeBuilder::markDelayedChild(llvm::ArrayRef<syntax::Token> Range,
@@ -1001,24 +1054,28 @@ void syntax::TreeBuilder::markStmtChild(Stmt *Child, NodeRole Role) {
   if (!Child)
     return;
 
-  auto Range = getStmtRange(Child);
+  syntax::Tree *ChildNode = Mapping.find(Child);
+  assert(ChildNode != nullptr);
+
   // This is an expression in a statement position, consume the trailing
   // semicolon and form an 'ExpressionStatement' node.
   if (auto *E = dyn_cast<Expr>(Child)) {
-    Pending.assignRole(getExprRange(E),
-                       NodeRole::ExpressionStatement_expression);
-    // 'getRange(Stmt)' ensures this already covers a trailing semicolon.
-    Pending.foldChildren(Arena, Range,
-                         new (allocator()) syntax::ExpressionStatement);
+    setRole(ChildNode, NodeRole::ExpressionStatement_expression);
+    ChildNode = new (allocator()) syntax::ExpressionStatement;
+    // (!) 'getStmtRange()' ensures this covers a trailing semicolon.
+    Pending.foldChildren(Arena, getStmtRange(Child), ChildNode);
   }
-  Pending.assignRole(Range, Role);
+  setRole(ChildNode, Role);
 }
 
 void syntax::TreeBuilder::markExprChild(Expr *Child, NodeRole Role) {
   if (!Child)
     return;
+  Child = Child->IgnoreImplicit();
 
-  Pending.assignRole(getExprRange(Child), Role);
+  syntax::Tree *ChildNode = Mapping.find(Child);
+  assert(ChildNode != nullptr);
+  setRole(ChildNode, Role);
 }
 
 const syntax::Token *syntax::TreeBuilder::findToken(SourceLocation L) const {

diff  --git a/clang/lib/Tooling/Syntax/Mutations.cpp b/clang/lib/Tooling/Syntax/Mutations.cpp
index 72458528202e..24048b297a11 100644
--- a/clang/lib/Tooling/Syntax/Mutations.cpp
+++ b/clang/lib/Tooling/Syntax/Mutations.cpp
@@ -35,7 +35,7 @@ class syntax::MutationsImpl {
     assert(!New->isDetached());
     assert(Role != NodeRole::Detached);
 
-    New->Role = static_cast<unsigned>(Role);
+    New->setRole(Role);
     auto *P = Anchor->parent();
     P->replaceChildRangeLowLevel(Anchor, Anchor, New);
 

diff  --git a/clang/lib/Tooling/Syntax/Tree.cpp b/clang/lib/Tooling/Syntax/Tree.cpp
index 9a6270ec4cce..37579e6145b6 100644
--- a/clang/lib/Tooling/Syntax/Tree.cpp
+++ b/clang/lib/Tooling/Syntax/Tree.cpp
@@ -58,22 +58,33 @@ bool syntax::Leaf::classof(const Node *N) {
 
 syntax::Node::Node(NodeKind Kind)
     : Parent(nullptr), NextSibling(nullptr), Kind(static_cast<unsigned>(Kind)),
-      Role(static_cast<unsigned>(NodeRole::Detached)), Original(false),
-      CanModify(false) {}
+      Role(0), Original(false), CanModify(false) {
+  this->setRole(NodeRole::Detached);
+}
 
 bool syntax::Node::isDetached() const { return role() == NodeRole::Detached; }
 
+void syntax::Node::setRole(NodeRole NR) {
+  this->Role = static_cast<unsigned>(NR);
+}
+
 bool syntax::Tree::classof(const Node *N) { return N->kind() > NodeKind::Leaf; }
 
 void syntax::Tree::prependChildLowLevel(Node *Child, NodeRole Role) {
-  assert(Child->Parent == nullptr);
-  assert(Child->NextSibling == nullptr);
   assert(Child->role() == NodeRole::Detached);
   assert(Role != NodeRole::Detached);
 
+  Child->setRole(Role);
+  prependChildLowLevel(Child);
+}
+
+void syntax::Tree::prependChildLowLevel(Node *Child) {
+  assert(Child->Parent == nullptr);
+  assert(Child->NextSibling == nullptr);
+  assert(Child->role() != NodeRole::Detached);
+
   Child->Parent = this;
   Child->NextSibling = this->FirstChild;
-  Child->Role = static_cast<unsigned>(Role);
   this->FirstChild = Child;
 }
 
@@ -94,7 +105,7 @@ void syntax::Tree::replaceChildRangeLowLevel(Node *BeforeBegin, Node *End,
        N != End;) {
     auto *Next = N->NextSibling;
 
-    N->Role = static_cast<unsigned>(NodeRole::Detached);
+    N->setRole(NodeRole::Detached);
     N->Parent = nullptr;
     N->NextSibling = nullptr;
     if (N->Original)


        


More information about the cfe-commits mailing list