[Mlir-commits] [mlir] [mlir][scf] Always remove for iter args that are loop invariant (PR #121555)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 3 01:14:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jeff Niu (Mogball)
<details>
<summary>Changes</summary>
This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.
---
Full diff: https://github.com/llvm/llvm-project/pull/121555.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+12-13)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+18-4)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..872d34de4495bf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -872,30 +872,29 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
- for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
+ for (auto [init, arg, result, yielded] :
+ llvm::zip(forOp.getInitArgs(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ forOp.getResults(), // op results
+ forOp.getYieldedValues() // iter yield
+ )) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
// 2) The region `iter` argument has no use, and the corresponding iter
// operand (input) is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
- (std::get<1>(it).use_empty() &&
- (std::get<0>(it) == std::get<3>(it) ||
- std::get<2>(it).use_empty())));
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
keepMask.push_back(!forwarded);
canonicalize |= forwarded;
if (forwarded) {
- newBlockTransferArgs.push_back(std::get<0>(it));
- newResultValues.push_back(std::get<0>(it));
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
continue;
}
- newIterArgs.push_back(std::get<0>(it));
- newYieldValues.push_back(std::get<3>(it));
+ newIterArgs.push_back(init);
+ newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
newResultValues.push_back(Value()); // placeholder with null value
}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
// -----
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+ %c0_i32 = arith.constant 0 : i32
+ // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+ %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+ // CHECK-NEXT: "test.use"(%c0_i32)
+ "test.use"(%arg3) : (i32) -> ()
+ scf.yield %c0_i32 : i32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: @replace_true_if
func.func @replace_true_if() {
%true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
%switch_cst_2 = arith.constant 2: index
%1 = scf.index_switch %switch_cst_2 -> f32
case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
return %0, %1 : f32, f32
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121555
More information about the Mlir-commits
mailing list