[Mlir-commits] [mlir] [mlir][python] fix up affine for (PR #74495)
Maksim Levental
llvmlistbot at llvm.org
Tue Dec 5 08:50:29 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/74495
>From ab80dc302ab39735caf7949b953108c2f4f04f21 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 5 Dec 2023 10:42:21 -0600
Subject: [PATCH] [mlir][python] fix up affine for
---
mlir/python/mlir/dialects/_ods_common.py | 3 ++
mlir/python/mlir/dialects/affine.py | 68 +++++++++++++-----------
2 files changed, 40 insertions(+), 31 deletions(-)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..7736f1c579320 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.irOpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..47456a50fcc6f 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
+from ._ods_common import ResultValueT, VariadicResultValueT
try:
from ..ir import *
@@ -21,17 +21,17 @@
@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
- """Specialization for the Affine for op class"""
+ """Specialization for the Affine for op class."""
def __init__(
self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ lower_bound: Union[int, AffineMap],
+ upper_bound: Optional[Union[int, AffineMap]] = None,
+ step: Optional[Union[VariadicResultValueT]] = None,
+ iter_args: Optional[ResultValueT] = None,
*,
- lower_bound_operands=[],
- upper_bound_operands=[],
+ lower_bound_operands: Optional[VariadicResultValueT] = None,
+ upper_bound_operands: Optional[VariadicResultValueT] = None,
loc=None,
ip=None,
):
@@ -45,12 +45,37 @@ def __init__(
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
then symbols in the `lower_bound` affine map, in an increasing order
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
- then symbols in the `upper_bound` affine map, in an increasing order
+ then symbols in the `upper_bound` affine map, in an increasing order.
"""
+ if step is None:
+ step = 1
+ if upper_bound is None:
+ upper_bound = lower_bound
+ lower_bound = 0
+
+ if isinstance(lower_bound, int):
+ lower_bound = AffineMap.get_constant(lower_bound)
+ if lower_bound_operands is not None:
+ raise ValueError(
+ "Constant lower bound doesn't require lower bound operands."
+ )
+ if isinstance(upper_bound, int):
+ upper_bound = AffineMap.get_constant(upper_bound)
+ if upper_bound_operands is not None:
+ raise ValueError(
+ "Constant upper bound doesn't require upper bound operands."
+ )
+
+ if lower_bound_operands is None:
+ lower_bound_operands = []
+ if upper_bound_operands is None:
+ upper_bound_operands = []
+
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
+
if len(lower_bound_operands) != lower_bound.n_inputs:
raise ValueError(
f"Wrong number of lower bound operands passed to AffineForOp. "
@@ -105,30 +130,11 @@ def for_(
loc=None,
ip=None,
):
- if step is None:
- step = 1
- if stop is None:
- stop = start
- start = 0
- params = [start, stop]
- for i, p in enumerate(params):
- if isinstance(p, int):
- p = constant(IntegerAttr.get(IndexType.get(), p))
- elif isinstance(p, float):
- raise ValueError(f"{p=} must be int.")
- params[i] = p
-
- start, stop = params
- s0 = AffineSymbolExpr.get(0)
- lbmap = AffineMap.get(0, 1, [s0])
- ubmap = AffineMap.get(0, 1, [s0])
for_op = AffineForOp(
- lbmap,
- ubmap,
+ start,
+ stop,
step,
iter_args=iter_args,
- lower_bound_operands=[start],
- upper_bound_operands=[stop],
loc=loc,
ip=ip,
)
More information about the Mlir-commits
mailing list