[Mlir-commits] [mlir] 24685aa - [mlir][python] allow for detaching operations from a block
Alex Zinenko
llvmlistbot at llvm.org
Sun Oct 31 01:42:23 PDT 2021
Author: Alex Zinenko
Date: 2021-10-31T09:42:15+01:00
New Revision: 24685aaeb7371137e74d8290a3cf9c8ad2d544a9
URL: https://github.com/llvm/llvm-project/commit/24685aaeb7371137e74d8290a3cf9c8ad2d544a9
DIFF: https://github.com/llvm/llvm-project/commit/24685aaeb7371137e74d8290a3cf9c8ad2d544a9.diff
LOG: [mlir][python] allow for detaching operations from a block
Provide support for removing an operation from the block that contains it and
moving it back to detached state. This allows for the operation to be moved to
a different block, a common IR manipulation for, e.g., module merging.
Also fix a potential one-past-end iterator dereference in Operation::moveAfter
discovered in the process.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D112700
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/IR/Operation.cpp
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 456aa93b25dc7..ca0c45224f3a5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -346,6 +346,10 @@ MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op);
/// Takes an operation owned by the caller and destroys it.
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op);
+/// Removes the given operation from its parent block. The operation is not
+/// destroyed. The ownership of the operation is transferred to the caller.
+MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op);
+
/// Checks whether the underlying operation is null.
static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
@@ -455,6 +459,19 @@ MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
/// Verify the operation and return true if it passes, false if it fails.
MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op);
+/// Moves the given operation immediately after the other operation in its
+/// parent block. The given operation may be owned by the caller or by its
+/// current block. The other operation must belong to a block. In any case, the
+/// ownership is transferred to the block of the other operation.
+MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
+ MlirOperation other);
+
+/// Moves the given operation immediately before the other operation in its
+/// parent block. The given operation may be owner by the caller or by its
+/// current block. The other operation must belong to a block. In any case, the
+/// ownership is transferred to the block of the other operation.
+MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
+ MlirOperation other);
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7abd2a1f6b796..d47d06a3aa75e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -875,6 +875,24 @@ py::object PyOperationBase::getAsm(bool binary,
return fileObject.attr("getvalue")();
}
+void PyOperationBase::moveAfter(PyOperationBase &other) {
+ PyOperation &operation = getOperation();
+ PyOperation &otherOp = other.getOperation();
+ operation.checkValid();
+ otherOp.checkValid();
+ mlirOperationMoveAfter(operation, otherOp);
+ operation.parentKeepAlive = otherOp.parentKeepAlive;
+}
+
+void PyOperationBase::moveBefore(PyOperationBase &other) {
+ PyOperation &operation = getOperation();
+ PyOperation &otherOp = other.getOperation();
+ operation.checkValid();
+ otherOp.checkValid();
+ mlirOperationMoveBefore(operation, otherOp);
+ operation.parentKeepAlive = otherOp.parentKeepAlive;
+}
+
llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
checkValid();
if (!isAttached())
@@ -2185,7 +2203,25 @@ void mlir::python::populateIRCore(py::module &m) {
return mlirOperationVerify(self.getOperation());
},
"Verify the operation and return true if it passes, false if it "
- "fails.");
+ "fails.")
+ .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
+ "Puts self immediately after the other operation in its parent "
+ "block.")
+ .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
+ "Puts self immediately before the other operation in its parent "
+ "block.")
+ .def(
+ "detach_from_parent",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ if (!operation.isAttached())
+ throw py::value_error("Detached operation has no parent.");
+
+ operation.detachFromParent();
+ return operation.createOpView();
+ },
+ "Detaches the operation from its parent block.");
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
@@ -2380,7 +2416,20 @@ void mlir::python::populateIRCore(py::module &m) {
printAccum.getUserData());
return printAccum.join();
},
- "Returns the assembly form of the block.");
+ "Returns the assembly form of the block.")
+ .def(
+ "append",
+ [](PyBlock &self, PyOperationBase &operation) {
+ if (operation.getOperation().isAttached())
+ operation.getOperation().detachFromParent();
+
+ MlirOperation mlirOperation = operation.getOperation().get();
+ mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+ operation.getOperation().setAttached(
+ self.getParentOperation().getObject());
+ },
+ "Appends an operation to this block. If the operation is currently "
+ "in another block, it will be moved.");
//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dac9486c4e773..73924fc74bdbf 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -399,6 +399,10 @@ class PyOperationBase {
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope);
+ /// Moves the operation before or after the other operation.
+ void moveAfter(PyOperationBase &other);
+ void moveBefore(PyOperationBase &other);
+
/// Each must provide access to the raw Operation.
virtual PyOperation &getOperation() = 0;
};
@@ -428,6 +432,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
+ /// Detaches the operation from its parent block and updates its state
+ /// accordingly.
+ void detachFromParent() {
+ mlirOperationRemoveFromParent(getOperation());
+ setDetached();
+ parentKeepAlive = pybind11::object();
+ }
+
/// Gets the backing operation.
operator MlirOperation() const { return get(); }
MlirOperation get() const {
@@ -441,10 +453,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
}
bool isAttached() { return attached; }
- void setAttached() {
+ void setAttached(pybind11::object parent = pybind11::object()) {
assert(!attached && "operation already attached");
attached = true;
}
+ void setDetached() {
+ assert(attached && "operation already detached");
+ attached = false;
+ }
void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
@@ -495,6 +511,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
pybind11::object parentKeepAlive;
bool attached = true;
bool valid = true;
+
+ friend class PyOperationBase;
};
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c738198f75b42..6f617dc19269d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -338,6 +338,8 @@ MlirOperation mlirOperationClone(MlirOperation op) {
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
+void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
+
bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
return unwrap(op) == unwrap(other);
}
@@ -451,6 +453,14 @@ bool mlirOperationVerify(MlirOperation op) {
return succeeded(verify(unwrap(op)));
}
+void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
+ return unwrap(op)->moveAfter(unwrap(other));
+}
+
+void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
+ return unwrap(op)->moveBefore(unwrap(other));
+}
+
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index e877449fa83fb..3f1310f73a78b 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -505,7 +505,7 @@ void Operation::moveAfter(Operation *existingOp) {
void Operation::moveAfter(Block *block,
llvm::iplist<Operation>::iterator iterator) {
assert(iterator != block->end() && "cannot move after end of block");
- moveBefore(&*std::next(iterator));
+ moveBefore(block, std::next(iterator));
}
/// This drops all operand uses from this operation, which is an essential
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9cd4824d68997..c94c22ea53a0b 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -740,3 +740,66 @@ def testOperationLoc():
op = Operation.create("custom.op", loc=loc)
assert op.location == loc
assert op.operation.location == loc
+
+# CHECK-LABEL: TEST: testModuleMerge
+ at run
+def testModuleMerge():
+ with Context():
+ m1 = Module.parse("func private @foo()")
+ m2 = Module.parse("""
+ func private @bar()
+ func private @qux()
+ """)
+ foo = m1.body.operations[0]
+ bar = m2.body.operations[0]
+ qux = m2.body.operations[1]
+ bar.move_before(foo)
+ qux.move_after(foo)
+
+ # CHECK: module
+ # CHECK: func private @bar
+ # CHECK: func private @foo
+ # CHECK: func private @qux
+ print(m1)
+
+ # CHECK: module {
+ # CHECK-NEXT: }
+ print(m2)
+
+
+# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
+ at run
+def testAppendMoveFromAnotherBlock():
+ with Context():
+ m1 = Module.parse("func private @foo()")
+ m2 = Module.parse("func private @bar()")
+ func = m1.body.operations[0]
+ m2.body.append(func)
+
+ # CHECK: module
+ # CHECK: func private @bar
+ # CHECK: func private @foo
+
+ print(m2)
+ # CHECK: module {
+ # CHECK-NEXT: }
+ print(m1)
+
+
+# CHECK-LABEL: TEST: testDetachFromParent
+ at run
+def testDetachFromParent():
+ with Context():
+ m1 = Module.parse("func private @foo()")
+ func = m1.body.operations[0].detach_from_parent()
+
+ try:
+ func.detach_from_parent()
+ except ValueError as e:
+ if "has no parent" not in str(e):
+ raise
+ else:
+ assert False, "expected ValueError when detaching a detached operation"
+
+ print(m1)
+ # CHECK-NOT: func private @foo
More information about the Mlir-commits
mailing list