[Mlir-commits] [mlir] [mlir][python] fix scf.for_ convenience builder (PR #72170)

Maksim Levental llvmlistbot at llvm.org
Mon Nov 13 15:04:10 PST 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/72170

None

>From ee7cd7f4869be0878f65a473d22a598ed8ad6f25 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 13 Nov 2023 17:03:34 -0600
Subject: [PATCH] [mlir][python] fix scf.for_ convenience builder

---
 mlir/python/mlir/dialects/scf.py |   4 +-
 mlir/test/python/dialects/scf.py | 101 ++++++++++++++++++++++++++-----
 2 files changed, 90 insertions(+), 15 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 71c80cab76dfb86..20bbed9bc93df67 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,11 +120,13 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(p)
+            p = constant(IntegerAttr.get(IndexType.get(), p))
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
 
+    start, stop, step = params
+
     for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
     iv = for_op.induction_variable
     iter_args = tuple(for_op.inner_iter_args)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 414307d8191513b..8875c5a23d1847c 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,7 +3,9 @@
 from mlir.ir import *
 from mlir.dialects import arith
 from mlir.dialects import func
+from mlir.dialects import memref
 from mlir.dialects import scf
+from mlir.passmanager import PassManager
 
 
 def constructAndPrintInModule(f):
@@ -54,25 +56,96 @@ def induction_var(lb, ub, step):
 
 
 # CHECK-LABEL: TEST: testForSugar
- at constructAndPrintInModule
 def testForSugar():
-    index_type = IndexType.get()
-    range = scf.for_
+    print("TEST: testForSugar")
+    with Context(), Location.unknown():
+        index_type = IndexType.get()
+        memref_t = MemRefType.get([10], index_type)
+        range = scf.for_
+        module = Module.create()
+        with InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(index_type, index_type, index_type)
-    def range_loop(lb, ub, step):
-        for i in range(lb, ub, step):
-            add = arith.addi(i, i)
-            scf.yield_([])
-        return
+            @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+            def range_loop(lb, ub, step, memref_v):
+                for i in range(lb, ub, step):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+
+                    scf.yield_([])
+
+                for i in range(lb, 10, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, ub, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10, step):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10, 1):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(0, 10):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                for i in range(10):
+                    add = arith.addi(i, i)
+                    memref.store(add, memref_v, [i])
+                    scf.yield_([])
+
+                return
+
+        pm = PassManager("builtin.module")
+        pm.add("canonicalize")
+        pm.run(module.operation)
+        print(module)
 
 
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK:   func.func @range_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<10xindex>) {
+# CHECK:     %c0 = arith.constant 0 : index
+# CHECK:     %c1 = arith.constant 1 : index
+# CHECK:     %c10 = arith.constant 10 : index
+# CHECK:     scf.for %arg4 = %arg0 to %arg1 step %arg2 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %arg0 to %c10 step %c1 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %c0 to %arg1 step %c1 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %c0 to %c10 step %arg2 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK:       %0 = arith.addi %arg4, %arg4 : index
+# CHECK:       memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK:     }
+# CHECK:     return
 # CHECK:   }
-# CHECK:   return
-# CHECK: }
+testForSugar()
 
 
 @constructAndPrintInModule



More information about the Mlir-commits mailing list