[clang] [CIR] Upstream initial support for switch statements (PR #137106)

via cfe-commits cfe-commits at lists.llvm.org
Mon Apr 28 11:28:19 PDT 2025


https://github.com/Andres-Salamanca updated https://github.com/llvm/llvm-project/pull/137106

>From f1f56e16d524783c69016867fcdf474ac3e4e09f Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Tue, 22 Apr 2025 15:16:19 -0500
Subject: [PATCH 1/8] Add initial CIR support for switch operation

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td | 224 ++++++++++++++++++-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp      | 131 ++++++++++-
 clang/test/CIR/IR/switch.cir                 |  38 ++++
 3 files changed, 389 insertions(+), 4 deletions(-)
 create mode 100644 clang/test/CIR/IR/switch.cir

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index bb19de31b4fa5..04bfb76c3b95b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -470,7 +470,8 @@ def StoreOp : CIR_Op<"store", [
 //===----------------------------------------------------------------------===//
 
 def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "IfOp",
-                                              "DoWhileOp", "WhileOp", "ForOp"]>,
+                                              "SwitchOp", "DoWhileOp","WhileOp",
+                                              "ForOp", "CaseOp"]>,
                                  Terminator]> {
   let summary = "Return from function";
   let description = [{
@@ -609,8 +610,9 @@ def ConditionOp : CIR_Op<"condition", [
 //===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "ScopeOp", "WhileOp",
-                                            "ForOp", "DoWhileOp"]>]> {
+                               ParentOneOf<["IfOp", "ScopeOp", "SwitchOp",
+                                            "WhileOp", "ForOp", "CaseOp",
+                                            "DoWhileOp"]>]> {
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on different CIR operations,
@@ -753,6 +755,222 @@ def ScopeOp : CIR_Op<"scope", [
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+def CaseOpKind_DT : I32EnumAttrCase<"Default", 1, "default">;
+def CaseOpKind_EQ : I32EnumAttrCase<"Equal", 2, "equal">;
+def CaseOpKind_AO : I32EnumAttrCase<"Anyof", 3, "anyof">;
+def CaseOpKind_RG : I32EnumAttrCase<"Range", 4, "range">;
+
+def CaseOpKind : I32EnumAttr<
+    "CaseOpKind",
+    "case kind",
+    [CaseOpKind_DT, CaseOpKind_EQ, CaseOpKind_AO, CaseOpKind_RG]> {
+  let cppNamespace = "::cir";
+}
+
+def CaseOp : CIR_Op<"case", [
+       DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope]> {
+  let summary = "Case operation";
+  let description = [{
+    The `cir.case` operation represents a case within a C/C++ switch.
+    The `cir.case` operation must be in a `cir.switch` operation directly or indirectly.
+
+    The `cir.case` have 4 kinds:
+    - `equal, <constant>`: equality of the second case operand against the
+    condition.
+    - `anyof, [constant-list]`: equals to any of the values in a subsequent
+    following list.
+    - `range, [lower-bound, upper-bound]`: the condition is within the closed interval.
+    - `default`: any other value.
+
+    Each case region must be explicitly terminated.
+  }];
+
+  let arguments = (ins ArrayAttr:$value, CaseOpKind:$kind);
+  let regions = (region AnyRegion:$caseRegion);
+
+  let assemblyFormat = "`(` $kind `,` $value `)` $caseRegion attr-dict";
+
+  let hasVerifier = 1;
+
+  let skipDefaultBuilders = 1;
+let builders = [
+    OpBuilder<(ins "mlir::ArrayAttr":$value,
+                   "CaseOpKind":$kind,
+                   "mlir::OpBuilder::InsertPoint &":$insertPoint)>
+  ];
+}
+
+def SwitchOp : CIR_Op<"switch",
+      [SameVariadicOperandSize,
+       DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+  let summary = "Switch operation";
+  let description = [{
+    The `cir.switch` operation represents C/C++ switch functionality for
+    conditionally executing multiple regions of code. The operand to an switch
+    is an integral condition value.
+
+    The set of `cir.case` operations and their enclosing `cir.switch`
+    represents the semantics of a C/C++ switch statement. Users can use
+    `collectCases(llvm::SmallVector<CaseOp> &cases)` to collect the `cir.case`
+    operation in the `cir.switch` operation easily.
+
+    The `cir.case` operations doesn't have to be in the region of `cir.switch`
+    directly. However, when all the `cir.case` operations lives in the region
+    of `cir.switch` directly and there is no other operations except the ending
+    `cir.yield` operation in the region of `cir.switch` directly, we call the
+    `cir.switch` operation is in a simple form. Users can use
+    `bool isSimpleForm(llvm::SmallVector<CaseOp> &cases)` member function to
+    detect if the `cir.switch` operation is in a simple form. The simple form
+    makes analysis easier to handle the `cir.switch` operation
+    and makes the boundary to give up pretty clear.
+
+    To make the simple form as common as possible, CIR code generation attaches
+    operations corresponding to the statements that lives between top level
+    cases into the closest `cir.case` operation.
+
+    For example,
+
+    ```
+    switch(int cond) {
+      case 4:
+        a++;
+
+      b++;
+      case 5;
+        c++;
+
+      ...
+    }
+    ```
+
+    The statement `b++` is not a sub-statement of the case statement `case 4`.
+    But to make the generated `cir.switch` a simple form, we will attach the
+    statement `b++` into the closest `cir.case` operation. So that the generated
+    code will be like:
+
+    ```
+    cir.switch(int cond) {
+      cir.case(equal, 4) {
+        a++;
+        b++;
+        cir.yield
+      }
+      cir.case(equal, 5) {
+        c++;
+        cir.yield
+      }
+      ...
+    }
+    ```
+
+    For the same reason, we will hoist the case statement as the substatement
+    of another case statement so that they will be in the same level. For
+    example,
+
+    ```
+    switch(int cond) {
+      case 4:
+      default;
+      case 5;
+        a++;
+      ...
+    }
+    ```
+
+    will be generated as
+
+    ```
+    cir.switch(int cond) {
+      cir.case(equal, 4) {
+        cir.yield
+      }
+      cir.case(default) {
+        cir.yield
+      }
+      cir.case(equal, 5) {
+        a++;
+        cir.yield
+      }
+      ...
+    }
+    ```
+
+    The cir.switch might not be considered "simple" if any of the following is
+    true:
+    - There are case statements of the switch statement lives in other scopes
+      other than the top level compound statement scope. Note that a case
+      statement itself doesn't form a scope.
+    - The sub-statement of the switch statement is not a compound statement.
+    - There are codes before the first case statement. For example,
+
+    ```
+    switch(int cond) {
+      l:
+        b++;
+
+      case 4:
+        a++;
+        break;
+
+      case 5:
+        goto l;
+      ...
+    }
+    ```
+
+    the generated CIR for this non-simple switch would be:
+
+    ```
+    cir.switch(int cond) {
+      cir.label "l"
+      b++;
+      cir.case(4) {
+        a++;
+        cir.break
+      }
+      cir.case(5) {
+        goto "l"
+      }
+      cir.yield
+    }
+    ```
+  }];
+
+  let arguments = (ins CIR_IntType:$condition);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasVerifier = 1;
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$condition,
+               "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::OperationState &)>":$switchBuilder)>
+  ];
+
+  let assemblyFormat = [{
+    custom<SwitchOp>(
+      $body, $condition, type($condition)
+    )
+    attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    // Collect cases in the switch.
+    void collectCases(llvm::SmallVector<CaseOp> &cases);
+
+    // Check if the switch is in a simple form. If yes, collect the cases to \param cases.
+    // This is an expensive and need to be used with caution.
+    bool isSimpleForm(llvm::SmallVector<CaseOp> &cases);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BrOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 3cd17053a52ba..06f413acc8263 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -12,8 +12,10 @@
 
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
+#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
 #include "clang/CIR/Dialect/IR/CIRTypes.h"
 
+#include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Support/LogicalResult.h"
@@ -166,7 +168,8 @@ void cir::AllocaOp::build(mlir::OpBuilder &odsBuilder,
 
 LogicalResult cir::BreakOp::verify() {
   assert(!cir::MissingFeatures::switchOp());
-  if (!getOperation()->getParentOfType<LoopOpInterface>())
+  if (!getOperation()->getParentOfType<LoopOpInterface>() &&
+      !getOperation()->getParentOfType<SwitchOp>())
     return emitOpError("must be within a loop");
   return success();
 }
@@ -802,6 +805,132 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// CaseOp
+//===----------------------------------------------------------------------===//
+
+void cir::CaseOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor());
+    return;
+  }
+  regions.push_back(RegionSuccessor(&getCaseRegion()));
+}
+
+void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
+                        ArrayAttr value, CaseOpKind kind,
+                        OpBuilder::InsertPoint &insertPoint) {
+  OpBuilder::InsertionGuard guardSwitch(builder);
+  result.addAttribute("value", value);
+  result.getOrAddProperties<Properties>().kind =
+      cir::CaseOpKindAttr::get(builder.getContext(), kind);
+  Region *caseRegion = result.addRegion();
+  builder.createBlock(caseRegion);
+
+  insertPoint = builder.saveInsertionPoint();
+}
+
+LogicalResult cir::CaseOp::verify() { return success(); }
+
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
+                                 mlir::OpAsmParser::UnresolvedOperand &cond,
+                                 mlir::Type &condType) {
+  cir::IntType intCondType;
+
+  if (parser.parseLParen())
+    return ::mlir::failure();
+
+  if (parser.parseOperand(cond))
+    return ::mlir::failure();
+  if (parser.parseColon())
+    return ::mlir::failure();
+  if (parser.parseCustomTypeWithFallback(intCondType))
+    return ::mlir::failure();
+  condType = intCondType;
+
+  if (parser.parseRParen())
+    return ::mlir::failure();
+  if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+
+  return ::mlir::success();
+}
+
+static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
+                          mlir::Region &bodyRegion, mlir::Value condition,
+                          mlir::Type condType) {
+  p << "(";
+  p << condition;
+  p << " : ";
+  p.printStrippedAttrOrType(condType);
+  p << ")";
+
+  p << ' ';
+  p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/true);
+}
+
+void cir::SwitchOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
+  if (!point.isParent()) {
+    region.push_back(RegionSuccessor());
+    return;
+  }
+
+  region.push_back(RegionSuccessor(&getBody()));
+}
+
+LogicalResult cir::SwitchOp::verify() { return success(); }
+
+void cir::SwitchOp::build(
+    OpBuilder &builder, OperationState &result, Value cond,
+    function_ref<void(OpBuilder &, Location, OperationState &)> switchBuilder) {
+  assert(switchBuilder && "the builder callback for regions must be present");
+  OpBuilder::InsertionGuard guardSwitch(builder);
+  Region *switchRegion = result.addRegion();
+  builder.createBlock(switchRegion);
+  result.addOperands({cond});
+  switchBuilder(builder, result.location, result);
+}
+
+void cir::SwitchOp::collectCases(llvm::SmallVector<CaseOp> &cases) {
+  walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
+    // Don't walk in nested switch op.
+    if (isa<cir::SwitchOp>(op) && op != *this)
+      return WalkResult::skip();
+
+    if (auto caseOp = dyn_cast<cir::CaseOp>(op))
+      cases.push_back(caseOp);
+
+    return WalkResult::advance();
+  });
+}
+
+// Check if the switch is in a simple form. If yes, collect the cases to \param
+// cases. This is an expensive and need to be used with caution.
+bool cir::SwitchOp::isSimpleForm(llvm::SmallVector<CaseOp> &cases) {
+  collectCases(cases);
+
+  if (getBody().empty())
+    return false;
+
+  if (!isa<YieldOp>(getBody().front().back()))
+    return false;
+
+  if (!llvm::all_of(getBody().front(),
+                    [](Operation &op) { return isa<CaseOp, YieldOp>(op); }))
+    return false;
+
+  return llvm::all_of(cases, [this](CaseOp op) {
+    return op->getParentOfType<SwitchOp>() == *this;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/IR/switch.cir b/clang/test/CIR/IR/switch.cir
new file mode 100644
index 0000000000000..0bdc9c1e7e896
--- /dev/null
+++ b/clang/test/CIR/IR/switch.cir
@@ -0,0 +1,38 @@
+// RUN: cir-opt %s | FileCheck %s
+!s32i = !cir.int<s, 32>
+
+cir.func @s0() {
+  %1 = cir.const #cir.int<2> : !s32i
+  cir.switch (%1 : !s32i) {
+    cir.case (default, []) {
+      cir.return
+    }
+    cir.case (equal, [#cir.int<3> : !s32i]) {
+      cir.yield
+    }
+    cir.case (anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]) {
+      cir.break
+    }
+    cir.case (equal, [#cir.int<5> : !s32i]) {
+      cir.yield
+    }
+    cir.yield
+  }
+  cir.return
+}
+
+// CHECK: cir.switch (%0 : !s32i) {
+// CHECK-NEXT: cir.case(default, [])  {
+// CHECK-NEXT:   cir.return
+// CHECK-NEXT: }
+// CHECK-NEXT: cir.case(equal, [#cir.int<3> : !s32i])  {
+// CHECK-NEXT:   cir.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: cir.case(anyof, [#cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]) {
+// CHECK-NEXT:   cir.break
+// CHECK-NEXT: }
+// CHECK-NEXT: cir.case(equal, [#cir.int<5> : !s32i])  {
+// CHECK-NEXT:   cir.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: cir.yield
+// CHECK-NEXT: }

>From b5d56d5ae424ea706e8b4b0e138f117bc6a873be Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Wed, 23 Apr 2025 17:16:17 -0500
Subject: [PATCH 2/8] Add codegen for switch and EqualCase kind

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td |   9 +-
 clang/include/clang/CIR/MissingFeatures.h    |   1 +
 clang/lib/CIR/CodeGen/CIRGenFunction.h       |  18 ++
 clang/lib/CIR/CodeGen/CIRGenStmt.cpp         | 210 ++++++++++++++++++-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp      |   1 -
 clang/test/CIR/CodeGen/switch.cpp            |  92 ++++++++
 6 files changed, 326 insertions(+), 5 deletions(-)
 create mode 100644 clang/test/CIR/CodeGen/switch.cpp

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 04bfb76c3b95b..25cdf156659c7 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -777,14 +777,16 @@ def CaseOp : CIR_Op<"case", [
   let summary = "Case operation";
   let description = [{
     The `cir.case` operation represents a case within a C/C++ switch.
-    The `cir.case` operation must be in a `cir.switch` operation directly or indirectly.
+    The `cir.case` operation must be in a `cir.switch` operation directly
+    or indirectly.
 
     The `cir.case` have 4 kinds:
     - `equal, <constant>`: equality of the second case operand against the
     condition.
     - `anyof, [constant-list]`: equals to any of the values in a subsequent
     following list.
-    - `range, [lower-bound, upper-bound]`: the condition is within the closed interval.
+    - `range, [lower-bound, upper-bound]`: the condition is within the closed
+                                           interval.
     - `default`: any other value.
 
     Each case region must be explicitly terminated.
@@ -965,7 +967,8 @@ def SwitchOp : CIR_Op<"switch",
     // Collect cases in the switch.
     void collectCases(llvm::SmallVector<CaseOp> &cases);
 
-    // Check if the switch is in a simple form. If yes, collect the cases to \param cases.
+    // Check if the switch is in a simple form.
+    // If yes, collect the cases to \param cases.
     // This is an expensive and need to be used with caution.
     bool isSimpleForm(llvm::SmallVector<CaseOp> &cases);
   }];
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index bb5dac4faa1e0..fc0999e83b727 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -161,6 +161,7 @@ struct MissingFeatures {
   static bool targetSpecificCXXABI() { return false; }
   static bool moduleNameHash() { return false; }
   static bool setDSOLocal() { return false; }
+  static bool foldCaseStmt() { return false; }
 
   // Missing types
   static bool dataMemberType() { return false; }
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 74fcd081dec18..ec42aee08ee15 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -63,6 +63,9 @@ class CIRGenFunction : public CIRGenTypeCache {
   /// declarations.
   DeclMapTy localDeclMap;
 
+  /// The type of the condition for the emitting switch statement.
+  llvm::SmallVector<mlir::Type, 2> condTypeStack;
+
   clang::ASTContext &getContext() const { return cgm.getASTContext(); }
 
   CIRGenBuilderTy &getBuilder() { return builder; }
@@ -469,6 +472,16 @@ class CIRGenFunction : public CIRGenTypeCache {
                       ReturnValueSlot returnValue = ReturnValueSlot());
   CIRGenCallee emitCallee(const clang::Expr *e);
 
+  template <typename T>
+  mlir::LogicalResult emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
+                                             mlir::ArrayAttr value,
+                                             cir::CaseOpKind kind,
+                                             bool buildingTopLevelCase);
+
+  mlir::LogicalResult emitCaseStmt(const clang::CaseStmt &s,
+                                   mlir::Type condType,
+                                   bool buildingTopLevelCase);
+
   mlir::LogicalResult emitContinueStmt(const clang::ContinueStmt &s);
   mlir::LogicalResult emitDoStmt(const clang::DoStmt &s);
 
@@ -595,6 +608,11 @@ class CIRGenFunction : public CIRGenTypeCache {
 
   mlir::Value emitStoreThroughBitfieldLValue(RValue src, LValue dstresult);
 
+  mlir::LogicalResult emitSwitchBody(const clang::Stmt *s);
+  mlir::LogicalResult emitSwitchCase(const clang::SwitchCase &s,
+                                     bool buildingTopLevelCase);
+  mlir::LogicalResult emitSwitchStmt(const clang::SwitchStmt &s);
+
   /// Given a value and its clang type, returns the value casted to its memory
   /// representation.
   /// Note: CIR defers most of the special casting to the final lowering passes
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index 82ac53706b7f9..76fa7bf07f097 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -89,6 +89,8 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
     }
   case Stmt::IfStmtClass:
     return emitIfStmt(cast<IfStmt>(*s));
+  case Stmt::SwitchStmtClass:
+    return emitSwitchStmt(cast<SwitchStmt>(*s));
   case Stmt::ForStmtClass:
     return emitForStmt(cast<ForStmt>(*s));
   case Stmt::WhileStmtClass:
@@ -132,7 +134,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
   case Stmt::CaseStmtClass:
   case Stmt::SEHLeaveStmtClass:
   case Stmt::SYCLKernelCallStmtClass:
-  case Stmt::SwitchStmtClass:
   case Stmt::CoroutineBodyStmtClass:
   case Stmt::CoreturnStmtClass:
   case Stmt::CXXTryStmtClass:
@@ -422,6 +423,117 @@ mlir::LogicalResult CIRGenFunction::emitBreakStmt(const clang::BreakStmt &s) {
   return mlir::success();
 }
 
+template <typename T>
+mlir::LogicalResult
+CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
+                                       mlir::ArrayAttr value, CaseOpKind kind,
+                                       bool buildingTopLevelCase) {
+
+  assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
+         "only case or default stmt go here");
+
+  mlir::LogicalResult result = mlir::success();
+
+  mlir::Location loc = getLoc(stmt->getBeginLoc());
+
+  enum class SubStmtKind { Case, Default, Other };
+  SubStmtKind subStmtKind = SubStmtKind::Other;
+  const Stmt *sub = stmt->getSubStmt();
+
+  mlir::OpBuilder::InsertPoint insertPoint;
+  builder.create<CaseOp>(loc, value, kind, insertPoint);
+
+  {
+    mlir::OpBuilder::InsertionGuard guardSwitch(builder);
+    builder.restoreInsertionPoint(insertPoint);
+
+    if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
+      subStmtKind = SubStmtKind::Default;
+      builder.createYield(loc);
+    } else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
+      subStmtKind = SubStmtKind::Case;
+      builder.createYield(loc);
+    } else
+      result = emitStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
+
+    insertPoint = builder.saveInsertionPoint();
+  }
+
+  // If the substmt is default stmt or case stmt, try to handle the special case
+  // to make it into the simple form. e.g.
+  //
+  //  swtich () {
+  //    case 1:
+  //    default:
+  //      ...
+  //  }
+  //
+  // we prefer generating
+  //
+  //  cir.switch() {
+  //     cir.case(equal, 1) {
+  //        cir.yield
+  //     }
+  //     cir.case(default) {
+  //        ...
+  //     }
+  //  }
+  //
+  // than
+  //
+  //  cir.switch() {
+  //     cir.case(equal, 1) {
+  //       cir.case(default) {
+  //         ...
+  //       }
+  //     }
+  //  }
+  //
+  // We don't need to revert this if we find the current switch can't be in
+  // simple form later since the conversion itself should be harmless.
+  if (subStmtKind == SubStmtKind::Case)
+    result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
+  else if (subStmtKind == SubStmtKind::Default)
+    getCIRGenModule().errorNYI(sub->getSourceRange(), "Default case");
+  else if (buildingTopLevelCase)
+    // If we're building a top level case, try to restore the insert point to
+    // the case we're building, then we can attach more random stmts to the
+    // case to make generating `cir.switch` operation to be a simple form.
+    builder.restoreInsertionPoint(insertPoint);
+
+  return result;
+}
+
+mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
+                                                 mlir::Type condType,
+                                                 bool buildingTopLevelCase) {
+  llvm::APSInt intVal = s.getLHS()->EvaluateKnownConstInt(getContext());
+  SmallVector<mlir::Attribute, 1> caseEltValueListAttr;
+  caseEltValueListAttr.push_back(cir::IntAttr::get(condType, intVal));
+  mlir::ArrayAttr value = builder.getArrayAttr(caseEltValueListAttr);
+  if (s.getRHS()) {
+    getCIRGenModule().errorNYI(s.getSourceRange(), "SwitchOp range kind");
+  }
+  assert(!cir::MissingFeatures::foldCaseStmt());
+  return emitCaseDefaultCascade(&s, condType, value, cir::CaseOpKind::Equal,
+                                buildingTopLevelCase);
+}
+
+mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
+                                                   bool buildingTopLevelCase) {
+  assert(!condTypeStack.empty() &&
+         "build switch case without specifying the type of the condition");
+
+  if (s.getStmtClass() == Stmt::CaseStmtClass)
+    return emitCaseStmt(cast<CaseStmt>(s), condTypeStack.back(),
+                        buildingTopLevelCase);
+
+  if (s.getStmtClass() == Stmt::DefaultStmtClass)
+    getCIRGenModule().errorNYI(s.getSourceRange(), "Default case");
+
+  llvm_unreachable("expect case or default stmt");
+}
+
 mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
   cir::ForOp forOp;
 
@@ -600,3 +712,99 @@ mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) {
   terminateBody(builder, whileOp.getBody(), getLoc(s.getEndLoc()));
   return mlir::success();
 }
+
+mlir::LogicalResult CIRGenFunction::emitSwitchBody(const Stmt *s) {
+  // It is rare but legal if the switch body is not a compound stmt. e.g.,
+  //
+  //  switch(a)
+  //    while(...) {
+  //      case1
+  //      ...
+  //      case2
+  //      ...
+  //    }
+  if (!isa<CompoundStmt>(s))
+    return emitStmt(s, /*useCurrentScope=*/!false);
+
+  auto *compoundStmt = cast<CompoundStmt>(s);
+
+  mlir::Block *swtichBlock = builder.getBlock();
+  for (auto *c : compoundStmt->body()) {
+    if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
+      builder.setInsertionPointToEnd(swtichBlock);
+      // Reset insert point automatically, so that we can attach following
+      // random stmt to the region of previous built case op to try to make
+      // the being generated `cir.switch` to be in simple form.
+      if (mlir::failed(
+              emitSwitchCase(*switchCase, /*buildingTopLevelCase=*/true)))
+        return mlir::failure();
+
+      continue;
+    }
+
+    // Otherwise, just build the statements in the nearest case region.
+    if (mlir::failed(emitStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c))))
+      return mlir::failure();
+  }
+
+  return mlir::success();
+}
+
+mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) {
+  // TODO: LLVM codegen does some early optimization to fold the condition and
+  // only emit live cases. CIR should use MLIR to achieve similar things,
+  // nothing to be done here.
+  // if (ConstantFoldsToSimpleInteger(S.getCond(), ConstantCondValue))...
+
+  SwitchOp swop;
+  auto switchStmtBuilder = [&]() -> mlir::LogicalResult {
+    if (s.getInit())
+      if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
+        return mlir::failure();
+
+    if (s.getConditionVariable())
+      emitDecl(*s.getConditionVariable());
+
+    mlir::Value condV = emitScalarExpr(s.getCond());
+
+    // TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
+    assert(!cir::MissingFeatures::pgoUse());
+    assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
+    // TODO: if the switch has a condition wrapped by __builtin_unpredictable?
+    assert(!cir::MissingFeatures::insertBuiltinUnpredictable());
+
+    mlir::LogicalResult res = mlir::success();
+    swop = builder.create<SwitchOp>(
+        getLoc(s.getBeginLoc()), condV,
+        /*switchBuilder=*/
+        [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
+          curLexScope->setAsSwitch();
+
+          condTypeStack.push_back(condV.getType());
+
+          res = emitSwitchBody(s.getBody());
+
+          condTypeStack.pop_back();
+        });
+
+    return res;
+  };
+
+  // The switch scope contains the full source range for SwitchStmt.
+  mlir::Location scopeLoc = getLoc(s.getSourceRange());
+  mlir::LogicalResult res = mlir::success();
+  builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
+                               [&](mlir::OpBuilder &b, mlir::Location loc) {
+                                 LexicalScope lexScope{
+                                     *this, loc, builder.getInsertionBlock()};
+                                 res = switchStmtBuilder();
+                               });
+
+  llvm::SmallVector<CaseOp> cases;
+  swop.collectCases(cases);
+  for (auto caseOp : cases)
+    terminateBody(builder, caseOp.getCaseRegion(), caseOp.getLoc());
+  terminateBody(builder, swop.getBody(), swop.getLoc());
+
+  return res;
+}
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 06f413acc8263..85ba12d41a516 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -15,7 +15,6 @@
 #include "clang/CIR/Dialect/IR/CIROpsEnums.h"
 #include "clang/CIR/Dialect/IR/CIRTypes.h"
 
-#include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Support/LogicalResult.h"
diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
new file mode 100644
index 0000000000000..5667949b9a0c0
--- /dev/null
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -0,0 +1,92 @@
+// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
+/// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
+// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG
+void sw1(int a) {
+  switch (int b = 1; a) {
+  case 0:
+    b = b + 1;
+    break;
+  case 1:
+    break;
+  case 2: {
+    b = b + 1;
+    int yolo = 100;
+    break;
+  }
+  }
+}
+// CIR: cir.func @sw1
+// CIR: cir.switch (%3 : !s32i) {
+// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
+// CIR: cir.break
+// CIR: cir.case(equal, [#cir.int<1> : !s32i]) {
+// CIR-NEXT: cir.break
+// CIR: cir.case(equal, [#cir.int<2> : !s32i]) {
+// CIR: cir.scope {
+// CIR: cir.alloca !s32i, !cir.ptr<!s32i>, ["yolo", init]
+// CIR: cir.break
+
+// OGCG: define dso_local void @_Z3sw1i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[B:.*]] = alloca i32, align 4
+// OGCG:   %[[YOLO:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[SW_EPILOG:.*]] [
+// OGCG:     i32 0, label %[[SW0:.*]]
+// OGCG:     i32 1, label %[[SW1:.*]]
+// OGCG:     i32 2, label %[[SW2:.*]]
+// OGCG:   ]
+// OGCG: [[SW0]]:
+// OGCG:   %[[B_LOAD0:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[B_INC0:.*]] = add nsw i32 %[[B_LOAD0]], 1
+// OGCG:   store i32 %[[B_INC0]], ptr %[[B]], align 4
+// OGCG:   br label %[[SW_EPILOG]]
+// OGCG: [[SW1]]:
+// OGCG:   br label %[[SW_EPILOG]]
+// OGCG: [[SW2]]:
+// OGCG:   %[[B_LOAD2:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[B_INC2:.*]] = add nsw i32 %[[B_LOAD2]], 1
+// OGCG:   store i32 %[[B_INC2]], ptr %[[B]], align 4
+// OGCG:   store i32 100, ptr %[[YOLO]], align 4
+// OGCG:   br label %[[SW_EPILOG]]
+// OGCG: [[SW_EPILOG]]:
+// OGCG:   ret void
+
+void sw2(int a) {
+  switch (int yolo = 2; a) {
+  case 3:
+    // "fomo" has the same lifetime as "yolo"
+    int fomo = 0;
+    yolo = yolo + fomo;
+    break;
+  }
+}
+
+// CIR: cir.func @sw2
+// CIR: cir.scope {
+// CIR-NEXT:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["yolo", init]
+// CIR-NEXT:   %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["fomo", init]
+// CIR:        cir.switch (%4 : !s32i) {
+// CIR-NEXT:   cir.case(equal, [#cir.int<3> : !s32i]) {
+// CIR-NEXT:     %5 = cir.const #cir.int<0> : !s32i
+// CIR-NEXT:     cir.store %5, %2 : !s32i, !cir.ptr<!s32i>
+
+// OGCG: define dso_local void @_Z3sw2i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[YOLO:.*]] = alloca i32, align 4
+// OGCG:   %[[FOMO:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[SW_EPILOG:.*]] [
+// OGCG:     i32 3, label %[[SW3:.*]]
+// OGCG:   ]
+// OGCG: [[SW3]]:
+// OGCG:   %[[Y:.*]] = load i32, ptr %[[YOLO]], align 4
+// OGCG:   %[[F:.*]] = load i32, ptr %[[FOMO]], align 4
+// OGCG:   %[[SUM:.*]] = add nsw i32 %[[Y]], %[[F]]
+// OGCG:   store i32 %[[SUM]], ptr %[[YOLO]], align 4
+// OGCG:   br label %[[SW_EPILOG]]
+// OGCG: [[SW_EPILOG]]:
+// OGCG:   ret void

>From 08ad25699257c630a6da14fca86ef677ae917357 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Wed, 23 Apr 2025 18:43:23 -0500
Subject: [PATCH 3/8] Add early return after NYI to prevent crashes

---
 clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index 76fa7bf07f097..9c17821415f32 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -493,9 +493,10 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
   // simple form later since the conversion itself should be harmless.
   if (subStmtKind == SubStmtKind::Case)
     result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
-  else if (subStmtKind == SubStmtKind::Default)
+  else if (subStmtKind == SubStmtKind::Default) {
     getCIRGenModule().errorNYI(sub->getSourceRange(), "Default case");
-  else if (buildingTopLevelCase)
+    return mlir::failure();
+  } else if (buildingTopLevelCase)
     // If we're building a top level case, try to restore the insert point to
     // the case we're building, then we can attach more random stmts to the
     // case to make generating `cir.switch` operation to be a simple form.
@@ -513,6 +514,7 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
   mlir::ArrayAttr value = builder.getArrayAttr(caseEltValueListAttr);
   if (s.getRHS()) {
     getCIRGenModule().errorNYI(s.getSourceRange(), "SwitchOp range kind");
+    return mlir::failure();
   }
   assert(!cir::MissingFeatures::foldCaseStmt());
   return emitCaseDefaultCascade(&s, condType, value, cir::CaseOpKind::Equal,
@@ -528,8 +530,10 @@ mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
     return emitCaseStmt(cast<CaseStmt>(s), condTypeStack.back(),
                         buildingTopLevelCase);
 
-  if (s.getStmtClass() == Stmt::DefaultStmtClass)
+  if (s.getStmtClass() == Stmt::DefaultStmtClass) {
     getCIRGenModule().errorNYI(s.getSourceRange(), "Default case");
+    return mlir::failure();
+  }
 
   llvm_unreachable("expect case or default stmt");
 }
@@ -724,7 +728,7 @@ mlir::LogicalResult CIRGenFunction::emitSwitchBody(const Stmt *s) {
   //      ...
   //    }
   if (!isa<CompoundStmt>(s))
-    return emitStmt(s, /*useCurrentScope=*/!false);
+    return emitStmt(s, /*useCurrentScope=*/true);
 
   auto *compoundStmt = cast<CompoundStmt>(s);
 

>From 6e1e7c3784bb9340b67c6bed6c00a1b75d6e3238 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Thu, 24 Apr 2025 17:53:24 -0500
Subject: [PATCH 4/8] Fixed format, addressed reviews, and added more tests for
 better coverage

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td |   4 +-
 clang/lib/CIR/CodeGen/CIRGenStmt.cpp         |   5 +
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp      |  14 +-
 clang/test/CIR/CodeGen/switch.cpp            | 191 +++++++++++++++++++
 4 files changed, 204 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 25cdf156659c7..bab1e994a1577 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -800,8 +800,8 @@ def CaseOp : CIR_Op<"case", [
   let hasVerifier = 1;
 
   let skipDefaultBuilders = 1;
-let builders = [
-    OpBuilder<(ins "mlir::ArrayAttr":$value,
+  let builders = [
+      OpBuilder<(ins "mlir::ArrayAttr":$value,
                    "CaseOpKind":$kind,
                    "mlir::OpBuilder::InsertPoint &":$insertPoint)>
   ];
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index 9c17821415f32..dfa290d7e415e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -252,6 +252,11 @@ mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
   // NullStmt doesn't need any handling, but we need to say we handled it.
   case Stmt::NullStmtClass:
     break;
+  case Stmt::CaseStmtClass:
+    // If we reached here, we must not handling a switch case in the top level.
+    return emitSwitchCase(cast<SwitchCase>(*s),
+                          /*buildingTopLevelCase=*/false);
+    break;
 
   case Stmt::BreakStmtClass:
     return emitBreakStmt(cast<BreakStmt>(*s));
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 85ba12d41a516..de942cc88c42e 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -842,22 +842,22 @@ static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
   cir::IntType intCondType;
 
   if (parser.parseLParen())
-    return ::mlir::failure();
+    return mlir::failure();
 
   if (parser.parseOperand(cond))
-    return ::mlir::failure();
+    return mlir::failure();
   if (parser.parseColon())
-    return ::mlir::failure();
+    return mlir::failure();
   if (parser.parseCustomTypeWithFallback(intCondType))
-    return ::mlir::failure();
+    return mlir::failure();
   condType = intCondType;
 
   if (parser.parseRParen())
-    return ::mlir::failure();
+    return mlir::failure();
   if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
     return failure();
 
-  return ::mlir::success();
+  return mlir::success();
 }
 
 static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
@@ -910,8 +910,6 @@ void cir::SwitchOp::collectCases(llvm::SmallVector<CaseOp> &cases) {
   });
 }
 
-// Check if the switch is in a simple form. If yes, collect the cases to \param
-// cases. This is an expensive and need to be used with caution.
 bool cir::SwitchOp::isSimpleForm(llvm::SmallVector<CaseOp> &cases) {
   collectCases(cases);
 
diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
index 5667949b9a0c0..dbcb2694aac5b 100644
--- a/clang/test/CIR/CodeGen/switch.cpp
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -90,3 +90,194 @@ void sw2(int a) {
 // OGCG:   br label %[[SW_EPILOG]]
 // OGCG: [[SW_EPILOG]]:
 // OGCG:   ret void
+
+void sw5(int a) {
+  switch (a) {
+  case 1:;
+  }
+}
+
+// CIR: cir.func @sw5
+// CIR: cir.switch (%1 : !s32i) {
+// CIR-NEXT:   cir.case(equal, [#cir.int<1> : !s32i]) {
+// CIR-NEXT:     cir.yield
+// CIR-NEXT:   }
+// CIR-NEXT:   cir.yield
+// CIR-NEXT:   }
+
+// OGCG: define dso_local void @_Z3sw5i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[SW_EPILOG:.*]] [
+// OGCG:     i32 1, label %[[SW1:.*]]
+// OGCG:   ]
+// OGCG: [[SW1]]:
+// OGCG:   br label %[[SW_EPILOG]]
+// OGCG: [[SW_EPILOG]]:
+// OGCG:   ret void
+
+void sw12(int a) {
+  switch (a)
+  {
+  case 3:
+    return;
+    break;
+  }
+}
+
+//      CIR: cir.func @sw12
+//      CIR:   cir.scope {
+//      CIR:     cir.switch
+// CIR-NEXT:     cir.case(equal, [#cir.int<3> : !s32i]) {
+// CIR-NEXT:       cir.return
+// CIR-NEXT:     ^bb1:  // no predecessors
+// CIR-NEXT:       cir.break
+// CIR-NEXT:     }
+
+// OGCG: define dso_local void @_Z4sw12i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[SW_DEFAULT:.*]] [
+// OGCG:     i32 3, label %[[SW3:.*]]
+// OGCG:   ]
+// OGCG: [[SW3]]:
+// OGCG:   br label %[[SW_DEFAULT]]
+// OGCG: [[SW_DEFAULT]]:
+// OGCG:   ret void
+
+void sw13(int a, int b) {
+  switch (a) {
+  case 1:
+    switch (b) {
+    case 2:
+      break;
+    }
+  }
+}
+
+//      CIR:  cir.func @sw13
+//      CIR:    cir.scope {
+//      CIR:      cir.switch
+// CIR-NEXT:      cir.case(equal, [#cir.int<1> : !s32i]) {
+// CIR-NEXT:        cir.scope {
+//      CIR:          cir.switch
+// CIR-NEXT:          cir.case(equal, [#cir.int<2> : !s32i]) {
+// CIR-NEXT:            cir.break
+// CIR-NEXT:          }
+// CIR-NEXT:          cir.yield
+// CIR-NEXT:        }
+// CIR-NEXT:      }
+// CIR:         cir.yield
+//      CIR:    }
+//      CIR:    cir.return
+
+// OGCG: define dso_local void @_Z4sw13ii
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[B_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[EPILOG2:.*]] [
+// OGCG:     i32 1, label %[[SW1:.*]]
+// OGCG:   ]
+// OGCG: [[SW1]]:
+// OGCG:   %[[B_VAL:.*]] = load i32, ptr %[[B_ADDR]], align 4
+// OGCG:   switch i32 %[[B_VAL]], label %[[EPILOG:.*]] [
+// OGCG:     i32 2, label %[[SW12:.*]]
+// OGCG:   ]
+// OGCG: [[SW12]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[EPILOG]]:
+// OGCG:   br label %[[EPILOG2]]
+// OGCG: [[EPILOG2]]:
+// OGCG:   ret void
+
+int nested_switch(int a) {
+  switch (int b = 1; a) {
+  case 0:
+    b = b + 1;
+  case 1:
+    return b;
+  case 2: {
+    b = b + 1;
+    if (a > 1000) {
+        case 9:
+          b = a + b;
+    }
+    if (a > 500) {
+        case 7:
+          return a + b;
+    }
+    break;
+  }
+  }
+
+  return 0;
+}
+
+// CIR: cir.switch (%6 : !s32i) {
+// CIR:   cir.case(equal, [#cir.int<0> : !s32i]) {
+// CIR:     cir.yield
+// CIR:   }
+// CIR:   cir.case(equal, [#cir.int<1> : !s32i]) {
+// CIR:     cir.return
+// CIR:   }
+// CIR:   cir.case(equal, [#cir.int<2> : !s32i]) {
+// CIR:     cir.scope {
+// CIR:     cir.scope {
+// CIR:       cir.if
+// CIR:         cir.case(equal, [#cir.int<9> : !s32i]) {
+// CIR:         cir.yield
+// CIR:     cir.scope {
+// CIR:         cir.if
+// CIR:           cir.case(equal, [#cir.int<7> : !s32i]) {
+// CIR:           cir.return
+
+// OGCG: define dso_local noundef i32 @_Z13nested_switchi
+// OGCG: entry:
+// OGCG:   %[[RETVAL:.*]] = alloca i32, align 4
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[B:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[EPILOG:.*]] [
+// OGCG:     i32 0, label %[[SW0:.*]]
+// OGCG:     i32 1, label %[[SW1:.*]]
+// OGCG:     i32 2, label %[[SW2:.*]]
+// OGCG:     i32 9, label %[[SW4:.*]]
+// OGCG:     i32 7, label %[[SW8:.*]]
+// OGCG:   ]
+// OGCG: [[SW0]]:
+// OGCG:   %[[B_VAL0:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[ADD0:.*]] = add nsw i32 %[[B_VAL0]], 1
+// OGCG:   br label %[[SW1]]
+// OGCG: [[SW1]]:
+// OGCG:   %[[B_VAL1:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   br label %[[RETURN:.*]]
+// OGCG: [[SW2]]:
+// OGCG:   %[[B_VAL2:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[ADD2:.*]] = add nsw i32 %[[B_VAL2]], 1
+// OGCG:   %[[A_VAL2:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   %[[CMP1000:.*]] = icmp sgt i32 %[[A_VAL2]], 1000
+// OGCG:   br i1 %[[CMP1000]], label %[[IFTHEN:.*]], label %[[IFEND:.*]]
+// OGCG: [[IFTHEN]]:
+// OGCG:   br label %[[SW4]]
+// OGCG: [[SW4]]:
+// OGCG:   %[[A_VAL4:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   %[[B_VAL4:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[ADD4:.*]] = add nsw i32 %[[A_VAL4]], %[[B_VAL4]]
+// OGCG:   br label %[[IFEND]]
+// OGCG: [[IFEND]]:
+// OGCG:   %[[A_VAL5:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   %[[CMP500:.*]] = icmp sgt i32 %[[A_VAL5]], 500
+// OGCG:   br i1 %[[CMP500]], label %[[IFTHEN7:.*]], label %[[IFEND10:.*]]
+// OGCG: [[IFTHEN7]]:
+// OGCG:   br label %[[SW8]]
+// OGCG: [[SW8]]:
+// OGCG:   %[[A_VAL8:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   %[[B_VAL8:.*]] = load i32, ptr %[[B]], align 4
+// OGCG:   %[[ADD8:.*]] = add nsw i32 %[[A_VAL8]], %[[B_VAL8]]
+// OGCG:   br label %[[RETURN]]
+// OGCG: [[IFEND10]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[EPILOG]]:

>From a8741c5eecc93094cfdeb1c0ad1030159b21acb8 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Fri, 25 Apr 2025 10:42:12 -0500
Subject: [PATCH 5/8] Remove empty verifier

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td | 4 ----
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp      | 4 ----
 2 files changed, 8 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index bab1e994a1577..a414905cd8d94 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -797,8 +797,6 @@ def CaseOp : CIR_Op<"case", [
 
   let assemblyFormat = "`(` $kind `,` $value `)` $caseRegion attr-dict";
 
-  let hasVerifier = 1;
-
   let skipDefaultBuilders = 1;
   let builders = [
       OpBuilder<(ins "mlir::ArrayAttr":$value,
@@ -948,8 +946,6 @@ def SwitchOp : CIR_Op<"switch",
 
   let regions = (region AnyRegion:$body);
 
-  let hasVerifier = 1;
-
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "mlir::Value":$condition,
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index de942cc88c42e..b523b1048764d 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -830,8 +830,6 @@ void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
   insertPoint = builder.saveInsertionPoint();
 }
 
-LogicalResult cir::CaseOp::verify() { return success(); }
-
 //===----------------------------------------------------------------------===//
 // SwitchOp
 //===----------------------------------------------------------------------===//
@@ -884,8 +882,6 @@ void cir::SwitchOp::getSuccessorRegions(
   region.push_back(RegionSuccessor(&getBody()));
 }
 
-LogicalResult cir::SwitchOp::verify() { return success(); }
-
 void cir::SwitchOp::build(
     OpBuilder &builder, OperationState &result, Value cond,
     function_ref<void(OpBuilder &, Location, OperationState &)> switchBuilder) {

>From ad4fd8747ae275ec74f92d2eea3ba680d07c34e2 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Fri, 25 Apr 2025 14:10:23 -0500
Subject: [PATCH 6/8] Rename tests to match with new mangled names

---
 clang/test/CIR/CodeGen/switch.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
index dbcb2694aac5b..be31b5249182b 100644
--- a/clang/test/CIR/CodeGen/switch.cpp
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -16,7 +16,7 @@ void sw1(int a) {
   }
   }
 }
-// CIR: cir.func @sw1
+// CIR: cir.func @_Z3sw1i
 // CIR: cir.switch (%3 : !s32i) {
 // CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
 // CIR: cir.break
@@ -64,7 +64,7 @@ void sw2(int a) {
   }
 }
 
-// CIR: cir.func @sw2
+// CIR: cir.func @_Z3sw2i
 // CIR: cir.scope {
 // CIR-NEXT:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["yolo", init]
 // CIR-NEXT:   %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["fomo", init]
@@ -97,7 +97,7 @@ void sw5(int a) {
   }
 }
 
-// CIR: cir.func @sw5
+// CIR: cir.func @_Z3sw5i
 // CIR: cir.switch (%1 : !s32i) {
 // CIR-NEXT:   cir.case(equal, [#cir.int<1> : !s32i]) {
 // CIR-NEXT:     cir.yield
@@ -126,7 +126,7 @@ void sw12(int a) {
   }
 }
 
-//      CIR: cir.func @sw12
+//      CIR: cir.func @_Z4sw12i
 //      CIR:   cir.scope {
 //      CIR:     cir.switch
 // CIR-NEXT:     cir.case(equal, [#cir.int<3> : !s32i]) {
@@ -157,7 +157,7 @@ void sw13(int a, int b) {
   }
 }
 
-//      CIR:  cir.func @sw13
+//      CIR:  cir.func @_Z4sw13ii
 //      CIR:    cir.scope {
 //      CIR:      cir.switch
 // CIR-NEXT:      cir.case(equal, [#cir.int<1> : !s32i]) {

>From 7d5a53f1c2b6c06a628970c7a137a14d962c39ab Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Fri, 25 Apr 2025 17:48:53 -0500
Subject: [PATCH 7/8] Add more switch tests and mark unsupported constructs as
 TODO

---
 clang/test/CIR/CodeGen/switch.cpp | 109 ++++++++++++++++++++++++++++++
 1 file changed, 109 insertions(+)

diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
index be31b5249182b..36523755376a1 100644
--- a/clang/test/CIR/CodeGen/switch.cpp
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -91,6 +91,45 @@ void sw2(int a) {
 // OGCG: [[SW_EPILOG]]:
 // OGCG:   ret void
 
+int sw4(int a) {
+  switch (a) {
+  case 42: {
+    return 3;
+  }
+  // TODO: add default case when it is upstreamed
+  }
+  return 0;
+}
+
+// CIR: cir.func @_Z3sw4i
+// CIR:       cir.switch (%4 : !s32i) {
+// CIR-NEXT:       cir.case(equal, [#cir.int<42> : !s32i]) {
+// CIR-NEXT:         cir.scope {
+// CIR-NEXT:           %5 = cir.const #cir.int<3> : !s32i
+// CIR-NEXT:           cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+// CIR-NEXT:           %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CIR-NEXT:           cir.return %6 : !s32i
+// CIR-NEXT:         }
+// CIR-NEXT:         cir.yield
+// CIR-NEXT:       }
+
+// OGCG: define dso_local noundef i32 @_Z3sw4i
+// OGCG: entry:
+// OGCG:   %[[RETVAL:.*]] = alloca i32, align 4
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[EPILOG:.*]] [
+// OGCG:     i32 42, label %[[SW42:.*]]
+// OGCG:   ]
+// OGCG: [[SW42]]:
+// OGCG:   br label %[[RETURN:.*]]
+// OGCG: [[EPILOG]]:
+// OGCG:   br label %[[RETURN]]
+// OGCG: [[RETURN]]:
+// OGCG:   %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
+// OGCG:   ret i32 %[[RETVAL_LOAD]]
+
+
 void sw5(int a) {
   switch (a) {
   case 1:;
@@ -117,6 +156,76 @@ void sw5(int a) {
 // OGCG: [[SW_EPILOG]]:
 // OGCG:   ret void
 
+void sw8(int a) {
+  switch (a)
+  {
+  case 3:
+    break;
+  case 4:
+  // TODO: add default case when it is upstreamed
+    break;
+  }
+}
+
+// CIR:    cir.func @_Z3sw8i
+// CIR:      cir.case(equal, [#cir.int<3> : !s32i]) {
+// CIR-NEXT:   cir.break
+// CIR-NEXT: }
+// CIR-NEXT: cir.case(equal, [#cir.int<4> : !s32i]) {
+// CIR-NEXT:   cir.break
+// CIR-NEXT: }
+
+
+// OGCG: define dso_local void @_Z3sw8i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[EPILOG:.*]] [
+// OGCG:     i32 3, label %[[SW3:.*]]
+// OGCG:     i32 4, label %[[SW4:.*]]
+// OGCG:   ]
+// OGCG: [[SW3]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[SW4]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[EPILOG]]:
+// OGCG:   ret void
+
+
+void sw9(int a) {
+  switch (a)
+  {
+  case 3:
+    break;
+  // TODO: add default case when it is upstreamed
+  case 4:
+    break;
+  }
+}
+
+// CIR:    cir.func @_Z3sw9i
+// CIR:      cir.case(equal, [#cir.int<3> : !s32i]) {
+// CIR-NEXT:   cir.break
+// CIR-NEXT: }
+// CIR-NEXT: cir.case(equal, [#cir.int<4> : !s32i]) {
+// CIR-NEXT:   cir.break
+// CIR-NEXT: }
+
+// OGCG: define dso_local void @_Z3sw9i
+// OGCG: entry:
+// OGCG:   %[[A_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[A_VAL:.*]] = load i32, ptr %[[A_ADDR]], align 4
+// OGCG:   switch i32 %[[A_VAL]], label %[[EPILOG:.*]] [
+// OGCG:     i32 3, label %[[SW3:.*]]
+// OGCG:     i32 4, label %[[SW4:.*]]
+// OGCG:   ]
+// OGCG: [[SW3]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[SW4]]:
+// OGCG:   br label %[[EPILOG]]
+// OGCG: [[EPILOG]]:
+// OGCG:   ret void
+
 void sw12(int a) {
   switch (a)
   {

>From d7dbe0cd0a20d3db1fb78888f1d08d452897e764 Mon Sep 17 00:00:00 2001
From: Andres Salamanca <andrealebarbaritos at gmail.com>
Date: Mon, 28 Apr 2025 13:20:10 -0500
Subject: [PATCH 8/8] Apply new code review suggestions

---
 .../include/clang/CIR/Dialect/IR/CIRDialect.h |  2 ++
 clang/include/clang/CIR/Dialect/IR/CIROps.td  | 29 +++++++++----------
 clang/include/clang/CIR/MissingFeatures.h     |  1 +
 clang/lib/CIR/CodeGen/CIRGenStmt.cpp          | 11 ++++---
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       |  9 +++---
 5 files changed, 28 insertions(+), 24 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
index 36c5ad1388afa..ffa727cca4064 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.h
@@ -59,6 +59,8 @@ class SameFirstOperandAndResultType
 
 using BuilderCallbackRef =
     llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>;
+using BuilderOpStateCallbackRef = llvm::function_ref<void(
+    mlir::OpBuilder &, mlir::Location, mlir::OperationState &)>;
 
 namespace cir {
 void buildTerminatedBody(mlir::OpBuilder &builder, mlir::Location loc);
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index a414905cd8d94..283b877ef54ff 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -821,14 +821,14 @@ def SwitchOp : CIR_Op<"switch",
     operation in the `cir.switch` operation easily.
 
     The `cir.case` operations doesn't have to be in the region of `cir.switch`
-    directly. However, when all the `cir.case` operations lives in the region
-    of `cir.switch` directly and there is no other operations except the ending
-    `cir.yield` operation in the region of `cir.switch` directly, we call the
+    directly. However, when all the `cir.case` operations live in the region
+    of `cir.switch` directly and there are no other operations except the ending
+    `cir.yield` operation in the region of `cir.switch` directly, we say the
     `cir.switch` operation is in a simple form. Users can use
     `bool isSimpleForm(llvm::SmallVector<CaseOp> &cases)` member function to
     detect if the `cir.switch` operation is in a simple form. The simple form
-    makes analysis easier to handle the `cir.switch` operation
-    and makes the boundary to give up pretty clear.
+    makes it easier for analyses to handle the `cir.switch` operation
+    and makes the boundary to give up clear.
 
     To make the simple form as common as possible, CIR code generation attaches
     operations corresponding to the statements that lives between top level
@@ -840,9 +840,8 @@ def SwitchOp : CIR_Op<"switch",
     switch(int cond) {
       case 4:
         a++;
-
-      b++;
-      case 5;
+        b++;
+      case 5:
         c++;
 
       ...
@@ -877,7 +876,7 @@ def SwitchOp : CIR_Op<"switch",
     switch(int cond) {
       case 4:
       default;
-      case 5;
+      case 5:
         a++;
       ...
     }
@@ -901,13 +900,13 @@ def SwitchOp : CIR_Op<"switch",
     }
     ```
 
-    The cir.switch might not be considered "simple" if any of the following is
+    The cir.switch is not be considered "simple" if any of the following is
     true:
-    - There are case statements of the switch statement lives in other scopes
+    - There are case statements of the switch statement that are scope
       other than the top level compound statement scope. Note that a case
       statement itself doesn't form a scope.
     - The sub-statement of the switch statement is not a compound statement.
-    - There are codes before the first case statement. For example,
+    - There is any code before the first case statement. For example,
 
     ```
     switch(int cond) {
@@ -949,7 +948,7 @@ def SwitchOp : CIR_Op<"switch",
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "mlir::Value":$condition,
-               "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::OperationState &)>":$switchBuilder)>
+               "BuilderOpStateCallbackRef":$switchBuilder)>
   ];
 
   let assemblyFormat = [{
@@ -961,12 +960,12 @@ def SwitchOp : CIR_Op<"switch",
 
   let extraClassDeclaration = [{
     // Collect cases in the switch.
-    void collectCases(llvm::SmallVector<CaseOp> &cases);
+    void collectCases(llvm::SmallVectorImpl<CaseOp> &cases);
 
     // Check if the switch is in a simple form.
     // If yes, collect the cases to \param cases.
     // This is an expensive and need to be used with caution.
-    bool isSimpleForm(llvm::SmallVector<CaseOp> &cases);
+    bool isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases);
   }];
 }
 
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index fc0999e83b727..4d4951aa0e126 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -162,6 +162,7 @@ struct MissingFeatures {
   static bool moduleNameHash() { return false; }
   static bool setDSOLocal() { return false; }
   static bool foldCaseStmt() { return false; }
+  static bool constantFoldSwitchStatement() { return false; }
 
   // Missing types
   static bool dataMemberType() { return false; }
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
index dfa290d7e415e..31e29e7828156 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp
@@ -458,8 +458,9 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
     } else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
       subStmtKind = SubStmtKind::Case;
       builder.createYield(loc);
-    } else
+    } else {
       result = emitStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
+    }
 
     insertPoint = builder.saveInsertionPoint();
   }
@@ -496,16 +497,17 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
   //
   // We don't need to revert this if we find the current switch can't be in
   // simple form later since the conversion itself should be harmless.
-  if (subStmtKind == SubStmtKind::Case)
+  if (subStmtKind == SubStmtKind::Case) {
     result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
-  else if (subStmtKind == SubStmtKind::Default) {
+  } else if (subStmtKind == SubStmtKind::Default) {
     getCIRGenModule().errorNYI(sub->getSourceRange(), "Default case");
     return mlir::failure();
-  } else if (buildingTopLevelCase)
+  } else if (buildingTopLevelCase) {
     // If we're building a top level case, try to restore the insert point to
     // the case we're building, then we can attach more random stmts to the
     // case to make generating `cir.switch` operation to be a simple form.
     builder.restoreInsertionPoint(insertPoint);
+  }
 
   return result;
 }
@@ -764,6 +766,7 @@ mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) {
   // only emit live cases. CIR should use MLIR to achieve similar things,
   // nothing to be done here.
   // if (ConstantFoldsToSimpleInteger(S.getCond(), ConstantCondValue))...
+  assert(!cir::MissingFeatures::constantFoldSwitchStatement());
 
   SwitchOp swop;
   auto switchStmtBuilder = [&]() -> mlir::LogicalResult {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b523b1048764d..dac5d4222f8af 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -882,9 +882,8 @@ void cir::SwitchOp::getSuccessorRegions(
   region.push_back(RegionSuccessor(&getBody()));
 }
 
-void cir::SwitchOp::build(
-    OpBuilder &builder, OperationState &result, Value cond,
-    function_ref<void(OpBuilder &, Location, OperationState &)> switchBuilder) {
+void cir::SwitchOp::build(OpBuilder &builder, OperationState &result,
+                          Value cond, BuilderOpStateCallbackRef switchBuilder) {
   assert(switchBuilder && "the builder callback for regions must be present");
   OpBuilder::InsertionGuard guardSwitch(builder);
   Region *switchRegion = result.addRegion();
@@ -893,7 +892,7 @@ void cir::SwitchOp::build(
   switchBuilder(builder, result.location, result);
 }
 
-void cir::SwitchOp::collectCases(llvm::SmallVector<CaseOp> &cases) {
+void cir::SwitchOp::collectCases(llvm::SmallVectorImpl<CaseOp> &cases) {
   walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
     // Don't walk in nested switch op.
     if (isa<cir::SwitchOp>(op) && op != *this)
@@ -906,7 +905,7 @@ void cir::SwitchOp::collectCases(llvm::SmallVector<CaseOp> &cases) {
   });
 }
 
-bool cir::SwitchOp::isSimpleForm(llvm::SmallVector<CaseOp> &cases) {
+bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) {
   collectCases(cases);
 
   if (getBody().empty())



More information about the cfe-commits mailing list