[Mlir-commits] [mlir] [mlir][scf] Always remove for iter args that are loop invariant (PR #121555)
Jeff Niu
llvmlistbot at llvm.org
Fri Jan 3 11:44:42 PST 2025
https://github.com/Mogball updated https://github.com/llvm/llvm-project/pull/121555
>From 4ded6315aa5e9f1184c5af8cecf3ff05094c3f14 Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeffniu at openai.com>
Date: Fri, 3 Jan 2025 01:12:02 -0800
Subject: [PATCH 1/2] [mlir][scf] Always remove for iter args that are loop
invariant
This alters the condition in ForOpIterArgsFolder to always remove iter
args when their initial value equals the yielded value, not just when
the arg has no use.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 25 ++++++++++++-------------
mlir/test/Dialect/SCF/canonicalize.mlir | 22 ++++++++++++++++++----
2 files changed, 30 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..872d34de4495bf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -872,30 +872,29 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
- for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
+ for (auto [init, arg, result, yielded] :
+ llvm::zip(forOp.getInitArgs(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ forOp.getResults(), // op results
+ forOp.getYieldedValues() // iter yield
+ )) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
// 2) The region `iter` argument has no use, and the corresponding iter
// operand (input) is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
- (std::get<1>(it).use_empty() &&
- (std::get<0>(it) == std::get<3>(it) ||
- std::get<2>(it).use_empty())));
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
keepMask.push_back(!forwarded);
canonicalize |= forwarded;
if (forwarded) {
- newBlockTransferArgs.push_back(std::get<0>(it));
- newResultValues.push_back(std::get<0>(it));
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
continue;
}
- newIterArgs.push_back(std::get<0>(it));
- newYieldValues.push_back(std::get<3>(it));
+ newIterArgs.push_back(init);
+ newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
newResultValues.push_back(Value()); // placeholder with null value
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
// -----
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+ %c0_i32 = arith.constant 0 : i32
+ // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+ %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+ // CHECK-NEXT: "test.use"(%c0_i32)
+ "test.use"(%arg3) : (i32) -> ()
+ scf.yield %c0_i32 : i32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: @replace_true_if
func.func @replace_true_if() {
%true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
%switch_cst_2 = arith.constant 2: index
%1 = scf.index_switch %switch_cst_2 -> f32
case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
return %0, %1 : f32, f32
}
>From d26846433898c9639b37c2dcb9b3247625e63dd7 Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeffniu at openai.com>
Date: Fri, 3 Jan 2025 11:44:31 -0800
Subject: [PATCH 2/2] update doc
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 872d34de4495bf..83ae79ce482669 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
-// 2) The iter arguments have no use and the corresponding outer region
-// iterators (inputs) are yielded.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
@@ -880,8 +879,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
)) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument has no use, and the corresponding iter
- // operand (input) is yielded.
+ // 2) The region `iter` argument the corresponding input is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
bool forwarded = (arg == yielded) || (init == yielded) ||
More information about the Mlir-commits
mailing list