[Mlir-commits] [mlir] [mlir][python] fix scf.for_ convenience builder (PR #72170)
Maksim Levental
llvmlistbot at llvm.org
Mon Nov 13 15:04:10 PST 2023
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/72170
None
>From ee7cd7f4869be0878f65a473d22a598ed8ad6f25 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 13 Nov 2023 17:03:34 -0600
Subject: [PATCH] [mlir][python] fix scf.for_ convenience builder
---
mlir/python/mlir/dialects/scf.py | 4 +-
mlir/test/python/dialects/scf.py | 101 ++++++++++++++++++++++++++-----
2 files changed, 90 insertions(+), 15 deletions(-)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 71c80cab76dfb86..20bbed9bc93df67 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -120,11 +120,13 @@ def for_(
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
- p = constant(p)
+ p = constant(IntegerAttr.get(IndexType.get(), p))
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p
+ start, stop, step = params
+
for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
iv = for_op.induction_variable
iter_args = tuple(for_op.inner_iter_args)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 414307d8191513b..8875c5a23d1847c 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,7 +3,9 @@
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects import func
+from mlir.dialects import memref
from mlir.dialects import scf
+from mlir.passmanager import PassManager
def constructAndPrintInModule(f):
@@ -54,25 +56,96 @@ def induction_var(lb, ub, step):
# CHECK-LABEL: TEST: testForSugar
- at constructAndPrintInModule
def testForSugar():
- index_type = IndexType.get()
- range = scf.for_
+ print("TEST: testForSugar")
+ with Context(), Location.unknown():
+ index_type = IndexType.get()
+ memref_t = MemRefType.get([10], index_type)
+ range = scf.for_
+ module = Module.create()
+ with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(index_type, index_type, index_type)
- def range_loop(lb, ub, step):
- for i in range(lb, ub, step):
- add = arith.addi(i, i)
- scf.yield_([])
- return
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def range_loop(lb, ub, step, memref_v):
+ for i in range(lb, ub, step):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+
+ scf.yield_([])
+
+ for i in range(lb, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, ub, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10, step):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10, 1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(0, 10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ for i in range(10):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ scf.yield_([])
+
+ return
+
+ pm = PassManager("builtin.module")
+ pm.add("canonicalize")
+ pm.run(module.operation)
+ print(module)
-# CHECK: func.func @range_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) {
-# CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
-# CHECK: %0 = arith.addi %[[IV]], %[[IV]] : index
+# CHECK: func.func @range_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<10xindex>) {
+# CHECK: %c0 = arith.constant 0 : index
+# CHECK: %c1 = arith.constant 1 : index
+# CHECK: %c10 = arith.constant 10 : index
+# CHECK: scf.for %arg4 = %arg0 to %arg1 step %arg2 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %arg0 to %c10 step %c1 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %c0 to %arg1 step %c1 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %c0 to %c10 step %arg2 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: scf.for %arg4 = %c0 to %c10 step %c1 {
+# CHECK: %0 = arith.addi %arg4, %arg4 : index
+# CHECK: memref.store %0, %arg3[%arg4] : memref<10xindex>
+# CHECK: }
+# CHECK: return
# CHECK: }
-# CHECK: return
-# CHECK: }
+testForSugar()
@constructAndPrintInModule
More information about the Mlir-commits
mailing list