[llvm] [StandardInstrumentations] add `unwrapIR` to simplify code NFCI (PR #75474)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 14 05:52:40 PST 2023


https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/75474

Use pointer to represent semantic of `optional`.

>From 0e9c183e10e271f147cea62e6096c72b0e21df20 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Thu, 14 Dec 2023 21:38:41 +0800
Subject: [PATCH] [StandardInstrumentations] add `unwrapIR` to simplify code
 NFCI

Use pointer to represent semantic of `optional`.
---
 llvm/lib/Passes/StandardInstrumentations.cpp | 157 +++++++++----------
 1 file changed, 76 insertions(+), 81 deletions(-)

diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp
index df445c2dd78b77..fd1317e3eb2567 100644
--- a/llvm/lib/Passes/StandardInstrumentations.cpp
+++ b/llvm/lib/Passes/StandardInstrumentations.cpp
@@ -130,6 +130,11 @@ static cl::opt<std::string> IRDumpDirectory(
              "files in this directory rather than written to stderr"),
     cl::Hidden, cl::value_desc("filename"));
 
+template <typename IRUnitT> static const IRUnitT *unwrapIR(Any IR) {
+  const IRUnitT **IRPtr = llvm::any_cast<const IRUnitT *>(&IR);
+  return IRPtr ? *IRPtr : nullptr;
+}
+
 namespace {
 
 // An option for specifying an executable that will be called with the IR
@@ -147,18 +152,18 @@ static cl::opt<std::string>
 /// Extract Module out of \p IR unit. May return nullptr if \p IR does not match
 /// certain global filters. Will never return nullptr if \p Force is true.
 const Module *unwrapModule(Any IR, bool Force = false) {
-  if (const auto **M = llvm::any_cast<const Module *>(&IR))
-    return *M;
+  if (const auto *M = unwrapIR<Module>(IR))
+    return M;
 
-  if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
-    if (!Force && !isFunctionInPrintList((*F)->getName()))
+  if (const auto *F = unwrapIR<Function>(IR)) {
+    if (!Force && !isFunctionInPrintList(F->getName()))
       return nullptr;
 
-    return (*F)->getParent();
+    return F->getParent();
   }
 
-  if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
-    for (const LazyCallGraph::Node &N : **C) {
+  if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+    for (const LazyCallGraph::Node &N : *C) {
       const Function &F = N.getFunction();
       if (Force || (!F.isDeclaration() && isFunctionInPrintList(F.getName()))) {
         return F.getParent();
@@ -168,8 +173,8 @@ const Module *unwrapModule(Any IR, bool Force = false) {
     return nullptr;
   }
 
-  if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
-    const Function *F = (*L)->getHeader()->getParent();
+  if (const auto *L = unwrapIR<Loop>(IR)) {
+    const Function *F = L->getHeader()->getParent();
     if (!Force && !isFunctionInPrintList(F->getName()))
       return nullptr;
     return F->getParent();
@@ -211,20 +216,20 @@ void printIR(raw_ostream &OS, const Loop *L) {
 }
 
 std::string getIRName(Any IR) {
-  if (llvm::any_cast<const Module *>(&IR))
+  if (unwrapIR<Module>(IR))
     return "[module]";
 
-  if (const auto **F = llvm::any_cast<const Function *>(&IR))
-    return (*F)->getName().str();
+  if (const auto *F = unwrapIR<Function>(IR))
+    return F->getName().str();
 
-  if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
-    return (*C)->getName();
+  if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+    return C->getName();
 
-  if (const auto **L = llvm::any_cast<const Loop *>(&IR))
-    return (*L)->getName().str();
+  if (const auto *L = unwrapIR<Loop>(IR))
+    return L->getName().str();
 
-  if (const auto **MF = llvm::any_cast<const MachineFunction *>(&IR))
-    return (*MF)->getName().str();
+  if (const auto *MF = unwrapIR<MachineFunction>(IR))
+    return MF->getName().str();
 
   llvm_unreachable("Unknown wrapped IR type");
 }
@@ -246,17 +251,17 @@ bool sccContainsFilterPrintFunc(const LazyCallGraph::SCC &C) {
 }
 
 bool shouldPrintIR(Any IR) {
-  if (const auto **M = llvm::any_cast<const Module *>(&IR))
-    return moduleContainsFilterPrintFunc(**M);
+  if (const auto *M = unwrapIR<Module>(IR))
+    return moduleContainsFilterPrintFunc(*M);
 
-  if (const auto **F = llvm::any_cast<const Function *>(&IR))
-    return isFunctionInPrintList((*F)->getName());
+  if (const auto *F = unwrapIR<Function>(IR))
+    return isFunctionInPrintList(F->getName());
 
-  if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
-    return sccContainsFilterPrintFunc(**C);
+  if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+    return sccContainsFilterPrintFunc(*C);
 
-  if (const auto **L = llvm::any_cast<const Loop *>(&IR))
-    return isFunctionInPrintList((*L)->getHeader()->getParent()->getName());
+  if (const auto *L = unwrapIR<Loop>(IR))
+    return isFunctionInPrintList(L->getHeader()->getParent()->getName());
   llvm_unreachable("Unknown wrapped IR type");
 }
 
@@ -273,23 +278,23 @@ void unwrapAndPrint(raw_ostream &OS, Any IR) {
     return;
   }
 
-  if (const auto **M = llvm::any_cast<const Module *>(&IR)) {
-    printIR(OS, *M);
+  if (const auto *M = unwrapIR<Module>(IR)) {
+    printIR(OS, M);
     return;
   }
 
-  if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
-    printIR(OS, *F);
+  if (const auto *F = unwrapIR<Function>(IR)) {
+    printIR(OS, F);
     return;
   }
 
-  if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
-    printIR(OS, *C);
+  if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+    printIR(OS, C);
     return;
   }
 
-  if (const auto **L = llvm::any_cast<const Loop *>(&IR)) {
-    printIR(OS, *L);
+  if (const auto *L = unwrapIR<Loop>(IR)) {
+    printIR(OS, L);
     return;
   }
   llvm_unreachable("Unknown wrapped IR type");
@@ -320,13 +325,10 @@ std::string makeHTMLReady(StringRef SR) {
 
 // Return the module when that is the appropriate level of comparison for \p IR.
 const Module *getModuleForComparison(Any IR) {
-  if (const auto **M = llvm::any_cast<const Module *>(&IR))
-    return *M;
-  if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
-    return (*C)
-        ->begin()
-        ->getFunction()
-        .getParent();
+  if (const auto *M = unwrapIR<Module>(IR))
+    return M;
+  if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+    return C->begin()->getFunction().getParent();
   return nullptr;
 }
 
@@ -339,8 +341,8 @@ bool isInterestingFunction(const Function &F) {
 bool isInteresting(Any IR, StringRef PassID, StringRef PassName) {
   if (isIgnored(PassID) || !isPassInPrintList(PassName))
     return false;
-  if (const auto **F = llvm::any_cast<const Function *>(&IR))
-    return isInterestingFunction(**F);
+  if (const auto *F = unwrapIR<Function>(IR))
+    return isInterestingFunction(*F);
   return true;
 }
 
@@ -662,12 +664,11 @@ template <typename T> void IRComparer<T>::analyzeIR(Any IR, IRDataT<T> &Data) {
     return;
   }
 
-  const Function **FPtr = llvm::any_cast<const Function *>(&IR);
-  const Function *F = FPtr ? *FPtr : nullptr;
+  const auto *F = unwrapIR<Function>(IR);
   if (!F) {
-    const Loop **L = llvm::any_cast<const Loop *>(&IR);
+    const auto *L = unwrapIR<Loop>(IR);
     assert(L && "Unknown IR unit.");
-    F = (*L)->getHeader()->getParent();
+    F = L->getHeader()->getParent();
   }
   assert(F && "Unknown IR unit.");
   generateFunctionData(Data, *F);
@@ -706,21 +707,20 @@ static SmallString<32> getIRFileDisplayName(Any IR) {
   stable_hash NameHash = stable_hash_combine_string(M->getName());
   unsigned int MaxHashWidth = sizeof(stable_hash) * 8 / 4;
   write_hex(ResultStream, NameHash, HexPrintStyle::Lower, MaxHashWidth);
-  if (llvm::any_cast<const Module *>(&IR)) {
+  if (unwrapIR<Module>(IR)) {
     ResultStream << "-module";
-  } else if (const Function **F = llvm::any_cast<const Function *>(&IR)) {
+  } else if (const auto *F = unwrapIR<Function>(IR)) {
     ResultStream << "-function-";
-    stable_hash FunctionNameHash = stable_hash_combine_string((*F)->getName());
+    stable_hash FunctionNameHash = stable_hash_combine_string(F->getName());
     write_hex(ResultStream, FunctionNameHash, HexPrintStyle::Lower,
               MaxHashWidth);
-  } else if (const LazyCallGraph::SCC **C =
-                 llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
+  } else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
     ResultStream << "-scc-";
-    stable_hash SCCNameHash = stable_hash_combine_string((*C)->getName());
+    stable_hash SCCNameHash = stable_hash_combine_string(C->getName());
     write_hex(ResultStream, SCCNameHash, HexPrintStyle::Lower, MaxHashWidth);
-  } else if (const Loop **L = llvm::any_cast<const Loop *>(&IR)) {
+  } else if (const auto *L = unwrapIR<Loop>(IR)) {
     ResultStream << "-loop-";
-    stable_hash LoopNameHash = stable_hash_combine_string((*L)->getName());
+    stable_hash LoopNameHash = stable_hash_combine_string(L->getName());
     write_hex(ResultStream, LoopNameHash, HexPrintStyle::Lower, MaxHashWidth);
   } else {
     llvm_unreachable("Unknown wrapped IR type");
@@ -975,11 +975,10 @@ void OptNoneInstrumentation::registerCallbacks(
 }
 
 bool OptNoneInstrumentation::shouldRun(StringRef PassID, Any IR) {
-  const Function **FPtr = llvm::any_cast<const Function *>(&IR);
-  const Function *F = FPtr ? *FPtr : nullptr;
+  const auto *F = unwrapIR<Function>(IR);
   if (!F) {
-    if (const auto **L = llvm::any_cast<const Loop *>(&IR))
-      F = (*L)->getHeader()->getParent();
+    if (const auto *L = unwrapIR<Loop>(IR))
+      F = L->getHeader()->getParent();
   }
   bool ShouldRun = !(F && F->hasOptNone());
   if (!ShouldRun && DebugLogging) {
@@ -1054,15 +1053,14 @@ void PrintPassInstrumentation::registerCallbacks(
 
     auto &OS = print();
     OS << "Running pass: " << PassID << " on " << getIRName(IR);
-    if (const auto **F = llvm::any_cast<const Function *>(&IR)) {
-      unsigned Count = (*F)->getInstructionCount();
+    if (const auto *F = unwrapIR<Function>(IR)) {
+      unsigned Count = F->getInstructionCount();
       OS << " (" << Count << " instruction";
       if (Count != 1)
         OS << 's';
       OS << ')';
-    } else if (const auto **C =
-                   llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) {
-      int Count = (*C)->size();
+    } else if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR)) {
+      int Count = C->size();
       OS << " (" << Count << " node";
       if (Count != 1)
         OS << 's';
@@ -1277,10 +1275,10 @@ bool PreservedCFGCheckerInstrumentation::CFG::invalidate(
 static SmallVector<Function *, 1> GetFunctions(Any IR) {
   SmallVector<Function *, 1> Functions;
 
-  if (const auto **MaybeF = llvm::any_cast<const Function *>(&IR)) {
-    Functions.push_back(*const_cast<Function **>(MaybeF));
-  } else if (const auto **MaybeM = llvm::any_cast<const Module *>(&IR)) {
-    for (Function &F : **const_cast<Module **>(MaybeM))
+  if (const auto *MaybeF = unwrapIR<Function>(IR)) {
+    Functions.push_back(const_cast<Function *>(MaybeF));
+  } else if (const auto *MaybeM = unwrapIR<Module>(IR)) {
+    for (Function &F : *const_cast<Module *>(MaybeM))
       Functions.push_back(&F);
   }
   return Functions;
@@ -1315,8 +1313,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
       FAM.getResult<PreservedFunctionHashAnalysis>(*F);
     }
 
-    if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
-      Module &M = **const_cast<Module **>(MaybeM);
+    if (const auto *MPtr = unwrapIR<Module>(IR)) {
+      auto &M = *const_cast<Module *>(MPtr);
       MAM.getResult<PreservedModuleHashAnalysis>(M);
     }
   });
@@ -1374,8 +1372,8 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
         CheckCFG(P, F->getName(), *GraphBefore,
                  CFG(F, /* TrackBBLifetime */ false));
     }
-    if (auto *MaybeM = llvm::any_cast<const Module *>(&IR)) {
-      Module &M = **const_cast<Module **>(MaybeM);
+    if (const auto *MPtr = unwrapIR<Module>(IR)) {
+      auto &M = *const_cast<Module *>(MPtr);
       if (auto *HashBefore =
               MAM.getCachedResult<PreservedModuleHashAnalysis>(M)) {
         if (HashBefore->Hash != StructuralHash(M)) {
@@ -1393,11 +1391,10 @@ void VerifyInstrumentation::registerCallbacks(
       [this](StringRef P, Any IR, const PreservedAnalyses &PassPA) {
         if (isIgnored(P) || P == "VerifierPass")
           return;
-        const Function **FPtr = llvm::any_cast<const Function *>(&IR);
-        const Function *F = FPtr ? *FPtr : nullptr;
+        const auto *F = unwrapIR<Function>(IR);
         if (!F) {
-          if (const auto **L = llvm::any_cast<const Loop *>(&IR))
-            F = (*L)->getHeader()->getParent();
+          if (const auto *L = unwrapIR<Loop>(IR))
+            F = L->getHeader()->getParent();
         }
 
         if (F) {
@@ -1409,12 +1406,10 @@ void VerifyInstrumentation::registerCallbacks(
                                        "\"{0}\", compilation aborted!",
                                        P));
         } else {
-          const Module **MPtr = llvm::any_cast<const Module *>(&IR);
-          const Module *M = MPtr ? *MPtr : nullptr;
+          const auto *M = unwrapIR<Module>(IR);
           if (!M) {
-            if (const auto **C =
-                    llvm::any_cast<const LazyCallGraph::SCC *>(&IR))
-              M = (*C)->begin()->getFunction().getParent();
+            if (const auto *C = unwrapIR<LazyCallGraph::SCC>(IR))
+              M = C->begin()->getFunction().getParent();
           }
 
           if (M) {



More information about the llvm-commits mailing list