[llvm] 0d21c2b - [SandboxIR] Implement CatchReturnInst (#105605)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 16:20:55 PDT 2024


Author: vporpo
Date: 2024-08-22T16:20:51-07:00
New Revision: 0d21c2b3e516617ee0fe60e2e5368e0c447b17ad

URL: https://github.com/llvm/llvm-project/commit/0d21c2b3e516617ee0fe60e2e5368e0c447b17ad
DIFF: https://github.com/llvm/llvm-project/commit/0d21c2b3e516617ee0fe60e2e5368e0c447b17ad.diff

LOG: [SandboxIR] Implement CatchReturnInst (#105605)

This patch implements sandboxir::CatchReturnInst mirroring
llvm::CatchReturnInst.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp
    llvm/unittests/SandboxIR/TrackerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index ed5b6f9c9da852..c09e167d67bb1c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -130,6 +130,7 @@ class CallBrInst;
 class FuncletPadInst;
 class CatchPadInst;
 class CleanupPadInst;
+class CatchReturnInst;
 class GetElementPtrInst;
 class CastInst;
 class PtrToIntInst;
@@ -262,6 +263,7 @@ class Value {
   friend class FuncletPadInst;        // For getting `Val`.
   friend class CatchPadInst;          // For getting `Val`.
   friend class CleanupPadInst;        // For getting `Val`.
+  friend class CatchReturnInst;       // For getting `Val`.
   friend class GetElementPtrInst;     // For getting `Val`.
   friend class CatchSwitchInst;       // For getting `Val`.
   friend class SwitchInst;            // For getting `Val`.
@@ -687,6 +689,7 @@ class Instruction : public sandboxir::User {
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
   friend class CatchPadInst;       // For getTopmostLLVMInstruction().
   friend class CleanupPadInst;     // For getTopmostLLVMInstruction().
+  friend class CatchReturnInst;    // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
   friend class CatchSwitchInst;    // For getTopmostLLVMInstruction().
   friend class SwitchInst;         // For getTopmostLLVMInstruction().
@@ -1914,6 +1917,30 @@ class CleanupPadInst : public FuncletPadInst {
   }
 };
 
+class CatchReturnInst
+    : public SingleLLVMInstructionImpl<llvm::CatchReturnInst> {
+  CatchReturnInst(llvm::CatchReturnInst *CRI, Context &Ctx)
+      : SingleLLVMInstructionImpl(ClassID::CatchRet, Opcode::CatchRet, CRI,
+                                  Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static CatchReturnInst *create(CatchPadInst *CatchPad, BasicBlock *BB,
+                                 BBIterator WhereIt, BasicBlock *WhereBB,
+                                 Context &Ctx);
+  CatchPadInst *getCatchPad() const;
+  void setCatchPad(CatchPadInst *CatchPad);
+  BasicBlock *getSuccessor() const;
+  void setSuccessor(BasicBlock *NewSucc);
+  unsigned getNumSuccessors() {
+    return cast<llvm::CatchReturnInst>(Val)->getNumSuccessors();
+  }
+  Value *getCatchSwitchParentPad() const;
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CatchRet;
+  }
+};
+
 class GetElementPtrInst final
     : public SingleLLVMInstructionImpl<llvm::GetElementPtrInst> {
   /// Use Context::createGetElementPtrInst(). Don't call
@@ -2820,6 +2847,8 @@ class Context {
   friend CatchPadInst; // For createCatchPadInst()
   CleanupPadInst *createCleanupPadInst(llvm::CleanupPadInst *I);
   friend CleanupPadInst; // For createCleanupPadInst()
+  CatchReturnInst *createCatchReturnInst(llvm::CatchReturnInst *I);
+  friend CatchReturnInst; // For createCatchReturnInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
   CatchSwitchInst *createCatchSwitchInst(llvm::CatchSwitchInst *I);

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index a75f872bc88acb..b7b396e30dc3ca 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -48,6 +48,7 @@ DEF_INSTR(Invoke,         OP(Invoke),         InvokeInst)
 DEF_INSTR(CallBr,         OP(CallBr),         CallBrInst)
 DEF_INSTR(CatchPad,       OP(CatchPad),       CatchPadInst)
 DEF_INSTR(CleanupPad,     OP(CleanupPad),     CleanupPadInst)
+DEF_INSTR(CatchRet,       OP(CatchRet),       CatchReturnInst)
 DEF_INSTR(GetElementPtr,  OP(GetElementPtr),  GetElementPtrInst)
 DEF_INSTR(CatchSwitch,    OP(CatchSwitch),    CatchSwitchInst)
 DEF_INSTR(Switch,         OP(Switch),         SwitchInst)

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 1ff82a968a717f..b953e68c33180e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1105,6 +1105,50 @@ CleanupPadInst *CleanupPadInst::create(Value *ParentPad, ArrayRef<Value *> Args,
   return Ctx.createCleanupPadInst(LLVMI);
 }
 
+CatchReturnInst *CatchReturnInst::create(CatchPadInst *CatchPad, BasicBlock *BB,
+                                         BBIterator WhereIt,
+                                         BasicBlock *WhereBB, Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  llvm::CatchReturnInst *LLVMI = Builder.CreateCatchRet(
+      cast<llvm::CatchPadInst>(CatchPad->Val), cast<llvm::BasicBlock>(BB->Val));
+  return Ctx.createCatchReturnInst(LLVMI);
+}
+
+CatchPadInst *CatchReturnInst::getCatchPad() const {
+  return cast<CatchPadInst>(
+      Ctx.getValue(cast<llvm::CatchReturnInst>(Val)->getCatchPad()));
+}
+
+void CatchReturnInst::setCatchPad(CatchPadInst *CatchPad) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&CatchReturnInst::getCatchPad,
+                                       &CatchReturnInst::setCatchPad>>(this);
+  cast<llvm::CatchReturnInst>(Val)->setCatchPad(
+      cast<llvm::CatchPadInst>(CatchPad->Val));
+}
+
+BasicBlock *CatchReturnInst::getSuccessor() const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::CatchReturnInst>(Val)->getSuccessor()));
+}
+
+void CatchReturnInst::setSuccessor(BasicBlock *NewSucc) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&CatchReturnInst::getSuccessor,
+                                       &CatchReturnInst::setSuccessor>>(this);
+  cast<llvm::CatchReturnInst>(Val)->setSuccessor(
+      cast<llvm::BasicBlock>(NewSucc->Val));
+}
+
+Value *CatchReturnInst::getCatchSwitchParentPad() const {
+  return Ctx.getValue(
+      cast<llvm::CatchReturnInst>(Val)->getCatchSwitchParentPad());
+}
+
 Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
                                  ArrayRef<Value *> IdxList,
                                  BasicBlock::iterator WhereIt,
@@ -2138,6 +2182,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
     return It->second.get();
   }
+  case llvm::Instruction::CatchRet: {
+    auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV);
+    It->second =
+        std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::GetElementPtr: {
     auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
     It->second = std::unique_ptr<GetElementPtrInst>(
@@ -2322,6 +2372,10 @@ CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
   auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
   return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
 }
+CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
+  auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
+  return cast<CatchReturnInst>(registerValue(std::move(NewPtr)));
+}
 GetElementPtrInst *
 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
   auto NewPtr =

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 28894397a60d6f..76ca64caeeeb07 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1957,6 +1957,71 @@ define void @foo() {
 #endif // NDEBUG
 }
 
+TEST_F(SandboxIRTest, CatchReturnInst) {
+  parseIR(C, R"IR(
+define void @foo() {
+dispatch:
+  %cs = catchswitch within none [label %catch] unwind to caller
+catch:
+  %catchpad = catchpad within %cs [ptr @foo]
+  catchret from %catchpad to label %continue
+continue:
+  ret void
+catch2:
+  %catchpad2 = catchpad within %cs [ptr @foo]
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  BasicBlock *LLVMCatch = getBasicBlockByName(LLVMF, "catch");
+  auto LLVMIt = LLVMCatch->begin();
+  [[maybe_unused]] auto *LLVMCP = cast<llvm::CatchPadInst>(&*LLVMIt++);
+  auto *LLVMCR = cast<llvm::CatchReturnInst>(&*LLVMIt++);
+
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Catch = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMCatch));
+  auto *Catch2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "catch2")));
+  auto It = Catch->begin();
+  [[maybe_unused]] auto *CP = cast<sandboxir::CatchPadInst>(&*It++);
+  auto *CR = cast<sandboxir::CatchReturnInst>(&*It++);
+  auto *CP2 = cast<sandboxir::CatchPadInst>(&*Catch2->begin());
+
+  // Check getCatchPad().
+  EXPECT_EQ(CR->getCatchPad(), Ctx.getValue(LLVMCR->getCatchPad()));
+  // Check setCatchPad().
+  auto *OrigCP = CR->getCatchPad();
+  auto *NewCP = CP2;
+  EXPECT_NE(NewCP, OrigCP);
+  CR->setCatchPad(NewCP);
+  EXPECT_EQ(CR->getCatchPad(), NewCP);
+  CR->setCatchPad(OrigCP);
+  EXPECT_EQ(CR->getCatchPad(), OrigCP);
+  // Check getSuccessor().
+  EXPECT_EQ(CR->getSuccessor(), Ctx.getValue(LLVMCR->getSuccessor()));
+  // Check setSuccessor().
+  auto *OrigSucc = CR->getSuccessor();
+  auto *NewSucc = Catch;
+  EXPECT_NE(NewSucc, OrigSucc);
+  CR->setSuccessor(NewSucc);
+  EXPECT_EQ(CR->getSuccessor(), NewSucc);
+  CR->setSuccessor(OrigSucc);
+  EXPECT_EQ(CR->getSuccessor(), OrigSucc);
+  // Check getNumSuccessors().
+  EXPECT_EQ(CR->getNumSuccessors(), LLVMCR->getNumSuccessors());
+  // Check getCatchSwitchParentPad().
+  EXPECT_EQ(CR->getCatchSwitchParentPad(),
+            Ctx.getValue(LLVMCR->getCatchSwitchParentPad()));
+  // Check create().
+  auto *CRI =
+      cast<sandboxir::CatchReturnInst>(sandboxir::CatchReturnInst::create(
+          CP, Catch, CP->getIterator(), Catch, Ctx));
+  EXPECT_EQ(CRI->getNextNode(), CP);
+  EXPECT_EQ(CRI->getCatchPad(), CP);
+  EXPECT_EQ(CRI->getSuccessor(), Catch);
+}
+
 TEST_F(SandboxIRTest, GetElementPtrInstruction) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, <2 x ptr> %ptrs) {

diff  --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index c2faf60a57f3b8..6614ab7fa248e1 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -713,6 +713,56 @@ define void @foo(i32 %cond0, i32 %cond1) {
   EXPECT_EQ(*HIt++, Handler1);
 }
 
+TEST_F(TrackerTest, CatchReturnInstSetters) {
+  parseIR(C, R"IR(
+define void @foo() {
+dispatch:
+  %cs = catchswitch within none [label %catch] unwind to caller
+catch:
+  %catchpad = catchpad within %cs [ptr @foo]
+  catchret from %catchpad to label %continue
+continue:
+  ret void
+catch2:
+  %catchpad2 = catchpad within %cs [ptr @foo]
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  BasicBlock *LLVMCatch = getBasicBlockByName(LLVMF, "catch");
+  auto LLVMIt = LLVMCatch->begin();
+  [[maybe_unused]] auto *LLVMCP = cast<llvm::CatchPadInst>(&*LLVMIt++);
+
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Catch = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMCatch));
+  auto *Catch2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "catch2")));
+  auto It = Catch->begin();
+  [[maybe_unused]] auto *CP = cast<sandboxir::CatchPadInst>(&*It++);
+  auto *CR = cast<sandboxir::CatchReturnInst>(&*It++);
+  auto *CP2 = cast<sandboxir::CatchPadInst>(&*Catch2->begin());
+
+  // Check setCatchPad().
+  auto *OrigCP = CR->getCatchPad();
+  auto *NewCP = CP2;
+  EXPECT_NE(NewCP, OrigCP);
+  Ctx.save();
+  CR->setCatchPad(NewCP);
+  EXPECT_EQ(CR->getCatchPad(), NewCP);
+  Ctx.revert();
+  EXPECT_EQ(CR->getCatchPad(), OrigCP);
+  // Check setSuccessor().
+  auto *OrigSucc = CR->getSuccessor();
+  auto *NewSucc = Catch;
+  EXPECT_NE(NewSucc, OrigSucc);
+  Ctx.save();
+  CR->setSuccessor(NewSucc);
+  EXPECT_EQ(CR->getSuccessor(), NewSucc);
+  Ctx.revert();
+  EXPECT_EQ(CR->getSuccessor(), OrigSucc);
+}
+
 TEST_F(TrackerTest, SwitchInstSetters) {
   parseIR(C, R"IR(
 define void @foo(i32 %cond0, i32 %cond1) {


        


More information about the llvm-commits mailing list