[Mlir-commits] [mlir] [mlir][python] Yield results of `scf.for_` (PR #93610)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 28 13:36:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for.

This PR yields results of scf.for in `for_`

---
Full diff: https://github.com/llvm/llvm-project/pull/93610.diff


2 Files Affected:

- (modified) mlir/python/mlir/dialects/scf.py (+2-2) 
- (modified) mlir/test/python/dialects/scf.py (+50) 


``````````diff
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index dad7377987e56..d2b52fb8a235e 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -132,8 +132,8 @@ def for_(
     iter_args = tuple(for_op.inner_iter_args)
     with InsertionPoint(for_op.body):
         if len(iter_args) > 1:
-            yield iv, iter_args
+            yield iv, iter_args, for_op.results
         elif len(iter_args) == 1:
-            yield iv, iter_args[0]
+            yield iv, iter_args[0], for_op.results
         else:
             yield iv
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index ee8d09aa301d9..95a6de86b670d 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -176,6 +176,56 @@ def range_loop_7(lb, ub, step, memref_v):
             memref.store(add, memref_v, [i])
             scf.yield_([])
 
+    # CHECK:  func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
+    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
+    # CHECK:    %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
+    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
+    # CHECK:      scf.yield %[[VAL_9]] : index
+    # CHECK:    }
+    # CHECK:    memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def loop_yield_1(lb, ub, step, memref_v):
+        sum = arith.ConstantOp.create_index(0)
+        c0 = arith.ConstantOp.create_index(0)
+        for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
+            loc_sum = arith.addi(loc_sum, i)
+            scf.yield_([loc_sum])
+        memref.store(sum, memref_v, [c0])
+
+    # CHECK:  func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[c0:.*]] = arith.constant 0 : index
+    # CHECK:    %[[c2:.*]] = arith.constant 2 : index
+    # CHECK:    %[[REF1:.*]] = arith.constant 0 : index
+    # CHECK:    %[[REF2:.*]] = arith.constant 1 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
+    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
+    # CHECK:    %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
+    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
+    # CHECK:      %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
+    # CHECK:      scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def loop_yield_2(lb, ub, step, memref_v):
+        sum1 = arith.ConstantOp.create_index(0)
+        sum2 = arith.ConstantOp.create_index(2)
+        c0 = arith.ConstantOp.create_index(0)
+        c1 = arith.ConstantOp.create_index(1)
+        for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
+            loc_sum1 = arith.addi(loc_sum1, i)
+            loc_sum2 = arith.addi(loc_sum2, i)
+            scf.yield_([loc_sum1, loc_sum2])
+        memref.store(sum1, memref_v, [c0])
+        memref.store(sum2, memref_v, [c1])
+
 
 @constructAndPrintInModule
 def testOpsAsArguments():

``````````

</details>


https://github.com/llvm/llvm-project/pull/93610


More information about the Mlir-commits mailing list