[Mlir-commits] [mlir] [mlir][scf] Remove identical `scf.for` iter args (PR #127145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 13 15:28:48 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

<details>
<summary>Changes</summary>

This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.

---
Full diff: https://github.com/llvm/llvm-project/pull/127145.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+24-5) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83ae79ce48266..448141735ba7f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -843,9 +843,8 @@ namespace {
 // 3) The iter arguments have no use and the corresponding (operation) results
 // have no use.
 //
-// These arguments must be defined outside of
-// the ForOp region and can just be forwarded after simplifying the op inits,
-// yields and returns.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
 //
 // The implementation uses `inlineBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
+    DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
     for (auto [init, arg, result, yielded] :
          llvm::zip(forOp.getInitArgs(),       // iter from outside
                    forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
       // result has no use.
       bool forwarded = (arg == yielded) || (init == yielded) ||
                        (arg.use_empty() && result.use_empty());
-      keepMask.push_back(!forwarded);
-      canonicalize |= forwarded;
       if (forwarded) {
+        canonicalize = true;
+        keepMask.push_back(false);
         newBlockTransferArgs.push_back(init);
         newResultValues.push_back(init);
         continue;
       }
+
+      // Check if a previous kept argument always has the same values for init
+      // and yielded values.
+      if (auto it = initYieldToArg.find({init, yielded});
+          it != initYieldToArg.end()) {
+        canonicalize = true;
+        keepMask.push_back(false);
+        auto [sameArg, sameResult] = it->second;
+        rewriter.replaceAllUsesWith(arg, sameArg);
+        rewriter.replaceAllUsesWith(result, sameResult);
+        // The replacement value doesn't matter because there are no uses.
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
+        continue;
+      }
+
+      // This value is kept.
+      initYieldToArg.insert({{init, yielded}, {arg, result}});
+      keepMask.push_back(true);
       newIterArgs.push_back(init);
       newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 828758df6d31c..c18bd617216f1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
 
 // -----
 
+// CHECK-LABEL: @replace_duplicate_iter_args
+// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
+func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
+  // CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
+  %0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
+    // CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
+    %1 = arith.addi %k0, %k1 : index
+    // CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
+    %2 = arith.addi %k2, %k3 : index
+    // CHECK-NEXT: yield [[V0]], [[V1]]
+    scf.yield %1, %2, %2, %1 : index, index, index, index
+  }
+  // CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
+  return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
+}
+
+// -----
+
 func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
 
 func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/127145


More information about the Mlir-commits mailing list