[clang] [CIR] Fix multiple returns in switch statements (PR #164468)

via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 21 10:56:02 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: Morris Hafner (mmha)

<details>
<summary>Changes</summary>

Add support for multiple return statements in switch statements. Cases in switch statements don't have their own scopes but are distinct regions nonetheless. Insert multiple return blocks for each case and handle them in the cleanup code.

---
Full diff: https://github.com/llvm/llvm-project/pull/164468.diff


3 Files Affected:

- (modified) clang/lib/CIR/CodeGen/CIRGenFunction.cpp (+20-13) 
- (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+53-28) 
- (modified) clang/test/CIR/CodeGen/switch.cpp (+87) 


``````````diff
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
index 25a46df406df4..23440689c14a9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
@@ -242,12 +242,15 @@ void CIRGenFunction::LexicalScope::cleanup() {
     }
   };
 
-  if (returnBlock != nullptr) {
-    // Write out the return block, which loads the value from `__retval` and
-    // issues the `cir.return`.
+  // Cleanup are done right before codegen resumes a scope. This is where
+  // objects are destroyed. Process all return blocks.
+  llvm::SmallVector<mlir::Block *> retBlocks;
+  for (mlir::Block *retBlock : localScope->getRetBlocks()) {
     mlir::OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToEnd(returnBlock);
-    (void)emitReturn(*returnLoc);
+    builder.setInsertionPointToEnd(retBlock);
+    retBlocks.push_back(retBlock);
+    mlir::Location retLoc = localScope->getRetLoc(retBlock);
+    emitReturn(retLoc);
   }
 
   auto insertCleanupAndLeave = [&](mlir::Block *insPt) {
@@ -274,19 +277,21 @@ void CIRGenFunction::LexicalScope::cleanup() {
 
     if (localScope->depth == 0) {
       // Reached the end of the function.
-      if (returnBlock != nullptr) {
-        if (returnBlock->getUses().empty()) {
-          returnBlock->erase();
+      // Special handling only for single return block case
+      if (localScope->getRetBlocks().size() == 1) {
+        mlir::Block *retBlock = localScope->getRetBlocks()[0];
+        mlir::Location retLoc = localScope->getRetLoc(retBlock);
+        if (retBlock->getUses().empty()) {
+          retBlock->erase();
         } else {
           // Thread return block via cleanup block.
           if (cleanupBlock) {
-            for (mlir::BlockOperand &blockUse : returnBlock->getUses()) {
+            for (mlir::BlockOperand &blockUse : retBlock->getUses()) {
               cir::BrOp brOp = mlir::cast<cir::BrOp>(blockUse.getOwner());
               brOp.setSuccessor(cleanupBlock);
             }
           }
-
-          builder.create<cir::BrOp>(*returnLoc, returnBlock);
+          builder.create<cir::BrOp>(retLoc, retBlock);
           return;
         }
       }
@@ -324,8 +329,10 @@ void CIRGenFunction::LexicalScope::cleanup() {
   bool entryBlock = builder.getInsertionBlock()->isEntryBlock();
   if (!entryBlock && curBlock->empty()) {
     curBlock->erase();
-    if (returnBlock != nullptr && returnBlock->getUses().empty())
-      returnBlock->erase();
+    for (mlir::Block *retBlock : retBlocks) {
+      if (retBlock->getUses().empty())
+        retBlock->erase();
+    }
     return;
   }
 
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 5a71126c8dc07..dfa7917be6b57 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1040,44 +1040,69 @@ class CIRGenFunction : public CIRGenTypeCache {
     // ---
 
   private:
-    // `returnBlock`, `returnLoc`, and all the functions that deal with them
-    // will change and become more complicated when `switch` statements are
-    // upstreamed.  `case` statements within the `switch` are in the same scope
-    // but have their own regions.  Therefore the LexicalScope will need to
-    // keep track of multiple return blocks.
-    mlir::Block *returnBlock = nullptr;
-    std::optional<mlir::Location> returnLoc;
-
-    // See the comment on `getOrCreateRetBlock`.
+    // On switches we need one return block per region, since cases don't
+    // have their own scopes but are distinct regions nonetheless.
+
+    // TODO: This implementation should change once we have support for early
+    //       exits in MLIR structured control flow (llvm-project#161575)
+    llvm::SmallVector<mlir::Block *> retBlocks;
+    llvm::DenseMap<mlir::Block *, mlir::Location> retLocs;
+    llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex;
+    std::optional<unsigned> normalRetBlockIndex;
+
+    // There's usually only one ret block per scope, but this needs to be
+    // get or create because of potential unreachable return statements, note
+    // that for those, all source location maps to the first one found.
     mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
-      assert(returnBlock == nullptr && "only one return block per scope");
-      // Create the cleanup block but don't hook it up just yet.
+      assert((isa_and_nonnull<cir::CaseOp>(
+                  cgf.builder.getBlock()->getParentOp()) ||
+              retBlocks.size() == 0) &&
+             "only switches can hold more than one ret block");
+
+      // Create the return block but don't hook it up just yet.
       mlir::OpBuilder::InsertionGuard guard(cgf.builder);
-      returnBlock =
-          cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
-      updateRetLoc(returnBlock, loc);
-      return returnBlock;
+      auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent());
+      retBlocks.push_back(b);
+      updateRetLoc(b, loc);
+      return b;
     }
 
     cir::ReturnOp emitReturn(mlir::Location loc);
     void emitImplicitReturn();
 
   public:
-    mlir::Block *getRetBlock() { return returnBlock; }
-    mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; }
-    void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; }
-
-    // Create the return block for this scope, or return the existing one.
-    // This get-or-create logic is necessary to handle multiple return
-    // statements within the same scope, which can happen if some of them are
-    // dead code or if there is a `goto` into the middle of the scope.
+    llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; }
+    mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); }
+    void updateRetLoc(mlir::Block *b, mlir::Location loc) {
+      retLocs.insert_or_assign(b, loc);
+    }
+
     mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) {
-      if (returnBlock == nullptr) {
-        returnBlock = createRetBlock(cgf, loc);
-        return returnBlock;
+      // Check if we're inside a case region
+      if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>(
+              cgf.builder.getBlock()->getParentOp())) {
+        auto iter = retBlockInCaseIndex.find(caseOp);
+        if (iter != retBlockInCaseIndex.end()) {
+          // Reuse existing return block
+          mlir::Block *ret = retBlocks[iter->second];
+          updateRetLoc(ret, loc);
+          return ret;
+        }
+        // Create new return block
+        mlir::Block *ret = createRetBlock(cgf, loc);
+        retBlockInCaseIndex[caseOp] = retBlocks.size() - 1;
+        return ret;
       }
-      updateRetLoc(returnBlock, loc);
-      return returnBlock;
+
+      if (normalRetBlockIndex) {
+        mlir::Block *ret = retBlocks[*normalRetBlockIndex];
+        updateRetLoc(ret, loc);
+        return ret;
+      }
+
+      mlir::Block *ret = createRetBlock(cgf, loc);
+      normalRetBlockIndex = retBlocks.size() - 1;
+      return ret;
     }
 
     mlir::Block *getEntryBlock() { return entryBlock; }
diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp
index e13aa8f4f4953..3824be0d08c2f 100644
--- a/clang/test/CIR/CodeGen/switch.cpp
+++ b/clang/test/CIR/CodeGen/switch.cpp
@@ -1183,3 +1183,90 @@ int nested_switch(int a) {
 // OGCG: [[IFEND10]]:
 // OGCG:   br label %[[EPILOG]]
 // OGCG: [[EPILOG]]:
+
+int sw_return_multi_cases(int x) {
+  switch (x) {
+  case 0:
+    return 0;
+  case 1:
+    return 1;
+  case 2:
+    return 2;
+  default:
+    return -1;
+  }
+}
+
+// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
+// CIR:       cir.switch (%{{.*}} : !s32i) {
+// CIR-NEXT:  cir.case(equal, [#cir.int<0> : !s32i]) {
+// CIR:         %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
+// CIR:         cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
+// CIR:         %[[RET0:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
+// CIR-NEXT:    cir.return %[[RET0]] : !s32i
+// CIR-NEXT:  }
+// CIR-NEXT:  cir.case(equal, [#cir.int<1> : !s32i]) {
+// CIR:         %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
+// CIR:         cir.store{{.*}} %[[ONE]], %{{.*}} : !s32i, !cir.ptr<!s32i>
+// CIR:         %[[RET1:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
+// CIR-NEXT:    cir.return %[[RET1]] : !s32i
+// CIR-NEXT:  }
+// CIR-NEXT:  cir.case(equal, [#cir.int<2> : !s32i]) {
+// CIR:         %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
+// CIR:         cir.store{{.*}} %[[TWO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
+// CIR:         %[[RET2:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
+// CIR-NEXT:    cir.return %[[RET2]] : !s32i
+// CIR-NEXT:  }
+// CIR-NEXT:  cir.case(default, []) {
+// CIR:         %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
+// CIR:         %[[NEG:.*]] = cir.unary(minus, %[[ONE]]) {{.*}} : !s32i, !s32i
+// CIR:         cir.store{{.*}} %[[NEG]], %{{.*}} : !s32i, !cir.ptr<!s32i>
+// CIR:         %[[RETDEF:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i
+// CIR-NEXT:    cir.return %[[RETDEF]] : !s32i
+// CIR-NEXT:  }
+// CIR-NEXT:  cir.yield
+
+// LLVM-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
+// LLVM:   switch i32 %{{.*}}, label %[[DEFAULT:.*]] [
+// LLVM-DAG:   i32 0, label %[[CASE0:.*]]
+// LLVM-DAG:   i32 1, label %[[CASE1:.*]]
+// LLVM-DAG:   i32 2, label %[[CASE2:.*]]
+// LLVM:   ]
+// LLVM: [[CASE0]]:
+// LLVM:   store i32 0, ptr %{{.*}}, align 4
+// LLVM:   %{{.*}} = load i32, ptr %{{.*}}, align 4
+// LLVM:   ret i32 %{{.*}}
+// LLVM: [[CASE1]]:
+// LLVM:   store i32 1, ptr %{{.*}}, align 4
+// LLVM:   %{{.*}} = load i32, ptr %{{.*}}, align 4
+// LLVM:   ret i32 %{{.*}}
+// LLVM: [[CASE2]]:
+// LLVM:   store i32 2, ptr %{{.*}}, align 4
+// LLVM:   %{{.*}} = load i32, ptr %{{.*}}, align 4
+// LLVM:   ret i32 %{{.*}}
+// LLVM: [[DEFAULT]]:
+// LLVM:   store i32 -1, ptr %{{.*}}, align 4
+// LLVM:   %{{.*}} = load i32, ptr %{{.*}}, align 4
+// LLVM:   ret i32 %{{.*}}
+
+// OGCG-LABEL: define{{.*}} i32 @_Z21sw_return_multi_casesi
+// OGCG: entry:
+// OGCG:   %[[RETVAL:.*]] = alloca i32, align 4
+// OGCG:   %[[X_ADDR:.*]] = alloca i32, align 4
+// OGCG:   %[[X_VAL:.*]] = load i32, ptr %[[X_ADDR]], align 4
+// OGCG:   switch i32 %[[X_VAL]], label %[[DEFAULT:.*]] [
+// OGCG-DAG:   i32 0, label %[[SW0:.*]]
+// OGCG-DAG:   i32 1, label %[[SW1:.*]]
+// OGCG-DAG:   i32 2, label %[[SW2:.*]]
+// OGCG:   ]
+// OGCG: [[SW0]]:
+// OGCG:   br label %[[RETURN:.*]]
+// OGCG: [[SW1]]:
+// OGCG:   br label %[[RETURN]]
+// OGCG: [[SW2]]:
+// OGCG:   br label %[[RETURN]]
+// OGCG: [[DEFAULT]]:
+// OGCG:   br label %[[RETURN]]
+// OGCG: [[RETURN]]:
+// OGCG:   %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
+// OGCG:   ret i32 %[[RETVAL_LOAD]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/164468


More information about the cfe-commits mailing list