[Mlir-commits] [mlir] 9d8e634 - [mlir][scf] Always remove for iter args that are loop invariant (#121555)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 3 11:44:49 PST 2025


Author: Jeff Niu
Date: 2025-01-03T11:44:46-08:00
New Revision: 9d8e634e85ca46fbec07733d3e69d34c0d7814ac

URL: https://github.com/llvm/llvm-project/commit/9d8e634e85ca46fbec07733d3e69d34c0d7814ac
DIFF: https://github.com/llvm/llvm-project/commit/9d8e634e85ca46fbec07733d3e69d34c0d7814ac.diff

LOG: [mlir][scf] Always remove for iter args that are loop invariant (#121555)

This alters the condition in ForOpIterArgsFolder to always remove iter
args when their initial value equals the yielded value, not just when
the arg has no use.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..83ae79ce482669 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
 namespace {
 // Fold away ForOp iter arguments when:
 // 1) The op yields the iter arguments.
-// 2) The iter arguments have no use and the corresponding outer region
-// iterators (inputs) are yielded.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
 // 3) The iter arguments have no use and the corresponding (operation) results
 // have no use.
 //
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
-    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
-                             forOp.getRegionIterArgs(), // iter inside region
-                             forOp.getResults(),        // op results
-                             forOp.getYieldedValues()   // iter yield
-                             )) {
+    for (auto [init, arg, result, yielded] :
+         llvm::zip(forOp.getInitArgs(),       // iter from outside
+                   forOp.getRegionIterArgs(), // iter inside region
+                   forOp.getResults(),        // op results
+                   forOp.getYieldedValues()   // iter yield
+                   )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
-      // 2) The region `iter` argument has no use, and the corresponding iter
-      // operand (input) is yielded.
+      // 2) The region `iter` argument the corresponding input is yielded.
       // 3) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
-                        (std::get<1>(it).use_empty() &&
-                         (std::get<0>(it) == std::get<3>(it) ||
-                          std::get<2>(it).use_empty())));
+      bool forwarded = (arg == yielded) || (init == yielded) ||
+                       (arg.use_empty() && result.use_empty());
       keepMask.push_back(!forwarded);
       canonicalize |= forwarded;
       if (forwarded) {
-        newBlockTransferArgs.push_back(std::get<0>(it));
-        newResultValues.push_back(std::get<0>(it));
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
         continue;
       }
-      newIterArgs.push_back(std::get<0>(it));
-      newYieldValues.push_back(std::get<3>(it));
+      newIterArgs.push_back(init);
+      newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
       newResultValues.push_back(Value());      // placeholder with null value
     }

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
 
 // -----
 
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+  %c0_i32 = arith.constant 0 : i32
+  // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+  %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+    // CHECK-NEXT: "test.use"(%c0_i32)
+    "test.use"(%arg3) : (i32) -> ()
+    scf.yield %c0_i32 : i32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func.func @replace_true_if() {
   %true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 


        


More information about the Mlir-commits mailing list