[llvm] 3993da2 - [SandboxIR] Implement BranchInst (#100063)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 23 11:50:37 PDT 2024


Author: vporpo
Date: 2024-07-23T11:50:33-07:00
New Revision: 3993da23daa0ae75e9e80def76854534903e3761

URL: https://github.com/llvm/llvm-project/commit/3993da23daa0ae75e9e80def76854534903e3761
DIFF: https://github.com/llvm/llvm-project/commit/3993da23daa0ae75e9e80def76854534903e3761.diff

LOG: [SandboxIR] Implement BranchInst (#100063)

This patch implements sandboxir::BranchInst which mirrors
llvm::BranchInst.

BranchInst::swapSuccessors() relies on User::swapOperandsInternal() so
this patch also adds Use::swap() and the corresponding tracking code and
test.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/include/llvm/SandboxIR/Tracker.h
    llvm/include/llvm/SandboxIR/Use.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/lib/SandboxIR/Tracker.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp
    llvm/unittests/SandboxIR/TrackerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 0c67206d307ef..6c04c92e3e70e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -76,6 +76,7 @@ class Context;
 class Function;
 class Instruction;
 class SelectInst;
+class BranchInst;
 class LoadInst;
 class ReturnInst;
 class StoreInst;
@@ -179,6 +180,7 @@ class Value {
   friend class User;       // For getting `Val`.
   friend class Use;        // For getting `Val`.
   friend class SelectInst; // For getting `Val`.
+  friend class BranchInst; // For getting `Val`.
   friend class LoadInst;   // For getting `Val`.
   friend class StoreInst;  // For getting `Val`.
   friend class ReturnInst; // For getting `Val`.
@@ -343,6 +345,14 @@ class User : public Value {
   virtual unsigned getUseOperandNo(const Use &Use) const = 0;
   friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
 
+  void swapOperandsInternal(unsigned OpIdxA, unsigned OpIdxB) {
+    assert(OpIdxA < getNumOperands() && "OpIdxA out of bounds!");
+    assert(OpIdxB < getNumOperands() && "OpIdxB out of bounds!");
+    auto UseA = getOperandUse(OpIdxA);
+    auto UseB = getOperandUse(OpIdxB);
+    UseA.swap(UseB);
+  }
+
 #ifndef NDEBUG
   void verifyUserOfLLVMUse(const llvm::Use &Use) const;
 #endif // NDEBUG
@@ -504,6 +514,7 @@ class Instruction : public sandboxir::User {
   /// returns its topmost LLVM IR instruction.
   llvm::Instruction *getTopmostLLVMInstruction() const;
   friend class SelectInst; // For getTopmostLLVMInstruction().
+  friend class BranchInst; // For getTopmostLLVMInstruction().
   friend class LoadInst;   // For getTopmostLLVMInstruction().
   friend class StoreInst;  // For getTopmostLLVMInstruction().
   friend class ReturnInst; // For getTopmostLLVMInstruction().
@@ -617,6 +628,100 @@ class SelectInst : public Instruction {
 #endif
 };
 
+class BranchInst : public Instruction {
+  /// Use Context::createBranchInst(). Don't call the constructor directly.
+  BranchInst(llvm::BranchInst *BI, Context &Ctx)
+      : Instruction(ClassID::Br, Opcode::Br, BI, Ctx) {}
+  friend Context; // for BranchInst()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  static BranchInst *create(BasicBlock *IfTrue, Instruction *InsertBefore,
+                            Context &Ctx);
+  static BranchInst *create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
+                            Context &Ctx);
+  static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
+                            Value *Cond, Instruction *InsertBefore,
+                            Context &Ctx);
+  static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
+                            Value *Cond, BasicBlock *InsertAtEnd, Context &Ctx);
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From);
+  bool isUnconditional() const {
+    return cast<llvm::BranchInst>(Val)->isUnconditional();
+  }
+  bool isConditional() const {
+    return cast<llvm::BranchInst>(Val)->isConditional();
+  }
+  Value *getCondition() const;
+  void setCondition(Value *V) { setOperand(0, V); }
+  unsigned getNumSuccessors() const { return 1 + isConditional(); }
+  BasicBlock *getSuccessor(unsigned SuccIdx) const;
+  void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
+  void swapSuccessors() { swapOperandsInternal(1, 2); }
+
+private:
+  struct LLVMBBToSBBB {
+    Context &Ctx;
+    LLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
+    BasicBlock *operator()(llvm::BasicBlock *BB) const;
+  };
+
+  struct ConstLLVMBBToSBBB {
+    Context &Ctx;
+    ConstLLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
+    const BasicBlock *operator()(const llvm::BasicBlock *BB) const;
+  };
+
+public:
+  using sb_succ_op_iterator =
+      mapped_iterator<llvm::BranchInst::succ_op_iterator, LLVMBBToSBBB>;
+  iterator_range<sb_succ_op_iterator> successors() {
+    iterator_range<llvm::BranchInst::succ_op_iterator> LLVMRange =
+        cast<llvm::BranchInst>(Val)->successors();
+    LLVMBBToSBBB BBMap(Ctx);
+    sb_succ_op_iterator MappedBegin = map_iterator(LLVMRange.begin(), BBMap);
+    sb_succ_op_iterator MappedEnd = map_iterator(LLVMRange.end(), BBMap);
+    return make_range(MappedBegin, MappedEnd);
+  }
+
+  using const_sb_succ_op_iterator =
+      mapped_iterator<llvm::BranchInst::const_succ_op_iterator,
+                      ConstLLVMBBToSBBB>;
+  iterator_range<const_sb_succ_op_iterator> successors() const {
+    iterator_range<llvm::BranchInst::const_succ_op_iterator> ConstLLVMRange =
+        static_cast<const llvm::BranchInst *>(cast<llvm::BranchInst>(Val))
+            ->successors();
+    ConstLLVMBBToSBBB ConstBBMap(Ctx);
+    const_sb_succ_op_iterator ConstMappedBegin =
+        map_iterator(ConstLLVMRange.begin(), ConstBBMap);
+    const_sb_succ_op_iterator ConstMappedEnd =
+        map_iterator(ConstLLVMRange.end(), ConstBBMap);
+    return make_range(ConstMappedBegin, ConstMappedEnd);
+  }
+
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<llvm::BranchInst>(Val) && "Expected BranchInst!");
+  }
+  friend raw_ostream &operator<<(raw_ostream &OS, const BranchInst &BI) {
+    BI.dump(OS);
+    return OS;
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 class LoadInst final : public Instruction {
   /// Use LoadInst::create() instead of calling the constructor.
   LoadInst(llvm::LoadInst *LI, Context &Ctx)
@@ -870,6 +975,8 @@ class Context {
 
   SelectInst *createSelectInst(llvm::SelectInst *SI);
   friend SelectInst; // For createSelectInst()
+  BranchInst *createBranchInst(llvm::BranchInst *I);
+  friend BranchInst; // For createBranchInst()
   LoadInst *createLoadInst(llvm::LoadInst *LI);
   friend LoadInst; // For createLoadInst()
   StoreInst *createStoreInst(llvm::StoreInst *SI);

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index efa9155755587..f3d616774b3fd 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -26,6 +26,7 @@ DEF_USER(Constant, Constant)
 //       ClassID, Opcode(s),  Class
 DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
 DEF_INSTR(Select, OP(Select), SelectInst)
+DEF_INSTR(Br, OP(Br), BranchInst)
 DEF_INSTR(Load, OP(Load), LoadInst)
 DEF_INSTR(Store, OP(Store), StoreInst)
 DEF_INSTR(Ret, OP(Ret), ReturnInst)

diff  --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index b88eb3d2a5280..3daec3fd5c63c 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -101,6 +101,27 @@ class UseSet : public IRChangeBase {
 #endif
 };
 
+/// Tracks swapping a Use with another Use.
+class UseSwap : public IRChangeBase {
+  Use ThisUse;
+  Use OtherUse;
+
+public:
+  UseSwap(const Use &ThisUse, const Use &OtherUse, Tracker &Tracker)
+      : IRChangeBase(Tracker), ThisUse(ThisUse), OtherUse(OtherUse) {
+    assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!");
+  }
+  void revert() final { ThisUse.swap(OtherUse); }
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "UseSwap";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
 class EraseFromParent : public IRChangeBase {
   /// Contains all the data we need to restore an "erased" (i.e., detached)
   /// instruction: the instruction itself and its operands in order.

diff  --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index d77b4568d0fab..03cbfe6cb0446 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -47,6 +47,7 @@ class Use {
   void set(Value *V);
   class User *getUser() const { return Usr; }
   unsigned getOperandNo() const;
+  void swap(Use &OtherUse);
   Context *getContext() const { return Ctx; }
   bool operator==(const Use &Other) const {
     assert(Ctx == Other.Ctx && "Contexts 
diff er!");

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 51c9af8a6e1fe..ceadb34f53eaf 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -20,6 +20,13 @@ void Use::set(Value *V) { LLVMUse->set(V->Val); }
 
 unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
 
+void Use::swap(Use &OtherUse) {
+  auto &Tracker = Ctx->getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<UseSwap>(*this, OtherUse, Tracker));
+  LLVMUse->swap(*OtherUse.LLVMUse);
+}
+
 #ifndef NDEBUG
 void Use::dump(raw_ostream &OS) const {
   Value *Def = nullptr;
@@ -500,6 +507,85 @@ void SelectInst::dump() const {
 }
 #endif // NDEBUG
 
+BranchInst *BranchInst::create(BasicBlock *IfTrue, Instruction *InsertBefore,
+                               Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
+  llvm::BranchInst *NewBr =
+      Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
+  return Ctx.createBranchInst(NewBr);
+}
+
+BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
+                               Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  llvm::BranchInst *NewBr =
+      Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
+  return Ctx.createBranchInst(NewBr);
+}
+
+BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
+                               Value *Cond, Instruction *InsertBefore,
+                               Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
+  llvm::BranchInst *NewBr =
+      Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
+                           cast<llvm::BasicBlock>(IfFalse->Val));
+  return Ctx.createBranchInst(NewBr);
+}
+
+BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
+                               Value *Cond, BasicBlock *InsertAtEnd,
+                               Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  llvm::BranchInst *NewBr =
+      Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
+                           cast<llvm::BasicBlock>(IfFalse->Val));
+  return Ctx.createBranchInst(NewBr);
+}
+
+bool BranchInst::classof(const Value *From) {
+  return From->getSubclassID() == ClassID::Br;
+}
+
+Value *BranchInst::getCondition() const {
+  assert(isConditional() && "Cannot get condition of an uncond branch!");
+  return Ctx.getValue(cast<llvm::BranchInst>(Val)->getCondition());
+}
+
+BasicBlock *BranchInst::getSuccessor(unsigned SuccIdx) const {
+  assert(SuccIdx < getNumSuccessors() &&
+         "Successor # out of range for Branch!");
+  return cast_or_null<BasicBlock>(
+      Ctx.getValue(cast<llvm::BranchInst>(Val)->getSuccessor(SuccIdx)));
+}
+
+void BranchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
+  assert((Idx == 0 || Idx == 1) && "Out of bounds!");
+  setOperand(2u - Idx, NewSucc);
+}
+
+BasicBlock *BranchInst::LLVMBBToSBBB::operator()(llvm::BasicBlock *BB) const {
+  return cast<BasicBlock>(Ctx.getValue(BB));
+}
+const BasicBlock *
+BranchInst::ConstLLVMBBToSBBB::operator()(const llvm::BasicBlock *BB) const {
+  return cast<BasicBlock>(Ctx.getValue(BB));
+}
+#ifndef NDEBUG
+void BranchInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+void BranchInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
                            Instruction *InsertBefore, Context &Ctx,
                            const Twine &Name) {
@@ -758,6 +844,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Br: {
+    auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
+    It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::Load: {
     auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
     It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
@@ -796,6 +887,11 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
   return cast<SelectInst>(registerValue(std::move(NewPtr)));
 }
 
+BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
+  auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
+  return cast<BranchInst>(registerValue(std::move(NewPtr)));
+}
+
 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
   auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
   return cast<LoadInst>(registerValue(std::move(NewPtr)));

diff  --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 626c9c27d05e5..c74177608aff2 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -35,6 +35,11 @@ void UseSet::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+
+void UseSwap::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 #endif // NDEBUG
 
 Tracker::~Tracker() {

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ba90b4f811f8e..783f606c70380 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -398,7 +398,7 @@ define void @foo(i32 %arg0, i32 %arg1) {
     EXPECT_EQ(Buff, R"IR(
 void @foo(i32 %arg0, i32 %arg1) {
 bb0:
-  br label %bb1 ; SB3. (Opaque)
+  br label %bb1 ; SB3. (Br)
 
 bb1:
   ret void ; SB5. (Ret)
@@ -466,7 +466,7 @@ define void @foo(i32 %v1) {
     BB0.dump(BS);
     EXPECT_EQ(Buff, R"IR(
 bb0:
-  br label %bb1 ; SB2. (Opaque)
+  br label %bb1 ; SB2. (Br)
 )IR");
   }
 #endif // NDEBUG
@@ -629,6 +629,107 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   }
 }
 
+TEST_F(SandboxIRTest, BranchInst) {
+  parseIR(C, R"IR(
+define void @foo(i1 %cond0, i1 %cond2) {
+ bb0:
+   br i1 %cond0, label %bb1, label %bb2
+ bb1:
+   ret void
+ bb2:
+   ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  auto *Cond0 = F->getArg(0);
+  auto *Cond1 = F->getArg(1);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(*LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(*LLVMF, "bb1")));
+  auto *Ret1 = BB1->getTerminator();
+  auto *BB2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(*LLVMF, "bb2")));
+  auto *Ret2 = BB2->getTerminator();
+  auto It = BB0->begin();
+  auto *Br0 = cast<sandboxir::BranchInst>(&*It++);
+  // Check isUnconditional().
+  EXPECT_FALSE(Br0->isUnconditional());
+  // Check isConditional().
+  EXPECT_TRUE(Br0->isConditional());
+  // Check getCondition().
+  EXPECT_EQ(Br0->getCondition(), Cond0);
+  // Check setCondition().
+  Br0->setCondition(Cond1);
+  EXPECT_EQ(Br0->getCondition(), Cond1);
+  // Check getNumSuccessors().
+  EXPECT_EQ(Br0->getNumSuccessors(), 2u);
+  // Check getSuccessor().
+  EXPECT_EQ(Br0->getSuccessor(0), BB1);
+  EXPECT_EQ(Br0->getSuccessor(1), BB2);
+  // Check swapSuccessors().
+  Br0->swapSuccessors();
+  EXPECT_EQ(Br0->getSuccessor(0), BB2);
+  EXPECT_EQ(Br0->getSuccessor(1), BB1);
+  // Check successors().
+  EXPECT_EQ(range_size(Br0->successors()), 2u);
+  unsigned SuccIdx = 0;
+  SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB1, BB2});
+  for (sandboxir::BasicBlock *Succ : Br0->successors())
+    EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+
+  {
+    // Check unconditional BranchInst::create() InsertBefore.
+    auto *Br = sandboxir::BranchInst::create(BB1, /*InsertBefore=*/Ret1, Ctx);
+    EXPECT_FALSE(Br->isConditional());
+    EXPECT_TRUE(Br->isUnconditional());
+    EXPECT_DEATH(Br->getCondition(), ".*condition.*");
+    unsigned SuccIdx = 0;
+    SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB1});
+    for (sandboxir::BasicBlock *Succ : Br->successors())
+      EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+    EXPECT_EQ(Br->getNextNode(), Ret1);
+  }
+  {
+    // Check unconditional BranchInst::create() InsertAtEnd.
+    auto *Br = sandboxir::BranchInst::create(BB1, /*InsertAtEnd=*/BB1, Ctx);
+    EXPECT_FALSE(Br->isConditional());
+    EXPECT_TRUE(Br->isUnconditional());
+    EXPECT_DEATH(Br->getCondition(), ".*condition.*");
+    unsigned SuccIdx = 0;
+    SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB1});
+    for (sandboxir::BasicBlock *Succ : Br->successors())
+      EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+    EXPECT_EQ(Br->getPrevNode(), Ret1);
+  }
+  {
+    // Check conditional BranchInst::create() InsertBefore.
+    auto *Br = sandboxir::BranchInst::create(BB1, BB2, Cond0,
+                                             /*InsertBefore=*/Ret1, Ctx);
+    EXPECT_TRUE(Br->isConditional());
+    EXPECT_EQ(Br->getCondition(), Cond0);
+    unsigned SuccIdx = 0;
+    SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB2, BB1});
+    for (sandboxir::BasicBlock *Succ : Br->successors())
+      EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+    EXPECT_EQ(Br->getNextNode(), Ret1);
+  }
+  {
+    // Check conditional BranchInst::create() InsertAtEnd.
+    auto *Br = sandboxir::BranchInst::create(BB1, BB2, Cond0,
+                                             /*InsertAtEnd=*/BB2, Ctx);
+    EXPECT_TRUE(Br->isConditional());
+    EXPECT_EQ(Br->getCondition(), Cond0);
+    unsigned SuccIdx = 0;
+    SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB2, BB1});
+    for (sandboxir::BasicBlock *Succ : Br->successors())
+      EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+    EXPECT_EQ(Br->getPrevNode(), Ret2);
+  }
+}
+
 TEST_F(SandboxIRTest, LoadInst) {
   parseIR(C, R"IR(
 define void @foo(ptr %arg0, ptr %arg1) {

diff  --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 354cd187adb10..dd9dcd543236e 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -69,6 +69,49 @@ define void @foo(ptr %ptr) {
   EXPECT_EQ(Ld->getOperand(0), Gep0);
 }
 
+TEST_F(TrackerTest, SwapOperands) {
+  parseIR(C, R"IR(
+define void @foo(i1 %cond) {
+ bb0:
+   br i1 %cond, label %bb1, label %bb2
+ bb1:
+   ret void
+ bb2:
+   ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *BB2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb2")));
+  auto &Tracker = Ctx.getTracker();
+  Tracker.save();
+  auto It = BB0->begin();
+  auto *Br = cast<sandboxir::BranchInst>(&*It++);
+
+  unsigned SuccIdx = 0;
+  SmallVector<sandboxir::BasicBlock *> ExpectedSuccs({BB2, BB1});
+  for (auto *Succ : Br->successors())
+    EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+
+  // This calls User::swapOperandsInternal() internally.
+  Br->swapSuccessors();
+
+  SuccIdx = 0;
+  for (auto *Succ : reverse(Br->successors()))
+    EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+
+  Ctx.getTracker().revert();
+  SuccIdx = 0;
+  for (auto *Succ : Br->successors())
+    EXPECT_EQ(Succ, ExpectedSuccs[SuccIdx++]);
+}
+
 TEST_F(TrackerTest, RUWIf_RAUW_RUOW) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr) {


        


More information about the llvm-commits mailing list