[llvm] [SandboxIR] Implement CallBrInst (PR #100823)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 26 15:11:27 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/100823

This patch implements sandboxir::CallBrInst which mirrors llvm::CallBrInst.

LLVM IR does not expose the Uses to DefaultDest and IndirectDest so we need special Tracker objects for both of setters.

>From e20334228bd009c8925688345368e8f9d3f9f8df Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 24 Jul 2024 10:06:18 -0700
Subject: [PATCH] [SandboxIR] Implement CallBrInst

This patch implements sandboxir::CallBrInst which mirrors llvm::CallBrInst.

LLVM IR does not expose the Uses to DefaultDest and IndirectDest so we need
special Tracker objects for both of setters.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h    |  79 +++++++++++++-
 llvm/include/llvm/SandboxIR/Tracker.h      |  36 +++++++
 llvm/include/llvm/SandboxIR/Use.h          |   1 +
 llvm/lib/SandboxIR/SandboxIR.cpp           | 110 ++++++++++++++++++++
 llvm/lib/SandboxIR/Tracker.cpp             |  31 ++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 114 +++++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp   |  40 ++++++++
 7 files changed, 406 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 1397d9da70643..e19bee4e08036 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -30,11 +30,11 @@
 //                                      |
 //                                      +- CastInst
 //                                      |
-//                                      +- CallBase ------+- CallInst
-//                                      |                 |
-//                                      +- CmpInst        +- InvokeInst
-//                                      |
-//                                      +- ExtractElementInst
+//                                      +- CallBase -----------+- CallBrInst
+//                                      |                      |
+//                                      +- CmpInst             +- CallInst
+//                                      |                      |
+//                                      +- ExtractElementInst  +- InvokeInst
 //                                      |
 //                                      +- GetElementPtrInst
 //                                      |
@@ -92,6 +92,7 @@ class Value;
 class CallBase;
 class CallInst;
 class InvokeInst;
+class CallBrInst;
 
 /// Iterator for the `Use` edges of a User's operands.
 /// \Returns the operand `Use` when dereferenced.
@@ -206,6 +207,7 @@ class Value {
   friend class CallBase;   // For getting `Val`.
   friend class CallInst;   // For getting `Val`.
   friend class InvokeInst; // For getting `Val`.
+  friend class CallBrInst; // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -545,6 +547,7 @@ class Instruction : public sandboxir::User {
   friend class ReturnInst; // For getTopmostLLVMInstruction().
   friend class CallInst;   // For getTopmostLLVMInstruction().
   friend class InvokeInst; // For getTopmostLLVMInstruction().
+  friend class CallBrInst; // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -870,6 +873,7 @@ class CallBase : public Instruction {
       : Instruction(ID, Opc, I, Ctx) {}
   friend class CallInst;   // For constructor.
   friend class InvokeInst; // For constructor.
+  friend class CallBrInst; // For constructor.
 
 public:
   static bool classof(const Value *From) {
@@ -1101,6 +1105,69 @@ class InvokeInst final : public CallBase {
 #endif
 };
 
+class CallBrInst final : public CallBase {
+  /// Use Context::createCallBrInst(). Don't call the
+  /// constructor directly.
+  CallBrInst(llvm::Instruction *I, Context &Ctx)
+      : CallBase(ClassID::CallBr, Opcode::CallBr, I, Ctx) {}
+  friend class Context; // For accessing the constructor in
+                        // create*()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  static CallBrInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *DefaultDest,
+                            ArrayRef<BasicBlock *> IndirectDests,
+                            ArrayRef<Value *> Args, BBIterator WhereIt,
+                            BasicBlock *WhereBB, Context &Ctx,
+                            const Twine &NameStr = "");
+  static CallBrInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *DefaultDest,
+                            ArrayRef<BasicBlock *> IndirectDests,
+                            ArrayRef<Value *> Args, Instruction *InsertBefore,
+                            Context &Ctx, const Twine &NameStr = "");
+  static CallBrInst *create(FunctionType *FTy, Value *Func,
+                            BasicBlock *DefaultDest,
+                            ArrayRef<BasicBlock *> IndirectDests,
+                            ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                            Context &Ctx, const Twine &NameStr = "");
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::CallBr;
+  }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  unsigned getNumIndirectDests() const {
+    return cast<llvm::CallBrInst>(Val)->getNumIndirectDests();
+  }
+  Value *getIndirectDestLabel(unsigned Idx) const;
+  Value *getIndirectDestLabelUse(unsigned Idx) const;
+  BasicBlock *getDefaultDest() const;
+  BasicBlock *getIndirectDest(unsigned Idx) const;
+  SmallVector<BasicBlock *, 16> getIndirectDests() const;
+  void setDefaultDest(BasicBlock *BB);
+  void setIndirectDest(unsigned Idx, BasicBlock *BB);
+  BasicBlock *getSuccessor(unsigned Idx) const;
+  unsigned getNumSuccessors() const {
+    return cast<llvm::CallBrInst>(Val)->getNumSuccessors();
+  }
+#ifndef NDEBUG
+  void verify() const final {}
+  friend raw_ostream &operator<<(raw_ostream &OS, const CallBrInst &I) {
+    I.dump(OS);
+    return OS;
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -1253,6 +1320,8 @@ class Context {
   friend CallInst; // For createCallInst()
   InvokeInst *createInvokeInst(llvm::InvokeInst *I);
   friend InvokeInst; // For createInvokeInst()
+  CallBrInst *createCallBrInst(llvm::CallBrInst *I);
+  friend CallBrInst; // For createCallBrInst()
 
 public:
   Context(LLVMContext &LLVMCtx)
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 3daec3fd5c63c..64068461b9490 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -53,6 +53,7 @@
 namespace llvm::sandboxir {
 
 class BasicBlock;
+class CallBrInst;
 class Instruction;
 class Tracker;
 
@@ -177,6 +178,41 @@ class RemoveFromParent : public IRChangeBase {
 #endif // NDEBUG
 };
 
+class CallBrInstSetDefaultDest : public IRChangeBase {
+  CallBrInst *CallBr;
+  BasicBlock *OrigDefaultDest;
+
+public:
+  CallBrInstSetDefaultDest(CallBrInst *CallBr, Tracker &Tracker);
+  void revert() final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "CallBrInstSetDefaultDest";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
+class CallBrInstSetIndirectDest : public IRChangeBase {
+  CallBrInst *CallBr;
+  unsigned Idx;
+  BasicBlock *OrigIndirectDest;
+
+public:
+  CallBrInstSetIndirectDest(CallBrInst *CallBr, unsigned Idx, Tracker &Tracker);
+  void revert() final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "CallBrInstSetIndirectDest";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
 class MoveInstr : public IRChangeBase {
   /// The instruction that moved.
   Instruction *MovedI;
diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index d30eb90594294..ef728ea387851 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -42,6 +42,7 @@ class Use {
   friend class OperandUseIterator; // For constructor
   friend class UserUseIterator;    // For accessing members
   friend class CallBase;           // For LLVMUse
+  friend class CallBrInst;         // For constructor
 
 public:
   operator Value *() const { return get(); }
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 2dc9f5864dc5c..8bb81a79b61f7 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/SandboxIR/SandboxIR.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/Support/Debug.h"
 #include <sstream>
@@ -884,6 +885,105 @@ void InvokeInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *DefaultDest,
+                               ArrayRef<BasicBlock *> IndirectDests,
+                               ArrayRef<Value *> Args, BBIterator WhereIt,
+                               BasicBlock *WhereBB, Context &Ctx,
+                               const Twine &NameStr) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  if (WhereIt != WhereBB->end())
+    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
+  else
+    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+
+  SmallVector<llvm::BasicBlock *> LLVMIndirectDests;
+  LLVMIndirectDests.reserve(IndirectDests.size());
+  for (BasicBlock *IndDest : IndirectDests)
+    LLVMIndirectDests.push_back(cast<llvm::BasicBlock>(IndDest->Val));
+
+  SmallVector<llvm::Value *> LLVMArgs;
+  LLVMArgs.reserve(Args.size());
+  for (Value *Arg : Args)
+    LLVMArgs.push_back(Arg->Val);
+
+  llvm::CallBrInst *CallBr = Builder.CreateCallBr(
+      FTy, Func->Val, cast<llvm::BasicBlock>(DefaultDest->Val),
+      LLVMIndirectDests, LLVMArgs, NameStr);
+  return Ctx.createCallBrInst(CallBr);
+}
+
+CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *DefaultDest,
+                               ArrayRef<BasicBlock *> IndirectDests,
+                               ArrayRef<Value *> Args,
+                               Instruction *InsertBefore, Context &Ctx,
+                               const Twine &NameStr) {
+  return create(FTy, Func, DefaultDest, IndirectDests, Args,
+                InsertBefore->getIterator(), InsertBefore->getParent(), Ctx,
+                NameStr);
+}
+CallBrInst *CallBrInst::create(FunctionType *FTy, Value *Func,
+                               BasicBlock *DefaultDest,
+                               ArrayRef<BasicBlock *> IndirectDests,
+                               ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
+                               Context &Ctx, const Twine &NameStr) {
+  return create(FTy, Func, DefaultDest, IndirectDests, Args, InsertAtEnd->end(),
+                InsertAtEnd, Ctx, NameStr);
+}
+
+Value *CallBrInst::getIndirectDestLabel(unsigned Idx) const {
+  return Ctx.getValue(cast<llvm::CallBrInst>(Val)->getIndirectDestLabel(Idx));
+}
+Value *CallBrInst::getIndirectDestLabelUse(unsigned Idx) const {
+  return Ctx.getValue(
+      cast<llvm::CallBrInst>(Val)->getIndirectDestLabelUse(Idx));
+}
+BasicBlock *CallBrInst::getDefaultDest() const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::CallBrInst>(Val)->getDefaultDest()));
+}
+BasicBlock *CallBrInst::getIndirectDest(unsigned Idx) const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::CallBrInst>(Val)->getIndirectDest(Idx)));
+}
+llvm::SmallVector<BasicBlock *, 16> CallBrInst::getIndirectDests() const {
+  SmallVector<BasicBlock *, 16> BBs;
+  for (llvm::BasicBlock *LLVMBB :
+       cast<llvm::CallBrInst>(Val)->getIndirectDests())
+    BBs.push_back(cast<BasicBlock>(Ctx.getValue(LLVMBB)));
+  return BBs;
+}
+void CallBrInst::setDefaultDest(BasicBlock *BB) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<CallBrInstSetDefaultDest>(this, Tracker));
+  cast<llvm::CallBrInst>(Val)->setDefaultDest(cast<llvm::BasicBlock>(BB->Val));
+}
+void CallBrInst::setIndirectDest(unsigned Idx, BasicBlock *BB) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(
+        std::make_unique<CallBrInstSetIndirectDest>(this, Idx, Tracker));
+  cast<llvm::CallBrInst>(Val)->setIndirectDest(Idx,
+                                               cast<llvm::BasicBlock>(BB->Val));
+}
+BasicBlock *CallBrInst::getSuccessor(unsigned Idx) const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::CallBrInst>(Val)->getSuccessor(Idx)));
+}
+
+#ifndef NDEBUG
+void CallBrInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+void CallBrInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -1048,6 +1148,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
     return It->second.get();
   }
+  case llvm::Instruction::CallBr: {
+    auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV);
+    It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
+    return It->second.get();
+  }
   default:
     break;
   }
@@ -1101,6 +1206,11 @@ InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
   return cast<InvokeInst>(registerValue(std::move(NewPtr)));
 }
 
+CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
+  auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
+  return cast<CallBrInst>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index c74177608aff2..eae55d7b3d962 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -129,6 +129,37 @@ void RemoveFromParent::dump() const {
 }
 #endif
 
+CallBrInstSetDefaultDest::CallBrInstSetDefaultDest(CallBrInst *CallBr,
+                                                   Tracker &Tracker)
+    : IRChangeBase(Tracker), CallBr(CallBr) {
+  OrigDefaultDest = CallBr->getDefaultDest();
+}
+void CallBrInstSetDefaultDest::revert() {
+  CallBr->setDefaultDest(OrigDefaultDest);
+}
+#ifndef NDEBUG
+void CallBrInstSetDefaultDest::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif
+
+CallBrInstSetIndirectDest::CallBrInstSetIndirectDest(CallBrInst *CallBr,
+                                                     unsigned Idx,
+                                                     Tracker &Tracker)
+    : IRChangeBase(Tracker), CallBr(CallBr), Idx(Idx) {
+  OrigIndirectDest = CallBr->getIndirectDest(Idx);
+}
+void CallBrInstSetIndirectDest::revert() {
+  CallBr->setIndirectDest(Idx, OrigIndirectDest);
+}
+#ifndef NDEBUG
+void CallBrInstSetIndirectDest::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif
+
 MoveInstr::MoveInstr(Instruction *MovedI, Tracker &Tracker)
     : IRChangeBase(Tracker), MovedI(MovedI) {
   if (auto *NextI = MovedI->getNextNode())
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 7edfe457c32a9..d8ee74057fed0 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1133,3 +1133,117 @@ define void @foo(i8 %arg) {
     EXPECT_EQ(NewInvoke->getNextNode(), nullptr);
   }
 }
+
+TEST_F(SandboxIRTest, CallBrInst) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg) {
+ bb0:
+   callbr void asm "", ""()
+               to label %bb1 [label %bb2]
+ bb1:
+   ret void
+ bb2:
+   ret void
+ other_bb:
+   ret void
+ bb3:
+   callbr void @foo(i8 %arg)
+               to label %bb1 [label %bb2]
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMBB0 = getBasicBlockByName(LLVMF, "bb0");
+  auto *LLVMCallBr = cast<llvm::CallBrInst>(&*LLVMBB0->begin());
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *Arg = F.getArg(0);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *BB2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb2")));
+  auto *BB3 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb3")));
+  auto *OtherBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
+  auto It = BB0->begin();
+  // Check classof(Instruction *).
+  auto *CallBr0 = cast<sandboxir::CallBrInst>(&*It++);
+
+  It = BB3->begin();
+  auto *CallBr1 = cast<sandboxir::CallBrInst>(&*It++);
+  for (sandboxir::CallBrInst *CallBr : {CallBr0, CallBr1}) {
+    // Check getNumIndirectDests().
+    EXPECT_EQ(CallBr->getNumIndirectDests(), 1u);
+    // Check getIndirectDestLabel().
+    EXPECT_EQ(CallBr->getIndirectDestLabel(0),
+              Ctx.getValue(LLVMCallBr->getIndirectDestLabel(0)));
+    // Check getIndirectDestLabelUse().
+    EXPECT_EQ(CallBr->getIndirectDestLabelUse(0),
+              Ctx.getValue(LLVMCallBr->getIndirectDestLabelUse(0)));
+    // Check getDefaultDest().
+    EXPECT_EQ(CallBr->getDefaultDest(),
+              Ctx.getValue(LLVMCallBr->getDefaultDest()));
+    // Check getIndirectDest().
+    EXPECT_EQ(CallBr->getIndirectDest(0),
+              Ctx.getValue(LLVMCallBr->getIndirectDest(0)));
+    // Check getIndirectDests().
+    auto Dests = CallBr->getIndirectDests();
+    EXPECT_EQ(Dests.size(), LLVMCallBr->getIndirectDests().size());
+    EXPECT_EQ(Dests[0], Ctx.getValue(LLVMCallBr->getIndirectDests()[0]));
+    // Check getNumSuccessors().
+    EXPECT_EQ(CallBr->getNumSuccessors(), LLVMCallBr->getNumSuccessors());
+    // Check getSuccessor().
+    for (unsigned SuccIdx = 0, E = CallBr->getNumSuccessors(); SuccIdx != E;
+         ++SuccIdx)
+      EXPECT_EQ(CallBr->getSuccessor(SuccIdx),
+                Ctx.getValue(LLVMCallBr->getSuccessor(SuccIdx)));
+    // Check setDefaultDest().
+    auto *SvDefaultDest = CallBr->getDefaultDest();
+    CallBr->setDefaultDest(OtherBB);
+    EXPECT_EQ(CallBr->getDefaultDest(), OtherBB);
+    CallBr->setDefaultDest(SvDefaultDest);
+    // Check setIndirectDest().
+    auto *SvIndirectDest = CallBr->getIndirectDest(0);
+    CallBr->setIndirectDest(0, OtherBB);
+    EXPECT_EQ(CallBr->getIndirectDest(0), OtherBB);
+    CallBr->setIndirectDest(0, SvIndirectDest);
+  }
+
+  {
+    // Check create() WhereIt, WhereBB.
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *NewCallBr = cast<sandboxir::CallBrInst>(sandboxir::CallBrInst::create(
+        F.getFunctionType(), &F, BB1, {BB2}, Args, /*WhereIt=*/BB0->end(),
+        /*WhereBB=*/BB0, Ctx));
+    EXPECT_EQ(NewCallBr->getDefaultDest(), BB1);
+    EXPECT_EQ(NewCallBr->getIndirectDests().size(), 1u);
+    EXPECT_EQ(NewCallBr->getIndirectDests()[0], BB2);
+    EXPECT_EQ(NewCallBr->getNextNode(), nullptr);
+    EXPECT_EQ(NewCallBr->getParent(), BB0);
+  }
+  {
+    // Check create() InsertBefore
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *InsertBefore = &*BB0->rbegin();
+    auto *NewCallBr = cast<sandboxir::CallBrInst>(sandboxir::CallBrInst::create(
+        F.getFunctionType(), &F, BB1, {BB2}, Args, InsertBefore, Ctx));
+    EXPECT_EQ(NewCallBr->getDefaultDest(), BB1);
+    EXPECT_EQ(NewCallBr->getIndirectDests().size(), 1u);
+    EXPECT_EQ(NewCallBr->getIndirectDests()[0], BB2);
+    EXPECT_EQ(NewCallBr->getNextNode(), InsertBefore);
+  }
+  {
+    // Check create() InsertAtEnd.
+    SmallVector<sandboxir::Value *> Args({Arg});
+    auto *NewCallBr = cast<sandboxir::CallBrInst>(
+        sandboxir::CallBrInst::create(F.getFunctionType(), &F, BB1, {BB2}, Args,
+                                      /*InsertAtEnd=*/BB0, Ctx));
+    EXPECT_EQ(NewCallBr->getDefaultDest(), BB1);
+    EXPECT_EQ(NewCallBr->getIndirectDests().size(), 1u);
+    EXPECT_EQ(NewCallBr->getIndirectDests()[0], BB2);
+    EXPECT_EQ(NewCallBr->getNextNode(), nullptr);
+    EXPECT_EQ(NewCallBr->getParent(), BB0);
+  }
+}
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 04536411e02d0..cd737d33dd193 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -544,3 +544,43 @@ define void @foo(i8 %arg) {
   Ctx.revert();
   EXPECT_EQ(Invoke->getSuccessor(1), ExceptionBB);
 }
+
+TEST_F(TrackerTest, CallBrSetters) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg) {
+ bb0:
+   callbr void @foo(i8 %arg)
+               to label %bb1 [label %bb2]
+ bb1:
+   ret void
+ bb2:
+   ret void
+ other_bb:
+   ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  [[maybe_unused]] auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *OtherBB = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "other_bb")));
+  auto It = BB0->begin();
+  auto *CallBr = cast<sandboxir::CallBrInst>(&*It++);
+  // Check setDefaultDest().
+  Ctx.save();
+  auto *OrigDefaultDest = CallBr->getDefaultDest();
+  CallBr->setDefaultDest(OtherBB);
+  EXPECT_EQ(CallBr->getDefaultDest(), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(CallBr->getDefaultDest(), OrigDefaultDest);
+
+  // Check setIndirectDest().
+  Ctx.save();
+  auto *OrigIndirectDest = CallBr->getIndirectDest(0);
+  CallBr->setIndirectDest(0, OtherBB);
+  EXPECT_EQ(CallBr->getIndirectDest(0), OtherBB);
+  Ctx.revert();
+  EXPECT_EQ(CallBr->getIndirectDest(0), OrigIndirectDest);
+}



More information about the llvm-commits mailing list