[llvm] 6e8c970 - [SandboxIR] Implement CatchSwitchInst (#104652)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 12:30:07 PDT 2024


Author: vporpo
Date: 2024-08-20T12:30:04-07:00
New Revision: 6e8c97035ca32c6b163f8735a340e15e011ec0c8

URL: https://github.com/llvm/llvm-project/commit/6e8c97035ca32c6b163f8735a340e15e011ec0c8
DIFF: https://github.com/llvm/llvm-project/commit/6e8c97035ca32c6b163f8735a340e15e011ec0c8.diff

LOG: [SandboxIR] Implement CatchSwitchInst (#104652)

This patch implements sandboxir::CatchSwitchInst mirroring
llvm::CatchSwitchInst.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/include/llvm/SandboxIR/Tracker.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/lib/SandboxIR/Tracker.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp
    llvm/unittests/SandboxIR/TrackerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a881bdf28f22c2..ca71566091bf82 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -131,6 +131,7 @@ class CastInst;
 class PtrToIntInst;
 class BitCastInst;
 class AllocaInst;
+class CatchSwitchInst;
 class SwitchInst;
 class UnaryOperator;
 class BinaryOperator;
@@ -254,6 +255,7 @@ class Value {
   friend class InvokeInst;         // For getting `Val`.
   friend class CallBrInst;         // For getting `Val`.
   friend class GetElementPtrInst;  // For getting `Val`.
+  friend class CatchSwitchInst;    // For getting `Val`.
   friend class SwitchInst;         // For getting `Val`.
   friend class UnaryOperator;      // For getting `Val`.
   friend class BinaryOperator;     // For getting `Val`.
@@ -263,6 +265,7 @@ class Value {
   friend class CastInst;           // For getting `Val`.
   friend class PHINode;            // For getting `Val`.
   friend class UnreachableInst;    // For getting `Val`.
+  friend class CatchSwitchAddHandler; // For `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -674,6 +677,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class CatchSwitchInst;    // For getTopmostLLVMInstruction().
   friend class SwitchInst;         // For getTopmostLLVMInstruction().
   friend class UnaryOperator;      // For getTopmostLLVMInstruction().
   friend class BinaryOperator;     // For getTopmostLLVMInstruction().
@@ -1480,6 +1484,97 @@ class GetElementPtrInst final
   // TODO: Add missing member functions.
 };
 
+class CatchSwitchInst
+    : public SingleLLVMInstructionImpl<llvm::CatchSwitchInst> {
+public:
+  CatchSwitchInst(llvm::CatchSwitchInst *CSI, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::CatchSwitch, Opcode::CatchSwitch,
+                                  CSI, Ctx) {}
+
+  static CatchSwitchInst *create(Value *ParentPad, BasicBlock *UnwindBB,
+                                 unsigned NumHandlers, BBIterator WhereIt,
+                                 BasicBlock *WhereBB, Context &Ctx,
+                                 const Twine &Name = "");
+
+  Value *getParentPad() const;
+  void setParentPad(Value *ParentPad);
+
+  bool hasUnwindDest() const {
+    return cast<llvm::CatchSwitchInst>(Val)->hasUnwindDest();
+  }
+  bool unwindsToCaller() const {
+    return cast<llvm::CatchSwitchInst>(Val)->unwindsToCaller();
+  }
+  BasicBlock *getUnwindDest() const;
+  void setUnwindDest(BasicBlock *UnwindDest);
+
+  unsigned getNumHandlers() const {
+    return cast<llvm::CatchSwitchInst>(Val)->getNumHandlers();
+  }
+
+private:
+  static BasicBlock *handler_helper(Value *V) { return cast<BasicBlock>(V); }
+  static const BasicBlock *handler_helper(const Value *V) {
+    return cast<BasicBlock>(V);
+  }
+
+public:
+  using DerefFnTy = BasicBlock *(*)(Value *);
+  using handler_iterator = mapped_iterator<op_iterator, DerefFnTy>;
+  using handler_range = iterator_range<handler_iterator>;
+  using ConstDerefFnTy = const BasicBlock *(*)(const Value *);
+  using const_handler_iterator =
+      mapped_iterator<const_op_iterator, ConstDerefFnTy>;
+  using const_handler_range = iterator_range<const_handler_iterator>;
+
+  handler_iterator handler_begin() {
+    op_iterator It = op_begin() + 1;
+    if (hasUnwindDest())
+      ++It;
+    return handler_iterator(It, DerefFnTy(handler_helper));
+  }
+  const_handler_iterator handler_begin() const {
+    const_op_iterator It = op_begin() + 1;
+    if (hasUnwindDest())
+      ++It;
+    return const_handler_iterator(It, ConstDerefFnTy(handler_helper));
+  }
+  handler_iterator handler_end() {
+    return handler_iterator(op_end(), DerefFnTy(handler_helper));
+  }
+  const_handler_iterator handler_end() const {
+    return const_handler_iterator(op_end(), ConstDerefFnTy(handler_helper));
+  }
+  handler_range handlers() {
+    return make_range(handler_begin(), handler_end());
+  }
+  const_handler_range handlers() const {
+    return make_range(handler_begin(), handler_end());
+  }
+
+  void addHandler(BasicBlock *Dest);
+
+  // TODO: removeHandler() cannot be reverted because there is no equivalent
+  // addHandler() with a handler_iterator to specify the position. So we can't
+  // implement it for now.
+
+  unsigned getNumSuccessors() const { return getNumOperands() - 1; }
+  BasicBlock *getSuccessor(unsigned Idx) const {
+    assert(Idx < getNumSuccessors() &&
+           "Successor # out of range for catchswitch!");
+    return cast<BasicBlock>(getOperand(Idx + 1));
+  }
+  void setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
+    assert(Idx < getNumSuccessors() &&
+           "Successor # out of range for catchswitch!");
+    setOperand(Idx + 1, NewSucc);
+  }
+
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CatchSwitch;
+  }
+};
+
 class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
 public:
   SwitchInst(llvm::SwitchInst *SI, Context &Ctx)
@@ -2201,6 +2296,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  CatchSwitchInst *createCatchSwitchInst(llvm::CatchSwitchInst *I);
+  friend CatchSwitchInst; // For createCatchSwitchInst()
   SwitchInst *createSwitchInst(llvm::SwitchInst *I);
   friend SwitchInst; // For createSwitchInst()
   UnaryOperator *createUnaryOperator(llvm::UnaryOperator *I);

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 2b9b44c529b30d..402b6f3324a222 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -46,6 +46,7 @@ DEF_INSTR(Call,          OP(Call),          CallInst)
 DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
 DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(CatchSwitch,   OP(CatchSwitch),   CatchSwitchInst)
 DEF_INSTR(Switch,        OP(Switch),        SwitchInst)
 DEF_INSTR(UnOp,          OPCODES( \
                          OP(FNeg) \

diff  --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 22d910e1d73e72..6f205ae2a075c6 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -59,6 +59,7 @@ class StoreInst;
 class Instruction;
 class Tracker;
 class AllocaInst;
+class CatchSwitchInst;
 class SwitchInst;
 class ConstantInt;
 
@@ -263,6 +264,23 @@ class GenericSetterWithIdx final : public IRChangeBase {
 #endif
 };
 
+class CatchSwitchAddHandler : public IRChangeBase {
+  CatchSwitchInst *CSI;
+  unsigned HandlerIdx;
+
+public:
+  CatchSwitchAddHandler(CatchSwitchInst *CSI);
+  void revert(Tracker &Tracker) final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { OS << "CatchSwitchAddHandler"; }
+  LLVM_DUMP_METHOD void dump() const final {
+    dump(dbgs());
+    dbgs() << "\n";
+  }
+#endif // NDEBUG
+};
+
 class SwitchAddCase : public IRChangeBase {
   SwitchInst *Switch;
   ConstantInt *Val;

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index c243df7fc864ee..5b170cee20c940 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1236,6 +1236,51 @@ static llvm::Instruction::UnaryOps getLLVMUnaryOp(Instruction::Opcode Opc) {
   }
 }
 
+CatchSwitchInst *CatchSwitchInst::create(Value *ParentPad, BasicBlock *UnwindBB,
+                                         unsigned NumHandlers,
+                                         BBIterator WhereIt,
+                                         BasicBlock *WhereBB, Context &Ctx,
+                                         const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  llvm::CatchSwitchInst *LLVMCSI = Builder.CreateCatchSwitch(
+      ParentPad->Val, cast<llvm::BasicBlock>(UnwindBB->Val), NumHandlers, Name);
+  return Ctx.createCatchSwitchInst(LLVMCSI);
+}
+
+Value *CatchSwitchInst::getParentPad() const {
+  return Ctx.getValue(cast<llvm::CatchSwitchInst>(Val)->getParentPad());
+}
+
+void CatchSwitchInst::setParentPad(Value *ParentPad) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&CatchSwitchInst::getParentPad,
+                                       &CatchSwitchInst::setParentPad>>(this);
+  cast<llvm::CatchSwitchInst>(Val)->setParentPad(ParentPad->Val);
+}
+
+BasicBlock *CatchSwitchInst::getUnwindDest() const {
+  return cast_or_null<BasicBlock>(
+      Ctx.getValue(cast<llvm::CatchSwitchInst>(Val)->getUnwindDest()));
+}
+
+void CatchSwitchInst::setUnwindDest(BasicBlock *UnwindDest) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&CatchSwitchInst::getUnwindDest,
+                                       &CatchSwitchInst::setUnwindDest>>(this);
+  cast<llvm::CatchSwitchInst>(Val)->setUnwindDest(
+      cast<llvm::BasicBlock>(UnwindDest->Val));
+}
+
+void CatchSwitchInst::addHandler(BasicBlock *Dest) {
+  Ctx.getTracker().emplaceIfTracking<CatchSwitchAddHandler>(this);
+  cast<llvm::CatchSwitchInst>(Val)->addHandler(
+      cast<llvm::BasicBlock>(Dest->Val));
+}
+
 SwitchInst *SwitchInst::create(Value *V, BasicBlock *Dest, unsigned NumCases,
                                BasicBlock::iterator WhereIt,
                                BasicBlock *WhereBB, Context &Ctx,
@@ -1953,6 +1998,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::CatchSwitch: {
+    auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
+    It->second = std::unique_ptr<CatchSwitchInst>(
+        new CatchSwitchInst(LLVMCatchSwitchInst, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Switch: {
     auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
     It->second =
@@ -2117,6 +2168,10 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
+CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
+  auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
+  return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
+}
 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
   auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
   return cast<SwitchInst>(registerValue(std::move(NewPtr)));

diff  --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 8faff72bc3e2fb..38a1c03556650e 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -160,6 +160,16 @@ void RemoveFromParent::dump() const {
 }
 #endif
 
+CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
+    : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
+
+void CatchSwitchAddHandler::revert(Tracker &Tracker) {
+  // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
+  // once it gets implemented.
+  auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val);
+  LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
+}
+
 void SwitchRemoveCase::revert(Tracker &Tracker) { Switch->addCase(Val, Dest); }
 
 #ifndef NDEBUG

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 1d6a26728c9c56..712865fd07cd7b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1643,6 +1643,110 @@ define void @foo(i32 %arg, float %farg) {
   EXPECT_FALSE(FAdd->getFastMathFlags() != LLVMFAdd->getFastMathFlags());
 }
 
+TEST_F(SandboxIRTest, CatchSwitchInst) {
+  parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+  bb0:
+    %cs0 = catchswitch within none [label %handler0, label %handler1] unwind to caller
+  bb1:
+    %cs1 = catchswitch within %cs0 [label %handler0, label %handler1] unwind label %cleanup
+  handler0:
+    ret void
+  handler1:
+    ret void
+  cleanup:
+    ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMBB0 = getBasicBlockByName(LLVMF, "bb0");
+  auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
+  auto *LLVMHandler0 = getBasicBlockByName(LLVMF, "handler0");
+  auto *LLVMHandler1 = getBasicBlockByName(LLVMF, "handler1");
+  auto *LLVMCleanup = getBasicBlockByName(LLVMF, "cleanup");
+  auto *LLVMCS0 = cast<llvm::CatchSwitchInst>(&*LLVMBB0->begin());
+  auto *LLVMCS1 = cast<llvm::CatchSwitchInst>(&*LLVMBB1->begin());
+
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
+  auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
+  auto *Handler0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler0));
+  auto *Handler1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler1));
+  auto *Cleanup = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMCleanup));
+  auto *CS0 = cast<sandboxir::CatchSwitchInst>(&*BB0->begin());
+  auto *CS1 = cast<sandboxir::CatchSwitchInst>(&*BB1->begin());
+
+  // Check getParentPad().
+  EXPECT_EQ(CS0->getParentPad(), Ctx.getValue(LLVMCS0->getParentPad()));
+  EXPECT_EQ(CS1->getParentPad(), Ctx.getValue(LLVMCS1->getParentPad()));
+  // Check setParentPad().
+  auto *OrigPad = CS0->getParentPad();
+  auto *NewPad = CS1;
+  EXPECT_NE(NewPad, OrigPad);
+  CS0->setParentPad(NewPad);
+  EXPECT_EQ(CS0->getParentPad(), NewPad);
+  CS0->setParentPad(OrigPad);
+  EXPECT_EQ(CS0->getParentPad(), OrigPad);
+  // Check hasUnwindDest().
+  EXPECT_EQ(CS0->hasUnwindDest(), LLVMCS0->hasUnwindDest());
+  EXPECT_EQ(CS1->hasUnwindDest(), LLVMCS1->hasUnwindDest());
+  // Check unwindsToCaller().
+  EXPECT_EQ(CS0->unwindsToCaller(), LLVMCS0->unwindsToCaller());
+  EXPECT_EQ(CS1->unwindsToCaller(), LLVMCS1->unwindsToCaller());
+  // Check getUnwindDest().
+  EXPECT_EQ(CS0->getUnwindDest(), Ctx.getValue(LLVMCS0->getUnwindDest()));
+  EXPECT_EQ(CS1->getUnwindDest(), Ctx.getValue(LLVMCS1->getUnwindDest()));
+  // Check setUnwindDest().
+  auto *OrigUnwindDest = CS1->getUnwindDest();
+  auto *NewUnwindDest = BB0;
+  EXPECT_NE(NewUnwindDest, OrigUnwindDest);
+  CS1->setUnwindDest(NewUnwindDest);
+  EXPECT_EQ(CS1->getUnwindDest(), NewUnwindDest);
+  CS1->setUnwindDest(OrigUnwindDest);
+  EXPECT_EQ(CS1->getUnwindDest(), OrigUnwindDest);
+  // Check getNumHandlers().
+  EXPECT_EQ(CS0->getNumHandlers(), LLVMCS0->getNumHandlers());
+  EXPECT_EQ(CS1->getNumHandlers(), LLVMCS1->getNumHandlers());
+  // Check handler_begin(), handler_end().
+  auto It = CS0->handler_begin();
+  EXPECT_EQ(*It++, Handler0);
+  EXPECT_EQ(*It++, Handler1);
+  EXPECT_EQ(It, CS0->handler_end());
+  // Check handlers().
+  SmallVector<sandboxir::BasicBlock *, 2> Handlers;
+  for (sandboxir::BasicBlock *Handler : CS0->handlers())
+    Handlers.push_back(Handler);
+  EXPECT_EQ(Handlers.size(), 2u);
+  EXPECT_EQ(Handlers[0], Handler0);
+  EXPECT_EQ(Handlers[1], Handler1);
+  // Check addHandler().
+  CS0->addHandler(BB0);
+  EXPECT_EQ(CS0->getNumHandlers(), 3u);
+  EXPECT_EQ(*std::next(CS0->handler_begin(), 2), BB0);
+  // Check getNumSuccessors().
+  EXPECT_EQ(CS0->getNumSuccessors(), LLVMCS0->getNumSuccessors());
+  EXPECT_EQ(CS1->getNumSuccessors(), LLVMCS1->getNumSuccessors());
+  // Check getSuccessor().
+  for (auto SuccIdx : seq<unsigned>(0, CS0->getNumSuccessors()))
+    EXPECT_EQ(CS0->getSuccessor(SuccIdx),
+              Ctx.getValue(LLVMCS0->getSuccessor(SuccIdx)));
+  // Check setSuccessor().
+  auto *OrigSuccessor = CS0->getSuccessor(0);
+  auto *NewSuccessor = BB0;
+  EXPECT_NE(NewSuccessor, OrigSuccessor);
+  CS0->setSuccessor(0, NewSuccessor);
+  EXPECT_EQ(CS0->getSuccessor(0), NewSuccessor);
+  CS0->setSuccessor(0, OrigSuccessor);
+  EXPECT_EQ(CS0->getSuccessor(0), OrigSuccessor);
+  // Check create().
+  CS1->eraseFromParent();
+  auto *NewCSI = sandboxir::CatchSwitchInst::create(
+      CS0, Cleanup, 2, BB1->begin(), BB1, Ctx, "NewCSI");
+  EXPECT_TRUE(isa<sandboxir::CatchSwitchInst>(NewCSI));
+  EXPECT_EQ(NewCSI->getParentPad(), CS0);
+}
+
 TEST_F(SandboxIRTest, SwitchInst) {
   parseIR(C, R"IR(
 define void @foo(i32 %cond0, i32 %cond1) {

diff  --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index b6770b237853cc..9f502375204024 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -644,6 +644,74 @@ define void @foo(i8 %arg) {
   EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
 }
 
+TEST_F(TrackerTest, CatchSwitchInst) {
+  parseIR(C, R"IR(
+define void @foo(i32 %cond0, i32 %cond1) {
+  bb0:
+    %cs0 = catchswitch within none [label %handler0, label %handler1] unwind to caller
+  bb1:
+    %cs1 = catchswitch within %cs0 [label %handler0, label %handler1] unwind label %cleanup
+  handler0:
+    ret void
+  handler1:
+    ret void
+  cleanup:
+    ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *Handler0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "handler0")));
+  auto *Handler1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "handler1")));
+  auto *CS0 = cast<sandboxir::CatchSwitchInst>(&*BB0->begin());
+  auto *CS1 = cast<sandboxir::CatchSwitchInst>(&*BB1->begin());
+
+  // Check setParentPad().
+  auto *OrigPad = CS0->getParentPad();
+  auto *NewPad = CS1;
+  EXPECT_NE(NewPad, OrigPad);
+  Ctx.save();
+  CS0->setParentPad(NewPad);
+  EXPECT_EQ(CS0->getParentPad(), NewPad);
+  Ctx.revert();
+  EXPECT_EQ(CS0->getParentPad(), OrigPad);
+  // Check setUnwindDest().
+  auto *OrigUnwindDest = CS1->getUnwindDest();
+  auto *NewUnwindDest = BB0;
+  EXPECT_NE(NewUnwindDest, OrigUnwindDest);
+  Ctx.save();
+  CS1->setUnwindDest(NewUnwindDest);
+  EXPECT_EQ(CS1->getUnwindDest(), NewUnwindDest);
+  Ctx.revert();
+  EXPECT_EQ(CS1->getUnwindDest(), OrigUnwindDest);
+  // Check setSuccessor().
+  auto *OrigSuccessor = CS0->getSuccessor(0);
+  auto *NewSuccessor = BB0;
+  EXPECT_NE(NewSuccessor, OrigSuccessor);
+  Ctx.save();
+  CS0->setSuccessor(0, NewSuccessor);
+  EXPECT_EQ(CS0->getSuccessor(0), NewSuccessor);
+  Ctx.revert();
+  EXPECT_EQ(CS0->getSuccessor(0), OrigSuccessor);
+  // Check addHandler().
+  Ctx.save();
+  CS0->addHandler(BB0);
+  EXPECT_EQ(CS0->getNumHandlers(), 3u);
+  Ctx.revert();
+  EXPECT_EQ(CS0->getNumHandlers(), 2u);
+  auto HIt = CS0->handler_begin();
+  EXPECT_EQ(*HIt++, Handler0);
+  EXPECT_EQ(*HIt++, Handler1);
+}
+
 TEST_F(TrackerTest, SwitchInstSetters) {
   parseIR(C, R"IR(
 define void @foo(i32 %cond0, i32 %cond1) {


        


More information about the llvm-commits mailing list