[llvm] [SandboxIR] SetUse callback (PR #126985)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 12 15:11:38 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/126985
This patch implements a callback mechanism similar to the existing ones, but for getting notified whenever a Use edge gets updated. This is going to be used in a follow up patch by the Dependency Graph.
>From b9828f73952b939612aac9cd9536f7a7f039604c Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 30 Jan 2025 13:15:25 -0800
Subject: [PATCH] [SandboxIR] SetUse callback
This patch implements a callback mechanism similar to the existing ones, but
for getting notified whenever a Use edge gets updated. This is going to be
used in a follow up patch by the Dependency Graph.
---
llvm/include/llvm/SandboxIR/Context.h | 15 ++++-
llvm/lib/SandboxIR/Context.cpp | 18 ++++++
llvm/lib/SandboxIR/User.cpp | 13 +++--
llvm/lib/SandboxIR/Value.cpp | 8 ++-
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 66 ++++++++++++++++++++++
5 files changed, 111 insertions(+), 9 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index a88b0003f55bd..714d1ec78f452 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -26,6 +26,7 @@ class BBIterator;
class Constant;
class Module;
class Value;
+class Use;
class Context {
public:
@@ -37,6 +38,8 @@ class Context {
// destination BB and an iterator pointing to the insertion position.
using MoveInstrCallback =
std::function<void(Instruction *, const BBIterator &)>;
+ // A SetUseCallback receives the Use that is about to get its source set.
+ using SetUseCallback = std::function<void(const Use &, Value *)>;
/// An ID for a registered callback. Used for deregistration. A dedicated type
/// is employed so as to keep IDs opaque to the end user; only Context should
@@ -98,6 +101,9 @@ class Context {
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
+ /// Callbacks called when a Use gets its source set. Keys are used as IDs for
+ /// deregistration.
+ MapVector<CallbackID, SetUseCallback> SetUseCallbacks;
/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
@@ -129,6 +135,10 @@ class Context {
void runEraseInstrCallbacks(Instruction *I);
void runCreateInstrCallbacks(Instruction *I);
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
+ void runSetUseCallbacks(const Use &U, Value *NewSrc);
+
+ friend class User; // For runSetUseCallbacks().
+ friend class Value; // For runSetUseCallbacks().
// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,7 +291,10 @@ class Context {
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
void unregisterMoveInstrCallback(CallbackID ID);
- // TODO: Add callbacks for instructions inserted/removed if needed.
+ /// Register a callback that gets called when a Use gets set.
+ /// \Returns a callback ID for later deregistration.
+ CallbackID registerSetUseCallback(SetUseCallback CB);
+ void unregisterSetUseCallback(CallbackID ID);
};
} // namespace sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 830f2832853fe..6a397b02d6bde 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
CBEntry.second(I, WhereIt);
}
+void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
+ for (auto &CBEntry : SetUseCallbacks)
+ CBEntry.second(U, NewSrc);
+}
+
// An arbitrary limit, to check for accidental misuse. We expect a small number
// of callbacks to be registered at a time, but we can increase this number if
// we discover we needed more.
@@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) {
"Callback ID not found in MoveInstrCallbacks during deregistration");
}
+Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
+ assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
+ "SetUseCallbacks size limit exceeded");
+ CallbackID ID{NextCallbackID++};
+ SetUseCallbacks[ID] = CB;
+ return ID;
+}
+void Context::unregisterSetUseCallback(CallbackID ID) {
+ [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
+ assert(Erased &&
+ "Callback ID not found in SetUseCallbacks during deregistration");
+}
+
} // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/User.cpp b/llvm/lib/SandboxIR/User.cpp
index d7e4656e6e90e..43fd565e23836 100644
--- a/llvm/lib/SandboxIR/User.cpp
+++ b/llvm/lib/SandboxIR/User.cpp
@@ -90,17 +90,20 @@ bool User::classof(const Value *From) {
void User::setOperand(unsigned OperandIdx, Value *Operand) {
assert(isa<llvm::User>(Val) && "No operands!");
- Ctx.getTracker().emplaceIfTracking<UseSet>(getOperandUse(OperandIdx));
+ const auto &U = getOperandUse(OperandIdx);
+ Ctx.getTracker().emplaceIfTracking<UseSet>(U);
+ Ctx.runSetUseCallbacks(U, Operand);
// We are delegating to llvm::User::setOperand().
cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
}
bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
auto &Tracker = Ctx.getTracker();
- if (Tracker.isTracking()) {
- for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
- auto Use = getOperandUse(OpIdx);
- if (Use.get() == FromV)
+ for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
+ auto Use = getOperandUse(OpIdx);
+ if (Use.get() == FromV) {
+ Ctx.runSetUseCallbacks(Use, ToV);
+ if (Tracker.isTracking())
Tracker.emplaceIfTracking<UseSet>(Use);
}
}
diff --git a/llvm/lib/SandboxIR/Value.cpp b/llvm/lib/SandboxIR/Value.cpp
index b9d91c7e11f74..e39bbc44bca00 100644
--- a/llvm/lib/SandboxIR/Value.cpp
+++ b/llvm/lib/SandboxIR/Value.cpp
@@ -51,7 +51,7 @@ void Value::replaceUsesWithIf(
llvm::Value *OtherVal = OtherV->Val;
// We are delegating RUWIf to LLVM IR's RUWIf.
Val->replaceUsesWithIf(
- OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+ OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool {
User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
if (DstU == nullptr)
return false;
@@ -59,6 +59,7 @@ void Value::replaceUsesWithIf(
if (!ShouldReplace(UseToReplace))
return false;
Ctx.getTracker().emplaceIfTracking<UseSet>(UseToReplace);
+ Ctx.runSetUseCallbacks(UseToReplace, OtherV);
return true;
});
}
@@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) {
assert(getType() == Other->getType() &&
"Replacing with Value of different type!");
auto &Tracker = Ctx.getTracker();
- if (Tracker.isTracking()) {
- for (auto Use : uses())
+ for (auto Use : uses()) {
+ Ctx.runSetUseCallbacks(Use, Other);
+ if (Tracker.isTracking())
Tracker.track(std::make_unique<UseSet>(Use));
}
// We are delegating RAUW to LLVM IR's RAUW.
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 9eeac9b60372f..2ad33659c609b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
EXPECT_THAT(Moved, testing::IsEmpty());
}
+// Check callbacks when we set a Use.
+TEST_F(SandboxIRTest, SetUseCallbacks) {
+ parseIR(C, R"IR(
+define void @foo(i8 %v0, i8 %v1) {
+ %add0 = add i8 %v0, %v1
+ %add1 = add i8 %add0, %v1
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *Arg0 = F->getArg(0);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+
+ SmallVector<std::pair<sandboxir::Use, sandboxir::Value *>> UsesSet;
+ auto Id = Ctx.registerSetUseCallback(
+ [&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) {
+ UsesSet.push_back({U, NewSrc});
+ });
+
+ // Now change %add1 operand to not use %add0.
+ Add1->setOperand(0, Arg0);
+ EXPECT_EQ(UsesSet.size(), 1u);
+ EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+ EXPECT_EQ(UsesSet[0].second, Arg0);
+ // Restore to previous state.
+ Add1->setOperand(0, Add0);
+ UsesSet.clear();
+
+ // RAUW
+ Add0->replaceAllUsesWith(Arg0);
+ EXPECT_EQ(UsesSet.size(), 1u);
+ EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+ EXPECT_EQ(UsesSet[0].second, Arg0);
+ // Restore to previous state.
+ Add1->setOperand(0, Add0);
+ UsesSet.clear();
+
+ // RUWIf
+ Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; });
+ EXPECT_EQ(UsesSet.size(), 1u);
+ EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+ EXPECT_EQ(UsesSet[0].second, Arg0);
+ // Restore to previous state.
+ Add1->setOperand(0, Add0);
+ UsesSet.clear();
+
+ // RUOW
+ Add1->replaceUsesOfWith(Add0, Arg0);
+ EXPECT_EQ(UsesSet.size(), 1u);
+ EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+ EXPECT_EQ(UsesSet[0].second, Arg0);
+ // Restore to previous state.
+ Add1->setOperand(0, Add0);
+ UsesSet.clear();
+
+ // Check unregister.
+ Ctx.unregisterSetUseCallback(Id);
+ Add0->replaceAllUsesWith(Arg0);
+ EXPECT_TRUE(UsesSet.empty());
+}
+
TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) {
parseIR(C, R"IR(
define void @foo() {
More information about the llvm-commits
mailing list