[Mlir-commits] [mlir] [mlir][python] fix `scf.for_` convenience builder (PR #72170)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 13 15:33:24 PST 2023


================
@@ -54,25 +56,96 @@ def induction_var(lb, ub, step):
 
 
 # CHECK-LABEL: TEST: testForSugar
- at constructAndPrintInModule
 def testForSugar():
-    index_type = IndexType.get()
-    range = scf.for_
+    print("TEST: testForSugar")
+    with Context(), Location.unknown():
+        index_type = IndexType.get()
+        memref_t = MemRefType.get([10], index_type)
+        range = scf.for_
+        module = Module.create()
+        with InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(index_type, index_type, index_type)
-    def range_loop(lb, ub, step):
-        for i in range(lb, ub, step):
-            add = arith.addi(i, i)
-            scf.yield_([])
-        return
+            @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+            def range_loop(lb, ub, step, memref_v):
+                for i in range(lb, ub, step):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+
+                    scf.yield_([])
+
+                for i in range(lb, 10, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, ub, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10, step):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(10):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                return
+
+        pm = PassManager("builtin.module")
+        pm.add("canonicalize")
+        pm.run(module.operation)
+        print(module)
 
 
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK:   func.func @range_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<10xindex>) {
+# CHECK:     %c0 = arith.constant 0 : index
----------------
rkayaith wrote:

maybe split these into separate functions in the module? seems better to avoid relying on `canonicalize`/hardcoding var names

https://github.com/llvm/llvm-project/pull/72170


More information about the Mlir-commits mailing list