[Mlir-commits] [mlir] [mlir][python] fix up affine for (PR #74495)

Maksim Levental llvmlistbot at llvm.org
Tue Dec 5 08:43:02 PST 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/74495

Fix up https://github.com/llvm/llvm-project/pull/74408.

>From 598b4bbe0f547d54178b05b155b2f91529b1ce65 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 5 Dec 2023 10:42:21 -0600
Subject: [PATCH] [mlir][python] fix up affine for

---
 mlir/python/mlir/dialects/_ods_common.py |  3 ++
 mlir/python/mlir/dialects/affine.py      | 60 ++++++++++++------------
 2 files changed, 32 insertions(+), 31 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..7736f1c579320 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
 # see the typing.Type doc string.
 _U = _TypeVar("_U", bound=_cext.ir.Value)
 SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.irOpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..7ed51bd4ed631 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,8 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
+from ._ods_common import ResultValueT, VariadicResultValueT
 
 try:
     from ..ir import *
@@ -21,17 +21,17 @@
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class AffineForOp(AffineForOp):
-    """Specialization for the Affine for op class"""
+    """Specialization for the Affine for op class."""
 
     def __init__(
         self,
-        lower_bound,
-        upper_bound,
-        step,
-        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        lower_bound: Union[int, AffineMap],
+        upper_bound: Optional[Union[int, AffineMap]],
+        step: Optional[Union[VariadicResultValueT]],
+        iter_args: Optional[ResultValueT] = None,
         *,
-        lower_bound_operands=[],
-        upper_bound_operands=[],
+        lower_bound_operands: Optional[VariadicResultValueT] = None,
+        upper_bound_operands: Optional[VariadicResultValueT] = None,
         loc=None,
         ip=None,
     ):
@@ -45,12 +45,29 @@ def __init__(
         - `lower_bound_operands` is the list of arguments to substitute the dimensions,
           then symbols in the `lower_bound` affine map, in an increasing order
         - `upper_bound_operands` is the list of arguments to substitute the dimensions,
-          then symbols in the `upper_bound` affine map, in an increasing order
+          then symbols in the `upper_bound` affine map, in an increasing order.
         """
 
+        if step is None:
+            step = 1
+        if upper_bound is None:
+            upper_bound = lower_bound
+            lower_bound = 0
+
+        if lower_bound_operands is None:
+            lower_bound_operands = []
+        if upper_bound_operands is None:
+            upper_bound_operands = []
+
+        if isinstance(lower_bound, int):
+            lower_bound = AffineMap.get_constant(lower_bound)
+        if isinstance(upper_bound, int):
+            upper_bound = AffineMap.get_constant(upper_bound)
+
         if iter_args is None:
             iter_args = []
         iter_args = _get_op_results_or_values(iter_args)
+
         if len(lower_bound_operands) != lower_bound.n_inputs:
             raise ValueError(
                 f"Wrong number of lower bound operands passed to AffineForOp. "
@@ -105,30 +122,11 @@ def for_(
     loc=None,
     ip=None,
 ):
-    if step is None:
-        step = 1
-    if stop is None:
-        stop = start
-        start = 0
-    params = [start, stop]
-    for i, p in enumerate(params):
-        if isinstance(p, int):
-            p = constant(IntegerAttr.get(IndexType.get(), p))
-        elif isinstance(p, float):
-            raise ValueError(f"{p=} must be int.")
-        params[i] = p
-
-    start, stop = params
-    s0 = AffineSymbolExpr.get(0)
-    lbmap = AffineMap.get(0, 1, [s0])
-    ubmap = AffineMap.get(0, 1, [s0])
     for_op = AffineForOp(
-        lbmap,
-        ubmap,
+        start,
+        stop,
         step,
         iter_args=iter_args,
-        lower_bound_operands=[start],
-        upper_bound_operands=[stop],
         loc=loc,
         ip=ip,
     )



More information about the Mlir-commits mailing list