[Mlir-commits] [mlir] [MLIR][Python] bind InsertionPointAfter (PR #157156)

Maksim Levental llvmlistbot at llvm.org
Fri Sep 5 13:56:26 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/157156

>From b503ef0813c63e8969fc2665b8771185cc14db9a Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 5 Sep 2025 11:22:04 -0700
Subject: [PATCH] [MLIR][Python] bind InsertionPointAfter

---
 mlir/lib/Bindings/Python/IRCore.cpp    | 15 ++++++++-
 mlir/lib/Bindings/Python/IRModule.h    |  5 ++-
 mlir/test/python/ir/insertion_point.py | 42 ++++++++++++++++++++++++--
 3 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf4950fc1a070..ba00ef712084b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
 // PyInsertionPoint.
 //------------------------------------------------------------------------------
 
-PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
 
 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
     : refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
   return PyInsertionPoint{block, std::move(terminatorOpRef)};
 }
 
+PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
+  PyOperation &operation = op.getOperation();
+  PyBlock block = operation.getBlock();
+  MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
+  if (mlirOperationIsNull(nextOperation))
+    return PyInsertionPoint(block);
+  PyOperationRef nextOpRef = PyOperation::forOperation(
+      block.getParentOperation()->getContext(), nextOperation);
+  return PyInsertionPoint{block, std::move(nextOpRef)};
+}
+
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
 }
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                   nb::arg("block"), "Inserts at the beginning of the block.")
       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
                   nb::arg("block"), "Inserts before the block terminator.")
+      .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+                  "Inserts after the operation.")
       .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
            "Inserts an operation.")
       .def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0cc0459ebc9a0..1d1ff29533f98 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -821,7 +821,7 @@ class PyInsertionPoint {
 public:
   /// Creates an insertion point positioned after the last operation in the
   /// block, but still inside the block.
-  PyInsertionPoint(PyBlock &block);
+  PyInsertionPoint(const PyBlock &block);
   /// Creates an insertion point positioned before a reference operation.
   PyInsertionPoint(PyOperationBase &beforeOperationBase);
 
@@ -829,6 +829,9 @@ class PyInsertionPoint {
   static PyInsertionPoint atBlockBegin(PyBlock &block);
   /// Shortcut to create an insertion point before the block terminator.
   static PyInsertionPoint atBlockTerminator(PyBlock &block);
+  /// Shortcut to create an insertion point to the node after the specified
+  /// operation.
+  static PyInsertionPoint after(PyOperationBase &op);
 
   /// Inserts an operation.
   void insert(PyOperationBase &operationBase);
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 5eb861a2c0891..41fa619941723 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -63,6 +63,34 @@ def test_insert_before_operation():
 run(test_insert_before_operation)
 
 
+# CHECK-LABEL: TEST: test_insert_after_operation
+def test_insert_after_operation():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
+      func.func @foo() -> () {
+        "custom.op1"() : () -> ()
+        "custom.op2"() : () -> ()
+      }
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        custom_op1 = entry_block.operations[0]
+        custom_op2 = entry_block.operations[1]
+        InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
+        InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op3"
+        # CHECK: "custom.op2"
+        # CHECK: "custom.op4"
+        module.operation.print()
+
+
+run(test_insert_after_operation)
+
+
 # CHECK-LABEL: TEST: test_insert_at_block_begin
 def test_insert_at_block_begin():
     ctx = Context()
@@ -111,14 +139,24 @@ def test_insert_at_terminator():
     """
         )
         entry_block = module.body.operations[0].regions[0].blocks[0]
+        return_op = entry_block.operations[1]
         ip = InsertionPoint.at_block_terminator(entry_block)
         assert ip.block == entry_block
-        assert ip.ref_operation == entry_block.operations[1]
-        ip.insert(Operation.create("custom.op2"))
+        assert ip.ref_operation == return_op
+        custom_op2 = Operation.create("custom.op2")
+        ip.insert(custom_op2)
+        InsertionPoint.after(custom_op2).insert(Operation.create("custom.op3"))
         # CHECK: "custom.op1"
         # CHECK: "custom.op2"
+        # CHECK: "custom.op3"
         module.operation.print()
 
+        try:
+            InsertionPoint.after(return_op).insert(Operation.create("custom.op4"))
+        except IndexError as e:
+            # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
+            print(f"ERROR: {e}")
+
 
 run(test_insert_at_terminator)
 



More information about the Mlir-commits mailing list