[Mlir-commits] [mlir] [MLIR][Python] bind InsertionPointAfter (PR #157156)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 5 12:45:52 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/157156.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+14-1)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+4-1)
- (modified) mlir/test/python/ir/insertion_point.py (+28)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf4950fc1a070..f6f3abf9819e9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
// PyInsertionPoint.
//------------------------------------------------------------------------------
-PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
return PyInsertionPoint{block, std::move(terminatorOpRef)};
}
+PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
+ PyOperation &operation = op.getOperation();
+ MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
+ if (mlirOperationIsNull(nextOperation))
+ return PyInsertionPoint{operation.getBlock()};
+ PyOperationRef nextOpRef =
+ PyOperation::forOperation(operation.getContext(), nextOperation);
+ return PyInsertionPoint{nextOpRef->getOperation().getBlock(),
+ std::move(nextOpRef)};
+}
+
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("block"), "Inserts at the beginning of the block.")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
nb::arg("block"), "Inserts before the block terminator.")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ "Inserts after the operation.")
.def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
"Inserts an operation.")
.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0cc0459ebc9a0..1d1ff29533f98 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -821,7 +821,7 @@ class PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
- PyInsertionPoint(PyBlock &block);
+ PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
@@ -829,6 +829,9 @@ class PyInsertionPoint {
static PyInsertionPoint atBlockBegin(PyBlock &block);
/// Shortcut to create an insertion point before the block terminator.
static PyInsertionPoint atBlockTerminator(PyBlock &block);
+ /// Shortcut to create an insertion point to the node after the specified
+ /// operation.
+ static PyInsertionPoint after(PyOperationBase &op);
/// Inserts an operation.
void insert(PyOperationBase &operationBase);
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 5eb861a2c0891..a0296227cb050 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -63,6 +63,34 @@ def test_insert_before_operation():
run(test_insert_before_operation)
+# CHECK-LABEL: TEST: test_insert_after_operation
+def test_insert_after_operation():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
+ func.func @foo() -> () {
+ "custom.op1"() : () -> ()
+ "custom.op2"() : () -> ()
+ }
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ custom_op1 = entry_block.operations[0]
+ custom_op2 = entry_block.operations[1]
+ InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
+ InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op3"
+ # CHECK: "custom.op2"
+ # CHECK: "custom.op4"
+ module.operation.print()
+
+
+run(test_insert_after_operation)
+
+
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
``````````
</details>
https://github.com/llvm/llvm-project/pull/157156
More information about the Mlir-commits
mailing list