[Mlir-commits] [mlir] 1284dc3 - Use an Identifier instead of an OperationName internally for OpPassManager identification (NFC)

Mehdi Amini llvmlistbot at llvm.org
Wed Sep 2 14:46:19 PDT 2020


Author: Mehdi Amini
Date: 2020-09-02T21:46:05Z
New Revision: 1284dc34abd11ce4275ad21c0470ad8c679b59b7

URL: https://github.com/llvm/llvm-project/commit/1284dc34abd11ce4275ad21c0470ad8c679b59b7
DIFF: https://github.com/llvm/llvm-project/commit/1284dc34abd11ce4275ad21c0470ad8c679b59b7.diff

LOG: Use an Identifier instead of an OperationName internally for OpPassManager identification (NFC)

This allows to defers the check for traits to the execution instead of forcing it on the pipeline creation.
In particular, this is making our pipeline creation tolerant to dialects not being loaded in the context yet.

Reviewed By: rriddle, GMNGeoffrey

Differential Revision: https://reviews.llvm.org/D86915

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassInstrumentation.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassStatistics.cpp
    mlir/lib/Pass/PassTiming.cpp
    mlir/unittests/Pass/PassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index dc648b2b0edf..baf230f086fd 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -9,12 +9,12 @@
 #ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
 #define MLIR_PASS_PASSINSTRUMENTATION_H_
 
+#include "mlir/IR/Identifier.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/TypeID.h"
 
 namespace mlir {
 class Operation;
-class OperationName;
 class Pass;
 
 namespace detail {
@@ -43,13 +43,13 @@ class PassInstrumentation {
   /// A callback to run before a pass pipeline is executed. This function takes
   /// the name of the operation type being operated on, and information related
   /// to the parent that spawned this pipeline.
-  virtual void runBeforePipeline(const OperationName &name,
+  virtual void runBeforePipeline(Identifier name,
                                  const PipelineParentInfo &parentInfo) {}
 
   /// A callback to run after a pass pipeline has executed. This function takes
   /// the name of the operation type being operated on, and information related
   /// to the parent that spawned this pipeline.
-  virtual void runAfterPipeline(const OperationName &name,
+  virtual void runAfterPipeline(Identifier name,
                                 const PipelineParentInfo &parentInfo) {}
 
   /// A callback to run before a pass is executed. This function takes a pointer
@@ -90,12 +90,12 @@ class PassInstrumentor {
 
   /// See PassInstrumentation::runBeforePipeline for details.
   void
-  runBeforePipeline(const OperationName &name,
+  runBeforePipeline(Identifier name,
                     const PassInstrumentation::PipelineParentInfo &parentInfo);
 
   /// See PassInstrumentation::runAfterPipeline for details.
   void
-  runAfterPipeline(const OperationName &name,
+  runAfterPipeline(Identifier name,
                    const PassInstrumentation::PipelineParentInfo &parentInfo);
 
   /// See PassInstrumentation::runBeforePass for details.

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index e19a1fab7f13..8addd9809f90 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -26,9 +26,9 @@ class Any;
 
 namespace mlir {
 class AnalysisManager;
+class Identifier;
 class MLIRContext;
 class ModuleOp;
-class OperationName;
 class Operation;
 class Pass;
 class PassInstrumentation;
@@ -47,7 +47,7 @@ struct OpPassManagerImpl;
 /// other OpPassManagers or the top-level PassManager.
 class OpPassManager {
 public:
-  OpPassManager(OperationName name, bool verifyPasses);
+  OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses);
   OpPassManager(OpPassManager &&rhs);
   OpPassManager(const OpPassManager &rhs);
   ~OpPassManager();
@@ -70,10 +70,10 @@ class OpPassManager {
 
   /// Nest a new operation pass manager for the given operation kind under this
   /// pass manager.
-  OpPassManager &nest(const OperationName &nestedName);
+  OpPassManager &nest(Identifier nestedName);
   OpPassManager &nest(StringRef nestedName);
   template <typename OpT> OpPassManager &nest() {
-    return nest(OpT::getOperationName());
+    return nest(Identifier::get(OpT::getOperationName(), getContext()));
   }
 
   /// Add the given pass to this pass manager. If this pass has a concrete
@@ -93,7 +93,7 @@ class OpPassManager {
   MLIRContext *getContext() const;
 
   /// Return the operation name that this pass manager operates on.
-  const OperationName &getOpName() const;
+  Identifier getOpName() const;
 
   /// Returns the internal implementation instance.
   detail::OpPassManagerImpl &getImpl();

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index bb521633b5f3..d3cf62574afd 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -92,17 +92,17 @@ void VerifierPass::runOnOperation() {
 namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
-  OpPassManagerImpl(OperationName name, bool verifyPasses)
-      : name(name), verifyPasses(verifyPasses) {}
+  OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses)
+      : name(name), context(ctx), verifyPasses(verifyPasses) {}
 
   /// Merge the passes of this pass manager into the one provided.
   void mergeInto(OpPassManagerImpl &rhs);
 
   /// Nest a new operation pass manager for the given operation kind under this
   /// pass manager.
-  OpPassManager &nest(const OperationName &nestedName);
+  OpPassManager &nest(Identifier nestedName);
   OpPassManager &nest(StringRef nestedName) {
-    return nest(OperationName(nestedName, getContext()));
+    return nest(Identifier::get(nestedName, getContext()));
   }
 
   /// Add the given pass to this pass manager. If this pass has a concrete
@@ -118,12 +118,13 @@ struct OpPassManagerImpl {
   void splitAdaptorPasses();
 
   /// Return an instance of the context.
-  MLIRContext *getContext() const {
-    return name.getAbstractOperation()->dialect.getContext();
-  }
+  MLIRContext *getContext() const { return context; }
 
   /// The name of the operation that passes of this pass manager operate on.
-  OperationName name;
+  Identifier name;
+
+  /// The current context for this pass manager
+  MLIRContext *context;
 
   /// Flag that specifies if the IR should be verified after each pass has run.
   bool verifyPasses : 1;
@@ -141,8 +142,8 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
   passes.clear();
 }
 
-OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) {
-  OpPassManager nested(nestedName, verifyPasses);
+OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
+  OpPassManager nested(nestedName, getContext(), verifyPasses);
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
   addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
@@ -152,7 +153,7 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
   // If this pass runs on a 
diff erent operation than this pass manager, then
   // implicitly nest a pass manager for this operation.
   auto passOpName = pass->getOpName();
-  if (passOpName && passOpName != name.getStringRef())
+  if (passOpName && passOpName != name.strref())
     return nest(*passOpName).addPass(std::move(pass));
 
   passes.emplace_back(std::move(pass));
@@ -239,19 +240,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
 // OpPassManager
 //===----------------------------------------------------------------------===//
 
-OpPassManager::OpPassManager(OperationName name, bool verifyPasses)
-    : impl(new OpPassManagerImpl(name, verifyPasses)) {
-  assert(name.getAbstractOperation() &&
-         "OpPassManager can only operate on registered operations");
-  assert(name.getAbstractOperation()->hasProperty(
-             OperationProperty::IsolatedFromAbove) &&
-         "OpPassManager only supports operating on operations marked as "
-         "'IsolatedFromAbove'");
-}
+OpPassManager::OpPassManager(Identifier name, MLIRContext *context,
+                             bool verifyPasses)
+    : impl(new OpPassManagerImpl(name, context, verifyPasses)) {}
 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
-  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
+  impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(),
+                                   rhs.impl->verifyPasses));
   for (auto &pass : rhs.impl->passes)
     impl->passes.emplace_back(pass->clone());
   return *this;
@@ -275,7 +271,7 @@ OpPassManager::const_pass_iterator OpPassManager::end() const {
 
 /// Nest a new operation pass manager for the given operation kind under this
 /// pass manager.
-OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
+OpPassManager &OpPassManager::nest(Identifier nestedName) {
   return impl->nest(nestedName);
 }
 OpPassManager &OpPassManager::nest(StringRef nestedName) {
@@ -298,7 +294,7 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
 MLIRContext *OpPassManager::getContext() const { return impl->getContext(); }
 
 /// Return the operation name that this pass manager operates on.
-const OperationName &OpPassManager::getOpName() const { return impl->name; }
+Identifier OpPassManager::getOpName() const { return impl->name; }
 
 /// Prints out the given passes as the textual representation of a pipeline.
 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
@@ -336,6 +332,14 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
 
 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
                                      AnalysisManager am) {
+  if (!op->getName().getAbstractOperation())
+    return op->emitOpError()
+           << "trying to schedule a pass on an unregistered operation";
+  if (!op->getName().getAbstractOperation()->hasProperty(
+          OperationProperty::IsolatedFromAbove))
+    return op->emitOpError() << "trying to schedule a pass on an operation not "
+                                "marked as 'IsolatedFromAbove'";
+
   pass->passState.emplace(op, am);
 
   // Instrument before the pass has run.
@@ -385,7 +389,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
 /// Find an operation pass manager that can operate on an operation of the given
 /// type, or nullptr if one does not exist.
 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
-                                         const OperationName &name) {
+                                         Identifier name) {
   auto it = llvm::find_if(
       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
   return it == mgrs.end() ? nullptr : &*it;
@@ -417,8 +421,8 @@ void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
   // After coalescing, sort the pass managers within rhs by name.
   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
                        [](const OpPassManager *lhs, const OpPassManager *rhs) {
-                         return lhs->getOpName().getStringRef().compare(
-                             rhs->getOpName().getStringRef());
+                         return lhs->getOpName().strref().compare(
+                             rhs->getOpName().strref());
                        });
 }
 
@@ -450,7 +454,7 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        auto *mgr = findPassManagerFor(mgrs, op.getName());
+        auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier());
         if (!mgr)
           continue;
 
@@ -494,8 +498,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        // Add this operation iff the name matches the any of the pass managers.
-        if (findPassManagerFor(mgrs, op.getName()))
+        // Add this operation iff the name matches any of the pass managers.
+        if (findPassManagerFor(mgrs, op.getName().getIdentifier()))
           opAMPairs.emplace_back(&op, am.nest(&op));
       }
     }
@@ -531,7 +535,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
 
           // Get the pass manager for this operation and execute it.
           auto &it = opAMPairs[nextID];
-          auto *pm = findPassManagerFor(pms, it.first->getName());
+          auto *pm =
+              findPassManagerFor(pms, it.first->getName().getIdentifier());
           assert(pm && "expected valid pass manager for operation");
 
           if (instrumentor)
@@ -732,7 +737,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 //===----------------------------------------------------------------------===//
 
 PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
-    : OpPassManager(OperationName(ModuleOp::getOperationName(), ctx),
+    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx,
                     verifyPasses),
       passTiming(false), localReproducer(false) {}
 
@@ -870,7 +875,7 @@ PassInstrumentor::~PassInstrumentor() {}
 
 /// See PassInstrumentation::runBeforePipeline for details.
 void PassInstrumentor::runBeforePipeline(
-    const OperationName &name,
+    Identifier name,
     const PassInstrumentation::PipelineParentInfo &parentInfo) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : impl->instrumentations)
@@ -879,7 +884,7 @@ void PassInstrumentor::runBeforePipeline(
 
 /// See PassInstrumentation::runAfterPipeline for details.
 void PassInstrumentor::runAfterPipeline(
-    const OperationName &name,
+    Identifier name,
     const PassInstrumentation::PipelineParentInfo &parentInfo) {
   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
   for (auto &instr : llvm::reverse(impl->instrumentations))

diff  --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp
index 6ef0d3bbea6a..3721230b6913 100644
--- a/mlir/lib/Pass/PassStatistics.cpp
+++ b/mlir/lib/Pass/PassStatistics.cpp
@@ -116,7 +116,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
 
       // Print each of the children passes.
       for (OpPassManager &mgr : mgrs) {
-        auto name = ("'" + mgr.getOpName().getStringRef() + "' Pipeline").str();
+        auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str();
         printPassEntry(os, indent, name);
         for (Pass &pass : mgr.getPasses())
           printPass(indent + 2, &pass);

diff  --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index 71bf822a864b..e3978751c11c 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -165,9 +165,9 @@ struct PassTiming : public PassInstrumentation {
   ~PassTiming() override { print(); }
 
   /// Setup the instrumentation hooks.
-  void runBeforePipeline(const OperationName &name,
+  void runBeforePipeline(Identifier name,
                          const PipelineParentInfo &parentInfo) override;
-  void runAfterPipeline(const OperationName &name,
+  void runAfterPipeline(Identifier name,
                         const PipelineParentInfo &parentInfo) override;
   void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); }
   void runAfterPass(Pass *pass, Operation *) override;
@@ -242,15 +242,15 @@ struct PassTiming : public PassInstrumentation {
 };
 } // end anonymous namespace
 
-void PassTiming::runBeforePipeline(const OperationName &name,
+void PassTiming::runBeforePipeline(Identifier name,
                                    const PipelineParentInfo &parentInfo) {
   // We don't actually want to time the pipelines, they gather their total
   // from their held passes.
   getTimer(name.getAsOpaquePointer(), TimerKind::Pipeline,
-           [&] { return ("'" + name.getStringRef() + "' Pipeline").str(); });
+           [&] { return ("'" + name.strref() + "' Pipeline").str(); });
 }
 
-void PassTiming::runAfterPipeline(const OperationName &name,
+void PassTiming::runAfterPipeline(Identifier name,
                                   const PipelineParentInfo &parentInfo) {
   // Pop the timer for the pipeline.
   auto tid = llvm::get_threadid();

diff  --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 29086a2994e8..99d4972ef63c 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -74,4 +74,47 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
   }
 }
 
+namespace {
+struct InvalidPass : Pass {
+  InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
+  StringRef getName() const override { return "Invalid Pass"; }
+  void runOnOperation() override {}
+
+  /// A clone method to create a copy of this pass.
+  std::unique_ptr<Pass> clonePass() const override {
+    return std::make_unique<InvalidPass>(
+        *static_cast<const InvalidPass *>(this));
+  }
+};
+} // anonymous namespace
+
+TEST(PassManagerTest, InvalidPass) {
+  MLIRContext context;
+
+  // Create a module
+  OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
+
+  // Add a single "invalid_op" operation
+  OpBuilder builder(&module->getBodyRegion());
+  OperationState state(UnknownLoc::get(&context), "invalid_op");
+  builder.insert(Operation::create(state));
+
+  // Register a diagnostic handler to capture the diagnostic so that we can
+  // check it later.
+  std::unique_ptr<Diagnostic> diagnostic;
+  context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
+    diagnostic.reset(new Diagnostic(std::move(diag)));
+  });
+
+  // Instantiate and run our pass.
+  PassManager pm(&context);
+  pm.addPass(std::make_unique<InvalidPass>());
+  LogicalResult result = pm.run(module.get());
+  EXPECT_TRUE(failed(result));
+  ASSERT_TRUE(diagnostic.get() != nullptr);
+  EXPECT_EQ(
+      diagnostic->str(),
+      "'invalid_op' op trying to schedule a pass on an unregistered operation");
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list