[Mlir-commits] [mlir] [MLIR][Python] bind InsertionPointAfter (PR #157156)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 5 12:45:52 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/157156.diff


3 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+14-1) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+4-1) 
- (modified) mlir/test/python/ir/insertion_point.py (+28) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index bf4950fc1a070..f6f3abf9819e9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
 // PyInsertionPoint.
 //------------------------------------------------------------------------------
 
-PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
+PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
 
 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
     : refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
   return PyInsertionPoint{block, std::move(terminatorOpRef)};
 }
 
+PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
+  PyOperation &operation = op.getOperation();
+  MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
+  if (mlirOperationIsNull(nextOperation))
+    return PyInsertionPoint{operation.getBlock()};
+  PyOperationRef nextOpRef =
+      PyOperation::forOperation(operation.getContext(), nextOperation);
+  return PyInsertionPoint{nextOpRef->getOperation().getBlock(),
+                          std::move(nextOpRef)};
+}
+
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
 }
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                   nb::arg("block"), "Inserts at the beginning of the block.")
       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
                   nb::arg("block"), "Inserts before the block terminator.")
+      .def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
+                  "Inserts after the operation.")
       .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
            "Inserts an operation.")
       .def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0cc0459ebc9a0..1d1ff29533f98 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -821,7 +821,7 @@ class PyInsertionPoint {
 public:
   /// Creates an insertion point positioned after the last operation in the
   /// block, but still inside the block.
-  PyInsertionPoint(PyBlock &block);
+  PyInsertionPoint(const PyBlock &block);
   /// Creates an insertion point positioned before a reference operation.
   PyInsertionPoint(PyOperationBase &beforeOperationBase);
 
@@ -829,6 +829,9 @@ class PyInsertionPoint {
   static PyInsertionPoint atBlockBegin(PyBlock &block);
   /// Shortcut to create an insertion point before the block terminator.
   static PyInsertionPoint atBlockTerminator(PyBlock &block);
+  /// Shortcut to create an insertion point to the node after the specified
+  /// operation.
+  static PyInsertionPoint after(PyOperationBase &op);
 
   /// Inserts an operation.
   void insert(PyOperationBase &operationBase);
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 5eb861a2c0891..a0296227cb050 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -63,6 +63,34 @@ def test_insert_before_operation():
 run(test_insert_before_operation)
 
 
+# CHECK-LABEL: TEST: test_insert_after_operation
+def test_insert_after_operation():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
+      func.func @foo() -> () {
+        "custom.op1"() : () -> ()
+        "custom.op2"() : () -> ()
+      }
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        custom_op1 = entry_block.operations[0]
+        custom_op2 = entry_block.operations[1]
+        InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
+        InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op3"
+        # CHECK: "custom.op2"
+        # CHECK: "custom.op4"
+        module.operation.print()
+
+
+run(test_insert_after_operation)
+
+
 # CHECK-LABEL: TEST: test_insert_at_block_begin
 def test_insert_at_block_begin():
     ctx = Context()

``````````

</details>


https://github.com/llvm/llvm-project/pull/157156


More information about the Mlir-commits mailing list