[Mlir-commits] [mlir] [mlir][python] Yield results of `scf.for_` (PR #93610)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 13:36:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for.
This PR yields results of scf.for in `for_`
---
Full diff: https://github.com/llvm/llvm-project/pull/93610.diff
2 Files Affected:
- (modified) mlir/python/mlir/dialects/scf.py (+2-2)
- (modified) mlir/test/python/dialects/scf.py (+50)
``````````diff
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index dad7377987e56..d2b52fb8a235e 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -132,8 +132,8 @@ def for_(
iter_args = tuple(for_op.inner_iter_args)
with InsertionPoint(for_op.body):
if len(iter_args) > 1:
- yield iv, iter_args
+ yield iv, iter_args, for_op.results
elif len(iter_args) == 1:
- yield iv, iter_args[0]
+ yield iv, iter_args[0], for_op.results
else:
yield iv
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index ee8d09aa301d9..95a6de86b670d 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -176,6 +176,56 @@ def range_loop_7(lb, ub, step, memref_v):
memref.store(add, memref_v, [i])
scf.yield_([])
+ # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
+ # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
+ # CHECK: scf.yield %[[VAL_9]] : index
+ # CHECK: }
+ # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def loop_yield_1(lb, ub, step, memref_v):
+ sum = arith.ConstantOp.create_index(0)
+ c0 = arith.ConstantOp.create_index(0)
+ for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
+ loc_sum = arith.addi(loc_sum, i)
+ scf.yield_([loc_sum])
+ memref.store(sum, memref_v, [c0])
+
+ # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+ # CHECK: %[[c0:.*]] = arith.constant 0 : index
+ # CHECK: %[[c2:.*]] = arith.constant 2 : index
+ # CHECK: %[[REF1:.*]] = arith.constant 0 : index
+ # CHECK: %[[REF2:.*]] = arith.constant 1 : index
+ # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
+ # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
+ # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+ # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
+ # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
+ # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
+ # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+ def loop_yield_2(lb, ub, step, memref_v):
+ sum1 = arith.ConstantOp.create_index(0)
+ sum2 = arith.ConstantOp.create_index(2)
+ c0 = arith.ConstantOp.create_index(0)
+ c1 = arith.ConstantOp.create_index(1)
+ for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
+ loc_sum1 = arith.addi(loc_sum1, i)
+ loc_sum2 = arith.addi(loc_sum2, i)
+ scf.yield_([loc_sum1, loc_sum2])
+ memref.store(sum1, memref_v, [c0])
+ memref.store(sum2, memref_v, [c1])
+
@constructAndPrintInModule
def testOpsAsArguments():
``````````
</details>
https://github.com/llvm/llvm-project/pull/93610
More information about the Mlir-commits
mailing list