[Mlir-commits] [mlir] [mlir][scf] Always remove for iter args that are loop invariant (PR #121555)

Jeff Niu llvmlistbot at llvm.org
Fri Jan 3 01:14:13 PST 2025


https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/121555

This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.

>From 4ded6315aa5e9f1184c5af8cecf3ff05094c3f14 Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeffniu at openai.com>
Date: Fri, 3 Jan 2025 01:12:02 -0800
Subject: [PATCH] [mlir][scf] Always remove for iter args that are loop
 invariant

This alters the condition in ForOpIterArgsFolder to always remove iter
args when their initial value equals the yielded value, not just when
the arg has no use.
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 25 ++++++++++++-------------
 mlir/test/Dialect/SCF/canonicalize.mlir | 22 ++++++++++++++++++----
 2 files changed, 30 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..872d34de4495bf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -872,30 +872,29 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
-    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
-                             forOp.getRegionIterArgs(), // iter inside region
-                             forOp.getResults(),        // op results
-                             forOp.getYieldedValues()   // iter yield
-                             )) {
+    for (auto [init, arg, result, yielded] :
+         llvm::zip(forOp.getInitArgs(),       // iter from outside
+                   forOp.getRegionIterArgs(), // iter inside region
+                   forOp.getResults(),        // op results
+                   forOp.getYieldedValues()   // iter yield
+                   )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
       // 2) The region `iter` argument has no use, and the corresponding iter
       // operand (input) is yielded.
       // 3) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
-                        (std::get<1>(it).use_empty() &&
-                         (std::get<0>(it) == std::get<3>(it) ||
-                          std::get<2>(it).use_empty())));
+      bool forwarded = (arg == yielded) || (init == yielded) ||
+                       (arg.use_empty() && result.use_empty());
       keepMask.push_back(!forwarded);
       canonicalize |= forwarded;
       if (forwarded) {
-        newBlockTransferArgs.push_back(std::get<0>(it));
-        newResultValues.push_back(std::get<0>(it));
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
         continue;
       }
-      newIterArgs.push_back(std::get<0>(it));
-      newYieldValues.push_back(std::get<3>(it));
+      newIterArgs.push_back(init);
+      newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
       newResultValues.push_back(Value());      // placeholder with null value
     }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
 
 // -----
 
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+  %c0_i32 = arith.constant 0 : i32
+  // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+  %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+    // CHECK-NEXT: "test.use"(%c0_i32)
+    "test.use"(%arg3) : (i32) -> ()
+    scf.yield %c0_i32 : i32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func.func @replace_true_if() {
   %true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 



More information about the Mlir-commits mailing list