[clang] [CIR] Upstream support for FlattenCFG switch and SwitchFlatOp (PR #139154)
via cfe-commits
cfe-commits at lists.llvm.org
Thu May 15 15:36:30 PDT 2025
https://github.com/Andres-Salamanca updated https://github.com/llvm/llvm-project/pull/139154
>From 803abd7bfa9aa48ba3446e6de8ffb1f20b16ec26 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Thu, 8 May 2025 15:39:49 -0500
Subject: [PATCH 1/7] Add support for FlattenCFG switch and introduce
SwitchFlatOp
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 46 +++
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 97 ++++++
.../Dialect/Transforms/CIRCanonicalize.cpp | 16 +-
.../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 235 ++++++++++++++-
clang/test/CIR/IR/switch-flat.cir | 68 +++++
clang/test/CIR/Transforms/switch.cir | 278 ++++++++++++++++++
6 files changed, 734 insertions(+), 6 deletions(-)
create mode 100644 clang/test/CIR/IR/switch-flat.cir
create mode 100644 clang/test/CIR/Transforms/switch.cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e08f372450285..87b9823da6ddc 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch",
}];
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
+ Terminator]> {
+
+ let description = [{
+ The `cir.switch.flat` operation is a region-less and simplified
+ version of the `cir.switch`.
+ It's representation is closer to LLVM IR dialect
+ than the C/C++ language feature.
+ }];
+
+ let arguments = (ins
+ CIR_IntType:$condition,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
+ ArrayAttr:$case_values,
+ DenseI32ArrayAttr:$case_operand_segments
+ );
+
+ let successors = (successor
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations
+ );
+
+ let assemblyFormat = [{
+ $condition `:` type($condition) `,`
+ $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+ custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+ $caseDestinations, $caseOperands,
+ type($caseOperands))
+ attr-dict
+ }];
+
+ let builders = [
+ OpBuilder<(ins "mlir::Value":$condition,
+ "mlir::Block *":$defaultDestination,
+ "mlir::ValueRange":$defaultOperands,
+ CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues,
+ CArg<"mlir::BlockRange", "{}">:$caseDestinations,
+ CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// BrOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 779114e09d834..dd7ee4a2c1adf 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
#include "clang/CIR/MissingFeatures.h"
+#include <numeric>
using namespace mlir;
using namespace cir;
@@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
});
}
+//===----------------------------------------------------------------------===//
+// SwitchFlatOp
+//===----------------------------------------------------------------------===//
+
+void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
+ Value value, Block *defaultDestination,
+ ValueRange defaultOperands,
+ ArrayRef<APInt> caseValues,
+ BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands) {
+
+ std::vector<mlir::Attribute> caseValuesAttrs;
+ for (auto &val : caseValues) {
+ caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
+ }
+ mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
+
+ build(builder, result, value, defaultOperands, caseOperands, attrs,
+ defaultDestination, caseDestinations);
+}
+
+/// <cases> ::= `[` (case (`,` case )* )? `]`
+/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+static ParseResult parseSwitchFlatOpCases(
+ OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>>
+ &caseOperands,
+ SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ if (succeeded(parser.parseOptionalRSquare()))
+ return success();
+ llvm::SmallVector<mlir::Attribute> values;
+
+ auto parseCase = [&]() {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+
+ values.push_back(cir::IntAttr::get(flagType, value));
+
+ Block *destination;
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<Type> operandTypes;
+ if (parser.parseColon() || parser.parseSuccessor(destination))
+ return failure();
+ if (!parser.parseOptionalLParen()) {
+ if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(operandTypes) || parser.parseRParen())
+ return failure();
+ }
+ caseDestinations.push_back(destination);
+ caseOperands.emplace_back(operands);
+ caseOperandTypes.emplace_back(operandTypes);
+ return success();
+ };
+ if (failed(parser.parseCommaSeparatedList(parseCase)))
+ return failure();
+
+ caseValues = ArrayAttr::get(flagType.getContext(), values);
+
+ return parser.parseRSquare();
+}
+
+static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op,
+ Type flagType, mlir::ArrayAttr caseValues,
+ SuccessorRange caseDestinations,
+ OperandRangeRange caseOperands,
+ const TypeRangeRange &caseOperandTypes) {
+ p << '[';
+ p.printNewline();
+ if (!caseValues) {
+ p << ']';
+ return;
+ }
+
+ size_t index = 0;
+ llvm::interleave(
+ llvm::zip(caseValues, caseDestinations),
+ [&](auto i) {
+ p << " ";
+ mlir::Attribute a = std::get<0>(i);
+ p << mlir::cast<cir::IntAttr>(a).getValue();
+ p << ": ";
+ p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
+ },
+ [&] {
+ p << ',';
+ p.printNewline();
+ });
+ p.printNewline();
+ p << ']';
+}
+
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 798bc0dab9384..027b08907346b 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
}
};
+struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const final {
+ if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
+ return failure();
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRCanonicalizePass
//===----------------------------------------------------------------------===//
@@ -127,8 +140,7 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::callOp());
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp, VecExtractOp>(
- op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, VecExtractOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 4a936d33b022a..70f383b556567 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
}
};
+class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
+public:
+ using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
+
+ inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
+ cir::YieldOp yieldOp,
+ mlir::Block *destination) const {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
+ destination);
+ }
+
+ // Return the new defaultDestination block.
+ Block *condBrToRangeDestination(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter,
+ mlir::Block *rangeDestination,
+ mlir::Block *defaultDestination,
+ const APInt &lowerBound,
+ const APInt &upperBound) const {
+ assert(lowerBound.sle(upperBound) && "Invalid range");
+ mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
+ cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
+ cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
+
+ cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));
+
+ cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(
+ op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
+ cir::BinOp diffValue =
+ rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,
+ op.getCondition(), lowerBoundValue);
+
+ // Use unsigned comparison to check if the condition is in the range.
+ cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, diffValue);
+ cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(
+ op.getLoc(), uIntType, CastKind::integral, rangeLength);
+
+ cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(
+ op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,
+ uDiffValue, uRangeLength);
+ rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,
+ defaultDestination);
+ return resBlock;
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::SwitchOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ llvm::SmallVector<CaseOp> cases;
+ op.collectCases(cases);
+
+ // Empty switch statement: just erase it.
+ if (cases.empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Create exit block from the next node of cir.switch op.
+ mlir::Block *exitBlock = rewriter.splitBlock(
+ rewriter.getBlock(), op->getNextNode()->getIterator());
+
+ // We lower cir.switch op in the following process:
+ // 1. Inline the region from the switch op after switch op.
+ // 2. Traverse each cir.case op:
+ // a. Record the entry block, block arguments and condition for every
+ // case. b. Inline the case region after the case op.
+ // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
+ // recorded block and conditions.
+
+ // inline everything from switch body between the switch op and the exit
+ // block.
+ {
+ cir::YieldOp switchYield = nullptr;
+ // Clear switch operation.
+ for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
+ switchYield = yieldOp;
+
+ assert(!op.getBody().empty());
+ mlir::Block *originalBlock = op->getBlock();
+ mlir::Block *swopBlock =
+ rewriter.splitBlock(originalBlock, op->getIterator());
+ rewriter.inlineRegionBefore(op.getBody(), exitBlock);
+
+ if (switchYield)
+ rewriteYieldOp(rewriter, switchYield, exitBlock);
+
+ rewriter.setInsertionPointToEnd(originalBlock);
+ rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);
+ }
+
+ // Allocate required data structures (disconsider default case in
+ // vectors).
+ llvm::SmallVector<mlir::APInt, 8> caseValues;
+ llvm::SmallVector<mlir::Block *, 8> caseDestinations;
+ llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
+
+ llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
+ llvm::SmallVector<mlir::Block *> rangeDestinations;
+ llvm::SmallVector<mlir::ValueRange> rangeOperands;
+
+ // Initialize default case as optional.
+ mlir::Block *defaultDestination = exitBlock;
+ mlir::ValueRange defaultOperands = exitBlock->getArguments();
+
+ // Digest the case statements values and bodies.
+ for (auto caseOp : cases) {
+ mlir::Region ®ion = caseOp.getCaseRegion();
+
+ // Found default case: save destination and operands.
+ switch (caseOp.getKind()) {
+ case cir::CaseOpKind::Default:
+ defaultDestination = ®ion.front();
+ defaultOperands = defaultDestination->getArguments();
+ break;
+ case cir::CaseOpKind::Range:
+ assert(caseOp.getValue().size() == 2 &&
+ "Case range should have 2 case value");
+ rangeValues.push_back(
+ {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
+ cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
+ rangeDestinations.push_back(®ion.front());
+ rangeOperands.push_back(rangeDestinations.back()->getArguments());
+ break;
+ case cir::CaseOpKind::Anyof:
+ case cir::CaseOpKind::Equal:
+ // AnyOf cases kind can have multiple values, hence the loop below.
+ for (auto &value : caseOp.getValue()) {
+ caseValues.push_back(cast<cir::IntAttr>(value).getValue());
+ caseDestinations.push_back(®ion.front());
+ caseOperands.push_back(caseDestinations.back()->getArguments());
+ }
+ break;
+ }
+
+ // Handle break statements.
+ walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
+ region, [&](mlir::Operation *op) {
+ if (!isa<cir::BreakOp>(op))
+ return mlir::WalkResult::advance();
+
+ lowerTerminator(op, exitBlock, rewriter);
+ return mlir::WalkResult::skip();
+ });
+
+ // Track fallthrough in cases.
+ for (auto &blk : region.getBlocks()) {
+ if (blk.getNumSuccessors())
+ continue;
+
+ if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
+ mlir::Operation *nextOp = caseOp->getNextNode();
+ assert(nextOp && "caseOp is not expected to be the last op");
+ mlir::Block *oldBlock = nextOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, nextOp->getIterator());
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),
+ newBlock);
+ rewriteYieldOp(rewriter, yieldOp, newBlock);
+ }
+ }
+
+ mlir::Block *oldBlock = caseOp->getBlock();
+ mlir::Block *newBlock =
+ rewriter.splitBlock(oldBlock, caseOp->getIterator());
+
+ mlir::Block &entryBlock = caseOp.getCaseRegion().front();
+ rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
+
+ // Create a branch to the entry of the inlined region.
+ rewriter.setInsertionPointToEnd(oldBlock);
+ rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);
+ }
+
+ // Remove all cases since we've inlined the regions.
+ for (auto caseOp : cases) {
+ mlir::Block *caseBlock = caseOp->getBlock();
+ // Erase the block with no predecessors here to make the generated code
+ // simpler a little bit.
+ if (caseBlock->hasNoPredecessors())
+ rewriter.eraseBlock(caseBlock);
+ else
+ rewriter.eraseOp(caseOp);
+ }
+
+ for (size_t index = 0; index < rangeValues.size(); ++index) {
+ APInt lowerBound = rangeValues[index].first;
+ APInt upperBound = rangeValues[index].second;
+
+ // The case range is unreachable, skip it.
+ if (lowerBound.sgt(upperBound))
+ continue;
+
+ // If range is small, add multiple switch instruction cases.
+ // This magical number is from the original CGStmt code.
+ constexpr int kSmallRangeThreshold = 64;
+ if ((upperBound - lowerBound)
+ .ult(llvm::APInt(32, kSmallRangeThreshold))) {
+ for (APInt iValue = lowerBound; iValue.sle(upperBound);
+ (void)iValue++) {
+ caseValues.push_back(iValue);
+ caseOperands.push_back(rangeOperands[index]);
+ caseDestinations.push_back(rangeDestinations[index]);
+ }
+ continue;
+ }
+
+ defaultDestination =
+ condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+ defaultDestination, lowerBound, upperBound);
+ defaultOperands = rangeOperands[index];
+ }
+
+ // Set switch op to branch to the newly created blocks.
+ rewriter.setInsertionPoint(op);
+ rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
+ op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
+ caseDestinations, caseOperands);
+
+ return mlir::success();
+ }
+};
+
class CIRLoopOpInterfaceFlattening
: public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
public:
@@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
};
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
- patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
- CIRScopeOpFlattening, CIRTernaryOpFlattening>(
- patterns.getContext());
+ patterns
+ .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
+ CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+ patterns.getContext());
}
void CIRFlattenCFGPass::runOnOperation() {
@@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
+ if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
new file mode 100644
index 0000000000000..b072c224b4a2c
--- /dev/null
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb2
+ ^bb2:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb2
+// CHECK-NEXT: ^bb2:
+//CHECK-NEXT: cir.return
+
+cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb2 [
+ 1: ^bb1
+ ]
+ ^bb1:
+ cir.br ^bb3
+ ^bb2:
+ cir.br ^bb3
+ ^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [
+// CHECK-NEXT: 1: ^bb1
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
+
+cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) {
+ cir.switch.flat %arg0 : !s32i, ^bb3 [
+ 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+ 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+ ]
+^bb1:
+ cir.br ^bb3
+
+^bb2:
+ cir.br ^bb3
+
+^bb3:
+ cir.return
+}
+
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [
+// CHECK-NEXT: 0: ^bb1(%arg1, %arg2 : !s32i, !s32i),
+// CHECK-NEXT: 1: ^bb2(%arg2, %arg1 : !s32i, !s32i)
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb2:
+// CHECK-NEXT: cir.br ^bb3
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT: cir.return
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
new file mode 100644
index 0000000000000..a05cf37e39728
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch.cir
@@ -0,0 +1,278 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s8i = !cir.int<s, 8>
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+
+module {
+ cir.func @shouldFlatSwitchWithDefault(%arg0: !s8i) {
+ cir.switch (%arg0 : !s8i) {
+ cir.case (equal, [#cir.int<1> : !s8i]) {
+ cir.break
+ }
+ cir.case (default, []) {
+ cir.break
+ }
+ cir.yield
+ }
+ cir.return
+ }
+// CHECK: cir.func @shouldFlatSwitchWithDefault(%arg0: !s8i) {
+// CHECK: cir.switch.flat %arg0 : !s8i, ^bb[[#DEFAULT:]] [
+// CHECK: 1: ^bb[[#CASE1:]]
+// CHECK: ]
+// CHECK: ^bb[[#CASE1]]:
+// CHECK: cir.br ^bb[[#EXIT:]]
+// CHECK: ^bb[[#DEFAULT]]:
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]:
+// CHECK: cir.return
+// CHECK: }
+
+ cir.func @shouldFlatSwitchWithoutDefault(%arg0: !s32i) {
+ cir.switch (%arg0 : !s32i) {
+ cir.case (equal, [#cir.int<1> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ cir.return
+ }
+// CHECK: cir.func @shouldFlatSwitchWithoutDefault(%arg0: !s32i) {
+// CHECK: cir.switch.flat %arg0 : !s32i, ^bb[[#EXIT:]] [
+// CHECK: 1: ^bb[[#CASE1:]]
+// CHECK: ]
+// CHECK: ^bb[[#CASE1]]:
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]:
+// CHECK: cir.return
+// CHECK: }
+
+
+ cir.func @shouldFlatSwitchWithImplicitFallthrough(%arg0: !s64i) {
+ cir.switch (%arg0 : !s64i) {
+ cir.case (anyof, [#cir.int<1> : !s64i, #cir.int<2> : !s64i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ cir.return
+ }
+// CHECK: cir.func @shouldFlatSwitchWithImplicitFallthrough(%arg0: !s64i) {
+// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [
+// CHECK: 1: ^bb[[#CASE1N2:]],
+// CHECK: 2: ^bb[[#CASE1N2]]
+// CHECK: ]
+// CHECK: ^bb[[#CASE1N2]]:
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]:
+// CHECK: cir.return
+// CHECK: }
+
+
+
+ cir.func @shouldFlatSwitchWithExplicitFallthrough(%arg0: !s64i) {
+ cir.switch (%arg0 : !s64i) {
+ cir.case (equal, [#cir.int<1> : !s64i]) { // case 1 has its own region
+ cir.yield // fallthrough to case 2
+ }
+ cir.case (equal, [#cir.int<2> : !s64i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ cir.return
+ }
+// CHECK: cir.func @shouldFlatSwitchWithExplicitFallthrough(%arg0: !s64i) {
+// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [
+// CHECK: 1: ^bb[[#CASE1:]],
+// CHECK: 2: ^bb[[#CASE2:]]
+// CHECK: ]
+// CHECK: ^bb[[#CASE1]]:
+// CHECK: cir.br ^bb[[#CASE2]]
+// CHECK: ^bb[[#CASE2]]:
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]:
+// CHECK: cir.return
+// CHECK: }
+
+ cir.func @shouldFlatSwitchWithFallthroughToExit(%arg0: !s64i) {
+ cir.switch (%arg0 : !s64i) {
+ cir.case (equal, [#cir.int<1> : !s64i]) {
+ cir.yield // fallthrough to exit
+ }
+ cir.yield
+ }
+ cir.return
+ }
+// CHECK: cir.func @shouldFlatSwitchWithFallthroughToExit(%arg0: !s64i) {
+// CHECK: cir.switch.flat %arg0 : !s64i, ^bb[[#EXIT:]] [
+// CHECK: 1: ^bb[[#CASE1:]]
+// CHECK: ]
+// CHECK: ^bb[[#CASE1]]:
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]:
+// CHECK: cir.return
+// CHECK: }
+
+ cir.func @shouldDropEmptySwitch(%arg0: !s64i) {
+ cir.switch (%arg0 : !s64i) {
+ cir.yield
+ }
+ // CHECK-NOT: llvm.switch
+ cir.return
+ }
+// CHECK: cir.func @shouldDropEmptySwitch(%arg0: !s64i)
+// CHECK-NOT: cir.switch.flat
+
+
+ cir.func @shouldFlatMultiBlockCase(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case (equal, [#cir.int<3> : !s32i]) {
+ cir.return
+ ^bb1: // no predecessors
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+
+// CHECK: cir.func @shouldFlatMultiBlockCase(%arg0: !s32i) {
+// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK: cir.br ^bb1
+// CHECK: ^bb1: // pred: ^bb0
+// CHECK: %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK: cir.switch.flat %1 : !s32i, ^bb[[#DEFAULT:]] [
+// CHECK: 3: ^bb[[#BB1:]]
+// CHECK: ]
+// CHECK: ^bb[[#BB1]]:
+// CHECK: cir.return
+// CHECK: ^bb[[#DEFAULT]]:
+// CHECK: cir.br ^bb[[#RET_BB:]]
+// CHECK: ^bb[[#RET_BB]]: // pred: ^bb[[#DEFAULT]]
+// CHECK: cir.return
+// CHECK: }
+
+
+ cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+ %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.store %arg1, %1 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %5 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%5 : !s32i) {
+ cir.case (equal, [#cir.int<0> : !s32i]) {
+ cir.scope {
+ %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+ %7 = cir.const #cir.int<0> : !s32i
+ %8 = cir.cmp(ge, %6, %7) : !s32i, !cir.bool
+ cir.if %8 {
+ cir.break
+ }
+ }
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ %3 = cir.const #cir.int<3> : !s32i
+ cir.store %3, %2 : !s32i, !cir.ptr<!s32i>
+ %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i
+ cir.return %4 : !s32i
+ }
+// CHECK: cir.func @shouldFlatNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i {
+// CHECK: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#DEFAULT_BB:]] [
+// CHECK: 0: ^bb[[#BB1:]]
+// CHECK: ]
+// CHECK: ^bb[[#BB1]]:
+// CHECK: cir.br ^bb[[#COND_BB:]]
+// CHECK: ^bb[[#COND_BB]]:
+// CHECK: cir.brcond {{%.*}} ^bb[[#TRUE_BB:]], ^bb[[#FALSE_BB:]]
+// CHECK: ^bb[[#TRUE_BB]]:
+// CHECK: cir.br ^bb[[#DEFAULT_BB]]
+// CHECK: ^bb[[#FALSE_BB]]:
+// CHECK: cir.br ^bb[[#PRED_BB:]]
+// CHECK: ^bb[[#PRED_BB]]:
+// CHECK: cir.br ^bb[[#DEFAULT_BB]]
+// CHECK: ^bb[[#DEFAULT_BB]]:
+// CHECK: cir.br ^bb[[#RET_BB:]]
+// CHECK: ^bb[[#RET_BB]]:
+// CHECK: cir.return
+// CHECK: }
+
+
+ cir.func @flatCaseRange(%arg0: !s32i) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x", init] {alignment = 4 : i64}
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+ %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ %3 = cir.const #cir.int<0> : !s32i
+ cir.store %3, %2 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %6 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%6 : !s32i) {
+ cir.case (equal, [#cir.int<-100> : !s32i]) {
+ %7 = cir.const #cir.int<1> : !s32i
+ cir.store %7, %2 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.case (range, [#cir.int<1> : !s32i, #cir.int<100> : !s32i]) {
+ %7 = cir.const #cir.int<2> : !s32i
+ cir.store %7, %2 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.case (default, []) {
+ %7 = cir.const #cir.int<3> : !s32i
+ cir.store %7, %2 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i
+ cir.store %4, %1 : !s32i, !cir.ptr<!s32i>
+ %5 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+ cir.return %5 : !s32i
+ }
+// CHECK: cir.func @flatCaseRange(%arg0: !s32i) -> !s32i {
+// CHECK: cir.switch.flat %[[X:[0-9]+]] : !s32i, ^[[JUDGE_RANGE:bb[0-9]+]] [
+// CHECK-NEXT: -100: ^[[CASE_EQUAL:bb[0-9]+]]
+// CHECK-NEXT: ]
+// CHECK-NEXT: ^[[UNRACHABLE_BB:.+]]: // no predecessors
+// CHECK-NEXT: cir.br ^[[CASE_EQUAL]]
+// CHECK-NEXT: ^[[CASE_EQUAL]]:
+// CHECK-NEXT: cir.int<1>
+// CHECK-NEXT: cir.store
+// CHECK-NEXT: cir.br ^[[EPILOG:bb[0-9]+]]
+// CHECK-NEXT: ^[[CASE_RANGE:bb[0-9]+]]:
+// CHECK-NEXT: cir.int<2>
+// CHECK-NEXT: cir.store
+// CHECK-NEXT: cir.br ^[[EPILOG]]
+// CHECK-NEXT: ^[[JUDGE_RANGE]]:
+// CHECK-NEXT: %[[RANGE:[0-9]+]] = cir.const #cir.int<99>
+// CHECK-NEXT: %[[LOWER_BOUND:[0-9]+]] = cir.const #cir.int<1>
+// CHECK-NEXT: %[[DIFF:[0-9]+]] = cir.binop(sub, %[[X]], %[[LOWER_BOUND]])
+// CHECK-NEXT: %[[U_DIFF:[0-9]+]] = cir.cast(integral, %[[DIFF]] : !s32i), !u32i
+// CHECK-NEXT: %[[U_RANGE:[0-9]+]] = cir.cast(integral, %[[RANGE]] : !s32i), !u32i
+// CHECK-NEXT: %[[CMP_RESULT:[0-9]+]] = cir.cmp(le, %[[U_DIFF]], %[[U_RANGE]])
+// CHECK-NEXT: cir.brcond %[[CMP_RESULT]] ^[[CASE_RANGE]], ^[[CASE_DEFAULT:bb[0-9]+]]
+// CHECK-NEXT: ^[[CASE_DEFAULT]]:
+// CHECK-NEXT: cir.int<3>
+// CHECK-NEXT: cir.store
+// CHECK-NEXT: cir.br ^[[EPILOG]]
+// CHECK-NEXT: ^[[EPILOG]]:
+// CHECK-NEXT: cir.br ^[[EPILOG_END:bb[0-9]+]]
+// CHECK-NEXT: ^[[EPILOG_END]]:
+// CHECK: cir.return
+// CHECK: }
+
+}
>From 75f6761c52cb7affa1e4b80d3e23a1d56b54a1c7 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Tue, 13 May 2025 20:46:38 -0500
Subject: [PATCH 2/7] Remove auto, add log-range test, and end-to-end test for
switch flat op
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 6 +-
.../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 17 ++--
clang/test/CIR/CodeGen/switch_flat_op.cpp | 81 +++++++++++++++++++
clang/test/CIR/IR/switch-flat.cir | 2 +-
clang/test/CIR/Transforms/switch.cir | 40 +++++++++
5 files changed, 134 insertions(+), 12 deletions(-)
create mode 100644 clang/test/CIR/CodeGen/switch_flat_op.cpp
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 87b9823da6ddc..71b9a816669bc 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -981,7 +981,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
let description = [{
The `cir.switch.flat` operation is a region-less and simplified
version of the `cir.switch`.
- It's representation is closer to LLVM IR dialect
+ Its representation is closer to LLVM IR dialect
than the C/C++ language feature.
}];
@@ -989,7 +989,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
CIR_IntType:$condition,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
- ArrayAttr:$case_values,
+ ArrayAttr:$caseValues,
DenseI32ArrayAttr:$case_operand_segments
);
@@ -1001,7 +1001,7 @@ def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments,
let assemblyFormat = [{
$condition `:` type($condition) `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
- custom<SwitchFlatOpCases>(ref(type($condition)), $case_values,
+ custom<SwitchFlatOpCases>(ref(type($condition)), $caseValues,
$caseDestinations, $caseOperands,
type($caseOperands))
attr-dict
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 70f383b556567..46e25719abafb 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -247,7 +247,8 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
{
cir::YieldOp switchYield = nullptr;
// Clear switch operation.
- for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks()))
+ for (mlir::Block &block :
+ llvm::make_early_inc_range(op.getBody().getBlocks()))
if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
switchYield = yieldOp;
@@ -279,7 +280,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
mlir::ValueRange defaultOperands = exitBlock->getArguments();
// Digest the case statements values and bodies.
- for (auto caseOp : cases) {
+ for (cir::CaseOp caseOp : cases) {
mlir::Region ®ion = caseOp.getCaseRegion();
// Found default case: save destination and operands.
@@ -300,7 +301,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
case cir::CaseOpKind::Anyof:
case cir::CaseOpKind::Equal:
// AnyOf cases kind can have multiple values, hence the loop below.
- for (auto &value : caseOp.getValue()) {
+ for (const mlir::Attribute &value : caseOp.getValue()) {
caseValues.push_back(cast<cir::IntAttr>(value).getValue());
caseDestinations.push_back(®ion.front());
caseOperands.push_back(caseDestinations.back()->getArguments());
@@ -319,7 +320,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
});
// Track fallthrough in cases.
- for (auto &blk : region.getBlocks()) {
+ for (mlir::Block &blk : region.getBlocks()) {
if (blk.getNumSuccessors())
continue;
@@ -349,7 +350,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
}
// Remove all cases since we've inlined the regions.
- for (auto caseOp : cases) {
+ for (cir::CaseOp caseOp : cases) {
mlir::Block *caseBlock = caseOp->getBlock();
// Erase the block with no predecessors here to make the generated code
// simpler a little bit.
@@ -359,9 +360,9 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
rewriter.eraseOp(caseOp);
}
- for (size_t index = 0; index < rangeValues.size(); ++index) {
- APInt lowerBound = rangeValues[index].first;
- APInt upperBound = rangeValues[index].second;
+ for (auto [index, rangeVal] : llvm::enumerate(rangeValues)) {
+ APInt lowerBound = rangeVal.first;
+ APInt upperBound = rangeVal.second;
// The case range is unreachable, skip it.
if (lowerBound.sgt(upperBound))
diff --git a/clang/test/CIR/CodeGen/switch_flat_op.cpp b/clang/test/CIR/CodeGen/switch_flat_op.cpp
new file mode 100644
index 0000000000000..e6ed1db2c9e19
--- /dev/null
+++ b/clang/test/CIR/CodeGen/switch_flat_op.cpp
@@ -0,0 +1,81 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
+// RUN: cir-opt --mlir-print-ir-before=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.before.cir 2> %t.before
+// RUN: FileCheck --input-file=%t.before %s --check-prefix=BEFORE
+// RUN: cir-opt --mlir-print-ir-after=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.after.cir 2> %t.after
+// RUN: FileCheck --input-file=%t.after %s --check-prefix=AFTER
+
+
+
+
+
+void swf(int a) {
+ switch (int b = 3; a) {
+ case 3:
+ b = b * 2;
+ break;
+ case 4 ... 5:
+ b = b * 3;
+ break;
+ default:
+ break;
+ }
+
+}
+
+// BEFORE: cir.func @_Z3swfi
+// BEFORE: %[[VAR_B:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
+// BEFORE: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// BEFORE: cir.switch (%[[COND:.*]] : !s32i) {
+// BEFORE: cir.case(equal, [#cir.int<3> : !s32i]) {
+// BEFORE: %[[LOAD_B_EQ:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i
+// BEFORE: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// BEFORE: %[[MUL_EQ:.*]] = cir.binop(mul, %[[LOAD_B_EQ]], %[[CONST_2]]) nsw : !s32i
+// BEFORE: cir.store %[[MUL_EQ]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i>
+// BEFORE: cir.break
+// BEFORE: }
+// BEFORE: cir.case(range, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+// BEFORE: %[[LOAD_B_RANGE:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i
+// BEFORE: %[[CONST_3_RANGE:.*]] = cir.const #cir.int<3> : !s32i
+// BEFORE: %[[MUL_RANGE:.*]] = cir.binop(mul, %[[LOAD_B_RANGE]], %[[CONST_3_RANGE]]) nsw : !s32i
+// BEFORE: cir.store %[[MUL_RANGE]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i>
+// BEFORE: cir.break
+// BEFORE: }
+// BEFORE: cir.case(default, []) {
+// BEFORE: cir.break
+// BEFORE: }
+// BEFORE: cir.yield
+// BEFORE: }
+// BEFORE: }
+// BEFORE: cir.return
+
+// AFTER: cir.func @_Z3swfi
+// AFTER: %[[VAR_A:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+// AFTER: cir.store %arg0, %[[VAR_A]] : !s32i, !cir.ptr<!s32i>
+// AFTER: %[[VAR_B:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
+// AFTER: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// AFTER: cir.store %[[CONST_3]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i>
+// AFTER: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#BB6:]] [
+// AFTER: 3: ^bb[[#BB4:]],
+// AFTER: 4: ^bb[[#BB5:]],
+// AFTER: 5: ^bb[[#BB5:]]
+// AFTER: ]
+// AFTER: ^bb[[#BB4]]:
+// AFTER: %[[LOAD_B_EQ:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i
+// AFTER: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// AFTER: %[[MUL_EQ:.*]] = cir.binop(mul, %[[LOAD_B_EQ]], %[[CONST_2]]) nsw : !s32i
+// AFTER: cir.store %[[MUL_EQ]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i>
+// AFTER: cir.br ^bb[[#BB7:]]
+// AFTER: ^bb[[#BB5]]:
+// AFTER: %[[LOAD_B_RANGE:.*]] = cir.load %[[VAR_B]] : !cir.ptr<!s32i>, !s32i
+// AFTER: %[[CONST_3_AGAIN:.*]] = cir.const #cir.int<3> : !s32i
+// AFTER: %[[MUL_RANGE:.*]] = cir.binop(mul, %[[LOAD_B_RANGE]], %[[CONST_3_AGAIN]]) nsw : !s32i
+// AFTER: cir.store %[[MUL_RANGE]], %[[VAR_B]] : !s32i, !cir.ptr<!s32i>
+// AFTER: cir.br ^bb[[#BB7]]
+// AFTER: ^bb[[#BB6]]:
+// AFTER: cir.br ^bb[[#BB7]]
+// AFTER: ^bb[[#BB7]]:
+// AFTER: cir.br ^bb[[#BB8:]]
+// AFTER: ^bb[[#BB8]]:
+// AFTER: cir.return
+// AFTER: }
+
diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir
index b072c224b4a2c..8c11a74484d39 100644
--- a/clang/test/CIR/IR/switch-flat.cir
+++ b/clang/test/CIR/IR/switch-flat.cir
@@ -17,7 +17,7 @@ cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) {
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: cir.br ^bb2
// CHECK-NEXT: ^bb2:
-//CHECK-NEXT: cir.return
+// CHECK-NEXT: cir.return
cir.func @FlatSwitchWithDefault(%arg0: !s32i) {
cir.switch.flat %arg0 : !s32i, ^bb2 [
diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir
index a05cf37e39728..00b462a6075c9 100644
--- a/clang/test/CIR/Transforms/switch.cir
+++ b/clang/test/CIR/Transforms/switch.cir
@@ -275,4 +275,44 @@ module {
// CHECK: cir.return
// CHECK: }
+ cir.func @_Z8bigRangei(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(range, [#cir.int<3> : !s32i, #cir.int<100> : !s32i]) {
+ cir.break
+ }
+ cir.case(default, []) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+
+// CHECK: cir.func @_Z8bigRangei(%arg0: !s32i) {
+// CHECK: cir.switch.flat %[[COND:.*]] : !s32i, ^bb[[#RANGE_BR:]] [
+// CHECK: ]
+// CHECK: ^bb[[#NO_PRED_BB:]]: // no predecessors
+// CHECK: cir.br ^bb[[#DEFAULT_BB:]]
+// CHECK: ^bb[[#DEFAULT_BB]]: // 2 preds: ^bb[[#NO_PRED_BB]], ^bb[[#RANGE_BR]]
+// CHECK: cir.br ^bb[[#EXIT:]]
+// CHECK: ^bb[[#RANGE_BR]]: // pred: ^bb[[#BB2:]]
+// CHECK: %[[CONST97:.*]] = cir.const #cir.int<97> : !s32i
+// CHECK: %[[CONST3:.*]] = cir.const #cir.int<3> : !s32i
+// CHECK: %[[SUB:.*]] = cir.binop(sub, %[[COND]], %[[CONST3]]) : !s32i
+// CHECK: %[[CAST1:.*]] = cir.cast(integral, %[[SUB]] : !s32i), !u32i
+// CHECK: %[[CAST2:.*]] = cir.cast(integral, %[[CONST97]] : !s32i), !u32i
+// CHECK: %[[CMP:.*]] = cir.cmp(le, %[[CAST1]], %[[CAST2]]) : !u32i, !cir.bool
+// CHECK: cir.brcond %7 ^bb[[#DEFAULT_BB]], ^bb[[#RANGE_BB:]]
+// CHECK: ^bb[[#RANGE_BB]]: // pred: ^bb[[#RANGE_BR]]
+// CHECK: cir.br ^bb[[#EXIT]]
+// CHECK: ^bb[[#EXIT]]: // 2 preds: ^bb[[#DEFAULT_BB]], ^bb[[#RANGE_BB]]
+// CHECK: cir.br ^bb[[#RET_BB:]]
+// CHECK: ^bb[[#RET_BB]]: // pred: ^bb[[#EXIT]]
+// CHECK: cir.return
+// CHECK: }
}
>From 69466b624de06e36a074fd00fa910db777952396 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Tue, 13 May 2025 20:48:26 -0500
Subject: [PATCH 3/7] Fix formatting for switch_flat_op
---
clang/test/CIR/CodeGen/switch_flat_op.cpp | 4 ----
1 file changed, 4 deletions(-)
diff --git a/clang/test/CIR/CodeGen/switch_flat_op.cpp b/clang/test/CIR/CodeGen/switch_flat_op.cpp
index e6ed1db2c9e19..a9fc095025eb0 100644
--- a/clang/test/CIR/CodeGen/switch_flat_op.cpp
+++ b/clang/test/CIR/CodeGen/switch_flat_op.cpp
@@ -4,10 +4,6 @@
// RUN: cir-opt --mlir-print-ir-after=cir-flatten-cfg --cir-flatten-cfg %t.cir -o %t.flattened.after.cir 2> %t.after
// RUN: FileCheck --input-file=%t.after %s --check-prefix=AFTER
-
-
-
-
void swf(int a) {
switch (int b = 3; a) {
case 3:
>From 3df59fdb88563135a291c9540a0b1af01a49373d Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Tue, 13 May 2025 20:54:04 -0500
Subject: [PATCH 4/7] remove auto keyword
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index dd7ee4a2c1adf..9ca831c3efd19 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -22,6 +22,7 @@
#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"
#include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc"
#include "clang/CIR/MissingFeatures.h"
+#include "llvm/ADT/APInt.h"
#include <numeric>
using namespace mlir;
@@ -975,7 +976,7 @@ void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ValueRange> caseOperands) {
std::vector<mlir::Attribute> caseValuesAttrs;
- for (auto &val : caseValues) {
+ for (const APInt &val : caseValues) {
caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
}
mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
>From 494e1d1193cd77648744a195ecf2327f946be3f9 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Wed, 14 May 2025 11:12:18 -0500
Subject: [PATCH 5/7] change enumerate to zip
---
clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 46e25719abafb..71a45d3c84eea 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -360,7 +360,8 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
rewriter.eraseOp(caseOp);
}
- for (auto [index, rangeVal] : llvm::enumerate(rangeValues)) {
+ for (auto [rangeVal, operand, destination] :
+ llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
APInt lowerBound = rangeVal.first;
APInt upperBound = rangeVal.second;
@@ -376,16 +377,16 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
for (APInt iValue = lowerBound; iValue.sle(upperBound);
(void)iValue++) {
caseValues.push_back(iValue);
- caseOperands.push_back(rangeOperands[index]);
- caseDestinations.push_back(rangeDestinations[index]);
+ caseOperands.push_back(operand);
+ caseDestinations.push_back(destination);
}
continue;
}
defaultDestination =
- condBrToRangeDestination(op, rewriter, rangeDestinations[index],
+ condBrToRangeDestination(op, rewriter, destination,
defaultDestination, lowerBound, upperBound);
- defaultOperands = rangeOperands[index];
+ defaultOperands = operand;
}
// Set switch op to branch to the newly created blocks.
>From 37c399916a42cdb9a704f462e369cb91cd1f4368 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Thu, 15 May 2025 15:34:04 -0500
Subject: [PATCH 6/7] switch to ++iValue, drop void cast & braces
---
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 3 +--
clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 3 +--
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 9ca831c3efd19..bdd9b2b04d909 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -976,9 +976,8 @@ void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ValueRange> caseOperands) {
std::vector<mlir::Attribute> caseValuesAttrs;
- for (const APInt &val : caseValues) {
+ for (const APInt &val : caseValues)
caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val));
- }
mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs);
build(builder, result, value, defaultOperands, caseOperands, attrs,
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 71a45d3c84eea..26e5c0572f12e 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -374,8 +374,7 @@ class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
constexpr int kSmallRangeThreshold = 64;
if ((upperBound - lowerBound)
.ult(llvm::APInt(32, kSmallRangeThreshold))) {
- for (APInt iValue = lowerBound; iValue.sle(upperBound);
- (void)iValue++) {
+ for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
caseValues.push_back(iValue);
caseOperands.push_back(operand);
caseDestinations.push_back(destination);
>From 32d8c879771f0bcd7179e27f67327d5366776fc6 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Thu, 15 May 2025 17:32:59 -0500
Subject: [PATCH 7/7] Rebase branch
---
clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 027b08907346b..fb000adee04c6 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -140,7 +140,8 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::callOp());
// CastOp, UnaryOp and VecExtractOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, VecExtractOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
+ VecExtractOp>(op))
ops.push_back(op);
});
More information about the cfe-commits
mailing list