[Mlir-commits] [mlir] [mlir][python] fix up affine for (PR #74495)
Maksim Levental
llvmlistbot at llvm.org
Tue Dec 5 11:39:13 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/74495
>From 14fa3249e69046dd48577f73ff56fa5ae768e667 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 5 Dec 2023 10:42:21 -0600
Subject: [PATCH] [mlir][python] fix up affine for
---
mlir/python/mlir/dialects/_ods_common.py | 3 +
mlir/python/mlir/dialects/affine.py | 85 ++++++-----
mlir/test/python/dialects/affine.py | 173 ++++++++++++++++-------
3 files changed, 168 insertions(+), 93 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..20ec08400d081 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..c7373c23548ce 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
try:
from ..ir import *
@@ -12,6 +11,8 @@
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
+ ResultValueT as _ResultValueT,
+ VariadicResultValueT as _VariadicResultValueT,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +22,17 @@
@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
- """Specialization for the Affine for op class"""
+ """Specialization for the Affine for op class."""
def __init__(
self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ lower_bound: Union[int, ResultValueT, AffineMap],
+ upper_bound: Optional[Union[int, ResultValueT, AffineMap]] = None,
+ step: Optional[Union[int, ResultValueT]] = None,
+ iter_args: Optional[ResultValueT] = None,
*,
- lower_bound_operands=[],
- upper_bound_operands=[],
+ lower_bound_operands: Optional[VariadicResultValueT] = None,
+ upper_bound_operands: Optional[VariadicResultValueT] = None,
loc=None,
ip=None,
):
@@ -45,23 +46,40 @@ def __init__(
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
then symbols in the `lower_bound` affine map, in an increasing order
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
- then symbols in the `upper_bound` affine map, in an increasing order
+ then symbols in the `upper_bound` affine map, in an increasing order.
"""
+ if lower_bound_operands is None:
+ lower_bound_operands = []
+ if upper_bound_operands is None:
+ upper_bound_operands = []
+
+ if step is None:
+ step = 1
+ if upper_bound is None:
+ upper_bound, lower_bound = lower_bound, 0
+
+ if isinstance(lower_bound, int):
+ lower_bound = AffineMap.get_constant(lower_bound)
+ elif isinstance(lower_bound, ResultValueT):
+ lower_bound_operands.append(lower_bound)
+ lower_bound = AffineMap.get_constant(1)
+
+ if not isinstance(lower_bound, AffineMap):
+ raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
+
+ if isinstance(upper_bound, int):
+ upper_bound = AffineMap.get_constant(upper_bound)
+ elif isinstance(upper_bound, ResultValueT):
+ upper_bound_operands.append(upper_bound)
+ upper_bound = AffineMap.get_constant(1)
+
+ if not isinstance(upper_bound, AffineMap):
+ raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
+
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
- if len(lower_bound_operands) != lower_bound.n_inputs:
- raise ValueError(
- f"Wrong number of lower bound operands passed to AffineForOp. "
- + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
- )
-
- if len(upper_bound_operands) != upper_bound.n_inputs:
- raise ValueError(
- f"Wrong number of upper bound operands passed to AffineForOp. "
- + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
- )
results = [arg.type for arg in iter_args]
super().__init__(
@@ -71,7 +89,7 @@ def __init__(
inits=list(iter_args),
lowerBoundMap=AffineMapAttr.get(lower_bound),
upperBoundMap=AffineMapAttr.get(upper_bound),
- step=IntegerAttr.get(IndexType.get(), step),
+ step=step,
loc=loc,
ip=ip,
)
@@ -105,30 +123,11 @@ def for_(
loc=None,
ip=None,
):
- if step is None:
- step = 1
- if stop is None:
- stop = start
- start = 0
- params = [start, stop]
- for i, p in enumerate(params):
- if isinstance(p, int):
- p = constant(IntegerAttr.get(IndexType.get(), p))
- elif isinstance(p, float):
- raise ValueError(f"{p=} must be int.")
- params[i] = p
-
- start, stop = params
- s0 = AffineSymbolExpr.get(0)
- lbmap = AffineMap.get(0, 1, [s0])
- ubmap = AffineMap.get(0, 1, [s0])
for_op = AffineForOp(
- lbmap,
- ubmap,
+ start,
+ stop,
step,
iter_args=iter_args,
- lower_bound_operands=[start],
- upper_bound_operands=[stop],
loc=loc,
ip=ip,
)
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index df42f8fcf1a57..737044b293f8c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -5,6 +5,7 @@
from mlir.dialects import arith
from mlir.dialects import memref
from mlir.dialects import affine
+import mlir.extras.types as T
def constructAndPrintInModule(f):
@@ -115,58 +116,130 @@ def affine_for_op_test(buffer):
@constructAndPrintInModule
def testForSugar():
- index_type = IndexType.get()
- memref_t = MemRefType.get([10], index_type)
+ memref_t = T.memref(10, T.index())
range = affine.for_
- # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
- # CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
- # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
- # CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_1(lb, ub, step, memref_v):
- for i in range(lb, 10, 2):
+ # CHECK-LABEL: func.func @range_loop_1(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_1(lb, ub, memref_v):
+ for i in range(lb, ub, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_2(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_2(lb, ub, memref_v):
+ for i in range(lb, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_3(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_3(lb, ub, memref_v):
+ for i in range(0, ub, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_4(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_4(lb, ub, memref_v):
+ for i in range(0, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_5(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_5(lb, ub, memref_v):
+ for i in range(0, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_6(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_6(lb, ub, memref_v):
+ for i in range(0, 10):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
-
- # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
- # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
- # CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
- # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
- # CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_2(lb, ub, step, memref_v):
- for i in range(0, 10, 1):
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_7(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_7(lb, ub, memref_v):
+ for i in range(10):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
-
- # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
- # CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
- # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
- # CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_3(lb, ub, step, memref_v):
- for i in range(0, ub, 1):
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_8(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
+ # CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+ # CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+ # CHECK: affine.yield %[[VAL_5]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_8(lb, ub, memref_v):
+ for i, it in range(10, iter_args=[memref_v]):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
+ memref.store(add, it, [i])
+ affine.yield_([it])
More information about the Mlir-commits
mailing list