[Mlir-commits] [mlir] 24685aa - [mlir][python] allow for detaching operations from a block

Alex Zinenko llvmlistbot at llvm.org
Sun Oct 31 01:42:23 PDT 2021


Author: Alex Zinenko
Date: 2021-10-31T09:42:15+01:00
New Revision: 24685aaeb7371137e74d8290a3cf9c8ad2d544a9

URL: https://github.com/llvm/llvm-project/commit/24685aaeb7371137e74d8290a3cf9c8ad2d544a9
DIFF: https://github.com/llvm/llvm-project/commit/24685aaeb7371137e74d8290a3cf9c8ad2d544a9.diff

LOG: [mlir][python] allow for detaching operations from a block

Provide support for removing an operation from the block that contains it and
moving it back to detached state. This allows for the operation to be moved to
a different block, a common IR manipulation for, e.g., module merging.

Also fix a potential one-past-end iterator dereference in Operation::moveAfter
discovered in the process.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D112700

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 456aa93b25dc7..ca0c45224f3a5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -346,6 +346,10 @@ MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op);
 /// Takes an operation owned by the caller and destroys it.
 MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op);
 
+/// Removes the given operation from its parent block. The operation is not
+/// destroyed. The ownership of the operation is transferred to the caller.
+MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op);
+
 /// Checks whether the underlying operation is null.
 static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; }
 
@@ -455,6 +459,19 @@ MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
 /// Verify the operation and return true if it passes, false if it fails.
 MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op);
 
+/// Moves the given operation immediately after the other operation in its
+/// parent block. The given operation may be owned by the caller or by its
+/// current block. The other operation must belong to a block. In any case, the
+/// ownership is transferred to the block of the other operation.
+MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
+                                               MlirOperation other);
+
+/// Moves the given operation immediately before the other operation in its
+/// parent block. The given operation may be owner by the caller or by its
+/// current block. The other operation must belong to a block. In any case, the
+/// ownership is transferred to the block of the other operation.
+MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
+                                                MlirOperation other);
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7abd2a1f6b796..d47d06a3aa75e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -875,6 +875,24 @@ py::object PyOperationBase::getAsm(bool binary,
   return fileObject.attr("getvalue")();
 }
 
+void PyOperationBase::moveAfter(PyOperationBase &other) {
+  PyOperation &operation = getOperation();
+  PyOperation &otherOp = other.getOperation();
+  operation.checkValid();
+  otherOp.checkValid();
+  mlirOperationMoveAfter(operation, otherOp);
+  operation.parentKeepAlive = otherOp.parentKeepAlive;
+}
+
+void PyOperationBase::moveBefore(PyOperationBase &other) {
+  PyOperation &operation = getOperation();
+  PyOperation &otherOp = other.getOperation();
+  operation.checkValid();
+  otherOp.checkValid();
+  mlirOperationMoveBefore(operation, otherOp);
+  operation.parentKeepAlive = otherOp.parentKeepAlive;
+}
+
 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
   checkValid();
   if (!isAttached())
@@ -2185,7 +2203,25 @@ void mlir::python::populateIRCore(py::module &m) {
             return mlirOperationVerify(self.getOperation());
           },
           "Verify the operation and return true if it passes, false if it "
-          "fails.");
+          "fails.")
+      .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
+           "Puts self immediately after the other operation in its parent "
+           "block.")
+      .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
+           "Puts self immediately before the other operation in its parent "
+           "block.")
+      .def(
+          "detach_from_parent",
+          [](PyOperationBase &self) {
+            PyOperation &operation = self.getOperation();
+            operation.checkValid();
+            if (!operation.isAttached())
+              throw py::value_error("Detached operation has no parent.");
+
+            operation.detachFromParent();
+            return operation.createOpView();
+          },
+          "Detaches the operation from its parent block.");
 
   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
       .def_static("create", &PyOperation::create, py::arg("name"),
@@ -2380,7 +2416,20 @@ void mlir::python::populateIRCore(py::module &m) {
                            printAccum.getUserData());
             return printAccum.join();
           },
-          "Returns the assembly form of the block.");
+          "Returns the assembly form of the block.")
+      .def(
+          "append",
+          [](PyBlock &self, PyOperationBase &operation) {
+            if (operation.getOperation().isAttached())
+              operation.getOperation().detachFromParent();
+
+            MlirOperation mlirOperation = operation.getOperation().get();
+            mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
+            operation.getOperation().setAttached(
+                self.getParentOperation().getObject());
+          },
+          "Appends an operation to this block. If the operation is currently "
+          "in another block, it will be moved.");
 
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dac9486c4e773..73924fc74bdbf 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -399,6 +399,10 @@ class PyOperationBase {
                           bool enableDebugInfo, bool prettyDebugInfo,
                           bool printGenericOpForm, bool useLocalScope);
 
+  /// Moves the operation before or after the other operation.
+  void moveAfter(PyOperationBase &other);
+  void moveBefore(PyOperationBase &other);
+
   /// Each must provide access to the raw Operation.
   virtual PyOperation &getOperation() = 0;
 };
@@ -428,6 +432,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
                  pybind11::object parentKeepAlive = pybind11::object());
 
+  /// Detaches the operation from its parent block and updates its state
+  /// accordingly.
+  void detachFromParent() {
+    mlirOperationRemoveFromParent(getOperation());
+    setDetached();
+    parentKeepAlive = pybind11::object();
+  }
+
   /// Gets the backing operation.
   operator MlirOperation() const { return get(); }
   MlirOperation get() const {
@@ -441,10 +453,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   }
 
   bool isAttached() { return attached; }
-  void setAttached() {
+  void setAttached(pybind11::object parent = pybind11::object()) {
     assert(!attached && "operation already attached");
     attached = true;
   }
+  void setDetached() {
+    assert(attached && "operation already detached");
+    attached = false;
+  }
   void checkValid() const;
 
   /// Gets the owning block or raises an exception if the operation has no
@@ -495,6 +511,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   pybind11::object parentKeepAlive;
   bool attached = true;
   bool valid = true;
+
+  friend class PyOperationBase;
 };
 
 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index c738198f75b42..6f617dc19269d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -338,6 +338,8 @@ MlirOperation mlirOperationClone(MlirOperation op) {
 
 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
 
+void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
+
 bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
   return unwrap(op) == unwrap(other);
 }
@@ -451,6 +453,14 @@ bool mlirOperationVerify(MlirOperation op) {
   return succeeded(verify(unwrap(op)));
 }
 
+void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
+  return unwrap(op)->moveAfter(unwrap(other));
+}
+
+void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
+  return unwrap(op)->moveBefore(unwrap(other));
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index e877449fa83fb..3f1310f73a78b 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -505,7 +505,7 @@ void Operation::moveAfter(Operation *existingOp) {
 void Operation::moveAfter(Block *block,
                           llvm::iplist<Operation>::iterator iterator) {
   assert(iterator != block->end() && "cannot move after end of block");
-  moveBefore(&*std::next(iterator));
+  moveBefore(block, std::next(iterator));
 }
 
 /// This drops all operand uses from this operation, which is an essential

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9cd4824d68997..c94c22ea53a0b 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -740,3 +740,66 @@ def testOperationLoc():
     op = Operation.create("custom.op", loc=loc)
     assert op.location == loc
     assert op.operation.location == loc
+
+# CHECK-LABEL: TEST: testModuleMerge
+ at run
+def testModuleMerge():
+  with Context():
+    m1 = Module.parse("func private @foo()")
+    m2 = Module.parse("""
+      func private @bar()
+      func private @qux()
+    """)
+    foo = m1.body.operations[0]
+    bar = m2.body.operations[0]
+    qux = m2.body.operations[1]
+    bar.move_before(foo)
+    qux.move_after(foo)
+
+    # CHECK: module
+    # CHECK: func private @bar
+    # CHECK: func private @foo
+    # CHECK: func private @qux
+    print(m1)
+
+    # CHECK: module {
+    # CHECK-NEXT: }
+    print(m2)
+
+
+# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
+ at run
+def testAppendMoveFromAnotherBlock():
+  with Context():
+    m1 = Module.parse("func private @foo()")
+    m2 = Module.parse("func private @bar()")
+    func = m1.body.operations[0]  
+    m2.body.append(func)
+
+    # CHECK: module
+    # CHECK: func private @bar
+    # CHECK: func private @foo
+
+    print(m2)
+    # CHECK: module {
+    # CHECK-NEXT: }
+    print(m1)
+
+
+# CHECK-LABEL: TEST: testDetachFromParent
+ at run
+def testDetachFromParent():
+  with Context():
+    m1 = Module.parse("func private @foo()")
+    func = m1.body.operations[0].detach_from_parent()
+
+    try:
+      func.detach_from_parent()
+    except ValueError as e:
+      if "has no parent" not in str(e):
+        raise
+    else:
+      assert False, "expected ValueError when detaching a detached operation"
+
+    print(m1)
+    # CHECK-NOT: func private @foo


        


More information about the Mlir-commits mailing list