[Mlir-commits] [mlir] Makslevental/scf forall parallel (PR #150243)

Maksim Levental llvmlistbot at llvm.org
Wed Jul 23 09:25:53 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/150243

>From 159214a4928bb0e3313e270f8f3255c24d3f18de Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:43:45 -0400
Subject: [PATCH 1/7] move helpers

---
 mlir/python/mlir/dialects/arith.py            |  26 +----
 .../dialects/linalg/opdsl/lang/emitter.py     | 106 +++++++-----------
 mlir/python/mlir/util.py                      |  47 ++++++++
 3 files changed, 89 insertions(+), 90 deletions(-)
 create mode 100644 mlir/python/mlir/util.py

diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce66..3b60ed8ddf94c 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,9 +5,11 @@
 from ._arith_ops_gen import *
 from ._arith_ops_gen import _Dialect
 from ._arith_enum_gen import *
+from ..util import is_integer_type, is_index_type, is_float_type
 from array import array as _array
 from typing import overload
 
+
 try:
     from ..ir import *
     from ._ods_common import (
@@ -21,26 +23,6 @@
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-def _isa(obj: Any, cls: type):
-    try:
-        cls(obj)
-    except ValueError:
-        return False
-    return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
-    return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
-    return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
-    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
@@ -96,9 +78,9 @@ def value(self):
 
     @property
     def literal_value(self) -> Union[int, float]:
-        if _is_integer_like_type(self.type):
+        if is_integer_type(self.type) or is_index_type(self.type):
             return IntegerAttr(self.value).value
-        elif _is_float_type(self.type):
+        elif is_float_type(self.type):
             return FloatAttr(self.value).value
         else:
             raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..cae70fc03b9d6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -5,6 +5,13 @@
 from typing import Callable, Dict, List, Sequence, Tuple, Union
 
 from .....ir import *
+from .....util import (
+    is_complex_type,
+    is_float_type,
+    is_index_type,
+    is_integer_type,
+    get_floating_point_width,
+)
 
 from .... import func
 from .... import linalg
@@ -412,9 +419,9 @@ def _cast(
             )
         if operand.type == to_type:
             return operand
-        if _is_integer_type(to_type):
+        if is_integer_type(to_type):
             return self._cast_to_integer(to_type, operand, is_unsigned_cast)
-        elif _is_floating_point_type(to_type):
+        elif is_float_type(to_type):
             return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
 
     def _cast_to_integer(
@@ -422,11 +429,11 @@ def _cast_to_integer(
     ) -> Value:
         to_width = IntegerType(to_type).width
         operand_type = operand.type
-        if _is_floating_point_type(operand_type):
+        if is_float_type(operand_type):
             if is_unsigned_cast:
                 return arith.FPToUIOp(to_type, operand).result
             return arith.FPToSIOp(to_type, operand).result
-        if _is_index_type(operand_type):
+        if is_index_type(operand_type):
             return arith.IndexCastOp(to_type, operand).result
         # Assume integer.
         from_width = IntegerType(operand_type).width
@@ -444,13 +451,13 @@ def _cast_to_floating_point(
         self, to_type: Type, operand: Value, is_unsigned_cast: bool
     ) -> Value:
         operand_type = operand.type
-        if _is_integer_type(operand_type):
+        if is_integer_type(operand_type):
             if is_unsigned_cast:
                 return arith.UIToFPOp(to_type, operand).result
             return arith.SIToFPOp(to_type, operand).result
         # Assume FloatType.
-        to_width = _get_floating_point_width(to_type)
-        from_width = _get_floating_point_width(operand_type)
+        to_width = get_floating_point_width(to_type)
+        from_width = get_floating_point_width(operand_type)
         if to_width > from_width:
             return arith.ExtFOp(to_type, operand).result
         elif to_width < from_width:
@@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
         return self._cast(type_var_name, operand, True)
 
     def _unary_exp(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.AbsFOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.CeilOp(x).result
         raise NotImplementedError("Unsupported 'ceil' operand: {x}")
 
     def _unary_floor(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return math.FloorOp(x).result
         raise NotImplementedError("Unsupported 'floor' operand: {x}")
 
     def _unary_negf(self, x: Value) -> Value:
-        if _is_floating_point_type(x.type):
+        if is_float_type(x.type):
             return arith.NegFOp(x).result
-        if _is_complex_type(x.type):
+        if is_complex_type(x.type):
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.AddIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.AddOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
 
     def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.SubFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.SubIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.SubOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
     def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MulFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MulIOp(lhs, rhs).result
-        if _is_complex_type(lhs.type):
+        if is_complex_type(lhs.type):
             return complex.MulOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
     def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MaximumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MaxSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
 
     def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MaximumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MaxUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
 
     def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MinimumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MinSIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
 
     def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-        if _is_floating_point_type(lhs.type):
+        if is_float_type(lhs.type):
             return arith.MinimumFOp(lhs, rhs).result
-        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+        if is_integer_type(lhs.type) or is_index_type(lhs.type):
             return arith.MinUIOp(lhs, rhs).result
         raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
@@ -609,40 +616,3 @@ def _add_type_mapping(
             )
     type_mapping[name] = element_or_self_type
     block_arg_types.append(element_or_self_type)
-
-
-def _is_complex_type(t: Type) -> bool:
-    return ComplexType.isinstance(t)
-
-
-def _is_floating_point_type(t: Type) -> bool:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    return (
-        F64Type.isinstance(t)
-        or F32Type.isinstance(t)
-        or F16Type.isinstance(t)
-        or BF16Type.isinstance(t)
-    )
-
-
-def _is_integer_type(t: Type) -> bool:
-    return IntegerType.isinstance(t)
-
-
-def _is_index_type(t: Type) -> bool:
-    return IndexType.isinstance(t)
-
-
-def _get_floating_point_width(t: Type) -> int:
-    # TODO: Create a FloatType in the Python API and implement the switch
-    # there.
-    if F64Type.isinstance(t):
-        return 64
-    if F32Type.isinstance(t):
-        return 32
-    if F16Type.isinstance(t):
-        return 16
-    if BF16Type.isinstance(t):
-        return 16
-    raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py
new file mode 100644
index 0000000000000..cc85a99337f38
--- /dev/null
+++ b/mlir/python/mlir/util.py
@@ -0,0 +1,47 @@
+from .ir import (
+    BF16Type,
+    ComplexType,
+    F16Type,
+    F32Type,
+    F64Type,
+    IndexType,
+    IntegerType,
+    Type,
+)
+
+
+def is_complex_type(t: Type) -> bool:
+    return ComplexType.isinstance(t)
+
+
+def is_float_type(t: Type) -> bool:
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    return (
+        F64Type.isinstance(t)
+        or F32Type.isinstance(t)
+        or F16Type.isinstance(t)
+        or BF16Type.isinstance(t)
+    )
+
+
+def is_integer_type(t: Type) -> bool:
+    return IntegerType.isinstance(t)
+
+
+def is_index_type(t: Type) -> bool:
+    return IndexType.isinstance(t)
+
+
+def get_floating_point_width(t: Type) -> int:
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    if F64Type.isinstance(t):
+        return 64
+    if F32Type.isinstance(t):
+        return 32
+    if F16Type.isinstance(t):
+        return 16
+    if BF16Type.isinstance(t):
+        return 16
+    raise NotImplementedError(f"Unhandled floating point type switch {t}")

>From f68d0669a18a3d5769fb20a83e8832db688a82f4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:43:56 -0400
Subject: [PATCH 2/7] add index_cast

---
 mlir/python/mlir/dialects/arith.py | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 3b60ed8ddf94c..418ec0bdbb6db 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -90,3 +90,19 @@ def constant(
     result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+def index_cast(
+    in_: Value,
+    to: Type = None,
+    *,
+    out: Type = None,
+    loc: Location = None,
+    ip: InsertionPoint = None,
+) -> Value:
+    if bool(to) != bool(out):
+        raise ValueError("either `to` or `out` must be set but not both")
+    res_type = out or to
+    if res_type is None:
+        res_type = IndexType.get()
+    return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip))

>From f67bbd06546541fb4a0a68a59641a444c3b9116c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:44:11 -0400
Subject: [PATCH 3/7] add region_adder

---
 mlir/python/mlir/extras/meta.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
index 3f2defadf7941..fabe1d8e141ed 100644
--- a/mlir/python/mlir/extras/meta.py
+++ b/mlir/python/mlir/extras/meta.py
@@ -6,7 +6,7 @@
 from functools import wraps
 
 from ..dialects._ods_common import get_op_result_or_op_results
-from ..ir import Type, InsertionPoint
+from ..ir import Type, InsertionPoint, Value
 
 
 def op_region_builder(op, op_region, terminator=None):
@@ -81,3 +81,17 @@ def maybe_no_args(*args, **kwargs):
             return op_decorator(*args, **kwargs)
 
     return maybe_no_args
+
+
+def region_adder(terminator=None):
+    def wrapper(op_region_adder):
+        def region_adder_decorator(op, *args, **kwargs):
+            if isinstance(op, Value):
+                op = op.owner.opview
+            region = op_region_adder(op, *args, **kwargs)
+
+            return op_region_builder(op, region, terminator)
+
+        return region_adder_decorator
+
+    return wrapper

>From 567f32e7c486349850c3e85c7da8c8c0af2c96a3 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:44:36 -0400
Subject: [PATCH 4/7] add forall parallel helper

---
 mlir/python/mlir/dialects/scf.py | 105 ++++++++++++++++++++-
 mlir/test/python/dialects/scf.py | 152 ++++++++++++++++++++++++++++++-
 2 files changed, 254 insertions(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..b77cfa8668b65 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -5,10 +5,12 @@
 
 from ._scf_ops_gen import *
 from ._scf_ops_gen import _Dialect
-from .arith import constant
+from . import arith
+from ..extras.meta import region_op, region_adder
 
 try:
     from ..ir import *
+    from ..util import is_index_type
     from ._ods_common import (
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
@@ -237,7 +239,7 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(IndexType.get(), p)
+            p = arith.constant(IndexType.get(), p)
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
@@ -254,3 +256,102 @@ def for_(
             yield iv, iter_args[0], for_op.results[0]
         else:
             yield iv
+
+
+def _parfor(op_ctor):
+    def _base(
+        lower_bounds, upper_bounds=None, steps=None, *, loc=None, ip=None, **kwargs
+    ):
+        if upper_bounds is None:
+            upper_bounds = lower_bounds
+            lower_bounds = [0] * len(upper_bounds)
+        if steps is None:
+            steps = [1] * len(lower_bounds)
+
+        params = [lower_bounds, upper_bounds, steps]
+        for i, p in enumerate(params):
+            for j, pp in enumerate(p):
+                if isinstance(p, int):
+                    pp = arith.constant(IndexType.get(), p)
+                if not is_index_type(pp.type):
+                    pp = arith.index_cast(pp)
+                p[j] = pp
+            params[i] = p
+
+        return op_ctor(*params, loc=loc, ip=ip, **kwargs)
+
+    return _base
+
+
+def _parfor_cm(op_ctor):
+    def _base(*args, **kwargs):
+        for_op = _parfor(op_ctor)(*args, **kwargs)
+        block = for_op.regions[0].blocks[0]
+        block_args = tuple(block.arguments)
+        with InsertionPoint(block):
+            yield block_args
+
+    return _base
+
+
+forall = _parfor_cm(ForallOp)
+
+
+class ParallelOp(ParallelOp):
+    def __init__(
+        self,
+        lower_bounds,
+        upper_bounds,
+        steps,
+        inits: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        assert len(lower_bounds) == len(upper_bounds) == len(steps)
+        if inits is None:
+            inits = []
+        results = [i.type for i in inits]
+        iv_types = [IndexType.get()] * len(lower_bounds)
+        super().__init__(
+            results,
+            lower_bounds,
+            upper_bounds,
+            steps,
+            inits,
+            loc=loc,
+            ip=ip,
+        )
+        self.regions[0].blocks.append(*iv_types)
+
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variables(self):
+        return self.body.arguments
+
+
+parallel = _parfor_cm(ParallelOp)
+
+
+class ReduceOp(ReduceOp):
+    def __init__(self, operands, num_reductions, *, loc=None, ip=None):
+        super().__init__(operands, num_reductions, loc=loc, ip=ip)
+        for i in range(num_reductions):
+            self.regions[i].blocks.append(operands[i].type, operands[i].type)
+
+
+def reduce_(*operands, num_reductions=1):
+    return ReduceOp(operands, num_reductions, loc=loc)
+
+
+reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))
+
+
+ at region_adder(terminator=lambda xs: reduce_return(*xs))
+def another_reduce(reduce_op):
+    for r in reduce_op.regions:
+        if len(r.blocks[0].operations) == 0:
+            return r
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5e189c8..45ac71e4b0bc7 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -5,7 +5,7 @@
 from mlir.dialects import func
 from mlir.dialects import memref
 from mlir.dialects import scf
-from mlir.passmanager import PassManager
+from mlir.dialects import tensor
 
 
 def constructAndPrintInModule(f):
@@ -38,6 +38,156 @@ def forall_loop(tensor):
         assert loop.verify()
 
 
+# CHECK-LABEL: TEST: test_forall_insert_slice_no_region_with_for
+ at constructAndPrintInModule
+def test_forall_insert_slice_no_region_with_for():
+    i32 = IntegerType.get_signless(32)
+    f32 = F32Type.get()
+    ten = tensor.empty([10, 10], i32)
+
+    for i, j, shared_outs in scf.forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]):
+        one = arith.constant(f32, 1.0)
+
+        scf.parallel_insert_slice(
+            ten,
+            shared_outs,
+            offsets=[i, j],
+            static_sizes=[10, 10],
+            static_strides=[1, 1],
+        )
+
+    # CHECK:  %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
+    # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_2:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_3:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_4:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_5:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_6:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
+    # CHECK:    %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32
+    # CHECK:    scf.forall.in_parallel {
+    # CHECK:      tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32>
+    # CHECK:    }
+    # CHECK:  }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for
+ at constructAndPrintInModule
+def test_parange_inits_with_for():
+    i32 = IntegerType.get_signless(32)
+    f32 = F32Type.get()
+    tensor_type = RankedTensorType.get([10, 10], f32)
+    ten = tensor.empty([10, 10], i32)
+
+    for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[ten]):
+        one = arith.constant(f32, 1.0)
+        ten2 = tensor.empty([10, 10], i32)
+
+        @scf.reduce(ten2)
+        def res(lhs: tensor_type, rhs: tensor_type):
+            return lhs + rhs
+
+    # CHECK:  %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
+    # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_2:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_3:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_4:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_5:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_6:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_7:.*]] = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]]) -> tensor<10x10xi32> {
+    # CHECK:    %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32
+    # CHECK:    %[[VAL_11:.*]] = tensor.empty() : tensor<10x10xi32>
+    # CHECK:    scf.reduce(%[[VAL_11]] : tensor<10x10xi32>) {
+    # CHECK:    ^bb0(%[[VAL_12:.*]]: tensor<10x10xi32>, %[[VAL_13:.*]]: tensor<10x10xi32>):
+    # CHECK:      %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : tensor<10x10xi32>
+    # CHECK:      scf.reduce.return %[[VAL_14]] : tensor<10x10xi32>
+    # CHECK:    }
+    # CHECK:  }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for_with_two_reduce
+ at constructAndPrintInModule
+def test_parange_inits_with_for_with_two_reduce():
+    index_type = IndexType.get()
+    one = arith.constant(index_type, 1)
+
+    for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[one, one]):
+
+        @scf.reduce(i, j, num_reductions=2)
+        def res1(lhs: index_type, rhs: index_type):
+            return lhs + rhs
+
+        @scf.another_reduce(res1)
+        def res2(lhs: index_type, rhs: index_type):
+            return lhs + rhs
+
+    # CHECK:  %[[VAL_0:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_2:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_3:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_4:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_5:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_6:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_7:.*]]:2 = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]], %[[VAL_0]]) -> (index, index) {
+    # CHECK:    scf.reduce(%[[VAL_8]], %[[VAL_9]] : index, index) {
+    # CHECK:    ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
+    # CHECK:      %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index
+    # CHECK:      scf.reduce.return %[[VAL_12]] : index
+    # CHECK:    }, {
+    # CHECK:    ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index):
+    # CHECK:      %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
+    # CHECK:      scf.reduce.return %[[VAL_15]] : index
+    # CHECK:    }
+    # CHECK:  }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for_with_three_reduce
+ at constructAndPrintInModule
+def test_parange_inits_with_for_with_three_reduce():
+    index_type = IndexType.get()
+    one = arith.constant(index_type, 1)
+
+    for i, j, k in scf.parallel([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]):
+
+        @scf.reduce(i, j, k, num_reductions=3)
+        def res1(lhs: index_type, rhs: index_type):
+            return lhs + rhs
+
+        @scf.another_reduce(res1)
+        def res2(lhs: index_type, rhs: index_type):
+            return lhs + rhs
+
+        @scf.another_reduce(res2)
+        def res3(lhs: index_type, rhs: index_type):
+            return lhs + rhs
+
+    # CHECK:  %[[VAL_0:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_2:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_3:.*]] = arith.constant 1 : index
+    # CHECK:  %[[VAL_4:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_5:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_6:.*]] = arith.constant 2 : index
+    # CHECK:  %[[VAL_7:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_8:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_9:.*]] = arith.constant 3 : index
+    # CHECK:  %[[VAL_10:.*]]:3 = scf.parallel (%[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]]) = (%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) to (%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) step (%[[VAL_7]], %[[VAL_8]], %[[VAL_9]]) init (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) -> (index, index, index) {
+    # CHECK:    scf.reduce(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : index, index, index) {
+    # CHECK:    ^bb0(%[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index):
+    # CHECK:      %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index
+    # CHECK:      scf.reduce.return %[[VAL_16]] : index
+    # CHECK:    }, {
+    # CHECK:    ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index):
+    # CHECK:      %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
+    # CHECK:      scf.reduce.return %[[VAL_19]] : index
+    # CHECK:    }, {
+    # CHECK:    ^bb0(%[[VAL_20:.*]]: index, %[[VAL_21:.*]]: index):
+    # CHECK:      %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index
+    # CHECK:      scf.reduce.return %[[VAL_22]] : index
+    # CHECK:    }
+    # CHECK:  }
+
+
 # CHECK-LABEL: TEST: testSimpleLoop
 @constructAndPrintInModule
 def testSimpleLoop():

>From 3f645bb8672503e577b490bd2951bf20628bd9b7 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:24:38 -0400
Subject: [PATCH 5/7] fix util helpers

---
 mlir/python/CMakeLists.txt                 | 1 +
 mlir/python/mlir/dialects/memref.py        | 5 +++--
 mlir/python/mlir/util.py                   | 4 ++++
 mlir/test/python/dialects/arith_dialect.py | 7 +++----
 4 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..f01798f48ff86 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    util.py
 
     # The main _mlir module has submodules: include stubs from each.
     _mlir_libs/_mlir/__init__.pyi
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..2130a1966d88e 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -7,8 +7,9 @@
 
 from ._memref_ops_gen import *
 from ._ods_common import _dispatch_mixed_values, MixedValues
-from .arith import ConstantOp, _is_integer_like_type
+from .arith import ConstantOp
 from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
+from ..util import is_integer_like_type
 
 
 def _is_constant_int_like(i):
@@ -16,7 +17,7 @@ def _is_constant_int_like(i):
         isinstance(i, Value)
         and isinstance(i.owner, Operation)
         and isinstance(i.owner.opview, ConstantOp)
-        and _is_integer_like_type(i.type)
+        and is_integer_like_type(i.type)
     )
 
 
diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py
index cc85a99337f38..453b74777014f 100644
--- a/mlir/python/mlir/util.py
+++ b/mlir/python/mlir/util.py
@@ -33,6 +33,10 @@ def is_index_type(t: Type) -> bool:
     return IndexType.isinstance(t)
 
 
+def is_integer_like_type(t: Type) -> bool:
+    return is_integer_type(t) or is_index_type(t)
+
+
 def get_floating_point_width(t: Type) -> int:
     # TODO: Create a FloatType in the Python API and implement the switch
     # there.
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..0a197c4e673f9 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
 from mlir.ir import *
 import mlir.dialects.arith as arith
 import mlir.dialects.func as func
+from mlir.util import is_float_type, is_integer_like_type
 from array import array
 
 
@@ -42,11 +43,9 @@ def testFastMathFlags():
 def testArithValue():
     def _binary_op(lhs, rhs, op: str) -> "ArithValue":
         op = op.capitalize()
-        if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+        if is_float_type(lhs.type) and is_float_type(rhs.type):
             op += "F"
-        elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
-            lhs.type
-        ):
+        elif is_integer_like_type(lhs.type) and is_integer_like_type(lhs.type):
             op += "I"
         else:
             raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")

>From 7f3b59ef83930607ce2d5b3b6796716659312199 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:25:00 -0400
Subject: [PATCH 6/7] add parallel_insert_slice

---
 mlir/python/mlir/dialects/scf.py    | 31 +++++++++++++++++++++
 mlir/python/mlir/dialects/tensor.py | 42 +++++++++++++++++++++++++++++
 2 files changed, 73 insertions(+)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index b77cfa8668b65..95a5e1c901e66 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -355,3 +355,34 @@ def another_reduce(reduce_op):
     for r in reduce_op.regions:
         if len(r.blocks[0].operations) == 0:
             return r
+
+
+ at region_op
+def in_parallel():
+    return InParallelOp()
+
+
+def parallel_insert_slice(
+    source,
+    dest,
+    static_offsets=None,
+    static_sizes=None,
+    static_strides=None,
+    offsets=None,
+    sizes=None,
+    strides=None,
+):
+    from . import tensor
+
+    @in_parallel
+    def foo():
+        tensor.parallel_insert_slice(
+            source,
+            dest,
+            offsets,
+            sizes,
+            strides,
+            static_offsets,
+            static_sizes,
+            static_strides,
+        )
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 146b5f85d07f5..b1baa22b15e23 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -65,3 +65,45 @@ def empty(
     lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
     terminator=lambda args: YieldOp(args[0]),
 )
+
+
+def parallel_insert_slice(
+    source,
+    dest,
+    offsets=None,
+    sizes=None,
+    strides=None,
+    static_offsets=None,
+    static_sizes=None,
+    static_strides=None,
+):
+    S = ShapedType.get_dynamic_size()
+    if static_offsets is None:
+        assert offsets is not None
+        static_offsets = [S, S]
+    if static_sizes is None:
+        assert sizes is not None
+        static_sizes = [S, S]
+    if static_strides is None:
+        assert strides is not None
+        static_strides = [S, S]
+    if offsets is None:
+        assert static_offsets
+        offsets = []
+    if sizes is None:
+        assert static_sizes
+        sizes = []
+    if strides is None:
+        assert static_strides
+        strides = []
+
+    return ParallelInsertSliceOp(
+        source,
+        dest,
+        offsets,
+        sizes,
+        strides,
+        static_offsets,
+        static_sizes,
+        static_strides,
+    )

>From 82d66d01fca460b3653f23df1842345c208c947d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:25:39 -0400
Subject: [PATCH 7/7] fix scf

---
 mlir/python/mlir/dialects/scf.py |  9 +++++----
 mlir/test/python/dialects/scf.py | 12 ++++++------
 2 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 95a5e1c901e66..3b58d5c1c48b6 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -271,8 +271,9 @@ def _base(
         params = [lower_bounds, upper_bounds, steps]
         for i, p in enumerate(params):
             for j, pp in enumerate(p):
-                if isinstance(p, int):
-                    pp = arith.constant(IndexType.get(), p)
+                if isinstance(pp, int):
+                    pp = arith.constant(IndexType.get(), pp)
+                assert isinstance(pp, Value), f"expected ir.Value, got {type(pp)=}"
                 if not is_index_type(pp.type):
                     pp = arith.index_cast(pp)
                 p[j] = pp
@@ -343,8 +344,8 @@ def __init__(self, operands, num_reductions, *, loc=None, ip=None):
             self.regions[i].blocks.append(operands[i].type, operands[i].type)
 
 
-def reduce_(*operands, num_reductions=1):
-    return ReduceOp(operands, num_reductions, loc=loc)
+def reduce_(*operands, num_reductions=1, loc=None, ip=None):
+    return ReduceOp(operands, num_reductions, loc=loc, ip=ip)
 
 
 reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 45ac71e4b0bc7..e374a6114662a 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -85,7 +85,7 @@ def test_parange_inits_with_for():
 
         @scf.reduce(ten2)
         def res(lhs: tensor_type, rhs: tensor_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
     # CHECK:  %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
     # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
@@ -115,11 +115,11 @@ def test_parange_inits_with_for_with_two_reduce():
 
         @scf.reduce(i, j, num_reductions=2)
         def res1(lhs: index_type, rhs: index_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
         @scf.another_reduce(res1)
         def res2(lhs: index_type, rhs: index_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
     # CHECK:  %[[VAL_0:.*]] = arith.constant 1 : index
     # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index
@@ -151,15 +151,15 @@ def test_parange_inits_with_for_with_three_reduce():
 
         @scf.reduce(i, j, k, num_reductions=3)
         def res1(lhs: index_type, rhs: index_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
         @scf.another_reduce(res1)
         def res2(lhs: index_type, rhs: index_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
         @scf.another_reduce(res2)
         def res3(lhs: index_type, rhs: index_type):
-            return lhs + rhs
+            return arith.addi(lhs, rhs)
 
     # CHECK:  %[[VAL_0:.*]] = arith.constant 1 : index
     # CHECK:  %[[VAL_1:.*]] = arith.constant 1 : index



More information about the Mlir-commits mailing list