[Mlir-commits] [mlir] c0b6bc0 - Decouple OpPassManager from the the MLIRContext (NFC)

Mehdi Amini llvmlistbot at llvm.org
Wed Sep 2 23:02:16 PDT 2020


Author: Mehdi Amini
Date: 2020-09-03T06:02:05Z
New Revision: c0b6bc070e78cbd20bc4351704f52d85192e8804

URL: https://github.com/llvm/llvm-project/commit/c0b6bc070e78cbd20bc4351704f52d85192e8804
DIFF: https://github.com/llvm/llvm-project/commit/c0b6bc070e78cbd20bc4351704f52d85192e8804.diff

LOG: Decouple OpPassManager from the the MLIRContext (NFC)

This is allowing to build an OpPassManager from a StringRef instead of an
Identifier, which enables building pipelines without an MLIRContext.
An identifier is still cached on-demand on the OpPassManager for efficiency
during the IR traversal.

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassStatistics.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 8addd9809f90..ec88485cd3ef 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -47,7 +47,8 @@ struct OpPassManagerImpl;
 /// other OpPassManagers or the top-level PassManager.
 class OpPassManager {
 public:
-  OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses);
+  OpPassManager(Identifier name, bool verifyPasses);
+  OpPassManager(StringRef name, bool verifyPasses);
   OpPassManager(OpPassManager &&rhs);
   OpPassManager(const OpPassManager &rhs);
   ~OpPassManager();
@@ -73,7 +74,7 @@ class OpPassManager {
   OpPassManager &nest(Identifier nestedName);
   OpPassManager &nest(StringRef nestedName);
   template <typename OpT> OpPassManager &nest() {
-    return nest(Identifier::get(OpT::getOperationName(), getContext()));
+    return nest(OpT::getOperationName());
   }
 
   /// Add the given pass to this pass manager. If this pass has a concrete
@@ -89,11 +90,11 @@ class OpPassManager {
   /// Returns the number of passes held by this manager.
   size_t size() const;
 
-  /// Return an instance of the context.
-  MLIRContext *getContext() const;
+  /// Return the operation name that this pass manager operates on.
+  Identifier getOpName(MLIRContext &context) const;
 
   /// Return the operation name that this pass manager operates on.
-  Identifier getOpName() const;
+  StringRef getOpName() const;
 
   /// Returns the internal implementation instance.
   detail::OpPassManagerImpl &getImpl();
@@ -151,6 +152,9 @@ class PassManager : public OpPassManager {
   LLVM_NODISCARD
   LogicalResult run(ModuleOp module);
 
+  /// Return an instance of the context.
+  MLIRContext *getContext() const { return context; }
+
   /// Enable support for the pass manager to generate a reproducer on the event
   /// of a crash or a pass failure. `outputFile` is a .mlir filename used to
   /// write the generated reproducer. If `genLocalReproducer` is true, the pass
@@ -304,6 +308,8 @@ class PassManager : public OpPassManager {
   runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
                        ModuleOp module, AnalysisManager am);
 
+  MLIRContext *context;
+
   /// Flag that specifies if pass statistics should be dumped.
   Optional<PassDisplayMode> passStatisticsMode;
 

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d3cf62574afd..3ac41cde7911 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -92,8 +92,10 @@ void VerifierPass::runOnOperation() {
 namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
-  OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses)
-      : name(name), context(ctx), verifyPasses(verifyPasses) {}
+  OpPassManagerImpl(Identifier identifier, bool verifyPasses)
+      : name(identifier), identifier(identifier), verifyPasses(verifyPasses) {}
+  OpPassManagerImpl(StringRef name, bool verifyPasses)
+      : name(name), verifyPasses(verifyPasses) {}
 
   /// Merge the passes of this pass manager into the one provided.
   void mergeInto(OpPassManagerImpl &rhs);
@@ -101,9 +103,7 @@ struct OpPassManagerImpl {
   /// Nest a new operation pass manager for the given operation kind under this
   /// pass manager.
   OpPassManager &nest(Identifier nestedName);
-  OpPassManager &nest(StringRef nestedName) {
-    return nest(Identifier::get(nestedName, getContext()));
-  }
+  OpPassManager &nest(StringRef nestedName);
 
   /// Add the given pass to this pass manager. If this pass has a concrete
   /// operation type, it must be the same type as this pass manager.
@@ -117,14 +117,18 @@ struct OpPassManagerImpl {
   /// pass.
   void splitAdaptorPasses();
 
-  /// Return an instance of the context.
-  MLIRContext *getContext() const { return context; }
+  Identifier getOpName(MLIRContext &context) {
+    if (!identifier)
+      identifier = Identifier::get(name, &context);
+    return *identifier;
+  }
 
   /// The name of the operation that passes of this pass manager operate on.
-  Identifier name;
+  StringRef name;
 
-  /// The current context for this pass manager
-  MLIRContext *context;
+  /// The cached identifier (internalized in the context) for the name of the
+  /// operation that passes of this pass manager operate on.
+  Optional<Identifier> identifier;
 
   /// Flag that specifies if the IR should be verified after each pass has run.
   bool verifyPasses : 1;
@@ -143,7 +147,14 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 }
 
 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
-  OpPassManager nested(nestedName, getContext(), verifyPasses);
+  OpPassManager nested(nestedName, verifyPasses);
+  auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
+  addPass(std::unique_ptr<Pass>(adaptor));
+  return adaptor->getPassManagers().front();
+}
+
+OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
+  OpPassManager nested(nestedName, verifyPasses);
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
@@ -153,7 +164,7 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
   // If this pass runs on a 
diff erent operation than this pass manager, then
   // implicitly nest a pass manager for this operation.
   auto passOpName = pass->getOpName();
-  if (passOpName && passOpName != name.strref())
+  if (passOpName && passOpName != name)
     return nest(*passOpName).addPass(std::move(pass));
 
   passes.emplace_back(std::move(pass));
@@ -240,14 +251,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
 // OpPassManager
 //===----------------------------------------------------------------------===//
 
-OpPassManager::OpPassManager(Identifier name, MLIRContext *context,
-                             bool verifyPasses)
-    : impl(new OpPassManagerImpl(name, context, verifyPasses)) {}
+OpPassManager::OpPassManager(Identifier name, bool verifyPasses)
+    : impl(new OpPassManagerImpl(name, verifyPasses)) {}
+OpPassManager::OpPassManager(StringRef name, bool verifyPasses)
+    : impl(new OpPassManagerImpl(name, verifyPasses)) {}
 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
-  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(),
-                                   rhs.impl->verifyPasses));
+  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
   for (auto &pass : rhs.impl->passes)
     impl->passes.emplace_back(pass->clone());
   return *this;
@@ -290,11 +301,13 @@ size_t OpPassManager::size() const { return impl->passes.size(); }
 /// Returns the internal implementation instance.
 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
 
-/// Return an instance of the context.
-MLIRContext *OpPassManager::getContext() const { return impl->getContext(); }
+/// Return the operation name that this pass manager operates on.
+StringRef OpPassManager::getOpName() const { return impl->name; }
 
 /// Return the operation name that this pass manager operates on.
-Identifier OpPassManager::getOpName() const { return impl->name; }
+Identifier OpPassManager::getOpName(MLIRContext &context) const {
+  return impl->getOpName(context);
+}
 
 /// Prints out the given passes as the textual representation of a pipeline.
 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
@@ -389,12 +402,22 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
 /// Find an operation pass manager that can operate on an operation of the given
 /// type, or nullptr if one does not exist.
 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
-                                         Identifier name) {
+                                         StringRef name) {
   auto it = llvm::find_if(
       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
   return it == mgrs.end() ? nullptr : &*it;
 }
 
+/// Find an operation pass manager that can operate on an operation of the given
+/// type, or nullptr if one does not exist.
+static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
+                                         Identifier name,
+                                         MLIRContext &context) {
+  auto it = llvm::find_if(
+      mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
+  return it == mgrs.end() ? nullptr : &*it;
+}
+
 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
   mgrs.emplace_back(std::move(mgr));
 }
@@ -421,8 +444,7 @@ void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
   // After coalescing, sort the pass managers within rhs by name.
   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
                        [](const OpPassManager *lhs, const OpPassManager *rhs) {
-                         return lhs->getOpName().strref().compare(
-                             rhs->getOpName().strref());
+                         return lhs->getOpName().compare(rhs->getOpName());
                        });
 }
 
@@ -454,16 +476,18 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier());
+        auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
+                                       *op.getContext());
         if (!mgr)
           continue;
+        Identifier opName = mgr->getOpName(*getOperation()->getContext());
 
         // Run the held pipeline over the current operation.
         if (instrumentor)
-          instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo);
+          instrumentor->runBeforePipeline(opName, parentInfo);
         auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op));
         if (instrumentor)
-          instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo);
+          instrumentor->runAfterPipeline(opName, parentInfo);
 
         if (failed(result))
           return signalPassFailure();
@@ -499,7 +523,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
     for (auto &block : region) {
       for (auto &op : block) {
         // Add this operation iff the name matches any of the pass managers.
-        if (findPassManagerFor(mgrs, op.getName().getIdentifier()))
+        if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
+                               getContext()))
           opAMPairs.emplace_back(&op, am.nest(&op));
       }
     }
@@ -535,16 +560,17 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
 
           // Get the pass manager for this operation and execute it.
           auto &it = opAMPairs[nextID];
-          auto *pm =
-              findPassManagerFor(pms, it.first->getName().getIdentifier());
+          auto *pm = findPassManagerFor(
+              pms, it.first->getName().getIdentifier(), getContext());
           assert(pm && "expected valid pass manager for operation");
 
+          Identifier opName = pm->getOpName(*getOperation()->getContext());
           if (instrumentor)
-            instrumentor->runBeforePipeline(pm->getOpName(), parentInfo);
+            instrumentor->runBeforePipeline(opName, parentInfo);
           auto pipelineResult =
               runPipeline(pm->getPasses(), it.first, it.second);
           if (instrumentor)
-            instrumentor->runAfterPipeline(pm->getOpName(), parentInfo);
+            instrumentor->runAfterPipeline(opName, parentInfo);
 
           // Drop this thread from being tracked by the diagnostic handler.
           // After this task has finished, the thread may be used outside of
@@ -737,9 +763,9 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 //===----------------------------------------------------------------------===//
 
 PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
-    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx,
+    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
                     verifyPasses),
-      passTiming(false), localReproducer(false) {}
+      context(ctx), passTiming(false), localReproducer(false) {}
 
 PassManager::~PassManager() {}
 

diff  --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp
index 3721230b6913..d909c98abf56 100644
--- a/mlir/lib/Pass/PassStatistics.cpp
+++ b/mlir/lib/Pass/PassStatistics.cpp
@@ -116,7 +116,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
 
       // Print each of the children passes.
       for (OpPassManager &mgr : mgrs) {
-        auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str();
+        auto name = ("'" + mgr.getOpName() + "' Pipeline").str();
         printPassEntry(os, indent, name);
         for (Pass &pass : mgr.getPasses())
           printPass(indent + 2, &pass);


        


More information about the Mlir-commits mailing list