[Mlir-commits] [mlir] [mlir][python] fix `scf.for_` convenience builder (PR #72170)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 13 15:33:24 PST 2023
================
@@ -54,25 +56,96 @@ def induction_var(lb, ub, step):
# CHECK-LABEL: TEST: testForSugar
- at constructAndPrintInModule
def testForSugar():
- index_type = IndexType.get()
- range = scf.for_
+ print("TEST: testForSugar")
+ with Context(), Location.unknown():
+ index_type = IndexType.get()
+ memref_t = MemRefType.get([10], index_type)
+ range = scf.for_
+ module = Module.create()
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(index_type, index_type, index_type)
- def range_loop(lb, ub, step):
- for i in range(lb, ub, step):
- add = arith.addi(i, i)
- scf.yield_([])
- return
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop(lb, ub, step, memref_v):
+ for i in range(lb, ub, step):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+
+ scf.yield_([])
+
+ for i in range(lb, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, ub, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10, step):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ return
+
+ pm = PassManager("builtin.module")
+ pm.add("canonicalize")
+ pm.run(module.operation)
+ print(module)
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK: func.func @range_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<10xindex>) {
+# CHECK: %c0 = arith.constant 0 : index
----------------
rkayaith wrote:
maybe split these into separate functions in the module? seems better to avoid relying on `canonicalize`/hardcoding var names
https://github.com/llvm/llvm-project/pull/72170
More information about the Mlir-commits
mailing list