[clang] [CIR] Upstream support for FlattenCFG switch and SwitchFlatOp (PR #139154)
Andy Kaylor via cfe-commits
cfe-commits at lists.llvm.org
Thu May 8 16:16:19 PDT 2025
================
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
}
};
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+ using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+ inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+ cir::YieldOp yieldOp,
+ mlir::Block *destination) const {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+ destination);
+ }
+
+ // Return the new defaultDestination block.
+ Block *condBrToRangeDestination(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter,
+ mlir::Block *rangeDestination,
+ mlir::Block *defaultDestination,
+ const APInt &lowerBound,
+ const APInt &upperBound) const {
+ assert(lowerBound.sle(upperBound) && "Invalid range");
+ mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+ cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+ cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+ cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+ cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+ cir::BinOp diffValue =
+ rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+ op.getCondition(), lowerBoundValue);
+
+ // Use unsigned comparison to check if the condition is in the range.
+ cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, diffValue);
+ cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+ cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+ op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+ uDiffValue, uRangeLength);
+ rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+ defaultDestination);
+ return resBlock;
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ llvm::SmallVector<CaseOp> cases;
+ op.collectCases(cases);
+
+ // Empty switch statement: just erase it.
+ if (cases.empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Create exit block from the next node of cir.switch op.
+ mlir::Block *exitBlock = rewriter.splitBlock(
+ rewriter.getBlock(), op->getNextNode()->getIterator());
+
+ // We lower cir.switch op in the following process:
+ // 1. Inline the region from the switch op after switch op.
+ // 2. Traverse each cir.case op:
+ // a. Record the entry block, block arguments and condition for every
+ // case. b. Inline the case region after the case op.
+ // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+ // recorded block and conditions.
+
+ // inline everything from switch body between the switch op and the exit
+ // block.
+ {
+ cir::YieldOp switchYield = nullptr;
+ // Clear switch operation.
+ for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+ switchYield = yieldOp;
+
+ assert(!op.getBody().empty());
+ mlir::Block *originalBlock = op->getBlock();
+ mlir::Block *swopBlock =
+ rewriter.splitBlock(originalBlock, op->getIterator());
+ rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+ if (switchYield)
+ rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+ rewriter.setInsertionPointToEnd(originalBlock);
+ rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+ }
+
+ // Allocate required data structures (disconsider default case in
+ // vectors).
+ llvm::SmallVector<mlir::APInt, 8> caseValues;
+ llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+ llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+ llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+ llvm::SmallVector<mlir::Block *> rangeDestinations;
+ llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+ // Initialize default case as optional.
+ mlir::Block *defaultDestination = exitBlock;
+ mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+ // Digest the case statements values and bodies.
+ for (auto caseOp : cases) {
----------------
andykaylor wrote:
```suggestion
for (cir::CaseOp caseOp : cases) {
```
https://github.com/llvm/llvm-project/pull/139154
More information about the cfe-commits
mailing list