[llvm-branch-commits] [mlir] [mlir][SCF] Fold unused `index_switch` results (PR #173560)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Dec 28 10:10:27 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173560
>From da738b44231c3e2b2ede42380057774e12f0d9b3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 12:55:26 +0000
Subject: [PATCH 1/2] [mlir][SCF] Fold unused `index_switch` results
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++
2 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4a6b8aa7b1125..46d09abd89d69 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4797,9 +4797,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // Find dead results.
+ BitVector deadResults(op.getNumResults(), false);
+ SmallVector<Type> newResultTypes;
+ for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+ if (!result.use_empty()) {
+ newResultTypes.push_back(result.getType());
+ } else {
+ deadResults[idx] = true;
+ }
+ }
+ if (!deadResults.any())
+ return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+ // Create new op without dead results and inline case regions.
+ auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+ op.getArg(), op.getCases(),
+ op.getCaseRegions().size());
+ auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+ rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+ // Remove respective operands from yield op.
+ Operation *terminator = newRegion.front().getTerminator();
+ assert(isa<YieldOp>(terminator) && "expected yield op");
+ rewriter.modifyOpInPlace(
+ terminator, [&]() { terminator->eraseOperands(deadResults); });
+ };
+ for (auto [oldRegion, newRegion] :
+ llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+ inlineCaseRegion(oldRegion, newRegion);
+ inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+ // Replace op with new op.
+ SmallVector<Value> newResults(op.getNumResults(), Value());
+ unsigned nextNewResult = 0;
+ for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ if (deadResults[idx])
+ continue;
+ newResults[idx] = newOp.getResult(nextNewResult++);
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
+ results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 37851710ef010..984ea10f7e540 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2207,3 +2207,34 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
}
return %res#0, %res#1, %res#2 : i32, i32, i32
}
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10
+// CHECK-DAG: %[[c11:.*]] = arith.constant 11
+// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: case 1 {
+// CHECK: memref.store %[[c10]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: default {
+// CHECK: memref.store %[[c11]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+ %non_live, %live = scf.index_switch %arg0 -> i32, index
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10, %arg0 : i32, index
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ memref.store %c11, %arg1[] : memref<i32>
+ scf.yield %c11, %arg0 : i32, index
+ }
+ return %live : index
+}
>From 9bef67465ba47702c8a24b59fedde3d1aee8f9ad Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 28 Dec 2025 19:10:20 +0100
Subject: [PATCH 2/2] Apply suggestions from code review
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Co-authored-by: lonely eagle <2020382038 at qq.com>
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 46d09abd89d69..178e344b8963e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4814,7 +4814,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
deadResults[idx] = true;
}
}
- if (!deadResults.any())
+ if (newResultTypes.size() != op.getNumResults())
return rewriter.notifyMatchFailure(op, "no dead results to fold");
// Create new op without dead results and inline case regions.
@@ -4837,7 +4837,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
// Replace op with new op.
SmallVector<Value> newResults(op.getNumResults(), Value());
unsigned nextNewResult = 0;
- for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ for (unsigned idx = 0, e = op.getNumResults(); idx < e; ++idx) {
if (deadResults[idx])
continue;
newResults[idx] = newOp.getResult(nextNewResult++);
More information about the llvm-branch-commits
mailing list