[llvm] [SandboxIR] Remove tight-coupling with LLVM's SwitchInst::CaseHandle (PR #167093)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 8 13:03:20 PST 2025


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/167093

>From 66ad147064206c37261948cb1f32aef7370cf282 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vasileios.porpodas at amd.com>
Date: Fri, 7 Nov 2025 18:45:59 +0000
Subject: [PATCH] [SandboxIR] Remove tight-coupling with LLVM's
 SwitchInst::CaseHandle

SandboxIR's SwitchInst CaseHandle was relying on LLVM IR's
SwitchInst::CaseHandleImpl template, which may call private functions of
SandboxIR's SwitchInst. This creates a dependency cycle which is against the
design principles of Sandbox IR.

The issue was exposed by: https://github.com/llvm/llvm-project/pull/166842
Thanks to @aengelke for raising the issue.
---
 llvm/include/llvm/SandboxIR/Instruction.h | 106 +++++++++++++++++-----
 llvm/lib/SandboxIR/Instruction.cpp        |  27 ++++++
 2 files changed, 112 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index 5e369a482be57..04bc359b5c266 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -1884,22 +1884,89 @@ class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
     return cast<llvm::SwitchInst>(Val)->getNumCases();
   }
 
+  template <typename LLVMCaseItT, typename BlockT, typename ConstT>
+  class CaseItImpl;
+
+  template <typename LLVMCaseItT, typename BlockT, typename ConstT>
+  class CaseHandleImpl {
+    Context &Ctx;
+    LLVMCaseItT LLVMCaseIt;
+    template <typename T1, typename T2, typename T3> friend class CaseItImpl;
+
+  public:
+    CaseHandleImpl(Context &Ctx, LLVMCaseItT LLVMCaseIt)
+        : Ctx(Ctx), LLVMCaseIt(LLVMCaseIt) {}
+    ConstT *getCaseValue() const;
+    BlockT *getCaseSuccessor() const;
+    unsigned getCaseIndex() const {
+      const auto &LLVMCaseHandle = *LLVMCaseIt;
+      return LLVMCaseHandle.getCaseIndex();
+    }
+    unsigned getSuccessorIndex() const {
+      const auto &LLVMCaseHandle = *LLVMCaseIt;
+      return LLVMCaseHandle.getSuccessorIndex();
+    }
+  };
+
+  template <typename LLVMCaseItT, typename BlockT, typename ConstT>
+  class CaseItImpl : public iterator_facade_base<
+                         CaseItImpl<LLVMCaseItT, BlockT, ConstT>,
+                         std::random_access_iterator_tag,
+                         const CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>> {
+    CaseHandleImpl<LLVMCaseItT, BlockT, ConstT> CH;
+
+  public:
+    CaseItImpl(Context &Ctx, LLVMCaseItT It) : CH(Ctx, It) {}
+    CaseItImpl(SwitchInst *SI, ptrdiff_t CaseNum)
+        : CH(SI->getContext(), llvm::SwitchInst::CaseIt(
+                                   cast<llvm::SwitchInst>(SI->Val), CaseNum)) {}
+    CaseItImpl &operator+=(ptrdiff_t N) {
+      CH.LLVMCaseIt.operator+=(N);
+      return *this;
+    }
+    CaseItImpl &operator-=(ptrdiff_t N) {
+      CH.LLVMCaseIt.operator-=(N);
+      return *this;
+    }
+    ptrdiff_t operator-(const CaseItImpl &Other) const {
+      return CH.LLVMCaseIt - Other.CH.LLVMCaseIt;
+    }
+    bool operator==(const CaseItImpl &Other) const {
+      return CH.LLVMCaseIt == Other.CH.LLVMCaseIt;
+    }
+    bool operator<(const CaseItImpl &Other) const {
+      return CH.LLVMCaseIt < Other.CH.LLVMCaseIt;
+    }
+    const CaseHandleImpl<LLVMCaseItT, BlockT, ConstT> &operator*() const {
+      return CH;
+    }
+  };
+
   using CaseHandle =
-      llvm::SwitchInst::CaseHandleImpl<SwitchInst, ConstantInt, BasicBlock>;
-  using ConstCaseHandle =
-      llvm::SwitchInst::CaseHandleImpl<const SwitchInst, const ConstantInt,
-                                       const BasicBlock>;
-  using CaseIt = llvm::SwitchInst::CaseIteratorImpl<CaseHandle>;
-  using ConstCaseIt = llvm::SwitchInst::CaseIteratorImpl<ConstCaseHandle>;
+      CaseHandleImpl<llvm::SwitchInst::CaseIt, BasicBlock, ConstantInt>;
+  using CaseIt = CaseItImpl<llvm::SwitchInst::CaseIt, BasicBlock, ConstantInt>;
+
+  using ConstCaseHandle = CaseHandleImpl<llvm::SwitchInst::ConstCaseIt,
+                                         const BasicBlock, const ConstantInt>;
+  using ConstCaseIt = CaseItImpl<llvm::SwitchInst::ConstCaseIt,
+                                 const BasicBlock, const ConstantInt>;
 
   /// Returns a read/write iterator that points to the first case in the
   /// SwitchInst.
-  CaseIt case_begin() { return CaseIt(this, 0); }
-  ConstCaseIt case_begin() const { return ConstCaseIt(this, 0); }
+  CaseIt case_begin() {
+    return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_begin());
+  }
+  ConstCaseIt case_begin() const {
+    return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_begin());
+  }
   /// Returns a read/write iterator that points one past the last in the
   /// SwitchInst.
-  CaseIt case_end() { return CaseIt(this, getNumCases()); }
-  ConstCaseIt case_end() const { return ConstCaseIt(this, getNumCases()); }
+  CaseIt case_end() {
+    return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_end());
+  }
+  ConstCaseIt case_end() const {
+    return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_end());
+  }
   /// Iteration adapter for range-for loops.
   iterator_range<CaseIt> cases() {
     return make_range(case_begin(), case_end());
@@ -1907,22 +1974,19 @@ class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
   iterator_range<ConstCaseIt> cases() const {
     return make_range(case_begin(), case_end());
   }
-  CaseIt case_default() { return CaseIt(this, DefaultPseudoIndex); }
+  CaseIt case_default() {
+    return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_default());
+  }
   ConstCaseIt case_default() const {
-    return ConstCaseIt(this, DefaultPseudoIndex);
+    return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_default());
   }
   CaseIt findCaseValue(const ConstantInt *C) {
-    return CaseIt(
-        this,
-        const_cast<const SwitchInst *>(this)->findCaseValue(C)->getCaseIndex());
+    const llvm::ConstantInt *LLVMC = cast<llvm::ConstantInt>(C->Val);
+    return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->findCaseValue(LLVMC));
   }
   ConstCaseIt findCaseValue(const ConstantInt *C) const {
-    ConstCaseIt I = llvm::find_if(cases(), [C](const ConstCaseHandle &Case) {
-      return Case.getCaseValue() == C;
-    });
-    if (I != case_end())
-      return I;
-    return case_default();
+    const llvm::ConstantInt *LLVMC = cast<llvm::ConstantInt>(C->Val);
+    return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->findCaseValue(LLVMC));
   }
   LLVM_ABI ConstantInt *findCaseDest(BasicBlock *BB);
 
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index 1a81d185acf76..9ae4c98723fba 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -1125,6 +1125,33 @@ void SwitchInst::setDefaultDest(BasicBlock *DefaultCase) {
   cast<llvm::SwitchInst>(Val)->setDefaultDest(
       cast<llvm::BasicBlock>(DefaultCase->Val));
 }
+
+template <typename LLVMCaseItT, typename BlockT, typename ConstT>
+ConstT *
+SwitchInst::CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>::getCaseValue() const {
+  const auto &LLVMCaseHandle = *LLVMCaseIt;
+  auto *LLVMC = Ctx.getValue(LLVMCaseHandle.getCaseValue());
+  return cast<ConstT>(LLVMC);
+}
+
+template <typename LLVMCaseItT, typename BlockT, typename ConstT>
+BlockT *
+SwitchInst::CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>::getCaseSuccessor()
+    const {
+  const auto &LLVMCaseHandle = *LLVMCaseIt;
+  auto *LLVMBB = LLVMCaseHandle.getCaseSuccessor();
+  return cast<BlockT>(Ctx.getValue(LLVMBB));
+}
+
+template class SwitchInst::CaseHandleImpl<llvm::SwitchInst::CaseIt, BasicBlock,
+                                          ConstantInt>;
+template class SwitchInst::CaseItImpl<llvm::SwitchInst::CaseIt, BasicBlock,
+                                      ConstantInt>;
+template class SwitchInst::CaseHandleImpl<llvm::SwitchInst::ConstCaseIt,
+                                          const BasicBlock, const ConstantInt>;
+template class SwitchInst::CaseItImpl<llvm::SwitchInst::ConstCaseIt,
+                                      const BasicBlock, const ConstantInt>;
+
 ConstantInt *SwitchInst::findCaseDest(BasicBlock *BB) {
   auto *LLVMC = cast<llvm::SwitchInst>(Val)->findCaseDest(
       cast<llvm::BasicBlock>(BB->Val));



More information about the llvm-commits mailing list