[clang] 7912508 - [ASTMatchers] Fix traversal below range-for elements

Stephen Kelly via cfe-commits cfe-commits at lists.llvm.org
Sat Jan 30 05:49:14 PST 2021


Author: Stephen Kelly
Date: 2021-01-30T13:47:14Z
New Revision: 79125085f16540579d27c7e4987f63eef9c4aa23

URL: https://github.com/llvm/llvm-project/commit/79125085f16540579d27c7e4987f63eef9c4aa23
DIFF: https://github.com/llvm/llvm-project/commit/79125085f16540579d27c7e4987f63eef9c4aa23.diff

LOG: [ASTMatchers] Fix traversal below range-for elements

Differential Revision: https://reviews.llvm.org/D95562

Added: 
    

Modified: 
    clang/lib/ASTMatchers/ASTMatchFinder.cpp
    clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
index 5034203840fc..89e83ee61574 100644
--- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp
+++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp
@@ -243,10 +243,14 @@ class MatchChildASTVisitor
       return true;
     ScopedIncrement ScopedDepth(&CurrentDepth);
     if (auto *Init = Node->getInit())
-      if (!match(*Init))
+      if (!traverse(*Init))
         return false;
-    if (!match(*Node->getLoopVariable()) || !match(*Node->getRangeInit()) ||
-        !match(*Node->getBody()))
+    if (!match(*Node->getLoopVariable()))
+      return false;
+    if (match(*Node->getRangeInit()))
+      if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
+        return false;
+    if (!match(*Node->getBody()))
       return false;
     return VisitorBase::TraverseStmt(Node->getBody());
   }
@@ -488,15 +492,21 @@ class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
 
   bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
     if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
-      for (auto *SubStmt : RF->children()) {
-        if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() ||
-            SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) {
-          TraverseStmt(SubStmt, Queue);
-        } else {
-          ASTNodeNotSpelledInSourceScope RAII(this, true);
-          TraverseStmt(SubStmt, Queue);
+      {
+        ASTNodeNotAsIsSourceScope RAII(this, true);
+        TraverseStmt(RF->getInit());
+        // Don't traverse under the loop variable
+        match(*RF->getLoopVariable());
+        TraverseStmt(RF->getRangeInit());
+      }
+      {
+        ASTNodeNotSpelledInSourceScope RAII(this, true);
+        for (auto *SubStmt : RF->children()) {
+          if (SubStmt != RF->getBody())
+            TraverseStmt(SubStmt);
         }
       }
+      TraverseStmt(RF->getBody());
       return true;
     } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
       {

diff  --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index a3a09c426673..cbea274cecc9 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -2820,6 +2820,36 @@ struct CtorInitsNonTrivial : NonTrivial
     EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
   }
 
+  Code = R"cpp(
+  struct Range {
+    int* begin() const;
+    int* end() const;
+  };
+  Range getRange(int);
+
+  void rangeFor()
+  {
+    for (auto i : getRange(42))
+    {
+    }
+  }
+  )cpp";
+  {
+    auto M = integerLiteral(equals(42));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = callExpr(hasDescendant(integerLiteral(equals(42))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+  {
+    auto M = compoundStmt(hasDescendant(integerLiteral(equals(42))));
+    EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
+    EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
+  }
+
   Code = R"cpp(
   void rangeFor()
   {
@@ -2891,6 +2921,40 @@ struct CtorInitsNonTrivial : NonTrivial
         matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
                              true, {"-std=c++20"}));
   }
+
+  Code = R"cpp(
+  struct Range {
+    int* begin() const;
+    int* end() const;
+  };
+  Range getRange(int);
+
+  int getNum(int);
+
+  void rangeFor()
+  {
+    for (auto j = getNum(42); auto i : getRange(j))
+    {
+    }
+  }
+  )cpp";
+  {
+    auto M = integerLiteral(equals(42));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
+  {
+    auto M = compoundStmt(hasDescendant(integerLiteral(equals(42))));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
+    EXPECT_TRUE(
+        matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
+                             true, {"-std=c++20"}));
+  }
+
   Code = R"cpp(
 void hasDefaultArg(int i, int j = 0)
 {


        


More information about the cfe-commits mailing list