[Mlir-commits] [mlir] [mlir][python] fix up affine for (PR #74495)

Maksim Levental llvmlistbot at llvm.org
Wed Dec 6 10:41:27 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/74495

>From c333f86e6825110b51d34d962fb3b268f9a34629 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 5 Dec 2023 10:42:21 -0600
Subject: [PATCH 1/3] [mlir][python] fix up affine for

---
 mlir/python/mlir/dialects/_ods_common.py |   3 +
 mlir/python/mlir/dialects/affine.py      |  85 ++++++-----
 mlir/test/python/dialects/affine.py      | 173 ++++++++++++++++-------
 3 files changed, 168 insertions(+), 93 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..20ec08400d081 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
 # see the typing.Type doc string.
 _U = _TypeVar("_U", bound=_cext.ir.Value)
 SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..834a8cccc7c71 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
 
 try:
     from ..ir import *
@@ -12,6 +11,8 @@
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
         _cext as _ods_cext,
+        ResultValueT as _ResultValueT,
+        VariadicResultValueT as _VariadicResultValueT,
     )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +22,17 @@
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class AffineForOp(AffineForOp):
-    """Specialization for the Affine for op class"""
+    """Specialization for the Affine for op class."""
 
     def __init__(
         self,
-        lower_bound,
-        upper_bound,
-        step,
-        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        lower_bound: Union[int, _ResultValueT, AffineMap],
+        upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
+        step: Optional[Union[int, _ResultValueT]] = None,
+        iter_args: Optional[_ResultValueT] = None,
         *,
-        lower_bound_operands=[],
-        upper_bound_operands=[],
+        lower_bound_operands: Optional[_VariadicResultValueT] = None,
+        upper_bound_operands: Optional[_VariadicResultValueT] = None,
         loc=None,
         ip=None,
     ):
@@ -45,23 +46,40 @@ def __init__(
         - `lower_bound_operands` is the list of arguments to substitute the dimensions,
           then symbols in the `lower_bound` affine map, in an increasing order
         - `upper_bound_operands` is the list of arguments to substitute the dimensions,
-          then symbols in the `upper_bound` affine map, in an increasing order
+          then symbols in the `upper_bound` affine map, in an increasing order.
         """
 
+        if lower_bound_operands is None:
+            lower_bound_operands = []
+        if upper_bound_operands is None:
+            upper_bound_operands = []
+
+        if step is None:
+            step = 1
+        if upper_bound is None:
+            upper_bound, lower_bound = lower_bound, 0
+
+        if isinstance(lower_bound, int):
+            lower_bound = AffineMap.get_constant(lower_bound)
+        elif isinstance(lower_bound, _ResultValueT):
+            lower_bound_operands.append(lower_bound)
+            lower_bound = AffineMap.get_constant(1)
+
+        if not isinstance(lower_bound, AffineMap):
+            raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
+
+        if isinstance(upper_bound, int):
+            upper_bound = AffineMap.get_constant(upper_bound)
+        elif isinstance(upper_bound, _ResultValueT):
+            upper_bound_operands.append(upper_bound)
+            upper_bound = AffineMap.get_constant(1)
+
+        if not isinstance(upper_bound, AffineMap):
+            raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
+
         if iter_args is None:
             iter_args = []
         iter_args = _get_op_results_or_values(iter_args)
-        if len(lower_bound_operands) != lower_bound.n_inputs:
-            raise ValueError(
-                f"Wrong number of lower bound operands passed to AffineForOp. "
-                + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
-            )
-
-        if len(upper_bound_operands) != upper_bound.n_inputs:
-            raise ValueError(
-                f"Wrong number of upper bound operands passed to AffineForOp. "
-                + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
-            )
 
         results = [arg.type for arg in iter_args]
         super().__init__(
@@ -71,7 +89,7 @@ def __init__(
             inits=list(iter_args),
             lowerBoundMap=AffineMapAttr.get(lower_bound),
             upperBoundMap=AffineMapAttr.get(upper_bound),
-            step=IntegerAttr.get(IndexType.get(), step),
+            step=step,
             loc=loc,
             ip=ip,
         )
@@ -105,30 +123,11 @@ def for_(
     loc=None,
     ip=None,
 ):
-    if step is None:
-        step = 1
-    if stop is None:
-        stop = start
-        start = 0
-    params = [start, stop]
-    for i, p in enumerate(params):
-        if isinstance(p, int):
-            p = constant(IntegerAttr.get(IndexType.get(), p))
-        elif isinstance(p, float):
-            raise ValueError(f"{p=} must be int.")
-        params[i] = p
-
-    start, stop = params
-    s0 = AffineSymbolExpr.get(0)
-    lbmap = AffineMap.get(0, 1, [s0])
-    ubmap = AffineMap.get(0, 1, [s0])
     for_op = AffineForOp(
-        lbmap,
-        ubmap,
+        start,
+        stop,
         step,
         iter_args=iter_args,
-        lower_bound_operands=[start],
-        upper_bound_operands=[stop],
         loc=loc,
         ip=ip,
     )
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index df42f8fcf1a57..737044b293f8c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -5,6 +5,7 @@
 from mlir.dialects import arith
 from mlir.dialects import memref
 from mlir.dialects import affine
+import mlir.extras.types as T
 
 
 def constructAndPrintInModule(f):
@@ -115,58 +116,130 @@ def affine_for_op_test(buffer):
 
 @constructAndPrintInModule
 def testForSugar():
-    index_type = IndexType.get()
-    memref_t = MemRefType.get([10], index_type)
+    memref_t = T.memref(10, T.index())
     range = affine.for_
 
-    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
-    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
-    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
-    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_1(lb, ub, step, memref_v):
-        for i in range(lb, 10, 2):
+    # CHECK-LABEL:   func.func @range_loop_1(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_1(lb, ub, memref_v):
+        for i in range(lb, ub, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_2(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_2(lb, ub, memref_v):
+        for i in range(lb, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_3(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_3(lb, ub, memref_v):
+        for i in range(0, ub, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_4(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_4(lb, ub, memref_v):
+        for i in range(0, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_5(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_5(lb, ub, memref_v):
+        for i in range(0, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_6(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_6(lb, ub, memref_v):
+        for i in range(0, 10):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
-
-    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
-    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
-    # CHECK:    affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
-    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
-    # CHECK:      affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_2(lb, ub, step, memref_v):
-        for i in range(0, 10, 1):
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_7(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_7(lb, ub, memref_v):
+        for i in range(10):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
-
-    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
-    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
-    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
-    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_3(lb, ub, step, memref_v):
-        for i in range(0, ub, 1):
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_8(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
+    # CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+    # CHECK:             memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+    # CHECK:             affine.yield %[[VAL_5]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_8(lb, ub, memref_v):
+        for i, it in range(10, iter_args=[memref_v]):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
+            memref.store(add, it, [i])
+            affine.yield_([it])

>From 829df046a304739fde172a30046e68a2a4feabfc Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 5 Dec 2023 14:33:40 -0600
Subject: [PATCH 2/3] Update affine.py

---
 mlir/python/mlir/dialects/affine.py | 18 ++++++++++++++----
 mlir/test/python/dialects/affine.py | 11 +++++------
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 834a8cccc7c71..243951a7ef547 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -61,18 +61,28 @@ def __init__(
 
         if isinstance(lower_bound, int):
             lower_bound = AffineMap.get_constant(lower_bound)
-        elif isinstance(lower_bound, _ResultValueT):
+        elif isinstance(lower_bound, (Operation, OpView, Value)):
+            if len(lower_bound_operands):
+                raise ValueError(
+                    f"Either a concrete lower bound or an AffineMap in combination "
+                    f"with lower bound operands, but not both, is supported."
+                )
             lower_bound_operands.append(lower_bound)
-            lower_bound = AffineMap.get_constant(1)
+            lower_bound = AffineMap.get_identity(1)
 
         if not isinstance(lower_bound, AffineMap):
             raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
 
         if isinstance(upper_bound, int):
             upper_bound = AffineMap.get_constant(upper_bound)
-        elif isinstance(upper_bound, _ResultValueT):
+        elif isinstance(upper_bound, (Operation, OpView, Value)):
+            if len(upper_bound_operands):
+                raise ValueError(
+                    f"Either a concrete upper bound or an AffineMap in combination "
+                    f"with upper bound operands, but not both, is supported."
+                )
             upper_bound_operands.append(upper_bound)
-            upper_bound = AffineMap.get_constant(1)
+            upper_bound = AffineMap.get_identity(1)
 
         if not isinstance(upper_bound, AffineMap):
             raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 737044b293f8c..4789bd8560989 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -119,12 +119,13 @@ def testForSugar():
     memref_t = T.memref(10, T.index())
     range = affine.for_
 
+    # CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
+
     # CHECK-LABEL:   func.func @range_loop_1(
     # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
+    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to #[[$ATTR_2]](%[[VAL_1]]) {
     # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
     # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:             affine.yield
     # CHECK:           }
     # CHECK:           return
     # CHECK:         }
@@ -138,10 +139,9 @@ def range_loop_1(lb, ub, memref_v):
 
     # CHECK-LABEL:   func.func @range_loop_2(
     # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
+    # CHECK:           affine.for %[[VAL_3:.*]] = #[[$ATTR_2]](%[[VAL_0]]) to 10 {
     # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
     # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:             affine.yield
     # CHECK:           }
     # CHECK:           return
     # CHECK:         }
@@ -154,10 +154,9 @@ def range_loop_2(lb, ub, memref_v):
 
     # CHECK-LABEL:   func.func @range_loop_3(
     # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to #[[$ATTR_2]](%[[VAL_1]]) {
     # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
     # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:             affine.yield
     # CHECK:           }
     # CHECK:           return
     # CHECK:         }

>From 168ef88920639e66d79e9dfcba9100ae0c7ff2b4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 6 Dec 2023 12:41:17 -0600
Subject: [PATCH 3/3] incorporate comments

---
 mlir/python/mlir/dialects/_ods_common.py |   3 +-
 mlir/python/mlir/dialects/affine.py      |  62 +++++++------
 mlir/test/python/dialects/affine.py      | 110 +++++++++++++----------
 3 files changed, 98 insertions(+), 77 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 20ec08400d081..63558a7915ff3 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -135,5 +135,6 @@ def get_op_result_or_op_results(
 _U = _TypeVar("_U", bound=_cext.ir.Value)
 SubClassValueT = _Type[_U]
 
-ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
+ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
+ResultValueT = _Union[*ResultValueTypeTuple]
 VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 243951a7ef547..913cea61105ce 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -11,6 +11,7 @@
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
         _cext as _ods_cext,
+        ResultValueTypeTuple as _ResultValueTypeTuple,
         ResultValueT as _ResultValueT,
         VariadicResultValueT as _VariadicResultValueT,
     )
@@ -27,8 +28,8 @@ class AffineForOp(AffineForOp):
     def __init__(
         self,
         lower_bound: Union[int, _ResultValueT, AffineMap],
-        upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
-        step: Optional[Union[int, _ResultValueT]] = None,
+        upper_bound: Optional[Union[int, _ResultValueT, AffineMap]],
+        step: Optional[Union[int, Attribute]] = None,
         iter_args: Optional[_ResultValueT] = None,
         *,
         lower_bound_operands: Optional[_VariadicResultValueT] = None,
@@ -44,7 +45,7 @@ def __init__(
         - `iter_args` is a list of additional loop-carried arguments or an operation
           producing them as results.
         - `lower_bound_operands` is the list of arguments to substitute the dimensions,
-          then symbols in the `lower_bound` affine map, in an increasing order
+          then symbols in the `lower_bound` affine map, in an increasing order.
         - `upper_bound_operands` is the list of arguments to substitute the dimensions,
           then symbols in the `upper_bound` affine map, in an increasing order.
         """
@@ -56,36 +57,41 @@ def __init__(
 
         if step is None:
             step = 1
-        if upper_bound is None:
-            upper_bound, lower_bound = lower_bound, 0
 
-        if isinstance(lower_bound, int):
-            lower_bound = AffineMap.get_constant(lower_bound)
-        elif isinstance(lower_bound, (Operation, OpView, Value)):
-            if len(lower_bound_operands):
+        bounds_operands = [lower_bound_operands, upper_bound_operands]
+        bounds = [lower_bound, upper_bound]
+        bounds_names = ["lower", "upper"]
+        for i, name in enumerate(bounds_names):
+            if isinstance(bounds[i], int):
+                bounds[i] = AffineMap.get_constant(bounds[i])
+            elif isinstance(bounds[i], _ResultValueTypeTuple):
+                if len(bounds_operands[i]):
+                    raise ValueError(
+                        f"Either a concrete {name} bound or an AffineMap in combination "
+                        f"with {name} bound operands, but not both, is supported."
+                    )
+                if (
+                    isinstance(bounds[i], (OpView, Operation))
+                    and len(bounds[i].results) > 1
+                ):
+                    raise ValueError(
+                        f"Only a single concrete value is supported for {name} bound."
+                    )
+
+                bounds_operands[i].append(_get_op_result_or_value(bounds[i]))
+                bounds[i] = AffineMap.get_identity(1)
+
+            if not isinstance(bounds[i], AffineMap):
                 raise ValueError(
-                    f"Either a concrete lower bound or an AffineMap in combination "
-                    f"with lower bound operands, but not both, is supported."
+                    f"{name} bound must be int | ResultValueT | AffineMap."
                 )
-            lower_bound_operands.append(lower_bound)
-            lower_bound = AffineMap.get_identity(1)
-
-        if not isinstance(lower_bound, AffineMap):
-            raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
-
-        if isinstance(upper_bound, int):
-            upper_bound = AffineMap.get_constant(upper_bound)
-        elif isinstance(upper_bound, (Operation, OpView, Value)):
-            if len(upper_bound_operands):
+            if len(bounds_operands[i]) != bounds[i].n_inputs:
                 raise ValueError(
-                    f"Either a concrete upper bound or an AffineMap in combination "
-                    f"with upper bound operands, but not both, is supported."
+                    f"Wrong number of {name} bound operands passed to AffineForOp; "
+                    + f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}."
                 )
-            upper_bound_operands.append(upper_bound)
-            upper_bound = AffineMap.get_identity(1)
 
-        if not isinstance(upper_bound, AffineMap):
-            raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
+        lower_bound, upper_bound = bounds
 
         if iter_args is None:
             iter_args = []
@@ -126,7 +132,7 @@ def inner_iter_args(self):
 
 def for_(
     start,
-    stop=None,
+    stop,
     step=None,
     iter_args: Optional[Sequence[Value]] = None,
     *,
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 4789bd8560989..9e2027937a09f 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -108,10 +108,69 @@ def affine_for_op_test(buffer):
             # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
             tmp = memref.LoadOp(buffer, [sum.induction_variable])
             sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
-
             affine.AffineYieldOp([sum_next])
 
-        return
+
+# CHECK-LABEL: TEST: testAffineForOpErrors
+ at constructAndPrintInModule
+def testAffineForOpErrors():
+    c1 = arith.ConstantOp(T.index(), 1)
+    c2 = arith.ConstantOp(T.index(), 2)
+    c3 = arith.ConstantOp(T.index(), 3)
+    d0 = AffineDimExpr.get(0)
+
+    try:
+        affine.AffineForOp(
+            c1,
+            c2,
+            1,
+            lower_bound_operands=[c3],
+            upper_bound_operands=[],
+        )
+    except ValueError as e:
+        assert (
+            e.args[0]
+            == "Either a concrete lower bound or an AffineMap in combination with lower bound operands, but not both, is supported."
+        )
+
+    try:
+        affine.AffineForOp(
+            AffineMap.get_constant(1),
+            c2,
+            1,
+            lower_bound_operands=[c3, c3],
+            upper_bound_operands=[],
+        )
+    except ValueError as e:
+        assert (
+            e.args[0]
+            == "Wrong number of lower bound operands passed to AffineForOp; Expected 0, got 2."
+        )
+
+    try:
+        two_indices = affine.AffineDelinearizeIndexOp(
+            [T.index(), T.index()], c1, [c1, c1]
+        )
+        affine.AffineForOp(
+            two_indices,
+            c2,
+            1,
+            lower_bound_operands=[],
+            upper_bound_operands=[],
+        )
+    except ValueError as e:
+        assert e.args[0] == "Only a single concrete value is supported for lower bound."
+
+    try:
+        affine.AffineForOp(
+            1.0,
+            c2,
+            1,
+            lower_bound_operands=[],
+            upper_bound_operands=[],
+        )
+    except ValueError as e:
+        assert e.args[0] == "lower bound must be int | ResultValueT | AffineMap."
 
 
 @constructAndPrintInModule
@@ -182,51 +241,6 @@ def range_loop_4(lb, ub, memref_v):
             memref.store(add, memref_v, [i])
             affine.yield_([])
 
-    # CHECK-LABEL:   func.func @range_loop_5(
-    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
-    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
-    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:           }
-    # CHECK:           return
-    # CHECK:         }
-    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
-    def range_loop_5(lb, ub, memref_v):
-        for i in range(0, 10, step=1):
-            add = arith.addi(i, i)
-            memref.store(add, memref_v, [i])
-            affine.yield_([])
-
-    # CHECK-LABEL:   func.func @range_loop_6(
-    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
-    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
-    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:           }
-    # CHECK:           return
-    # CHECK:         }
-    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
-    def range_loop_6(lb, ub, memref_v):
-        for i in range(0, 10):
-            add = arith.addi(i, i)
-            memref.store(add, memref_v, [i])
-            affine.yield_([])
-
-    # CHECK-LABEL:   func.func @range_loop_7(
-    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
-    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
-    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
-    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
-    # CHECK:           }
-    # CHECK:           return
-    # CHECK:         }
-    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
-    def range_loop_7(lb, ub, memref_v):
-        for i in range(10):
-            add = arith.addi(i, i)
-            memref.store(add, memref_v, [i])
-            affine.yield_([])
-
     # CHECK-LABEL:   func.func @range_loop_8(
     # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
     # CHECK:           %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
@@ -238,7 +252,7 @@ def range_loop_7(lb, ub, memref_v):
     # CHECK:         }
     @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
     def range_loop_8(lb, ub, memref_v):
-        for i, it in range(10, iter_args=[memref_v]):
+        for i, it in range(0, 10, iter_args=[memref_v]):
             add = arith.addi(i, i)
             memref.store(add, it, [i])
             affine.yield_([it])



More information about the Mlir-commits mailing list