[Mlir-commits] [mlir] 774818c - Expose MlirOperationClone in Python bindings.
Alex Zinenko
llvmlistbot at llvm.org
Mon Mar 28 06:58:28 PDT 2022
Author: Dominik Grewe
Date: 2022-03-28T15:58:22+02:00
New Revision: 774818c09c9abd952aaae6db6d045be8dd98f168
URL: https://github.com/llvm/llvm-project/commit/774818c09c9abd952aaae6db6d045be8dd98f168
DIFF: https://github.com/llvm/llvm-project/commit/774818c09c9abd952aaae6db6d045be8dd98f168.diff
LOG: Expose MlirOperationClone in Python bindings.
Expose MlirOperationClone in Python bindings.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D122526
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 621c095021c7f..1225c26486a39 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
.releaseObject();
}
+static void maybeInsertOperation(PyOperationRef &op,
+ const py::object &maybeIp) {
+ // InsertPoint active?
+ if (!maybeIp.is(py::cast(false))) {
+ PyInsertionPoint *ip;
+ if (maybeIp.is_none()) {
+ ip = PyThreadContextEntry::getDefaultInsertionPoint();
+ } else {
+ ip = py::cast<PyInsertionPoint *>(maybeIp);
+ }
+ if (ip)
+ ip->insert(*op.get());
+ }
+}
+
py::object PyOperation::create(
const std::string &name, llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<std::vector<PyValue *>> operands,
@@ -1192,22 +1207,20 @@ py::object PyOperation::create(
MlirOperation operation = mlirOperationCreate(&state);
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
-
- // InsertPoint active?
- if (!maybeIp.is(py::cast(false))) {
- PyInsertionPoint *ip;
- if (maybeIp.is_none()) {
- ip = PyThreadContextEntry::getDefaultInsertionPoint();
- } else {
- ip = py::cast<PyInsertionPoint *>(maybeIp);
- }
- if (ip)
- ip->insert(*created.get());
- }
+ maybeInsertOperation(created, maybeIp);
return created->createOpView();
}
+py::object PyOperation::clone(const py::object &maybeIp) {
+ MlirOperation clonedOperation = mlirOperationClone(operation);
+ PyOperationRef cloned =
+ PyOperation::createDetached(getContext(), clonedOperation);
+ maybeInsertOperation(cloned, maybeIp);
+
+ return cloned->createOpView();
+}
+
py::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
@@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) {
return py::none();
})
.def("erase", &PyOperation::erase)
+ .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b1424a994d857..2046ce0c16552 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -575,6 +575,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// parent context's live operations map, and sets the valid bit false.
void erase();
+ /// Clones this operation.
+ pybind11::object clone(const pybind11::object &ip);
+
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 8dca68385947d..7e23268c2d8ae 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -767,6 +767,26 @@ def testOperationErase():
Operation.create("custom.op2")
+# CHECK-LABEL: TEST: testOperationClone
+ at run
+def testOperationClone():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ m = Module.create()
+ with InsertionPoint(m.body):
+ op = Operation.create("custom.op1")
+
+ # CHECK: "custom.op1"
+ print(m)
+
+ clone = op.operation.clone()
+ op.operation.erase()
+
+ # CHECK: "custom.op1"
+ print(m)
+
+
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():
More information about the Mlir-commits
mailing list