[Mlir-commits] [mlir] fd7eee6 - scf::ForOp: Fold away iterator arguments with no use and for which the corresponding input is yielded
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Mar 16 00:01:51 PDT 2021
Author: Lorenzo Chelini
Date: 2021-03-16T07:01:25Z
New Revision: fd7eee64c570e5e14e511045c64d4d8cf98dde25
URL: https://github.com/llvm/llvm-project/commit/fd7eee64c570e5e14e511045c64d4d8cf98dde25
DIFF: https://github.com/llvm/llvm-project/commit/fd7eee64c570e5e14e511045c64d4d8cf98dde25.diff
LOG: scf::ForOp: Fold away iterator arguments with no use and for which the corresponding input is yielded
Enhance 'ForOpIterArgsFolder' to remove unused iteration arguments in a
scf::ForOp. If the block argument corresponding to the given iterator has no
use and the yielded value equals the input, we fold it away.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D98503
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 9c0df1b47c35..c66d0ea497a3 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -408,9 +408,14 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
}
namespace {
-// Fold away ForOp iter arguments that are also yielded by the op.
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
+// Fold away ForOp iter arguments when:
+// 1) The op yields the iter arguments.
+// 2) The iter arguments have no use and the corresponding outer region
+// iterators (inputs) are yielded.
+//
+// These arguments must be defined outside of
+// the ForOp region and can just be forwarded after simplifying the op inits,
+// yields and returns.
//
// The implementation uses `mergeBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
@@ -441,8 +446,13 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
forOp.getRegionIterArgs(), // iter inside region
yieldOp.getOperands() // iter yield
)) {
- // Forwarded is `true` when the region `iter` argument is yielded.
- bool forwarded = (std::get<1>(it) == std::get<2>(it));
+ // Forwarded is `true` when:
+ // 1) The region `iter` argument is yielded.
+ // 2) The region `iter` argument has zero use, and the corresponding iter
+ // operand (input) is yielded.
+ bool forwarded =
+ ((std::get<1>(it) == std::get<2>(it)) ||
+ (std::get<1>(it).use_empty() && std::get<0>(it) == std::get<2>(it)));
keepMask.push_back(!forwarded);
canonicalize |= forwarded;
if (forwarded) {
@@ -483,7 +493,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
"unexpected argument size mismatch");
// No results case: the scf::ForOp builder already created a zero
- // reult terminator. Merge before this terminator and just get rid of the
+ // result terminator. Merge before this terminator and just get rid of the
// original terminator that has been merged in.
if (newIterArgs.empty()) {
auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8f76926bdff0..6f75532b9bc7 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -335,6 +335,7 @@ func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
}
// -----
+
func private @process(%0 : memref<128x128xf32>)
func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
@@ -382,3 +383,22 @@ func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}
+
+// -----
+
+// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
+// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32
+func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
+ %ub : index, %lb : index, %step : index) -> (i32, i32) {
+ // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32
+ %cst = constant 32 : i32
+ // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) {
+ %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst)
+ -> (i32, i32) {
+ %1 = addi %arg2, %cst : i32
+ scf.yield %1, %cst : i32, i32
+ }
+
+ // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32
+ return %0#0, %0#1 : i32, i32
+}
More information about the Mlir-commits
mailing list