[Mlir-commits] [mlir] 774818c - Expose MlirOperationClone in Python bindings.

Alex Zinenko llvmlistbot at llvm.org
Mon Mar 28 06:58:28 PDT 2022


Author: Dominik Grewe
Date: 2022-03-28T15:58:22+02:00
New Revision: 774818c09c9abd952aaae6db6d045be8dd98f168

URL: https://github.com/llvm/llvm-project/commit/774818c09c9abd952aaae6db6d045be8dd98f168
DIFF: https://github.com/llvm/llvm-project/commit/774818c09c9abd952aaae6db6d045be8dd98f168.diff

LOG: Expose MlirOperationClone in Python bindings.

Expose MlirOperationClone in Python bindings.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122526

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 621c095021c7f..1225c26486a39 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
       .releaseObject();
 }
 
+static void maybeInsertOperation(PyOperationRef &op,
+                                 const py::object &maybeIp) {
+  // InsertPoint active?
+  if (!maybeIp.is(py::cast(false))) {
+    PyInsertionPoint *ip;
+    if (maybeIp.is_none()) {
+      ip = PyThreadContextEntry::getDefaultInsertionPoint();
+    } else {
+      ip = py::cast<PyInsertionPoint *>(maybeIp);
+    }
+    if (ip)
+      ip->insert(*op.get());
+  }
+}
+
 py::object PyOperation::create(
     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
     llvm::Optional<std::vector<PyValue *>> operands,
@@ -1192,22 +1207,20 @@ py::object PyOperation::create(
   MlirOperation operation = mlirOperationCreate(&state);
   PyOperationRef created =
       PyOperation::createDetached(location->getContext(), operation);
-
-  // InsertPoint active?
-  if (!maybeIp.is(py::cast(false))) {
-    PyInsertionPoint *ip;
-    if (maybeIp.is_none()) {
-      ip = PyThreadContextEntry::getDefaultInsertionPoint();
-    } else {
-      ip = py::cast<PyInsertionPoint *>(maybeIp);
-    }
-    if (ip)
-      ip->insert(*created.get());
-  }
+  maybeInsertOperation(created, maybeIp);
 
   return created->createOpView();
 }
 
+py::object PyOperation::clone(const py::object &maybeIp) {
+  MlirOperation clonedOperation = mlirOperationClone(operation);
+  PyOperationRef cloned =
+      PyOperation::createDetached(getContext(), clonedOperation);
+  maybeInsertOperation(cloned, maybeIp);
+
+  return cloned->createOpView();
+}
+
 py::object PyOperation::createOpView() {
   checkValid();
   MlirIdentifier ident = mlirOperationGetName(get());
@@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) {
                                return py::none();
                              })
       .def("erase", &PyOperation::erase)
+      .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyOperation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b1424a994d857..2046ce0c16552 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -575,6 +575,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// parent context's live operations map, and sets the valid bit false.
   void erase();
 
+  /// Clones this operation.
+  pybind11::object clone(const pybind11::object &ip);
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 8dca68385947d..7e23268c2d8ae 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -767,6 +767,26 @@ def testOperationErase():
       Operation.create("custom.op2")
 
 
+# CHECK-LABEL: TEST: testOperationClone
+ at run
+def testOperationClone():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    m = Module.create()
+    with InsertionPoint(m.body):
+      op = Operation.create("custom.op1")
+
+      # CHECK: "custom.op1"
+      print(m)
+
+      clone = op.operation.clone()
+      op.operation.erase()
+
+      # CHECK: "custom.op1"
+      print(m)
+
+
 # CHECK-LABEL: TEST: testOperationLoc
 @run
 def testOperationLoc():


        


More information about the Mlir-commits mailing list