[llvm] [StandardInstrumentations] add `unwrapIR` to simplify code NFCI (PR #75474)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 14 05:52:40 PST 2023
https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/75474
Use pointer to represent semantic of `optional`.
>From 0e9c183e10e271f147cea62e6096c72b0e21df20 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Thu, 14 Dec 2023 21:38:41 +0800
Subject: [PATCH] [StandardInstrumentations] add `unwrapIR` to simplify code
NFCI
Use pointer to represent semantic of `optional`.
---
llvm/lib/Passes/StandardInstrumentations.cpp | 157 +++++++++----------
1 file changed, 76 insertions(+), 81 deletions(-)
diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp
index df445c2dd78b77..fd1317e3eb2567 100644
--- a/llvm/lib/Passes/StandardInstrumentations.cpp
+++ b/llvm/lib/Passes/StandardInstrumentations.cpp
@@ -130,6 +130,11 @@ static cl::opt<std::string> IRDumpDirectory(
"files in this directory rather than written to stderr"),
cl::Hidden, cl::value_desc("filename"));
+template <typename IRUnitT> static const IRUnitT *unwrapIR(Any IR) {
+ const IRUnitT **IRPtr = llvm::any_cast<const IRUnitT *>(&IR);
+ return IRPtr ? *IRPtr : nullptr;
+}
+
namespace {
// An option for specifying an executable that will be called with the IR
@@ -147,18 +152,18 @@ static cl::opt<std::string>
/// Extract Module out of \p IR unit. May return nullptr if \p IR does not match
/// certain global filters. Will never return nullptr if \p Force is true.
const Module *unwrapModule(Any IR, bool Force = false) {
- if (const auto **M = llvm::any_cast<const Module *>(&IR))
- return *M;
+ if (const auto *M = unwrapIR<Module>(IR))
+ return M;
- if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
- if (!Force && !isFunctionInPrintList((*F)->getName()))
+ if (const auto *F = unwrapIR<Function>(IR)) {
+ if (!Force && !isFunctionInPrintList(F->getName()))
return nullptr;
- return (*F)->getParent();
+ return F->getParent();
}
- if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
- for (const LazyCallGraph::Node &N : **C) {
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+ for (const LazyCallGraph::Node &N : *C) {
const Function &F = N.getFunction();
if (Force || (!F.isDeclaration() && isFunctionInPrintList(F.getName()))) {
return F.getParent();
@@ -168,8 +173,8 @@ const Module *unwrapModule(Any IR, bool Force = false) {
return nullptr;
}
- if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
- const Function *F = (*L)->getHeader()->getParent();
+ if (const auto *L = unwrapIR<Loop>(IR)) {
+ const Function *F = L->getHeader()->getParent();
if (!Force && !isFunctionInPrintList(F->getName()))
return nullptr;
return F->getParent();
@@ -211,20 +216,20 @@ void printIR(raw_ostream &OS, const Loop *L) {
}
std::string getIRName(Any IR) {
- if (llvm::any_cast<const Module *>(&IR))
+ if (unwrapIR<Module>(IR))
return "[module]";
- if (const auto **F = llvm::any_cast<const Function *>(&IR))
- return (*F)->getName().str();
+ if (const auto *F = unwrapIR<Function>(IR))
+ return F->getName().str();
- if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
- return (*C)->getName();
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+ return C->getName();
- if (const auto **L = llvm::any_cast<const Loop *>(&IR))
- return (*L)->getName().str();
+ if (const auto *L = unwrapIR<Loop>(IR))
+ return L->getName().str();
- if (const auto **MF = llvm::any_cast<const MachineFunction *>(&IR))
- return (*MF)->getName().str();
+ if (const auto *MF = unwrapIR<MachineFunction>(IR))
+ return MF->getName().str();
llvm_unreachable("Unknown wrapped IR type");
}
@@ -246,17 +251,17 @@ bool sccContainsFilterPrintFunc(const LazyCallGraph::SCC &C) {
}
bool shouldPrintIR(Any IR) {
- if (const auto **M = llvm::any_cast<const Module *>(&IR))
- return moduleContainsFilterPrintFunc(**M);
+ if (const auto *M = unwrapIR<Module>(IR))
+ return moduleContainsFilterPrintFunc(*M);
- if (const auto **F = llvm::any_cast<const Function *>(&IR))
- return isFunctionInPrintList((*F)->getName());
+ if (const auto *F = unwrapIR<Function>(IR))
+ return isFunctionInPrintList(F->getName());
- if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
- return sccContainsFilterPrintFunc(**C);
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+ return sccContainsFilterPrintFunc(*C);
- if (const auto **L = llvm::any_cast<const Loop *>(&IR))
- return isFunctionInPrintList((*L)->getHeader()->getParent()->getName());
+ if (const auto *L = unwrapIR<Loop>(IR))
+ return isFunctionInPrintList(L->getHeader()->getParent()->getName());
llvm_unreachable("Unknown wrapped IR type");
}
@@ -273,23 +278,23 @@ void unwrapAndPrint(raw_ostream &OS, Any IR) {
return;
}
- if (const auto **M = llvm::any_cast<const Module *>(&IR)) {
- printIR(OS, *M);
+ if (const auto *M = unwrapIR<Module>(IR)) {
+ printIR(OS, M);
return;
}
- if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
- printIR(OS, *F);
+ if (const auto *F = unwrapIR<Function>(IR)) {
+ printIR(OS, F);
return;
}
- if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
- printIR(OS, *C);
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+ printIR(OS, C);
return;
}
- if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
- printIR(OS, *L);
+ if (const auto *L = unwrapIR<Loop>(IR)) {
+ printIR(OS, L);
return;
}
llvm_unreachable("Unknown wrapped IR type");
@@ -320,13 +325,10 @@ std::string makeHTMLReady(StringRef SR) {
// Return the module when that is the appropriate level of comparison for \p IR.
const Module *getModuleForComparison(Any IR) {
- if (const auto **M = llvm::any_cast<const Module *>(&IR))
- return *M;
- if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
- return (*C)
- ->begin()
- ->getFunction()
- .getParent();
+ if (const auto *M = unwrapIR<Module>(IR))
+ return M;
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+ return C->begin()->getFunction().getParent();
return nullptr;
}
@@ -339,8 +341,8 @@ bool isInterestingFunction(const Function &F) {
bool isInteresting(Any IR, StringRef PassID, StringRef PassName) {
if (isIgnored(PassID) || !isPassInPrintList(PassName))
return false;
- if (const auto **F = llvm::any_cast<const Function *>(&IR))
- return isInterestingFunction(**F);
+ if (const auto *F = unwrapIR<Function>(IR))
+ return isInterestingFunction(*F);
return true;
}
@@ -662,12 +664,11 @@ template <typename T> void IRComparer<T>::analyzeIR(Any IR, IRDataT<T> &Data) {
return;
}
- const Function **FPtr = llvm::any_cast<const Function *>(&IR);
- const Function *F = FPtr ? *FPtr : nullptr;
+ const auto *F = unwrapIR<Function>(IR);
if (!F) {
- const Loop **L = llvm::any_cast<const Loop *>(&IR);
+ const auto *L = unwrapIR<Loop>(IR);
assert(L && "Unknown IR unit.");
- F = (*L)->getHeader()->getParent();
+ F = L->getHeader()->getParent();
}
assert(F && "Unknown IR unit.");
generateFunctionData(Data, *F);
@@ -706,21 +707,20 @@ static SmallString<32> getIRFileDisplayName(Any IR) {
stable_hash NameHash = stable_hash_combine_string(M->getName());
unsigned int MaxHashWidth = sizeof(stable_hash) * 8 / 4;
write_hex(ResultStream, NameHash, HexPrintStyle::Lower, MaxHashWidth);
- if (llvm::any_cast<const Module *>(&IR)) {
+ if (unwrapIR<Module>(IR)) {
ResultStream << "-module";
- } else if (const Function **F = llvm::any_cast<const Function *>(&IR)) {
+ } else if (const auto *F = unwrapIR<Function>(IR)) {
ResultStream << "-function-";
- stable_hash FunctionNameHash = stable_hash_combine_string((*F)->getName());
+ stable_hash FunctionNameHash = stable_hash_combine_string(F->getName());
write_hex(ResultStream, FunctionNameHash, HexPrintStyle::Lower,
MaxHashWidth);
- } else if (const LazyCallGraph::SCC **C =
- llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
+ } else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
ResultStream << "-scc-";
- stable_hash SCCNameHash = stable_hash_combine_string((*C)->getName());
+ stable_hash SCCNameHash = stable_hash_combine_string(C->getName());
write_hex(ResultStream, SCCNameHash, HexPrintStyle::Lower, MaxHashWidth);
- } else if (const Loop **L = llvm::any_cast<const Loop *>(&IR)) {
+ } else if (const auto *L = unwrapIR<Loop>(IR)) {
ResultStream << "-loop-";
- stable_hash LoopNameHash = stable_hash_combine_string((*L)->getName());
+ stable_hash LoopNameHash = stable_hash_combine_string(L->getName());
write_hex(ResultStream, LoopNameHash, HexPrintStyle::Lower, MaxHashWidth);
} else {
llvm_unreachable("Unknown wrapped IR type");
@@ -975,11 +975,10 @@ void OptNoneInstrumentation::registerCallbacks(
}
bool OptNoneInstrumentation::shouldRun(StringRef PassID, Any IR) {
- const Function **FPtr = llvm::any_cast<const Function *>(&IR);
- const Function *F = FPtr ? *FPtr : nullptr;
+ const auto *F = unwrapIR<Function>(IR);
if (!F) {
- if (const auto **L = llvm::any_cast<const Loop *>(&IR))
- F = (*L)->getHeader()->getParent();
+ if (const auto *L = unwrapIR<Loop>(IR))
+ F = L->getHeader()->getParent();
}
bool ShouldRun = !(F && F->hasOptNone());
if (!ShouldRun && DebugLogging) {
@@ -1054,15 +1053,14 @@ void PrintPassInstrumentation::registerCallbacks(
auto &OS = print();
OS << "Running pass: " << PassID << " on " << getIRName(IR);
- if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
- unsigned Count = (*F)->getInstructionCount();
+ if (const auto *F = unwrapIR<Function>(IR)) {
+ unsigned Count = F->getInstructionCount();
OS << " (" << Count << " instruction";
if (Count != 1)
OS << 's';
OS << ')';
- } else if (const auto **C =
- llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
- int Count = (*C)->size();
+ } else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+ int Count = C->size();
OS << " (" << Count << " node";
if (Count != 1)
OS << 's';
@@ -1277,10 +1275,10 @@ bool PreservedCFGCheckerInstrumentation::CFG::invalidate(
static SmallVector<Function *, 1> GetFunctions(Any IR) {
SmallVector<Function *, 1> Functions;
- if (const auto **MaybeF = llvm::any_cast<const Function *>(&IR)) {
- Functions.push_back(*const_cast<Function **>(MaybeF));
- } else if (const auto **MaybeM = llvm::any_cast<const Module *>(&IR)) {
- for (Function &F : **const_cast<Module **>(MaybeM))
+ if (const auto *MaybeF = unwrapIR<Function>(IR)) {
+ Functions.push_back(const_cast<Function *>(MaybeF));
+ } else if (const auto *MaybeM = unwrapIR<Module>(IR)) {
+ for (Function &F : *const_cast<Module *>(MaybeM))
Functions.push_back(&F);
}
return Functions;
@@ -1315,8 +1313,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
FAM.getResult<PreservedFunctionHashAnalysis>(*F);
}
- if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
- Module &M = **const_cast<Module **>(MaybeM);
+ if (const auto *MPtr = unwrapIR<Module>(IR)) {
+ auto &M = *const_cast<Module *>(MPtr);
MAM.getResult<PreservedModuleHashAnalysis>(M);
}
});
@@ -1374,8 +1372,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
CheckCFG(P, F->getName(), *GraphBefore,
CFG(F, /* TrackBBLifetime */ false));
}
- if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
- Module &M = **const_cast<Module **>(MaybeM);
+ if (const auto *MPtr = unwrapIR<Module>(IR)) {
+ auto &M = *const_cast<Module *>(MPtr);
if (auto *HashBefore =
MAM.getCachedResult<PreservedModuleHashAnalysis>(M)) {
if (HashBefore->Hash != StructuralHash(M)) {
@@ -1393,11 +1391,10 @@ void VerifyInstrumentation::registerCallbacks(
[this](StringRef P, Any IR, const PreservedAnalyses &PassPA) {
if (isIgnored(P) || P == "VerifierPass")
return;
- const Function **FPtr = llvm::any_cast<const Function *>(&IR);
- const Function *F = FPtr ? *FPtr : nullptr;
+ const auto *F = unwrapIR<Function>(IR);
if (!F) {
- if (const auto **L = llvm::any_cast<const Loop *>(&IR))
- F = (*L)->getHeader()->getParent();
+ if (const auto *L = unwrapIR<Loop>(IR))
+ F = L->getHeader()->getParent();
}
if (F) {
@@ -1409,12 +1406,10 @@ void VerifyInstrumentation::registerCallbacks(
"\"{0}\", compilation aborted!",
P));
} else {
- const Module **MPtr = llvm::any_cast<const Module *>(&IR);
- const Module *M = MPtr ? *MPtr : nullptr;
+ const auto *M = unwrapIR<Module>(IR);
if (!M) {
- if (const auto **C =
- llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
- M = (*C)->begin()->getFunction().getParent();
+ if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+ M = C->begin()->getFunction().getParent();
}
if (M) {
More information about the llvm-commits
mailing list