[Mlir-commits] [mlir] [mlir][scf] Remove identical `scf.for` iter args (PR #127145)
Jeff Niu
llvmlistbot at llvm.org
Thu Feb 13 15:28:12 PST 2025
https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/127145
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
>From 6e01664829f3563985988fbd3502ab24aece98d1 Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeffniu at openai.com>
Date: Thu, 13 Feb 2025 15:26:42 -0800
Subject: [PATCH] [mlir][scf] Remove identical `scf.for` iter args
This augments the iter arg canonicalizer to remove iter args that always
have the same value, i.e. their correpsonding init and yielded values
are the same.
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 29 ++++++++++++++++++++-----
mlir/test/Dialect/SCF/canonicalize.mlir | 18 +++++++++++++++
2 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83ae79ce48266..448141735ba7f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -843,9 +843,8 @@ namespace {
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
-// These arguments must be defined outside of
-// the ForOp region and can just be forwarded after simplifying the op inits,
-// yields and returns.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
//
// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
+ DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
for (auto [init, arg, result, yielded] :
llvm::zip(forOp.getInitArgs(), // iter from outside
forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
// result has no use.
bool forwarded = (arg == yielded) || (init == yielded) ||
(arg.use_empty() && result.use_empty());
- keepMask.push_back(!forwarded);
- canonicalize |= forwarded;
if (forwarded) {
+ canonicalize = true;
+ keepMask.push_back(false);
newBlockTransferArgs.push_back(init);
newResultValues.push_back(init);
continue;
}
+
+ // Check if a previous kept argument always has the same values for init
+ // and yielded values.
+ if (auto it = initYieldToArg.find({init, yielded});
+ it != initYieldToArg.end()) {
+ canonicalize = true;
+ keepMask.push_back(false);
+ auto [sameArg, sameResult] = it->second;
+ rewriter.replaceAllUsesWith(arg, sameArg);
+ rewriter.replaceAllUsesWith(result, sameResult);
+ // The replacement value doesn't matter because there are no uses.
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
+ continue;
+ }
+
+ // This value is kept.
+ initYieldToArg.insert({{init, yielded}, {arg, result}});
+ keepMask.push_back(true);
newIterArgs.push_back(init);
newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 828758df6d31c..c18bd617216f1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
// -----
+// CHECK-LABEL: @replace_duplicate_iter_args
+// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
+func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
+ // CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
+ %0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
+ // CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
+ %1 = arith.addi %k0, %k1 : index
+ // CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
+ %2 = arith.addi %k2, %k3 : index
+ // CHECK-NEXT: yield [[V0]], [[V1]]
+ scf.yield %1, %2, %2, %1 : index, index, index, index
+ }
+ // CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
+ return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
+}
+
+// -----
+
func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
More information about the Mlir-commits
mailing list