[Mlir-commits] [mlir] 9c9a431 - [mlir][Pass] Add support for an InterfacePass and pass filtering based on OperationName

River Riddle llvmlistbot at llvm.org
Fri Mar 4 15:14:25 PST 2022


Author: River Riddle
Date: 2022-03-04T15:14:04-08:00
New Revision: 9c9a4317359e40cca6e1dbf4439c0f5cd1afbd7a

URL: https://github.com/llvm/llvm-project/commit/9c9a4317359e40cca6e1dbf4439c0f5cd1afbd7a
DIFF: https://github.com/llvm/llvm-project/commit/9c9a4317359e40cca6e1dbf4439c0f5cd1afbd7a.diff

LOG: [mlir][Pass] Add support for an InterfacePass and pass filtering based on OperationName

This commit adds a new hook Pass `bool canScheduleOn(RegisteredOperationName)` that
indicates if the given pass can be scheduled on operations of the given type. This makes it
easier to define constraints on generic passes without a) adding conditional checks to
the beginning of the `runOnOperation`, or b) defining a new pass type that forwards
from `runOnOperation` (after checking the invariants) to a new hook. This new hook is
used to implement an `InterfacePass` pass class, that represents a  generic pass that
runs on operations of the given interface type.

The PassManager will also verify that passes added to a pass manager can actually be
scheduled on that pass manager, meaning that we will properly error when an Interface
is scheduled on an operation that doesn't actually implement that interface.

Differential Revision: https://reviews.llvm.org/D120791

Added: 
    mlir/test/Pass/interface-pass.mlir
    mlir/test/Pass/invalid-interface-pass.mlir

Modified: 
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassBase.td
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/test/lib/Pass/TestPassManager.cpp
    mlir/unittests/Pass/PassManagerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 9b67491993c35..89cebc7014822 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -20,6 +20,7 @@
 namespace mlir {
 namespace detail {
 class OpToOpPassAdaptor;
+struct OpPassManagerImpl;
 
 /// The state for a single execution of a pass. This provides a unified
 /// interface for accessing and initializing necessary state for pass execution.
@@ -184,6 +185,11 @@ class Pass {
   /// pipeline won't execute.
   virtual LogicalResult initialize(MLIRContext *context) { return success(); }
 
+  /// Indicate if the current pass can be scheduled on the given operation type.
+  /// This is useful for generic operation passes to add restrictions on the
+  /// operations they operate on.
+  virtual bool canScheduleOn(RegisteredOperationName opName) const = 0;
+
   /// Schedule an arbitrary pass pipeline on the provided operation.
   /// This can be invoke any time in a pass to dynamic schedule more passes.
   /// The provided operation must be the current one or one nested below.
@@ -313,6 +319,9 @@ class Pass {
   /// Allow access to 'clone'.
   friend class OpPassManager;
 
+  /// Allow access to 'canScheduleOn'.
+  friend detail::OpPassManagerImpl;
+
   /// Allow access to 'passState'.
   friend detail::OpToOpPassAdaptor;
 
@@ -346,6 +355,11 @@ template <typename OpT = void> class OperationPass : public Pass {
     return pass->getOpName() == OpT::getOperationName();
   }
 
+  /// Indicate if the current pass can be scheduled on the given operation type.
+  bool canScheduleOn(RegisteredOperationName opName) const final {
+    return opName.getStringRef() == getOpName();
+  }
+
   /// Return the current operation being transformed.
   OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
 
@@ -373,6 +387,46 @@ template <> class OperationPass<void> : public Pass {
 protected:
   OperationPass(TypeID passID) : Pass(passID) {}
   OperationPass(const OperationPass &) = default;
+
+  /// Indicate if the current pass can be scheduled on the given operation type.
+  /// By default, generic operation passes can be scheduled on any operation.
+  bool canScheduleOn(RegisteredOperationName opName) const override {
+    return true;
+  }
+};
+
+/// Pass to transform an operation that implements the given interface.
+///
+/// Interface passes must not:
+///   - modify any other operations within the parent region, as other threads
+///     may be manipulating them concurrently.
+///   - modify any state within the parent operation, this includes adding
+///     additional operations.
+///
+/// Derived interface passes are expected to provide the following:
+///   - A 'void runOnOperation()' method.
+///   - A 'StringRef getName() const' method.
+///   - A 'std::unique_ptr<Pass> clonePass() const' method.
+template <typename InterfaceT>
+class InterfacePass : public OperationPass<> {
+protected:
+  using OperationPass::OperationPass;
+
+  /// Indicate if the current pass can be scheduled on the given operation type.
+  /// For an InterfacePass, this checks if the operation implements the given
+  /// interface.
+  bool canScheduleOn(RegisteredOperationName opName) const final {
+    return opName.hasInterface<InterfaceT>();
+  }
+
+  /// Return the current operation being transformed.
+  InterfaceT getOperation() { return cast<InterfaceT>(Pass::getOperation()); }
+
+  /// Query an analysis for the current operation.
+  template <typename AnalysisT>
+  AnalysisT &getAnalysis() {
+    return Pass::getAnalysis<AnalysisT, InterfaceT>();
+  }
 };
 
 /// This class provides a CRTP wrapper around a base pass class to define

diff  --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index 64ef988fdf941..9b1903acc3019 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -92,4 +92,8 @@ class PassBase<string passArg, string base> {
 class Pass<string passArg, string operation = "">
   : PassBase<passArg, "::mlir::OperationPass<" # operation # ">">;
 
+// This class represents an mlir::InterfacePass.
+class InterfacePass<string passArg, string interface>
+  : PassBase<passArg, "::mlir::InterfacePass<" # interface # ">">;
+
 #endif // MLIR_PASS_PASSBASE

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 1389276c2a196..f2c42bf9140fd 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -98,7 +98,7 @@ class OpPassManager {
   size_t size() const;
 
   /// Return the operation name that this pass manager operates on.
-  StringAttr getOpName(MLIRContext &context) const;
+  OperationName getOpName(MLIRContext &context) const;
 
   /// Return the operation name that this pass manager operates on.
   StringRef getOpName() const;

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 5cc84dcba7e97..22a9641bde0e8 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -80,8 +80,8 @@ void Pass::printAsTextualPipeline(raw_ostream &os) {
 namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
-  OpPassManagerImpl(StringAttr identifier, OpPassManager::Nesting nesting)
-      : name(identifier.str()), identifier(identifier),
+  OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
+      : name(opName.getStringRef()), opName(opName),
         initializationGeneration(0), nesting(nesting) {}
   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
       : name(name), initializationGeneration(0), nesting(nesting) {}
@@ -102,23 +102,24 @@ struct OpPassManagerImpl {
   /// preserved.
   void clear();
 
-  /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
-  /// recursively through the pipeline graph.
-  void coalesceAdjacentAdaptorPasses();
+  /// Finalize the pass list in preparation for execution. This includes
+  /// coalescing adjacent pass managers when possible, verifying scheduled
+  /// passes, etc.
+  LogicalResult finalizePassList(MLIRContext *ctx);
 
-  /// Return the operation name of this pass manager as an identifier.
-  StringAttr getOpName(MLIRContext &context) {
-    if (!identifier)
-      identifier = StringAttr::get(&context, name);
-    return *identifier;
+  /// Return the operation name of this pass manager.
+  OperationName getOpName(MLIRContext &context) {
+    if (!opName)
+      opName = OperationName(name, &context);
+    return *opName;
   }
 
   /// The name of the operation that passes of this pass manager operate on.
   std::string name;
 
-  /// The cached identifier (internalized in the context) for the name of the
+  /// The cached OperationName (internalized in the context) for the name of the
   /// operation that passes of this pass manager operate on.
-  Optional<StringAttr> identifier;
+  Optional<OperationName> opName;
 
   /// The set of passes to run as part of this pass manager.
   std::vector<std::unique_ptr<Pass>> passes;
@@ -173,18 +174,12 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
 
 void OpPassManagerImpl::clear() { passes.clear(); }
 
-void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
-  // Bail out early if there are no adaptor passes.
-  if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
-        return isa<OpToOpPassAdaptor>(pass.get());
-      }))
-    return;
-
+LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
   // Walk the pass list and merge adjacent adaptors.
   OpToOpPassAdaptor *lastAdaptor = nullptr;
-  for (auto &passe : passes) {
+  for (auto &pass : passes) {
     // Check to see if this pass is an adaptor.
-    if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(passe.get())) {
+    if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) {
       // If it is the first adaptor in a possible chain, remember it and
       // continue.
       if (!lastAdaptor) {
@@ -194,25 +189,39 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
 
       // Otherwise, merge into the existing adaptor and delete the current one.
       currentAdaptor->mergeInto(*lastAdaptor);
-      passe.reset();
+      pass.reset();
     } else if (lastAdaptor) {
-      // If this pass is not an adaptor, then coalesce and forget any existing
+      // If this pass is not an adaptor, then finalize and forget any existing
       // adaptor.
       for (auto &pm : lastAdaptor->getPassManagers())
-        pm.getImpl().coalesceAdjacentAdaptorPasses();
+        if (failed(pm.getImpl().finalizePassList(ctx)))
+          return failure();
       lastAdaptor = nullptr;
     }
   }
 
-  // If there was an adaptor at the end of the manager, coalesce it as well.
+  // If there was an adaptor at the end of the manager, finalize it as well.
   if (lastAdaptor) {
     for (auto &pm : lastAdaptor->getPassManagers())
-      pm.getImpl().coalesceAdjacentAdaptorPasses();
+      if (failed(pm.getImpl().finalizePassList(ctx)))
+        return failure();
   }
 
-  // Now that the adaptors have been merged, erase the empty slot corresponding
+  // Now that the adaptors have been merged, erase any empty slots corresponding
   // to the merged adaptors that were nulled-out in the loop above.
+  Optional<RegisteredOperationName> opName =
+      getOpName(*ctx).getRegisteredInfo();
   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
+
+  // Verify that all of the passes are valid for the operation.
+  for (std::unique_ptr<Pass> &pass : passes) {
+    if (opName && !pass->canScheduleOn(*opName)) {
+      return emitError(UnknownLoc::get(ctx))
+             << "unable to schedule pass '" << pass->getName()
+             << "' on a PassManager intended to run on '" << name << "'!";
+    }
+  }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -279,7 +288,7 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
 StringRef OpPassManager::getOpName() const { return impl->name; }
 
 /// Return the operation name that this pass manager operates on.
-StringAttr OpPassManager::getOpName(MLIRContext &context) const {
+OperationName OpPassManager::getOpName(MLIRContext &context) const {
   return impl->getOpName(context);
 }
 
@@ -367,9 +376,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
                 "nested under the current operation the pass is processing";
     assert(pipeline.getOpName() == root->getName().getStringRef());
 
-    // Before running, make sure to coalesce any adjacent pass adaptors in the
-    // pipeline.
-    pipeline.getImpl().coalesceAdjacentAdaptorPasses();
+    // Before running, finalize the passes held by the pipeline.
+    if (failed(pipeline.getImpl().finalizePassList(root->getContext())))
+      return failure();
 
     // Initialize the user provided pipeline and execute the pipeline.
     if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
@@ -468,7 +477,7 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
 /// Find an operation pass manager that can operate on an operation of the given
 /// type, or nullptr if one does not exist.
 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
-                                         StringAttr name,
+                                         OperationName name,
                                          MLIRContext &context) {
   auto *it = llvm::find_if(
       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
@@ -538,8 +547,7 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
       for (auto &op : block) {
-        auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
-                                       *op.getContext());
+        auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
         if (!mgr)
           continue;
 
@@ -581,7 +589,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     for (auto &block : region) {
       for (auto &op : block) {
         // Add this operation iff the name matches any of the pass managers.
-        if (findPassManagerFor(mgrs, op.getName().getIdentifier(), *context))
+        if (findPassManagerFor(mgrs, op.getName(), *context))
           opAMPairs.emplace_back(&op, am.nest(&op));
       }
     }
@@ -604,9 +612,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     unsigned pmIndex = it - activePMs.begin();
 
     // Get the pass manager for this operation and execute it.
-    auto *pm =
-        findPassManagerFor(asyncExecutors[pmIndex],
-                           opPMPair.first->getName().getIdentifier(), *context);
+    auto *pm = findPassManagerFor(asyncExecutors[pmIndex],
+                                  opPMPair.first->getName(), *context);
     assert(pm && "expected valid pass manager for operation");
 
     unsigned initGeneration = pm->impl->initializationGeneration;
@@ -641,14 +648,10 @@ void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
 /// Run the passes within this manager on the provided operation.
 LogicalResult PassManager::run(Operation *op) {
   MLIRContext *context = getContext();
-  assert(op->getName().getIdentifier() == getOpName(*context) &&
+  assert(op->getName() == getOpName(*context) &&
          "operation has a 
diff erent name than the PassManager or is from a "
          "
diff erent context");
 
-  // Before running, make sure to coalesce any adjacent pass adaptors in the
-  // pipeline.
-  getImpl().coalesceAdjacentAdaptorPasses();
-
   // Register all dialects for the current pipeline.
   DialectRegistry dependentDialects;
   getDependentDialects(dependentDialects);
@@ -656,6 +659,10 @@ LogicalResult PassManager::run(Operation *op) {
   for (StringRef name : dependentDialects.getDialectNames())
     context->getOrLoadDialect(name);
 
+  // Before running, make sure to finalize the pipeline pass list.
+  if (failed(getImpl().finalizePassList(context)))
+    return failure();
+
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
   if (newInitKey != initializationKey) {

diff  --git a/mlir/test/Pass/interface-pass.mlir b/mlir/test/Pass/interface-pass.mlir
new file mode 100644
index 0000000000000..4506dde3d747b
--- /dev/null
+++ b/mlir/test/Pass/interface-pass.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s -verify-diagnostics -pass-pipeline='builtin.func(test-interface-pass)' -o /dev/null
+
+// Test that we run the interface pass on the function.
+
+// expected-remark at below {{Executing interface pass on operation}}
+func @main() {
+  return
+}

diff  --git a/mlir/test/Pass/invalid-interface-pass.mlir b/mlir/test/Pass/invalid-interface-pass.mlir
new file mode 100644
index 0000000000000..25d5baca0b864
--- /dev/null
+++ b/mlir/test/Pass/invalid-interface-pass.mlir
@@ -0,0 +1,9 @@
+// RUN: not mlir-opt %s -pass-pipeline='test-interface-pass' 2>&1 | FileCheck %s
+
+// Test that we emit an error when an interface pass is added to a pass manager it can't be scheduled on.
+
+// CHECK: unable to schedule pass '{{.*}}' on a PassManager intended to run on 'builtin.module'!
+
+func @main() {
+  return
+}

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index af429ebeeca54..5a759f39412cc 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -29,6 +29,18 @@ struct TestFunctionPass
     return "Test a function pass in the pass manager";
   }
 };
+class TestInterfacePass
+    : public PassWrapper<TestInterfacePass,
+                         InterfacePass<FunctionOpInterface>> {
+  void runOnOperation() final {
+    getOperation()->emitRemark() << "Executing interface pass on operation";
+  }
+  StringRef getArgument() const final { return "test-interface-pass"; }
+  StringRef getDescription() const final {
+    return "Test an interface pass (running on FunctionOpInterface) in the "
+           "pass manager";
+  }
+};
 class TestOptionsPass
     : public PassWrapper<TestOptionsPass, OperationPass<FuncOp>> {
 public:
@@ -128,6 +140,8 @@ void registerPassManagerTestPass() {
 
   PassRegistration<TestFunctionPass>();
 
+  PassRegistration<TestInterfacePass>();
+
   PassRegistration<TestCrashRecoveryPass>();
   PassRegistration<TestFailurePass>();
 

diff  --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 64a25aed24692..fc085ad1ce7e1 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -81,6 +81,9 @@ struct InvalidPass : Pass {
   InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
   StringRef getName() const override { return "Invalid Pass"; }
   void runOnOperation() override {}
+  bool canScheduleOn(RegisteredOperationName opName) const override {
+    return true;
+  }
 
   /// A clone method to create a copy of this pass.
   std::unique_ptr<Pass> clonePass() const override {


        


More information about the Mlir-commits mailing list