[Mlir-commits] [mlir] e9453f3 - [mlir][python] fix `scf.for_` convenience builder (#72170)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 13 18:25:45 PST 2023
Author: Maksim Levental
Date: 2023-11-13T20:25:41-06:00
New Revision: e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa
URL: https://github.com/llvm/llvm-project/commit/e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa
DIFF: https://github.com/llvm/llvm-project/commit/e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa.diff
LOG: [mlir][python] fix `scf.for_` convenience builder (#72170)
Added:
Modified:
mlir/python/mlir/dialects/scf.py
mlir/test/python/dialects/scf.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 71c80cab76dfb86..20bbed9bc93df67 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,11 +120,13 @@ def for_(
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
- p = constant(p)
+ p = constant(IntegerAttr.get(IndexType.get(), p))
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p
+ start, stop, step = params
+
for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
iv = for_op.induction_variable
iter_args = tuple(for_op.inner_iter_args)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 414307d8191513b..ee8d09aa301d98a 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,7 +3,9 @@
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects import func
+from mlir.dialects import memref
from mlir.dialects import scf
+from mlir.passmanager import PassManager
def constructAndPrintInModule(f):
@@ -57,22 +59,122 @@ def induction_var(lb, ub, step):
@constructAndPrintInModule
def testForSugar():
index_type = IndexType.get()
+ memref_t = MemRefType.get([10], index_type)
range = scf.for_
- @func.FuncOp.from_py_func(index_type, index_type, index_type)
- def range_loop(lb, ub, step):
+ # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+ # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+ # CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_1(lb, ub, step, memref_v):
for i in range(lb, ub, step):
add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+
+ scf.yield_([])
+
+ # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+ # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
+ # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+ # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_2(lb, ub, step, memref_v):
+ for i in range(lb, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
scf.yield_([])
- return
+ # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+ # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
+ # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+ # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_3(lb, ub, step, memref_v):
+ for i in range(0, ub, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
-# CHECK: }
-# CHECK: return
-# CHECK: }
+ # CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
+ # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
+ # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+ # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_4(lb, ub, step, memref_v):
+ for i in range(0, 10, step):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ # CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+ # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+ # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+ # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_5(lb, ub, step, memref_v):
+ for i in range(0, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ # CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+ # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+ # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+ # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_6(lb, ub, step, memref_v):
+ for i in range(0, 10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ # CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
+ # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+ # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+ # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop_7(lb, ub, step, memref_v):
+ for i in range(10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
@constructAndPrintInModule
More information about the Mlir-commits
mailing list