[Mlir-commits] [mlir] e9453f3 - [mlir][python] fix `scf.for_` convenience builder (#72170)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 13 18:25:45 PST 2023


Author: Maksim Levental
Date: 2023-11-13T20:25:41-06:00
New Revision: e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa

URL: https://github.com/llvm/llvm-project/commit/e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa
DIFF: https://github.com/llvm/llvm-project/commit/e9453f3c3c7e682e39952c9e18e6b1f8152b0ffa.diff

LOG: [mlir][python] fix `scf.for_` convenience builder (#72170)

Added: 
    

Modified: 
    mlir/python/mlir/dialects/scf.py
    mlir/test/python/dialects/scf.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 71c80cab76dfb86..20bbed9bc93df67 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,11 +120,13 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(p)
+            p = constant(IntegerAttr.get(IndexType.get(), p))
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
 
+    start, stop, step = params
+
     for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
     iv = for_op.induction_variable
     iter_args = tuple(for_op.inner_iter_args)

diff  --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 414307d8191513b..ee8d09aa301d98a 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,7 +3,9 @@
 from mlir.ir import *
 from mlir.dialects import arith
 from mlir.dialects import func
+from mlir.dialects import memref
 from mlir.dialects import scf
+from mlir.passmanager import PassManager
 
 
 def constructAndPrintInModule(f):
@@ -57,22 +59,122 @@ def induction_var(lb, ub, step):
 @constructAndPrintInModule
 def testForSugar():
     index_type = IndexType.get()
+    memref_t = MemRefType.get([10], index_type)
     range = scf.for_
 
-    @func.FuncOp.from_py_func(index_type, index_type, index_type)
-    def range_loop(lb, ub, step):
+    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+    # CHECK:      %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+    # CHECK:      memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_1(lb, ub, step, memref_v):
         for i in range(lb, ub, step):
             add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_2(lb, ub, step, memref_v):
+        for i in range(lb, 10, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
             scf.yield_([])
-        return
 
+    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_3(lb, ub, step, memref_v):
+        for i in range(0, ub, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
 
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
-# CHECK:   }
-# CHECK:   return
-# CHECK: }
+    # CHECK:  func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_4(lb, ub, step, memref_v):
+        for i in range(0, 10, step):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_5(lb, ub, step, memref_v):
+        for i in range(0, 10, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_6(lb, ub, step, memref_v):
+        for i in range(0, 10):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_7(lb, ub, step, memref_v):
+        for i in range(10):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
 
 
 @constructAndPrintInModule


        


More information about the Mlir-commits mailing list