[llvm] [SandboxIR] Implement SwitchInst (PR #104641)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 16 14:08:38 PDT 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/104641
This patch implements sandboxir::SwitchInst mirroring llvm::SwitchInst.
>From b15b5f0aec42f8afc6349cf517c09b720fafa396 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 12 Aug 2024 15:11:54 -0700
Subject: [PATCH] [SandboxIR] Implement SwitchInst
This patch implements sandboxir::SwitchInst mirroring llvm::SwitchInst.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 90 ++++++++++++
.../llvm/SandboxIR/SandboxIRValues.def | 1 +
llvm/include/llvm/SandboxIR/Tracker.h | 33 +++++
llvm/lib/SandboxIR/SandboxIR.cpp | 88 ++++++++++++
llvm/lib/SandboxIR/Tracker.cpp | 21 +++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 128 ++++++++++++++++++
llvm/unittests/SandboxIR/TrackerTest.cpp | 80 +++++++++++
7 files changed, 441 insertions(+)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 3c568f4956857..a881bdf28f22c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -131,6 +131,7 @@ class CastInst;
class PtrToIntInst;
class BitCastInst;
class AllocaInst;
+class SwitchInst;
class UnaryOperator;
class BinaryOperator;
class AtomicRMWInst;
@@ -253,6 +254,7 @@ class Value {
friend class InvokeInst; // For getting `Val`.
friend class CallBrInst; // For getting `Val`.
friend class GetElementPtrInst; // For getting `Val`.
+ friend class SwitchInst; // For getting `Val`.
friend class UnaryOperator; // For getting `Val`.
friend class BinaryOperator; // For getting `Val`.
friend class AtomicRMWInst; // For getting `Val`.
@@ -672,6 +674,7 @@ class Instruction : public sandboxir::User {
friend class InvokeInst; // For getTopmostLLVMInstruction().
friend class CallBrInst; // For getTopmostLLVMInstruction().
friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
+ friend class SwitchInst; // For getTopmostLLVMInstruction().
friend class UnaryOperator; // For getTopmostLLVMInstruction().
friend class BinaryOperator; // For getTopmostLLVMInstruction().
friend class AtomicRMWInst; // For getTopmostLLVMInstruction().
@@ -1477,6 +1480,91 @@ class GetElementPtrInst final
// TODO: Add missing member functions.
};
+class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
+public:
+ SwitchInst(llvm::SwitchInst *SI, Context &Ctx)
+ : SingleLLVMInstructionImpl(ClassID::Switch, Opcode::Switch, SI, Ctx) {}
+
+ static constexpr const unsigned DefaultPseudoIndex =
+ llvm::SwitchInst::DefaultPseudoIndex;
+
+ static SwitchInst *create(Value *V, BasicBlock *Dest, unsigned NumCases,
+ BasicBlock::iterator WhereIt, BasicBlock *WhereBB,
+ Context &Ctx, const Twine &Name = "");
+
+ Value *getCondition() const;
+ void setCondition(Value *V);
+ BasicBlock *getDefaultDest() const;
+ bool defaultDestUndefined() const {
+ return cast<llvm::SwitchInst>(Val)->defaultDestUndefined();
+ }
+ void setDefaultDest(BasicBlock *DefaultCase);
+ unsigned getNumCases() const {
+ return cast<llvm::SwitchInst>(Val)->getNumCases();
+ }
+
+ using CaseHandle =
+ llvm::SwitchInst::CaseHandleImpl<SwitchInst, ConstantInt, BasicBlock>;
+ using ConstCaseHandle =
+ llvm::SwitchInst::CaseHandleImpl<const SwitchInst, const ConstantInt,
+ const BasicBlock>;
+ using CaseIt = llvm::SwitchInst::CaseIteratorImpl<CaseHandle>;
+ using ConstCaseIt = llvm::SwitchInst::CaseIteratorImpl<ConstCaseHandle>;
+
+ /// Returns a read/write iterator that points to the first case in the
+ /// SwitchInst.
+ CaseIt case_begin() { return CaseIt(this, 0); }
+ ConstCaseIt case_begin() const { return ConstCaseIt(this, 0); }
+ /// Returns a read/write iterator that points one past the last in the
+ /// SwitchInst.
+ CaseIt case_end() { return CaseIt(this, getNumCases()); }
+ ConstCaseIt case_end() const { return ConstCaseIt(this, getNumCases()); }
+ /// Iteration adapter for range-for loops.
+ iterator_range<CaseIt> cases() {
+ return make_range(case_begin(), case_end());
+ }
+ iterator_range<ConstCaseIt> cases() const {
+ return make_range(case_begin(), case_end());
+ }
+ CaseIt case_default() { return CaseIt(this, DefaultPseudoIndex); }
+ ConstCaseIt case_default() const {
+ return ConstCaseIt(this, DefaultPseudoIndex);
+ }
+ CaseIt findCaseValue(const ConstantInt *C) {
+ return CaseIt(
+ this,
+ const_cast<const SwitchInst *>(this)->findCaseValue(C)->getCaseIndex());
+ }
+ ConstCaseIt findCaseValue(const ConstantInt *C) const {
+ ConstCaseIt I = llvm::find_if(cases(), [C](const ConstCaseHandle &Case) {
+ return Case.getCaseValue() == C;
+ });
+ if (I != case_end())
+ return I;
+ return case_default();
+ }
+ ConstantInt *findCaseDest(BasicBlock *BB);
+
+ void addCase(ConstantInt *OnVal, BasicBlock *Dest);
+ /// This method removes the specified case and its successor from the switch
+ /// instruction. Note that this operation may reorder the remaining cases at
+ /// index idx and above.
+ /// Note:
+ /// This action invalidates iterators for all cases following the one removed,
+ /// including the case_end() iterator. It returns an iterator for the next
+ /// case.
+ CaseIt removeCase(CaseIt It);
+
+ unsigned getNumSuccessors() const {
+ return cast<llvm::SwitchInst>(Val)->getNumSuccessors();
+ }
+ BasicBlock *getSuccessor(unsigned Idx) const;
+ void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::Switch;
+ }
+};
+
class UnaryOperator : public UnaryInstruction {
static Opcode getUnaryOpcode(llvm::Instruction::UnaryOps UnOp) {
switch (UnOp) {
@@ -2113,6 +2201,8 @@ class Context {
friend CallBrInst; // For createCallBrInst()
GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
friend GetElementPtrInst; // For createGetElementPtrInst()
+ SwitchInst *createSwitchInst(llvm::SwitchInst *I);
+ friend SwitchInst; // For createSwitchInst()
UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);
friend UnaryOperator; // For createUnaryOperator()
BinaryOperator *createBinaryOperator(llvm::BinaryOperator *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index cba5b69ebf121..2b9b44c529b30 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -46,6 +46,7 @@ DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(Switch, OP(Switch), SwitchInst)
DEF_INSTR(UnOp, OPCODES( \
OP(FNeg) \
), UnaryOperator)
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 2b4dfafa3b0e5..22d910e1d73e7 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -59,6 +59,8 @@ class StoreInst;
class Instruction;
class Tracker;
class AllocaInst;
+class SwitchInst;
+class ConstantInt;
/// The base class for IR Change classes.
class IRChangeBase {
@@ -261,6 +263,37 @@ class GenericSetterWithIdx final : public IRChangeBase {
#endif
};
+class SwitchAddCase : public IRChangeBase {
+ SwitchInst *Switch;
+ ConstantInt *Val;
+
+public:
+ SwitchAddCase(SwitchInst *Switch, ConstantInt *Val)
+ : Switch(Switch), Val(Val) {}
+ void revert(Tracker &Tracker) final;
+ void accept() final {}
+#ifndef NDEBUG
+ void dump(raw_ostream &OS) const final { OS << "SwitchAddCase"; }
+ LLVM_DUMP_METHOD void dump() const final;
+#endif // NDEBUG
+};
+
+class SwitchRemoveCase : public IRChangeBase {
+ SwitchInst *Switch;
+ ConstantInt *Val;
+ BasicBlock *Dest;
+
+public:
+ SwitchRemoveCase(SwitchInst *Switch, ConstantInt *Val, BasicBlock *Dest)
+ : Switch(Switch), Val(Val), Dest(Dest) {}
+ void revert(Tracker &Tracker) final;
+ void accept() final {}
+#ifndef NDEBUG
+ void dump(raw_ostream &OS) const final { OS << "SwitchRemoveCase"; }
+ LLVM_DUMP_METHOD void dump() const final;
+#endif // NDEBUG
+};
+
class MoveInstr : public IRChangeBase {
/// The instruction that moved.
Instruction *MovedI;
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 6e8c83528bded..c243df7fc864e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1236,6 +1236,84 @@ static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
}
}
+SwitchInst *SwitchInst::create(Value *V, BasicBlock *Dest, unsigned NumCases,
+ BasicBlock::iterator WhereIt,
+ BasicBlock *WhereBB, Context &Ctx,
+ const Twine &Name) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ if (WhereIt != WhereBB->end())
+ Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+ else
+ Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+ llvm::SwitchInst *LLVMSwitch =
+ Builder.CreateSwitch(V->Val, cast<llvm::BasicBlock>(Dest->Val), NumCases);
+ return Ctx.createSwitchInst(LLVMSwitch);
+}
+
+Value *SwitchInst::getCondition() const {
+ return Ctx.getValue(cast<llvm::SwitchInst>(Val)->getCondition());
+}
+
+void SwitchInst::setCondition(Value *V) {
+ Ctx.getTracker()
+ .emplaceIfTracking<
+ GenericSetter<&SwitchInst::getCondition, &SwitchInst::setCondition>>(
+ this);
+ cast<llvm::SwitchInst>(Val)->setCondition(V->Val);
+}
+
+BasicBlock *SwitchInst::getDefaultDest() const {
+ return cast<BasicBlock>(
+ Ctx.getValue(cast<llvm::SwitchInst>(Val)->getDefaultDest()));
+}
+
+void SwitchInst::setDefaultDest(BasicBlock *DefaultCase) {
+ Ctx.getTracker()
+ .emplaceIfTracking<GenericSetter<&SwitchInst::getDefaultDest,
+ &SwitchInst::setDefaultDest>>(this);
+ cast<llvm::SwitchInst>(Val)->setDefaultDest(
+ cast<llvm::BasicBlock>(DefaultCase->Val));
+}
+ConstantInt *SwitchInst::findCaseDest(BasicBlock *BB) {
+ auto *LLVMC = cast<llvm::SwitchInst>(Val)->findCaseDest(
+ cast<llvm::BasicBlock>(BB->Val));
+ return LLVMC != nullptr ? cast<ConstantInt>(Ctx.getValue(LLVMC)) : nullptr;
+}
+
+void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) {
+ Ctx.getTracker().emplaceIfTracking<SwitchAddCase>(this, OnVal);
+ // TODO: Track this!
+ cast<llvm::SwitchInst>(Val)->addCase(cast<llvm::ConstantInt>(OnVal->Val),
+ cast<llvm::BasicBlock>(Dest->Val));
+}
+
+SwitchInst::CaseIt SwitchInst::removeCase(CaseIt It) {
+ auto &Case = *It;
+ Ctx.getTracker().emplaceIfTracking<SwitchRemoveCase>(
+ this, Case.getCaseValue(), Case.getCaseSuccessor());
+
+ auto *LLVMSwitch = cast<llvm::SwitchInst>(Val);
+ unsigned CaseNum = It - case_begin();
+ llvm::SwitchInst::CaseIt LLVMIt(LLVMSwitch, CaseNum);
+ auto LLVMCaseIt = LLVMSwitch->removeCase(LLVMIt);
+ unsigned Num = LLVMCaseIt - LLVMSwitch->case_begin();
+ return CaseIt(this, Num);
+}
+
+BasicBlock *SwitchInst::getSuccessor(unsigned Idx) const {
+ return cast<BasicBlock>(
+ Ctx.getValue(cast<llvm::SwitchInst>(Val)->getSuccessor(Idx)));
+}
+
+void SwitchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
+ Ctx.getTracker()
+ .emplaceIfTracking<GenericSetterWithIdx<&SwitchInst::getSuccessor,
+ &SwitchInst::setSuccessor>>(this,
+ Idx);
+ cast<llvm::SwitchInst>(Val)->setSuccessor(
+ Idx, cast<llvm::BasicBlock>(NewSucc->Val));
+}
+
Value *UnaryOperator::create(Instruction::Opcode Op, Value *OpV,
BBIterator WhereIt, BasicBlock *WhereBB,
Context &Ctx, const Twine &Name) {
@@ -1875,6 +1953,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
new GetElementPtrInst(LLVMGEP, *this));
return It->second.get();
}
+ case llvm::Instruction::Switch: {
+ auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
+ It->second =
+ std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
+ return It->second.get();
+ }
case llvm::Instruction::FNeg: {
auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
It->second = std::unique_ptr<UnaryOperator>(
@@ -2033,6 +2117,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
}
+SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
+ auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
+ return cast<SwitchInst>(registerValue(std::move(NewPtr)));
+}
UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index baf218aa41c1c..8faff72bc3e2f 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -160,6 +160,27 @@ void RemoveFromParent::dump() const {
}
#endif
+void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
+
+#ifndef NDEBUG
+void SwitchRemoveCase::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
+
+void SwitchAddCase::revert(Tracker &Tracker) {
+ auto It = Switch->findCaseValue(Val);
+ Switch->removeCase(It);
+}
+
+#ifndef NDEBUG
+void SwitchAddCase::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
+#endif // NDEBUG
+
MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
if (auto *NextI = MovedI->getNextNode())
NextInstrOrBB = NextI;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 6c0f1b19243c2..0b2635861c285 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1643,6 +1643,134 @@ define void @foo(i32 %arg, float %farg) {
EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
}
+TEST_F(SandboxIRTest, SwitchInst) {
+ parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+ entry:
+ switch i32 %cond0, label %default [ i32 0, label %bb0
+ i32 1, label %bb1 ]
+ bb0:
+ ret void
+ bb1:
+ ret void
+ default:
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry");
+ auto *LLVMSwitch = cast<llvm::SwitchInst>(&*LLVMEntry->begin());
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *Cond1 = F.getArg(1);
+ auto *Entry = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMEntry));
+ auto *Switch = cast<sandboxir::SwitchInst>(&*Entry->begin());
+ auto *BB0 = cast<sandboxir::BasicBlock>(
+ Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+ auto *BB1 = cast<sandboxir::BasicBlock>(
+ Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+ auto *Default = cast<sandboxir::BasicBlock>(
+ Ctx.getValue(getBasicBlockByName(LLVMF, "default")));
+
+ // Check getCondition().
+ EXPECT_EQ(Switch->getCondition(), Ctx.getValue(LLVMSwitch->getCondition()));
+ // Check setCondition().
+ auto *OrigCond = Switch->getCondition();
+ auto *NewCond = Cond1;
+ EXPECT_NE(NewCond, OrigCond);
+ Switch->setCondition(NewCond);
+ EXPECT_EQ(Switch->getCondition(), NewCond);
+ Switch->setCondition(OrigCond);
+ EXPECT_EQ(Switch->getCondition(), OrigCond);
+ // Check getDefaultDest().
+ EXPECT_EQ(Switch->getDefaultDest(),
+ Ctx.getValue(LLVMSwitch->getDefaultDest()));
+ EXPECT_EQ(Switch->getDefaultDest(), Default);
+ // Check defaultDestUndefined().
+ EXPECT_EQ(Switch->defaultDestUndefined(), LLVMSwitch->defaultDestUndefined());
+ // Check setDefaultDest().
+ auto *OrigDefaultDest = Switch->getDefaultDest();
+ auto *NewDefaultDest = Entry;
+ EXPECT_NE(NewDefaultDest, OrigDefaultDest);
+ Switch->setDefaultDest(NewDefaultDest);
+ EXPECT_EQ(Switch->getDefaultDest(), NewDefaultDest);
+ Switch->setDefaultDest(OrigDefaultDest);
+ EXPECT_EQ(Switch->getDefaultDest(), OrigDefaultDest);
+ // Check getNumCases().
+ EXPECT_EQ(Switch->getNumCases(), LLVMSwitch->getNumCases());
+ // Check getNumSuccessors().
+ EXPECT_EQ(Switch->getNumSuccessors(), LLVMSwitch->getNumSuccessors());
+ // Check getSuccessor().
+ for (auto SuccIdx : seq<unsigned>(0, Switch->getNumSuccessors()))
+ EXPECT_EQ(Switch->getSuccessor(SuccIdx),
+ Ctx.getValue(LLVMSwitch->getSuccessor(SuccIdx)));
+ // Check setSuccessor().
+ auto *OrigSucc = Switch->getSuccessor(0);
+ auto *NewSucc = Entry;
+ EXPECT_NE(NewSucc, OrigSucc);
+ Switch->setSuccessor(0, NewSucc);
+ EXPECT_EQ(Switch->getSuccessor(0), NewSucc);
+ Switch->setSuccessor(0, OrigSucc);
+ EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
+ // Check case_begin(), case_end(), CaseIt.
+ auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
+ auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+ auto CaseIt = Switch->case_begin();
+ {
+ const sandboxir::SwitchInst::CaseHandle &Case = *CaseIt++;
+ EXPECT_EQ(Case.getCaseValue(), Zero);
+ EXPECT_EQ(Case.getCaseSuccessor(), BB0);
+ EXPECT_EQ(Case.getCaseIndex(), 0u);
+ EXPECT_EQ(Case.getSuccessorIndex(), 1u);
+ }
+ {
+ const sandboxir::SwitchInst::CaseHandle &Case = *CaseIt++;
+ EXPECT_EQ(Case.getCaseValue(), One);
+ EXPECT_EQ(Case.getCaseSuccessor(), BB1);
+ EXPECT_EQ(Case.getCaseIndex(), 1u);
+ EXPECT_EQ(Case.getSuccessorIndex(), 2u);
+ }
+ EXPECT_EQ(CaseIt, Switch->case_end());
+ // Check cases().
+ unsigned CntCase = 0;
+ for (auto &Case : Switch->cases()) {
+ EXPECT_EQ(Case.getCaseIndex(), CntCase);
+ ++CntCase;
+ }
+ EXPECT_EQ(CntCase, 2u);
+ // Check case_default().
+ auto &CaseDefault = *Switch->case_default();
+ EXPECT_EQ(CaseDefault.getCaseSuccessor(), Default);
+ EXPECT_EQ(CaseDefault.getCaseIndex(),
+ sandboxir::SwitchInst::DefaultPseudoIndex);
+ // Check findCaseValue().
+ EXPECT_EQ(Switch->findCaseValue(Zero)->getCaseIndex(), 0u);
+ EXPECT_EQ(Switch->findCaseValue(One)->getCaseIndex(), 1u);
+ // Check findCaseDest().
+ EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+ EXPECT_EQ(Switch->findCaseDest(BB1), One);
+ EXPECT_EQ(Switch->findCaseDest(Entry), nullptr);
+ // Check addCase().
+ auto *Two = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 2, Ctx);
+ Switch->addCase(Two, Entry);
+ auto CaseTwoIt = Switch->findCaseValue(Two);
+ auto &CaseTwo = *CaseTwoIt;
+ EXPECT_EQ(CaseTwo.getCaseValue(), Two);
+ EXPECT_EQ(CaseTwo.getCaseSuccessor(), Entry);
+ EXPECT_EQ(Switch->getNumCases(), 3u);
+ // Check removeCase().
+ auto RemovedIt = Switch->removeCase(CaseTwoIt);
+ EXPECT_EQ(RemovedIt, Switch->case_end());
+ EXPECT_EQ(Switch->getNumCases(), 2u);
+ // Check create().
+ auto NewSwitch = sandboxir::SwitchInst::create(
+ Cond1, Default, 1, Default->begin(), Default, Ctx, "NewSwitch");
+ EXPECT_TRUE(isa<sandboxir::SwitchInst>(NewSwitch));
+ EXPECT_EQ(NewSwitch->getCondition(), Cond1);
+ EXPECT_EQ(NewSwitch->getDefaultDest(), Default);
+}
+
TEST_F(SandboxIRTest, UnaryOperator) {
parseIR(C, R"IR(
define void @foo(float %arg0) {
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 380c90e7f0f65..b6770b237853c 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -644,6 +644,86 @@ define void @foo(i8 %arg) {
EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
}
+TEST_F(TrackerTest, SwitchInstSetters) {
+ parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+ entry:
+ switch i32 %cond0, label %default [ i32 0, label %bb0
+ i32 1, label %bb1 ]
+ bb0:
+ ret void
+ bb1:
+ ret void
+ default:
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ auto *LLVMEntry = getBasicBlockByName(LLVMF, "entry");
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *Cond1 = F.getArg(1);
+ auto *Entry = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMEntry));
+ auto *BB0 = cast<sandboxir::BasicBlock>(
+ Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+ auto *BB1 = cast<sandboxir::BasicBlock>(
+ Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+ auto *Switch = cast<sandboxir::SwitchInst>(&*Entry->begin());
+
+ // Check setCondition().
+ auto *OrigCond = Switch->getCondition();
+ auto *NewCond = Cond1;
+ EXPECT_NE(NewCond, OrigCond);
+ Ctx.save();
+ Switch->setCondition(NewCond);
+ EXPECT_EQ(Switch->getCondition(), NewCond);
+ Ctx.revert();
+ EXPECT_EQ(Switch->getCondition(), OrigCond);
+ // Check setDefaultDest().
+ auto *OrigDefaultDest = Switch->getDefaultDest();
+ auto *NewDefaultDest = Entry;
+ EXPECT_NE(NewDefaultDest, OrigDefaultDest);
+ Ctx.save();
+ Switch->setDefaultDest(NewDefaultDest);
+ EXPECT_EQ(Switch->getDefaultDest(), NewDefaultDest);
+ Ctx.revert();
+ EXPECT_EQ(Switch->getDefaultDest(), OrigDefaultDest);
+ // Check setSuccessor().
+ auto *OrigSucc = Switch->getSuccessor(0);
+ auto *NewSucc = Entry;
+ EXPECT_NE(NewSucc, OrigSucc);
+ Ctx.save();
+ Switch->setSuccessor(0, NewSucc);
+ EXPECT_EQ(Switch->getSuccessor(0), NewSucc);
+ Ctx.revert();
+ EXPECT_EQ(Switch->getSuccessor(0), OrigSucc);
+ // Check addCase().
+ auto *Zero = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 0, Ctx);
+ auto *One = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 1, Ctx);
+ auto *FortyTwo = sandboxir::ConstantInt::get(Type::getInt32Ty(C), 42, Ctx);
+ Ctx.save();
+ Switch->addCase(FortyTwo, Entry);
+ EXPECT_EQ(Switch->getNumCases(), 3u);
+ EXPECT_EQ(Switch->findCaseDest(Entry), FortyTwo);
+ EXPECT_EQ(Switch->findCaseValue(FortyTwo)->getCaseSuccessor(), Entry);
+ EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+ EXPECT_EQ(Switch->findCaseDest(BB1), One);
+ Ctx.revert();
+ EXPECT_EQ(Switch->getNumCases(), 2u);
+ EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+ EXPECT_EQ(Switch->findCaseDest(BB1), One);
+ // Check removeCase().
+ Ctx.save();
+ Switch->removeCase(Switch->findCaseValue(Zero));
+ EXPECT_EQ(Switch->getNumCases(), 1u);
+ EXPECT_EQ(Switch->findCaseDest(BB1), One);
+ Ctx.revert();
+ EXPECT_EQ(Switch->getNumCases(), 2u);
+ EXPECT_EQ(Switch->findCaseDest(BB0), Zero);
+ EXPECT_EQ(Switch->findCaseDest(BB1), One);
+}
+
TEST_F(TrackerTest, AtomicRMWSetters) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %arg) {
More information about the llvm-commits
mailing list