[Mlir-commits] [mlir] [MLIR][python bindings] invalidate ops after PassManager run (PR #69746)

Maksim Levental llvmlistbot at llvm.org
Fri Oct 20 15:51:33 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/69746

>From 8b5c99cdda78df93e9ef7ac167bee0c318094981 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 20 Oct 2023 13:00:36 -0500
Subject: [PATCH] [MLIR][python bindings] invalidate ops after PassManager run

---
 mlir/include/mlir-c/IR.h            | 19 +++++-
 mlir/lib/Bindings/Python/IRCore.cpp |  5 ++
 mlir/lib/Bindings/Python/IRModule.h |  5 ++
 mlir/lib/Bindings/Python/Pass.cpp   | 37 +++++++++--
 mlir/lib/CAPI/IR/IR.cpp             | 15 +++++
 mlir/test/CAPI/ir.c                 | 47 ++++++++++++++
 mlir/test/python/pass_manager.py    | 99 ++++++++++++++++++++++++++++-
 7 files changed, 218 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e361f33a0d83641..7b121d4df328641 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void);
 ///
 /// A named attribute is essentially a (name, attribute) pair where the name is
 /// a string.
-
 struct MlirNamedAttribute {
   MlirIdentifier name;
   MlirAttribute attribute;
@@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
 /// ownership is transferred to the block of the other operation.
 MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
                                                 MlirOperation other);
+
+/// Traversal order for operation walk.
+typedef enum MlirWalkOrder {
+  MlirWalkPreOrder,
+  MlirWalkPostOrder
+} MlirWalkOrder;
+
+/// Operation walker type. The handler is passed an (opaque) reference to an
+/// operation a pointer to a `userData`.
+typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+
+/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
+/// `*userData` is passed to the callback as well and can be used to tunnel some
+/// some context or other data into the callback.
+MLIR_CAPI_EXPORTED
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 389a4621c14e594..a8ea1a381edb96e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
+void PyMlirContext::setOperationInvalid(MlirOperation op) {
+  if (liveOperations.contains(op.ptr))
+    liveOperations[op.ptr].second->setInvalid();
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index c5412e735dddcb5..26292885711a4e4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,6 +209,11 @@ class PyMlirContext {
   /// place.
   size_t clearLiveOperations();
 
+  /// Sets an operation invalid. This is useful for when some non-bindings
+  /// code destroys the operation and the bindings need to made aware. For
+  /// example, in the case when pass manager is run.
+  void setOperationInvalid(MlirOperation op);
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..2627efe90897770 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -13,6 +13,7 @@
 #include "mlir-c/Pass.h"
 
 namespace py = pybind11;
+using namespace py::literals;
 using namespace mlir;
 using namespace mlir::python;
 
@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
                  mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
              return new PyPassManager(passManager);
            }),
-           py::arg("anchor_op") = py::str("any"),
-           py::arg("context") = py::none(),
+           "anchor_op"_a = py::str("any"), "context"_a = py::none(),
            "Create a new PassManager for the current (or provided) Context.")
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           [](PyPassManager &passManager, bool enable) {
             mlirPassManagerEnableVerifier(passManager.get(), enable);
           },
-          py::arg("enable"), "Enable / disable verify-each.")
+          "enable"_a, "Enable / disable verify-each.")
       .def_static(
           "parse",
           [](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw py::value_error(std::string(errorMsg.join()));
             return new PyPassManager(passManager);
           },
-          py::arg("pipeline"), py::arg("context") = py::none(),
+          "pipeline"_a, "context"_a = py::none(),
           "Parse a textual pass-pipeline and return a top-level PassManager "
           "that can be applied on a Module. Throw a ValueError if the pipeline "
           "can't be parsed")
@@ -111,12 +111,35 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
             if (mlirLogicalResultIsFailure(status))
               throw py::value_error(std::string(errorMsg.join()));
           },
-          py::arg("pipeline"),
+          "pipeline"_a,
           "Add textual pipeline elements to the pass manager. Throws a "
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op) {
+          [](PyPassManager &passManager, PyOperationBase &op,
+             bool invalidateOps) {
+            if (invalidateOps) {
+              typedef struct {
+                PyOperation &rootOp;
+                bool rootSeen;
+              } callBackData;
+              callBackData data{op.getOperation(), false};
+              // Mark all ops below the op that the passmanager will be rooted
+              // at as invalid.
+              MlirOperationWalkCallback invalidatingCallback =
+                  [](MlirOperation op, void *userData) {
+                    callBackData *data = static_cast<callBackData *>(userData);
+                    if (LLVM_LIKELY(data->rootSeen))
+                      data->rootOp.getOperation()
+                          .getContext()
+                          ->setOperationInvalid(op);
+                    else
+                      data->rootSeen = true;
+                  };
+              mlirOperationWalk(op.getOperation(), invalidatingCallback,
+                                static_cast<void *>(&data), MlirWalkPreOrder);
+            }
+            // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
                 passManager.get(), op.getOperation().get());
@@ -124,7 +147,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          py::arg("operation"),
+          "operation"_a, "invalidate_ops"_a = true,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c1abbbe364611af..0a5151751873f2b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
 
@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
   return unwrap(op)->moveBefore(unwrap(other));
 }
 
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+                       void *userData, MlirWalkOrder walkOrder) {
+  switch (walkOrder) {
+
+  case MlirWalkPreOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+    break;
+  case MlirWalkPostOrder:
+    unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
+        [callback, userData](Operation *op) { callback(wrap(op), userData); });
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index a181332e219db8a..d5ca8884306c344 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2210,6 +2210,51 @@ int testSymbolTable(MlirContext ctx) {
   return 0;
 }
 
+typedef struct {
+  const char *x;
+} callBackData;
+
+void walkCallBack(MlirOperation op, void *rootOpVoid) {
+  fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+          mlirIdentifierStr(mlirOperationGetName(op)).data);
+}
+
+int testOperationWalk(MlirContext ctx) {
+  // CHECK-LABEL: @testOperationWalk
+  fprintf(stderr, "@testOperationWalk\n");
+
+  const char *moduleString = "module {\n"
+                             "  func.func @foo() {\n"
+                             "    %1 = arith.constant 10: i32\n"
+                             "    arith.addi %1, %1: i32\n"
+                             "    return\n"
+                             "  }\n"
+                             "}";
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+  callBackData data;
+  data.x = "i love you";
+
+  // CHECK: i love you: arith.constant
+  // CHECK: i love you: arith.addi
+  // CHECK: i love you: func.return
+  // CHECK: i love you: func.func
+  // CHECK: i love you: builtin.module
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
+                    (void *)(&data), MlirWalkPostOrder);
+
+  data.x = "i don't love you";
+  // CHECK: i don't love you: builtin.module
+  // CHECK: i don't love you: func.func
+  // CHECK: i don't love you: arith.constant
+  // CHECK: i don't love you: arith.addi
+  // CHECK: i don't love you: func.return
+  mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
+                    (void *)(&data), MlirWalkPreOrder);
+  return 0;
+}
+
 int testDialectRegistry(void) {
   fprintf(stderr, "@testDialectRegistry\n");
 
@@ -2349,6 +2394,8 @@ int main(void) {
     return 14;
   if (testDialectRegistry())
     return 15;
+  if (testOperationWalk(ctx))
+    return 16;
 
   testExplicitThreadPools();
   testDiagnostics();
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 4b3a02ac42bd9b1..e7f79ddc75113e0 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -4,6 +4,8 @@
 from mlir.ir import *
 from mlir.passmanager import *
 from mlir.dialects.func import FuncOp
+from mlir.dialects.builtin import ModuleOp
+
 
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
 
 run(testCapsule)
 
+
 # CHECK-LABEL: TEST: testConstruct
 @run
 def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
 
 run(testParseSuccess)
 
+
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSpacedPipeline
 def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
 
 run(testParseSpacedPipeline)
 
+
 # Verify failure on unregistered pass.
 # CHECK-LABEL: TEST: testParseFail
 def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
 
 run(testParseFail)
 
+
 # Check that adding to a pass manager works
 # CHECK-LABEL: TEST: testAdd
 @run
@@ -147,6 +153,7 @@ def testRunPipeline():
 # CHECK: func.return        , 1
 run(testRunPipeline)
 
+
 # CHECK-LABEL: TEST: testRunPipelineError
 @run
 def testRunPipelineError():
@@ -162,4 +169,94 @@ def testRunPipelineError():
             # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
             # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
             # CHECK: >
-            print(f"Exception: <{e}>")
+            log(f"Exception: <{e}>")
+
+
+# CHECK-LABEL: TEST: testPostPassOpInvalidation
+ at run
+def testPostPassOpInvalidation():
+    with Context() as ctx:
+        module = ModuleOp.parse(
+            """
+          module {
+            arith.constant 10
+            func.func @foo() {
+              arith.constant 10
+              return
+            }
+          }
+        """
+        )
+
+        # CHECK: invalidate_ops=False
+        log("invalidate_ops=False")
+
+        outer_const_op = module.body.operations[0]
+        # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
+        log(outer_const_op)
+
+        func_op = module.body.operations[1]
+        # CHECK: func.func @[[FOO:.*]]() {
+        # CHECK:   %[[VAL1:.*]] = arith.constant 10 : i64
+        # CHECK:   return
+        # CHECK: }
+        log(func_op)
+
+        inner_const_op = func_op.body.blocks[0].operations[0]
+        # CHECK: %[[VAL1]] = arith.constant 10 : i64
+        log(inner_const_op)
+
+        PassManager.parse("builtin.module(canonicalize)").run(
+            module, invalidate_ops=False
+        )
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(func_op)
+
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(module)
+
+        # CHECK: invalidate_ops=True
+        log("invalidate_ops=True")
+
+        module = ModuleOp.parse(
+            """
+          module {
+            arith.constant 10
+            func.func @foo() {
+              arith.constant 10
+              return
+            }
+          }
+        """
+        )
+        outer_const_op = module.body.operations[0]
+        func_op = module.body.operations[1]
+        inner_const_op = func_op.body.blocks[0].operations[0]
+
+        PassManager.parse("builtin.module(canonicalize)").run(module)
+        try:
+            log(func_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        try:
+            log(outer_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        try:
+            log(inner_const_op)
+        except RuntimeError as e:
+            # CHECK: the operation has been invalidated
+            log(e)
+
+        # CHECK: func.func @foo() {
+        # CHECK:   return
+        # CHECK: }
+        log(module)



More information about the Mlir-commits mailing list