[llvm] [SandboxIR] Implement InvokeInst (PR #100796)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 26 13:34:46 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/100796

>From c80d0dcb4e1a4a53b9951d3558c81ed5a33dcd9a Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 23 Jul 2024 15:19:56 -0700
Subject: [PATCH] [SandboxIR] Implement InvokeInst

This patch implements sandboxir::InvokeInst which mirrors llvm::InvokeInst.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h    | 78 ++++++++++++++++++-
 llvm/lib/SandboxIR/SandboxIR.cpp           | 85 ++++++++++++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 91 ++++++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp   | 56 +++++++++++++
 4 files changed, 306 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 2678ee0f4f90a..97a5fd4898d45 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -30,9 +30,9 @@
 //                                      |
 //                                      +- CastInst
 //                                      |
-//                                      +- CallBase ----- CallInst
-//                                      |
-//                                      +- CmpInst
+//                                      +- CallBase ------+- CallInst
+//                                      |                 |
+//                                      +- CmpInst        +- InvokeInst
 //                                      |
 //                                      +- ExtractElementInst
 //                                      |
@@ -90,6 +90,7 @@ class User;
 class Value;
 class CallBase;
 class CallInst;
+class InvokeInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -203,6 +204,7 @@ class Value {
   friend class ReturnInst; // For getting `Val`.
   friend class CallBase;   // For getting `Val`.
   friend class CallInst;   // For getting `Val`.
+  friend class InvokeInst; // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -541,6 +543,7 @@ class Instruction : public sandboxir::User {
   friend class StoreInst;  // For getTopmostLLVMInstruction().
   friend class ReturnInst; // For getTopmostLLVMInstruction().
   friend class CallInst;   // For getTopmostLLVMInstruction().
+  friend class InvokeInst; // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -861,7 +864,8 @@ class ReturnInst final : public Instruction {
 class CallBase : public Instruction {
   CallBase(ClassID ID, Opcode Opc, llvm::Instruction *I, Context &Ctx)
       : Instruction(ID, Opc, I, Ctx) {}
-  friend class CallInst; // For constructor.
+  friend class CallInst;   // For constructor.
+  friend class InvokeInst; // For constructor.
 
 public:
   static bool classof(const Value *From) {
@@ -1029,6 +1033,70 @@ class CallInst final : public CallBase {
 #endif
 };
 
+class InvokeInst final : public CallBase {
+  /// Use Context::createInvokeInst(). Don't call the
+  /// constructor directly.
+  InvokeInst(llvm::Instruction *I, Context &Ctx)
+      : CallBase(ClassID::Invoke, Opcode::Invoke, I, Ctx) {}
+  friend class Context; // For accessing the constructor in
+                        // create*()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  static InvokeInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *IfNormal, BasicBlock *IfException,
+                            ArrayRef<Value *> Args, BBIterator WhereIt,
+                            BasicBlock *WhereBB, Context &Ctx,
+                            const Twine &NameStr = "");
+  static InvokeInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *IfNormal, BasicBlock *IfException,
+                            ArrayRef<Value *> Args, Instruction *InsertBefore,
+                            Context &Ctx, const Twine &NameStr = "");
+  static InvokeInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *IfNormal, BasicBlock *IfException,
+                            ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                            Context &Ctx, const Twine &NameStr = "");
+
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::Invoke;
+  }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  BasicBlock *getNormalDest() const;
+  BasicBlock *getUnwindDest() const;
+  void setNormalDest(BasicBlock *BB);
+  void setUnwindDest(BasicBlock *BB);
+  // TODO: Return a `LandingPadInst` once implemented.
+  Instruction *getLandingPadInst() const;
+  BasicBlock *getSuccessor(unsigned SuccIdx) const;
+  void setSuccessor(unsigned SuccIdx, BasicBlock *NewSucc) {
+    assert(SuccIdx < 2 && "Successor # out of range for invoke!");
+    if (SuccIdx == 0)
+      setNormalDest(NewSucc);
+    else
+      setUnwindDest(NewSucc);
+  }
+  unsigned getNumSuccessors() const {
+    return cast<llvm::InvokeInst>(Val)->getNumSuccessors();
+  }
+#ifndef NDEBUG
+  void verify() const final {}
+  friend raw_ostream &operator<<(raw_ostream &OS, const InvokeInst &I) {
+    I.dump(OS);
+    return OS;
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -1179,6 +1247,8 @@ class Context {
   friend ReturnInst; // For createReturnInst()
   CallInst *createCallInst(llvm::CallInst *I);
   friend CallInst; // For createCallInst()
+  InvokeInst *createInvokeInst(llvm::InvokeInst *I);
+  friend InvokeInst; // For createInvokeInst()
 
 public:
   Context(LLVMContext &LLVMCtx)
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index da482765c7d11..2dc9f5864dc5c 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -809,6 +809,81 @@ void CallInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *IfNormal, BasicBlock *IfException,
+                               ArrayRef<Value *> Args, BBIterator WhereIt,
+                               BasicBlock *WhereBB, Context &Ctx,
+                               const Twine &NameStr) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+  SmallVector<llvm::Value *> LLVMArgs;
+  LLVMArgs.reserve(Args.size());
+  for (Value *Arg : Args)
+    LLVMArgs.push_back(Arg->Val);
+  llvm::InvokeInst *Invoke = Builder.CreateInvoke(
+      FTy, Func->Val, cast<llvm::BasicBlock>(IfNormal->Val),
+      cast<llvm::BasicBlock>(IfException->Val), LLVMArgs, NameStr);
+  return Ctx.createInvokeInst(Invoke);
+}
+
+InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *IfNormal, BasicBlock *IfException,
+                               ArrayRef<Value *> Args,
+                               Instruction *InsertBefore, Context &Ctx,
+                               const Twine &NameStr) {
+  return create(FTy, Func, IfNormal, IfException, Args,
+                InsertBefore->getIterator(), InsertBefore->getParent(), Ctx,
+                NameStr);
+}
+
+InvokeInst *InvokeInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *IfNormal, BasicBlock *IfException,
+                               ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                               Context &Ctx, const Twine &NameStr) {
+  return create(FTy, Func, IfNormal, IfException, Args, InsertAtEnd->end(),
+                InsertAtEnd, Ctx, NameStr);
+}
+
+BasicBlock *InvokeInst::getNormalDest() const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::InvokeInst>(Val)->getNormalDest()));
+}
+BasicBlock *InvokeInst::getUnwindDest() const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::InvokeInst>(Val)->getUnwindDest()));
+}
+void InvokeInst::setNormalDest(BasicBlock *BB) {
+  setOperand(1, BB);
+  assert(getNormalDest() == BB && "LLVM IR uses a different operan index!");
+}
+void InvokeInst::setUnwindDest(BasicBlock *BB) {
+  setOperand(2, BB);
+  assert(getUnwindDest() == BB && "LLVM IR uses a different operan index!");
+}
+Instruction *InvokeInst::getLandingPadInst() const {
+  return cast<Instruction>(
+      Ctx.getValue(cast<llvm::InvokeInst>(Val)->getLandingPadInst()));
+  ;
+}
+BasicBlock *InvokeInst::getSuccessor(unsigned SuccIdx) const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::InvokeInst>(Val)->getSuccessor(SuccIdx)));
+}
+
+#ifndef NDEBUG
+void InvokeInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+void InvokeInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -968,6 +1043,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Invoke: {
+    auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
+    It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
+    return It->second.get();
+  }
   default:
     break;
   }
@@ -1016,6 +1096,11 @@ CallInst *Context::createCallInst(llvm::CallInst *I) {
   return cast<CallInst>(registerValue(std::move(NewPtr)));
 }
 
+InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
+  auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
+  return cast<InvokeInst>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 05ec42c952eb6..84e895d656fc4 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1036,3 +1036,94 @@ define i8 @foo(i8 %arg) {
     EXPECT_EQ(Call->getArgOperand(0), Arg0);
   }
 }
+
+TEST_F(SandboxIRTest, InvokeInst) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg) {
+ bb0:
+   invoke i8 @foo(i8 %arg) to label %normal_bb
+                       unwind label %exception_bb
+ normal_bb:
+   ret void
+ exception_bb:
+   %lpad = landingpad { ptr, i32}
+           cleanup
+   ret void
+ other_bb:
+   ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Arg = F.getArg(0);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *NormalBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "normal_bb")));
+  auto *ExceptionBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "exception_bb")));
+  auto *LandingPad = &*ExceptionBB->begin();
+  auto *OtherBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
+  auto It = BB0->begin();
+  // Check classof(Instruction *).
+  auto *Invoke = cast<sandboxir::InvokeInst>(&*It++);
+
+  // Check getNormalDest().
+  EXPECT_EQ(Invoke->getNormalDest(), NormalBB);
+  // Check getUnwindDest().
+  EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
+  // Check getSuccessor().
+  EXPECT_EQ(Invoke->getSuccessor(0), NormalBB);
+  EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
+  // Check setNormalDest().
+  Invoke->setNormalDest(OtherBB);
+  EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
+  EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
+  // Check setUnwindDest().
+  Invoke->setUnwindDest(OtherBB);
+  EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
+  EXPECT_EQ(Invoke->getUnwindDest(), OtherBB);
+  // Check setSuccessor().
+  Invoke->setSuccessor(0, NormalBB);
+  EXPECT_EQ(Invoke->getNormalDest(), NormalBB);
+  Invoke->setSuccessor(1, ExceptionBB);
+  EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
+  // Check getLandingPadInst().
+  EXPECT_EQ(Invoke->getLandingPadInst(), LandingPad);
+
+  {
+    // Check create() WhereIt, WhereBB.
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *InsertBefore = &*BB0->begin();
+    auto *NewInvoke = cast<sandboxir::InvokeInst>(sandboxir::InvokeInst::create(
+        F.getFunctionType(), &F, NormalBB, ExceptionBB, Args,
+        /*WhereIt=*/InsertBefore->getIterator(), /*WhereBB=*/BB0, Ctx));
+    EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
+    EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
+    EXPECT_EQ(NewInvoke->getNextNode(), InsertBefore);
+  }
+  {
+    // Check create() InsertBefore.
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *InsertBefore = &*BB0->begin();
+    auto *NewInvoke = cast<sandboxir::InvokeInst>(
+        sandboxir::InvokeInst::create(F.getFunctionType(), &F, NormalBB,
+                                      ExceptionBB, Args, InsertBefore, Ctx));
+    EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
+    EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
+    EXPECT_EQ(NewInvoke->getNextNode(), InsertBefore);
+  }
+  {
+    // Check create() InsertAtEnd.
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *NewInvoke = cast<sandboxir::InvokeInst>(sandboxir::InvokeInst::create(
+        F.getFunctionType(), &F, NormalBB, ExceptionBB, Args,
+        /*InsertAtEnd=*/BB0, Ctx));
+    EXPECT_EQ(NewInvoke->getNormalDest(), NormalBB);
+    EXPECT_EQ(NewInvoke->getUnwindDest(), ExceptionBB);
+    EXPECT_EQ(NewInvoke->getParent(), BB0);
+    EXPECT_EQ(NewInvoke->getNextNode(), nullptr);
+  }
+}
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 5111d5f38798f..04536411e02d0 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -488,3 +488,59 @@ define void @foo(i8 %arg0, i8 %arg1) {
   Ctx.revert();
   EXPECT_EQ(Call->getCalledFunction(), Bar1F);
 }
+
+TEST_F(TrackerTest, InvokeSetters) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg) {
+ bb0:
+   invoke i8 @foo(i8 %arg) to label %normal_bb
+                       unwind label %exception_bb
+ normal_bb:
+   ret void
+ exception_bb:
+   ret void
+ other_bb:
+   ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *NormalBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "normal_bb")));
+  auto *ExceptionBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "exception_bb")));
+  auto *OtherBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
+  auto It = BB0->begin();
+  auto *Invoke = cast<sandboxir::InvokeInst>(&*It++);
+
+  // Check setNormalDest().
+  Ctx.save();
+  Invoke->setNormalDest(OtherBB);
+  EXPECT_EQ(Invoke->getNormalDest(), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(Invoke->getNormalDest(), NormalBB);
+
+  // Check setUnwindDest().
+  Ctx.save();
+  Invoke->setUnwindDest(OtherBB);
+  EXPECT_EQ(Invoke->getUnwindDest(), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(Invoke->getUnwindDest(), ExceptionBB);
+
+  // Check setSuccessor().
+  Ctx.save();
+  Invoke->setSuccessor(0, OtherBB);
+  EXPECT_EQ(Invoke->getSuccessor(0), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(Invoke->getSuccessor(0), NormalBB);
+
+  Ctx.save();
+  Invoke->setSuccessor(1, OtherBB);
+  EXPECT_EQ(Invoke->getSuccessor(1), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
+}



More information about the llvm-commits mailing list