[Mlir-commits] [mlir] Makslevental/scf forall parallel (PR #150243)
Maksim Levental
llvmlistbot at llvm.org
Wed Jul 23 09:25:53 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/150243
>From 159214a4928bb0e3313e270f8f3255c24d3f18de Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:43:45 -0400
Subject: [PATCH 1/7] move helpers
---
mlir/python/mlir/dialects/arith.py | 26 +----
.../dialects/linalg/opdsl/lang/emitter.py | 106 +++++++-----------
mlir/python/mlir/util.py | 47 ++++++++
3 files changed, 89 insertions(+), 90 deletions(-)
create mode 100644 mlir/python/mlir/util.py
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce66..3b60ed8ddf94c 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,9 +5,11 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
+from ..util import is_integer_type, is_index_type, is_float_type
from array import array as _array
from typing import overload
+
try:
from ..ir import *
from ._ods_common import (
@@ -21,26 +23,6 @@
raise RuntimeError("Error loading imports from extension module") from e
-def _isa(obj: Any, cls: type):
- try:
- cls(obj)
- except ValueError:
- return False
- return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
- return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
- return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
- return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
@@ -96,9 +78,9 @@ def value(self):
@property
def literal_value(self) -> Union[int, float]:
- if _is_integer_like_type(self.type):
+ if is_integer_type(self.type) or is_index_type(self.type):
return IntegerAttr(self.value).value
- elif _is_float_type(self.type):
+ elif is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..cae70fc03b9d6 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -5,6 +5,13 @@
from typing import Callable, Dict, List, Sequence, Tuple, Union
from .....ir import *
+from .....util import (
+ is_complex_type,
+ is_float_type,
+ is_index_type,
+ is_integer_type,
+ get_floating_point_width,
+)
from .... import func
from .... import linalg
@@ -412,9 +419,9 @@ def _cast(
)
if operand.type == to_type:
return operand
- if _is_integer_type(to_type):
+ if is_integer_type(to_type):
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
- elif _is_floating_point_type(to_type):
+ elif is_float_type(to_type):
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
def _cast_to_integer(
@@ -422,11 +429,11 @@ def _cast_to_integer(
) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
- if _is_floating_point_type(operand_type):
+ if is_float_type(operand_type):
if is_unsigned_cast:
return arith.FPToUIOp(to_type, operand).result
return arith.FPToSIOp(to_type, operand).result
- if _is_index_type(operand_type):
+ if is_index_type(operand_type):
return arith.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
@@ -444,13 +451,13 @@ def _cast_to_floating_point(
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
operand_type = operand.type
- if _is_integer_type(operand_type):
+ if is_integer_type(operand_type):
if is_unsigned_cast:
return arith.UIToFPOp(to_type, operand).result
return arith.SIToFPOp(to_type, operand).result
# Assume FloatType.
- to_width = _get_floating_point_width(to_type)
- from_width = _get_floating_point_width(operand_type)
+ to_width = get_floating_point_width(to_type)
+ from_width = get_floating_point_width(operand_type)
if to_width > from_width:
return arith.ExtFOp(to_type, operand).result
elif to_width < from_width:
@@ -466,89 +473,89 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
return self._cast(type_var_name, operand, True)
def _unary_exp(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _unary_log(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _unary_abs(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return math.AbsFOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")
def _unary_ceil(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return math.CeilOp(x).result
raise NotImplementedError("Unsupported 'ceil' operand: {x}")
def _unary_floor(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return math.FloorOp(x).result
raise NotImplementedError("Unsupported 'floor' operand: {x}")
def _unary_negf(self, x: Value) -> Value:
- if _is_floating_point_type(x.type):
+ if is_float_type(x.type):
return arith.NegFOp(x).result
- if _is_complex_type(x.type):
+ if is_complex_type(x.type):
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if is_complex_type(lhs.type):
return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.SubFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if is_complex_type(lhs.type):
return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.MulFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
- if _is_complex_type(lhs.type):
+ if is_complex_type(lhs.type):
return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
- if _is_floating_point_type(lhs.type):
+ if is_float_type(lhs.type):
return arith.MinimumFOp(lhs, rhs).result
- if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ if is_integer_type(lhs.type) or is_index_type(lhs.type):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
@@ -609,40 +616,3 @@ def _add_type_mapping(
)
type_mapping[name] = element_or_self_type
block_arg_types.append(element_or_self_type)
-
-
-def _is_complex_type(t: Type) -> bool:
- return ComplexType.isinstance(t)
-
-
-def _is_floating_point_type(t: Type) -> bool:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- return (
- F64Type.isinstance(t)
- or F32Type.isinstance(t)
- or F16Type.isinstance(t)
- or BF16Type.isinstance(t)
- )
-
-
-def _is_integer_type(t: Type) -> bool:
- return IntegerType.isinstance(t)
-
-
-def _is_index_type(t: Type) -> bool:
- return IndexType.isinstance(t)
-
-
-def _get_floating_point_width(t: Type) -> int:
- # TODO: Create a FloatType in the Python API and implement the switch
- # there.
- if F64Type.isinstance(t):
- return 64
- if F32Type.isinstance(t):
- return 32
- if F16Type.isinstance(t):
- return 16
- if BF16Type.isinstance(t):
- return 16
- raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py
new file mode 100644
index 0000000000000..cc85a99337f38
--- /dev/null
+++ b/mlir/python/mlir/util.py
@@ -0,0 +1,47 @@
+from .ir import (
+ BF16Type,
+ ComplexType,
+ F16Type,
+ F32Type,
+ F64Type,
+ IndexType,
+ IntegerType,
+ Type,
+)
+
+
+def is_complex_type(t: Type) -> bool:
+ return ComplexType.isinstance(t)
+
+
+def is_float_type(t: Type) -> bool:
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ return (
+ F64Type.isinstance(t)
+ or F32Type.isinstance(t)
+ or F16Type.isinstance(t)
+ or BF16Type.isinstance(t)
+ )
+
+
+def is_integer_type(t: Type) -> bool:
+ return IntegerType.isinstance(t)
+
+
+def is_index_type(t: Type) -> bool:
+ return IndexType.isinstance(t)
+
+
+def get_floating_point_width(t: Type) -> int:
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ if F64Type.isinstance(t):
+ return 64
+ if F32Type.isinstance(t):
+ return 32
+ if F16Type.isinstance(t):
+ return 16
+ if BF16Type.isinstance(t):
+ return 16
+ raise NotImplementedError(f"Unhandled floating point type switch {t}")
>From f68d0669a18a3d5769fb20a83e8832db688a82f4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:43:56 -0400
Subject: [PATCH 2/7] add index_cast
---
mlir/python/mlir/dialects/arith.py | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 3b60ed8ddf94c..418ec0bdbb6db 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -90,3 +90,19 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+def index_cast(
+ in_: Value,
+ to: Type = None,
+ *,
+ out: Type = None,
+ loc: Location = None,
+ ip: InsertionPoint = None,
+) -> Value:
+ if bool(to) != bool(out):
+ raise ValueError("either `to` or `out` must be set but not both")
+ res_type = out or to
+ if res_type is None:
+ res_type = IndexType.get()
+ return _get_op_result_or_op_results(IndexCastOp(res_type, in_, loc=loc, ip=ip))
>From f67bbd06546541fb4a0a68a59641a444c3b9116c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:44:11 -0400
Subject: [PATCH 3/7] add region_adder
---
mlir/python/mlir/extras/meta.py | 16 +++++++++++++++-
1 file changed, 15 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py
index 3f2defadf7941..fabe1d8e141ed 100644
--- a/mlir/python/mlir/extras/meta.py
+++ b/mlir/python/mlir/extras/meta.py
@@ -6,7 +6,7 @@
from functools import wraps
from ..dialects._ods_common import get_op_result_or_op_results
-from ..ir import Type, InsertionPoint
+from ..ir import Type, InsertionPoint, Value
def op_region_builder(op, op_region, terminator=None):
@@ -81,3 +81,17 @@ def maybe_no_args(*args, **kwargs):
return op_decorator(*args, **kwargs)
return maybe_no_args
+
+
+def region_adder(terminator=None):
+ def wrapper(op_region_adder):
+ def region_adder_decorator(op, *args, **kwargs):
+ if isinstance(op, Value):
+ op = op.owner.opview
+ region = op_region_adder(op, *args, **kwargs)
+
+ return op_region_builder(op, region, terminator)
+
+ return region_adder_decorator
+
+ return wrapper
>From 567f32e7c486349850c3e85c7da8c8c0af2c96a3 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 11:44:36 -0400
Subject: [PATCH 4/7] add forall parallel helper
---
mlir/python/mlir/dialects/scf.py | 105 ++++++++++++++++++++-
mlir/test/python/dialects/scf.py | 152 ++++++++++++++++++++++++++++++-
2 files changed, 254 insertions(+), 3 deletions(-)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..b77cfa8668b65 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -5,10 +5,12 @@
from ._scf_ops_gen import *
from ._scf_ops_gen import _Dialect
-from .arith import constant
+from . import arith
+from ..extras.meta import region_op, region_adder
try:
from ..ir import *
+ from ..util import is_index_type
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
@@ -237,7 +239,7 @@ def for_(
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
- p = constant(IndexType.get(), p)
+ p = arith.constant(IndexType.get(), p)
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p
@@ -254,3 +256,102 @@ def for_(
yield iv, iter_args[0], for_op.results[0]
else:
yield iv
+
+
+def _parfor(op_ctor):
+ def _base(
+ lower_bounds, upper_bounds=None, steps=None, *, loc=None, ip=None, **kwargs
+ ):
+ if upper_bounds is None:
+ upper_bounds = lower_bounds
+ lower_bounds = [0] * len(upper_bounds)
+ if steps is None:
+ steps = [1] * len(lower_bounds)
+
+ params = [lower_bounds, upper_bounds, steps]
+ for i, p in enumerate(params):
+ for j, pp in enumerate(p):
+ if isinstance(p, int):
+ pp = arith.constant(IndexType.get(), p)
+ if not is_index_type(pp.type):
+ pp = arith.index_cast(pp)
+ p[j] = pp
+ params[i] = p
+
+ return op_ctor(*params, loc=loc, ip=ip, **kwargs)
+
+ return _base
+
+
+def _parfor_cm(op_ctor):
+ def _base(*args, **kwargs):
+ for_op = _parfor(op_ctor)(*args, **kwargs)
+ block = for_op.regions[0].blocks[0]
+ block_args = tuple(block.arguments)
+ with InsertionPoint(block):
+ yield block_args
+
+ return _base
+
+
+forall = _parfor_cm(ForallOp)
+
+
+class ParallelOp(ParallelOp):
+ def __init__(
+ self,
+ lower_bounds,
+ upper_bounds,
+ steps,
+ inits: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ assert len(lower_bounds) == len(upper_bounds) == len(steps)
+ if inits is None:
+ inits = []
+ results = [i.type for i in inits]
+ iv_types = [IndexType.get()] * len(lower_bounds)
+ super().__init__(
+ results,
+ lower_bounds,
+ upper_bounds,
+ steps,
+ inits,
+ loc=loc,
+ ip=ip,
+ )
+ self.regions[0].blocks.append(*iv_types)
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variables(self):
+ return self.body.arguments
+
+
+parallel = _parfor_cm(ParallelOp)
+
+
+class ReduceOp(ReduceOp):
+ def __init__(self, operands, num_reductions, *, loc=None, ip=None):
+ super().__init__(operands, num_reductions, loc=loc, ip=ip)
+ for i in range(num_reductions):
+ self.regions[i].blocks.append(operands[i].type, operands[i].type)
+
+
+def reduce_(*operands, num_reductions=1):
+ return ReduceOp(operands, num_reductions, loc=loc)
+
+
+reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))
+
+
+ at region_adder(terminator=lambda xs: reduce_return(*xs))
+def another_reduce(reduce_op):
+ for r in reduce_op.regions:
+ if len(r.blocks[0].operations) == 0:
+ return r
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5e189c8..45ac71e4b0bc7 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -5,7 +5,7 @@
from mlir.dialects import func
from mlir.dialects import memref
from mlir.dialects import scf
-from mlir.passmanager import PassManager
+from mlir.dialects import tensor
def constructAndPrintInModule(f):
@@ -38,6 +38,156 @@ def forall_loop(tensor):
assert loop.verify()
+# CHECK-LABEL: TEST: test_forall_insert_slice_no_region_with_for
+ at constructAndPrintInModule
+def test_forall_insert_slice_no_region_with_for():
+ i32 = IntegerType.get_signless(32)
+ f32 = F32Type.get()
+ ten = tensor.empty([10, 10], i32)
+
+ for i, j, shared_outs in scf.forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]):
+ one = arith.constant(f32, 1.0)
+
+ scf.parallel_insert_slice(
+ ten,
+ shared_outs,
+ offsets=[i, j],
+ static_sizes=[10, 10],
+ static_strides=[1, 1],
+ )
+
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
+ # CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32
+ # CHECK: scf.forall.in_parallel {
+ # CHECK: tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32>
+ # CHECK: }
+ # CHECK: }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for
+ at constructAndPrintInModule
+def test_parange_inits_with_for():
+ i32 = IntegerType.get_signless(32)
+ f32 = F32Type.get()
+ tensor_type = RankedTensorType.get([10, 10], f32)
+ ten = tensor.empty([10, 10], i32)
+
+ for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[ten]):
+ one = arith.constant(f32, 1.0)
+ ten2 = tensor.empty([10, 10], i32)
+
+ @scf.reduce(ten2)
+ def res(lhs: tensor_type, rhs: tensor_type):
+ return lhs + rhs
+
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_7:.*]] = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]]) -> tensor<10x10xi32> {
+ # CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32
+ # CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<10x10xi32>
+ # CHECK: scf.reduce(%[[VAL_11]] : tensor<10x10xi32>) {
+ # CHECK: ^bb0(%[[VAL_12:.*]]: tensor<10x10xi32>, %[[VAL_13:.*]]: tensor<10x10xi32>):
+ # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : tensor<10x10xi32>
+ # CHECK: scf.reduce.return %[[VAL_14]] : tensor<10x10xi32>
+ # CHECK: }
+ # CHECK: }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for_with_two_reduce
+ at constructAndPrintInModule
+def test_parange_inits_with_for_with_two_reduce():
+ index_type = IndexType.get()
+ one = arith.constant(index_type, 1)
+
+ for i, j in scf.parallel([1, 1], [2, 2], [3, 3], inits=[one, one]):
+
+ @scf.reduce(i, j, num_reductions=2)
+ def res1(lhs: index_type, rhs: index_type):
+ return lhs + rhs
+
+ @scf.another_reduce(res1)
+ def res2(lhs: index_type, rhs: index_type):
+ return lhs + rhs
+
+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_7:.*]]:2 = scf.parallel (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) init (%[[VAL_0]], %[[VAL_0]]) -> (index, index) {
+ # CHECK: scf.reduce(%[[VAL_8]], %[[VAL_9]] : index, index) {
+ # CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
+ # CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_10]], %[[VAL_11]] : index
+ # CHECK: scf.reduce.return %[[VAL_12]] : index
+ # CHECK: }, {
+ # CHECK: ^bb0(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index):
+ # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
+ # CHECK: scf.reduce.return %[[VAL_15]] : index
+ # CHECK: }
+ # CHECK: }
+
+
+# CHECK-LABEL: TEST: test_parange_inits_with_for_with_three_reduce
+ at constructAndPrintInModule
+def test_parange_inits_with_for_with_three_reduce():
+ index_type = IndexType.get()
+ one = arith.constant(index_type, 1)
+
+ for i, j, k in scf.parallel([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]):
+
+ @scf.reduce(i, j, k, num_reductions=3)
+ def res1(lhs: index_type, rhs: index_type):
+ return lhs + rhs
+
+ @scf.another_reduce(res1)
+ def res2(lhs: index_type, rhs: index_type):
+ return lhs + rhs
+
+ @scf.another_reduce(res2)
+ def res3(lhs: index_type, rhs: index_type):
+ return lhs + rhs
+
+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_9:.*]] = arith.constant 3 : index
+ # CHECK: %[[VAL_10:.*]]:3 = scf.parallel (%[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]]) = (%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) to (%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) step (%[[VAL_7]], %[[VAL_8]], %[[VAL_9]]) init (%[[VAL_0]], %[[VAL_0]], %[[VAL_0]]) -> (index, index, index) {
+ # CHECK: scf.reduce(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : index, index, index) {
+ # CHECK: ^bb0(%[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index):
+ # CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index
+ # CHECK: scf.reduce.return %[[VAL_16]] : index
+ # CHECK: }, {
+ # CHECK: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index):
+ # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
+ # CHECK: scf.reduce.return %[[VAL_19]] : index
+ # CHECK: }, {
+ # CHECK: ^bb0(%[[VAL_20:.*]]: index, %[[VAL_21:.*]]: index):
+ # CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index
+ # CHECK: scf.reduce.return %[[VAL_22]] : index
+ # CHECK: }
+ # CHECK: }
+
+
# CHECK-LABEL: TEST: testSimpleLoop
@constructAndPrintInModule
def testSimpleLoop():
>From 3f645bb8672503e577b490bd2951bf20628bd9b7 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:24:38 -0400
Subject: [PATCH 5/7] fix util helpers
---
mlir/python/CMakeLists.txt | 1 +
mlir/python/mlir/dialects/memref.py | 5 +++--
mlir/python/mlir/util.py | 4 ++++
mlir/test/python/dialects/arith_dialect.py | 7 +++----
4 files changed, 11 insertions(+), 6 deletions(-)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..f01798f48ff86 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
+ util.py
# The main _mlir module has submodules: include stubs from each.
_mlir_libs/_mlir/__init__.pyi
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index bc9a3a52728ad..2130a1966d88e 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -7,8 +7,9 @@
from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
-from .arith import ConstantOp, _is_integer_like_type
+from .arith import ConstantOp
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
+from ..util import is_integer_like_type
def _is_constant_int_like(i):
@@ -16,7 +17,7 @@ def _is_constant_int_like(i):
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
- and _is_integer_like_type(i.type)
+ and is_integer_like_type(i.type)
)
diff --git a/mlir/python/mlir/util.py b/mlir/python/mlir/util.py
index cc85a99337f38..453b74777014f 100644
--- a/mlir/python/mlir/util.py
+++ b/mlir/python/mlir/util.py
@@ -33,6 +33,10 @@ def is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)
+def is_integer_like_type(t: Type) -> bool:
+ return is_integer_type(t) or is_index_type(t)
+
+
def get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index c9af5e7b46db8..0a197c4e673f9 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
+from mlir.util import is_float_type, is_integer_like_type
from array import array
@@ -42,11 +43,9 @@ def testFastMathFlags():
def testArithValue():
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = op.capitalize()
- if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+ if is_float_type(lhs.type) and is_float_type(rhs.type):
op += "F"
- elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
- lhs.type
- ):
+ elif is_integer_like_type(lhs.type) and is_integer_like_type(lhs.type):
op += "I"
else:
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
>From 7f3b59ef83930607ce2d5b3b6796716659312199 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:25:00 -0400
Subject: [PATCH 6/7] add parallel_insert_slice
---
mlir/python/mlir/dialects/scf.py | 31 +++++++++++++++++++++
mlir/python/mlir/dialects/tensor.py | 42 +++++++++++++++++++++++++++++
2 files changed, 73 insertions(+)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index b77cfa8668b65..95a5e1c901e66 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -355,3 +355,34 @@ def another_reduce(reduce_op):
for r in reduce_op.regions:
if len(r.blocks[0].operations) == 0:
return r
+
+
+ at region_op
+def in_parallel():
+ return InParallelOp()
+
+
+def parallel_insert_slice(
+ source,
+ dest,
+ static_offsets=None,
+ static_sizes=None,
+ static_strides=None,
+ offsets=None,
+ sizes=None,
+ strides=None,
+):
+ from . import tensor
+
+ @in_parallel
+ def foo():
+ tensor.parallel_insert_slice(
+ source,
+ dest,
+ offsets,
+ sizes,
+ strides,
+ static_offsets,
+ static_sizes,
+ static_strides,
+ )
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 146b5f85d07f5..b1baa22b15e23 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -65,3 +65,45 @@ def empty(
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
terminator=lambda args: YieldOp(args[0]),
)
+
+
+def parallel_insert_slice(
+ source,
+ dest,
+ offsets=None,
+ sizes=None,
+ strides=None,
+ static_offsets=None,
+ static_sizes=None,
+ static_strides=None,
+):
+ S = ShapedType.get_dynamic_size()
+ if static_offsets is None:
+ assert offsets is not None
+ static_offsets = [S, S]
+ if static_sizes is None:
+ assert sizes is not None
+ static_sizes = [S, S]
+ if static_strides is None:
+ assert strides is not None
+ static_strides = [S, S]
+ if offsets is None:
+ assert static_offsets
+ offsets = []
+ if sizes is None:
+ assert static_sizes
+ sizes = []
+ if strides is None:
+ assert static_strides
+ strides = []
+
+ return ParallelInsertSliceOp(
+ source,
+ dest,
+ offsets,
+ sizes,
+ strides,
+ static_offsets,
+ static_sizes,
+ static_strides,
+ )
>From 82d66d01fca460b3653f23df1842345c208c947d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 23 Jul 2025 12:25:39 -0400
Subject: [PATCH 7/7] fix scf
---
mlir/python/mlir/dialects/scf.py | 9 +++++----
mlir/test/python/dialects/scf.py | 12 ++++++------
2 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 95a5e1c901e66..3b58d5c1c48b6 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -271,8 +271,9 @@ def _base(
params = [lower_bounds, upper_bounds, steps]
for i, p in enumerate(params):
for j, pp in enumerate(p):
- if isinstance(p, int):
- pp = arith.constant(IndexType.get(), p)
+ if isinstance(pp, int):
+ pp = arith.constant(IndexType.get(), pp)
+ assert isinstance(pp, Value), f"expected ir.Value, got {type(pp)=}"
if not is_index_type(pp.type):
pp = arith.index_cast(pp)
p[j] = pp
@@ -343,8 +344,8 @@ def __init__(self, operands, num_reductions, *, loc=None, ip=None):
self.regions[i].blocks.append(operands[i].type, operands[i].type)
-def reduce_(*operands, num_reductions=1):
- return ReduceOp(operands, num_reductions, loc=loc)
+def reduce_(*operands, num_reductions=1, loc=None, ip=None):
+ return ReduceOp(operands, num_reductions, loc=loc, ip=ip)
reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 45ac71e4b0bc7..e374a6114662a 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -85,7 +85,7 @@ def test_parange_inits_with_for():
@scf.reduce(ten2)
def res(lhs: tensor_type, rhs: tensor_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<10x10xi32>
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
@@ -115,11 +115,11 @@ def test_parange_inits_with_for_with_two_reduce():
@scf.reduce(i, j, num_reductions=2)
def res1(lhs: index_type, rhs: index_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
@scf.another_reduce(res1)
def res2(lhs: index_type, rhs: index_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
@@ -151,15 +151,15 @@ def test_parange_inits_with_for_with_three_reduce():
@scf.reduce(i, j, k, num_reductions=3)
def res1(lhs: index_type, rhs: index_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
@scf.another_reduce(res1)
def res2(lhs: index_type, rhs: index_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
@scf.another_reduce(res2)
def res3(lhs: index_type, rhs: index_type):
- return lhs + rhs
+ return arith.addi(lhs, rhs)
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
More information about the Mlir-commits
mailing list