[llvm] [SandboxIR] Implement FuncletPadInst, CatchPadInst and CleanupInst (PR #105294)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 12:37:49 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/105294

This patch implements sandboxir::FuncletPadInst,CatchInst,CleanupInst mirroring their llvm:: counterparts.

>From eb718493be6cd15feeef91f33e7f4bdd0b7470e0 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 14 Aug 2024 16:45:04 -0700
Subject: [PATCH] [SandboxIR] Implement FuncletPadInst, CatchPadInst and
 CleanupInst

This patch implements sandboxir::FuncletPadInst,CatchInst,CleanupInst
mirroring their llvm:: counterparts.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 125 ++++++++++++++----
 .../llvm/SandboxIR/SandboxIRValues.def        |   2 +
 llvm/lib/SandboxIR/SandboxIR.cpp              |  83 +++++++++++-
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    |  90 +++++++++++++
 4 files changed, 274 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index ca71566091bf82..562989c227fb3b 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -126,6 +126,9 @@ class CallBase;
 class CallInst;
 class InvokeInst;
 class CallBrInst;
+class FuncletPadInst;
+class CatchPadInst;
+class CleanupPadInst;
 class GetElementPtrInst;
 class CastInst;
 class PtrToIntInst;
@@ -240,31 +243,34 @@ class Value {
   /// order.
   llvm::Value *Val = nullptr;
 
-  friend class Context;            // For getting `Val`.
-  friend class User;               // For getting `Val`.
-  friend class Use;                // For getting `Val`.
-  friend class SelectInst;         // For getting `Val`.
-  friend class ExtractElementInst; // For getting `Val`.
-  friend class InsertElementInst;  // For getting `Val`.
-  friend class BranchInst;         // For getting `Val`.
-  friend class LoadInst;           // For getting `Val`.
-  friend class StoreInst;          // For getting `Val`.
-  friend class ReturnInst;         // For getting `Val`.
-  friend class CallBase;           // For getting `Val`.
-  friend class CallInst;           // For getting `Val`.
-  friend class InvokeInst;         // For getting `Val`.
-  friend class CallBrInst;         // For getting `Val`.
-  friend class GetElementPtrInst;  // For getting `Val`.
-  friend class CatchSwitchInst;    // For getting `Val`.
-  friend class SwitchInst;         // For getting `Val`.
-  friend class UnaryOperator;      // For getting `Val`.
-  friend class BinaryOperator;     // For getting `Val`.
-  friend class AtomicRMWInst;      // For getting `Val`.
-  friend class AtomicCmpXchgInst;  // For getting `Val`.
-  friend class AllocaInst;         // For getting `Val`.
-  friend class CastInst;           // For getting `Val`.
-  friend class PHINode;            // For getting `Val`.
-  friend class UnreachableInst;    // For getting `Val`.
+  friend class Context;               // For getting `Val`.
+  friend class User;                  // For getting `Val`.
+  friend class Use;                   // For getting `Val`.
+  friend class SelectInst;            // For getting `Val`.
+  friend class ExtractElementInst;    // For getting `Val`.
+  friend class InsertElementInst;     // For getting `Val`.
+  friend class BranchInst;            // For getting `Val`.
+  friend class LoadInst;              // For getting `Val`.
+  friend class StoreInst;             // For getting `Val`.
+  friend class ReturnInst;            // For getting `Val`.
+  friend class CallBase;              // For getting `Val`.
+  friend class CallInst;              // For getting `Val`.
+  friend class InvokeInst;            // For getting `Val`.
+  friend class CallBrInst;            // For getting `Val`.
+  friend class FuncletPadInst;        // For getting `Val`.
+  friend class CatchPadInst;          // For getting `Val`.
+  friend class CleanupPadInst;        // For getting `Val`.
+  friend class GetElementPtrInst;     // For getting `Val`.
+  friend class CatchSwitchInst;       // For getting `Val`.
+  friend class SwitchInst;            // For getting `Val`.
+  friend class UnaryOperator;         // For getting `Val`.
+  friend class BinaryOperator;        // For getting `Val`.
+  friend class AtomicRMWInst;         // For getting `Val`.
+  friend class AtomicCmpXchgInst;     // For getting `Val`.
+  friend class AllocaInst;            // For getting `Val`.
+  friend class CastInst;              // For getting `Val`.
+  friend class PHINode;               // For getting `Val`.
+  friend class UnreachableInst;       // For getting `Val`.
   friend class CatchSwitchAddHandler; // For `Val`.
 
   /// All values point to the context.
@@ -676,6 +682,8 @@ class Instruction : public sandboxir::User {
   friend class CallInst;           // For getTopmostLLVMInstruction().
   friend class InvokeInst;         // For getTopmostLLVMInstruction().
   friend class CallBrInst;         // For getTopmostLLVMInstruction().
+  friend class CatchPadInst;       // For getTopmostLLVMInstruction().
+  friend class CleanupPadInst;     // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
   friend class CatchSwitchInst;    // For getTopmostLLVMInstruction().
   friend class SwitchInst;         // For getTopmostLLVMInstruction().
@@ -842,6 +850,7 @@ template <typename LLVMT> class SingleLLVMInstructionImpl : public Instruction {
 #include "llvm/SandboxIR/SandboxIRValues.def"
   friend class UnaryInstruction;
   friend class CallBase;
+  friend class FuncletPadInst;
 
   Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
     return getOperandUseDefault(OpIdx, Verify);
@@ -1394,6 +1403,68 @@ class CallBrInst final : public CallBase {
   }
 };
 
+class FuncletPadInst : public SingleLLVMInstructionImpl<llvm::FuncletPadInst> {
+  FuncletPadInst(ClassID SubclassID, Opcode Opc, llvm::Instruction *I,
+                 Context &Ctx)
+      : SingleLLVMInstructionImpl(SubclassID, Opc, I, Ctx) {}
+  friend class CatchPadInst;   // For constructor.
+  friend class CleanupPadInst; // For constructor.
+
+public:
+  /// Return the number of funcletpad arguments.
+  unsigned arg_size() const {
+    return cast<llvm::FuncletPadInst>(Val)->arg_size();
+  }
+  /// Return the outer EH-pad this funclet is nested within.
+  ///
+  /// Note: This returns the associated CatchSwitchInst if this FuncletPadInst
+  /// is a CatchPadInst.
+  Value *getParentPad() const;
+  void setParentPad(Value *ParentPad);
+  /// Return the Idx-th funcletpad argument.
+  Value *getArgOperand(unsigned Idx) const;
+  /// Set the Idx-th funcletpad argument.
+  void setArgOperand(unsigned Idx, Value *V);
+
+  // TODO: Implement missing functions: arg_operands().
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CatchPad ||
+           From->getSubclassID() == ClassID::CleanupPad;
+  }
+};
+
+class CatchPadInst : public FuncletPadInst {
+  CatchPadInst(llvm::CatchPadInst *CPI, Context &Ctx)
+      : FuncletPadInst(ClassID::CatchPad, Opcode::CatchPad, CPI, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  CatchSwitchInst *getCatchSwitch() const;
+  // TODO: We have not implemented setCatchSwitch() because we can't revert it
+  // for now, as there is no CatchPadInst member function that can undo it.
+
+  static CatchPadInst *create(Value *ParentPad, ArrayRef<Value *> Args,
+                              BBIterator WhereIt, BasicBlock *WhereBB,
+                              Context &Ctx, const Twine &Name = "");
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CatchPad;
+  }
+};
+
+class CleanupPadInst : public FuncletPadInst {
+  CleanupPadInst(llvm::CleanupPadInst *CPI, Context &Ctx)
+      : FuncletPadInst(ClassID::CleanupPad, Opcode::CleanupPad, CPI, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static CleanupPadInst *create(Value *ParentPad, ArrayRef<Value *> Args,
+                                BBIterator WhereIt, BasicBlock *WhereBB,
+                                Context &Ctx, const Twine &Name = "");
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CleanupPad;
+  }
+};
+
 class GetElementPtrInst final
     : public SingleLLVMInstructionImpl<llvm::GetElementPtrInst> {
   /// Use Context::createGetElementPtrInst(). Don't call
@@ -2294,6 +2365,10 @@ class Context {
   friend InvokeInst; // For createInvokeInst()
   CallBrInst *createCallBrInst(llvm::CallBrInst *I);
   friend CallBrInst; // For createCallBrInst()
+  CatchPadInst *createCatchPadInst(llvm::CatchPadInst *I);
+  friend CatchPadInst; // For createCatchPadInst()
+  CleanupPadInst *createCleanupPadInst(llvm::CleanupPadInst *I);
+  friend CleanupPadInst; // For createCleanupPadInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
   CatchSwitchInst *createCatchSwitchInst(llvm::CatchSwitchInst *I);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 402b6f3324a222..6d6795d8681fd9 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -45,6 +45,8 @@ DEF_INSTR(Ret,           OP(Ret),           ReturnInst)
 DEF_INSTR(Call,          OP(Call),          CallInst)
 DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
 DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
+DEF_INSTR(CatchPad,      OP(CatchPad),      CatchPadInst)
+DEF_INSTR(CleanupPad,    OP(CleanupPad),    CleanupPadInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
 DEF_INSTR(CatchSwitch,   OP(CatchSwitch),   CatchSwitchInst)
 DEF_INSTR(Switch,        OP(Switch),        SwitchInst)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 5b170cee20c940..66809dfa15560d 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1043,6 +1043,68 @@ BasicBlock *CallBrInst::getSuccessor(unsigned Idx) const {
       Ctx.getValue(cast<llvm::CallBrInst>(Val)->getSuccessor(Idx)));
 }
 
+Value *FuncletPadInst::getParentPad() const {
+  return Ctx.getValue(cast<llvm::FuncletPadInst>(Val)->getParentPad());
+}
+
+void FuncletPadInst::setParentPad(Value *ParentPad) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetter<&FuncletPadInst::getParentPad,
+                                       &FuncletPadInst::setParentPad>>(this);
+  cast<llvm::FuncletPadInst>(Val)->setParentPad(ParentPad->Val);
+}
+
+Value *FuncletPadInst::getArgOperand(unsigned Idx) const {
+  return Ctx.getValue(cast<llvm::FuncletPadInst>(Val)->getArgOperand(Idx));
+}
+
+void FuncletPadInst::setArgOperand(unsigned Idx, Value *V) {
+  Ctx.getTracker()
+      .emplaceIfTracking<GenericSetterWithIdx<&FuncletPadInst::getArgOperand,
+                                              &FuncletPadInst::setArgOperand>>(
+          this, Idx);
+  cast<llvm::FuncletPadInst>(Val)->setArgOperand(Idx, V->Val);
+}
+
+CatchSwitchInst *CatchPadInst::getCatchSwitch() const {
+  return cast<CatchSwitchInst>(
+      Ctx.getValue(cast<llvm::CatchPadInst>(Val)->getCatchSwitch()));
+}
+
+CatchPadInst *CatchPadInst::create(Value *ParentPad, ArrayRef<Value *> Args,
+                                   BBIterator WhereIt, BasicBlock *WhereBB,
+                                   Context &Ctx, const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  SmallVector<llvm::Value *> LLVMArgs;
+  LLVMArgs.reserve(Args.size());
+  for (auto *Arg : Args)
+    LLVMArgs.push_back(Arg->Val);
+  llvm::CatchPadInst *LLVMI =
+      Builder.CreateCatchPad(ParentPad->Val, LLVMArgs, Name);
+  return Ctx.createCatchPadInst(LLVMI);
+}
+
+CleanupPadInst *CleanupPadInst::create(Value *ParentPad, ArrayRef<Value *> Args,
+                                       BBIterator WhereIt, BasicBlock *WhereBB,
+                                       Context &Ctx, const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  SmallVector<llvm::Value *> LLVMArgs;
+  LLVMArgs.reserve(Args.size());
+  for (auto *Arg : Args)
+    LLVMArgs.push_back(Arg->Val);
+  llvm::CleanupPadInst *LLVMI =
+      Builder.CreateCleanupPad(ParentPad->Val, LLVMArgs, Name);
+  return Ctx.createCleanupPadInst(LLVMI);
+}
+
 Value *GetElementPtrInst::create(Type *Ty, Value *Ptr,
                                  ArrayRef<Value *> IdxList,
                                  BasicBlock::iterator WhereIt,
@@ -1992,6 +2054,18 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
     return It->second.get();
   }
+  case llvm::Instruction::CatchPad: {
+    auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV);
+    It->second =
+        std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
+    return It->second.get();
+  }
+  case llvm::Instruction::CleanupPad: {
+    auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV);
+    It->second =
+        std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::GetElementPtr: {
     auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
     It->second = std::unique_ptr<GetElementPtrInst>(
@@ -2161,7 +2235,14 @@ UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
       std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
   return cast<UnreachableInst>(registerValue(std::move(NewPtr)));
 }
-
+CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
+  auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
+  return cast<CatchPadInst>(registerValue(std::move(NewPtr)));
+}
+CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
+  auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
+  return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
+}
 GetElementPtrInst *
 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
   auto NewPtr =
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 712865fd07cd7b..7fba8393ac76ac 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1453,6 +1453,96 @@ define void @foo(i8 %arg) {
   }
 }
 
+TEST_F(SandboxIRTest, FuncletPadInst_CatchPadInst_CleanupPadInst) {
+  parseIR(C, R"IR(
+define void @foo() {
+dispatch:
+  %cs = catchswitch within none [label %handler0] unwind to caller
+handler0:
+  %catchpad = catchpad within %cs [ptr @foo]
+  ret void
+handler1:
+  %cleanuppad = cleanuppad within %cs [ptr @foo]
+  ret void
+bb:
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  BasicBlock *LLVMDispatch = getBasicBlockByName(LLVMF, "dispatch");
+  BasicBlock *LLVMHandler0 = getBasicBlockByName(LLVMF, "handler0");
+  BasicBlock *LLVMHandler1 = getBasicBlockByName(LLVMF, "handler1");
+  auto *LLVMCP = cast<llvm::CatchPadInst>(&*LLVMHandler0->begin());
+  auto *LLVMCLP = cast<llvm::CleanupPadInst>(&*LLVMHandler1->begin());
+
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Dispatch = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMDispatch));
+  auto *Handler0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler0));
+  auto *Handler1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMHandler1));
+  auto *BB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb")));
+  auto *BBRet = cast<sandboxir::ReturnInst>(&*BB->begin());
+  auto *CS = cast<sandboxir::CatchSwitchInst>(&*Dispatch->begin());
+  [[maybe_unused]] auto *CP =
+      cast<sandboxir::CatchPadInst>(&*Handler0->begin());
+  [[maybe_unused]] auto *CLP =
+      cast<sandboxir::CleanupPadInst>(&*Handler1->begin());
+
+  // Check getCatchSwitch().
+  EXPECT_EQ(CP->getCatchSwitch(), CS);
+  EXPECT_EQ(CP->getCatchSwitch(), Ctx.getValue(LLVMCP->getCatchSwitch()));
+
+  for (llvm::FuncletPadInst *LLVMFPI :
+       {static_cast<llvm::FuncletPadInst *>(LLVMCP),
+        static_cast<llvm::FuncletPadInst *>(LLVMCLP)}) {
+    auto *FPI = cast<sandboxir::FuncletPadInst>(Ctx.getValue(LLVMFPI));
+    // Check arg_size().
+    EXPECT_EQ(FPI->arg_size(), LLVMFPI->arg_size());
+    // Check getParentPad().
+    EXPECT_EQ(FPI->getParentPad(), Ctx.getValue(LLVMFPI->getParentPad()));
+    // Check setParentPad().
+    auto *OrigParentPad = FPI->getParentPad();
+    auto *NewParentPad = Dispatch;
+    EXPECT_NE(NewParentPad, OrigParentPad);
+    FPI->setParentPad(NewParentPad);
+    EXPECT_EQ(FPI->getParentPad(), NewParentPad);
+    FPI->setParentPad(OrigParentPad);
+    EXPECT_EQ(FPI->getParentPad(), OrigParentPad);
+    // Check getArgOperand().
+    for (auto Idx : seq<unsigned>(0, FPI->arg_size()))
+      EXPECT_EQ(FPI->getArgOperand(Idx),
+                Ctx.getValue(LLVMFPI->getArgOperand(Idx)));
+    // Check setArgOperand().
+    auto *OrigArgOperand = FPI->getArgOperand(0);
+    auto *NewArgOperand = Dispatch;
+    EXPECT_NE(NewArgOperand, OrigArgOperand);
+    FPI->setArgOperand(0, NewArgOperand);
+    EXPECT_EQ(FPI->getArgOperand(0), NewArgOperand);
+    FPI->setArgOperand(0, OrigArgOperand);
+    EXPECT_EQ(FPI->getArgOperand(0), OrigArgOperand);
+  }
+  // Check CatchPadInst::create().
+  auto *NewCPI = cast<sandboxir::CatchPadInst>(sandboxir::CatchPadInst::create(
+      CS, {}, BBRet->getIterator(), BB, Ctx, "NewCPI"));
+  EXPECT_EQ(NewCPI->getCatchSwitch(), CS);
+  EXPECT_EQ(NewCPI->arg_size(), 0u);
+  EXPECT_EQ(NewCPI->getNextNode(), BBRet);
+#ifndef NDEBUG
+  EXPECT_EQ(NewCPI->getName(), "NewCPI");
+#endif // NDEBUG
+  // Check CleanupPadInst::create().
+  auto *NewCLPI =
+      cast<sandboxir::CleanupPadInst>(sandboxir::CleanupPadInst::create(
+          CS, {}, BBRet->getIterator(), BB, Ctx, "NewCLPI"));
+  EXPECT_EQ(NewCLPI->getParentPad(), CS);
+  EXPECT_EQ(NewCLPI->arg_size(), 0u);
+  EXPECT_EQ(NewCLPI->getNextNode(), BBRet);
+#ifndef NDEBUG
+  EXPECT_EQ(NewCLPI->getName(), "NewCLPI");
+#endif // NDEBUG
+}
+
 TEST_F(SandboxIRTest, GetElementPtrInstruction) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, <2 x ptr> %ptrs) {



More information about the llvm-commits mailing list