[clang] [ASTMatchers] forCallable should not erase binding on success (PR #89657)

Marco Borgeaud via cfe-commits cfe-commits at lists.llvm.org
Wed Apr 24 22:39:26 PDT 2024


https://github.com/marco-antognini-sonarsource updated https://github.com/llvm/llvm-project/pull/89657

>From ebc417fe98f1cb0e030ec77c17c0150c3fcca7f9 Mon Sep 17 00:00:00 2001
From: Marco Borgeaud
 <89914223+marco-antognini-sonarsource at users.noreply.github.com>
Date: Fri, 19 Apr 2024 17:33:22 +0200
Subject: [PATCH] forCallable should not erase binding on success

Do not erase Builder when the first check fails because it could succeed
on the second stack frame.

The problem was that `InnerMatcher.matches` erases the bindings when it
returns false. The appropriate solution is to pass a copy of the
bindings, similar to what `matchesFirstInRange` does.
---
 clang/include/clang/ASTMatchers/ASTMatchers.h |  16 ++-
 .../ASTMatchers/ASTMatchersTraversalTest.cpp  | 130 ++++++++++++++++++
 2 files changed, 142 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h
index dc1f49525a004a..54671fe4043378 100644
--- a/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ b/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -8341,20 +8341,28 @@ AST_MATCHER_P(Stmt, forCallable, internal::Matcher<Decl>, InnerMatcher) {
     const auto &CurNode = Stack.back();
     Stack.pop_back();
     if (const auto *FuncDeclNode = CurNode.get<FunctionDecl>()) {
-      if (InnerMatcher.matches(*FuncDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*FuncDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *LambdaExprNode = CurNode.get<LambdaExpr>()) {
+      BoundNodesTreeBuilder B = *Builder;
       if (InnerMatcher.matches(*LambdaExprNode->getCallOperator(), Finder,
-                               Builder)) {
+                               &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *ObjCMethodDeclNode = CurNode.get<ObjCMethodDecl>()) {
-      if (InnerMatcher.matches(*ObjCMethodDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*ObjCMethodDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else if (const auto *BlockDeclNode = CurNode.get<BlockDecl>()) {
-      if (InnerMatcher.matches(*BlockDeclNode, Finder, Builder)) {
+      BoundNodesTreeBuilder B = *Builder;
+      if (InnerMatcher.matches(*BlockDeclNode, Finder, &B)) {
+        *Builder = std::move(B);
         return true;
       }
     } else {
diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
index 6911d7600a7188..2ecbdf659358fe 100644
--- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
+++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
@@ -5916,6 +5916,37 @@ TEST(StatementMatcher, ForCallable) {
   EXPECT_TRUE(notMatches(CppString2,
                          returnStmt(forCallable(functionDecl(hasName("F"))))));
 
+  StringRef CodeWithDeepCallExpr = R"cpp(
+void Other();
+void Function() {
+  {
+    (
+      Other()
+    );
+  }
+}
+)cpp";
+  auto ForCallableFirst =
+      callExpr(forCallable(functionDecl(hasName("Function"))),
+               callee(functionDecl(hasName("Other")).bind("callee")))
+          .bind("call");
+  auto ForCallableSecond =
+      callExpr(callee(functionDecl(hasName("Other")).bind("callee")),
+               forCallable(functionDecl(hasName("Function"))))
+          .bind("call");
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableFirst,
+      std::make_unique<VerifyIdIsBoundTo<CallExpr>>("call")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableFirst,
+      std::make_unique<VerifyIdIsBoundTo<FunctionDecl>>("callee")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableSecond,
+      std::make_unique<VerifyIdIsBoundTo<CallExpr>>("call")));
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      CodeWithDeepCallExpr, ForCallableSecond,
+      std::make_unique<VerifyIdIsBoundTo<FunctionDecl>>("callee")));
+
   // These tests are specific to forCallable().
   StringRef ObjCString1 = "@interface I"
                           "-(void) foo;"
@@ -5957,6 +5988,105 @@ TEST(StatementMatcher, ForCallable) {
       binaryOperator(forCallable(blockDecl()))));
 }
 
+namespace {
+class ForCallablePreservesBindingWithMultipleParentsTestCallback
+    : public BoundNodesCallback {
+public:
+  bool run(const BoundNodes *BoundNodes) override {
+    FunctionDecl const *FunDecl =
+        BoundNodes->getNodeAs<FunctionDecl>("funDecl");
+    // Validate test assumptions. This would be expressed as ASSERT_* in
+    // a TEST().
+    if (!FunDecl) {
+      EXPECT_TRUE(false && "Incorrect test setup");
+      return false;
+    }
+    auto const *FunDef = FunDecl->getDefinition();
+    if (!FunDef || !FunDef->getBody() ||
+        FunDef->getNameAsString() != "Function") {
+      EXPECT_TRUE(false && "Incorrect test setup");
+      return false;
+    }
+
+    ExpectCorrectResult(
+        "Baseline",
+        callExpr(callee(cxxMethodDecl().bind("callee"))).bind("call"), //
+        FunDecl);
+
+    ExpectCorrectResult("ForCallable first",
+                        callExpr(forCallable(equalsNode(FunDecl)),
+                                 callee(cxxMethodDecl().bind("callee")))
+                            .bind("call"),
+                        FunDecl);
+
+    ExpectCorrectResult("ForCallable second",
+                        callExpr(callee(cxxMethodDecl().bind("callee")),
+                                 forCallable(equalsNode(FunDecl)))
+                            .bind("call"),
+                        FunDecl);
+
+    // This value does not really matter: the EXPECT_* will set the exit code.
+    return true;
+  }
+
+  bool run(const BoundNodes *BoundNodes, ASTContext *Context) override {
+    return run(BoundNodes);
+  }
+
+private:
+  void ExpectCorrectResult(StringRef LogInfo,
+                           ArrayRef<BoundNodes> Results) const {
+    EXPECT_EQ(Results.size(), 1u) << LogInfo;
+    if (Results.empty())
+      return;
+    auto const &R = Results.front();
+    EXPECT_TRUE(R.getNodeAs<CallExpr>("call")) << LogInfo;
+    EXPECT_TRUE(R.getNodeAs<CXXMethodDecl>("callee")) << LogInfo;
+  }
+
+  template <typename MatcherT>
+  void ExpectCorrectResult(StringRef LogInfo, MatcherT Matcher,
+                           FunctionDecl const *FunDef) const {
+    auto &Context = FunDef->getASTContext();
+    auto const &Results = match(findAll(Matcher), *FunDef->getBody(), Context);
+    ExpectCorrectResult(LogInfo, Results);
+  }
+};
+} // namespace
+
+TEST(StatementMatcher, ForCallablePreservesBindingWithMultipleParents) {
+  // Tests in this file are fairly simple and therefore can rely on matches,
+  // matchAndVerifyResultTrue, etc. This test, however, needs a FunctionDecl* in
+  // order to call equalsNode in order to reproduce the observed issue (bindings
+  // being removed despite forCallable matching the node).
+  //
+  // Because of this and because the machinery to compile the code into an
+  // ASTUnit is not exposed outside matchAndVerifyResultConditionally, it is
+  // cheaper to have a custom BoundNodesCallback for the purpose of this test.
+  StringRef codeWithTemplateFunction = R"cpp(
+struct Klass {
+  void Method();
+  template <typename T>
+  void Function(T t); // Declaration
+};
+
+void Instantiate(Klass k) {
+  k.Function(0);
+}
+
+template <typename T>
+void Klass::Function(T t) { // Definition
+  // Compound statement has two parents: the declaration and the definition.
+  Method();
+}
+)cpp";
+  EXPECT_TRUE(matchAndVerifyResultTrue(
+      codeWithTemplateFunction,
+      callExpr(callee(functionDecl(hasName("Function")).bind("funDecl"))),
+      std::make_unique<
+          ForCallablePreservesBindingWithMultipleParentsTestCallback>()));
+}
+
 TEST(Matcher, ForEachOverriden) {
   const auto ForEachOverriddenInClass = [](const char *ClassName) {
     return cxxMethodDecl(ofClass(hasName(ClassName)), isVirtual(),



More information about the cfe-commits mailing list