[Mlir-commits] [mlir] 0b3943f - [MLIR][SCF] fix scf.index_switch fold convergence (#98535) (#98680)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 15 23:37:08 PDT 2024
Author: Keyi Zhang
Date: 2024-07-16T08:37:04+02:00
New Revision: 0b3943f3bba71e0cf9ea9984dae472cf503bca21
URL: https://github.com/llvm/llvm-project/commit/0b3943f3bba71e0cf9ea9984dae472cf503bca21
DIFF: https://github.com/llvm/llvm-project/commit/0b3943f3bba71e0cf9ea9984dae472cf503bca21.diff
LOG: [MLIR][SCF] fix scf.index_switch fold convergence (#98535) (#98680)
If the `scf.index_switch` op has no result, the current fold logic
results in an infinite loop (see #98535). The is because `fold`
mechanism does not support *erasing* zero-result ops. This PR moves the
fold logic to a canonicalizer and fix the issue.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index f35ea962bea16..acbcbae105dbf 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1159,7 +1159,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
Block &getCaseBlock(unsigned idx);
}];
- let hasFolder = 1;
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 907d7f794593d..4de8dacc0edbf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4297,33 +4297,42 @@ void IndexSwitchOp::getRegionInvocationBounds(
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
}
-LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
- if (!maybeCst.has_value())
- return failure();
- int64_t cst = *maybeCst;
- int64_t caseIdx, e = getNumCases();
- for (caseIdx = 0; caseIdx < e; ++caseIdx) {
- if (cst == getCases()[caseIdx])
- break;
- }
+struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
+ using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
- Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
- : getDefaultRegion();
- Block &source = r.front();
- results.assign(source.getTerminator()->getOperands().begin(),
- source.getTerminator()->getOperands().end());
+ LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // If `op.getArg()` is a constant, select the region that matches with
+ // the constant value. Use the default region if no matche is found.
+ std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
+ if (!maybeCst.has_value())
+ return failure();
+ int64_t cst = *maybeCst;
+ int64_t caseIdx, e = op.getNumCases();
+ for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+ if (cst == op.getCases()[caseIdx])
+ break;
+ }
- Block *pDestination = (*this)->getBlock();
- if (!pDestination)
- return failure();
- Block::iterator insertionPoint = (*this)->getIterator();
- pDestination->getOperations().splice(insertionPoint, source.getOperations(),
- source.getOperations().begin(),
- std::prev(source.getOperations().end()));
+ Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
+ : op.getDefaultRegion();
+ Block &source = r.front();
+ Operation *terminator = source.getTerminator();
+ SmallVector<Value> results = terminator->getOperands();
- return success();
+ rewriter.inlineBlockBefore(&source, op);
+ rewriter.eraseOp(terminator);
+ // Repalce the operation with a potentially empty list of results.
+ // Fold mechanism doesn't support the case where the result list is empty.
+ rewriter.replaceOp(op, results);
+
+ return success();
+ }
+};
+
+void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldConstantCase>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 459ccd73cfe61..268946803de7a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1846,3 +1846,21 @@ func.func @index_switch_fold() -> (f32, f32) {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[c42:.*]] = arith.constant 4.200000e+01 : f32
// CHECK-NEXT: return %[[c1]], %[[c42]] : f32, f32
+
+// -----
+
+func.func @index_switch_fold_no_res() {
+ %c1 = arith.constant 1 : index
+ scf.index_switch %c1
+ case 0 {
+ scf.yield
+ }
+ default {
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @index_switch_fold_no_res()
+// CHECK-NEXT: "test.op"() : () -> ()
More information about the Mlir-commits
mailing list