[llvm] [SandboxIR][Tracker] Track eraseFromParent() (PR #99431)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 23:47:17 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99431

>From a3606d8f67b5fb61b1376a97d668e86013a718e7 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 16 Jul 2024 16:44:03 -0700
Subject: [PATCH] [SandboxIR][Tracker] Track eraseFromParent()

This patch adds tracking support for Instruction::eraseFromParent().
The Instruction is not actually being erased, but instead it is
detached from the instruction list and drops its Use edges.
The original instruction position and Use edges are saved in the
`EraseFromParent` change object, and are being used during `revert()`
to restore the original state.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h  |  4 +-
 llvm/include/llvm/SandboxIR/Tracker.h    | 45 ++++++++++++++++++-
 llvm/lib/SandboxIR/SandboxIR.cpp         | 22 ++++++++--
 llvm/lib/SandboxIR/Tracker.cpp           | 56 ++++++++++++++++++++++++
 llvm/unittests/SandboxIR/TrackerTest.cpp | 52 ++++++++++++++++++++++
 5 files changed, 174 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index c5d59ba47ca31..a9f0177eb9338 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -493,6 +493,7 @@ class Instruction : public sandboxir::User {
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
   virtual SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const = 0;
+  friend class EraseFromParent; // For getLLVMInstrs().
 
 public:
   static const char *getOpcodeName(Opcode Opc);
@@ -658,6 +659,7 @@ class Context {
   friend void Instruction::eraseFromParent(); // For detach().
   /// Take ownership of VPtr and store it in `LLVMValueToValueMap`.
   Value *registerValue(std::unique_ptr<Value> &&VPtr);
+  friend class EraseFromParent; // For registerValue().
   /// This is the actual function that creates sandboxir values for \p V,
   /// and among others handles all instruction types.
   Value *getOrCreateValueInternal(llvm::Value *V, llvm::User *U = nullptr);
@@ -682,7 +684,7 @@ class Context {
   friend class BasicBlock; // For getOrCreateValue().
 
 public:
-  Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {}
+  Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}
 
   Tracker &getTracker() { return IRTracker; }
   /// Convenience function for `getTracker().save()`
diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h
index 2d0904f5665b1..68b0f38ac99ab 100644
--- a/llvm/include/llvm/SandboxIR/Tracker.h
+++ b/llvm/include/llvm/SandboxIR/Tracker.h
@@ -40,6 +40,7 @@
 #ifndef LLVM_SANDBOXIR_TRACKER_H
 #define LLVM_SANDBOXIR_TRACKER_H
 
+#include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
@@ -99,6 +100,46 @@ class UseSet : public IRChangeBase {
 #endif
 };
 
+class EraseFromParent : public IRChangeBase {
+  /// Contains all the data we need to restore an "erased" (i.e., detached)
+  /// instruction: the instruction itself and the operands data.
+  struct InstrData {
+    /// The operand and the corresponding operand number.
+    struct OpData {
+      llvm::Value *Op;
+      unsigned OpNum;
+    };
+    /// The operands that got dropped.
+    SmallVector<OpData> OpDataVec;
+    /// The instruction that got "erased".
+    llvm::Instruction *LLVMI;
+  };
+  /// The instruction data is in revere program order, which helps create the
+  /// original program order during revert().
+  SmallVector<InstrData> InstrData;
+  /// This is either the next Instruction in the stream, or the parent
+  /// BasicBlock if at the end of the BB.
+  PointerUnion<llvm::Instruction *, llvm::BasicBlock *> NextLLVMIOrBB;
+  /// We take ownership of the "erased" instruction.
+  std::unique_ptr<sandboxir::Value> ErasedIPtr;
+
+public:
+  EraseFromParent(std::unique_ptr<sandboxir::Value> &&IPtr, Tracker &Tracker);
+  void revert() final;
+  void accept() final;
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final {
+    dumpCommon(OS);
+    OS << "EraseFromParent";
+  }
+  LLVM_DUMP_METHOD void dump() const final;
+  friend raw_ostream &operator<<(raw_ostream &OS, const EraseFromParent &C) {
+    C.dump(OS);
+    return OS;
+  }
+#endif
+};
+
 /// The tracker collects all the change objects and implements the main API for
 /// saving / reverting / accepting.
 class Tracker {
@@ -116,6 +157,7 @@ class Tracker {
 #endif
   /// The current state of the tracker.
   TrackerState State = TrackerState::Disabled;
+  Context &Ctx;
 
 public:
 #ifndef NDEBUG
@@ -124,8 +166,9 @@ class Tracker {
   bool InMiddleOfCreatingChange = false;
 #endif // NDEBUG
 
-  Tracker() = default;
+  explicit Tracker(Context &Ctx) : Ctx(Ctx) {}
   ~Tracker();
+  Context &getContext() const { return Ctx; }
   /// Record \p Change and take ownership. This is the main function used to
   /// track Sandbox IR changes.
   void track(std::unique_ptr<IRChangeBase> &&Change);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 944869a37989c..2f80428c231ae 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -341,10 +341,26 @@ void Instruction::removeFromParent() {
 void Instruction::eraseFromParent() {
   assert(users().empty() && "Still connected to users, can't erase!");
   std::unique_ptr<Value> Detached = Ctx.detach(this);
-  // We don't have Tracking yet, so just erase the LLVM IR instructions.
+  auto LLVMInstrs = getLLVMInstrs();
+
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.isTracking()) {
+    Tracker.track(
+        std::make_unique<EraseFromParent>(std::move(Detached), Tracker));
+    // We don't actually delete the IR instruction, because then it would be
+    // impossible to bring it back from the dead at the same memory location.
+    // Instead we remove it from its BB and track its current location.
+    for (llvm::Instruction *I : LLVMInstrs)
+      I->removeFromParent();
+    // TODO: Multi-instructions need special treatment because some of the
+    // references are internal to the instruction.
+    for (llvm::Instruction *I : LLVMInstrs)
+      I->dropAllReferences();
+    return;
+  }
+
   // Erase in reverse to avoid erasing nstructions with attached uses.
-  auto Instrs = getLLVMInstrs();
-  for (llvm::Instruction *I : reverse(Instrs))
+  for (llvm::Instruction *I : reverse(LLVMInstrs))
     I->eraseFromParent();
 }
 
diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp
index 1182f5c55d10b..8f1dc57c9db01 100644
--- a/llvm/lib/SandboxIR/Tracker.cpp
+++ b/llvm/lib/SandboxIR/Tracker.cpp
@@ -41,6 +41,62 @@ Tracker::~Tracker() {
   assert(Changes.empty() && "You must accept or revert changes!");
 }
 
+EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr,
+                                 Tracker &Tracker)
+    : IRChangeBase(Tracker), ErasedIPtr(std::move(ErasedIPtr)) {
+  auto *I = cast<Instruction>(this->ErasedIPtr.get());
+  auto LLVMInstrs = I->getLLVMInstrs();
+  // Iterate in reverse program order.
+  for (auto *LLVMI : reverse(LLVMInstrs)) {
+    SmallVector<InstrData::OpData> OpDataVec;
+    for (auto [OpNum, Use] : enumerate(LLVMI->operands()))
+      OpDataVec.push_back({Use.get(), static_cast<unsigned>(OpNum)});
+    InstrData.push_back({OpDataVec, LLVMI});
+  }
+#ifndef NDEBUG
+  for (auto Idx : seq<unsigned>(1, InstrData.size()))
+    assert(InstrData[Idx].LLVMI->comesBefore(InstrData[Idx - 1].LLVMI) &&
+           "Expected reverse program order!");
+#endif
+  auto *BotLLVMI = cast<llvm::Instruction>(I->Val);
+  if (BotLLVMI->getNextNode() != nullptr)
+    NextLLVMIOrBB = BotLLVMI->getNextNode();
+  else
+    NextLLVMIOrBB = BotLLVMI->getParent();
+}
+
+void EraseFromParent::accept() {
+  for (const auto &IData : InstrData)
+    IData.LLVMI->deleteValue();
+}
+
+void EraseFromParent::revert() {
+  auto [OpData, BotLLVMI] = InstrData[0];
+  if (auto *NextLLVMI = NextLLVMIOrBB.dyn_cast<llvm::Instruction *>()) {
+    BotLLVMI->insertBefore(NextLLVMI);
+  } else {
+    auto *LLVMBB = NextLLVMIOrBB.get<llvm::BasicBlock *>();
+    BotLLVMI->insertInto(LLVMBB, LLVMBB->end());
+  }
+  for (auto [Op, OpNum] : OpData)
+    BotLLVMI->setOperand(OpNum, Op);
+
+  for (auto [OpData, LLVMI] : drop_begin(InstrData)) {
+    LLVMI->insertBefore(BotLLVMI);
+    for (auto [Op, OpNum] : OpData)
+      LLVMI->setOperand(OpNum, Op);
+    BotLLVMI = LLVMI;
+  }
+  Parent.getContext().registerValue(std::move(ErasedIPtr));
+}
+
+#ifndef NDEBUG
+void EraseFromParent::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif
+
 void Tracker::track(std::unique_ptr<IRChangeBase> &&Change) {
   assert(State == TrackerState::Record && "The tracker should be tracking!");
   Changes.push_back(std::move(Change));
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index f090dc521c32b..ccbdf0b7b71e4 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -146,3 +146,55 @@ define void @foo(ptr %ptr) {
   Ctx.accept();
   EXPECT_EQ(St0->getOperand(0), Ld1);
 }
+
+TEST_F(TrackerTest, EraseFromParent) {
+  parseIR(C, R"IR(
+define void @foo(i32 %v1) {
+  %add0 = add i32 %v1, %v1
+  %add1 = add i32 %add0, %v1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto *F = Ctx.createFunction(&LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  sandboxir::Instruction *Add0 = &*It++;
+  sandboxir::Instruction *Add1 = &*It++;
+  sandboxir::Instruction *Ret = &*It++;
+
+  Ctx.save();
+  auto &Tracker = Ctx.getTracker();
+  // Check erase.
+  Add1->eraseFromParent();
+  It = BB->begin();
+  EXPECT_EQ(&*It++, Add0);
+  EXPECT_EQ(&*It++, Ret);
+  EXPECT_EQ(It, BB->end());
+  EXPECT_EQ(Add0->getNumUses(), 0u);
+
+  // Check revert().
+  Ctx.revert();
+  It = BB->begin();
+  EXPECT_EQ(&*It++, Add0);
+  EXPECT_EQ(&*It++, Add1);
+  EXPECT_EQ(&*It++, Ret);
+  EXPECT_EQ(It, BB->end());
+  EXPECT_EQ(Add1->getOperand(0), Add0);
+
+  // Same for the last instruction in the block.
+  Ctx.save();
+  Ret->eraseFromParent();
+  It = BB->begin();
+  EXPECT_EQ(&*It++, Add0);
+  EXPECT_EQ(&*It++, Add1);
+  EXPECT_EQ(It, BB->end());
+  Ctx.revert();
+  It = BB->begin();
+  EXPECT_EQ(&*It++, Add0);
+  EXPECT_EQ(&*It++, Add1);
+  EXPECT_EQ(&*It++, Ret);
+  EXPECT_EQ(It, BB->end());
+}



More information about the llvm-commits mailing list