[Mlir-commits] [mlir] 5a600c2 - [mlir][python] Expose `PyInsertionPoint`'s reference operation (#69082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 07:53:22 PDT 2023


Author: Tomás Longeri
Date: 2023-10-18T16:53:18+02:00
New Revision: 5a600c23f9e01f58bac09a8fad096e194fc90ae2

URL: https://github.com/llvm/llvm-project/commit/5a600c23f9e01f58bac09a8fad096e194fc90ae2
DIFF: https://github.com/llvm/llvm-project/commit/5a600c23f9e01f58bac09a8fad096e194fc90ae2.diff

LOG: [mlir][python] Expose `PyInsertionPoint`'s reference operation (#69082)

The reason I want this is that I am writing my own Python bindings and
would like to use the insertion point from
`PyThreadContextEntry::getDefaultInsertionPoint()` to call C++ functions
that take an `OpBuilder` (I don't need to expose it in Python but it
also seems appropriate). AFAICT, there is currently no way to translate
a `PyInsertionPoint` into an `OpBuilder` because the operation is
inaccessible.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/ir/insertion_point.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c8373e06f0db776..389a4621c14e594 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3207,7 +3207,18 @@ void mlir::python::populateIRCore(py::module &m) {
            "Inserts an operation.")
       .def_property_readonly(
           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
-          "Returns the block that this InsertionPoint points to.");
+          "Returns the block that this InsertionPoint points to.")
+      .def_property_readonly(
+          "ref_operation",
+          [](PyInsertionPoint &self) -> py::object {
+            auto ref_operation = self.getRefOperation();
+            if (ref_operation)
+              return ref_operation->getObject();
+            return py::none();
+          },
+          "The reference operation before which new operations are "
+          "inserted, or None if the insertion point is at the end of "
+          "the block");
 
   //----------------------------------------------------------------------------
   // Mapping of PyAttribute.

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 3ca7dd851961a46..c5412e735dddcb5 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -833,6 +833,7 @@ class PyInsertionPoint {
                    const pybind11::object &excTb);
 
   PyBlock &getBlock() { return block; }
+  std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
 
 private:
   // Trampoline constructor that avoids null initializing members while

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e8f4440d216eeb4..2609117dd220bea 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -755,6 +755,8 @@ class InsertionPoint:
     def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
     @property
     def block(self) -> Block: ...
+    @property
+    def ref_operation(self) -> Optional[_OperationBase]: ...
 
 # TODO: Auto-generated. Audit and fix.
 class IntegerAttr(Attribute):

diff  --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 0dc7d757f56d192..268d2e77d036f5e 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -27,6 +27,8 @@ def test_insert_at_block_end():
         )
         entry_block = module.body.operations[0].regions[0].blocks[0]
         ip = InsertionPoint(entry_block)
+        assert ip.block == entry_block
+        assert ip.ref_operation is None
         ip.insert(Operation.create("custom.op2"))
         # CHECK: "custom.op1"
         # CHECK: "custom.op2"
@@ -51,6 +53,8 @@ def test_insert_before_operation():
         )
         entry_block = module.body.operations[0].regions[0].blocks[0]
         ip = InsertionPoint(entry_block.operations[1])
+        assert ip.block == entry_block
+        assert ip.ref_operation == entry_block.operations[1]
         ip.insert(Operation.create("custom.op3"))
         # CHECK: "custom.op1"
         # CHECK: "custom.op3"
@@ -75,6 +79,8 @@ def test_insert_at_block_begin():
         )
         entry_block = module.body.operations[0].regions[0].blocks[0]
         ip = InsertionPoint.at_block_begin(entry_block)
+        assert ip.block == entry_block
+        assert ip.ref_operation == entry_block.operations[0]
         ip.insert(Operation.create("custom.op1"))
         # CHECK: "custom.op1"
         # CHECK: "custom.op2"
@@ -108,6 +114,8 @@ def test_insert_at_terminator():
         )
         entry_block = module.body.operations[0].regions[0].blocks[0]
         ip = InsertionPoint.at_block_terminator(entry_block)
+        assert ip.block == entry_block
+        assert ip.ref_operation == entry_block.operations[1]
         ip.insert(Operation.create("custom.op2"))
         # CHECK: "custom.op1"
         # CHECK: "custom.op2"


        


More information about the Mlir-commits mailing list