[Mlir-commits] [mlir] [MLIR] [python] A few improvements to the Python bindings (PR #131686)
Sergei Lebedev
llvmlistbot at llvm.org
Thu Mar 20 04:49:44 PDT 2025
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/131686
>From c2749515cc9b122b16a062d6efa0af77b9c9d1ab Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Mon, 17 Mar 2025 22:31:47 +0000
Subject: [PATCH] [MLIR] [PYTHON] A few improvements to the Python bindings
* `PyRegionList` is now sliceable. The dialect bindings generator seems to
assume it is sliceable already (!), yet accessing e.g. `cases` on
`scf.IndexedSwitchOp` raises a `TypeError` at runtime.
* `PyBlockList` and `PyOperationList` support negative indexing. It is common
for containers to do that in Python, and most container in the MLIR Python
bindings already allow the index to be negative.
---
mlir/lib/Bindings/Python/IRCore.cpp | 49 ++++++++++++++++--------
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 ++
mlir/test/python/ir/operation.py | 22 +++++++++--
3 files changed, 55 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..dc41aaea3261c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -361,37 +361,45 @@ class PyRegionIterator {
/// Regions of an op are fixed length and indexed numerically so are represented
/// with a sequence-like container.
-class PyRegionList {
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
public:
- PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
PyRegionIterator dunderIter() {
operation->checkValid();
return PyRegionIterator(operation);
}
- intptr_t dunderLen() {
+ static void bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter);
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumRegions(operation->get());
}
- PyRegion dunderGetItem(intptr_t index) {
- // dunderLen checks validity.
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds region");
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), index);
- return PyRegion(operation, region);
+ PyRegion getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
}
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionList>(m, "RegionSequence")
- .def("__len__", &PyRegionList::dunderLen)
- .def("__iter__", &PyRegionList::dunderIter)
- .def("__getitem__", &PyRegionList::dunderGetItem);
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyRegionList(operation, startIndex, length, step);
}
-private:
PyOperationRef operation;
};
@@ -450,6 +458,9 @@ class PyBlockList {
PyBlock dunderGetItem(intptr_t index) {
operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0) {
throw nb::index_error("attempt to access out of bounds block");
}
@@ -546,6 +557,9 @@ class PyOperationList {
nb::object dunderGetItem(intptr_t index) {
parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0) {
throw nb::index_error("attempt to access out of bounds operation");
}
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0 || index >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index c93de2fe3154e..c60ff72ff9fd4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -2466,7 +2466,10 @@ class RegionIterator:
def __next__(self) -> Region: ...
class RegionSequence:
+ @overload
def __getitem__(self, arg0: int) -> Region: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
def __iter__(self) -> RegionIterator: ...
def __len__(self) -> int: ...
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index dd2731ba2e1f1..b08fe98397fbc 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators():
op = module.operation
assert op.context is ctx
# Get the block using iterators off of the named collections.
- regions = list(op.regions)
+ regions = list(op.regions[:])
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
@@ -86,8 +86,24 @@ def walk_operations(indent, op):
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
print(" Region iter:", iter(op.regions))
- print(" Block iter:", iter(op.regions[0]))
- print("Operation iter:", iter(op.regions[0].blocks[0]))
+ print(" Block iter:", iter(op.regions[-1]))
+ print("Operation iter:", iter(op.regions[-1].blocks[-1]))
+
+ try:
+ op.regions[-42]
+ except IndexError as e:
+ # CHECK: Region OOB: index out of range
+ print("Region OOB:", e)
+ try:
+ op.regions[0].blocks[-42]
+ except IndexError as e:
+ # CHECK: attempt to access out of bounds block
+ print(e)
+ try:
+ op.regions[0].blocks[0].operations[-42]
+ except IndexError as e:
+ # CHECK: attempt to access out of bounds operation
+ print(e)
# Verify index based traversal of the op/region/block hierarchy.
More information about the Mlir-commits
mailing list