[llvm] [SandboxIR] Implement PHINodes (PR #101111)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 31 11:30:43 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/101111

>From 2dc9acaa6daa1369b7fa4b081010a986a0d3fdc7 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 30 Jul 2024 01:10:24 +0000
Subject: [PATCH 1/3] [SandboxIR] Implement PHINodes

This patch implements sandboxir::PHINode which mirrors llvm::PHINode.

Based almost entirely on work by vporpo.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  95 ++++++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |   1 +
 llvm/include/llvm/SandboxIR/Tracker.h         |  58 ++++++++
 llvm/include/llvm/SandboxIR/Use.h             |   2 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 111 ++++++++++++++++
 llvm/lib/SandboxIR/Tracker.cpp                |  75 +++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 120 +++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp      | 124 ++++++++++++++++++
 8 files changed, 586 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 01aa6146c7228..02a61bf0a0ecb 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -210,6 +210,7 @@ class Value {
   friend class InvokeInst;        // For getting `Val`.
   friend class CallBrInst;        // For getting `Val`.
   friend class GetElementPtrInst; // For getting `Val`.
+  friend class PHINode;           // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -551,6 +552,7 @@ class Instruction : public sandboxir::User {
   friend class InvokeInst;        // For getTopmostLLVMInstruction().
   friend class CallBrInst;        // For getTopmostLLVMInstruction().
   friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
+  friend class PHINode;           // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -1290,6 +1292,97 @@ class GetElementPtrInst final : public Instruction {
 #endif
 };
 
+class PHINode final : public Instruction {
+  /// Use Context::createPHINode(). Don't call the
+  /// constructor directly.
+  PHINode(llvm::PHINode *PHI, Context &Ctx)
+      : Instruction(ClassID::PHI, Opcode::PHI, PHI, Ctx) {}
+  friend Context; // for SBPHINode()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+  /// Helper for mapped_iterator.
+  struct LLVMBBToBB {
+    Context &Ctx;
+    LLVMBBToBB(Context &Ctx) : Ctx(Ctx) {}
+    BasicBlock *operator()(llvm::BasicBlock *LLVMBB) const;
+  };
+
+public:
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  static PHINode *create(Type *Ty, unsigned NumReservedValues,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name = "");
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From);
+
+  using const_block_iterator =
+      mapped_iterator<llvm::PHINode::const_block_iterator, LLVMBBToBB>;
+
+  const_block_iterator block_begin() const {
+    LLVMBBToBB BBGetter(Ctx);
+    return const_block_iterator(cast<llvm::PHINode>(Val)->block_begin(),
+                                BBGetter);
+  }
+  const_block_iterator block_end() const {
+    LLVMBBToBB BBGetter(Ctx);
+    return const_block_iterator(cast<llvm::PHINode>(Val)->block_end(),
+                                BBGetter);
+  }
+  iterator_range<const_block_iterator> blocks() const {
+    return make_range(block_begin(), block_end());
+  }
+
+  op_range incoming_values() { return operands(); }
+
+  const_op_range incoming_values() const { return operands(); }
+
+  unsigned getNumIncomingValues() const {
+    return cast<llvm::PHINode>(Val)->getNumIncomingValues();
+  }
+  Value *getIncomingValue(unsigned Idx) const;
+  void setIncomingValue(unsigned Idx, Value *V);
+  static unsigned getOperandNumForIncomingValue(unsigned Idx) {
+    return llvm::PHINode::getOperandNumForIncomingValue(Idx);
+  }
+  static unsigned getIncomingValueNumForOperand(unsigned Idx) {
+    return llvm::PHINode::getIncomingValueNumForOperand(Idx);
+  }
+  BasicBlock *getIncomingBlock(unsigned Idx) const;
+  BasicBlock *getIncomingBlock(const Use &U) const;
+
+  void setIncomingBlock(unsigned Idx, BasicBlock *BB);
+
+  void addIncoming(Value *V, BasicBlock *BB);
+
+  Value *removeIncomingValue(unsigned Idx);
+  Value *removeIncomingValue(BasicBlock *BB);
+
+  int getBasicBlockIndex(const BasicBlock *BB) const;
+  Value *getIncomingValueForBlock(const BasicBlock *BB) const;
+
+  Value *hasConstantValue() const;
+
+  bool hasConstantOrUndefValue() const {
+    return cast<llvm::PHINode>(Val)->hasConstantOrUndefValue();
+  }
+  bool isComplete() const { return cast<llvm::PHINode>(Val)->isComplete(); }
+
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<llvm::PHINode>(Val) && "Expected PHINode!");
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -1446,6 +1539,8 @@ class Context {
   friend CallBrInst; // For createCallBrInst()
   GetElementPtrInst *createGetElementPtrInst(llvm::GetElementPtrInst *I);
   friend GetElementPtrInst; // For createGetElementPtrInst()
+  PHINode *createPHINode(llvm::PHINode *I);
+  friend PHINode; // For createPHINode()
 
 public:
   Context(LLVMContext &LLVMCtx)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 29fb3aa396fdb..d5b6f1c51d7d1 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -34,6 +34,7 @@ DEF_INSTR(Call, OP(Call), CallInst)
 DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
 DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
 DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(PHI, OP(PHI), PHINode)
 
 #ifdef DEF_VALUE
 #undef DEF_VALUE
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 64068461b9490..238e4e9dacd34 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -102,6 +102,64 @@ class UseSet : public IRChangeBase {
 #endif
 };
 
+class PHISetIncoming : public IRChangeBase {
+  PHINode &PHI;
+  unsigned Idx;
+  PointerUnion<Value *, BasicBlock *> OrigValueOrBB;
+
+public:
+  enum class What {
+    Value,
+    Block,
+  };
+  PHISetIncoming(PHINode &PHI, unsigned Idx, What What, Tracker &Tracker);
+  void revert() final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "PHISetIncoming";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
+class PHIRemoveIncoming : public IRChangeBase {
+  PHINode &PHI;
+  unsigned RemovedIdx;
+  Value *RemovedV;
+  BasicBlock *RemovedBB;
+
+public:
+  PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx, Tracker &Tracker);
+  void revert() final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "PHISetIncoming";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
+class PHIAddIncoming : public IRChangeBase {
+  PHINode &PHI;
+  unsigned Idx;
+
+public:
+  PHIAddIncoming(PHINode &PHI, Tracker &Tracker);
+  void revert() final;
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "PHISetIncoming";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+#endif
+};
+
 /// Tracks swapping a Use with another Use.
 class UseSwap : public IRChangeBase {
   Use ThisUse;
diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index ef728ea387851..35d01daf39f6e 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -22,6 +22,7 @@ class Context;
 class Value;
 class User;
 class CallBase;
+class PHINode;
 
 /// Represents a Def-use/Use-def edge in SandboxIR.
 /// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
@@ -43,6 +44,7 @@ class Use {
   friend class UserUseIterator;    // For accessing members
   friend class CallBase;           // For LLVMUse
   friend class CallBrInst;         // For constructor
+  friend class PHINode;            // For LLVMUse
 
 public:
   operator Value *() const { return get(); }
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 317a37bbaf6eb..f8cb715198e89 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1061,6 +1061,107 @@ void GetElementPtrInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+BasicBlock *PHINode::LLVMBBToBB::operator()(llvm::BasicBlock *LLVMBB) const {
+  return cast<BasicBlock>(Ctx.getValue(LLVMBB));
+}
+
+PHINode *PHINode::create(Type *Ty, unsigned NumReservedValues,
+                         Instruction *InsertBefore, Context &Ctx,
+                         const Twine &Name) {
+  llvm::PHINode *NewPHI = llvm::PHINode::Create(
+      Ty, NumReservedValues, Name, InsertBefore->getTopmostLLVMInstruction());
+  return Ctx.createPHINode(NewPHI);
+}
+
+bool PHINode::classof(const Value *From) {
+  return From->getSubclassID() == ClassID::PHI;
+}
+
+Value *PHINode::getIncomingValue(unsigned Idx) const {
+  return Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingValue(Idx));
+}
+void PHINode::setIncomingValue(unsigned Idx, Value *V) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<PHISetIncoming>(
+        *this, Idx, PHISetIncoming::What::Value, Tracker));
+
+  cast<llvm::PHINode>(Val)->setIncomingValue(Idx, V->Val);
+}
+BasicBlock *PHINode::getIncomingBlock(unsigned Idx) const {
+  return cast<BasicBlock>(
+      Ctx.getValue(cast<llvm::PHINode>(Val)->getIncomingBlock(Idx)));
+}
+BasicBlock *PHINode::getIncomingBlock(const Use &U) const {
+  llvm::Use *LLVMUse = U.LLVMUse;
+  llvm::BasicBlock *BB = cast<llvm::PHINode>(Val)->getIncomingBlock(*LLVMUse);
+  return cast<BasicBlock>(Ctx.getValue(BB));
+}
+void PHINode::setIncomingBlock(unsigned Idx, BasicBlock *BB) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<PHISetIncoming>(
+        *this, Idx, PHISetIncoming::What::Block, Tracker));
+  cast<llvm::PHINode>(Val)->setIncomingBlock(Idx,
+                                             cast<llvm::BasicBlock>(BB->Val));
+}
+void PHINode::addIncoming(Value *V, BasicBlock *BB) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<PHIAddIncoming>(*this, Tracker));
+
+  cast<llvm::PHINode>(Val)->addIncoming(V->Val,
+                                        cast<llvm::BasicBlock>(BB->Val));
+}
+Value *PHINode::removeIncomingValue(unsigned Idx) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<PHIRemoveIncoming>(*this, Idx, Tracker));
+
+  llvm::Value *LLVMV =
+      cast<llvm::PHINode>(Val)->removeIncomingValue(Idx,
+                                                    /*DeletePHIIfEmpty=*/false);
+  return Ctx.getValue(LLVMV);
+}
+Value *PHINode::removeIncomingValue(BasicBlock *BB) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking())
+    Tracker.track(std::make_unique<PHIRemoveIncoming>(
+        *this, getBasicBlockIndex(BB), Tracker));
+
+  auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
+  llvm::Value *LLVMV =
+      cast<llvm::PHINode>(Val)->removeIncomingValue(LLVMBB,
+                                                    /*DeletePHIIfEmpty=*/false);
+  return Ctx.getValue(LLVMV);
+}
+int PHINode::getBasicBlockIndex(const BasicBlock *BB) const {
+  auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
+  return cast<llvm::PHINode>(Val)->getBasicBlockIndex(LLVMBB);
+}
+Value *PHINode::getIncomingValueForBlock(const BasicBlock *BB) const {
+  auto *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
+  llvm::Value *LLVMV =
+      cast<llvm::PHINode>(Val)->getIncomingValueForBlock(LLVMBB);
+  return Ctx.getValue(LLVMV);
+}
+Value *PHINode::hasConstantValue() const {
+  llvm::Value *LLVMV = cast<llvm::PHINode>(Val)->hasConstantValue();
+  return LLVMV != nullptr ? Ctx.getValue(LLVMV) : nullptr;
+}
+
+#ifndef NDEBUG
+void PHINode::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void PHINode::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -1236,6 +1337,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
         new GetElementPtrInst(LLVMGEP, *this));
     return It->second.get();
   }
+  case llvm::Instruction::PHI: {
+    auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
+    It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
+    return It->second.get();
+  }
   default:
     break;
   }
@@ -1301,6 +1407,11 @@ Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
 }
 
+PHINode *Context::createPHINode(llvm::PHINode *I) {
+  auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
+  return cast<PHINode>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index eae55d7b3d962..0310160e8bf35 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -42,6 +42,81 @@ void UseSwap::dump() const {
 }
 #endif // NDEBUG
 
+PHISetIncoming::PHISetIncoming(PHINode &PHI, unsigned Idx, What What,
+                               Tracker &Tracker)
+    : IRChangeBase(Tracker), PHI(PHI), Idx(Idx) {
+  switch (What) {
+  case What::Value:
+    OrigValueOrBB = PHI.getIncomingValue(Idx);
+    break;
+  case What::Block:
+    OrigValueOrBB = PHI.getIncomingBlock(Idx);
+    break;
+  }
+}
+
+void PHISetIncoming::revert() {
+  if (auto *V = OrigValueOrBB.dyn_cast<Value *>())
+    PHI.setIncomingValue(Idx, V);
+  else
+    PHI.setIncomingBlock(Idx, OrigValueOrBB.get<BasicBlock *>());
+}
+
+#ifndef NDEBUG
+void PHISetIncoming::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+PHIRemoveIncoming::PHIRemoveIncoming(PHINode &PHI, unsigned RemovedIdx,
+                                     Tracker &Tracker)
+    : IRChangeBase(Tracker), PHI(PHI), RemovedIdx(RemovedIdx) {
+  RemovedV = PHI.getIncomingValue(RemovedIdx);
+  RemovedBB = PHI.getIncomingBlock(RemovedIdx);
+}
+
+void PHIRemoveIncoming::revert() {
+  // Special case: if the PHI is now empty, as we don't need to care about the
+  // order of the incoming values.
+  unsigned NumIncoming = PHI.getNumIncomingValues();
+  if (NumIncoming == 0) {
+    PHI.addIncoming(RemovedV, RemovedBB);
+    return;
+  }
+  // Shift all incoming values by one starting from the end until `Idx`.
+  // Start by adding a copy of the last incoming values.
+  unsigned LastIdx = NumIncoming - 1;
+  PHI.addIncoming(PHI.getIncomingValue(LastIdx), PHI.getIncomingBlock(LastIdx));
+  for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) {
+    auto *PrevV = PHI.getIncomingValue(Idx - 1);
+    auto *PrevBB = PHI.getIncomingBlock(Idx - 1);
+    PHI.setIncomingValue(Idx, PrevV);
+    PHI.setIncomingBlock(Idx, PrevBB);
+  }
+  PHI.setIncomingValue(RemovedIdx, RemovedV);
+  PHI.setIncomingBlock(RemovedIdx, RemovedBB);
+}
+
+#ifndef NDEBUG
+void PHIRemoveIncoming::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+PHIAddIncoming::PHIAddIncoming(PHINode &PHI, Tracker &Tracker)
+    : IRChangeBase(Tracker), PHI(PHI), Idx(PHI.getNumIncomingValues()) {}
+
+void PHIAddIncoming::revert() { PHI.removeIncomingValue(Idx); }
+
+#ifndef NDEBUG
+void PHIAddIncoming::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 Tracker::~Tracker() {
   assert(Changes.empty() && "You must accept or revert changes!");
 }
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 94065c008dd8c..c2a99bf66f703 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1453,3 +1453,123 @@ define void @foo(ptr %ptr, <2 x ptr> %ptrs) {
   EXPECT_EQ(NewGEP2->getPrevNode(), Ret);
   EXPECT_EQ(NewGEP2->getNextNode(), nullptr);
 }
+
+TEST_F(SandboxIRTest, PHINode) {
+  parseIR(C, R"IR(
+define void @foo(i32 %arg) {
+bb1:
+  br label %bb2
+
+bb2:
+  %phi = phi i32 [ %arg, %bb1 ], [ 0, %bb2 ]
+  br label %bb2
+
+bb3:
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
+  auto *LLVMBB2 = getBasicBlockByName(LLVMF, "bb2");
+  auto *LLVMBB3 = getBasicBlockByName(LLVMF, "bb3");
+  auto LLVMIt = LLVMBB2->begin();
+  auto *LLVMPHI = cast<llvm::PHINode>(&*LLVMIt++);
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(&LLVMF);
+  auto *Arg = F->getArg(0);
+  auto *BB1 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB1));
+  auto *BB2 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB2));
+  auto *BB3 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB3));
+  auto It = BB2->begin();
+  // Check classof().
+  auto *PHI = cast<sandboxir::PHINode>(&*It++);
+  auto *Br = cast<sandboxir::BranchInst>(&*It++);
+  // Check blocks().
+  EXPECT_EQ(range_size(PHI->blocks()), range_size(LLVMPHI->blocks()));
+  auto BlockIt = PHI->block_begin();
+  for (llvm::BasicBlock *LLVMBB : LLVMPHI->blocks()) {
+    sandboxir::BasicBlock *BB = *BlockIt++;
+    EXPECT_EQ(BB, Ctx.getValue(LLVMBB));
+  }
+  // Check incoming_values().
+  EXPECT_EQ(range_size(PHI->incoming_values()),
+            range_size(LLVMPHI->incoming_values()));
+  auto IncIt = PHI->incoming_values().begin();
+  for (llvm::Value *LLVMV : LLVMPHI->incoming_values()) {
+    sandboxir::Value *IncV = *IncIt++;
+    EXPECT_EQ(IncV, Ctx.getValue(LLVMV));
+  }
+  // Check getNumIncomingValues().
+  EXPECT_EQ(PHI->getNumIncomingValues(), LLVMPHI->getNumIncomingValues());
+  // Check getIncomingValue().
+  EXPECT_EQ(PHI->getIncomingValue(0),
+            Ctx.getValue(LLVMPHI->getIncomingValue(0)));
+  EXPECT_EQ(PHI->getIncomingValue(1),
+            Ctx.getValue(LLVMPHI->getIncomingValue(1)));
+  // Check setIncomingValue().
+  auto *OrigV = PHI->getIncomingValue(0);
+  PHI->setIncomingValue(0, PHI);
+  EXPECT_EQ(PHI->getIncomingValue(0), PHI);
+  PHI->setIncomingValue(0, OrigV);
+  // Check getOperandNumForIncomingValue().
+  EXPECT_EQ(sandboxir::PHINode::getOperandNumForIncomingValue(0),
+            llvm::PHINode::getOperandNumForIncomingValue(0));
+  // Check getIncomingValueNumForOperand().
+  EXPECT_EQ(sandboxir::PHINode::getIncomingValueNumForOperand(0),
+            llvm::PHINode::getIncomingValueNumForOperand(0));
+  // Check getIncomingBlock(unsigned).
+  EXPECT_EQ(PHI->getIncomingBlock(0),
+            Ctx.getValue(LLVMPHI->getIncomingBlock(0)));
+  // Check getIncomingBlock(Use).
+  llvm::Use &LLVMUse = LLVMPHI->getOperandUse(0);
+  sandboxir::Use Use = PHI->getOperandUse(0);
+  EXPECT_EQ(PHI->getIncomingBlock(Use),
+            Ctx.getValue(LLVMPHI->getIncomingBlock(LLVMUse)));
+  // Check setIncomingBlock().
+  sandboxir::BasicBlock *OrigBB = PHI->getIncomingBlock(0);
+  EXPECT_NE(OrigBB, BB2);
+  PHI->setIncomingBlock(0, BB2);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB2);
+  PHI->setIncomingBlock(0, OrigBB);
+  EXPECT_EQ(PHI->getIncomingBlock(0), OrigBB);
+  // Check addIncoming().
+  unsigned OrigNumIncoming = PHI->getNumIncomingValues();
+  PHI->addIncoming(Arg, BB3);
+  EXPECT_EQ(PHI->getNumIncomingValues(), LLVMPHI->getNumIncomingValues());
+  EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming + 1);
+  EXPECT_EQ(PHI->getIncomingValue(OrigNumIncoming), Arg);
+  EXPECT_EQ(PHI->getIncomingBlock(OrigNumIncoming), BB3);
+  // Check removeIncomingValue(unsigned).
+  PHI->removeIncomingValue(OrigNumIncoming);
+  EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming);
+  // Check removeIncomingValue(BasicBlock *).
+  PHI->addIncoming(Arg, BB3);
+  PHI->removeIncomingValue(BB3);
+  EXPECT_EQ(PHI->getNumIncomingValues(), OrigNumIncoming);
+  // Check getBasicBlockIndex().
+  EXPECT_EQ(PHI->getBasicBlockIndex(BB1), LLVMPHI->getBasicBlockIndex(LLVMBB1));
+  // Check getIncomingValueForBlock().
+  EXPECT_EQ(PHI->getIncomingValueForBlock(BB1),
+            Ctx.getValue(LLVMPHI->getIncomingValueForBlock(LLVMBB1)));
+  // Check hasConstantValue().
+  llvm::Value *ConstV = LLVMPHI->hasConstantValue();
+  EXPECT_EQ(PHI->hasConstantValue(),
+            ConstV != nullptr ? Ctx.getValue(ConstV) : nullptr);
+  // Check hasConstantOrUndefValue().
+  EXPECT_EQ(PHI->hasConstantOrUndefValue(), LLVMPHI->hasConstantOrUndefValue());
+  // Check isComplete().
+  EXPECT_EQ(PHI->isComplete(), LLVMPHI->isComplete());
+
+  // Check create().
+  auto *NewPHI = cast<sandboxir::PHINode>(
+      sandboxir::PHINode::create(PHI->getType(), 0, Br, Ctx, "NewPHI"));
+  EXPECT_EQ(NewPHI->getType(), PHI->getType());
+  EXPECT_EQ(NewPHI->getNextNode(), Br);
+  EXPECT_EQ(NewPHI->getName(), "NewPHI");
+  EXPECT_EQ(NewPHI->getNumIncomingValues(), 0u);
+  for (auto [Idx, V] : enumerate(PHI->incoming_values())) {
+    sandboxir::BasicBlock *IncBB = PHI->getIncomingBlock(Idx);
+    NewPHI->addIncoming(V, IncBB);
+  }
+  EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
+}
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index cd737d33dd193..d016c7793a52c 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -584,3 +584,127 @@ define void @foo(i8 %arg) {
   Ctx.revert();
   EXPECT_EQ(CallBr->getIndirectDest(0), OrigIndirectDest);
 }
+
+TEST_F(TrackerTest, PHINodeSetters) {
+  parseIR(C, R"IR(
+define void @foo(i8 %arg0, i8 %arg1, i8 %arg2) {
+bb0:
+  br label %bb2
+
+bb1:
+  %phi = phi i8 [ %arg0, %bb0 ], [ %arg1, %bb1 ]
+  br label %bb1
+
+bb2:
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  unsigned ArgIdx = 0;
+  auto *Arg0 = F.getArg(ArgIdx++);
+  auto *Arg1 = F.getArg(ArgIdx++);
+  auto *Arg2 = F.getArg(ArgIdx++);
+  auto *BB0 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb0")));
+  auto *BB1 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb1")));
+  auto *BB2 = cast<sandboxir::BasicBlock>(
+      Ctx.getValue(getBasicBlockByName(LLVMF, "bb2")));
+  auto *PHI = cast<sandboxir::PHINode>(&*BB1->begin());
+
+  // Check setIncomingValue().
+  Ctx.save();
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  PHI->setIncomingValue(0, Arg2);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg2);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check setIncomingBlock().
+  Ctx.save();
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  PHI->setIncomingBlock(0, BB2);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB2);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check addIncoming().
+  Ctx.save();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  PHI->addIncoming(Arg1, BB2);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 3u);
+  EXPECT_EQ(PHI->getIncomingBlock(2), BB2);
+  EXPECT_EQ(PHI->getIncomingValue(2), Arg1);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check removeIncomingValue(1).
+  Ctx.save();
+  PHI->removeIncomingValue(1);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 1u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check removeIncomingValue(0).
+  Ctx.save();
+  PHI->removeIncomingValue(0u);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 1u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg1);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check removeIncomingValue() remove all.
+  Ctx.save();
+  PHI->removeIncomingValue(0u);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 1u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg1);
+  PHI->removeIncomingValue(0u);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 0u);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+
+  // Check removeIncomingValue(BasicBlock *).
+  Ctx.save();
+  PHI->removeIncomingValue(BB1);
+  EXPECT_EQ(PHI->getNumIncomingValues(), 1u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  Ctx.revert();
+  EXPECT_EQ(PHI->getNumIncomingValues(), 2u);
+  EXPECT_EQ(PHI->getIncomingBlock(0), BB0);
+  EXPECT_EQ(PHI->getIncomingValue(0), Arg0);
+  EXPECT_EQ(PHI->getIncomingBlock(1), BB1);
+  EXPECT_EQ(PHI->getIncomingValue(1), Arg1);
+}

>From 588ba5803842eb7c7be61f6378109cfc18e4fd6f Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Wed, 31 Jul 2024 18:11:03 +0000
Subject: [PATCH 2/3] Cleanup mismerge

---
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index bd1b02abf211a..32b23a1284eda 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -1623,7 +1623,6 @@ define void @foo(ptr %ptr) {
 }
 )IR");
   Function &LLVMF = *M->getFunction("foo");
-<
   sandboxir::Context Ctx(C);
   sandboxir::Function *F = Ctx.createFunction(&LLVMF);
   unsigned ArgIdx = 0;
@@ -1765,6 +1764,10 @@ define void @foo(i32 %arg) {
   br label %bb2
 
 bb3:
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
   auto *LLVMBB1 = getBasicBlockByName(LLVMF, "bb1");
   auto *LLVMBB2 = getBasicBlockByName(LLVMF, "bb2");
   auto *LLVMBB3 = getBasicBlockByName(LLVMF, "bb3");
@@ -1868,4 +1871,4 @@ define void @foo(i32 %arg) {
     NewPHI->addIncoming(V, IncBB);
   }
   EXPECT_EQ(NewPHI->getNumIncomingValues(), PHI->getNumIncomingValues());
-}
\ No newline at end of file
+}

>From b8f3023a763f37ee50fc9245e77f1718112df91a Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Wed, 31 Jul 2024 18:18:52 +0000
Subject: [PATCH 3/3] Fix formatting

---
 llvm/include/llvm/SandboxIR/SandboxIR.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index dd00b1ba83bcb..5c908a0e3b977 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -1371,7 +1371,7 @@ class CastInst : public Instruction {
   LLVM_DUMP_METHOD void dump() const override;
 #endif
 };
-  
+
 class PHINode final : public Instruction {
   /// Use Context::createPHINode(). Don't call the constructor directly.
   PHINode(llvm::PHINode *PHI, Context &Ctx)



More information about the llvm-commits mailing list