[clang] [ASTMatchers] forCallable should not erase binding on success (PR #89657)

via cfe-commits cfe-commits at lists.llvm.org
Wed Apr 24 01:03:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Marco Borgeaud (marco-antognini-sonarsource)

<details>
<summary>Changes</summary>

Do not erase Builder when the first check fails because it could succeed on the second stack frame.

The problem was that `InnerMatcher.matches` erases the bindings when it returns false. The appropriate solution is to pass a copy of the bindings, similar to what `matchesFirstInRange` does.

---
Full diff: https://github.com/llvm/llvm-project/pull/89657.diff


2 Files Affected:

- (modified) clang/include/clang/ASTMatchers/ASTMatchers.h (+12-4) 
- (modified) clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp (+130) 


``````````diff
diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index 8a2bbfff9e9e6b..bfa6b5fba830ee 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -8363,20 +8363,28 @@ AST_MATCHER_P(Stmt, forCallable, internal::Matcher<Decl>, InnerMatcher) {
     const auto &CurNode = Stack.back();
     Stack.pop_back();
     if (const auto *FuncDeclNode = CurNode.get<FunctionDecl>()) {
-      if (InnerMatcher.matches(*FuncDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*FuncDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *LambdaExprNode = CurNode.get<LambdaExpr>()) {
+      BoundNodesTreeBuilder B = *Builder;
       if (InnerMatcher.matches(*LambdaExprNode->getCallOperator(), Finder,
-                               Builder)) {
+                               &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *ObjCMethodDeclNode = CurNode.get<ObjCMethodDecl>()) {
-      if (InnerMatcher.matches(*ObjCMethodDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*ObjCMethodDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *BlockDeclNode = CurNode.get<BlockDecl>()) {
-      if (InnerMatcher.matches(*BlockDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*BlockDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else {
diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index f198dc71eb8337..0a444e16b1c25e 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -5979,6 +5979,37 @@ TEST(StatementMatcher, ForCallable) {
   EXPECT_TRUE(notMatches(CppString2,
                          returnStmt(forCallable(functionDecl(hasName("F"))))));
 
+  StringRef CodeWithDeepCallExpr = R"cpp(
+void Other();
+void Function() {
+  {
+    (
+      Other()
+    );
+  }
+}
+)cpp";
+  auto ForCallableFirst =
+      callExpr(forCallable(functionDecl(hasName("Function"))),
+               callee(functionDecl(hasName("Other")).bind("callee")))
+          .bind("call");
+  auto ForCallableSecond =
+      callExpr(callee(functionDecl(hasName("Other")).bind("callee")),
+               forCallable(functionDecl(hasName("Function"))))
+          .bind("call");
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableFirst,
+      std::make_unique<VerifyIdIsBoundTo<CallExpr>>("call")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableFirst,
+      std::make_unique<VerifyIdIsBoundTo<FunctionDecl>>("callee")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableSecond,
+      std::make_unique<VerifyIdIsBoundTo<CallExpr>>("call")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableSecond,
+      std::make_unique<VerifyIdIsBoundTo<FunctionDecl>>("callee")));
+
   // These tests are specific to forCallable().
   StringRef ObjCString1 = "@interface I"
                           "-(void) foo;"
@@ -6020,6 +6051,105 @@ TEST(StatementMatcher, ForCallable) {
       binaryOperator(forCallable(blockDecl()))));
 }
 
+namespace {
+class ForCallablePreservesBindingWithMultipleParentsTestCallback
+    : public BoundNodesCallback {
+public:
+  bool run(const BoundNodes *BoundNodes) override {
+    FunctionDecl const *FunDecl =
+        BoundNodes->getNodeAs<FunctionDecl>("funDecl");
+    // Validate test assumptions. This would be expressed as ASSERT_* in
+    // a TEST().
+    if (!FunDecl) {
+      EXPECT_TRUE(false && "Incorrect test setup");
+      return false;
+    }
+    auto const *FunDef = FunDecl->getDefinition();
+    if (!FunDef || !FunDef->getBody() ||
+        FunDef->getNameAsString() != "Function") {
+      EXPECT_TRUE(false && "Incorrect test setup");
+      return false;
+    }
+
+    ExpectCorrectResult(
+        "Baseline",
+        callExpr(callee(cxxMethodDecl().bind("callee"))).bind("call"), //
+        FunDecl);
+
+    ExpectCorrectResult("ForCallable first",
+                        callExpr(forCallable(equalsNode(FunDecl)),
+                                 callee(cxxMethodDecl().bind("callee")))
+                            .bind("call"),
+                        FunDecl);
+
+    ExpectCorrectResult("ForCallable second",
+                        callExpr(callee(cxxMethodDecl().bind("callee")),
+                                 forCallable(equalsNode(FunDecl)))
+                            .bind("call"),
+                        FunDecl);
+
+    // This value does not really matter: the EXPECT_* will set the exit code.
+    return true;
+  }
+
+  bool run(const BoundNodes *BoundNodes, ASTContext *Context) override {
+    return run(BoundNodes);
+  }
+
+private:
+  void ExpectCorrectResult(StringRef LogInfo,
+                           ArrayRef<BoundNodes> Results) const {
+    EXPECT_EQ(Results.size(), 1u) << LogInfo;
+    if (Results.empty())
+      return;
+    auto const &R = Results.front();
+    EXPECT_TRUE(R.getNodeAs<CallExpr>("call")) << LogInfo;
+    EXPECT_TRUE(R.getNodeAs<CXXMethodDecl>("callee")) << LogInfo;
+  }
+
+  template <typename MatcherT>
+  void ExpectCorrectResult(StringRef LogInfo, MatcherT Matcher,
+                           FunctionDecl const *FunDef) const {
+    auto &Context = FunDef->getASTContext();
+    auto const &Results = match(findAll(Matcher), *FunDef->getBody(), Context);
+    ExpectCorrectResult(LogInfo, Results);
+  }
+};
+} // namespace
+
+TEST(StatementMatcher, ForCallablePreservesBindingWithMultipleParents) {
+  // Tests in this file are fairly simple and therefore can rely on matches,
+  // matchAndVerifyResultTrue, etc. This test, however, needs a FunctionDecl* in
+  // order to call equalsNode in order to reproduce the observed issue (bindings
+  // being removed despite forCallable matching the node).
+  //
+  // Because of this and because the machinery to compile the code into an
+  // ASTUnit is not exposed outside matchAndVerifyResultConditionally, it is
+  // cheaper to have a custom BoundNodesCallback for the purpose of this test.
+  StringRef codeWithTemplateFunction = R"cpp(
+struct Klass {
+  void Method();
+  template <typename T>
+  void Function(T t); // Declaration
+};
+
+void Instantiate(Klass k) {
+  k.Function(0);
+}
+
+template <typename T>
+void Klass::Function(T t) { // Definition
+  // Compound statement has two parents: the declaration and the definition.
+  Method();
+}
+)cpp";
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      codeWithTemplateFunction,
+      callExpr(callee(functionDecl(hasName("Function")).bind("funDecl"))),
+      std::make_unique<
+          ForCallablePreservesBindingWithMultipleParentsTestCallback>()));
+}
+
 TEST(Matcher, ForEachOverriden) {
   const auto ForEachOverriddenInClass = [](const char *ClassName) {
     return cxxMethodDecl(ofClass(hasName(ClassName)), isVirtual(),

``````````

</details>


https://github.com/llvm/llvm-project/pull/89657


More information about the cfe-commits mailing list