[Mlir-commits] [mlir] [MLIR][python bindings] invalidate ops after PassManager run (PR #69746)

Maksim Levental llvmlistbot at llvm.org
Fri Oct 20 11:07:03 PDT 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/69746

Fixes https://github.com/llvm/llvm-project/issues/69730 (also see https://reviews.llvm.org/D155543).

>From a7d6f8b3caaf38008cb064f2a39b5abe48344ef6 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 20 Oct 2023 13:00:36 -0500
Subject: [PATCH] [MLIR][python bindings] invalidate ops after PassManager run

---
 mlir/include/mlir-c/IR.h            | 15 +++++++++
 mlir/lib/Bindings/Python/IRCore.cpp |  5 +++
 mlir/lib/Bindings/Python/IRModule.h |  5 +++
 mlir/lib/Bindings/Python/Pass.cpp   | 33 +++++++++++++++----
 mlir/lib/CAPI/IR/IR.cpp             | 15 +++++++++
 mlir/test/python/pass_manager.py    | 51 ++++++++++++++++++++++++++++-
 6 files changed, 116 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e361f33a0d83641..3163c3cc40c58b1 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -698,6 +698,21 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
 /// ownership is transferred to the block of the other operation.
 MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
                                                 MlirOperation other);
+
+typedef enum MlirWalkOrder {
+  MlirWalkPreOrder,
+  MlirWalkPostOrder
+} MlirWalkOrder;
+
+typedef void (*MlirOperationWalkCallback)(MlirOperation, void *);
+
+/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
+/// `*userData` is passed to the callback as well and can be used to tunnel some
+/// some context or other data into the callback.
+MLIR_CAPI_EXPORTED
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 389a4621c14e594..a8ea1a381edb96e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
+void PyMlirContext::setOperationInvalid(MlirOperation op) {
+  if (liveOperations.contains(op.ptr))
+    liveOperations[op.ptr].second->setInvalid();
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index c5412e735dddcb5..26292885711a4e4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,6 +209,11 @@ class PyMlirContext {
   /// place.
   size_t clearLiveOperations();
 
+  /// Sets an operation invalid. This is useful for when some non-bindings
+  /// code destroys the operation and the bindings need to made aware. For
+  /// example, in the case when pass manager is run.
+  void setOperationInvalid(MlirOperation op);
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..227910c8ccbd0a4 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -13,6 +13,7 @@
 #include "mlir-c/Pass.h"
 
 namespace py = pybind11;
+using namespace py::literals;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
              return new PyPassManager(passManager);
            }),
-           py::arg("anchor_op") = py::str("any"),
-           py::arg("context") = py::none(),
+           "anchor_op"_a = py::str("any"), "context"_a = py::none(),
            "Create a new PassManager for the current (or provided) Context.")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           [](PyPassManager &passManager, bool enable) {
             mlirPassManagerEnableVerifier(passManager.get(), enable);
           },
-          py::arg("enable"), "Enable / disable verify-each.")
+          "enable"_a, "Enable / disable verify-each.")
       .def_static(
           "parse",
           [](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw py::value_error(std::string(errorMsg.join()));
             return new PyPassManager(passManager);
           },
-          py::arg("pipeline"), py::arg("context") = py::none(),
+          "pipeline"_a, "context"_a = py::none(),
           "Parse a textual pass-pipeline and return a top-level PassManager "
           "that can be applied on a Module. Throw a ValueError if the pipeline "
           "can't be parsed")
@@ -111,12 +111,31 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
             if (mlirLogicalResultIsFailure(status))
               throw py::value_error(std::string(errorMsg.join()));
           },
-          py::arg("pipeline"),
+          "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op) {
+          [](PyPassManager &passManager, PyOperationBase &op,
+             bool invalidateOps) {
+            if (invalidateOps) {
+              // Mark all ops below the op that the passmanager will be rooted
+              // at as invalid.
+              MlirOperationWalkCallback invalidatingCallback =
+                  [](MlirOperation op, void *rootOpVoid) {
+                    PyOperation *rootOp =
+                        static_cast<PyOperation *>(rootOpVoid);
+                    if (!mlirOperationEqual(rootOp->get(), op)) {
+                      mlirOperationDump(op);
+                      rootOp->getOperation().getContext()->setOperationInvalid(
+                          op);
+                    }
+                  };
+              mlirOperationWalk(op.getOperation(), invalidatingCallback,
+                                static_cast<void *>(&op.getOperation()),
+                                MlirWalkPostOrder);
+            }
+            // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
                 passManager.get(), op.getOperation().get());
@@ -124,7 +143,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          py::arg("operation"),
+          "operation"_a, "invalidate_ops"_a = true,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c1abbbe364611af..0a5151751873f2b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
 
@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
   return unwrap(op)->moveBefore(unwrap(other));
 }
 
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder) {
+  switch (walkOrder) {
+
+  case MlirWalkPreOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+    break;
+  case MlirWalkPostOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 4b3a02ac42bd9b1..53533685ba816d1 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -4,6 +4,8 @@
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
+from mlir.dialects.builtin import ModuleOp
+
 
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
 
 run(testCapsule)
 
+
 # CHECK-LABEL: TEST: testConstruct
 @run
 def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
 
 run(testParseSuccess)
 
+
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSpacedPipeline
 def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
 
 run(testParseSpacedPipeline)
 
+
 # Verify failure on unregistered pass.
 # CHECK-LABEL: TEST: testParseFail
 def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
 
 run(testParseFail)
 
+
 # Check that adding to a pass manager works
 # CHECK-LABEL: TEST: testAdd
 @run
@@ -147,6 +153,7 @@ def testRunPipeline():
 # CHECK: func.return        , 1
 run(testRunPipeline)
 
+
 # CHECK-LABEL: TEST: testRunPipelineError
 @run
 def testRunPipelineError():
@@ -162,4 +169,46 @@ def testRunPipelineError():
             # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
             # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
             # CHECK: >
-            print(f"Exception: <{e}>")
+            log(f"Exception: <{e}>")
+
+
+# CHECK-LABEL: TEST: testPostPassOpInvalidation
+ at run
+def testPostPassOpInvalidation():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            arith.constant 10
+            func.func @foo() {
+              arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        outer_const_op = module.body.operations[0]
+        # CHECK: %c10_i64 = arith.constant 10 : i64
+        log(outer_const_op)
+        inner_const_op = module.body.operations[1].body.blocks[0].operations[0]
+        # CHECK: %c10_i64_0 = arith.constant 10 : i64
+        log(inner_const_op)
+
+        PassManager.parse("builtin.module(canonicalize)").run(module)
+        try:
+            log(outer_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+        try:
+            log(inner_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        # CHECK:       module {
+        # CHECK-LABEL:   func.func @foo() {
+        # CHECK:           return
+        # CHECK:         }
+        # CHECK:       }
+        log(module)



More information about the Mlir-commits mailing list