[llvm] [SandboxIR] Add callbacks for instruction insert/remove/move ops (PR #112965)
Jorge Gorbe Moya via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 18 18:08:28 PDT 2024
https://github.com/slackito updated https://github.com/llvm/llvm-project/pull/112965
>From 56ffef4a97b888e1c53153f71301befdc3cfd24d Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 12:00:06 -0700
Subject: [PATCH 1/5] [SandboxIR] Add callbacks for instruction
insert/remove/move ops.
---
llvm/include/llvm/SandboxIR/Context.h | 56 +++++++++++++-
llvm/lib/SandboxIR/Context.cpp | 66 +++++++++++++++--
llvm/lib/SandboxIR/Instruction.cpp | 5 ++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 85 ++++++++++++++++++++++
4 files changed, 205 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 1285598a1c0282..836988639a14bb 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -9,18 +9,31 @@
#ifndef LLVM_SANDBOXIR_CONTEXT_H
#define LLVM_SANDBOXIR_CONTEXT_H
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/SandboxIR/Tracker.h"
#include "llvm/SandboxIR/Type.h"
namespace llvm::sandboxir {
-class Module;
-class Value;
class Argument;
+class BBIterator;
class Constant;
+class Module;
+class Value;
class Context {
+public:
+ // A RemoveInstrCallback receives the instruction about to be removed.
+ using RemoveInstrCallback = std::function<void(Instruction *)>;
+ // A InsertInstrCallback receives the instruction about to be created.
+ using InsertInstrCallback = std::function<void(Instruction *)>;
+ // A MoveInstrCallback receives the instruction about to be moved, the
+ // destination BB and an iterator pointing to the insertion position.
+ using MoveInstrCallback =
+ std::function<void(Instruction *, const BBIterator &)>;
+
protected:
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
@@ -48,6 +61,21 @@ class Context {
/// Type objects.
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
+ /// Callbacks called when an IR instruction is about to get removed. Keys are
+ /// used as IDs for deregistration.
+ DenseMap<int, RemoveInstrCallback> RemoveInstrCallbacks;
+ /// Callbacks called when an IR instruction is about to get inserted. Keys are
+ /// used as IDs for deregistration.
+ DenseMap<int, InsertInstrCallback> InsertInstrCallbacks;
+ /// Callbacks called when an IR instruction is about to get moved. Keys are
+ /// used as IDs for deregistration.
+ DenseMap<int, MoveInstrCallback> MoveInstrCallbacks;
+
+ /// A counter used for assigning callback IDs during registration. The same
+ /// counter is used for all kinds of callbacks so we can detect mismatched
+ /// registration/deregistration.
+ static int NextCallbackId;
+
/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -70,6 +98,10 @@ class Context {
Constant *getOrCreateConstant(llvm::Constant *LLVMC);
friend class Utils; // For getMemoryBase
+ void runRemoveInstrCallbacks(Instruction *I);
+ void runInsertInstrCallbacks(Instruction *I);
+ void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
+
// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
#include "llvm/SandboxIR/Values.def"
@@ -198,6 +230,26 @@ class Context {
/// \Returns the number of values registered with Context.
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
+
+ /// Register a callback that gets called when a SandboxIR instruction is about
+ /// to be removed from its parent. Note that this will also be called when
+ /// reverting the creation of an instruction.
+ /// \Returns a callback ID for later deregistration.
+ int registerRemoveInstrCallback(RemoveInstrCallback CB);
+ void unregisterRemoveInstrCallback(int CallbackId);
+
+ /// Register a callback that gets called right after a SandboxIR instruction
+ /// is created. Note that this will also be called when reverting the removal
+ /// of an instruction.
+ /// \Returns a callback ID for later deregistration.
+ int registerInsertInstrCallback(InsertInstrCallback CB);
+ void unregisterInsertInstrCallback(int CallbackId);
+
+ /// Register a callback that gets called when a SandboxIR instruction is about
+ /// to be moved. Note that this will also be called when reverting a move.
+ /// \Returns a callback ID for later deregistration.
+ int registerMoveInstrCallback(MoveInstrCallback CB);
+ void unregisterMoveInstrCallback(int CallbackId);
};
} // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 486e935bc35fba..e13f833f1ba29d 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");
+ Value *V = VPtr.get();
+ [[maybe_unused]] auto Pair =
+ LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
+ assert(Pair.second && "Already exists!");
+
// Track creation of instructions.
// Please note that we don't allow the creation of detached instructions,
// meaning that the instructions need to be inserted into a block upon
// creation. This is why the tracker class combines creation and insertion.
- if (auto *I = dyn_cast<Instruction>(VPtr.get()))
+ if (auto *I = dyn_cast<Instruction>(V)) {
getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
+ runInsertInstrCallbacks(I);
+ }
- Value *V = VPtr.get();
- [[maybe_unused]] auto Pair =
- LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
- assert(Pair.second && "Already exists!");
return V;
}
@@ -660,4 +663,57 @@ Module *Context::createModule(llvm::Module *LLVMM) {
return M;
}
+void Context::runRemoveInstrCallbacks(Instruction *I) {
+ for (const auto &CBEntry : RemoveInstrCallbacks) {
+ CBEntry.second(I);
+ }
+}
+
+void Context::runInsertInstrCallbacks(Instruction *I) {
+ for (auto &CBEntry : InsertInstrCallbacks) {
+ CBEntry.second(I);
+ }
+}
+
+void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
+ for (auto &CBEntry : MoveInstrCallbacks) {
+ CBEntry.second(I, WhereIt);
+ }
+}
+
+int Context::NextCallbackId = 0;
+
+int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
+ int Id = NextCallbackId++;
+ RemoveInstrCallbacks[Id] = CB;
+ return Id;
+}
+void Context::unregisterRemoveInstrCallback(int CallbackId) {
+ [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId);
+ assert(erased &&
+ "Callback id not found in RemoveInstrCallbacks during deregistration");
+}
+
+int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
+ int Id = NextCallbackId++;
+ InsertInstrCallbacks[Id] = CB;
+ return Id;
+}
+void Context::unregisterInsertInstrCallback(int CallbackId) {
+ [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId);
+ assert(erased &&
+ "Callback id not found in InsertInstrCallbacks during deregistration");
+}
+
+int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
+ int Id = NextCallbackId++;
+ MoveInstrCallbacks[Id] = CB;
+ return Id;
+}
+void Context::unregisterMoveInstrCallback(int CallbackId) {
+ [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId);
+ assert(erased &&
+ "Callback id not found in MoveInstrCallbacks during deregistration");
+}
+
} // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index d80d10370e32d8..ddeb78eea19f73 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -64,6 +64,8 @@ Instruction *Instruction::getPrevNode() const {
}
void Instruction::removeFromParent() {
+ Ctx.runRemoveInstrCallbacks(this);
+
Ctx.getTracker().emplaceIfTracking<RemoveFromParent>(this);
// Detach all the LLVM IR instructions from their parent BB.
@@ -73,6 +75,8 @@ void Instruction::removeFromParent() {
void Instruction::eraseFromParent() {
assert(users().empty() && "Still connected to users, can't erase!");
+
+ Ctx.runRemoveInstrCallbacks(this);
std::unique_ptr<Value> Detached = Ctx.detach(this);
auto LLVMInstrs = getLLVMInstrs();
@@ -100,6 +104,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
// Destination is same as origin, nothing to do.
return;
+ Ctx.runMoveInstrCallbacks(this, WhereIt);
Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);
auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 97113b303f72e5..786580d1046a60 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -22,6 +22,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock-matchers.h"
+#include "gmock/gmock-more-matchers.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -5962,3 +5963,87 @@ TEST_F(SandboxIRTest, CheckClassof) {
EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
#include "llvm/SandboxIR/Values.def"
}
+
+TEST_F(SandboxIRTest, InstructionCallbacks) {
+ parseIR(C, R"IR(
+ define void @foo(ptr %ptr, i8 %val) {
+ ret void
+ }
+ )IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ sandboxir::Argument *Ptr = F.getArg(0);
+ sandboxir::Argument *Val = F.getArg(1);
+ sandboxir::Instruction *Ret = &BB.front();
+
+ SmallVector<sandboxir::Instruction *> Inserted;
+ int InsertCbId = Ctx.registerInsertInstrCallback(
+ [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
+
+ SmallVector<sandboxir::Instruction *> Removed;
+ int RemoveCbId = Ctx.registerRemoveInstrCallback(
+ [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
+
+ // Keep the moved instruction and the instruction pointed by the Where
+ // iterator so we can check both callback arguments work as expected.
+ SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
+ Moved;
+ int MoveCbId = Ctx.registerMoveInstrCallback(
+ [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
+ // Use a nullptr to signal "move to end" to keep it single. We only
+ // have a basic block in this test case anyway.
+ if (Where == Where.getNodeParent()->end())
+ Moved.push_back(std::make_pair(I, nullptr));
+ else
+ Moved.push_back(std::make_pair(I, &*Where));
+ });
+
+ Ctx.save();
+ auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
+ Ret->getIterator(), Ctx);
+ EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+ EXPECT_THAT(Removed, testing::IsEmpty());
+ EXPECT_THAT(Moved, testing::IsEmpty());
+
+ Ret->moveBefore(NewI);
+ EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+ EXPECT_THAT(Removed, testing::IsEmpty());
+ EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+ Ret->eraseFromParent();
+ EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+ EXPECT_THAT(Removed, testing::ElementsAre(Ret));
+ EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+ NewI->eraseFromParent();
+ EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
+ EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
+ EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
+
+ // Check that after revert the callbacks have been called for the inverse
+ // operations of the changes made so far.
+ Ctx.revert();
+ EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
+ EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
+ EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
+ std::make_pair(Ret, nullptr)));
+
+ // Check that deregistration works. Do an operation of each type after
+ // deregistering callbacks and check.
+ Inserted.clear();
+ Removed.clear();
+ Moved.clear();
+ Ctx.unregisterInsertInstrCallback(InsertCbId);
+ Ctx.unregisterRemoveInstrCallback(RemoveCbId);
+ Ctx.unregisterMoveInstrCallback(MoveCbId);
+ auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
+ Ret->getIterator(), Ctx);
+ Ret->moveBefore(NewI2);
+ Ret->eraseFromParent();
+ EXPECT_THAT(Inserted, testing::IsEmpty());
+ EXPECT_THAT(Removed, testing::IsEmpty());
+ EXPECT_THAT(Moved, testing::IsEmpty());
+}
>From f361d5cc0cde6c4ddfc85bc719025cf84b0818bd Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 13:44:21 -0700
Subject: [PATCH 2/5] Address some review feedback.
- Introduced `CallbackID` typedef for callback ids rather than a plain
int.
- Remove unnecessary braces.
---
llvm/include/llvm/SandboxIR/Context.h | 11 +++--
llvm/lib/SandboxIR/Context.cpp | 47 ++++++++++------------
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 8 ++--
3 files changed, 33 insertions(+), 33 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 836988639a14bb..2da80481f9b6c4 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -34,6 +34,9 @@ class Context {
using MoveInstrCallback =
std::function<void(Instruction *, const BBIterator &)>;
+ /// An ID for a registered callback. Used for deregistration.
+ using CallbackID = int;
+
protected:
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
@@ -63,18 +66,18 @@ class Context {
/// Callbacks called when an IR instruction is about to get removed. Keys are
/// used as IDs for deregistration.
- DenseMap<int, RemoveInstrCallback> RemoveInstrCallbacks;
+ DenseMap<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
/// Callbacks called when an IR instruction is about to get inserted. Keys are
/// used as IDs for deregistration.
- DenseMap<int, InsertInstrCallback> InsertInstrCallbacks;
+ DenseMap<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
- DenseMap<int, MoveInstrCallback> MoveInstrCallbacks;
+ DenseMap<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
/// registration/deregistration.
- static int NextCallbackId;
+ static CallbackID NextCallbackID;
/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index e13f833f1ba29d..f22945a076aa28 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -664,56 +664,53 @@ Module *Context::createModule(llvm::Module *LLVMM) {
}
void Context::runRemoveInstrCallbacks(Instruction *I) {
- for (const auto &CBEntry : RemoveInstrCallbacks) {
+ for (const auto &CBEntry : RemoveInstrCallbacks)
CBEntry.second(I);
- }
}
void Context::runInsertInstrCallbacks(Instruction *I) {
- for (auto &CBEntry : InsertInstrCallbacks) {
+ for (auto &CBEntry : InsertInstrCallbacks)
CBEntry.second(I);
- }
}
void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
- for (auto &CBEntry : MoveInstrCallbacks) {
+ for (auto &CBEntry : MoveInstrCallbacks)
CBEntry.second(I, WhereIt);
- }
}
-int Context::NextCallbackId = 0;
+Context::CallbackID Context::NextCallbackID = 0;
int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
- int Id = NextCallbackId++;
- RemoveInstrCallbacks[Id] = CB;
- return Id;
+ CallbackID ID = NextCallbackID++;
+ RemoveInstrCallbacks[ID] = CB;
+ return ID;
}
-void Context::unregisterRemoveInstrCallback(int CallbackId) {
- [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId);
+void Context::unregisterRemoveInstrCallback(CallbackID ID) {
+ [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID);
assert(erased &&
- "Callback id not found in RemoveInstrCallbacks during deregistration");
+ "Callback ID not found in RemoveInstrCallbacks during deregistration");
}
int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
- int Id = NextCallbackId++;
- InsertInstrCallbacks[Id] = CB;
- return Id;
+ CallbackID ID = NextCallbackID++;
+ InsertInstrCallbacks[ID] = CB;
+ return ID;
}
-void Context::unregisterInsertInstrCallback(int CallbackId) {
- [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId);
+void Context::unregisterInsertInstrCallback(CallbackID ID) {
+ [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID);
assert(erased &&
- "Callback id not found in InsertInstrCallbacks during deregistration");
+ "Callback ID not found in InsertInstrCallbacks during deregistration");
}
int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
- int Id = NextCallbackId++;
- MoveInstrCallbacks[Id] = CB;
- return Id;
+ CallbackID ID = NextCallbackID++;
+ MoveInstrCallbacks[ID] = CB;
+ return ID;
}
-void Context::unregisterMoveInstrCallback(int CallbackId) {
- [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId);
+void Context::unregisterMoveInstrCallback(CallbackID ID) {
+ [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID);
assert(erased &&
- "Callback id not found in MoveInstrCallbacks during deregistration");
+ "Callback ID not found in MoveInstrCallbacks during deregistration");
}
} // namespace llvm::sandboxir
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 786580d1046a60..268c6e0712c505 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5980,18 +5980,18 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
sandboxir::Instruction *Ret = &BB.front();
SmallVector<sandboxir::Instruction *> Inserted;
- int InsertCbId = Ctx.registerInsertInstrCallback(
+ auto InsertCbId = Ctx.registerInsertInstrCallback(
[&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
SmallVector<sandboxir::Instruction *> Removed;
- int RemoveCbId = Ctx.registerRemoveInstrCallback(
+ auto RemoveCbId = Ctx.registerRemoveInstrCallback(
[&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
// Keep the moved instruction and the instruction pointed by the Where
// iterator so we can check both callback arguments work as expected.
SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
Moved;
- int MoveCbId = Ctx.registerMoveInstrCallback(
+ auto MoveCbId = Ctx.registerMoveInstrCallback(
[&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
// Use a nullptr to signal "move to end" to keep it single. We only
// have a basic block in this test case anyway.
@@ -6040,7 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
Ctx.unregisterRemoveInstrCallback(RemoveCbId);
Ctx.unregisterMoveInstrCallback(MoveCbId);
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
- Ret->getIterator(), Ctx);
+ Ret->getIterator(), Ctx);
Ret->moveBefore(NewI2);
Ret->eraseFromParent();
EXPECT_THAT(Inserted, testing::IsEmpty());
>From e84e9e6a3f889f8bbbe1a8d7074e9213ae9904c3 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 14:10:25 -0700
Subject: [PATCH 3/5] Another round of feedback.
- Updated callback (de)registration method signatures in header to use
the CallbackID type instead of int.
- Corrected case in some variable names to follow the style guide.
---
llvm/include/llvm/SandboxIR/Context.h | 12 ++++++------
llvm/lib/SandboxIR/Context.cpp | 12 ++++++------
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index 2da80481f9b6c4..cc3572db6039da 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -238,21 +238,21 @@ class Context {
/// to be removed from its parent. Note that this will also be called when
/// reverting the creation of an instruction.
/// \Returns a callback ID for later deregistration.
- int registerRemoveInstrCallback(RemoveInstrCallback CB);
- void unregisterRemoveInstrCallback(int CallbackId);
+ CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB);
+ void unregisterRemoveInstrCallback(CallbackID ID);
/// Register a callback that gets called right after a SandboxIR instruction
/// is created. Note that this will also be called when reverting the removal
/// of an instruction.
/// \Returns a callback ID for later deregistration.
- int registerInsertInstrCallback(InsertInstrCallback CB);
- void unregisterInsertInstrCallback(int CallbackId);
+ CallbackID registerInsertInstrCallback(InsertInstrCallback CB);
+ void unregisterInsertInstrCallback(CallbackID ID);
/// Register a callback that gets called when a SandboxIR instruction is about
/// to be moved. Note that this will also be called when reverting a move.
/// \Returns a callback ID for later deregistration.
- int registerMoveInstrCallback(MoveInstrCallback CB);
- void unregisterMoveInstrCallback(int CallbackId);
+ CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
+ void unregisterMoveInstrCallback(CallbackID ID);
};
} // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index f22945a076aa28..66ec08757ca312 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -686,8 +686,8 @@ int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
return ID;
}
void Context::unregisterRemoveInstrCallback(CallbackID ID) {
- [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID);
- assert(erased &&
+ [[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID);
+ assert(Erased &&
"Callback ID not found in RemoveInstrCallbacks during deregistration");
}
@@ -697,8 +697,8 @@ int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
return ID;
}
void Context::unregisterInsertInstrCallback(CallbackID ID) {
- [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID);
- assert(erased &&
+ [[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID);
+ assert(Erased &&
"Callback ID not found in InsertInstrCallbacks during deregistration");
}
@@ -708,8 +708,8 @@ int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
return ID;
}
void Context::unregisterMoveInstrCallback(CallbackID ID) {
- [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID);
- assert(erased &&
+ [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
+ assert(Erased &&
"Callback ID not found in MoveInstrCallbacks during deregistration");
}
>From fa875d7995cbcfb850e84f3a7dbab0a9e071ae41 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 14:29:40 -0700
Subject: [PATCH 4/5] Make NextCallbackID not static
---
llvm/include/llvm/SandboxIR/Context.h | 2 +-
llvm/lib/SandboxIR/Context.cpp | 2 --
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index cc3572db6039da..f1ee6b6de222fa 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -77,7 +77,7 @@ class Context {
/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
/// registration/deregistration.
- static CallbackID NextCallbackID;
+ CallbackID NextCallbackID = 0;
/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 66ec08757ca312..213ad7f5c6d8a3 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -678,8 +678,6 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
CBEntry.second(I, WhereIt);
}
-Context::CallbackID Context::NextCallbackID = 0;
-
int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
CallbackID ID = NextCallbackID++;
RemoveInstrCallbacks[ID] = CB;
>From 906854107b4dfda37ed95db066fab308a8a2c6e8 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 18 Oct 2024 18:07:23 -0700
Subject: [PATCH 5/5] Switched callbacks to MapVector to iterate in
registration order.
Added test to check that registration order is respected when invoking
registered callbacks.
---
llvm/include/llvm/SandboxIR/Context.h | 7 ++++---
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 13 +++++++++++++
2 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index f1ee6b6de222fa..b8e1f667f14675 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -10,6 +10,7 @@
#define LLVM_SANDBOXIR_CONTEXT_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/SandboxIR/Tracker.h"
@@ -66,13 +67,13 @@ class Context {
/// Callbacks called when an IR instruction is about to get removed. Keys are
/// used as IDs for deregistration.
- DenseMap<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
+ MapVector<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
/// Callbacks called when an IR instruction is about to get inserted. Keys are
/// used as IDs for deregistration.
- DenseMap<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
+ MapVector<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
- DenseMap<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
+ MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 268c6e0712c505..5bad56b4064478 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -6001,12 +6001,22 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
Moved.push_back(std::make_pair(I, &*Where));
});
+ // Two more insertion callbacks, to check that they're called in registration
+ // order.
+ SmallVector<int> Order;
+ auto CheckOrderInsertCbId1 = Ctx.registerInsertInstrCallback(
+ [&Order](sandboxir::Instruction *I) { Order.push_back(1); });
+
+ auto CheckOrderInsertCbId2 = Ctx.registerInsertInstrCallback(
+ [&Order](sandboxir::Instruction *I) { Order.push_back(2); });
+
Ctx.save();
auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::IsEmpty());
+ EXPECT_THAT(Order, testing::ElementsAre(1, 2));
Ret->moveBefore(NewI);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
@@ -6030,6 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
std::make_pair(Ret, nullptr)));
+ EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2));
// Check that deregistration works. Do an operation of each type after
// deregistering callbacks and check.
@@ -6039,6 +6050,8 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
Ctx.unregisterInsertInstrCallback(InsertCbId);
Ctx.unregisterRemoveInstrCallback(RemoveCbId);
Ctx.unregisterMoveInstrCallback(MoveCbId);
+ Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId1);
+ Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId2);
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
Ret->moveBefore(NewI2);
More information about the llvm-commits
mailing list