[clang] e4d5f00 - [ASTMatchers] Fix hasParent while ignoring unwritten nodes

Stephen Kelly via cfe-commits cfe-commits at lists.llvm.org
Thu Feb 18 07:04:12 PST 2021


Author: Stephen Kelly
Date: 2021-02-18T15:04:03Z
New Revision: e4d5f00093bec4099f1d0496181dc670c42ac220

URL: https://github.com/llvm/llvm-project/commit/e4d5f00093bec4099f1d0496181dc670c42ac220
DIFF: https://github.com/llvm/llvm-project/commit/e4d5f00093bec4099f1d0496181dc670c42ac220.diff

LOG: [ASTMatchers] Fix hasParent while ignoring unwritten nodes

For example, before this patch we can use has() to get from a
cxxRewrittenBinaryOperator to its operand, but hasParent doesn't get
back to the cxxRewrittenBinaryOperator.  This patch fixes that.

Differential Revision: https://reviews.llvm.org/D96113

Added: 
    

Modified: 
    clang/include/clang/AST/ParentMapContext.h
    clang/lib/AST/ParentMapContext.cpp
    clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/ParentMapContext.h b/clang/include/clang/AST/ParentMapContext.h
index a0412380a864..2edbc987850d 100644
--- a/clang/include/clang/AST/ParentMapContext.h
+++ b/clang/include/clang/AST/ParentMapContext.h
@@ -64,9 +64,10 @@ class ParentMapContext {
   Expr *traverseIgnored(Expr *E) const;
   DynTypedNode traverseIgnored(const DynTypedNode &N) const;
 
+  class ParentMap;
+
 private:
   ASTContext &ASTCtx;
-  class ParentMap;
   TraversalKind Traversal = TK_AsIs;
   std::unique_ptr<ParentMap> Parents;
 };

diff  --git a/clang/lib/AST/ParentMapContext.cpp b/clang/lib/AST/ParentMapContext.cpp
index cb4995312efa..4a3e0a99c8a6 100644
--- a/clang/lib/AST/ParentMapContext.cpp
+++ b/clang/lib/AST/ParentMapContext.cpp
@@ -49,7 +49,17 @@ DynTypedNode ParentMapContext::traverseIgnored(const DynTypedNode &N) const {
   return N;
 }
 
+template <typename T, typename... U>
+std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+matchParents(const DynTypedNodeList &NodeList,
+             ParentMapContext::ParentMap *ParentMap);
+
+template <typename, typename...> struct MatchParents;
+
 class ParentMapContext::ParentMap {
+
+  template <typename, typename...> friend struct ::MatchParents;
+
   /// Contains parents of a node.
   using ParentVector = llvm::SmallVector<DynTypedNode, 2>;
 
@@ -117,11 +127,72 @@ class ParentMapContext::ParentMap {
     if (Node.getNodeKind().hasPointerIdentity()) {
       auto ParentList =
           getDynNodeFromMap(Node.getMemoizationData(), PointerParents);
-      if (ParentList.size() == 1 && TK == TK_IgnoreUnlessSpelledInSource) {
-        const auto *E = ParentList[0].get<Expr>();
-        const auto *Child = Node.get<Expr>();
-        if (E && Child)
-          return AscendIgnoreUnlessSpelledInSource(E, Child);
+      if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) {
+
+        const auto *ChildExpr = Node.get<Expr>();
+
+        {
+          // Don't match explicit node types because 
diff erent stdlib
+          // implementations implement this in 
diff erent ways and have
+          // 
diff erent intermediate nodes.
+          // Look up 4 levels for a cxxRewrittenBinaryOperator as that is
+          // enough for the major stdlib implementations.
+          auto RewrittenBinOpParentsList = ParentList;
+          int I = 0;
+          while (ChildExpr && RewrittenBinOpParentsList.size() == 1 &&
+                 I++ < 4) {
+            const auto *S = RewrittenBinOpParentsList[0].get<Stmt>();
+            if (!S)
+              break;
+
+            const auto *RWBO = dyn_cast<CXXRewrittenBinaryOperator>(S);
+            if (!RWBO) {
+              RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents);
+              continue;
+            }
+            if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr &&
+                RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr)
+              break;
+            return DynTypedNode::create(*RWBO);
+          }
+        }
+
+        const auto *ParentExpr = ParentList[0].get<Expr>();
+        if (ParentExpr && ChildExpr)
+          return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr);
+
+        {
+          auto AncestorNodes =
+              matchParents<DeclStmt, CXXForRangeStmt>(ParentList, this);
+          if (std::get<bool>(AncestorNodes) &&
+              std::get<const CXXForRangeStmt *>(AncestorNodes)
+                      ->getLoopVarStmt() ==
+                  std::get<const DeclStmt *>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes = matchParents<VarDecl, DeclStmt, CXXForRangeStmt>(
+              ParentList, this);
+          if (std::get<bool>(AncestorNodes) &&
+              std::get<const CXXForRangeStmt *>(AncestorNodes)
+                      ->getRangeStmt() ==
+                  std::get<const DeclStmt *>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes =
+              matchParents<CXXMethodDecl, CXXRecordDecl, LambdaExpr>(ParentList,
+                                                                     this);
+          if (std::get<bool>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
+        {
+          auto AncestorNodes =
+              matchParents<FunctionTemplateDecl, CXXRecordDecl, LambdaExpr>(
+                  ParentList, this);
+          if (std::get<bool>(AncestorNodes))
+            return std::get<DynTypedNodeList>(AncestorNodes);
+        }
       }
       return ParentList;
     }
@@ -194,6 +265,59 @@ class ParentMapContext::ParentMap {
   }
 };
 
+template <typename Tuple, std::size_t... Is>
+auto tuple_pop_front_impl(const Tuple &tuple, std::index_sequence<Is...>) {
+  return std::make_tuple(std::get<1 + Is>(tuple)...);
+}
+
+template <typename Tuple> auto tuple_pop_front(const Tuple &tuple) {
+  return tuple_pop_front_impl(
+      tuple, std::make_index_sequence<std::tuple_size<Tuple>::value - 1>());
+}
+
+template <typename T, typename... U> struct MatchParents {
+  static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+  match(const DynTypedNodeList &NodeList,
+        ParentMapContext::ParentMap *ParentMap) {
+    if (const auto *TypedNode = NodeList[0].get<T>()) {
+      auto NextParentList =
+          ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
+      if (NextParentList.size() == 1) {
+        auto TailTuple = MatchParents<U...>::match(NextParentList, ParentMap);
+        if (std::get<bool>(TailTuple)) {
+          return std::tuple_cat(
+              std::make_tuple(true, std::get<DynTypedNodeList>(TailTuple),
+                              TypedNode),
+              tuple_pop_front(tuple_pop_front(TailTuple)));
+        }
+      }
+    }
+    return std::tuple_cat(std::make_tuple(false, NodeList),
+                          std::tuple<const T *, const U *...>());
+  }
+};
+
+template <typename T> struct MatchParents<T> {
+  static std::tuple<bool, DynTypedNodeList, const T *>
+  match(const DynTypedNodeList &NodeList,
+        ParentMapContext::ParentMap *ParentMap) {
+    if (const auto *TypedNode = NodeList[0].get<T>()) {
+      auto NextParentList =
+          ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
+      if (NextParentList.size() == 1)
+        return std::make_tuple(true, NodeList, TypedNode);
+    }
+    return std::make_tuple(false, NodeList, nullptr);
+  }
+};
+
+template <typename T, typename... U>
+std::tuple<bool, DynTypedNodeList, const T *, const U *...>
+matchParents(const DynTypedNodeList &NodeList,
+             ParentMapContext::ParentMap *ParentMap) {
+  return MatchParents<T, U...>::match(NodeList, ParentMap);
+}
+
 /// Template specializations to abstract away from pointers and TypeLocs.
 /// @{
 template <typename T> static DynTypedNode createDynTypedNode(const T &Node) {

diff  --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index b756cf815aaf..8a6e94cf5624 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -2933,6 +2933,37 @@ struct CtorInitsNonTrivial : NonTrivial
     EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
     EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
   }
+  {
+    auto M = ifStmt(hasParent(compoundStmt(hasParent(cxxForRangeStmt()))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(
+        has(varDecl(hasName("i"), hasParent(cxxForRangeStmt()))));
+    EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(hasDescendant(varDecl(
+        hasName("i"), hasParent(declStmt(hasParent(cxxForRangeStmt()))))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
+        to(varDecl(hasName("arr"))), hasParent(cxxForRangeStmt()))));
+    EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+
+  {
+    auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
+        to(varDecl(hasName("arr"))), hasParent(varDecl(hasParent(declStmt(
+                                         hasParent(cxxForRangeStmt()))))))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
 
   Code = R"cpp(
   struct Range {
@@ -3035,6 +3066,15 @@ struct CtorInitsNonTrivial : NonTrivial
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+  {
+    auto M = cxxForRangeStmt(hasInitStatement(declStmt(
+        hasSingleDecl(varDecl(hasName("a"))), hasParent(cxxForRangeStmt()))));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
 
   Code = R"cpp(
   struct Range {
@@ -3511,6 +3551,20 @@ void func15() {
                            forFunction(functionDecl(hasName("func13"))))))),
       langCxx20OrLater()));
 
+  EXPECT_TRUE(matches(Code,
+                      traverse(TK_IgnoreUnlessSpelledInSource,
+                               compoundStmt(hasParent(lambdaExpr(forFunction(
+                                   functionDecl(hasName("func13"))))))),
+                      langCxx20OrLater()));
+
+  EXPECT_TRUE(matches(
+      Code,
+      traverse(TK_IgnoreUnlessSpelledInSource,
+               templateTypeParmDecl(hasName("TemplateType"),
+                                    hasParent(lambdaExpr(forFunction(
+                                        functionDecl(hasName("func14"))))))),
+      langCxx20OrLater()));
+
   EXPECT_TRUE(matches(
       Code,
       traverse(TK_IgnoreUnlessSpelledInSource,
@@ -3635,6 +3689,16 @@ void binop()
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+  {
+    auto M = cxxRewrittenBinaryOperator(
+        hasLHS(expr(hasParent(cxxRewrittenBinaryOperator()))),
+        hasRHS(expr(hasParent(cxxRewrittenBinaryOperator()))));
+    EXPECT_FALSE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
   {
     EXPECT_TRUE(matchesConditionally(
         Code,


        


More information about the cfe-commits mailing list