[Mlir-commits] [mlir] [mlir][python] Yield results of `scf.for_` (PR #93610)

Guray Ozen llvmlistbot at llvm.org
Tue May 28 13:36:24 PDT 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/93610

Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for.

This PR yields results of scf.for in `for_`

>From d1d2010650bb8f97e3995e2133381d31e3820d62 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 28 May 2024 22:35:37 +0200
Subject: [PATCH] [mlir][python] Yield results of scf.for_

Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for.

This PR yields results of scf.for in `for_`
---
 mlir/python/mlir/dialects/scf.py |  4 +--
 mlir/test/python/dialects/scf.py | 50 ++++++++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index dad7377987e56..d2b52fb8a235e 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -132,8 +132,8 @@ def for_(
     iter_args = tuple(for_op.inner_iter_args)
     with InsertionPoint(for_op.body):
         if len(iter_args) > 1:
-            yield iv, iter_args
+            yield iv, iter_args, for_op.results
         elif len(iter_args) == 1:
-            yield iv, iter_args[0]
+            yield iv, iter_args[0], for_op.results
         else:
             yield iv
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index ee8d09aa301d9..95a6de86b670d 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -176,6 +176,56 @@ def range_loop_7(lb, ub, step, memref_v):
             memref.store(add, memref_v, [i])
             scf.yield_([])
 
+    # CHECK:  func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_5:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
+    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
+    # CHECK:    %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
+    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
+    # CHECK:      scf.yield %[[VAL_9]] : index
+    # CHECK:    }
+    # CHECK:    memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def loop_yield_1(lb, ub, step, memref_v):
+        sum = arith.ConstantOp.create_index(0)
+        c0 = arith.ConstantOp.create_index(0)
+        for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
+            loc_sum = arith.addi(loc_sum, i)
+            scf.yield_([loc_sum])
+        memref.store(sum, memref_v, [c0])
+
+    # CHECK:  func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
+    # CHECK:    %[[c0:.*]] = arith.constant 0 : index
+    # CHECK:    %[[c2:.*]] = arith.constant 2 : index
+    # CHECK:    %[[REF1:.*]] = arith.constant 0 : index
+    # CHECK:    %[[REF2:.*]] = arith.constant 1 : index
+    # CHECK:    %[[VAL_6:.*]] = arith.constant 0 : index
+    # CHECK:    %[[VAL_7:.*]] = arith.constant 100 : index
+    # CHECK:    %[[VAL_8:.*]] = arith.constant 1 : index
+    # CHECK:    %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
+    # CHECK:      %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
+    # CHECK:      %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
+    # CHECK:      scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
+    # CHECK:    }
+    # CHECK:    return
+    # CHECK:  }
+    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
+    def loop_yield_2(lb, ub, step, memref_v):
+        sum1 = arith.ConstantOp.create_index(0)
+        sum2 = arith.ConstantOp.create_index(2)
+        c0 = arith.ConstantOp.create_index(0)
+        c1 = arith.ConstantOp.create_index(1)
+        for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
+            loc_sum1 = arith.addi(loc_sum1, i)
+            loc_sum2 = arith.addi(loc_sum2, i)
+            scf.yield_([loc_sum1, loc_sum2])
+        memref.store(sum1, memref_v, [c0])
+        memref.store(sum2, memref_v, [c1])
+
 
 @constructAndPrintInModule
 def testOpsAsArguments():



More information about the Mlir-commits mailing list