[clang] 9cfec72 - [clang] Refactor AST printing tests to share more infrastructure
Nathan Ridge via cfe-commits
cfe-commits at lists.llvm.org
Wed Jul 14 16:44:31 PDT 2021
Author: Nathan Ridge
Date: 2021-07-14T19:44:18-04:00
New Revision: 9cfec72ffeec242783b70e792c50bd163dcf9dbb
URL: https://github.com/llvm/llvm-project/commit/9cfec72ffeec242783b70e792c50bd163dcf9dbb
DIFF: https://github.com/llvm/llvm-project/commit/9cfec72ffeec242783b70e792c50bd163dcf9dbb.diff
LOG: [clang] Refactor AST printing tests to share more infrastructure
Differential Revision: https://reviews.llvm.org/D105457
Added:
Modified:
clang/unittests/AST/ASTPrint.h
clang/unittests/AST/DeclPrinterTest.cpp
clang/unittests/AST/NamedDeclPrinterTest.cpp
clang/unittests/AST/StmtPrinterTest.cpp
Removed:
################################################################################
diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h
index c3b6b842316d9..58364499bcd9f 100644
--- a/clang/unittests/AST/ASTPrint.h
+++ b/clang/unittests/AST/ASTPrint.h
@@ -19,72 +19,95 @@
namespace clang {
-using PolicyAdjusterType =
- Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
-
-static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
- const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
- assert(S != nullptr && "Expected non-null Stmt");
- PrintingPolicy Policy = Context->getPrintingPolicy();
- if (PolicyAdjuster)
- (*PolicyAdjuster)(Policy);
- S->printPretty(Out, /*Helper*/ nullptr, Policy);
-}
+using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>;
+
+template <typename NodeType>
+using NodePrinter =
+ std::function<void(llvm::raw_ostream &Out, const ASTContext *Context,
+ const NodeType *Node,
+ PrintingPolicyAdjuster PolicyAdjuster)>;
+template <typename NodeType>
+using NodeFilter = std::function<bool(const NodeType *Node)>;
+
+template <typename NodeType>
class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
+ using PrinterT = NodePrinter<NodeType>;
+ using FilterT = NodeFilter<NodeType>;
+
SmallString<1024> Printed;
- unsigned NumFoundStmts;
- PolicyAdjusterType PolicyAdjuster;
+ unsigned NumFoundNodes;
+ PrinterT Printer;
+ FilterT Filter;
+ PrintingPolicyAdjuster PolicyAdjuster;
public:
- PrintMatch(PolicyAdjusterType PolicyAdjuster)
- : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
+ PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster,
+ FilterT Filter)
+ : NumFoundNodes(0), Printer(std::move(Printer)),
+ Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {}
void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
- const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
- if (!S)
+ const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id");
+ if (!N || !Filter(N))
return;
- NumFoundStmts++;
- if (NumFoundStmts > 1)
+ NumFoundNodes++;
+ if (NumFoundNodes > 1)
return;
llvm::raw_svector_ostream Out(Printed);
- PrintStmt(Out, Result.Context, S, PolicyAdjuster);
+ Printer(Out, Result.Context, N, PolicyAdjuster);
}
StringRef getPrinted() const { return Printed; }
- unsigned getNumFoundStmts() const { return NumFoundStmts; }
+ unsigned getNumFoundNodes() const { return NumFoundNodes; }
};
-template <typename T>
-::testing::AssertionResult
-PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
- const T &NodeMatch, StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
+template <typename NodeType> bool NoNodeFilter(const NodeType *) {
+ return true;
+}
- PrintMatch Printer(PolicyAdjuster);
+template <typename NodeType, typename Matcher>
+::testing::AssertionResult
+PrintedNodeMatches(StringRef Code, const std::vector<std::string> &Args,
+ const Matcher &NodeMatch, StringRef ExpectedPrinted,
+ StringRef FileName, NodePrinter<NodeType> Printer,
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr,
+ bool AllowError = false,
+ // Would like to use a lambda for the default value, but that
+ // trips gcc 7 up.
+ NodeFilter<NodeType> Filter = &NoNodeFilter<NodeType>) {
+
+ PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter);
ast_matchers::MatchFinder Finder;
- Finder.addMatcher(NodeMatch, &Printer);
+ Finder.addMatcher(NodeMatch, &Callback);
std::unique_ptr<tooling::FrontendActionFactory> Factory(
tooling::newFrontendActionFactory(&Finder));
- if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
+ bool ToolResult;
+ if (FileName.empty()) {
+ ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args);
+ } else {
+ ToolResult =
+ tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName);
+ }
+ if (!ToolResult && !AllowError)
return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\"";
- if (Printer.getNumFoundStmts() == 0)
- return testing::AssertionFailure() << "Matcher didn't find any statements";
+ if (Callback.getNumFoundNodes() == 0)
+ return testing::AssertionFailure() << "Matcher didn't find any nodes";
- if (Printer.getNumFoundStmts() > 1)
+ if (Callback.getNumFoundNodes() > 1)
return testing::AssertionFailure()
- << "Matcher should match only one statement (found "
- << Printer.getNumFoundStmts() << ")";
+ << "Matcher should match only one node (found "
+ << Callback.getNumFoundNodes() << ")";
- if (Printer.getPrinted() != ExpectedPrinted)
+ if (Callback.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", got \""
- << Printer.getPrinted().str() << "\"";
+ << Callback.getPrinted().str() << "\"";
return ::testing::AssertionSuccess();
}
diff --git a/clang/unittests/AST/DeclPrinterTest.cpp b/clang/unittests/AST/DeclPrinterTest.cpp
index e70d2bef72121..bdc23f33f39b0 100644
--- a/clang/unittests/AST/DeclPrinterTest.cpp
+++ b/clang/unittests/AST/DeclPrinterTest.cpp
@@ -18,6 +18,7 @@
//
//===----------------------------------------------------------------------===//
+#include "ASTPrint.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
@@ -32,10 +33,8 @@ using namespace tooling;
namespace {
-using PrintingPolicyModifier = void (*)(PrintingPolicy &policy);
-
void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
- PrintingPolicyModifier PolicyModifier) {
+ PrintingPolicyAdjuster PolicyModifier) {
PrintingPolicy Policy = Context->getPrintingPolicy();
Policy.TerseOutput = true;
Policy.Indentation = 0;
@@ -44,74 +43,23 @@ void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false);
}
-class PrintMatch : public MatchFinder::MatchCallback {
- SmallString<1024> Printed;
- unsigned NumFoundDecls;
- PrintingPolicyModifier PolicyModifier;
-
-public:
- PrintMatch(PrintingPolicyModifier PolicyModifier)
- : NumFoundDecls(0), PolicyModifier(PolicyModifier) {}
-
- void run(const MatchFinder::MatchResult &Result) override {
- const Decl *D = Result.Nodes.getNodeAs<Decl>("id");
- if (!D || D->isImplicit())
- return;
- NumFoundDecls++;
- if (NumFoundDecls > 1)
- return;
-
- llvm::raw_svector_ostream Out(Printed);
- PrintDecl(Out, Result.Context, D, PolicyModifier);
- }
-
- StringRef getPrinted() const {
- return Printed;
- }
-
- unsigned getNumFoundDecls() const {
- return NumFoundDecls;
- }
-};
-
::testing::AssertionResult
PrintedDeclMatches(StringRef Code, const std::vector<std::string> &Args,
const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted, StringRef FileName,
- PrintingPolicyModifier PolicyModifier = nullptr,
+ PrintingPolicyAdjuster PolicyModifier = nullptr,
bool AllowError = false) {
- PrintMatch Printer(PolicyModifier);
- MatchFinder Finder;
- Finder.addMatcher(NodeMatch, &Printer);
- std::unique_ptr<FrontendActionFactory> Factory(
- newFrontendActionFactory(&Finder));
-
- if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName) &&
- !AllowError)
- return testing::AssertionFailure()
- << "Parsing error in \"" << Code.str() << "\"";
-
- if (Printer.getNumFoundDecls() == 0)
- return testing::AssertionFailure()
- << "Matcher didn't find any declarations";
-
- if (Printer.getNumFoundDecls() > 1)
- return testing::AssertionFailure()
- << "Matcher should match only one declaration "
- "(found " << Printer.getNumFoundDecls() << ")";
-
- if (Printer.getPrinted() != ExpectedPrinted)
- return ::testing::AssertionFailure()
- << "Expected \"" << ExpectedPrinted.str() << "\", "
- "got \"" << Printer.getPrinted().str() << "\"";
-
- return ::testing::AssertionSuccess();
+ return PrintedNodeMatches<Decl>(
+ Code, Args, NodeMatch, ExpectedPrinted, FileName, PrintDecl,
+ PolicyModifier, AllowError,
+ // Filter out implicit decls
+ [](const Decl *D) { return !D->isImplicit(); });
}
::testing::AssertionResult
PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
StringRef ExpectedPrinted,
- PrintingPolicyModifier PolicyModifier = nullptr) {
+ PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c++98");
return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"),
ExpectedPrinted, "input.cc", PolicyModifier);
@@ -120,7 +68,7 @@ PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
::testing::AssertionResult
PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted,
- PrintingPolicyModifier PolicyModifier = nullptr) {
+ PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c++98");
return PrintedDeclMatches(Code,
Args,
@@ -165,7 +113,7 @@ ::testing::AssertionResult PrintedDeclCXX11nonMSCMatches(
::testing::AssertionResult
PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted,
- PrintingPolicyModifier PolicyModifier = nullptr) {
+ PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args{"-std=c++17", "-fno-delayed-template-parsing"};
return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc",
PolicyModifier);
@@ -174,7 +122,7 @@ PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
::testing::AssertionResult
PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted,
- PrintingPolicyModifier PolicyModifier = nullptr) {
+ PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c11");
return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c",
PolicyModifier);
diff --git a/clang/unittests/AST/NamedDeclPrinterTest.cpp b/clang/unittests/AST/NamedDeclPrinterTest.cpp
index 1042312e8a730..cd833725b448d 100644
--- a/clang/unittests/AST/NamedDeclPrinterTest.cpp
+++ b/clang/unittests/AST/NamedDeclPrinterTest.cpp
@@ -15,6 +15,7 @@
//
//===----------------------------------------------------------------------===//
+#include "ASTPrint.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/PrettyPrinter.h"
@@ -66,31 +67,11 @@ ::testing::AssertionResult PrintedDeclMatches(
const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted,
StringRef FileName,
std::function<void(llvm::raw_ostream &, const NamedDecl *)> Print) {
- PrintMatch Printer(std::move(Print));
- MatchFinder Finder;
- Finder.addMatcher(NodeMatch, &Printer);
- std::unique_ptr<FrontendActionFactory> Factory =
- newFrontendActionFactory(&Finder);
-
- if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
- return testing::AssertionFailure()
- << "Parsing error in \"" << Code.str() << "\"";
-
- if (Printer.getNumFoundDecls() == 0)
- return testing::AssertionFailure()
- << "Matcher didn't find any named declarations";
-
- if (Printer.getNumFoundDecls() > 1)
- return testing::AssertionFailure()
- << "Matcher should match only one named declaration "
- "(found " << Printer.getNumFoundDecls() << ")";
-
- if (Printer.getPrinted() != ExpectedPrinted)
- return ::testing::AssertionFailure()
- << "Expected \"" << ExpectedPrinted.str() << "\", "
- "got \"" << Printer.getPrinted().str() << "\"";
-
- return ::testing::AssertionSuccess();
+ return PrintedNodeMatches<NamedDecl>(
+ Code, Args, NodeMatch, ExpectedPrinted, FileName,
+ [Print](llvm::raw_ostream &Out, const ASTContext *Context,
+ const NamedDecl *ND,
+ PrintingPolicyAdjuster PolicyAdjuster) { Print(Out, ND); });
}
::testing::AssertionResult
diff --git a/clang/unittests/AST/StmtPrinterTest.cpp b/clang/unittests/AST/StmtPrinterTest.cpp
index 29cdbf75a00c8..65dfec4cc5b4a 100644
--- a/clang/unittests/AST/StmtPrinterTest.cpp
+++ b/clang/unittests/AST/StmtPrinterTest.cpp
@@ -38,11 +38,29 @@ DeclarationMatcher FunctionBodyMatcher(StringRef ContainingFunction) {
has(compoundStmt(has(stmt().bind("id")))));
}
+static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
+ const Stmt *S, PrintingPolicyAdjuster PolicyAdjuster) {
+ assert(S != nullptr && "Expected non-null Stmt");
+ PrintingPolicy Policy = Context->getPrintingPolicy();
+ if (PolicyAdjuster)
+ PolicyAdjuster(Policy);
+ S->printPretty(Out, /*Helper*/ nullptr, Policy);
+}
+
+template <typename Matcher>
+::testing::AssertionResult
+PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
+ const Matcher &NodeMatch, StringRef ExpectedPrinted,
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
+ return PrintedNodeMatches<Stmt>(Code, Args, NodeMatch, ExpectedPrinted, "",
+ PrintStmt, PolicyAdjuster);
+}
+
template <typename T>
::testing::AssertionResult
PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
const char *StdOpt;
switch (Standard) {
case StdVer::CXX98: StdOpt = "-std=c++98"; break;
@@ -64,7 +82,7 @@ template <typename T>
::testing::AssertionResult
PrintedStmtMSMatches(StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
std::vector<std::string> Args = {
"-std=c++98",
"-target", "i686-pc-win32",
@@ -79,7 +97,7 @@ template <typename T>
::testing::AssertionResult
PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
std::vector<std::string> Args = {
"-ObjC",
"-fobjc-runtime=macosx-10.12.0",
@@ -202,10 +220,10 @@ class A {
};
)";
// No implicit 'this'.
- ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11,
- CPPSource, memberExpr(anything()).bind("id"), "field",
- PolicyAdjusterType(
- [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })));
+ ASSERT_TRUE(PrintedStmtCXXMatches(
+ StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "field",
+
+ [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
// Print implicit 'this'.
ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11,
CPPSource, memberExpr(anything()).bind("id"), "this->field"));
@@ -222,11 +240,10 @@ class A {
@end
)";
// No implicit 'self'.
- ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"),
- "return ivar;\n",
- PolicyAdjusterType([](PrintingPolicy &PP) {
- PP.SuppressImplicitBase = true;
- })));
+ ASSERT_TRUE(PrintedStmtObjCMatches(
+ ObjCSource, returnStmt().bind("id"), "return ivar;\n",
+
+ [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
// Print implicit 'self'.
ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"),
"return self->ivar;\n"));
@@ -243,5 +260,6 @@ TEST(StmtPrinter, TerseOutputWithLambdas) {
// body not printed when TerseOutput is on.
ASSERT_TRUE(PrintedStmtCXXMatches(
StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}",
- PolicyAdjusterType([](PrintingPolicy &PP) { PP.TerseOutput = true; })));
+
+ [](PrintingPolicy &PP) { PP.TerseOutput = true; }));
}
More information about the cfe-commits
mailing list