[Mlir-commits] [mlir] [mlir][SCF] Fold unused `index_switch` results (PR #173560)

Matthias Springer llvmlistbot at llvm.org
Fri Dec 26 03:47:59 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173560

>From fea8332f3c3d3ed856655dee35b59c748eba9f3f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 10:14:32 +0000
Subject: [PATCH 1/2] [mlir][Transforms][NFC] `remove-dead-values`: Simplify
 dropped value handling

---
 mlir/lib/Transforms/RemoveDeadValues.cpp | 21 +++++++--------------
 1 file changed, 7 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 878624abee464..62ce5e0bbb77e 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -111,7 +111,6 @@ struct SuccessorOperandsToCleanup {
 
 struct RDVFinalCleanupList {
   SmallVector<Operation *> operations;
-  SmallVector<Value> values;
   SmallVector<FunctionToCleanUp> functions;
   SmallVector<OperandsToCleanup> operands;
   SmallVector<ResultsToCleanup> results;
@@ -325,10 +324,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
 
   // Do (1).
   for (auto [index, arg] : llvm::enumerate(arguments))
-    if (arg && nonLiveArgs[index]) {
-      cl.values.push_back(arg);
+    if (arg && nonLiveArgs[index])
       nonLiveSet.insert(arg);
-    }
 
   // Do (2). (Skip creating generic operand cleanup entries for call ops.
   // Call arguments will be removed in the call-site specific segment-aware
@@ -850,14 +847,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     op->erase();
   }
 
-  // 4. Values
-  LDBG() << "Cleaning up " << list.values.size() << " values";
-  for (auto &v : list.values) {
-    LDBG() << "Dropping all uses of value: " << v;
-    v.dropAllUses();
-  }
-
-  // 5. Functions
+  // 4. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
   // Record which function arguments were erased so we can shrink call-site
   // argument segments for CallOpInterface operations (e.g. ops using
@@ -874,6 +864,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
       llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
       os << "]";
     });
+    // Drop all uses of the dead arguments.
+    for (auto deadIdx : f.nonLiveArgs.set_bits())
+      f.funcOp.getArgument(deadIdx).dropAllUses();
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
@@ -885,7 +878,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
-  // 6. Operands
+  // 5. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperandsToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee reference.
@@ -934,7 +927,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     }
   }
 
-  // 7. Results
+  // 6. Results
   LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
     LDBG_OS([&](raw_ostream &os) {

>From da738b44231c3e2b2ede42380057774e12f0d9b3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 12:55:26 +0000
Subject: [PATCH 2/2] [mlir][SCF] Fold unused `index_switch` results

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 52 ++++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4a6b8aa7b1125..46d09abd89d69 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4797,9 +4797,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find dead results.
+    BitVector deadResults(op.getNumResults(), false);
+    SmallVector<Type> newResultTypes;
+    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+      if (!result.use_empty()) {
+        newResultTypes.push_back(result.getType());
+      } else {
+        deadResults[idx] = true;
+      }
+    }
+    if (!deadResults.any())
+      return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+    // Create new op without dead results and inline case regions.
+    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+                                       op.getArg(), op.getCases(),
+                                       op.getCaseRegions().size());
+    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+      // Remove respective operands from yield op.
+      Operation *terminator = newRegion.front().getTerminator();
+      assert(isa<YieldOp>(terminator) && "expected yield op");
+      rewriter.modifyOpInPlace(
+          terminator, [&]() { terminator->eraseOperands(deadResults); });
+    };
+    for (auto [oldRegion, newRegion] :
+         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+      inlineCaseRegion(oldRegion, newRegion);
+    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+    // Replace op with new op.
+    SmallVector<Value> newResults(op.getNumResults(), Value());
+    unsigned nextNewResult = 0;
+    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+      if (deadResults[idx])
+        continue;
+      newResults[idx] = newOp.getResult(nextNewResult++);
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
+  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 37851710ef010..984ea10f7e540 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2207,3 +2207,34 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
   }
   return %res#0, %res#1, %res#2 : i32, i32, i32
 }
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
+//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   case 1 {
+//       CHECK:     memref.store %[[c10]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   } 
+//       CHECK:   default {
+//       CHECK:     memref.store %[[c11]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   }
+//       CHECK:   return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+  %non_live, %live = scf.index_switch %arg0 -> i32, index
+  case 1 {
+    %c10 = arith.constant 10 : i32
+    memref.store %c10, %arg1[] : memref<i32>
+    scf.yield %c10, %arg0 : i32, index
+  }
+  default {
+    %c11 = arith.constant 11 : i32
+    memref.store %c11, %arg1[] : memref<i32>
+    scf.yield %c11, %arg0 : i32, index
+  }
+  return %live : index
+}



More information about the Mlir-commits mailing list