[Mlir-commits] [mlir] c8a9a41 - [MLIR] [python] A few improvements to the Python bindings (#131686)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 20 21:13:17 PDT 2025


Author: Sergei Lebedev
Date: 2025-03-21T00:13:13-04:00
New Revision: c8a9a4109ac7756af3f0f5aab8c70e686a2f30b7

URL: https://github.com/llvm/llvm-project/commit/c8a9a4109ac7756af3f0f5aab8c70e686a2f30b7
DIFF: https://github.com/llvm/llvm-project/commit/c8a9a4109ac7756af3f0f5aab8c70e686a2f30b7.diff

LOG: [MLIR] [python] A few improvements to the Python bindings (#131686)

* `PyRegionList` is now sliceable. The dialect bindings generator seems
to assume it is sliceable already (!), yet accessing e.g. `cases` on
`scf.IndexedSwitchOp` raises a `TypeError` at runtime.
* `PyBlockList` and `PyOperationList` support negative indexing. It is
common for containers to do that in Python, and most container in the
MLIR Python bindings already allow the index to be negative.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 78ba144acf1e9..5ffcf671741bd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -361,37 +361,45 @@ class PyRegionIterator {
 
 /// Regions of an op are fixed length and indexed numerically so are represented
 /// with a sequence-like container.
-class PyRegionList {
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
 public:
-  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+  static constexpr const char *pyClassName = "RegionSequence";
+
+  PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+               intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumRegions(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
 
   PyRegionIterator dunderIter() {
     operation->checkValid();
     return PyRegionIterator(operation);
   }
 
-  intptr_t dunderLen() {
+  static void bindDerived(ClassTy &c) {
+    c.def("__iter__", &PyRegionList::dunderIter);
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyRegionList, PyRegion>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumRegions(operation->get());
   }
 
-  PyRegion dunderGetItem(intptr_t index) {
-    // dunderLen checks validity.
-    if (index < 0 || index >= dunderLen()) {
-      throw nb::index_error("attempt to access out of bounds region");
-    }
-    MlirRegion region = mlirOperationGetRegion(operation->get(), index);
-    return PyRegion(operation, region);
+  PyRegion getRawElement(intptr_t pos) {
+    operation->checkValid();
+    return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
   }
 
-  static void bind(nb::module_ &m) {
-    nb::class_<PyRegionList>(m, "RegionSequence")
-        .def("__len__", &PyRegionList::dunderLen)
-        .def("__iter__", &PyRegionList::dunderIter)
-        .def("__getitem__", &PyRegionList::dunderGetItem);
+  PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyRegionList(operation, startIndex, length, step);
   }
 
-private:
   PyOperationRef operation;
 };
 
@@ -450,6 +458,9 @@ class PyBlockList {
 
   PyBlock dunderGetItem(intptr_t index) {
     operation->checkValid();
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0) {
       throw nb::index_error("attempt to access out of bounds block");
     }
@@ -546,6 +557,9 @@ class PyOperationList {
 
   nb::object dunderGetItem(intptr_t index) {
     parentOperation->checkValid();
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0) {
       throw nb::index_error("attempt to access out of bounds operation");
     }
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
   }
 
   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0 || index >= dunderLen()) {
       throw nb::index_error("attempt to access out of bounds attribute");
     }

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index c93de2fe3154e..c60ff72ff9fd4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -2466,7 +2466,10 @@ class RegionIterator:
     def __next__(self) -> Region: ...
 
 class RegionSequence:
+    @overload
     def __getitem__(self, arg0: int) -> Region: ...
+    @overload
+    def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
     def __iter__(self) -> RegionIterator: ...
     def __len__(self) -> int: ...
 

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index dd2731ba2e1f1..b08fe98397fbc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators():
     op = module.operation
     assert op.context is ctx
     # Get the block using iterators off of the named collections.
-    regions = list(op.regions)
+    regions = list(op.regions[:])
     blocks = list(regions[0].blocks)
     # CHECK: MODULE REGIONS=1 BLOCKS=1
     print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
@@ -86,8 +86,24 @@ def walk_operations(indent, op):
     # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
     # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
     print("   Region iter:", iter(op.regions))
-    print("    Block iter:", iter(op.regions[0]))
-    print("Operation iter:", iter(op.regions[0].blocks[0]))
+    print("    Block iter:", iter(op.regions[-1]))
+    print("Operation iter:", iter(op.regions[-1].blocks[-1]))
+
+    try:
+        op.regions[-42]
+    except IndexError as e:
+        # CHECK: Region OOB: index out of range
+        print("Region OOB:", e)
+    try:
+        op.regions[0].blocks[-42]
+    except IndexError as e:
+        # CHECK: attempt to access out of bounds block
+        print(e)
+    try:
+        op.regions[0].blocks[0].operations[-42]
+    except IndexError as e:
+        # CHECK: attempt to access out of bounds operation
+        print(e)
 
 
 # Verify index based traversal of the op/region/block hierarchy.


        


More information about the Mlir-commits mailing list