[Mlir-commits] [mlir] 7f58ffd - [mlir][python] Yield results of `scf.for_` (#93610)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 23:43:17 PDT 2024
Author: Guray Ozen
Date: 2024-05-29T08:43:13+02:00
New Revision: 7f58ffd09b29d3ff4f9fa025bd4d05dd8fd9fc38
URL: https://github.com/llvm/llvm-project/commit/7f58ffd09b29d3ff4f9fa025bd4d05dd8fd9fc38
DIFF: https://github.com/llvm/llvm-project/commit/7f58ffd09b29d3ff4f9fa025bd4d05dd8fd9fc38.diff
LOG: [mlir][python] Yield results of `scf.for_` (#93610)
Using `for_` is very hand with python bindings. Currently, it doesn't
support results, we had to fallback to two lines scf.for.
This PR yields results of scf.for in `for_`
---------
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
Added:
Modified:
mlir/python/mlir/dialects/scf.py
mlir/test/python/dialects/scf.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index dad7377987e56..7025f6e0f1a16 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -132,8 +132,8 @@ def for_(
iter_args = tuple(for_op.inner_iter_args)
with InsertionPoint(for_op.body):
if len(iter_args) > 1:
- yield iv, iter_args
+ yield iv, iter_args, for_op.results
elif len(iter_args) == 1:
- yield iv, iter_args[0]
+ yield iv, iter_args[0], for_op.results[0]
else:
yield iv
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index ee8d09aa301d9..95a6de86b670d 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -176,6 +176,56 @@ def range_loop_7(lb, ub, step, memref_v):
memref.store(add, memref_v, [i])
scf.yield_([])
+ # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
+ # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
+ # CHECK: scf.yield %[[VAL_9]] : index
+ # CHECK: }
+ # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def loop_yield_1(lb, ub, step, memref_v):
+ sum = arith.ConstantOp.create_index(0)
+ c0 = arith.ConstantOp.create_index(0)
+ for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
+ loc_sum = arith.addi(loc_sum, i)
+ scf.yield_([loc_sum])
+ memref.store(sum, memref_v, [c0])
+
+ # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[c0:.*]] = arith.constant 0 : index
+ # CHECK: %[[c2:.*]] = arith.constant 2 : index
+ # CHECK: %[[REF1:.*]] = arith.constant 0 : index
+ # CHECK: %[[REF2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+ # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
+ # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
+ # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
+ # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def loop_yield_2(lb, ub, step, memref_v):
+ sum1 = arith.ConstantOp.create_index(0)
+ sum2 = arith.ConstantOp.create_index(2)
+ c0 = arith.ConstantOp.create_index(0)
+ c1 = arith.ConstantOp.create_index(1)
+ for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
+ loc_sum1 = arith.addi(loc_sum1, i)
+ loc_sum2 = arith.addi(loc_sum2, i)
+ scf.yield_([loc_sum1, loc_sum2])
+ memref.store(sum1, memref_v, [c0])
+ memref.store(sum2, memref_v, [c1])
+
@constructAndPrintInModule
def testOpsAsArguments():
More information about the Mlir-commits
mailing list