[Mlir-commits] [mlir] 02bc4c9 - [mlir][PassManager] Only reinitialize the pass manager if the context registry changes

River Riddle llvmlistbot at llvm.org
Wed Jan 27 17:50:38 PST 2021


Author: River Riddle
Date: 2021-01-27T17:41:51-08:00
New Revision: 02bc4c95f0729cc819776f73ec94a25405579183

URL: https://github.com/llvm/llvm-project/commit/02bc4c95f0729cc819776f73ec94a25405579183
DIFF: https://github.com/llvm/llvm-project/commit/02bc4c95f0729cc819776f73ec94a25405579183.diff

LOG: [mlir][PassManager] Only reinitialize the pass manager if the context registry changes

This prevents needless reinitialization for clients that want to reuse a pass manager multiple times. A new `getRegisryHash` function is exposed by the context to give a rough indicator of when the context registry has changed.

Differential Revision: https://reviews.llvm.org/D95493

Added: 
    

Modified: 
    mlir/include/mlir/IR/MLIRContext.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Pass/Pass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 4751f00a36df..eace86f9cb7a 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -166,6 +166,12 @@ class MLIRContext {
   Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
                             function_ref<std::unique_ptr<Dialect>()> ctor);
 
+  /// Returns a hash of the registry of the context that may be used to give
+  /// a rough indicator of if the state of the context registry has changed. The
+  /// context registry correlates to loaded dialects and their entities
+  /// (attributes, operations, types, etc.).
+  llvm::hash_code getRegistryHash();
+
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index beb6bc99a18e..e73459f835f2 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -375,6 +375,9 @@ class PassManager : public OpPassManager {
   /// An optional factory to use when generating a crash reproducer if valid.
   ReproducerStreamFactory crashReproducerStreamFactory;
 
+  /// A hash key used to detect when reinitialization is necessary.
+  llvm::hash_code initializationKey;
+
   /// Flag that specifies if pass timing is enabled.
   bool passTiming : 1;
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 9307d9cd2c6e..d782a854b6f0 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -492,6 +492,16 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   return dialect.get();
 }
 
+llvm::hash_code MLIRContext::getRegistryHash() {
+  llvm::hash_code hash(0);
+  // Factor in number of loaded dialects, attributes, operations, types.
+  hash = llvm::hash_combine(hash, impl->loadedDialects.size());
+  hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
+  hash = llvm::hash_combine(hash, impl->registeredOperations.size());
+  hash = llvm::hash_combine(hash, impl->registeredTypes.size());
+  return hash;
+}
+
 bool MLIRContext::allowsUnregisteredDialects() {
   return impl->allowUnregisteredDialects;
 }

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0828941d0267..66c8f66d7c8c 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -846,6 +846,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
                          StringRef operationName)
     : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
+      initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
       passTiming(false), localReproducer(false), verifyPasses(true) {}
 
 PassManager::~PassManager() {}
@@ -868,7 +869,11 @@ LogicalResult PassManager::run(Operation *op) {
   dependentDialects.loadAll(context);
 
   // Initialize all of the passes within the pass manager with a new generation.
-  initialize(context, impl->initializationGeneration + 1);
+  llvm::hash_code newInitKey = context->getRegistryHash();
+  if (newInitKey != initializationKey) {
+    initialize(context, impl->initializationGeneration + 1);
+    initializationKey = newInitKey;
+  }
 
   // Construct a top level analysis manager for the pipeline.
   ModuleAnalysisManager am(op, instrumentor.get());


        


More information about the Mlir-commits mailing list