[clang] 9cfec72 - [clang] Refactor AST printing tests to share more infrastructure

Nathan Ridge via cfe-commits cfe-commits at lists.llvm.org
Wed Jul 14 16:44:31 PDT 2021


Author: Nathan Ridge
Date: 2021-07-14T19:44:18-04:00
New Revision: 9cfec72ffeec242783b70e792c50bd163dcf9dbb

URL: https://github.com/llvm/llvm-project/commit/9cfec72ffeec242783b70e792c50bd163dcf9dbb
DIFF: https://github.com/llvm/llvm-project/commit/9cfec72ffeec242783b70e792c50bd163dcf9dbb.diff

LOG: [clang] Refactor AST printing tests to share more infrastructure

Differential Revision: https://reviews.llvm.org/D105457

Added: 
    

Modified: 
    clang/unittests/AST/ASTPrint.h
    clang/unittests/AST/DeclPrinterTest.cpp
    clang/unittests/AST/NamedDeclPrinterTest.cpp
    clang/unittests/AST/StmtPrinterTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h
index c3b6b842316d9..58364499bcd9f 100644
--- a/clang/unittests/AST/ASTPrint.h
+++ b/clang/unittests/AST/ASTPrint.h
@@ -19,72 +19,95 @@
 
 namespace clang {
 
-using PolicyAdjusterType =
-    Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
-
-static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
-                      const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
-  assert(S != nullptr && "Expected non-null Stmt");
-  PrintingPolicy Policy = Context->getPrintingPolicy();
-  if (PolicyAdjuster)
-    (*PolicyAdjuster)(Policy);
-  S->printPretty(Out, /*Helper*/ nullptr, Policy);
-}
+using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>;
+
+template <typename NodeType>
+using NodePrinter =
+    std::function<void(llvm::raw_ostream &Out, const ASTContext *Context,
+                       const NodeType *Node,
+                       PrintingPolicyAdjuster PolicyAdjuster)>;
 
+template <typename NodeType>
+using NodeFilter = std::function<bool(const NodeType *Node)>;
+
+template <typename NodeType>
 class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
+  using PrinterT = NodePrinter<NodeType>;
+  using FilterT = NodeFilter<NodeType>;
+
   SmallString<1024> Printed;
-  unsigned NumFoundStmts;
-  PolicyAdjusterType PolicyAdjuster;
+  unsigned NumFoundNodes;
+  PrinterT Printer;
+  FilterT Filter;
+  PrintingPolicyAdjuster PolicyAdjuster;
 
 public:
-  PrintMatch(PolicyAdjusterType PolicyAdjuster)
-      : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
+  PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster,
+             FilterT Filter)
+      : NumFoundNodes(0), Printer(std::move(Printer)),
+        Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {}
 
   void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
-    const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
-    if (!S)
+    const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id");
+    if (!N || !Filter(N))
       return;
-    NumFoundStmts++;
-    if (NumFoundStmts > 1)
+    NumFoundNodes++;
+    if (NumFoundNodes > 1)
       return;
 
     llvm::raw_svector_ostream Out(Printed);
-    PrintStmt(Out, Result.Context, S, PolicyAdjuster);
+    Printer(Out, Result.Context, N, PolicyAdjuster);
   }
 
   StringRef getPrinted() const { return Printed; }
 
-  unsigned getNumFoundStmts() const { return NumFoundStmts; }
+  unsigned getNumFoundNodes() const { return NumFoundNodes; }
 };
 
-template <typename T>
-::testing::AssertionResult
-PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
-                   const T &NodeMatch, StringRef ExpectedPrinted,
-                   PolicyAdjusterType PolicyAdjuster = None) {
+template <typename NodeType> bool NoNodeFilter(const NodeType *) {
+  return true;
+}
 
-  PrintMatch Printer(PolicyAdjuster);
+template <typename NodeType, typename Matcher>
+::testing::AssertionResult
+PrintedNodeMatches(StringRef Code, const std::vector<std::string> &Args,
+                   const Matcher &NodeMatch, StringRef ExpectedPrinted,
+                   StringRef FileName, NodePrinter<NodeType> Printer,
+                   PrintingPolicyAdjuster PolicyAdjuster = nullptr,
+                   bool AllowError = false,
+                   // Would like to use a lambda for the default value, but that
+                   // trips gcc 7 up.
+                   NodeFilter<NodeType> Filter = &NoNodeFilter<NodeType>) {
+
+  PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter);
   ast_matchers::MatchFinder Finder;
-  Finder.addMatcher(NodeMatch, &Printer);
+  Finder.addMatcher(NodeMatch, &Callback);
   std::unique_ptr<tooling::FrontendActionFactory> Factory(
       tooling::newFrontendActionFactory(&Finder));
 
-  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
+  bool ToolResult;
+  if (FileName.empty()) {
+    ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args);
+  } else {
+    ToolResult =
+        tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName);
+  }
+  if (!ToolResult && !AllowError)
     return testing::AssertionFailure()
            << "Parsing error in \"" << Code.str() << "\"";
 
-  if (Printer.getNumFoundStmts() == 0)
-    return testing::AssertionFailure() << "Matcher didn't find any statements";
+  if (Callback.getNumFoundNodes() == 0)
+    return testing::AssertionFailure() << "Matcher didn't find any nodes";
 
-  if (Printer.getNumFoundStmts() > 1)
+  if (Callback.getNumFoundNodes() > 1)
     return testing::AssertionFailure()
-           << "Matcher should match only one statement (found "
-           << Printer.getNumFoundStmts() << ")";
+           << "Matcher should match only one node (found "
+           << Callback.getNumFoundNodes() << ")";
 
-  if (Printer.getPrinted() != ExpectedPrinted)
+  if (Callback.getPrinted() != ExpectedPrinted)
     return ::testing::AssertionFailure()
            << "Expected \"" << ExpectedPrinted.str() << "\", got \""
-           << Printer.getPrinted().str() << "\"";
+           << Callback.getPrinted().str() << "\"";
 
   return ::testing::AssertionSuccess();
 }

diff  --git a/clang/unittests/AST/DeclPrinterTest.cpp b/clang/unittests/AST/DeclPrinterTest.cpp
index e70d2bef72121..bdc23f33f39b0 100644
--- a/clang/unittests/AST/DeclPrinterTest.cpp
+++ b/clang/unittests/AST/DeclPrinterTest.cpp
@@ -18,6 +18,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "ASTPrint.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
@@ -32,10 +33,8 @@ using namespace tooling;
 
 namespace {
 
-using PrintingPolicyModifier = void (*)(PrintingPolicy &policy);
-
 void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
-               PrintingPolicyModifier PolicyModifier) {
+               PrintingPolicyAdjuster PolicyModifier) {
   PrintingPolicy Policy = Context->getPrintingPolicy();
   Policy.TerseOutput = true;
   Policy.Indentation = 0;
@@ -44,74 +43,23 @@ void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
   D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false);
 }
 
-class PrintMatch : public MatchFinder::MatchCallback {
-  SmallString<1024> Printed;
-  unsigned NumFoundDecls;
-  PrintingPolicyModifier PolicyModifier;
-
-public:
-  PrintMatch(PrintingPolicyModifier PolicyModifier)
-      : NumFoundDecls(0), PolicyModifier(PolicyModifier) {}
-
-  void run(const MatchFinder::MatchResult &Result) override {
-    const Decl *D = Result.Nodes.getNodeAs<Decl>("id");
-    if (!D || D->isImplicit())
-      return;
-    NumFoundDecls++;
-    if (NumFoundDecls > 1)
-      return;
-
-    llvm::raw_svector_ostream Out(Printed);
-    PrintDecl(Out, Result.Context, D, PolicyModifier);
-  }
-
-  StringRef getPrinted() const {
-    return Printed;
-  }
-
-  unsigned getNumFoundDecls() const {
-    return NumFoundDecls;
-  }
-};
-
 ::testing::AssertionResult
 PrintedDeclMatches(StringRef Code, const std::vector<std::string> &Args,
                    const DeclarationMatcher &NodeMatch,
                    StringRef ExpectedPrinted, StringRef FileName,
-                   PrintingPolicyModifier PolicyModifier = nullptr,
+                   PrintingPolicyAdjuster PolicyModifier = nullptr,
                    bool AllowError = false) {
-  PrintMatch Printer(PolicyModifier);
-  MatchFinder Finder;
-  Finder.addMatcher(NodeMatch, &Printer);
-  std::unique_ptr<FrontendActionFactory> Factory(
-      newFrontendActionFactory(&Finder));
-
-  if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName) &&
-      !AllowError)
-    return testing::AssertionFailure()
-      << "Parsing error in \"" << Code.str() << "\"";
-
-  if (Printer.getNumFoundDecls() == 0)
-    return testing::AssertionFailure()
-        << "Matcher didn't find any declarations";
-
-  if (Printer.getNumFoundDecls() > 1)
-    return testing::AssertionFailure()
-        << "Matcher should match only one declaration "
-           "(found " << Printer.getNumFoundDecls() << ")";
-
-  if (Printer.getPrinted() != ExpectedPrinted)
-    return ::testing::AssertionFailure()
-      << "Expected \"" << ExpectedPrinted.str() << "\", "
-         "got \"" << Printer.getPrinted().str() << "\"";
-
-  return ::testing::AssertionSuccess();
+  return PrintedNodeMatches<Decl>(
+      Code, Args, NodeMatch, ExpectedPrinted, FileName, PrintDecl,
+      PolicyModifier, AllowError,
+      // Filter out implicit decls
+      [](const Decl *D) { return !D->isImplicit(); });
 }
 
 ::testing::AssertionResult
 PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
                         StringRef ExpectedPrinted,
-                        PrintingPolicyModifier PolicyModifier = nullptr) {
+                        PrintingPolicyAdjuster PolicyModifier = nullptr) {
   std::vector<std::string> Args(1, "-std=c++98");
   return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"),
                             ExpectedPrinted, "input.cc", PolicyModifier);
@@ -120,7 +68,7 @@ PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
 ::testing::AssertionResult
 PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
                         StringRef ExpectedPrinted,
-                        PrintingPolicyModifier PolicyModifier = nullptr) {
+                        PrintingPolicyAdjuster PolicyModifier = nullptr) {
   std::vector<std::string> Args(1, "-std=c++98");
   return PrintedDeclMatches(Code,
                             Args,
@@ -165,7 +113,7 @@ ::testing::AssertionResult PrintedDeclCXX11nonMSCMatches(
 ::testing::AssertionResult
 PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
                         StringRef ExpectedPrinted,
-                        PrintingPolicyModifier PolicyModifier = nullptr) {
+                        PrintingPolicyAdjuster PolicyModifier = nullptr) {
   std::vector<std::string> Args{"-std=c++17", "-fno-delayed-template-parsing"};
   return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc",
                             PolicyModifier);
@@ -174,7 +122,7 @@ PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
 ::testing::AssertionResult
 PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
                       StringRef ExpectedPrinted,
-                      PrintingPolicyModifier PolicyModifier = nullptr) {
+                      PrintingPolicyAdjuster PolicyModifier = nullptr) {
   std::vector<std::string> Args(1, "-std=c11");
   return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c",
                             PolicyModifier);

diff  --git a/clang/unittests/AST/NamedDeclPrinterTest.cpp b/clang/unittests/AST/NamedDeclPrinterTest.cpp
index 1042312e8a730..cd833725b448d 100644
--- a/clang/unittests/AST/NamedDeclPrinterTest.cpp
+++ b/clang/unittests/AST/NamedDeclPrinterTest.cpp
@@ -15,6 +15,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "ASTPrint.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/PrettyPrinter.h"
@@ -66,31 +67,11 @@ ::testing::AssertionResult PrintedDeclMatches(
     const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted,
     StringRef FileName,
     std::function<void(llvm::raw_ostream &, const NamedDecl *)> Print) {
-  PrintMatch Printer(std::move(Print));
-  MatchFinder Finder;
-  Finder.addMatcher(NodeMatch, &Printer);
-  std::unique_ptr<FrontendActionFactory> Factory =
-      newFrontendActionFactory(&Finder);
-
-  if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
-    return testing::AssertionFailure()
-        << "Parsing error in \"" << Code.str() << "\"";
-
-  if (Printer.getNumFoundDecls() == 0)
-    return testing::AssertionFailure()
-        << "Matcher didn't find any named declarations";
-
-  if (Printer.getNumFoundDecls() > 1)
-    return testing::AssertionFailure()
-        << "Matcher should match only one named declaration "
-           "(found " << Printer.getNumFoundDecls() << ")";
-
-  if (Printer.getPrinted() != ExpectedPrinted)
-    return ::testing::AssertionFailure()
-        << "Expected \"" << ExpectedPrinted.str() << "\", "
-           "got \"" << Printer.getPrinted().str() << "\"";
-
-  return ::testing::AssertionSuccess();
+  return PrintedNodeMatches<NamedDecl>(
+      Code, Args, NodeMatch, ExpectedPrinted, FileName,
+      [Print](llvm::raw_ostream &Out, const ASTContext *Context,
+              const NamedDecl *ND,
+              PrintingPolicyAdjuster PolicyAdjuster) { Print(Out, ND); });
 }
 
 ::testing::AssertionResult

diff  --git a/clang/unittests/AST/StmtPrinterTest.cpp b/clang/unittests/AST/StmtPrinterTest.cpp
index 29cdbf75a00c8..65dfec4cc5b4a 100644
--- a/clang/unittests/AST/StmtPrinterTest.cpp
+++ b/clang/unittests/AST/StmtPrinterTest.cpp
@@ -38,11 +38,29 @@ DeclarationMatcher FunctionBodyMatcher(StringRef ContainingFunction) {
                       has(compoundStmt(has(stmt().bind("id")))));
 }
 
+static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
+                      const Stmt *S, PrintingPolicyAdjuster PolicyAdjuster) {
+  assert(S != nullptr && "Expected non-null Stmt");
+  PrintingPolicy Policy = Context->getPrintingPolicy();
+  if (PolicyAdjuster)
+    PolicyAdjuster(Policy);
+  S->printPretty(Out, /*Helper*/ nullptr, Policy);
+}
+
+template <typename Matcher>
+::testing::AssertionResult
+PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
+                   const Matcher &NodeMatch, StringRef ExpectedPrinted,
+                   PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
+  return PrintedNodeMatches<Stmt>(Code, Args, NodeMatch, ExpectedPrinted, "",
+                                  PrintStmt, PolicyAdjuster);
+}
+
 template <typename T>
 ::testing::AssertionResult
 PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch,
                       StringRef ExpectedPrinted,
-                      PolicyAdjusterType PolicyAdjuster = None) {
+                      PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
   const char *StdOpt;
   switch (Standard) {
   case StdVer::CXX98: StdOpt = "-std=c++98"; break;
@@ -64,7 +82,7 @@ template <typename T>
 ::testing::AssertionResult
 PrintedStmtMSMatches(StringRef Code, const T &NodeMatch,
                      StringRef ExpectedPrinted,
-                     PolicyAdjusterType PolicyAdjuster = None) {
+                     PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
   std::vector<std::string> Args = {
     "-std=c++98",
     "-target", "i686-pc-win32",
@@ -79,7 +97,7 @@ template <typename T>
 ::testing::AssertionResult
 PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch,
                        StringRef ExpectedPrinted,
-                       PolicyAdjusterType PolicyAdjuster = None) {
+                       PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
   std::vector<std::string> Args = {
     "-ObjC",
     "-fobjc-runtime=macosx-10.12.0",
@@ -202,10 +220,10 @@ class A {
 };
 )";
   // No implicit 'this'.
-  ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11,
-      CPPSource, memberExpr(anything()).bind("id"), "field",
-      PolicyAdjusterType(
-          [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })));
+  ASSERT_TRUE(PrintedStmtCXXMatches(
+      StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "field",
+
+      [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
   // Print implicit 'this'.
   ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11,
       CPPSource, memberExpr(anything()).bind("id"), "this->field"));
@@ -222,11 +240,10 @@ class A {
 @end
       )";
   // No implicit 'self'.
-  ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"),
-                                     "return ivar;\n",
-                                     PolicyAdjusterType([](PrintingPolicy &PP) {
-                                       PP.SuppressImplicitBase = true;
-                                     })));
+  ASSERT_TRUE(PrintedStmtObjCMatches(
+      ObjCSource, returnStmt().bind("id"), "return ivar;\n",
+
+      [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
   // Print implicit 'self'.
   ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"),
                                      "return self->ivar;\n"));
@@ -243,5 +260,6 @@ TEST(StmtPrinter, TerseOutputWithLambdas) {
   // body not printed when TerseOutput is on.
   ASSERT_TRUE(PrintedStmtCXXMatches(
       StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}",
-      PolicyAdjusterType([](PrintingPolicy &PP) { PP.TerseOutput = true; })));
+
+      [](PrintingPolicy &PP) { PP.TerseOutput = true; }));
 }


        


More information about the cfe-commits mailing list