[Mlir-commits] [mlir] [MLIR][Python] fix PyRegionList `__iter__` (PR #167466)
Maksim Levental
llvmlistbot at llvm.org
Tue Nov 11 00:19:57 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/167466
>From 0acbf74be29a7eaf0bb2be4bacf99322ea613adb Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 11 Nov 2025 03:06:41 -0500
Subject: [PATCH] [MLIR][Python] fix PyRegionList __iter__
---
mlir/lib/Bindings/Python/IRCore.cpp | 8 +++----
mlir/lib/Bindings/Python/NanobindUtils.h | 1 -
mlir/test/python/ir/operation.py | 28 ++++++++++++++++++++++--
3 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d90f27bd037e6..40a466beee159 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -204,8 +204,8 @@ namespace {
class PyRegionIterator {
public:
- PyRegionIterator(PyOperationRef operation)
- : operation(std::move(operation)) {}
+ PyRegionIterator(PyOperationRef operation, int nextIndex)
+ : operation(std::move(operation)), nextIndex(nextIndex) {}
PyRegionIterator &dunderIter() { return *this; }
@@ -228,7 +228,7 @@ class PyRegionIterator {
private:
PyOperationRef operation;
- int nextIndex = 0;
+ intptr_t nextIndex = 0;
};
/// Regions of an op are fixed length and indexed numerically so are represented
@@ -247,7 +247,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
PyRegionIterator dunderIter() {
operation->checkValid();
- return PyRegionIterator(operation);
+ return PyRegionIterator(operation, startIndex);
}
static void bindDerived(ClassTy &c) {
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index 64ea4329f65f1..658e8ad5330ef 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -395,7 +395,6 @@ class Sliceable {
/// Hook for derived classes willing to bind more methods.
static void bindDerived(ClassTy &) {}
-private:
intptr_t startIndex;
intptr_t length;
intptr_t step;
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f5fa4dad856f8..1bdd345d98c05 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -2,12 +2,12 @@
import gc
import io
-import itertools
from tempfile import NamedTemporaryFile
from mlir.ir import *
from mlir.dialects.builtin import ModuleOp
-from mlir.dialects import arith
+from mlir.dialects import arith, func, scf
from mlir.dialects._ods_common import _cext
+from mlir.extras import types as T
def run(f):
@@ -1199,3 +1199,27 @@ def testGetOwnerConcreteOpview():
r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
for u in a.result.uses:
assert isinstance(u.owner, arith.AddIOp)
+
+
+# CHECK-LABEL: TEST: testIndexSwitch
+ at run
+def testIndexSwitch():
+ with Context() as ctx, Location.unknown():
+ i32 = T.i32()
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(T.index())
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ switch_op = scf.IndexSwitchOp(
+ results_=[i32], arg=index, cases=range(3), num_caseRegions=3
+ )
+
+ assert len(switch_op.regions) == 4
+ assert len(switch_op.regions[2:]) == 2
+ assert len([i for i in switch_op.regions[2:]]) == 2
+ assert len(switch_op.caseRegions) == 3
+ assert len([i for i in switch_op.caseRegions]) == 3
+ assert len(switch_op.caseRegions[1:]) == 2
+ assert len([i for i in switch_op.caseRegions[1:]]) == 2
More information about the Mlir-commits
mailing list