[Mlir-commits] [mlir] [MLIR][Python] bind InsertionPointAfter (PR #157156)
Maksim Levental
llvmlistbot at llvm.org
Fri Sep 5 13:17:52 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/157156
>From 57eb1d64a5419a04894de73fc7bb58306141e9a3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 5 Sep 2025 11:22:04 -0700
Subject: [PATCH] [MLIR][Python] bind InsertionPointAfter
---
mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++++++++-
mlir/lib/Bindings/Python/IRModule.h | 5 +++-
mlir/test/python/ir/insertion_point.py | 33 +++++++++++++++++++++++++-
3 files changed, 50 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf4950fc1a070..ab2602ac72df0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
// PyInsertionPoint.
//------------------------------------------------------------------------------
-PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
return PyInsertionPoint{block, std::move(terminatorOpRef)};
}
+PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
+ PyOperation &operation = op.getOperation();
+ PyBlock block = operation.getBlock();
+ MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
+ if (mlirOperationIsNull(nextOperation))
+ return PyInsertionPoint(block);
+ PyOperationRef nextOpRef =
+ PyOperation::forOperation(operation.getContext(), nextOperation);
+ return PyInsertionPoint{block, std::move(nextOpRef)};
+}
+
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("block"), "Inserts at the beginning of the block.")
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
nb::arg("block"), "Inserts before the block terminator.")
+ .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+ "Inserts after the operation.")
.def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
"Inserts an operation.")
.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0cc0459ebc9a0..1d1ff29533f98 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -821,7 +821,7 @@ class PyInsertionPoint {
public:
/// Creates an insertion point positioned after the last operation in the
/// block, but still inside the block.
- PyInsertionPoint(PyBlock &block);
+ PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
@@ -829,6 +829,9 @@ class PyInsertionPoint {
static PyInsertionPoint atBlockBegin(PyBlock &block);
/// Shortcut to create an insertion point before the block terminator.
static PyInsertionPoint atBlockTerminator(PyBlock &block);
+ /// Shortcut to create an insertion point to the node after the specified
+ /// operation.
+ static PyInsertionPoint after(PyOperationBase &op);
/// Inserts an operation.
void insert(PyOperationBase &operationBase);
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 5eb861a2c0891..ec2216803e573 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -63,6 +63,34 @@ def test_insert_before_operation():
run(test_insert_before_operation)
+# CHECK-LABEL: TEST: test_insert_after_operation
+def test_insert_after_operation():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ module = Module.parse(
+ r"""
+ func.func @foo() -> () {
+ "custom.op1"() : () -> ()
+ "custom.op2"() : () -> ()
+ }
+ """
+ )
+ entry_block = module.body.operations[0].regions[0].blocks[0]
+ custom_op1 = entry_block.operations[0]
+ custom_op2 = entry_block.operations[1]
+ InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
+ InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
+ # CHECK: "custom.op1"
+ # CHECK: "custom.op3"
+ # CHECK: "custom.op2"
+ # CHECK: "custom.op4"
+ module.operation.print()
+
+
+run(test_insert_after_operation)
+
+
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
@@ -114,9 +142,12 @@ def test_insert_at_terminator():
ip = InsertionPoint.at_block_terminator(entry_block)
assert ip.block == entry_block
assert ip.ref_operation == entry_block.operations[1]
- ip.insert(Operation.create("custom.op2"))
+ custom_op2 = Operation.create("custom.op2")
+ ip.insert(custom_op2)
+ InsertionPoint.after(custom_op2).insert(Operation.create("custom.op3"))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
+ # CHECK: "custom.op3"
module.operation.print()
More information about the Mlir-commits
mailing list