[clang] [CIR] Implement switch case simplify (PR #140649)
via cfe-commits
cfe-commits at lists.llvm.org
Wed May 21 09:31:28 PDT 2025
https://github.com/Andres-Salamanca updated https://github.com/llvm/llvm-project/pull/140649
>From c1403f148a58e259cc296310dc21b8c5611f2e82 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Mon, 19 May 2025 18:53:15 -0500
Subject: [PATCH 1/3] Implement CIR switch case simplify with appropriate tests
---
clang/include/clang/CIR/MissingFeatures.h | 1 -
clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 6 -
.../CIR/Dialect/Transforms/CIRSimplify.cpp | 106 +++++++++-
clang/test/CIR/Transforms/switch-fold.cir | 196 ++++++++++++++++++
4 files changed, 300 insertions(+), 9 deletions(-)
create mode 100644 clang/test/CIR/Transforms/switch-fold.cir
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 484822c351746..9f3e5d007d66c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -114,7 +114,6 @@ struct MissingFeatures {
static bool opUnaryPromotionType() { return false; }
// SwitchOp handling
- static bool foldCascadingCases() { return false; }
static bool foldRangeCase() { return false; }
// Clang early optimizations or things defered to LLVM lowering.
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index cc96e65e4ce1d..7f1ecbda414bd 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -531,12 +531,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
cir::IntAttr::get(condType, endVal)});
kind = cir::CaseOpKind::Range;
-
- // We don't currently fold case range statements with other case statements.
- // TODO(cir): Add this capability. Folding these cases is going to be
- // implemented in CIRSimplify when it is upstreamed.
- assert(!cir::MissingFeatures::foldRangeCase());
- assert(!cir::MissingFeatures::foldCascadingCases());
} else {
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
kind = cir::CaseOpKind::Equal;
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index b969569b0081c..58300cc219602 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
}
};
+/// Simplify `cir.switch` operations by folding cascading cases
+/// into a single `cir.case` with the `anyof` kind.
+///
+/// This pattern identifies cascading cases within a `cir.switch` operation.
+/// Cascading cases are defined as consecutive `cir.case` operations of kind
+/// `equal`, each containing a single `cir.yield` operation in their body.
+///
+/// The pattern merges these cascading cases into a single `cir.case` operation
+/// with kind `anyof`, aggregating all the case values.
+///
+/// The merging process continues until a `cir.case` with a different body
+/// (e.g., containing `cir.break` or compound stmt) is encountered, which
+/// breaks the chain.
+///
+/// Example:
+///
+/// Before:
+/// cir.case equal, [#cir.int<0> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<1> : !s32i] {
+/// cir.yield
+/// }
+/// cir.case equal, [#cir.int<2> : !s32i] {
+/// cir.break
+/// }
+///
+/// After applying SimplifySwitch:
+/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
+/// !s32i] {
+/// cir.break
+/// }
+struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SwitchOp op,
+ PatternRewriter &rewriter) const override {
+
+ LogicalResult changed = mlir::failure();
+ llvm::SmallVector<CaseOp, 8> cases;
+ SmallVector<CaseOp, 4> cascadingCases;
+ SmallVector<mlir::Attribute, 4> cascadingCaseValues;
+
+ op.collectCases(cases);
+ if (cases.empty())
+ return mlir::failure();
+
+ auto flushMergedOps = [&]() {
+ for (CaseOp &c : cascadingCases) {
+ rewriter.eraseOp(c);
+ }
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ };
+
+ auto mergeCascadingInto = [&](CaseOp &target) {
+ rewriter.modifyOpInPlace(target, [&]() {
+ target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
+ target.setKind(CaseOpKind::Anyof);
+ });
+ changed = mlir::success();
+ };
+
+ for (CaseOp c : cases) {
+ cir::CaseOpKind kind = c.getKind();
+ if (kind == cir::CaseOpKind::Equal &&
+ isa<YieldOp>(c.getCaseRegion().front().front())) {
+ // If the case contains only a YieldOp, collect it for cascading merge
+ cascadingCases.push_back(c);
+ cascadingCaseValues.push_back(c.getValue()[0]);
+
+ } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
+ // merge previously collected cascading cases
+ cascadingCaseValues.push_back(c.getValue()[0]);
+ mergeCascadingInto(c);
+ flushMergedOps();
+ } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
+ // If a Default, Anyof or Range case is found and there are previous
+ // cascading cases, merge all of them into the last cascading case.
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ } else {
+ cascadingCases.clear();
+ cascadingCaseValues.clear();
+ }
+ }
+
+ // Edge case: all cases are simple cascading cases
+ if (cascadingCases.size() == cases.size()) {
+ CaseOp lastCascadingCase = cascadingCases.back();
+ mergeCascadingInto(lastCascadingCase);
+ cascadingCases.pop_back();
+ flushMergedOps();
+ }
+ // We don't currently fold case range statements with other case statements.
+ assert(!cir::MissingFeatures::foldRangeCase());
+ return changed;
+ }
+};
+
//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
- SimplifySelect
+ SimplifySelect,
+ SimplifySwitch
>(patterns.getContext());
// clang-format on
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
- if (isa<TernaryOp, SelectOp>(op))
+ if (isa<TernaryOp, SelectOp, SwitchOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
new file mode 100644
index 0000000000000..3c2fe8a9cbf25
--- /dev/null
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -0,0 +1,196 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @foldCascade(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ %2 = cir.const #cir.int<2> : !s32i
+ cir.store %2, %0 : !s32i, !cir.ptr<!s32i>
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
+ //CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+ //CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
+ //CHECK-NEXT: cir.break
+ //CHECK-NEXT: }
+ //CHECK-NEXT: cir.yield
+ //CHECK-NEXT: }
+
+ cir.func @foldCascade2(%arg0: !s32i) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: @foldCascade2
+ //CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+ //CHECK: cir.break
+ //cehck: }
+ //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascade3(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"] {alignment = 4 : i64}
+ %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%2 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascade3
+ //CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldCascadeWithDefault(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.break
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.case(default, []) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<6> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<7> : !s32i]) {
+ cir.break
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldCascadeWithDefault
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(default, []) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i]) {
+ //CHECK: cir.break
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+ cir.func @foldAllCascade(%arg0: !s32i ) {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init] {alignment = 4 : i64}
+ cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+ cir.scope {
+ %1 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.switch (%1 : !s32i) {
+ cir.case(equal, [#cir.int<0> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.yield
+ }
+ cir.case(equal, [#cir.int<5> : !s32i]) {
+ cir.yield
+ }
+ cir.yield
+ }
+ }
+ cir.return
+ }
+ //CHECK: cir.func @foldAllCascade
+ //CHECK: cir.switch (%[[COND:.*]] : !s32i) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.yield
+ //CHECK: }
+ //CHECK: cir.yield
+ //CHECK: }
+}
>From 0afcfd6b0f5cf414f846d2012ac4508e968199d8 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Wed, 21 May 2025 11:27:36 -0500
Subject: [PATCH 2/3] Apply reviews
---
clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp | 8 ++++----
clang/test/CIR/Transforms/switch-fold.cir | 12 ++++++------
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index 58300cc219602..af064c0800fbe 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -197,7 +197,7 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
PatternRewriter &rewriter) const override {
LogicalResult changed = mlir::failure();
- llvm::SmallVector<CaseOp, 8> cases;
+ SmallVector<CaseOp, 8> cases;
SmallVector<CaseOp, 4> cascadingCases;
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
@@ -228,7 +228,6 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
// If the case contains only a YieldOp, collect it for cascading merge
cascadingCases.push_back(c);
cascadingCaseValues.push_back(c.getValue()[0]);
-
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
// merge previously collected cascading cases
cascadingCaseValues.push_back(c.getValue()[0]);
@@ -237,6 +236,8 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
// If a Default, Anyof or Range case is found and there are previous
// cascading cases, merge all of them into the last cascading case.
+ // We don't currently fold case range statements with other case statements.
+ assert(!cir::MissingFeatures::foldRangeCase());
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
cascadingCases.pop_back();
@@ -254,8 +255,7 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
cascadingCases.pop_back();
flushMergedOps();
}
- // We don't currently fold case range statements with other case statements.
- assert(!cir::MissingFeatures::foldRangeCase());
+
return changed;
}
};
diff --git a/clang/test/CIR/Transforms/switch-fold.cir b/clang/test/CIR/Transforms/switch-fold.cir
index 3c2fe8a9cbf25..62a94f4fde2c3 100644
--- a/clang/test/CIR/Transforms/switch-fold.cir
+++ b/clang/test/CIR/Transforms/switch-fold.cir
@@ -45,16 +45,16 @@ module {
cir.case(equal, [#cir.int<0> : !s32i]) {
cir.yield
}
- cir.case(equal, [#cir.int<1> : !s32i]) {
+ cir.case(equal, [#cir.int<2> : !s32i]) {
cir.yield
}
- cir.case(equal, [#cir.int<2> : !s32i]) {
+ cir.case(equal, [#cir.int<4> : !s32i]) {
cir.break
}
- cir.case(equal, [#cir.int<3> : !s32i]) {
+ cir.case(equal, [#cir.int<1> : !s32i]) {
cir.yield
}
- cir.case(equal, [#cir.int<4> : !s32i]) {
+ cir.case(equal, [#cir.int<3> : !s32i]) {
cir.yield
}
cir.case(equal, [#cir.int<5> : !s32i]) {
@@ -67,10 +67,10 @@ module {
}
//CHECK: @foldCascade2
//CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
- //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i]) {
+ //CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) {
//CHECK: cir.break
//cehck: }
- //CHECK: cir.case(anyof, [#cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
+ //CHECK: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i]) {
//CHECK: cir.break
//CHECK: }
//CHECK: cir.yield
>From e46e3abe21df1214b535fbe1ba7d1037b49b05e0 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Wed, 21 May 2025 11:31:09 -0500
Subject: [PATCH 3/3] Fix formatting
---
clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
index af064c0800fbe..40716f2467563 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -236,7 +236,8 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
// If a Default, Anyof or Range case is found and there are previous
// cascading cases, merge all of them into the last cascading case.
- // We don't currently fold case range statements with other case statements.
+ // We don't currently fold case range statements with other case
+ // statements.
assert(!cir::MissingFeatures::foldRangeCase());
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
More information about the cfe-commits
mailing list