[Mlir-commits] [mlir] 02bc4c9 - [mlir][PassManager] Only reinitialize the pass manager if the context registry changes
River Riddle
llvmlistbot at llvm.org
Wed Jan 27 17:50:38 PST 2021
Author: River Riddle
Date: 2021-01-27T17:41:51-08:00
New Revision: 02bc4c95f0729cc819776f73ec94a25405579183
URL: https://github.com/llvm/llvm-project/commit/02bc4c95f0729cc819776f73ec94a25405579183
DIFF: https://github.com/llvm/llvm-project/commit/02bc4c95f0729cc819776f73ec94a25405579183.diff
LOG: [mlir][PassManager] Only reinitialize the pass manager if the context registry changes
This prevents needless reinitialization for clients that want to reuse a pass manager multiple times. A new `getRegisryHash` function is exposed by the context to give a rough indicator of when the context registry has changed.
Differential Revision: https://reviews.llvm.org/D95493
Added:
Modified:
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/Pass/PassManager.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Pass/Pass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 4751f00a36df..eace86f9cb7a 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -166,6 +166,12 @@ class MLIRContext {
Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor);
+ /// Returns a hash of the registry of the context that may be used to give
+ /// a rough indicator of if the state of the context registry has changed. The
+ /// context registry correlates to loaded dialects and their entities
+ /// (attributes, operations, types, etc.).
+ llvm::hash_code getRegistryHash();
+
private:
const std::unique_ptr<MLIRContextImpl> impl;
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index beb6bc99a18e..e73459f835f2 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -375,6 +375,9 @@ class PassManager : public OpPassManager {
/// An optional factory to use when generating a crash reproducer if valid.
ReproducerStreamFactory crashReproducerStreamFactory;
+ /// A hash key used to detect when reinitialization is necessary.
+ llvm::hash_code initializationKey;
+
/// Flag that specifies if pass timing is enabled.
bool passTiming : 1;
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 9307d9cd2c6e..d782a854b6f0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -492,6 +492,16 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
return dialect.get();
}
+llvm::hash_code MLIRContext::getRegistryHash() {
+ llvm::hash_code hash(0);
+ // Factor in number of loaded dialects, attributes, operations, types.
+ hash = llvm::hash_combine(hash, impl->loadedDialects.size());
+ hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
+ hash = llvm::hash_combine(hash, impl->registeredOperations.size());
+ hash = llvm::hash_combine(hash, impl->registeredTypes.size());
+ return hash;
+}
+
bool MLIRContext::allowsUnregisteredDialects() {
return impl->allowUnregisteredDialects;
}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0828941d0267..66c8f66d7c8c 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -846,6 +846,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
StringRef operationName)
: OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
+ initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
passTiming(false), localReproducer(false), verifyPasses(true) {}
PassManager::~PassManager() {}
@@ -868,7 +869,11 @@ LogicalResult PassManager::run(Operation *op) {
dependentDialects.loadAll(context);
// Initialize all of the passes within the pass manager with a new generation.
- initialize(context, impl->initializationGeneration + 1);
+ llvm::hash_code newInitKey = context->getRegistryHash();
+ if (newInitKey != initializationKey) {
+ initialize(context, impl->initializationGeneration + 1);
+ initializationKey = newInitKey;
+ }
// Construct a top level analysis manager for the pipeline.
ModuleAnalysisManager am(op, instrumentor.get());
More information about the Mlir-commits
mailing list