[clang] 0a717d5 - Make it possible control matcher traversal kind with ASTContext

Stephen Kelly via cfe-commits cfe-commits at lists.llvm.org
Fri Dec 6 15:12:16 PST 2019


Author: Stephen Kelly
Date: 2019-12-06T23:11:32Z
New Revision: 0a717d5b5d31fc2d5bc98ca695031fb09e65beb0

URL: https://github.com/llvm/llvm-project/commit/0a717d5b5d31fc2d5bc98ca695031fb09e65beb0
DIFF: https://github.com/llvm/llvm-project/commit/0a717d5b5d31fc2d5bc98ca695031fb09e65beb0.diff

LOG: Make it possible control matcher traversal kind with ASTContext

Summary:
This will eventually allow traversal of an AST while ignoring invisible
AST nodes.  Currently it depends on the available enum values for
TraversalKinds.  That can be extended to ignore all invisible nodes in
the future.

Reviewers: klimek, aaron.ballman

Subscribers: cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D61837

Added: 
    

Modified: 
    clang/include/clang/AST/ASTContext.h
    clang/include/clang/AST/ASTNodeTraverser.h
    clang/include/clang/ASTMatchers/ASTMatchers.h
    clang/include/clang/ASTMatchers/ASTMatchersInternal.h
    clang/lib/AST/ASTContext.cpp
    clang/lib/ASTMatchers/ASTMatchFinder.cpp
    clang/lib/ASTMatchers/ASTMatchersInternal.cpp
    clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a0484509fa4a..4ba54fa51885 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -569,7 +569,17 @@ class ASTContext : public RefCountedBase<ASTContext> {
   clang::PrintingPolicy PrintingPolicy;
   std::unique_ptr<interp::Context> InterpContext;
 
+  ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs;
+
 public:
+  ast_type_traits::TraversalKind getTraversalKind() const { return Traversal; }
+  void setTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; }
+
+  const Expr *traverseIgnored(const Expr *E) const;
+  Expr *traverseIgnored(Expr *E) const;
+  ast_type_traits::DynTypedNode
+  traverseIgnored(const ast_type_traits::DynTypedNode &N) const;
+
   IdentifierTable &Idents;
   SelectorTable &Selectors;
   Builtin::Context &BuiltinInfo;
@@ -2996,7 +3006,7 @@ OPT_LIST(V)
 
   std::vector<Decl *> TraversalScope;
   class ParentMap;
-  std::unique_ptr<ParentMap> Parents;
+  std::map<ast_type_traits::TraversalKind, std::unique_ptr<ParentMap>> Parents;
 
   std::unique_ptr<VTableContextBase> VTContext;
 
@@ -3040,6 +3050,22 @@ inline Selector GetUnarySelector(StringRef name, ASTContext &Ctx) {
   return Ctx.Selectors.getSelector(1, &II);
 }
 
+class TraversalKindScope {
+  ASTContext &Ctx;
+  ast_type_traits::TraversalKind TK = ast_type_traits::TK_AsIs;
+
+public:
+  TraversalKindScope(ASTContext &Ctx,
+                     llvm::Optional<ast_type_traits::TraversalKind> ScopeTK)
+      : Ctx(Ctx) {
+    TK = Ctx.getTraversalKind();
+    if (ScopeTK)
+      Ctx.setTraversalKind(*ScopeTK);
+  }
+
+  ~TraversalKindScope() { Ctx.setTraversalKind(TK); }
+};
+
 } // namespace clang
 
 // operator new and delete aren't allowed inside namespaces.

diff  --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index ed9fc14aba42..b2e6d9e9c5e4 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -65,6 +65,9 @@ class ASTNodeTraverser
   /// not already been loaded.
   bool Deserialize = false;
 
+  ast_type_traits::TraversalKind Traversal =
+      ast_type_traits::TraversalKind::TK_AsIs;
+
   NodeDelegateType &getNodeDelegate() {
     return getDerived().doGetNodeDelegate();
   }
@@ -74,6 +77,8 @@ class ASTNodeTraverser
   void setDeserialize(bool D) { Deserialize = D; }
   bool getDeserialize() const { return Deserialize; }
 
+  void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; }
+
   void Visit(const Decl *D) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(D);
@@ -97,8 +102,20 @@ class ASTNodeTraverser
     });
   }
 
-  void Visit(const Stmt *S, StringRef Label = {}) {
+  void Visit(const Stmt *Node, StringRef Label = {}) {
     getNodeDelegate().AddChild(Label, [=] {
+      const Stmt *S = Node;
+
+      if (auto *E = dyn_cast_or_null<Expr>(S)) {
+        switch (Traversal) {
+        case ast_type_traits::TK_AsIs:
+          break;
+        case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
+          S = E->IgnoreParenImpCasts();
+          break;
+        }
+      }
+
       getNodeDelegate().Visit(S);
 
       if (!S) {

diff  --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index 3a5d0c08337e..608454631556 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -689,6 +689,31 @@ AST_POLYMORPHIC_MATCHER_P(
                              Builder);
 }
 
+/// Causes all nested matchers to be matched with the specified traversal kind.
+///
+/// Given
+/// \code
+///   void foo()
+///   {
+///       int i = 3.0;
+///   }
+/// \endcode
+/// The matcher
+/// \code
+///   traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+///     varDecl(hasInitializer(floatLiteral().bind("init")))
+///   )
+/// \endcode
+/// matches the variable declaration with "init" bound to the "3.0".
+template <typename T>
+internal::Matcher<T> traverse(ast_type_traits::TraversalKind TK,
+                              const internal::Matcher<T> &InnerMatcher) {
+  return internal::DynTypedMatcher::constructRestrictedWrapper(
+             new internal::TraversalMatcher<T>(TK, InnerMatcher),
+             InnerMatcher.getID().first)
+      .template unconditionalConvertTo<T>();
+}
+
 /// Matches expressions that match InnerMatcher after any implicit AST
 /// nodes are stripped off.
 ///

diff  --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
index e9fa920b6bce..47fb5a79ffe6 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -283,6 +283,10 @@ class DynMatcherInterface
   virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode,
                           ASTMatchFinder *Finder,
                           BoundNodesTreeBuilder *Builder) const = 0;
+
+  virtual llvm::Optional<ast_type_traits::TraversalKind> TraversalKind() const {
+    return llvm::None;
+  }
 };
 
 /// Generic interface for matchers on an AST node of type T.
@@ -371,6 +375,10 @@ class DynTypedMatcher {
                     ast_type_traits::ASTNodeKind SupportedKind,
                     std::vector<DynTypedMatcher> InnerMatchers);
 
+  static DynTypedMatcher
+  constructRestrictedWrapper(const DynTypedMatcher &InnerMatcher,
+                             ast_type_traits::ASTNodeKind RestrictKind);
+
   /// Get a "true" matcher for \p NodeKind.
   ///
   /// It only checks that the node is of the right kind.
@@ -1002,7 +1010,7 @@ class ASTMatchFinder {
                   std::is_base_of<QualType, T>::value,
                   "unsupported type for recursive matching");
     return matchesChildOf(ast_type_traits::DynTypedNode::create(Node),
-                          Matcher, Builder, Traverse, Bind);
+                          getASTContext(), Matcher, Builder, Traverse, Bind);
   }
 
   template <typename T>
@@ -1018,7 +1026,7 @@ class ASTMatchFinder {
                   std::is_base_of<QualType, T>::value,
                   "unsupported type for recursive matching");
     return matchesDescendantOf(ast_type_traits::DynTypedNode::create(Node),
-                               Matcher, Builder, Bind);
+                               getASTContext(), Matcher, Builder, Bind);
   }
 
   // FIXME: Implement support for BindKind.
@@ -1033,24 +1041,26 @@ class ASTMatchFinder {
                       std::is_base_of<TypeLoc, T>::value,
                   "type not allowed for recursive matching");
     return matchesAncestorOf(ast_type_traits::DynTypedNode::create(Node),
-                             Matcher, Builder, MatchMode);
+                             getASTContext(), Matcher, Builder, MatchMode);
   }
 
   virtual ASTContext &getASTContext() const = 0;
 
 protected:
   virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node,
-                              const DynTypedMatcher &Matcher,
+                              ASTContext &Ctx, const DynTypedMatcher &Matcher,
                               BoundNodesTreeBuilder *Builder,
                               ast_type_traits::TraversalKind Traverse,
                               BindKind Bind) = 0;
 
   virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node,
+                                   ASTContext &Ctx,
                                    const DynTypedMatcher &Matcher,
                                    BoundNodesTreeBuilder *Builder,
                                    BindKind Bind) = 0;
 
   virtual bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node,
+                                 ASTContext &Ctx,
                                  const DynTypedMatcher &Matcher,
                                  BoundNodesTreeBuilder *Builder,
                                  AncestorMatchMode MatchMode) = 0;
@@ -1162,6 +1172,28 @@ struct ArgumentAdaptingMatcherFunc {
   }
 };
 
+template <typename T>
+class TraversalMatcher : public WrapperMatcherInterface<T> {
+  ast_type_traits::TraversalKind Traversal;
+
+public:
+  explicit TraversalMatcher(ast_type_traits::TraversalKind TK,
+                            const Matcher<T> &ChildMatcher)
+      : TraversalMatcher::WrapperMatcherInterface(ChildMatcher), Traversal(TK) {
+  }
+
+  bool matches(const T &Node, ASTMatchFinder *Finder,
+               BoundNodesTreeBuilder *Builder) const override {
+    return this->InnerMatcher.matches(
+        ast_type_traits::DynTypedNode::create(Node), Finder, Builder);
+  }
+
+  llvm::Optional<ast_type_traits::TraversalKind>
+  TraversalKind() const override {
+    return Traversal;
+  }
+};
+
 /// A PolymorphicMatcherWithParamN<MatcherT, P1, ..., PN> object can be
 /// created from N parameters p1, ..., pN (of type P1, ..., PN) and
 /// used as a Matcher<T> where a MatcherT<T, P1, ..., PN>(p1, ..., pN)

diff  --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 10b6718c89b9..d6010caa4a2a 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -99,6 +99,30 @@ using namespace clang;
 enum FloatingRank {
   Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank
 };
+const Expr *ASTContext::traverseIgnored(const Expr *E) const {
+  return traverseIgnored(const_cast<Expr *>(E));
+}
+
+Expr *ASTContext::traverseIgnored(Expr *E) const {
+  if (!E)
+    return nullptr;
+
+  switch (Traversal) {
+  case ast_type_traits::TK_AsIs:
+    return E;
+  case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses:
+    return E->IgnoreParenImpCasts();
+  }
+  llvm_unreachable("Invalid Traversal type!");
+}
+
+ast_type_traits::DynTypedNode
+ASTContext::traverseIgnored(const ast_type_traits::DynTypedNode &N) const {
+  if (const auto *E = N.get<Expr>()) {
+    return ast_type_traits::DynTypedNode::create(*traverseIgnored(E));
+  }
+  return N;
+}
 
 /// \returns location that is relevant when searching for Doc comments related
 /// to \p D.
@@ -959,7 +983,7 @@ class ASTContext::ParentMap {
 
 void ASTContext::setTraversalScope(const std::vector<Decl *> &TopLevelDecls) {
   TraversalScope = TopLevelDecls;
-  Parents.reset();
+  Parents.clear();
 }
 
 void ASTContext::AddDeallocation(void (*Callback)(void *), void *Data) const {
@@ -10397,7 +10421,8 @@ createDynTypedNode(const NestedNameSpecifierLoc &Node) {
 class ASTContext::ParentMap::ASTVisitor
     : public RecursiveASTVisitor<ASTVisitor> {
 public:
-  ASTVisitor(ParentMap &Map) : Map(Map) {}
+  ASTVisitor(ParentMap &Map, ASTContext &Context)
+      : Map(Map), Context(Context) {}
 
 private:
   friend class RecursiveASTVisitor<ASTVisitor>;
@@ -10467,9 +10492,12 @@ class ASTContext::ParentMap::ASTVisitor
   }
 
   bool TraverseStmt(Stmt *StmtNode) {
-    return TraverseNode(
-        StmtNode, StmtNode, [&] { return VisitorBase::TraverseStmt(StmtNode); },
-        &Map.PointerParents);
+    Stmt *FilteredNode = StmtNode;
+    if (auto *ExprNode = dyn_cast_or_null<Expr>(FilteredNode))
+      FilteredNode = Context.traverseIgnored(ExprNode);
+    return TraverseNode(FilteredNode, FilteredNode,
+                        [&] { return VisitorBase::TraverseStmt(FilteredNode); },
+                        &Map.PointerParents);
   }
 
   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
@@ -10487,20 +10515,22 @@ class ASTContext::ParentMap::ASTVisitor
   }
 
   ParentMap ⤅
+  ASTContext &Context;
   llvm::SmallVector<ast_type_traits::DynTypedNode, 16> ParentStack;
 };
 
 ASTContext::ParentMap::ParentMap(ASTContext &Ctx) {
-  ASTVisitor(*this).TraverseAST(Ctx);
+  ASTVisitor(*this, Ctx).TraverseAST(Ctx);
 }
 
 ASTContext::DynTypedNodeList
 ASTContext::getParents(const ast_type_traits::DynTypedNode &Node) {
-  if (!Parents)
+  std::unique_ptr<ParentMap> &P = Parents[Traversal];
+  if (!P)
     // We build the parent map for the traversal scope (usually whole TU), as
     // hasAncestor can escape any subtree.
-    Parents = std::make_unique<ParentMap>(*this);
-  return Parents->getParents(Node);
+    P = std::make_unique<ParentMap>(*this);
+  return P->getParents(Node);
 }
 
 bool

diff  --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
index 8a103f3d89a3..8ac35d522843 100644
--- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
@@ -59,10 +59,12 @@ struct MatchKey {
   DynTypedMatcher::MatcherIDType MatcherID;
   ast_type_traits::DynTypedNode Node;
   BoundNodesTreeBuilder BoundNodes;
+  ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs;
 
   bool operator<(const MatchKey &Other) const {
-    return std::tie(MatcherID, Node, BoundNodes) <
-           std::tie(Other.MatcherID, Other.Node, Other.BoundNodes);
+    return std::tie(MatcherID, Node, BoundNodes, Traversal) <
+           std::tie(Other.MatcherID, Other.Node, Other.BoundNodes,
+                    Other.Traversal);
   }
 };
 
@@ -143,6 +145,8 @@ class MatchChildASTVisitor
 
     ScopedIncrement ScopedDepth(&CurrentDepth);
     Stmt *StmtToTraverse = StmtNode;
+    if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode))
+      StmtToTraverse = Finder->getASTContext().traverseIgnored(ExprNode);
     if (Traversal ==
         ast_type_traits::TraversalKind::TK_IgnoreImplicitCastsAndParentheses) {
       if (Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode))
@@ -390,6 +394,7 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
 
   // Matches children or descendants of 'Node' with 'BaseMatcher'.
   bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node,
+                                  ASTContext &Ctx,
                                   const DynTypedMatcher &Matcher,
                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
                                   ast_type_traits::TraversalKind Traversal,
@@ -404,6 +409,7 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
     Key.Node = Node;
     // Note that we key on the bindings *before* the match.
     Key.BoundNodes = *Builder;
+    Key.Traversal = Ctx.getTraversalKind();
 
     MemoizationMap::iterator I = ResultCache.find(Key);
     if (I != ResultCache.end()) {
@@ -446,36 +452,36 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
 
   // Implements ASTMatchFinder::matchesChildOf.
   bool matchesChildOf(const ast_type_traits::DynTypedNode &Node,
-                      const DynTypedMatcher &Matcher,
+                      ASTContext &Ctx, const DynTypedMatcher &Matcher,
                       BoundNodesTreeBuilder *Builder,
                       ast_type_traits::TraversalKind Traversal,
                       BindKind Bind) override {
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal,
+    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal,
                                       Bind);
   }
   // Implements ASTMatchFinder::matchesDescendantOf.
   bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node,
-                           const DynTypedMatcher &Matcher,
+                           ASTContext &Ctx, const DynTypedMatcher &Matcher,
                            BoundNodesTreeBuilder *Builder,
                            BindKind Bind) override {
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX,
+    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
                                       ast_type_traits::TraversalKind::TK_AsIs,
                                       Bind);
   }
   // Implements ASTMatchFinder::matchesAncestorOf.
   bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node,
-                         const DynTypedMatcher &Matcher,
+                         ASTContext &Ctx, const DynTypedMatcher &Matcher,
                          BoundNodesTreeBuilder *Builder,
                          AncestorMatchMode MatchMode) override {
     // Reset the cache outside of the recursive call to make sure we
     // don't invalidate any iterators.
     if (ResultCache.size() > MaxMemoizationEntries)
       ResultCache.clear();
-    return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder,
+    return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
                                                 MatchMode);
   }
 
@@ -576,7 +582,7 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
       if (EnableCheckProfiling)
         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
       BoundNodesTreeBuilder Builder;
-      if (MP.first.matchesNoKindCheck(DynNode, this, &Builder)) {
+      if (MP.first.matches(DynNode, this, &Builder)) {
         MatchVisitor Visitor(ActiveASTContext, MP.second);
         Builder.visitMatches(&Visitor);
       }
@@ -640,16 +646,19 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
   // allow simple memoization on the ancestors. Thus, we only memoize as long
   // as there is a single parent.
   bool memoizedMatchesAncestorOfRecursively(
-      const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher,
-      BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) {
+      const ast_type_traits::DynTypedNode &Node, ASTContext &Ctx,
+      const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder,
+      AncestorMatchMode MatchMode) {
     // For AST-nodes that don't have an identity, we can't memoize.
     if (!Builder->isComparable())
-      return matchesAncestorOfRecursively(Node, Matcher, Builder, MatchMode);
+      return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder,
+                                          MatchMode);
 
     MatchKey Key;
     Key.MatcherID = Matcher.getID();
     Key.Node = Node;
     Key.BoundNodes = *Builder;
+    Key.Traversal = Ctx.getTraversalKind();
 
     // Note that we cannot use insert and reuse the iterator, as recursive
     // calls to match might invalidate the result cache iterators.
@@ -661,8 +670,8 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
 
     MemoizedMatchResult Result;
     Result.Nodes = *Builder;
-    Result.ResultOfMatch =
-        matchesAncestorOfRecursively(Node, Matcher, &Result.Nodes, MatchMode);
+    Result.ResultOfMatch = matchesAncestorOfRecursively(
+        Node, Ctx, Matcher, &Result.Nodes, MatchMode);
 
     MemoizedMatchResult &CachedResult = ResultCache[Key];
     CachedResult = std::move(Result);
@@ -672,6 +681,7 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
   }
 
   bool matchesAncestorOfRecursively(const ast_type_traits::DynTypedNode &Node,
+                                    ASTContext &Ctx,
                                     const DynTypedMatcher &Matcher,
                                     BoundNodesTreeBuilder *Builder,
                                     AncestorMatchMode MatchMode) {
@@ -705,8 +715,8 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
         return true;
       }
       if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
-        return memoizedMatchesAncestorOfRecursively(Parent, Matcher, Builder,
-                                                    MatchMode);
+        return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher,
+                                                    Builder, MatchMode);
         // Once we get back from the recursive call, the result will be the
         // same as the parent's result.
       }
@@ -804,8 +814,6 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
   /// kind (and derived kinds) so it is a waste to try every matcher on every
   /// node.
   /// We precalculate a list of matchers that pass the toplevel restrict check.
-  /// This also allows us to skip the restrict check at matching time. See
-  /// use \c matchesNoKindCheck() above.
   llvm::DenseMap<ast_type_traits::ASTNodeKind, std::vector<unsigned short>>
       MatcherFiltersMap;
 

diff  --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
index 4ee32fbe94b1..efa628cfeefc 100644
--- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -189,6 +189,14 @@ DynTypedMatcher DynTypedMatcher::constructVariadic(
   llvm_unreachable("Invalid Op value.");
 }
 
+DynTypedMatcher DynTypedMatcher::constructRestrictedWrapper(
+    const DynTypedMatcher &InnerMatcher,
+    ast_type_traits::ASTNodeKind RestrictKind) {
+  DynTypedMatcher Copy = InnerMatcher;
+  Copy.RestrictKind = RestrictKind;
+  return Copy;
+}
+
 DynTypedMatcher DynTypedMatcher::trueMatcher(
     ast_type_traits::ASTNodeKind NodeKind) {
   return DynTypedMatcher(NodeKind, NodeKind, &*TrueMatcherInstance);
@@ -211,8 +219,13 @@ DynTypedMatcher DynTypedMatcher::dynCastTo(
 bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode,
                               ASTMatchFinder *Finder,
                               BoundNodesTreeBuilder *Builder) const {
-  if (RestrictKind.isBaseOf(DynNode.getNodeKind()) &&
-      Implementation->dynMatches(DynNode, Finder, Builder)) {
+  TraversalKindScope RAII(Finder->getASTContext(),
+                          Implementation->TraversalKind());
+
+  auto N = Finder->getASTContext().traverseIgnored(DynNode);
+
+  if (RestrictKind.isBaseOf(N.getNodeKind()) &&
+      Implementation->dynMatches(N, Finder, Builder)) {
     return true;
   }
   // Delete all bindings when a matcher does not match.
@@ -225,8 +238,13 @@ bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode,
 bool DynTypedMatcher::matchesNoKindCheck(
     const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder,
     BoundNodesTreeBuilder *Builder) const {
-  assert(RestrictKind.isBaseOf(DynNode.getNodeKind()));
-  if (Implementation->dynMatches(DynNode, Finder, Builder)) {
+  TraversalKindScope raii(Finder->getASTContext(),
+                          Implementation->TraversalKind());
+
+  auto N = Finder->getASTContext().traverseIgnored(DynNode);
+
+  assert(RestrictKind.isBaseOf(N.getNodeKind()));
+  if (Implementation->dynMatches(N, Finder, Builder)) {
     return true;
   }
   // Delete all bindings when a matcher does not match.

diff  --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index 693962919051..f67e4d965b5c 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -1595,6 +1595,91 @@ TEST(IgnoringImplicit, DoesNotMatchIncorrectly) {
       notMatches("class C {}; C a = C();", varDecl(has(cxxConstructExpr()))));
 }
 
+TEST(Traversal, traverseMatcher) {
+
+  StringRef VarDeclCode = R"cpp(
+void foo()
+{
+  int i = 3.0;
+}
+)cpp";
+
+  auto Matcher = varDecl(hasInitializer(floatLiteral()));
+
+  EXPECT_TRUE(
+      notMatches(VarDeclCode, traverse(ast_type_traits::TK_AsIs, Matcher)));
+  EXPECT_TRUE(
+      matches(VarDeclCode,
+              traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+                       Matcher)));
+}
+
+TEST(Traversal, traverseMatcherNesting) {
+
+  StringRef Code = R"cpp(
+float bar(int i)
+{
+  return i;
+}
+
+void foo()
+{
+  bar(bar(3.0));
+}
+)cpp";
+
+  EXPECT_TRUE(matches(
+      Code,
+      traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+               callExpr(has(callExpr(traverse(
+                   ast_type_traits::TK_AsIs,
+                   callExpr(has(implicitCastExpr(has(floatLiteral())))))))))));
+}
+
+TEST(Traversal, traverseMatcherThroughImplicit) {
+  StringRef Code = R"cpp(
+struct S {
+  S(int x);
+};
+
+void constructImplicit() {
+  int a = 8;
+  S s(a);
+}
+  )cpp";
+
+  auto Matcher = traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+                          implicitCastExpr());
+
+  // Verfiy that it does not segfault
+  EXPECT_FALSE(matches(Code, Matcher));
+}
+
+TEST(Traversal, traverseMatcherThroughMemoization) {
+
+  StringRef Code = R"cpp(
+void foo()
+{
+  int i = 3.0;
+}
+  )cpp";
+
+  auto Matcher = varDecl(hasInitializer(floatLiteral()));
+
+  // Matchers such as hasDescendant memoize their result regarding AST
+  // nodes. In the matcher below, the first use of hasDescendant(Matcher)
+  // fails, and the use of it inside the traverse() matcher should pass
+  // causing the overall matcher to be a true match.
+  // This test verifies that the first false result is not re-used, which
+  // would cause the overall matcher to be incorrectly false.
+
+  EXPECT_TRUE(matches(
+      Code, functionDecl(anyOf(
+                hasDescendant(Matcher),
+                traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses,
+                         functionDecl(hasDescendant(Matcher)))))));
+}
+
 TEST(IgnoringImpCasts, MatchesImpCasts) {
   // This test checks that ignoringImpCasts matches when implicit casts are
   // present and its inner matcher alone does not match.


        


More information about the cfe-commits mailing list