[Mlir-commits] [mlir] [MLIR][SCF] fix scf.index_switch fold convergence (#98535) (PR #98680)
Keyi Zhang
llvmlistbot at llvm.org
Mon Jul 15 16:37:09 PDT 2024
================
@@ -4297,33 +4297,38 @@ void IndexSwitchOp::getRegionInvocationBounds(
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
}
-LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
- if (!maybeCst.has_value())
- return failure();
- int64_t cst = *maybeCst;
- int64_t caseIdx, e = getNumCases();
- for (caseIdx = 0; caseIdx < e; ++caseIdx) {
- if (cst == getCases()[caseIdx])
- break;
- }
+struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
+ using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
- Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
- : getDefaultRegion();
- Block &source = r.front();
- results.assign(source.getTerminator()->getOperands().begin(),
- source.getTerminator()->getOperands().end());
+ LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
+ if (!maybeCst.has_value())
+ return failure();
+ int64_t cst = *maybeCst;
+ int64_t caseIdx, e = op.getNumCases();
+ for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+ if (cst == op.getCases()[caseIdx])
+ break;
+ }
- Block *pDestination = (*this)->getBlock();
- if (!pDestination)
- return failure();
- Block::iterator insertionPoint = (*this)->getIterator();
- pDestination->getOperations().splice(insertionPoint, source.getOperations(),
- source.getOperations().begin(),
- std::prev(source.getOperations().end()));
+ Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
+ : op.getDefaultRegion();
+ Block &source = r.front();
+ Operation *terminator = source.getTerminator();
+ SmallVector<Value> results = terminator->getOperands();
- return success();
+ rewriter.inlineBlockBefore(&source, op);
+ rewriter.eraseOp(terminator);
+ rewriter.replaceOp(op, results);
+
+ return success();
+ }
----------------
Kuree wrote:
Thanks. I've updated the PR description and more comments there.
https://github.com/llvm/llvm-project/pull/98680
More information about the Mlir-commits
mailing list