[llvm] [SandboxIR] Add callbacks for instruction insert/remove/move ops (PR #112965)

Jorge Gorbe Moya via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 18:17:58 PDT 2024


https://github.com/slackito updated https://github.com/llvm/llvm-project/pull/112965

>From 56ffef4a97b888e1c53153f71301befdc3cfd24d Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 12:00:06 -0700
Subject: [PATCH 1/8] [SandboxIR] Add callbacks for instruction
 insert/remove/move ops.

---
 llvm/include/llvm/SandboxIR/Context.h      | 56 +++++++++++++-
 llvm/lib/SandboxIR/Context.cpp             | 66 +++++++++++++++--
 llvm/lib/SandboxIR/Instruction.cpp         |  5 ++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 85 ++++++++++++++++++++++
 4 files changed, 205 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 1285598a1c0282..836988639a14bb 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -9,18 +9,31 @@
 #ifndef LLVM_SANDBOXIR_CONTEXT_H
 #define LLVM_SANDBOXIR_CONTEXT_H
 
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/SandboxIR/Tracker.h"
 #include "llvm/SandboxIR/Type.h"
 
 namespace llvm::sandboxir {
 
-class Module;
-class Value;
 class Argument;
+class BBIterator;
 class Constant;
+class Module;
+class Value;
 
 class Context {
+public:
+  // A RemoveInstrCallback receives the instruction about to be removed.
+  using RemoveInstrCallback = std::function<void(Instruction *)>;
+  // A InsertInstrCallback receives the instruction about to be created.
+  using InsertInstrCallback = std::function<void(Instruction *)>;
+  // A MoveInstrCallback receives the instruction about to be moved, the
+  // destination BB and an iterator pointing to the insertion position.
+  using MoveInstrCallback =
+      std::function<void(Instruction *, const BBIterator &)>;
+
 protected:
   LLVMContext &LLVMCtx;
   friend class Type;        // For LLVMCtx.
@@ -48,6 +61,21 @@ class Context {
   /// Type objects.
   DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
 
+  /// Callbacks called when an IR instruction is about to get removed. Keys are
+  /// used as IDs for deregistration.
+  DenseMap<int, RemoveInstrCallback> RemoveInstrCallbacks;
+  /// Callbacks called when an IR instruction is about to get inserted. Keys are
+  /// used as IDs for deregistration.
+  DenseMap<int, InsertInstrCallback> InsertInstrCallbacks;
+  /// Callbacks called when an IR instruction is about to get moved. Keys are
+  /// used as IDs for deregistration.
+  DenseMap<int, MoveInstrCallback> MoveInstrCallbacks;
+
+  /// A counter used for assigning callback IDs during registration. The same
+  /// counter is used for all kinds of callbacks so we can detect mismatched
+  /// registration/deregistration.
+  static int NextCallbackId;
+
   /// Remove \p V from the maps and returns the unique_ptr.
   std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
   /// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -70,6 +98,10 @@ class Context {
   Constant *getOrCreateConstant(llvm::Constant *LLVMC);
   friend class Utils; // For getMemoryBase
 
+  void runRemoveInstrCallbacks(Instruction *I);
+  void runInsertInstrCallbacks(Instruction *I);
+  void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
+
   // Friends for getOrCreateConstant().
 #define DEF_CONST(ID, CLASS) friend class CLASS;
 #include "llvm/SandboxIR/Values.def"
@@ -198,6 +230,26 @@ class Context {
 
   /// \Returns the number of values registered with Context.
   size_t getNumValues() const { return LLVMValueToValueMap.size(); }
+
+  /// Register a callback that gets called when a SandboxIR instruction is about
+  /// to be removed from its parent. Note that this will also be called when
+  /// reverting the creation of an instruction.
+  /// \Returns a callback ID for later deregistration.
+  int registerRemoveInstrCallback(RemoveInstrCallback CB);
+  void unregisterRemoveInstrCallback(int CallbackId);
+
+  /// Register a callback that gets called right after a SandboxIR instruction
+  /// is created. Note that this will also be called when reverting the removal
+  /// of an instruction.
+  /// \Returns a callback ID for later deregistration.
+  int registerInsertInstrCallback(InsertInstrCallback CB);
+  void unregisterInsertInstrCallback(int CallbackId);
+
+  /// Register a callback that gets called when a SandboxIR instruction is about
+  /// to be moved. Note that this will also be called when reverting a move.
+  /// \Returns a callback ID for later deregistration.
+  int registerMoveInstrCallback(MoveInstrCallback CB);
+  void unregisterMoveInstrCallback(int CallbackId);
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 486e935bc35fba..e13f833f1ba29d 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
   assert(VPtr->getSubclassID() != Value::ClassID::User &&
          "Can't register a user!");
 
+  Value *V = VPtr.get();
+  [[maybe_unused]] auto Pair =
+      LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
+  assert(Pair.second && "Already exists!");
+
   // Track creation of instructions.
   // Please note that we don't allow the creation of detached instructions,
   // meaning that the instructions need to be inserted into a block upon
   // creation. This is why the tracker class combines creation and insertion.
-  if (auto *I = dyn_cast<Instruction>(VPtr.get()))
+  if (auto *I = dyn_cast<Instruction>(V)) {
     getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
+    runInsertInstrCallbacks(I);
+  }
 
-  Value *V = VPtr.get();
-  [[maybe_unused]] auto Pair =
-      LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
-  assert(Pair.second && "Already exists!");
   return V;
 }
 
@@ -660,4 +663,57 @@ Module *Context::createModule(llvm::Module *LLVMM) {
   return M;
 }
 
+void Context::runRemoveInstrCallbacks(Instruction *I) {
+  for (const auto &CBEntry : RemoveInstrCallbacks) {
+    CBEntry.second(I);
+  }
+}
+
+void Context::runInsertInstrCallbacks(Instruction *I) {
+  for (auto &CBEntry : InsertInstrCallbacks) {
+    CBEntry.second(I);
+  }
+}
+
+void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
+  for (auto &CBEntry : MoveInstrCallbacks) {
+    CBEntry.second(I, WhereIt);
+  }
+}
+
+int Context::NextCallbackId = 0;
+
+int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
+  int Id = NextCallbackId++;
+  RemoveInstrCallbacks[Id] = CB;
+  return Id;
+}
+void Context::unregisterRemoveInstrCallback(int CallbackId) {
+  [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId);
+  assert(erased &&
+         "Callback id not found in RemoveInstrCallbacks during deregistration");
+}
+
+int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
+  int Id = NextCallbackId++;
+  InsertInstrCallbacks[Id] = CB;
+  return Id;
+}
+void Context::unregisterInsertInstrCallback(int CallbackId) {
+  [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId);
+  assert(erased &&
+         "Callback id not found in InsertInstrCallbacks during deregistration");
+}
+
+int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
+  int Id = NextCallbackId++;
+  MoveInstrCallbacks[Id] = CB;
+  return Id;
+}
+void Context::unregisterMoveInstrCallback(int CallbackId) {
+  [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId);
+  assert(erased &&
+         "Callback id not found in MoveInstrCallbacks during deregistration");
+}
+
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index d80d10370e32d8..ddeb78eea19f73 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -64,6 +64,8 @@ Instruction *Instruction::getPrevNode() const {
 }
 
 void Instruction::removeFromParent() {
+  Ctx.runRemoveInstrCallbacks(this);
+
   Ctx.getTracker().emplaceIfTracking<RemoveFromParent>(this);
 
   // Detach all the LLVM IR instructions from their parent BB.
@@ -73,6 +75,8 @@ void Instruction::removeFromParent() {
 
 void Instruction::eraseFromParent() {
   assert(users().empty() && "Still connected to users, can't erase!");
+
+  Ctx.runRemoveInstrCallbacks(this);
   std::unique_ptr<Value> Detached = Ctx.detach(this);
   auto LLVMInstrs = getLLVMInstrs();
 
@@ -100,6 +104,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
     // Destination is same as origin, nothing to do.
     return;
 
+  Ctx.runMoveInstrCallbacks(this, WhereIt);
   Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);
 
   auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 97113b303f72e5..786580d1046a60 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -22,6 +22,7 @@
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/SourceMgr.h"
 #include "gmock/gmock-matchers.h"
+#include "gmock/gmock-more-matchers.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -5962,3 +5963,87 @@ TEST_F(SandboxIRTest, CheckClassof) {
   EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
 #include "llvm/SandboxIR/Values.def"
 }
+
+TEST_F(SandboxIRTest, InstructionCallbacks) {
+  parseIR(C, R"IR(
+    define void @foo(ptr %ptr, i8 %val) {
+      ret void
+    }
+  )IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  sandboxir::Argument *Ptr = F.getArg(0);
+  sandboxir::Argument *Val = F.getArg(1);
+  sandboxir::Instruction *Ret = &BB.front();
+
+  SmallVector<sandboxir::Instruction *> Inserted;
+  int InsertCbId = Ctx.registerInsertInstrCallback(
+      [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
+
+  SmallVector<sandboxir::Instruction *> Removed;
+  int RemoveCbId = Ctx.registerRemoveInstrCallback(
+      [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
+
+  // Keep the moved instruction and the instruction pointed by the Where
+  // iterator so we can check both callback arguments work as expected.
+  SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
+      Moved;
+  int MoveCbId = Ctx.registerMoveInstrCallback(
+      [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
+        // Use a nullptr to signal "move to end" to keep it single. We only
+        // have a basic block in this test case anyway.
+        if (Where == Where.getNodeParent()->end())
+          Moved.push_back(std::make_pair(I, nullptr));
+        else
+          Moved.push_back(std::make_pair(I, &*Where));
+      });
+
+  Ctx.save();
+  auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
+                                            Ret->getIterator(), Ctx);
+  EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+  EXPECT_THAT(Removed, testing::IsEmpty());
+  EXPECT_THAT(Moved, testing::IsEmpty());
+
+  Ret->moveBefore(NewI);
+  EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+  EXPECT_THAT(Removed, testing::IsEmpty());
+  EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+  Ret->eraseFromParent();
+  EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+  EXPECT_THAT(Removed, testing::ElementsAre(Ret));
+  EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+  NewI->eraseFromParent();
+  EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+  EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
+  EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+  // Check that after revert the callbacks have been called for the inverse
+  // operations of the changes made so far.
+  Ctx.revert();
+  EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
+  EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
+  EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
+                                          std::make_pair(Ret, nullptr)));
+
+  // Check that deregistration works. Do an operation of each type after
+  // deregistering callbacks and check.
+  Inserted.clear();
+  Removed.clear();
+  Moved.clear();
+  Ctx.unregisterInsertInstrCallback(InsertCbId);
+  Ctx.unregisterRemoveInstrCallback(RemoveCbId);
+  Ctx.unregisterMoveInstrCallback(MoveCbId);
+  auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
+                                            Ret->getIterator(), Ctx);
+  Ret->moveBefore(NewI2);
+  Ret->eraseFromParent();
+  EXPECT_THAT(Inserted, testing::IsEmpty());
+  EXPECT_THAT(Removed, testing::IsEmpty());
+  EXPECT_THAT(Moved, testing::IsEmpty());
+}

>From f361d5cc0cde6c4ddfc85bc719025cf84b0818bd Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 13:44:21 -0700
Subject: [PATCH 2/8] Address some review feedback.

- Introduced `CallbackID` typedef for callback ids rather than a plain
  int.
- Remove unnecessary braces.
---
 llvm/include/llvm/SandboxIR/Context.h      | 11 +++--
 llvm/lib/SandboxIR/Context.cpp             | 47 ++++++++++------------
 llvm/unittests/SandboxIR/SandboxIRTest.cpp |  8 ++--
 3 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 836988639a14bb..2da80481f9b6c4 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -34,6 +34,9 @@ class Context {
   using MoveInstrCallback =
       std::function<void(Instruction *, const BBIterator &)>;
 
+  /// An ID for a registered callback. Used for deregistration.
+  using CallbackID = int;
+
 protected:
   LLVMContext &LLVMCtx;
   friend class Type;        // For LLVMCtx.
@@ -63,18 +66,18 @@ class Context {
 
   /// Callbacks called when an IR instruction is about to get removed. Keys are
   /// used as IDs for deregistration.
-  DenseMap<int, RemoveInstrCallback> RemoveInstrCallbacks;
+  DenseMap<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
   /// Callbacks called when an IR instruction is about to get inserted. Keys are
   /// used as IDs for deregistration.
-  DenseMap<int, InsertInstrCallback> InsertInstrCallbacks;
+  DenseMap<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
   /// Callbacks called when an IR instruction is about to get moved. Keys are
   /// used as IDs for deregistration.
-  DenseMap<int, MoveInstrCallback> MoveInstrCallbacks;
+  DenseMap<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
 
   /// A counter used for assigning callback IDs during registration. The same
   /// counter is used for all kinds of callbacks so we can detect mismatched
   /// registration/deregistration.
-  static int NextCallbackId;
+  static CallbackID NextCallbackID;
 
   /// Remove \p V from the maps and returns the unique_ptr.
   std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index e13f833f1ba29d..f22945a076aa28 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -664,56 +664,53 @@ Module *Context::createModule(llvm::Module *LLVMM) {
 }
 
 void Context::runRemoveInstrCallbacks(Instruction *I) {
-  for (const auto &CBEntry : RemoveInstrCallbacks) {
+  for (const auto &CBEntry : RemoveInstrCallbacks)
     CBEntry.second(I);
-  }
 }
 
 void Context::runInsertInstrCallbacks(Instruction *I) {
-  for (auto &CBEntry : InsertInstrCallbacks) {
+  for (auto &CBEntry : InsertInstrCallbacks)
     CBEntry.second(I);
-  }
 }
 
 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
-  for (auto &CBEntry : MoveInstrCallbacks) {
+  for (auto &CBEntry : MoveInstrCallbacks)
     CBEntry.second(I, WhereIt);
-  }
 }
 
-int Context::NextCallbackId = 0;
+Context::CallbackID Context::NextCallbackID = 0;
 
 int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
-  int Id = NextCallbackId++;
-  RemoveInstrCallbacks[Id] = CB;
-  return Id;
+  CallbackID ID = NextCallbackID++;
+  RemoveInstrCallbacks[ID] = CB;
+  return ID;
 }
-void Context::unregisterRemoveInstrCallback(int CallbackId) {
-  [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId);
+void Context::unregisterRemoveInstrCallback(CallbackID ID) {
+  [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID);
   assert(erased &&
-         "Callback id not found in RemoveInstrCallbacks during deregistration");
+         "Callback ID not found in RemoveInstrCallbacks during deregistration");
 }
 
 int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
-  int Id = NextCallbackId++;
-  InsertInstrCallbacks[Id] = CB;
-  return Id;
+  CallbackID ID = NextCallbackID++;
+  InsertInstrCallbacks[ID] = CB;
+  return ID;
 }
-void Context::unregisterInsertInstrCallback(int CallbackId) {
-  [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId);
+void Context::unregisterInsertInstrCallback(CallbackID ID) {
+  [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID);
   assert(erased &&
-         "Callback id not found in InsertInstrCallbacks during deregistration");
+         "Callback ID not found in InsertInstrCallbacks during deregistration");
 }
 
 int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
-  int Id = NextCallbackId++;
-  MoveInstrCallbacks[Id] = CB;
-  return Id;
+  CallbackID ID = NextCallbackID++;
+  MoveInstrCallbacks[ID] = CB;
+  return ID;
 }
-void Context::unregisterMoveInstrCallback(int CallbackId) {
-  [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId);
+void Context::unregisterMoveInstrCallback(CallbackID ID) {
+  [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID);
   assert(erased &&
-         "Callback id not found in MoveInstrCallbacks during deregistration");
+         "Callback ID not found in MoveInstrCallbacks during deregistration");
 }
 
 } // namespace llvm::sandboxir
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 786580d1046a60..268c6e0712c505 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5980,18 +5980,18 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   sandboxir::Instruction *Ret = &BB.front();
 
   SmallVector<sandboxir::Instruction *> Inserted;
-  int InsertCbId = Ctx.registerInsertInstrCallback(
+  auto InsertCbId = Ctx.registerInsertInstrCallback(
       [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
 
   SmallVector<sandboxir::Instruction *> Removed;
-  int RemoveCbId = Ctx.registerRemoveInstrCallback(
+  auto RemoveCbId = Ctx.registerRemoveInstrCallback(
       [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
 
   // Keep the moved instruction and the instruction pointed by the Where
   // iterator so we can check both callback arguments work as expected.
   SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
       Moved;
-  int MoveCbId = Ctx.registerMoveInstrCallback(
+  auto MoveCbId = Ctx.registerMoveInstrCallback(
       [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
         // Use a nullptr to signal "move to end" to keep it single. We only
         // have a basic block in this test case anyway.
@@ -6040,7 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   Ctx.unregisterRemoveInstrCallback(RemoveCbId);
   Ctx.unregisterMoveInstrCallback(MoveCbId);
   auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
-                                            Ret->getIterator(), Ctx);
+                                             Ret->getIterator(), Ctx);
   Ret->moveBefore(NewI2);
   Ret->eraseFromParent();
   EXPECT_THAT(Inserted, testing::IsEmpty());

>From e84e9e6a3f889f8bbbe1a8d7074e9213ae9904c3 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 14:10:25 -0700
Subject: [PATCH 3/8] Another round of feedback.

- Updated callback (de)registration method signatures in header to use
  the CallbackID type instead of int.

- Corrected case in some variable names to follow the style guide.
---
 llvm/include/llvm/SandboxIR/Context.h | 12 ++++++------
 llvm/lib/SandboxIR/Context.cpp        | 12 ++++++------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 2da80481f9b6c4..cc3572db6039da 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -238,21 +238,21 @@ class Context {
   /// to be removed from its parent. Note that this will also be called when
   /// reverting the creation of an instruction.
   /// \Returns a callback ID for later deregistration.
-  int registerRemoveInstrCallback(RemoveInstrCallback CB);
-  void unregisterRemoveInstrCallback(int CallbackId);
+  CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB);
+  void unregisterRemoveInstrCallback(CallbackID ID);
 
   /// Register a callback that gets called right after a SandboxIR instruction
   /// is created. Note that this will also be called when reverting the removal
   /// of an instruction.
   /// \Returns a callback ID for later deregistration.
-  int registerInsertInstrCallback(InsertInstrCallback CB);
-  void unregisterInsertInstrCallback(int CallbackId);
+  CallbackID registerInsertInstrCallback(InsertInstrCallback CB);
+  void unregisterInsertInstrCallback(CallbackID ID);
 
   /// Register a callback that gets called when a SandboxIR instruction is about
   /// to be moved. Note that this will also be called when reverting a move.
   /// \Returns a callback ID for later deregistration.
-  int registerMoveInstrCallback(MoveInstrCallback CB);
-  void unregisterMoveInstrCallback(int CallbackId);
+  CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
+  void unregisterMoveInstrCallback(CallbackID ID);
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index f22945a076aa28..66ec08757ca312 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -686,8 +686,8 @@ int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
   return ID;
 }
 void Context::unregisterRemoveInstrCallback(CallbackID ID) {
-  [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID);
-  assert(erased &&
+  [[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID);
+  assert(Erased &&
          "Callback ID not found in RemoveInstrCallbacks during deregistration");
 }
 
@@ -697,8 +697,8 @@ int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
   return ID;
 }
 void Context::unregisterInsertInstrCallback(CallbackID ID) {
-  [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID);
-  assert(erased &&
+  [[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID);
+  assert(Erased &&
          "Callback ID not found in InsertInstrCallbacks during deregistration");
 }
 
@@ -708,8 +708,8 @@ int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
   return ID;
 }
 void Context::unregisterMoveInstrCallback(CallbackID ID) {
-  [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID);
-  assert(erased &&
+  [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
+  assert(Erased &&
          "Callback ID not found in MoveInstrCallbacks during deregistration");
 }
 

>From fa875d7995cbcfb850e84f3a7dbab0a9e071ae41 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 14:29:40 -0700
Subject: [PATCH 4/8] Make NextCallbackID not static

---
 llvm/include/llvm/SandboxIR/Context.h | 2 +-
 llvm/lib/SandboxIR/Context.cpp        | 2 --
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index cc3572db6039da..f1ee6b6de222fa 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -77,7 +77,7 @@ class Context {
   /// A counter used for assigning callback IDs during registration. The same
   /// counter is used for all kinds of callbacks so we can detect mismatched
   /// registration/deregistration.
-  static CallbackID NextCallbackID;
+  CallbackID NextCallbackID = 0;
 
   /// Remove \p V from the maps and returns the unique_ptr.
   std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 66ec08757ca312..213ad7f5c6d8a3 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -678,8 +678,6 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
     CBEntry.second(I, WhereIt);
 }
 
-Context::CallbackID Context::NextCallbackID = 0;
-
 int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
   CallbackID ID = NextCallbackID++;
   RemoveInstrCallbacks[ID] = CB;

>From 906854107b4dfda37ed95db066fab308a8a2c6e8 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 18:07:23 -0700
Subject: [PATCH 5/8] Switched callbacks to MapVector to iterate in
 registration order.

Added test to check that registration order is respected when invoking
registered callbacks.
---
 llvm/include/llvm/SandboxIR/Context.h      |  7 ++++---
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 13 +++++++++++++
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index f1ee6b6de222fa..b8e1f667f14675 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -10,6 +10,7 @@
 #define LLVM_SANDBOXIR_CONTEXT_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/SandboxIR/Tracker.h"
@@ -66,13 +67,13 @@ class Context {
 
   /// Callbacks called when an IR instruction is about to get removed. Keys are
   /// used as IDs for deregistration.
-  DenseMap<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
+  MapVector<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
   /// Callbacks called when an IR instruction is about to get inserted. Keys are
   /// used as IDs for deregistration.
-  DenseMap<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
+  MapVector<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
   /// Callbacks called when an IR instruction is about to get moved. Keys are
   /// used as IDs for deregistration.
-  DenseMap<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
+  MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
 
   /// A counter used for assigning callback IDs during registration. The same
   /// counter is used for all kinds of callbacks so we can detect mismatched
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 268c6e0712c505..5bad56b4064478 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -6001,12 +6001,22 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
           Moved.push_back(std::make_pair(I, &*Where));
       });
 
+  // Two more insertion callbacks, to check that they're called in registration
+  // order.
+  SmallVector<int> Order;
+  auto CheckOrderInsertCbId1 = Ctx.registerInsertInstrCallback(
+      [&Order](sandboxir::Instruction *I) { Order.push_back(1); });
+
+  auto CheckOrderInsertCbId2 = Ctx.registerInsertInstrCallback(
+      [&Order](sandboxir::Instruction *I) { Order.push_back(2); });
+
   Ctx.save();
   auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
                                             Ret->getIterator(), Ctx);
   EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
   EXPECT_THAT(Removed, testing::IsEmpty());
   EXPECT_THAT(Moved, testing::IsEmpty());
+  EXPECT_THAT(Order, testing::ElementsAre(1, 2));
 
   Ret->moveBefore(NewI);
   EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
@@ -6030,6 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
   EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
                                           std::make_pair(Ret, nullptr)));
+  EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2));
 
   // Check that deregistration works. Do an operation of each type after
   // deregistering callbacks and check.
@@ -6039,6 +6050,8 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   Ctx.unregisterInsertInstrCallback(InsertCbId);
   Ctx.unregisterRemoveInstrCallback(RemoveCbId);
   Ctx.unregisterMoveInstrCallback(MoveCbId);
+  Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId1);
+  Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId2);
   auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
                                              Ret->getIterator(), Ctx);
   Ret->moveBefore(NewI2);

>From 18ed35d02462594640a5061e04b94db07791fcb6 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 25 Oct 2024 17:17:16 -0700
Subject: [PATCH 6/8] Address more review comments, plus a couple of things
 discussed offline.

- Remove insert/remove callbacks to create/erase. Don't call the erase
  callback on remove. This way we have a consistent model, but we don't
  have insert/remove callbacks, just create/erase/move. The missing
  callbacks can still be added if needed in the future.

- Changed callback ids to uint64 in the unlikely case they could
  overflow a 32-bit integer in a large compile, causing hard-to-debug
  errors.
---
 llvm/include/llvm/SandboxIR/Context.h      | 38 +++++++++++++---------
 llvm/lib/SandboxIR/Context.cpp             | 34 ++++++++++---------
 llvm/lib/SandboxIR/Instruction.cpp         |  4 +--
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 16 ++++-----
 4 files changed, 49 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index b8e1f667f14675..f2056de87cb946 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -16,6 +16,8 @@
 #include "llvm/SandboxIR/Tracker.h"
 #include "llvm/SandboxIR/Type.h"
 
+#include <cstdint>
+
 namespace llvm::sandboxir {
 
 class Argument;
@@ -26,17 +28,19 @@ class Value;
 
 class Context {
 public:
-  // A RemoveInstrCallback receives the instruction about to be removed.
-  using RemoveInstrCallback = std::function<void(Instruction *)>;
-  // A InsertInstrCallback receives the instruction about to be created.
-  using InsertInstrCallback = std::function<void(Instruction *)>;
+  // A EraseInstrCallback receives the instruction about to be erased.
+  using EraseInstrCallback = std::function<void(Instruction *)>;
+  // A CreateInstrCallback receives the instruction about to be created.
+  using CreateInstrCallback = std::function<void(Instruction *)>;
   // A MoveInstrCallback receives the instruction about to be moved, the
   // destination BB and an iterator pointing to the insertion position.
   using MoveInstrCallback =
       std::function<void(Instruction *, const BBIterator &)>;
 
-  /// An ID for a registered callback. Used for deregistration.
-  using CallbackID = int;
+  /// An ID for a registered callback. Used for deregistration. Using a 64-bit
+  /// integer so we don't have to worry about the unlikely case of overflowing
+  /// a 32-bit counter.
+  using CallbackID = uint64_t;
 
 protected:
   LLVMContext &LLVMCtx;
@@ -65,12 +69,12 @@ class Context {
   /// Type objects.
   DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
 
-  /// Callbacks called when an IR instruction is about to get removed. Keys are
+  /// Callbacks called when an IR instruction is about to get erased. Keys are
   /// used as IDs for deregistration.
-  MapVector<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
-  /// Callbacks called when an IR instruction is about to get inserted. Keys are
+  MapVector<CallbackID, EraseInstrCallback> EraseInstrCallbacks;
+  /// Callbacks called when an IR instruction is about to get created. Keys are
   /// used as IDs for deregistration.
-  MapVector<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
+  MapVector<CallbackID, CreateInstrCallback> CreateInstrCallbacks;
   /// Callbacks called when an IR instruction is about to get moved. Keys are
   /// used as IDs for deregistration.
   MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
@@ -102,8 +106,8 @@ class Context {
   Constant *getOrCreateConstant(llvm::Constant *LLVMC);
   friend class Utils; // For getMemoryBase
 
-  void runRemoveInstrCallbacks(Instruction *I);
-  void runInsertInstrCallbacks(Instruction *I);
+  void runEraseInstrCallbacks(Instruction *I);
+  void runCreateInstrCallbacks(Instruction *I);
   void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
 
   // Friends for getOrCreateConstant().
@@ -239,21 +243,23 @@ class Context {
   /// to be removed from its parent. Note that this will also be called when
   /// reverting the creation of an instruction.
   /// \Returns a callback ID for later deregistration.
-  CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB);
-  void unregisterRemoveInstrCallback(CallbackID ID);
+  CallbackID registerEraseInstrCallback(EraseInstrCallback CB);
+  void unregisterEraseInstrCallback(CallbackID ID);
 
   /// Register a callback that gets called right after a SandboxIR instruction
   /// is created. Note that this will also be called when reverting the removal
   /// of an instruction.
   /// \Returns a callback ID for later deregistration.
-  CallbackID registerInsertInstrCallback(InsertInstrCallback CB);
-  void unregisterInsertInstrCallback(CallbackID ID);
+  CallbackID registerCreateInstrCallback(CreateInstrCallback CB);
+  void unregisterCreateInstrCallback(CallbackID ID);
 
   /// Register a callback that gets called when a SandboxIR instruction is about
   /// to be moved. Note that this will also be called when reverting a move.
   /// \Returns a callback ID for later deregistration.
   CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
   void unregisterMoveInstrCallback(CallbackID ID);
+
+  // TODO: Add callbacks for instructions inserted/removed if needed.
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 213ad7f5c6d8a3..c0b35596928569 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -46,7 +46,7 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
   // creation. This is why the tracker class combines creation and insertion.
   if (auto *I = dyn_cast<Instruction>(V)) {
     getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
-    runInsertInstrCallbacks(I);
+    runCreateInstrCallbacks(I);
   }
 
   return V;
@@ -663,13 +663,13 @@ Module *Context::createModule(llvm::Module *LLVMM) {
   return M;
 }
 
-void Context::runRemoveInstrCallbacks(Instruction *I) {
-  for (const auto &CBEntry : RemoveInstrCallbacks)
+void Context::runEraseInstrCallbacks(Instruction *I) {
+  for (const auto &CBEntry : EraseInstrCallbacks)
     CBEntry.second(I);
 }
 
-void Context::runInsertInstrCallbacks(Instruction *I) {
-  for (auto &CBEntry : InsertInstrCallbacks)
+void Context::runCreateInstrCallbacks(Instruction *I) {
+  for (auto &CBEntry : CreateInstrCallbacks)
     CBEntry.second(I);
 }
 
@@ -678,29 +678,31 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
     CBEntry.second(I, WhereIt);
 }
 
-int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
+Context::CallbackID
+Context::registerEraseInstrCallback(EraseInstrCallback CB) {
   CallbackID ID = NextCallbackID++;
-  RemoveInstrCallbacks[ID] = CB;
+  EraseInstrCallbacks[ID] = CB;
   return ID;
 }
-void Context::unregisterRemoveInstrCallback(CallbackID ID) {
-  [[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID);
+void Context::unregisterEraseInstrCallback(CallbackID ID) {
+  [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
   assert(Erased &&
-         "Callback ID not found in RemoveInstrCallbacks during deregistration");
+         "Callback ID not found in EraseInstrCallbacks during deregistration");
 }
 
-int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
+Context::CallbackID
+Context::registerCreateInstrCallback(CreateInstrCallback CB) {
   CallbackID ID = NextCallbackID++;
-  InsertInstrCallbacks[ID] = CB;
+  CreateInstrCallbacks[ID] = CB;
   return ID;
 }
-void Context::unregisterInsertInstrCallback(CallbackID ID) {
-  [[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID);
+void Context::unregisterCreateInstrCallback(CallbackID ID) {
+  [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
   assert(Erased &&
-         "Callback ID not found in InsertInstrCallbacks during deregistration");
+         "Callback ID not found in CreateInstrCallbacks during deregistration");
 }
 
-int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
+Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
   CallbackID ID = NextCallbackID++;
   MoveInstrCallbacks[ID] = CB;
   return ID;
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index ddeb78eea19f73..096b827541eeaf 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -64,8 +64,6 @@ Instruction *Instruction::getPrevNode() const {
 }
 
 void Instruction::removeFromParent() {
-  Ctx.runRemoveInstrCallbacks(this);
-
   Ctx.getTracker().emplaceIfTracking<RemoveFromParent>(this);
 
   // Detach all the LLVM IR instructions from their parent BB.
@@ -76,7 +74,7 @@ void Instruction::removeFromParent() {
 void Instruction::eraseFromParent() {
   assert(users().empty() && "Still connected to users, can't erase!");
 
-  Ctx.runRemoveInstrCallbacks(this);
+  Ctx.runEraseInstrCallbacks(this);
   std::unique_ptr<Value> Detached = Ctx.detach(this);
   auto LLVMInstrs = getLLVMInstrs();
 
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 5bad56b4064478..99e14292a91b92 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5980,11 +5980,11 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   sandboxir::Instruction *Ret = &BB.front();
 
   SmallVector<sandboxir::Instruction *> Inserted;
-  auto InsertCbId = Ctx.registerInsertInstrCallback(
+  auto InsertCbId = Ctx.registerCreateInstrCallback(
       [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
 
   SmallVector<sandboxir::Instruction *> Removed;
-  auto RemoveCbId = Ctx.registerRemoveInstrCallback(
+  auto RemoveCbId = Ctx.registerEraseInstrCallback(
       [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
 
   // Keep the moved instruction and the instruction pointed by the Where
@@ -6004,10 +6004,10 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   // Two more insertion callbacks, to check that they're called in registration
   // order.
   SmallVector<int> Order;
-  auto CheckOrderInsertCbId1 = Ctx.registerInsertInstrCallback(
+  auto CheckOrderInsertCbId1 = Ctx.registerCreateInstrCallback(
       [&Order](sandboxir::Instruction *I) { Order.push_back(1); });
 
-  auto CheckOrderInsertCbId2 = Ctx.registerInsertInstrCallback(
+  auto CheckOrderInsertCbId2 = Ctx.registerCreateInstrCallback(
       [&Order](sandboxir::Instruction *I) { Order.push_back(2); });
 
   Ctx.save();
@@ -6047,11 +6047,11 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   Inserted.clear();
   Removed.clear();
   Moved.clear();
-  Ctx.unregisterInsertInstrCallback(InsertCbId);
-  Ctx.unregisterRemoveInstrCallback(RemoveCbId);
+  Ctx.unregisterCreateInstrCallback(InsertCbId);
+  Ctx.unregisterEraseInstrCallback(RemoveCbId);
   Ctx.unregisterMoveInstrCallback(MoveCbId);
-  Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId1);
-  Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId2);
+  Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId1);
+  Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId2);
   auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
                                              Ret->getIterator(), Ctx);
   Ret->moveBefore(NewI2);

>From 35a20748234d9a448b1c4987ac2b8dc50f1e01df Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 25 Oct 2024 17:27:38 -0700
Subject: [PATCH 7/8] clang-format

---
 llvm/lib/SandboxIR/Context.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index c0b35596928569..0943f9526d0244 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -678,8 +678,7 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
     CBEntry.second(I, WhereIt);
 }
 
-Context::CallbackID
-Context::registerEraseInstrCallback(EraseInstrCallback CB) {
+Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
   CallbackID ID = NextCallbackID++;
   EraseInstrCallbacks[ID] = CB;
   return ID;

>From 367e5b59c9be280e57ff6569c599bda44965c1b4 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Mon, 28 Oct 2024 18:17:38 -0700
Subject: [PATCH 8/8] Add assertion for max callbacks registered at once

---
 llvm/lib/SandboxIR/Context.cpp | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 0943f9526d0244..301b4b784016ea 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -678,7 +678,13 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
     CBEntry.second(I, WhereIt);
 }
 
+// An arbitrary limit, to check for accidental misuse. We expect a small number
+// of callbacks to be registered at a time, but we can increase this number if
+// we discover we needed more.
+static constexpr int MaxRegisteredCallbacks = 16;
+
 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
+  assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && "EraseInstrCallbacks size limit exceeded");
   CallbackID ID = NextCallbackID++;
   EraseInstrCallbacks[ID] = CB;
   return ID;
@@ -691,6 +697,7 @@ void Context::unregisterEraseInstrCallback(CallbackID ID) {
 
 Context::CallbackID
 Context::registerCreateInstrCallback(CreateInstrCallback CB) {
+  assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && "CreateInstrCallbacks size limit exceeded");
   CallbackID ID = NextCallbackID++;
   CreateInstrCallbacks[ID] = CB;
   return ID;
@@ -702,6 +709,7 @@ void Context::unregisterCreateInstrCallback(CallbackID ID) {
 }
 
 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
+  assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && "MoveInstrCallbacks size limit exceeded");
   CallbackID ID = NextCallbackID++;
   MoveInstrCallbacks[ID] = CB;
   return ID;



More information about the llvm-commits mailing list