[Mlir-commits] [mlir] bdc3e6c - [MLIR][python bindings] invalidate ops after PassManager run (#69746)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 20 18:28:36 PDT 2023
Author: Maksim Levental
Date: 2023-10-20T20:28:32-05:00
New Revision: bdc3e6cb45203ba59e6654da2cb7212ef3a15854
URL: https://github.com/llvm/llvm-project/commit/bdc3e6cb45203ba59e6654da2cb7212ef3a15854
DIFF: https://github.com/llvm/llvm-project/commit/bdc3e6cb45203ba59e6654da2cb7212ef3a15854.diff
LOG: [MLIR][python bindings] invalidate ops after PassManager run (#69746)
Fixes https://github.com/llvm/llvm-project/issues/69730 (also see
https://reviews.llvm.org/D155543).
There are two things outstanding (why I didn't land before):
1. add some C API tests for `mlirOperationWalk`;
2. potentially refactor how the invalidation in `run` works; the first
version of the code looked like this:
```cpp
if (invalidateOps) {
auto *context = op.getOperation().getContext().get();
MlirOperationWalkCallback invalidatingCallback =
[](MlirOperation op, void *userData) {
PyMlirContext *context =
static_cast<PyMlirContext *>(userData);
context->setOperationInvalid(op);
};
auto numRegions =
mlirOperationGetNumRegions(op.getOperation().get());
for (int i = 0; i < numRegions; ++i) {
MlirRegion region =
mlirOperationGetRegion(op.getOperation().get(), i);
for (MlirBlock block = mlirRegionGetFirstBlock(region);
!mlirBlockIsNull(block);
block = mlirBlockGetNextInRegion(block))
for (MlirOperation childOp =
mlirBlockGetFirstOperation(block);
!mlirOperationIsNull(childOp);
childOp = mlirOperationGetNextInBlock(childOp))
mlirOperationWalk(childOp, invalidatingCallback, context,
MlirWalkPostOrder);
}
}
```
This is verbose and ugly but it has the important benefit of not
executing `mlirOperationEqual(rootOp->get(), op)` for every op
underneath the root op.
Supposing there's no desire for the slightly more efficient but highly
convoluted approach, I can land this "posthaste".
But, since we have eyes on this now, any suggestions or approaches (or
needs/concerns) are welcome.
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e361f33a0d83641..7b121d4df328641 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void);
///
/// A named attribute is essentially a (name, attribute) pair where the name is
/// a string.
-
struct MlirNamedAttribute {
MlirIdentifier name;
MlirAttribute attribute;
@@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
/// ownership is transferred to the block of the other operation.
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
MlirOperation other);
+
+/// Traversal order for operation walk.
+typedef enum MlirWalkOrder {
+ MlirWalkPreOrder,
+ MlirWalkPostOrder
+} MlirWalkOrder;
+
+/// Operation walker type. The handler is passed an (opaque) reference to an
+/// operation a pointer to a `userData`.
+typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
+
+/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
+/// `*userData` is passed to the callback as well and can be used to tunnel some
+/// some context or other data into the callback.
+MLIR_CAPI_EXPORTED
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+ void *userData, MlirWalkOrder walkOrder);
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 389a4621c14e594..a8ea1a381edb96e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}
+void PyMlirContext::setOperationInvalid(MlirOperation op) {
+ if (liveOperations.contains(op.ptr))
+ liveOperations[op.ptr].second->setInvalid();
+}
+
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index c5412e735dddcb5..26292885711a4e4 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,6 +209,11 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();
+ /// Sets an operation invalid. This is useful for when some non-bindings
+ /// code destroys the operation and the bindings need to made aware. For
+ /// example, in the case when pass manager is run.
+ void setOperationInvalid(MlirOperation op);
+
/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cdbfcfbc22957a6..2175cea79960ca6 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -13,6 +13,7 @@
#include "mlir-c/Pass.h"
namespace py = pybind11;
+using namespace py::literals;
using namespace mlir;
using namespace mlir::python;
@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
return new PyPassManager(passManager);
}),
- py::arg("anchor_op") = py::str("any"),
- py::arg("context") = py::none(),
+ "anchor_op"_a = py::str("any"), "context"_a = py::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
- py::arg("enable"), "Enable / disable verify-each.")
+ "enable"_a, "Enable / disable verify-each.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw py::value_error(std::string(errorMsg.join()));
return new PyPassManager(passManager);
},
- py::arg("pipeline"), py::arg("context") = py::none(),
+ "pipeline"_a, "context"_a = py::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
@@ -111,12 +111,35 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
if (mlirLogicalResultIsFailure(status))
throw py::value_error(std::string(errorMsg.join()));
},
- py::arg("pipeline"),
+ "pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op) {
+ [](PyPassManager &passManager, PyOperationBase &op,
+ bool invalidateOps) {
+ if (invalidateOps) {
+ typedef struct {
+ PyOperation &rootOp;
+ bool rootSeen;
+ } callBackData;
+ callBackData data{op.getOperation(), false};
+ // Mark all ops below the op that the passmanager will be rooted
+ // at (but not op itself - note the preorder) as invalid.
+ MlirOperationWalkCallback invalidatingCallback =
+ [](MlirOperation op, void *userData) {
+ callBackData *data = static_cast<callBackData *>(userData);
+ if (LLVM_LIKELY(data->rootSeen))
+ data->rootOp.getOperation()
+ .getContext()
+ ->setOperationInvalid(op);
+ else
+ data->rootSeen = true;
+ };
+ mlirOperationWalk(op.getOperation(), invalidatingCallback,
+ static_cast<void *>(&data), MlirWalkPreOrder);
+ }
+ // Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
@@ -124,7 +147,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- py::arg("operation"),
+ "operation"_a, "invalidate_ops"_a = true,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c1abbbe364611af..0a5151751873f2b 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
+void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
+ void *userData, MlirWalkOrder walkOrder) {
+ switch (walkOrder) {
+
+ case MlirWalkPreOrder:
+ unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
+ [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ break;
+ case MlirWalkPostOrder:
+ unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
+ [callback, userData](Operation *op) { callback(wrap(op), userData); });
+ }
+}
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index a181332e219db8a..d5ca8884306c344 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -2210,6 +2210,51 @@ int testSymbolTable(MlirContext ctx) {
return 0;
}
+typedef struct {
+ const char *x;
+} callBackData;
+
+void walkCallBack(MlirOperation op, void *rootOpVoid) {
+ fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
+ mlirIdentifierStr(mlirOperationGetName(op)).data);
+}
+
+int testOperationWalk(MlirContext ctx) {
+ // CHECK-LABEL: @testOperationWalk
+ fprintf(stderr, "@testOperationWalk\n");
+
+ const char *moduleString = "module {\n"
+ " func.func @foo() {\n"
+ " %1 = arith.constant 10: i32\n"
+ " arith.addi %1, %1: i32\n"
+ " return\n"
+ " }\n"
+ "}";
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+
+ callBackData data;
+ data.x = "i love you";
+
+ // CHECK: i love you: arith.constant
+ // CHECK: i love you: arith.addi
+ // CHECK: i love you: func.return
+ // CHECK: i love you: func.func
+ // CHECK: i love you: builtin.module
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
+ (void *)(&data), MlirWalkPostOrder);
+
+ data.x = "i don't love you";
+ // CHECK: i don't love you: builtin.module
+ // CHECK: i don't love you: func.func
+ // CHECK: i don't love you: arith.constant
+ // CHECK: i don't love you: arith.addi
+ // CHECK: i don't love you: func.return
+ mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
+ (void *)(&data), MlirWalkPreOrder);
+ return 0;
+}
+
int testDialectRegistry(void) {
fprintf(stderr, "@testDialectRegistry\n");
@@ -2349,6 +2394,8 @@ int main(void) {
return 14;
if (testDialectRegistry())
return 15;
+ if (testOperationWalk(ctx))
+ return 16;
testExplicitThreadPools();
testDiagnostics();
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 4b3a02ac42bd9b1..e7f79ddc75113e0 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -4,6 +4,8 @@
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.func import FuncOp
+from mlir.dialects.builtin import ModuleOp
+
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
run(testCapsule)
+
# CHECK-LABEL: TEST: testConstruct
@run
def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
run(testParseSuccess)
+
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSpacedPipeline
def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
run(testParseSpacedPipeline)
+
# Verify failure on unregistered pass.
# CHECK-LABEL: TEST: testParseFail
def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
run(testParseFail)
+
# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
@@ -147,6 +153,7 @@ def testRunPipeline():
# CHECK: func.return , 1
run(testRunPipeline)
+
# CHECK-LABEL: TEST: testRunPipelineError
@run
def testRunPipelineError():
@@ -162,4 +169,94 @@ def testRunPipelineError():
# CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
# CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
# CHECK: >
- print(f"Exception: <{e}>")
+ log(f"Exception: <{e}>")
+
+
+# CHECK-LABEL: TEST: testPostPassOpInvalidation
+ at run
+def testPostPassOpInvalidation():
+ with Context() as ctx:
+ module = ModuleOp.parse(
+ """
+ module {
+ arith.constant 10
+ func.func @foo() {
+ arith.constant 10
+ return
+ }
+ }
+ """
+ )
+
+ # CHECK: invalidate_ops=False
+ log("invalidate_ops=False")
+
+ outer_const_op = module.body.operations[0]
+ # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
+ log(outer_const_op)
+
+ func_op = module.body.operations[1]
+ # CHECK: func.func @[[FOO:.*]]() {
+ # CHECK: %[[VAL1:.*]] = arith.constant 10 : i64
+ # CHECK: return
+ # CHECK: }
+ log(func_op)
+
+ inner_const_op = func_op.body.blocks[0].operations[0]
+ # CHECK: %[[VAL1]] = arith.constant 10 : i64
+ log(inner_const_op)
+
+ PassManager.parse("builtin.module(canonicalize)").run(
+ module, invalidate_ops=False
+ )
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(func_op)
+
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(module)
+
+ # CHECK: invalidate_ops=True
+ log("invalidate_ops=True")
+
+ module = ModuleOp.parse(
+ """
+ module {
+ arith.constant 10
+ func.func @foo() {
+ arith.constant 10
+ return
+ }
+ }
+ """
+ )
+ outer_const_op = module.body.operations[0]
+ func_op = module.body.operations[1]
+ inner_const_op = func_op.body.blocks[0].operations[0]
+
+ PassManager.parse("builtin.module(canonicalize)").run(module)
+ try:
+ log(func_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ try:
+ log(outer_const_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ try:
+ log(inner_const_op)
+ except RuntimeError as e:
+ # CHECK: the operation has been invalidated
+ log(e)
+
+ # CHECK: func.func @foo() {
+ # CHECK: return
+ # CHECK: }
+ log(module)
More information about the Mlir-commits
mailing list