[llvm] [SandboxIR] Implement SwitchInst (PR #104641)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 14:08:38 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/104641

This patch implements sandboxir::SwitchInst mirroring llvm::SwitchInst.

>From b15b5f0aec42f8afc6349cf517c09b720fafa396 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 12 Aug 2024 15:11:54 -0700
Subject: [PATCH] [SandboxIR] Implement SwitchInst

This patch implements sandboxir::SwitchInst mirroring llvm::SwitchInst.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  90 ++++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |   1 +
 llvm/include/llvm/SandboxIR/Tracker.h         |  33 +++++
 llvm/lib/SandboxIR/SandboxIR.cpp              |  88 ++++++++++++
 llvm/lib/SandboxIR/Tracker.cpp                |  21 +++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 128 ++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp      |  80 +++++++++++
 7 files changed, 441 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 3c568f4956857..a881bdf28f22c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -131,6 +131,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class SwitchInst;
 class UnaryOperator;
 class BinaryOperator;
 class AtomicRMWInst;
@@ -253,6 +254,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class SwitchInst;         // For getting `Val`.
   friend class UnaryOperator;      // For getting `Val`.
   friend class BinaryOperator;     // For getting `Val`.
   friend class AtomicRMWInst;      // For getting `Val`.
@@ -672,6 +674,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class SwitchInst;         // For getTopmostLLVMInstruction().
   friend class UnaryOperator;      // For getTopmostLLVMInstruction().
   friend class BinaryOperator;     // For getTopmostLLVMInstruction().
   friend class AtomicRMWInst;      // For getTopmostLLVMInstruction().
@@ -1477,6 +1480,91 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
+public:
+  SwitchInst(llvm::SwitchInst *SI, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::Switch, Opcode::Switch, SI, Ctx) {}
+
+  static constexpr const unsigned DefaultPseudoIndex =
+      llvm::SwitchInst::DefaultPseudoIndex;
+
+  static SwitchInst *create(Value *V, BasicBlock *Dest, unsigned NumCases,
+                            BasicBlock::iterator WhereIt, BasicBlock *WhereBB,
+                            Context &Ctx, const Twine &Name = "");
+
+  Value *getCondition() const;
+  void setCondition(Value *V);
+  BasicBlock *getDefaultDest() const;
+  bool defaultDestUndefined() const {
+    return cast<llvm::SwitchInst>(Val)->defaultDestUndefined();
+  }
+  void setDefaultDest(BasicBlock *DefaultCase);
+  unsigned getNumCases() const {
+    return cast<llvm::SwitchInst>(Val)->getNumCases();
+  }
+
+  using CaseHandle =
+      llvm::SwitchInst::CaseHandleImpl<SwitchInst, ConstantInt, BasicBlock>;
+  using ConstCaseHandle =
+      llvm::SwitchInst::CaseHandleImpl<const SwitchInst, const ConstantInt,
+                                       const BasicBlock>;
+  using CaseIt = llvm::SwitchInst::CaseIteratorImpl<CaseHandle>;
+  using ConstCaseIt = llvm::SwitchInst::CaseIteratorImpl<ConstCaseHandle>;
+
+  /// Returns a read/write iterator that points to the first case in the
+  /// SwitchInst.
+  CaseIt case_begin() { return CaseIt(this, 0); }
+  ConstCaseIt case_begin() const { return ConstCaseIt(this, 0); }
+  /// Returns a read/write iterator that points one past the last in the
+  /// SwitchInst.
+  CaseIt case_end() { return CaseIt(this, getNumCases()); }
+  ConstCaseIt case_end() const { return ConstCaseIt(this, getNumCases()); }
+  /// Iteration adapter for range-for loops.
+  iterator_range<CaseIt> cases() {
+    return make_range(case_begin(), case_end());
+  }
+  iterator_range<ConstCaseIt> cases() const {
+    return make_range(case_begin(), case_end());
+  }
+  CaseIt case_default() { return CaseIt(this, DefaultPseudoIndex); }
+  ConstCaseIt case_default() const {
+    return ConstCaseIt(this, DefaultPseudoIndex);
+  }
+  CaseIt findCaseValue(const ConstantInt *C) {
+    return CaseIt(
+        this,
+        const_cast<const SwitchInst *>(this)->findCaseValue(C)->getCaseIndex());
+  }
+  ConstCaseIt findCaseValue(const ConstantInt *C) const {
+    ConstCaseIt I = llvm::find_if(cases(), [C](const ConstCaseHandle &Case) {
+      return Case.getCaseValue() == C;
+    });
+    if (I != case_end())
+      return I;
+    return case_default();
+  }
+  ConstantInt *findCaseDest(BasicBlock *BB);
+
+  void addCase(ConstantInt *OnVal, BasicBlock *Dest);
+  /// This method removes the specified case and its successor from the switch
+  /// instruction. Note that this operation may reorder the remaining cases at
+  /// index idx and above.
+  /// Note:
+  /// This action invalidates iterators for all cases following the one removed,
+  /// including the case_end() iterator. It returns an iterator for the next
+  /// case.
+  CaseIt removeCase(CaseIt It);
+
+  unsigned getNumSuccessors() const {
+    return cast<llvm::SwitchInst>(Val)->getNumSuccessors();
+  }
+  BasicBlock *getSuccessor(unsigned Idx) const;
+  void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::Switch;
+  }
+};
+
 class UnaryOperator : public UnaryInstruction {
   static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
     switch (UnOp) {
@@ -2113,6 +2201,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  SwitchInst *createSwitchInst(llvm::SwitchInst *I);
+  friend SwitchInst; // For createSwitchInst()
   UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
   friend UnaryOperator; // For createUnaryOperator()
   BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index cba5b69ebf121..2b9b44c529b30 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -46,6 +46,7 @@ DEF_INSTR(Call,          OP(Call),          CallInst)
 DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
 DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(Switch,        OP(Switch),        SwitchInst)
 DEF_INSTR(UnOp,          OPCODES( \
                          OP(FNeg) \
                          ),                 UnaryOperator)
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 2b4dfafa3b0e5..22d910e1d73e7 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -59,6 +59,8 @@ class StoreInst;
 class Instruction;
 class Tracker;
 class AllocaInst;
+class SwitchInst;
+class ConstantInt;
 
 /// The base class for IR Change classes.
 class IRChangeBase {
@@ -261,6 +263,37 @@ class GenericSetterWithIdx final : public IRChangeBase {
 #endif
 };
 
+class SwitchAddCase : public IRChangeBase {
+  SwitchInst *Switch;
+  ConstantInt *Val;
+
+public:
+  SwitchAddCase(SwitchInst *Switch, ConstantInt *Val)
+      : Switch(Switch), Val(Val) {}
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "SwitchAddCase"; }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif // NDEBUG
+};
+
+class SwitchRemoveCase : public IRChangeBase {
+  SwitchInst *Switch;
+  ConstantInt *Val;
+  BasicBlock *Dest;
+
+public:
+  SwitchRemoveCase(SwitchInst *Switch, ConstantInt *Val, BasicBlock *Dest)
+      : Switch(Switch), Val(Val), Dest(Dest) {}
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "SwitchRemoveCase"; }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif // NDEBUG
+};
+
 class MoveInstr : public IRChangeBase {
   /// The instruction that moved.
   Instruction *MovedI;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 6e8c83528bded..c243df7fc864e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1236,6 +1236,84 @@ static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
   }
 }
 
+SwitchInst *SwitchInst::create(Value *V, BasicBlock *Dest, unsigned NumCases,
+                               BasicBlock::iterator WhereIt,
+                               BasicBlock *WhereBB, Context &Ctx,
+                               const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  llvm::SwitchInst *LLVMSwitch =
+      Builder.CreateSwitch(V->Val, cast<llvm::BasicBlock>(Dest->Val), NumCases);
+  return Ctx.createSwitchInst(LLVMSwitch);
+}
+
+Value *SwitchInst::getCondition() const {
+  return Ctx.getValue(cast<llvm::SwitchInst>(Val)->getCondition());
+}
+
+void SwitchInst::setCondition(Value *V) {
+  Ctx.getTracker()
+      .emplaceIfTracking<
+          GenericSetter<&SwitchInst::getCondition, &SwitchInst::setCondition>>(
+          this);
+  cast<llvm::SwitchInst>(Val)->setCondition(V->Val);
+}
+
+BasicBlock *SwitchInst::getDefaultDest() const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::SwitchInst>(Val)->getDefaultDest()));
+}
+
+void SwitchInst::setDefaultDest(BasicBlock *DefaultCase) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&SwitchInst::getDefaultDest,
+                                       &SwitchInst::setDefaultDest>>(this);
+  cast<llvm::SwitchInst>(Val)->setDefaultDest(
+      cast<llvm::BasicBlock>(DefaultCase->Val));
+}
+ConstantInt *SwitchInst::findCaseDest(BasicBlock *BB) {
+  auto *LLVMC = cast<llvm::SwitchInst>(Val)->findCaseDest(
+      cast<llvm::BasicBlock>(BB->Val));
+  return LLVMC != nullptr ? cast<ConstantInt>(Ctx.getValue(LLVMC)) : nullptr;
+}
+
+void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) {
+  Ctx.getTracker().emplaceIfTracking<SwitchAddCase>(this, OnVal);
+  // TODO: Track this!
+  cast<llvm::SwitchInst>(Val)->addCase(cast<llvm::ConstantInt>(OnVal->Val),
+                                       cast<llvm::BasicBlock>(Dest->Val));
+}
+
+SwitchInst::CaseIt SwitchInst::removeCase(CaseIt It) {
+  auto &Case = *It;
+  Ctx.getTracker().emplaceIfTracking<SwitchRemoveCase>(
+      this, Case.getCaseValue(), Case.getCaseSuccessor());
+
+  auto *LLVMSwitch = cast<llvm::SwitchInst>(Val);
+  unsigned CaseNum = It - case_begin();
+  llvm::SwitchInst::CaseIt LLVMIt(LLVMSwitch, CaseNum);
+  auto LLVMCaseIt = LLVMSwitch->removeCase(LLVMIt);
+  unsigned Num = LLVMCaseIt - LLVMSwitch->case_begin();
+  return CaseIt(this, Num);
+}
+
+BasicBlock *SwitchInst::getSuccessor(unsigned Idx) const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::SwitchInst>(Val)->getSuccessor(Idx)));
+}
+
+void SwitchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetterWithIdx<&SwitchInst::getSuccessor,
+                                              &SwitchInst::setSuccessor>>(this,
+                                                                          Idx);
+  cast<llvm::SwitchInst>(Val)->setSuccessor(
+      Idx, cast<llvm::BasicBlock>(NewSucc->Val));
+}
+
 Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
                              BBIterator WhereIt, BasicBlock *WhereBB,
                              Context &Ctx, const Twine &Name) {
@@ -1875,6 +1953,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Switch: {
+    auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
+    It->second =
+        std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::FNeg: {
     auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
     It->second = std::unique_ptr<UnaryOperator>(
@@ -2033,6 +2117,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
+  auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
+  return cast<SwitchInst>(registerValue(std::move(NewPtr)));
+}
 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
   auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
   return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index baf218aa41c1c..8faff72bc3e2f 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -160,6 +160,27 @@ void RemoveFromParent::dump() const {
 }
 #endif
 
+void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
+
+#ifndef NDEBUG
+void SwitchRemoveCase::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void SwitchAddCase::revert(Tracker &Tracker) {
+  auto It = Switch->findCaseValue(Val);
+  Switch->removeCase(It);
+}
+
+#ifndef NDEBUG
+void SwitchAddCase::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
   if (auto *NextI = MovedI->getNextNode())
     NextInstrOrBB = NextI;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 6c0f1b19243c2..0b2635861c285 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1643,6 +1643,134 @@ define void @foo(i32 %arg, float %farg) {
   EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
 }
 
+TEST_F(SandboxIRTest, SwitchInst) {
+  parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+  entry:
+    switch i32 %cond0, label %default [ i32 0, label %bb0
+                                        i32 1, label %bb1 ]
+  bb0:
+    ret void
+  bb1:
+    ret void
+  default:
+    ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry");
+  auto *LLVMSwitch = cast<llvm::SwitchInst>(&*LLVMEntry->begin());
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Cond1 = F.getArg(1);
+  auto *Entry = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMEntry));
+  auto *Switch = cast<sandboxir::SwitchInst>(&*Entry->begin());
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *Default = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "default")));
+
+  // Check getCondition().
+  EXPECT_EQ(Switch->getCondition(), Ctx.getValue(LLVMSwitch->getCondition()));
+  // Check setCondition().
+  auto *OrigCond = Switch->getCondition();
+  auto *NewCond = Cond1;
+  EXPECT_NE(NewCond, OrigCond);
+  Switch->setCondition(NewCond);
+  EXPECT_EQ(Switch->getCondition(), NewCond);
+  Switch->setCondition(OrigCond);
+  EXPECT_EQ(Switch->getCondition(), OrigCond);
+  // Check getDefaultDest().
+  EXPECT_EQ(Switch->getDefaultDest(),
+            Ctx.getValue(LLVMSwitch->getDefaultDest()));
+  EXPECT_EQ(Switch->getDefaultDest(), Default);
+  // Check defaultDestUndefined().
+  EXPECT_EQ(Switch->defaultDestUndefined(), LLVMSwitch->defaultDestUndefined());
+  // Check setDefaultDest().
+  auto *OrigDefaultDest = Switch->getDefaultDest();
+  auto *NewDefaultDest = Entry;
+  EXPECT_NE(NewDefaultDest, OrigDefaultDest);
+  Switch->setDefaultDest(NewDefaultDest);
+  EXPECT_EQ(Switch->getDefaultDest(), NewDefaultDest);
+  Switch->setDefaultDest(OrigDefaultDest);
+  EXPECT_EQ(Switch->getDefaultDest(), OrigDefaultDest);
+  // Check getNumCases().
+  EXPECT_EQ(Switch->getNumCases(), LLVMSwitch->getNumCases());
+  // Check getNumSuccessors().
+  EXPECT_EQ(Switch->getNumSuccessors(), LLVMSwitch->getNumSuccessors());
+  // Check getSuccessor().
+  for (auto SuccIdx : seq<unsigned>(0, Switch->getNumSuccessors()))
+    EXPECT_EQ(Switch->getSuccessor(SuccIdx),
+              Ctx.getValue(LLVMSwitch->getSuccessor(SuccIdx)));
+  // Check setSuccessor().
+  auto *OrigSucc = Switch->getSuccessor(0);
+  auto *NewSucc = Entry;
+  EXPECT_NE(NewSucc, OrigSucc);
+  Switch->setSuccessor(0, NewSucc);
+  EXPECT_EQ(Switch->getSuccessor(0), NewSucc);
+  Switch->setSuccessor(0, OrigSucc);
+  EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
+  // Check case_begin(), case_end(), CaseIt.
+  auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
+  auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+  auto CaseIt = Switch->case_begin();
+  {
+    const sandboxir::SwitchInst::CaseHandle &Case = *CaseIt++;
+    EXPECT_EQ(Case.getCaseValue(), Zero);
+    EXPECT_EQ(Case.getCaseSuccessor(), BB0);
+    EXPECT_EQ(Case.getCaseIndex(), 0u);
+    EXPECT_EQ(Case.getSuccessorIndex(), 1u);
+  }
+  {
+    const sandboxir::SwitchInst::CaseHandle &Case = *CaseIt++;
+    EXPECT_EQ(Case.getCaseValue(), One);
+    EXPECT_EQ(Case.getCaseSuccessor(), BB1);
+    EXPECT_EQ(Case.getCaseIndex(), 1u);
+    EXPECT_EQ(Case.getSuccessorIndex(), 2u);
+  }
+  EXPECT_EQ(CaseIt, Switch->case_end());
+  // Check cases().
+  unsigned CntCase = 0;
+  for (auto &Case : Switch->cases()) {
+    EXPECT_EQ(Case.getCaseIndex(), CntCase);
+    ++CntCase;
+  }
+  EXPECT_EQ(CntCase, 2u);
+  // Check case_default().
+  auto &CaseDefault = *Switch->case_default();
+  EXPECT_EQ(CaseDefault.getCaseSuccessor(), Default);
+  EXPECT_EQ(CaseDefault.getCaseIndex(),
+            sandboxir::SwitchInst::DefaultPseudoIndex);
+  // Check findCaseValue().
+  EXPECT_EQ(Switch->findCaseValue(Zero)->getCaseIndex(), 0u);
+  EXPECT_EQ(Switch->findCaseValue(One)->getCaseIndex(), 1u);
+  // Check findCaseDest().
+  EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+  EXPECT_EQ(Switch->findCaseDest(BB1), One);
+  EXPECT_EQ(Switch->findCaseDest(Entry), nullptr);
+  // Check addCase().
+  auto *Two = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 2, Ctx);
+  Switch->addCase(Two, Entry);
+  auto CaseTwoIt = Switch->findCaseValue(Two);
+  auto &CaseTwo = *CaseTwoIt;
+  EXPECT_EQ(CaseTwo.getCaseValue(), Two);
+  EXPECT_EQ(CaseTwo.getCaseSuccessor(), Entry);
+  EXPECT_EQ(Switch->getNumCases(), 3u);
+  // Check removeCase().
+  auto RemovedIt = Switch->removeCase(CaseTwoIt);
+  EXPECT_EQ(RemovedIt, Switch->case_end());
+  EXPECT_EQ(Switch->getNumCases(), 2u);
+  // Check create().
+  auto NewSwitch = sandboxir::SwitchInst::create(
+      Cond1, Default, 1, Default->begin(), Default, Ctx, "NewSwitch");
+  EXPECT_TRUE(isa<sandboxir::SwitchInst>(NewSwitch));
+  EXPECT_EQ(NewSwitch->getCondition(), Cond1);
+  EXPECT_EQ(NewSwitch->getDefaultDest(), Default);
+}
+
 TEST_F(SandboxIRTest, UnaryOperator) {
   parseIR(C, R"IR(
 define void @foo(float %arg0) {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 380c90e7f0f65..b6770b237853c 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -644,6 +644,86 @@ define void @foo(i8 %arg) {
   EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
 }
 
+TEST_F(TrackerTest, SwitchInstSetters) {
+  parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+  entry:
+    switch i32 %cond0, label %default [ i32 0, label %bb0
+                                        i32 1, label %bb1 ]
+  bb0:
+    ret void
+  bb1:
+    ret void
+  default:
+    ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry");
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Cond1 = F.getArg(1);
+  auto *Entry = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMEntry));
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *Switch = cast<sandboxir::SwitchInst>(&*Entry->begin());
+
+  // Check setCondition().
+  auto *OrigCond = Switch->getCondition();
+  auto *NewCond = Cond1;
+  EXPECT_NE(NewCond, OrigCond);
+  Ctx.save();
+  Switch->setCondition(NewCond);
+  EXPECT_EQ(Switch->getCondition(), NewCond);
+  Ctx.revert();
+  EXPECT_EQ(Switch->getCondition(), OrigCond);
+  // Check setDefaultDest().
+  auto *OrigDefaultDest = Switch->getDefaultDest();
+  auto *NewDefaultDest = Entry;
+  EXPECT_NE(NewDefaultDest, OrigDefaultDest);
+  Ctx.save();
+  Switch->setDefaultDest(NewDefaultDest);
+  EXPECT_EQ(Switch->getDefaultDest(), NewDefaultDest);
+  Ctx.revert();
+  EXPECT_EQ(Switch->getDefaultDest(), OrigDefaultDest);
+  // Check setSuccessor().
+  auto *OrigSucc = Switch->getSuccessor(0);
+  auto *NewSucc = Entry;
+  EXPECT_NE(NewSucc, OrigSucc);
+  Ctx.save();
+  Switch->setSuccessor(0, NewSucc);
+  EXPECT_EQ(Switch->getSuccessor(0), NewSucc);
+  Ctx.revert();
+  EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
+  // Check addCase().
+  auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
+  auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+  auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+  Ctx.save();
+  Switch->addCase(FortyTwo, Entry);
+  EXPECT_EQ(Switch->getNumCases(), 3u);
+  EXPECT_EQ(Switch->findCaseDest(Entry), FortyTwo);
+  EXPECT_EQ(Switch->findCaseValue(FortyTwo)->getCaseSuccessor(), Entry);
+  EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+  EXPECT_EQ(Switch->findCaseDest(BB1), One);
+  Ctx.revert();
+  EXPECT_EQ(Switch->getNumCases(), 2u);
+  EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+  EXPECT_EQ(Switch->findCaseDest(BB1), One);
+  // Check removeCase().
+  Ctx.save();
+  Switch->removeCase(Switch->findCaseValue(Zero));
+  EXPECT_EQ(Switch->getNumCases(), 1u);
+  EXPECT_EQ(Switch->findCaseDest(BB1), One);
+  Ctx.revert();
+  EXPECT_EQ(Switch->getNumCases(), 2u);
+  EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+  EXPECT_EQ(Switch->findCaseDest(BB1), One);
+}
+
 TEST_F(TrackerTest, AtomicRMWSetters) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, i8 %arg) {



More information about the llvm-commits mailing list