[Mlir-commits] [mlir] [mlir][SCF] Fold unused `index_switch` results (PR #173560)
Matthias Springer
llvmlistbot at llvm.org
Fri Dec 26 03:47:59 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173560
>From fea8332f3c3d3ed856655dee35b59c748eba9f3f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 10:14:32 +0000
Subject: [PATCH 1/2] [mlir][Transforms][NFC] `remove-dead-values`: Simplify
dropped value handling
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 21 +++++++--------------
1 file changed, 7 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 878624abee464..62ce5e0bbb77e 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -111,7 +111,6 @@ struct SuccessorOperandsToCleanup {
struct RDVFinalCleanupList {
SmallVector<Operation *> operations;
- SmallVector<Value> values;
SmallVector<FunctionToCleanUp> functions;
SmallVector<OperandsToCleanup> operands;
SmallVector<ResultsToCleanup> results;
@@ -325,10 +324,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Do (1).
for (auto [index, arg] : llvm::enumerate(arguments))
- if (arg && nonLiveArgs[index]) {
- cl.values.push_back(arg);
+ if (arg && nonLiveArgs[index])
nonLiveSet.insert(arg);
- }
// Do (2). (Skip creating generic operand cleanup entries for call ops.
// Call arguments will be removed in the call-site specific segment-aware
@@ -850,14 +847,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
op->erase();
}
- // 4. Values
- LDBG() << "Cleaning up " << list.values.size() << " values";
- for (auto &v : list.values) {
- LDBG() << "Dropping all uses of value: " << v;
- v.dropAllUses();
- }
-
- // 5. Functions
+ // 4. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
@@ -874,6 +864,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
os << "]";
});
+ // Drop all uses of the dead arguments.
+ for (auto deadIdx : f.nonLiveArgs.set_bits())
+ f.funcOp.getArgument(deadIdx).dropAllUses();
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -885,7 +878,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}
- // 6. Operands
+ // 5. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperandsToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
@@ -934,7 +927,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
}
- // 7. Results
+ // 6. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG_OS([&](raw_ostream &os) {
>From da738b44231c3e2b2ede42380057774e12f0d9b3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 25 Dec 2025 12:55:26 +0000
Subject: [PATCH 2/2] [mlir][SCF] Fold unused `index_switch` results
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++++++++++++++-
mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++
2 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4a6b8aa7b1125..46d09abd89d69 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4797,9 +4797,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // Find dead results.
+ BitVector deadResults(op.getNumResults(), false);
+ SmallVector<Type> newResultTypes;
+ for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+ if (!result.use_empty()) {
+ newResultTypes.push_back(result.getType());
+ } else {
+ deadResults[idx] = true;
+ }
+ }
+ if (!deadResults.any())
+ return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+ // Create new op without dead results and inline case regions.
+ auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+ op.getArg(), op.getCases(),
+ op.getCaseRegions().size());
+ auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+ rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+ // Remove respective operands from yield op.
+ Operation *terminator = newRegion.front().getTerminator();
+ assert(isa<YieldOp>(terminator) && "expected yield op");
+ rewriter.modifyOpInPlace(
+ terminator, [&]() { terminator->eraseOperands(deadResults); });
+ };
+ for (auto [oldRegion, newRegion] :
+ llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+ inlineCaseRegion(oldRegion, newRegion);
+ inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+ // Replace op with new op.
+ SmallVector<Value> newResults(op.getNumResults(), Value());
+ unsigned nextNewResult = 0;
+ for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ if (deadResults[idx])
+ continue;
+ newResults[idx] = newOp.getResult(nextNewResult++);
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
+ results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 37851710ef010..984ea10f7e540 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2207,3 +2207,34 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
}
return %res#0, %res#1, %res#2 : i32, i32, i32
}
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10
+// CHECK-DAG: %[[c11:.*]] = arith.constant 11
+// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: case 1 {
+// CHECK: memref.store %[[c10]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: default {
+// CHECK: memref.store %[[c11]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+ %non_live, %live = scf.index_switch %arg0 -> i32, index
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10, %arg0 : i32, index
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ memref.store %c11, %arg1[] : memref<i32>
+ scf.yield %c11, %arg0 : i32, index
+ }
+ return %live : index
+}
More information about the Mlir-commits
mailing list