[flang-commits] [flang] 6a6994c - Reland [StandardInstrumentations] Check function analysis invalidation in module passes as well

Arthur Eubanks via flang-commits flang-commits at lists.llvm.org
Wed Mar 15 13:29:42 PDT 2023


Author: Arthur Eubanks
Date: 2023-03-15T13:29:21-07:00
New Revision: 6a6994cc9bc0327aaf8b005c650ff5eb29d2bcce

URL: https://github.com/llvm/llvm-project/commit/6a6994cc9bc0327aaf8b005c650ff5eb29d2bcce
DIFF: https://github.com/llvm/llvm-project/commit/6a6994cc9bc0327aaf8b005c650ff5eb29d2bcce.diff

LOG: Reland [StandardInstrumentations] Check function analysis invalidation in module passes as well

See comments for why we now need to pass in the MAM instead of the FAM.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D146160

Added: 
    

Modified: 
    clang/lib/CodeGen/BackendUtil.cpp
    flang/lib/Frontend/FrontendActions.cpp
    llvm/include/llvm/Passes/StandardInstrumentations.h
    llvm/lib/LTO/LTOBackend.cpp
    llvm/lib/LTO/ThinLTOCodeGenerator.cpp
    llvm/lib/Passes/PassBuilderBindings.cpp
    llvm/lib/Passes/StandardInstrumentations.cpp
    llvm/tools/opt/NewPMDriver.cpp
    llvm/unittests/IR/PassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp
index 92bef4be9cf8f..485eb3ad2ab85 100644
--- a/clang/lib/CodeGen/BackendUtil.cpp
+++ b/clang/lib/CodeGen/BackendUtil.cpp
@@ -837,7 +837,7 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
       TheModule->getContext(),
       (CodeGenOpts.DebugPassManager || DebugPassStructure),
       /*VerifyEach*/ false, PrintPassOpts);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
   PassBuilder PB(TM.get(), PTO, PGOOpt, &PIC);
 
   if (CodeGenOpts.EnableAssignmentTracking) {

diff  --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index a6dac4df25d72..15c2b8b05688c 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -699,7 +699,7 @@ void CodeGenAction::runOptimizationPipeline(llvm::raw_pwrite_stream &os) {
   std::optional<llvm::PGOOptions> pgoOpt;
   llvm::StandardInstrumentations si(llvmModule->getContext(),
                                     opts.DebugPassManager);
-  si.registerCallbacks(pic, &fam);
+  si.registerCallbacks(pic, &mam);
   llvm::PassBuilder pb(tm.get(), pto, pgoOpt, &pic);
 
   // Attempt to load pass plugins and register their callbacks with PB.

diff  --git a/llvm/include/llvm/Passes/StandardInstrumentations.h b/llvm/include/llvm/Passes/StandardInstrumentations.h
index 3b748f5e9f44c..759b79dc2b0ba 100644
--- a/llvm/include/llvm/Passes/StandardInstrumentations.h
+++ b/llvm/include/llvm/Passes/StandardInstrumentations.h
@@ -153,7 +153,7 @@ class PreservedCFGCheckerInstrumentation {
 #endif
 
   void registerCallbacks(PassInstrumentationCallbacks &PIC,
-                         FunctionAnalysisManager &FAM);
+                         ModuleAnalysisManager &MAM);
 };
 
 // Base class for classes that report changes to the IR.
@@ -574,7 +574,7 @@ class StandardInstrumentations {
   // Register all the standard instrumentation callbacks. If \p FAM is nullptr
   // then PreservedCFGChecker is not enabled.
   void registerCallbacks(PassInstrumentationCallbacks &PIC,
-                         FunctionAnalysisManager *FAM = nullptr);
+                         ModuleAnalysisManager *MAM = nullptr);
 
   TimePassesHandler &getTimePasses() { return TimePasses; }
 };

diff  --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp
index 4c41a382276a4..d574ae5737d55 100644
--- a/llvm/lib/LTO/LTOBackend.cpp
+++ b/llvm/lib/LTO/LTOBackend.cpp
@@ -260,7 +260,7 @@ static void runNewPMPasses(const Config &Conf, Module &Mod, TargetMachine *TM,
 
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
   PassBuilder PB(TM, Conf.PTO, PGOOpt, &PIC);
 
   RegisterPassPlugins(Conf.PassPlugins, PB);

diff  --git a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp
index 5b137a8f8cb34..595906a44bb44 100644
--- a/llvm/lib/LTO/ThinLTOCodeGenerator.cpp
+++ b/llvm/lib/LTO/ThinLTOCodeGenerator.cpp
@@ -245,7 +245,7 @@ static void optimizeModule(Module &TheModule, TargetMachine &TM,
 
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(TheModule.getContext(), DebugPassManager);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
   PipelineTuningOptions PTO;
   PTO.LoopVectorization = true;
   PTO.SLPVectorization = true;

diff  --git a/llvm/lib/Passes/PassBuilderBindings.cpp b/llvm/lib/Passes/PassBuilderBindings.cpp
index a87c0e6dc0a37..2a49ae6e30e0a 100644
--- a/llvm/lib/Passes/PassBuilderBindings.cpp
+++ b/llvm/lib/Passes/PassBuilderBindings.cpp
@@ -66,7 +66,7 @@ LLVMErrorRef LLVMRunPasses(LLVMModuleRef M, const char *Passes,
   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
 
   StandardInstrumentations SI(Mod->getContext(), Debug, VerifyEach);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
   ModulePassManager MPM;
   if (VerifyEach) {
     MPM.addPass(VerifierPass());

diff  --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp
index 8008986046188..af9751745861d 100644
--- a/llvm/lib/Passes/StandardInstrumentations.cpp
+++ b/llvm/lib/Passes/StandardInstrumentations.cpp
@@ -1075,29 +1075,46 @@ bool PreservedCFGCheckerInstrumentation::CFG::invalidate(
            PAC.preservedSet<CFGAnalyses>());
 }
 
+static SmallVector<Function *, 1> GetFunctions(Any IR) {
+  SmallVector<Function *, 1> Functions;
+
+  if (const auto **MaybeF = any_cast<const Function *>(&IR)) {
+    Functions.push_back(*const_cast<Function **>(MaybeF));
+  } else if (const auto **MaybeM = any_cast<const Module *>(&IR)) {
+    for (Function &F : **const_cast<Module **>(MaybeM))
+      Functions.push_back(&F);
+  }
+  return Functions;
+}
+
 void PreservedCFGCheckerInstrumentation::registerCallbacks(
-    PassInstrumentationCallbacks &PIC, FunctionAnalysisManager &FAM) {
+    PassInstrumentationCallbacks &PIC, ModuleAnalysisManager &MAM) {
   if (!VerifyAnalysisInvalidation)
     return;
 
-  FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); });
-  FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); });
-
-  PIC.registerBeforeNonSkippedPassCallback(
-      [this, &FAM](StringRef P, Any IR) {
+  bool Registered = false;
+  PIC.registerBeforeNonSkippedPassCallback([this, &MAM, Registered](
+                                               StringRef P, Any IR) mutable {
 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-        assert(&PassStack.emplace_back(P));
+    assert(&PassStack.emplace_back(P));
 #endif
-        (void)this;
-        const auto **F = any_cast<const Function *>(&IR);
-        if (!F)
-          return;
+    (void)this;
 
-        // Make sure a fresh CFG snapshot is available before the pass.
-        FAM.getResult<PreservedCFGCheckerAnalysis>(*const_cast<Function *>(*F));
-        FAM.getResult<PreservedFunctionHashAnalysis>(
-            *const_cast<Function *>(*F));
-      });
+    auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(
+                       *const_cast<Module *>(unwrapModule(IR, /*Force=*/true)))
+                    .getManager();
+    if (!Registered) {
+      FAM.registerPass([&] { return PreservedCFGCheckerAnalysis(); });
+      FAM.registerPass([&] { return PreservedFunctionHashAnalysis(); });
+      Registered = true;
+    }
+
+    for (Function *F : GetFunctions(IR)) {
+      // Make sure a fresh CFG snapshot is available before the pass.
+      FAM.getResult<PreservedCFGCheckerAnalysis>(*F);
+      FAM.getResult<PreservedFunctionHashAnalysis>(*F);
+    }
+  });
 
   PIC.registerAfterPassInvalidatedCallback(
       [this](StringRef P, const PreservedAnalyses &PassPA) {
@@ -1108,7 +1125,7 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
         (void)this;
       });
 
-  PIC.registerAfterPassCallback([this, &FAM](StringRef P, Any IR,
+  PIC.registerAfterPassCallback([this, &MAM](StringRef P, Any IR,
                                              const PreservedAnalyses &PassPA) {
 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     assert(PassStack.pop_back_val() == P &&
@@ -1116,36 +1133,42 @@ void PreservedCFGCheckerInstrumentation::registerCallbacks(
 #endif
     (void)this;
 
-    const auto **MaybeF = any_cast<const Function *>(&IR);
-    if (!MaybeF)
-      return;
-    Function &F = *const_cast<Function *>(*MaybeF);
-
-    if (auto *HashBefore =
-            FAM.getCachedResult<PreservedFunctionHashAnalysis>(F)) {
-      if (HashBefore->Hash != StructuralHash(F)) {
-        report_fatal_error(formatv(
-            "Function @{0} changed by {1} without invalidating analyses",
-            F.getName(), P));
+    // We have to get the FAM via the MAM, rather than directly use a passed in
+    // FAM because if MAM has not cached the FAM, it won't invalidate function
+    // analyses in FAM.
+    auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(
+                       *const_cast<Module *>(unwrapModule(IR, /*Force=*/true)))
+                    .getManager();
+
+    for (Function *F : GetFunctions(IR)) {
+      if (auto *HashBefore =
+              FAM.getCachedResult<PreservedFunctionHashAnalysis>(*F)) {
+        if (HashBefore->Hash != StructuralHash(*F)) {
+          report_fatal_error(formatv(
+              "Function @{0} changed by {1} without invalidating analyses",
+              F->getName(), P));
+        }
       }
-    }
 
-    auto CheckCFG = [](StringRef Pass, StringRef FuncName,
-                       const CFG &GraphBefore, const CFG &GraphAfter) {
-      if (GraphAfter == GraphBefore)
-        return;
-
-      dbgs() << "Error: " << Pass
-             << " does not invalidate CFG analyses but CFG changes detected in "
-                "function @"
-             << FuncName << ":\n";
-      CFG::printDiff(dbgs(), GraphBefore, GraphAfter);
-      report_fatal_error(Twine("CFG unexpectedly changed by ", Pass));
-    };
-
-    if (auto *GraphBefore = FAM.getCachedResult<PreservedCFGCheckerAnalysis>(F))
-      CheckCFG(P, F.getName(), *GraphBefore,
-               CFG(&F, /* TrackBBLifetime */ false));
+      auto CheckCFG = [](StringRef Pass, StringRef FuncName,
+                         const CFG &GraphBefore, const CFG &GraphAfter) {
+        if (GraphAfter == GraphBefore)
+          return;
+
+        dbgs()
+            << "Error: " << Pass
+            << " does not invalidate CFG analyses but CFG changes detected in "
+               "function @"
+            << FuncName << ":\n";
+        CFG::printDiff(dbgs(), GraphBefore, GraphAfter);
+        report_fatal_error(Twine("CFG unexpectedly changed by ", Pass));
+      };
+
+      if (auto *GraphBefore =
+              FAM.getCachedResult<PreservedCFGCheckerAnalysis>(*F))
+        CheckCFG(P, F->getName(), *GraphBefore,
+                 CFG(F, /* TrackBBLifetime */ false));
+    }
   });
 }
 
@@ -2175,7 +2198,7 @@ void PrintCrashIRInstrumentation::registerCallbacks(
 }
 
 void StandardInstrumentations::registerCallbacks(
-    PassInstrumentationCallbacks &PIC, FunctionAnalysisManager *FAM) {
+    PassInstrumentationCallbacks &PIC, ModuleAnalysisManager *MAM) {
   PrintIR.registerCallbacks(PIC);
   PrintPass.registerCallbacks(PIC);
   TimePasses.registerCallbacks(PIC);
@@ -2189,8 +2212,8 @@ void StandardInstrumentations::registerCallbacks(
   WebsiteChangeReporter.registerCallbacks(PIC);
   ChangeTester.registerCallbacks(PIC);
   PrintCrashIR.registerCallbacks(PIC);
-  if (FAM)
-    PreservedCFGChecker.registerCallbacks(PIC, *FAM);
+  if (MAM)
+    PreservedCFGChecker.registerCallbacks(PIC, *MAM);
 
   // TimeProfiling records the pass running time cost.
   // Its 'BeforePassCallback' can be appended at the tail of all the

diff  --git a/llvm/tools/opt/NewPMDriver.cpp b/llvm/tools/opt/NewPMDriver.cpp
index 697f2649d20b0..bcf6a7f3f9aa8 100644
--- a/llvm/tools/opt/NewPMDriver.cpp
+++ b/llvm/tools/opt/NewPMDriver.cpp
@@ -395,7 +395,7 @@ bool llvm::runPassPipeline(StringRef Arg0, Module &M, TargetMachine *TM,
   PrintPassOpts.SkipAnalyses = DebugPM == DebugLogging::Quiet;
   StandardInstrumentations SI(M.getContext(), DebugPM != DebugLogging::None,
                               VerifyEachPass, PrintPassOpts);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
   DebugifyEachInstrumentation Debugify;
   DebugifyStatsMap DIStatsMap;
   DebugInfoPerPass DebugInfoBeforePass;

diff  --git a/llvm/unittests/IR/PassManagerTest.cpp b/llvm/unittests/IR/PassManagerTest.cpp
index 4c8be3702bf0f..dda1d0c7bbd80 100644
--- a/llvm/unittests/IR/PassManagerTest.cpp
+++ b/llvm/unittests/IR/PassManagerTest.cpp
@@ -824,10 +824,13 @@ TEST_F(PassManagerTest, FunctionPassCFGChecker) {
 
   auto *F = M->getFunction("foo");
   FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });
   FAM.registerPass([&] { return AssumptionAnalysis(); });
@@ -870,10 +873,13 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerInvalidateAnalysis) {
 
   auto *F = M->getFunction("foo");
   FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });
   FAM.registerPass([&] { return AssumptionAnalysis(); });
@@ -935,10 +941,13 @@ TEST_F(PassManagerTest, FunctionPassCFGCheckerWrapped) {
 
   auto *F = M->getFunction("foo");
   FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
   FunctionPassManager FPM;
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ true);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return DominatorTreeAnalysis(); });
   FAM.registerPass([&] { return AssumptionAnalysis(); });
@@ -961,7 +970,7 @@ struct WrongFunctionPass : PassInfoMixin<WrongFunctionPass> {
   static StringRef name() { return "WrongFunctionPass"; }
 };
 
-TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) {
+TEST_F(PassManagerTest, FunctionPassMissedFunctionAnalysisInvalidation) {
   LLVMContext Context;
   auto M = parseIR(Context, "define void @foo() {\n"
                             "  %a = add i32 0, 0\n"
@@ -969,9 +978,12 @@ TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) {
                             "}\n");
 
   FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
   PassInstrumentationCallbacks PIC;
   StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false);
-  SI.registerCallbacks(PIC, &FAM);
+  SI.registerCallbacks(PIC, &MAM);
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
   FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
 
   FunctionPassManager FPM;
@@ -981,6 +993,39 @@ TEST_F(PassManagerTest, FunctionAnalysisMissedInvalidation) {
   EXPECT_DEATH(FPM.run(*F, FAM), "Function @foo changed by WrongFunctionPass without invalidating analyses");
 }
 
-#endif
+struct WrongModulePass : PassInfoMixin<WrongModulePass> {
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
+    for (Function &F : M)
+      F.getEntryBlock().begin()->eraseFromParent();
+
+    return PreservedAnalyses::all();
+  }
+  static StringRef name() { return "WrongModulePass"; }
+};
+
+TEST_F(PassManagerTest, ModulePassMissedFunctionAnalysisInvalidation) {
+  LLVMContext Context;
+  auto M = parseIR(Context, "define void @foo() {\n"
+                            "  %a = add i32 0, 0\n"
+                            "  ret void\n"
+                            "}\n");
+
+  FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
+  PassInstrumentationCallbacks PIC;
+  StandardInstrumentations SI(M->getContext(), /*DebugLogging*/ false);
+  SI.registerCallbacks(PIC, &MAM);
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
+  FAM.registerPass([&] { return PassInstrumentationAnalysis(&PIC); });
 
+  ModulePassManager MPM;
+  MPM.addPass(WrongModulePass());
+
+  EXPECT_DEATH(
+      MPM.run(*M, MAM),
+      "Function @foo changed by WrongModulePass without invalidating analyses");
+}
+
+#endif
 }


        


More information about the flang-commits mailing list