[Mlir-commits] [mlir] fd7eee6 - scf::ForOp: Fold away iterator arguments with no use and for which the corresponding input is yielded

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 16 00:01:51 PDT 2021


Author: Lorenzo Chelini
Date: 2021-03-16T07:01:25Z
New Revision: fd7eee64c570e5e14e511045c64d4d8cf98dde25

URL: https://github.com/llvm/llvm-project/commit/fd7eee64c570e5e14e511045c64d4d8cf98dde25
DIFF: https://github.com/llvm/llvm-project/commit/fd7eee64c570e5e14e511045c64d4d8cf98dde25.diff

LOG: scf::ForOp: Fold away iterator arguments with no use and for which the corresponding input is yielded

Enhance 'ForOpIterArgsFolder' to remove unused iteration arguments in a
scf::ForOp. If the block argument corresponding to the given iterator has no
use and the yielded value equals the input, we fold it away.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D98503

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 9c0df1b47c35..c66d0ea497a3 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -408,9 +408,14 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
 }
 
 namespace {
-// Fold away ForOp iter arguments that are also yielded by the op.
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
+// Fold away ForOp iter arguments when:
+// 1) The op yields the iter arguments.
+// 2) The iter arguments have no use and the corresponding outer region
+// iterators (inputs) are yielded.
+//
+// These arguments must be defined outside of
+// the ForOp region and can just be forwarded after simplifying the op inits,
+// yields and returns.
 //
 // The implementation uses `mergeBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
@@ -441,8 +446,13 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
                              forOp.getRegionIterArgs(), // iter inside region
                              yieldOp.getOperands()      // iter yield
                              )) {
-      // Forwarded is `true` when the region `iter` argument is yielded.
-      bool forwarded = (std::get<1>(it) == std::get<2>(it));
+      // Forwarded is `true` when:
+      // 1) The region `iter` argument is yielded.
+      // 2) The region `iter` argument has zero use, and the corresponding iter
+      // operand (input) is yielded.
+      bool forwarded =
+          ((std::get<1>(it) == std::get<2>(it)) ||
+           (std::get<1>(it).use_empty() && std::get<0>(it) == std::get<2>(it)));
       keepMask.push_back(!forwarded);
       canonicalize |= forwarded;
       if (forwarded) {
@@ -483,7 +493,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
            "unexpected argument size mismatch");
 
     // No results case: the scf::ForOp builder already created a zero
-    // reult terminator. Merge before this terminator and just get rid of the
+    // result terminator. Merge before this terminator and just get rid of the
     // original terminator that has been merged in.
     if (newIterArgs.empty()) {
       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8f76926bdff0..6f75532b9bc7 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -335,6 +335,7 @@ func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
 }
 
 // -----
+
 func private @process(%0 : memref<128x128xf32>)
 func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
 
@@ -382,3 +383,22 @@ func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
   // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
   return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
 }
+
+// -----
+
+// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
+//  CHECK-SAME:   %[[A0:[0-9a-z]*]]: i32
+func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
+                    %ub : index, %lb : index, %step : index) -> (i32, i32) {
+  // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
+  %cst = constant 32 : i32
+  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { 
+  %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
+    -> (i32, i32) {
+    %1 = addi %arg2, %cst : i32
+    scf.yield %1, %cst : i32, i32
+  }
+
+  // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32
+  return %0#0, %0#1 : i32, i32
+}


        


More information about the Mlir-commits mailing list