[llvm-branch-commits] [mlir] [mlir][SCF] Fold unused `index_switch` results (PR #173560)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 25 04:57:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a canonicalization pattern to fold unused `scf.index_switch` results.
---
Full diff: https://github.com/llvm/llvm-project/pull/173560.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+51-1)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+31)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // Find dead results.
+ BitVector deadResults(op.getNumResults(), false);
+ SmallVector<Type> newResultTypes;
+ for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+ if (!result.use_empty()) {
+ newResultTypes.push_back(result.getType());
+ } else {
+ deadResults[idx] = true;
+ }
+ }
+ if (!deadResults.any())
+ return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+ // Create new op without dead results and inline case regions.
+ auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+ op.getArg(), op.getCases(),
+ op.getCaseRegions().size());
+ auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+ rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+ // Remove respective operands from yield op.
+ Operation *terminator = newRegion.front().getTerminator();
+ assert(isa<YieldOp>(terminator) && "expected yield op");
+ rewriter.modifyOpInPlace(
+ terminator, [&]() { terminator->eraseOperands(deadResults); });
+ };
+ for (auto [oldRegion, newRegion] :
+ llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+ inlineCaseRegion(oldRegion, newRegion);
+ inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+ // Replace op with new op.
+ SmallVector<Value> newResults(op.getNumResults(), Value());
+ unsigned nextNewResult = 0;
+ for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ if (deadResults[idx])
+ continue;
+ newResults[idx] = newOp.getResult(nextNewResult++);
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
+ results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0() {
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10
+// CHECK-DAG: %[[c11:.*]] = arith.constant 11
+// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: case 1 {
+// CHECK: memref.store %[[c10]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: default {
+// CHECK: memref.store %[[c11]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+ %non_live, %live = scf.index_switch %arg0 -> i32, index
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10, %arg0 : i32, index
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ memref.store %c11, %arg1[] : memref<i32>
+ scf.yield %c11, %arg0 : i32, index
+ }
+ return %live : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/173560
More information about the llvm-branch-commits
mailing list