[Mlir-commits] [mlir] 809b440 - Fix MLIR pass manager initialization: hash the pass pipeline to detect when initialization is needed
Mehdi Amini
llvmlistbot at llvm.org
Tue Aug 22 12:55:32 PDT 2023
Author: Mehdi Amini
Date: 2023-08-22T12:55:07-07:00
New Revision: 809b44039555d35096722e32aea2a8df778c303b
URL: https://github.com/llvm/llvm-project/commit/809b44039555d35096722e32aea2a8df778c303b
DIFF: https://github.com/llvm/llvm-project/commit/809b44039555d35096722e32aea2a8df778c303b.diff
LOG: Fix MLIR pass manager initialization: hash the pass pipeline to detect when initialization is needed
The current logic hashes the context to detect registration changes and re-run
the pass initialization. However it wasn't checking for changes to the
pipeline, so a pass that would get added after a first run would not be
initialized during subsequent runs.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D158377
Added:
Modified:
mlir/include/mlir/Pass/PassManager.h
mlir/lib/Pass/Pass.cpp
mlir/unittests/Pass/PassManagerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 75fe1524221c1c..d5f1ea0fe0350d 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -172,6 +172,10 @@ class OpPassManager {
/// if a pass manager has already been initialized.
LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration);
+ /// Compute a hash of the pipeline, so that we can detect changes (a pass is
+ /// added...).
+ llvm::hash_code hash();
+
/// A pointer to an internal implementation instance.
std::unique_ptr<detail::OpPassManagerImpl> impl;
@@ -439,9 +443,11 @@ class PassManager : public OpPassManager {
/// generate reproducers.
std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;
- /// A hash key used to detect when reinitialization is necessary.
+ /// Hash keys used to detect when reinitialization is necessary.
llvm::hash_code initializationKey =
DenseMapInfo<llvm::hash_code>::getTombstoneKey();
+ llvm::hash_code pipelineInitializationKey =
+ DenseMapInfo<llvm::hash_code>::getTombstoneKey();
/// Flag that specifies if pass timing is enabled.
bool passTiming : 1;
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 3b41cbe48124ea..3b933fde58a137 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Threading.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
@@ -424,6 +425,23 @@ LogicalResult OpPassManager::initialize(MLIRContext *context,
return success();
}
+llvm::hash_code OpPassManager::hash() {
+ llvm::hash_code hashCode;
+ for (Pass &pass : getPasses()) {
+ // If this pass isn't an adaptor, directly hash it.
+ auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
+ if (!adaptor) {
+ hashCode = llvm::hash_combine(hashCode, &pass);
+ continue;
+ }
+ // Otherwise, hash recursively each of the adaptors pass managers.
+ for (OpPassManager &adaptorPM : adaptor->getPassManagers())
+ llvm::hash_combine(hashCode, adaptorPM.hash());
+ }
+ return hashCode;
+}
+
+
//===----------------------------------------------------------------------===//
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
@@ -825,10 +843,12 @@ LogicalResult PassManager::run(Operation *op) {
// Initialize all of the passes within the pass manager with a new generation.
llvm::hash_code newInitKey = context->getRegistryHash();
- if (newInitKey != initializationKey) {
+ llvm::hash_code pipelineKey = hash();
+ if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
+ pipelineKey = pipelineInitializationKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 97349d681c3a0b..70a679125c0ea1 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "gtest/gtest.h"
@@ -144,4 +145,39 @@ TEST(PassManagerTest, InvalidPass) {
"intend to nest?");
}
+/// Simple pass to annotate a func::FuncOp with the results of analysis.
+struct InitializeCheckingPass
+ : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
+ LogicalResult initialize(MLIRContext *ctx) final {
+ initialized = true;
+ return success();
+ }
+ bool initialized = false;
+
+ void runOnOperation() override {
+ if (!initialized) {
+ getOperation()->emitError() << "Pass isn't initialized!";
+ signalPassFailure();
+ }
+ }
+};
+
+TEST(PassManagerTest, PassInitialization) {
+ MLIRContext context;
+ context.allowUnregisteredDialects();
+
+ // Create a module
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+
+ // Instantiate and run our pass.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ pm.addPass(std::make_unique<InitializeCheckingPass>());
+ EXPECT_TRUE(succeeded(pm.run(module.get())));
+
+ // Adding a second copy of the pass, we should also initialize it!
+ pm.addPass(std::make_unique<InitializeCheckingPass>());
+ EXPECT_TRUE(succeeded(pm.run(module.get())));
+}
+
} // namespace
More information about the Mlir-commits
mailing list