[Mlir-commits] [mlir] 809b440 - Fix MLIR pass manager initialization: hash the pass pipeline to detect when initialization is needed

Mehdi Amini llvmlistbot at llvm.org
Tue Aug 22 12:55:32 PDT 2023


Author: Mehdi Amini
Date: 2023-08-22T12:55:07-07:00
New Revision: 809b44039555d35096722e32aea2a8df778c303b

URL: https://github.com/llvm/llvm-project/commit/809b44039555d35096722e32aea2a8df778c303b
DIFF: https://github.com/llvm/llvm-project/commit/809b44039555d35096722e32aea2a8df778c303b.diff

LOG: Fix MLIR pass manager initialization: hash the pass pipeline to detect when initialization is needed

The current logic hashes the context to detect registration changes and re-run
the pass initialization. However it wasn't checking for changes to the
pipeline, so a pass that would get added after a first run would not be
initialized during subsequent runs.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D158377

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/unittests/Pass/PassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 75fe1524221c1c..d5f1ea0fe0350d 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -172,6 +172,10 @@ class OpPassManager {
   /// if a pass manager has already been initialized.
   LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration);
 
+  /// Compute a hash of the pipeline, so that we can detect changes (a pass is
+  /// added...).
+  llvm::hash_code hash();
+
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;
 
@@ -439,9 +443,11 @@ class PassManager : public OpPassManager {
   /// generate reproducers.
   std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;
 
-  /// A hash key used to detect when reinitialization is necessary.
+  /// Hash keys used to detect when reinitialization is necessary.
   llvm::hash_code initializationKey =
       DenseMapInfo<llvm::hash_code>::getTombstoneKey();
+  llvm::hash_code pipelineInitializationKey =
+      DenseMapInfo<llvm::hash_code>::getTombstoneKey();
 
   /// Flag that specifies if pass timing is enabled.
   bool passTiming : 1;

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 3b41cbe48124ea..3b933fde58a137 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/CommandLine.h"
@@ -424,6 +425,23 @@ LogicalResult OpPassManager::initialize(MLIRContext *context,
   return success();
 }
 
+llvm::hash_code OpPassManager::hash() {
+  llvm::hash_code hashCode;
+  for (Pass &pass : getPasses()) {
+    // If this pass isn't an adaptor, directly hash it.
+    auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
+    if (!adaptor) {
+      hashCode = llvm::hash_combine(hashCode, &pass);
+      continue;
+    }
+    // Otherwise, hash recursively each of the adaptors pass managers.
+    for (OpPassManager &adaptorPM : adaptor->getPassManagers())
+      llvm::hash_combine(hashCode, adaptorPM.hash());
+  }
+  return hashCode;
+}
+
+
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
@@ -825,10 +843,12 @@ LogicalResult PassManager::run(Operation *op) {
 
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
-  if (newInitKey != initializationKey) {
+  llvm::hash_code pipelineKey = hash();
+  if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
     if (failed(initialize(context, impl->initializationGeneration + 1)))
       return failure();
     initializationKey = newInitKey;
+    pipelineKey = pipelineInitializationKey;
   }
 
   // Construct a top level analysis manager for the pipeline.

diff  --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 97349d681c3a0b..70a679125c0ea1 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
 #include "gtest/gtest.h"
 
@@ -144,4 +145,39 @@ TEST(PassManagerTest, InvalidPass) {
                "intend to nest?");
 }
 
+/// Simple pass to annotate a func::FuncOp with the results of analysis.
+struct InitializeCheckingPass
+    : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
+  LogicalResult initialize(MLIRContext *ctx) final {
+    initialized = true;
+    return success();
+  }
+  bool initialized = false;
+
+  void runOnOperation() override {
+    if (!initialized) {
+      getOperation()->emitError() << "Pass isn't initialized!";
+      signalPassFailure();
+    }
+  }
+};
+
+TEST(PassManagerTest, PassInitialization) {
+  MLIRContext context;
+  context.allowUnregisteredDialects();
+
+  // Create a module
+  OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+
+  // Instantiate and run our pass.
+  auto pm = PassManager::on<ModuleOp>(&context);
+  pm.addPass(std::make_unique<InitializeCheckingPass>());
+  EXPECT_TRUE(succeeded(pm.run(module.get())));
+
+  // Adding a second copy of the pass, we should also initialize it!
+  pm.addPass(std::make_unique<InitializeCheckingPass>());
+  EXPECT_TRUE(succeeded(pm.run(module.get())));
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list