[Mlir-commits] [mlir] [mlir][python] Wrappers for scf.index_switch (PR #167458)

Asher Mancinelli llvmlistbot at llvm.org
Mon Nov 10 21:46:06 PST 2025


https://github.com/ashermancinelli created https://github.com/llvm/llvm-project/pull/167458

The C++ index switch op has utilies for `getCaseBlock(int i)` and `getDefaultBlock()`, so these have been added.
Optional body builder args have been added for the default case and each switch case.

The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable.

>From 02fe9fb8b2360339b0085bd9952c8dccf638ad49 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Mon, 10 Nov 2025 20:13:10 -0800
Subject: [PATCH 1/3] [mlir][python] Wrappers for scf.index_switch

The C++ index switch op has utilies for getCaseBlock(int i)
and getDefaultBlock(), so these have been added.

Optional builder args have been added for the default case
and each switch case.

The list comprehensions for accessing case regions are due
to what appears to be a bug in RegionSequence; using a comprehension
with explicit indices circumvents this.
The same paradigm is used for get_case_block(i: int), but this
is unavoidable.
---
 mlir/python/mlir/dialects/scf.py | 75 +++++++++++++++++++++++++++++++-
 mlir/test/python/dialects/scf.py | 72 ++++++++++++++++++++++++++++--
 2 files changed, 142 insertions(+), 5 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..6fc0034aa9859 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -6,6 +6,7 @@
 from ._scf_ops_gen import *
 from ._scf_ops_gen import _Dialect
 from .arith import constant
+import builtins
 
 try:
     from ..ir import *
@@ -19,7 +20,6 @@
 
 from typing import List, Optional, Sequence, Tuple, Union
 
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ForOp(ForOp):
     """Specialization for the SCF for op class."""
@@ -254,3 +254,76 @@ def for_(
             yield iv, iter_args[0], for_op.results[0]
         else:
             yield iv
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class IndexSwitchOp(IndexSwitchOp):
+    __doc__ = IndexSwitchOp.__doc__
+
+    def __init__(
+        self,
+        results_,
+        arg,
+        cases,
+        case_body_builder=None,
+        default_body_builder=None,
+        loc=None,
+        ip=None,
+    ):
+        cases = DenseI64ArrayAttr.get(cases)
+        super().__init__(
+            results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
+        )
+        for region in self.regions:
+            region.blocks.append()
+
+        if default_body_builder is not None:
+            with InsertionPoint(self.default_block):
+                default_body_builder(self)
+
+        if case_body_builder is not None:
+            for i, case in enumerate(cases):
+                with InsertionPoint(self.case_block(i)):
+                    case_body_builder(self, i, self.cases[i])
+
+    @builtins.property
+    def default_region(self) -> Region:
+        return self.regions[0]
+
+    @builtins.property
+    def default_block(self) -> Block:
+        return self.default_region.blocks[0]
+
+    @builtins.property
+    def case_regions(self) -> Sequence[Region]:
+        return [self.regions[1 + i] for i in range(len(self.cases))]
+
+    def case_region(self, i: int) -> Region:
+        return self.case_regions[i]
+
+    @builtins.property
+    def case_blocks(self) -> Sequence[Block]:
+        return [region.blocks[0] for region in self.case_regions]
+
+    def case_block(self, i: int) -> Block:
+        return self.case_regions[i].blocks[0]
+
+def index_switch(
+    results_,
+    arg,
+    cases,
+    case_body_builder=None,
+    default_body_builder=None,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, IndexSwitchOp]:
+    op = IndexSwitchOp(
+        results_=results_,
+        arg=arg,
+        cases=cases,
+        case_body_builder=case_body_builder,
+        default_body_builder=default_body_builder,
+        loc=loc,
+        ip=ip,
+    )
+    results = op.results
+    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5e189c8..a4293da945b12 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -1,10 +1,14 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-from mlir.dialects import arith
-from mlir.dialects import func
-from mlir.dialects import memref
-from mlir.dialects import scf
+from mlir.extras import types as T
+from mlir.dialects import (
+    arith,
+    func,
+    memref,
+    scf,
+    cf,
+)
 from mlir.passmanager import PassManager
 
 
@@ -355,3 +359,63 @@ def simple_if_else(cond):
 # CHECK:   scf.yield %[[TWO]], %[[THREE]]
 # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
 # CHECK: return
+
+
+ at constructAndPrintInModule
+def testIndexSwitch():
+
+    i32 = T.i32()
+    @func.FuncOp.from_py_func(T.index(), results=[i32])
+    def index_switch(index):
+        c1 = arith.constant(i32, 1)
+        c0 = arith.constant(i32, 0)
+        value = arith.constant(i32, 5)
+        switch_op = scf.IndexSwitchOp([i32], index, range(3))
+
+        assert switch_op.regions[0] == switch_op.default_region
+        assert switch_op.regions[1] == switch_op.case_regions[0]
+        assert switch_op.regions[1] == switch_op.case_region(0)
+        assert len(switch_op.case_regions) == 3
+        assert len(switch_op.regions) == 4
+
+        with InsertionPoint(switch_op.default_block):
+            cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+            scf.yield_([c1])
+
+        for i, block in enumerate(switch_op.case_blocks):
+            with InsertionPoint(block):
+                scf.YieldOp([arith.constant(i32, i)])
+
+        func.return_([switch_op.results[0]])
+
+    return index_switch
+
+
+ at constructAndPrintInModule
+def testIndexSwitchWithBodyBuilders():
+
+    i32 = T.i32()
+    @func.FuncOp.from_py_func(T.index(), results=[i32])
+    def index_switch(index):
+        c1 = arith.constant(i32, 1)
+        c0 = arith.constant(i32, 0)
+        value = arith.constant(i32, 5)
+
+        def default_body_builder(switch_op):
+            cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+            scf.yield_([c1])
+
+        def case_body_builder(switch_op, case_index: int, case_value: int):
+            scf.YieldOp([arith.constant(i32, case_value)])
+
+        result = scf.index_switch(
+            results_=[i32],
+            arg=index,
+            cases=range(3),
+            case_body_builder=case_body_builder,
+            default_body_builder=default_body_builder,
+        )
+
+        func.return_([result])
+
+    return index_switch

>From ece9d37f2df78dd235dbf37f78091aab7c21a463 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Mon, 10 Nov 2025 20:28:13 -0800
Subject: [PATCH 2/3] Use result builder utility

---
 mlir/python/mlir/dialects/scf.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 6fc0034aa9859..8bb8052d9b0dd 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -13,6 +13,7 @@
     from ._ods_common import (
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
+        get_op_result_or_op_results as _get_op_result_or_op_results,
         _cext as _ods_cext,
     )
 except ImportError as e:
@@ -325,5 +326,4 @@ def index_switch(
         loc=loc,
         ip=ip,
     )
-    results = op.results
-    return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
+    return _get_op_result_or_op_results(op)

>From e5b1689526cedb8fba1da3c2f0fa18317945f476 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Mon, 10 Nov 2025 21:43:21 -0800
Subject: [PATCH 3/3] Formatting

---
 mlir/python/mlir/dialects/scf.py |  3 ++
 mlir/test/python/dialects/scf.py | 62 +++++++++++++++++++++++++++++---
 2 files changed, 61 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 8bb8052d9b0dd..59ccbce147be3 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -21,6 +21,7 @@
 
 from typing import List, Optional, Sequence, Tuple, Union
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ForOp(ForOp):
     """Specialization for the SCF for op class."""
@@ -256,6 +257,7 @@ def for_(
         else:
             yield iv
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class IndexSwitchOp(IndexSwitchOp):
     __doc__ = IndexSwitchOp.__doc__
@@ -308,6 +310,7 @@ def case_blocks(self) -> Sequence[Block]:
     def case_block(self, i: int) -> Block:
         return self.case_regions[i].blocks[0]
 
+
 def index_switch(
     results_,
     arg,
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index a4293da945b12..11d207b4a5e07 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -363,8 +363,8 @@ def simple_if_else(cond):
 
 @constructAndPrintInModule
 def testIndexSwitch():
-
     i32 = T.i32()
+
     @func.FuncOp.from_py_func(T.index(), results=[i32])
     def index_switch(index):
         c1 = arith.constant(i32, 1)
@@ -384,17 +384,44 @@ def index_switch(index):
 
         for i, block in enumerate(switch_op.case_blocks):
             with InsertionPoint(block):
-                scf.YieldOp([arith.constant(i32, i)])
+                scf.yield_([arith.constant(i32, i)])
 
         func.return_([switch_op.results[0]])
 
     return index_switch
 
 
+# CHECK-LABEL:   func.func @index_switch(
+# CHECK-SAME:      %[[ARG0:.*]]: index) -> i32 {
+# CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK:           %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK:           %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK:           %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK:           case 0 {
+# CHECK:             %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK:             scf.yield %[[CONSTANT_3]] : i32
+# CHECK:           }
+# CHECK:           case 1 {
+# CHECK:             %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK:             scf.yield %[[CONSTANT_4]] : i32
+# CHECK:           }
+# CHECK:           case 2 {
+# CHECK:             %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK:             scf.yield %[[CONSTANT_5]] : i32
+# CHECK:           }
+# CHECK:           default {
+# CHECK:             %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK:             cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK:             scf.yield %[[CONSTANT_0]] : i32
+# CHECK:           }
+# CHECK:           return %[[INDEX_SWITCH_0]] : i32
+# CHECK:         }
+
+
 @constructAndPrintInModule
 def testIndexSwitchWithBodyBuilders():
-
     i32 = T.i32()
+
     @func.FuncOp.from_py_func(T.index(), results=[i32])
     def index_switch(index):
         c1 = arith.constant(i32, 1)
@@ -406,7 +433,7 @@ def default_body_builder(switch_op):
             scf.yield_([c1])
 
         def case_body_builder(switch_op, case_index: int, case_value: int):
-            scf.YieldOp([arith.constant(i32, case_value)])
+            scf.yield_([arith.constant(i32, case_value)])
 
         result = scf.index_switch(
             results_=[i32],
@@ -419,3 +446,30 @@ def case_body_builder(switch_op, case_index: int, case_value: int):
         func.return_([result])
 
     return index_switch
+
+
+# CHECK-LABEL:   func.func @index_switch(
+# CHECK-SAME:      %[[ARG0:.*]]: index) -> i32 {
+# CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK:           %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK:           %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK:           %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK:           case 0 {
+# CHECK:             %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK:             scf.yield %[[CONSTANT_3]] : i32
+# CHECK:           }
+# CHECK:           case 1 {
+# CHECK:             %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK:             scf.yield %[[CONSTANT_4]] : i32
+# CHECK:           }
+# CHECK:           case 2 {
+# CHECK:             %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK:             scf.yield %[[CONSTANT_5]] : i32
+# CHECK:           }
+# CHECK:           default {
+# CHECK:             %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK:             cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK:             scf.yield %[[CONSTANT_0]] : i32
+# CHECK:           }
+# CHECK:           return %[[INDEX_SWITCH_0]] : i32
+# CHECK:         }



More information about the Mlir-commits mailing list