[Mlir-commits] [mlir] [mlir][python] fix `scf.for_` convenience builder (PR #72170)

Maksim Levental llvmlistbot at llvm.org
Mon Nov 13 15:50:07 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72170

>From 07cfcfb243b5ab32dc3f060a860d19b5997e5460 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 13 Nov 2023 17:03:34 -0600
Subject: [PATCH] [mlir][python] fix scf.for_ convenience builder

---
 mlir/python/mlir/dialects/scf.py |   4 +-
 mlir/test/python/dialects/scf.py | 120 ++++++++++++++++++++++++++++---
 2 files changed, 114 insertions(+), 10 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 71c80cab76dfb86..20bbed9bc93df67 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,11 +120,13 @@ def for_(
     params = [start, stop, step]
     for i, p in enumerate(params):
         if isinstance(p, int):
-            p = constant(p)
+            p = constant(IntegerAttr.get(IndexType.get(), p))
         elif isinstance(p, float):
             raise ValueError(f"{p=} must be int.")
         params[i] = p
 
+    start, stop, step = params
+
     for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
     iv = for_op.induction_variable
     iter_args = tuple(for_op.inner_iter_args)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 414307d8191513b..ee8d09aa301d98a 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,7 +3,9 @@
 from mlir.ir import *
 from mlir.dialects import arith
 from mlir.dialects import func
+from mlir.dialects import memref
 from mlir.dialects import scf
+from mlir.passmanager import PassManager
 
 
 def constructAndPrintInModule(f):
@@ -57,22 +59,122 @@ def induction_var(lb, ub, step):
 @constructAndPrintInModule
 def testForSugar():
     index_type = IndexType.get()
+    memref_t = MemRefType.get([10], index_type)
     range = scf.for_
 
-    @func.FuncOp.from_py_func(index_type, index_type, index_type)
-    def range_loop(lb, ub, step):
+    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+    # CHECK:      %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+    # CHECK:      memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_1(lb, ub, step, memref_v):
         for i in range(lb, ub, step):
             add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_2(lb, ub, step, memref_v):
+        for i in range(lb, 10, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
             scf.yield_([])
-        return
 
+    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_3(lb, ub, step, memref_v):
+        for i in range(0, ub, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
 
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK:   scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK:     %0 = arith.addi %[[IV]], %[[IV]] : index
-# CHECK:   }
-# CHECK:   return
-# CHECK: }
+    # CHECK:  func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
+    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
+    # CHECK:      memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_4(lb, ub, step, memref_v):
+        for i in range(0, 10, step):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_5(lb, ub, step, memref_v):
+        for i in range(0, 10, 1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_6(lb, ub, step, memref_v):
+        for i in range(0, 10):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
+
+    # CHECK:  func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 1 : index
+    # CHECK:    scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
+    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+    # CHECK:      memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def range_loop_7(lb, ub, step, memref_v):
+        for i in range(10):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            scf.yield_([])
 
 
 @constructAndPrintInModule



More information about the Mlir-commits mailing list