[llvm-branch-commits] [mlir] [mlir][SCF] Fold unused `index_switch` results (PR #173560)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Dec 28 10:10:27 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173560

>From da738b44231c3e2b2ede42380057774e12f0d9b3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 12:55:26 +0000
Subject: [PATCH 1/2] [mlir][SCF] Fold unused `index_switch` results

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 52 ++++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4a6b8aa7b1125..46d09abd89d69 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4797,9 +4797,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find dead results.
+    BitVector deadResults(op.getNumResults(), false);
+    SmallVector<Type> newResultTypes;
+    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+      if (!result.use_empty()) {
+        newResultTypes.push_back(result.getType());
+      } else {
+        deadResults[idx] = true;
+      }
+    }
+    if (!deadResults.any())
+      return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+    // Create new op without dead results and inline case regions.
+    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+                                       op.getArg(), op.getCases(),
+                                       op.getCaseRegions().size());
+    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+      // Remove respective operands from yield op.
+      Operation *terminator = newRegion.front().getTerminator();
+      assert(isa<YieldOp>(terminator) && "expected yield op");
+      rewriter.modifyOpInPlace(
+          terminator, [&]() { terminator->eraseOperands(deadResults); });
+    };
+    for (auto [oldRegion, newRegion] :
+         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+      inlineCaseRegion(oldRegion, newRegion);
+    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+    // Replace op with new op.
+    SmallVector<Value> newResults(op.getNumResults(), Value());
+    unsigned nextNewResult = 0;
+    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+      if (deadResults[idx])
+        continue;
+      newResults[idx] = newOp.getResult(nextNewResult++);
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
+  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 37851710ef010..984ea10f7e540 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2207,3 +2207,34 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
   }
   return %res#0, %res#1, %res#2 : i32, i32, i32
 }
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
+//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   case 1 {
+//       CHECK:     memref.store %[[c10]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   } 
+//       CHECK:   default {
+//       CHECK:     memref.store %[[c11]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   }
+//       CHECK:   return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+  %non_live, %live = scf.index_switch %arg0 -> i32, index
+  case 1 {
+    %c10 = arith.constant 10 : i32
+    memref.store %c10, %arg1[] : memref<i32>
+    scf.yield %c10, %arg0 : i32, index
+  }
+  default {
+    %c11 = arith.constant 11 : i32
+    memref.store %c11, %arg1[] : memref<i32>
+    scf.yield %c11, %arg0 : i32, index
+  }
+  return %live : index
+}

>From 9bef67465ba47702c8a24b59fedde3d1aee8f9ad Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 28 Dec 2025 19:10:20 +0100
Subject: [PATCH 2/2] Apply suggestions from code review

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Co-authored-by: lonely eagle <2020382038 at qq.com>
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 46d09abd89d69..178e344b8963e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4814,7 +4814,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
         deadResults[idx] = true;
       }
     }
-    if (!deadResults.any())
+    if (newResultTypes.size() != op.getNumResults())
       return rewriter.notifyMatchFailure(op, "no dead results to fold");
 
     // Create new op without dead results and inline case regions.
@@ -4837,7 +4837,7 @@ struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
     // Replace op with new op.
     SmallVector<Value> newResults(op.getNumResults(), Value());
     unsigned nextNewResult = 0;
-    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+    for (unsigned idx = 0, e = op.getNumResults(); idx < e; ++idx) {
       if (deadResults[idx])
         continue;
       newResults[idx] = newOp.getResult(nextNewResult++);



More information about the llvm-branch-commits mailing list