[llvm] [SandboxIR] SetUse callback (PR #126985)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 13 12:48:37 PST 2025


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/126985

>From b9828f73952b939612aac9cd9536f7a7f039604c Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 30 Jan 2025 13:15:25 -0800
Subject: [PATCH] [SandboxIR] SetUse callback

This patch implements a callback mechanism similar to the existing ones, but
for getting notified whenever a Use edge gets updated. This is going to be
used in a follow up patch by the Dependency Graph.
---
 llvm/include/llvm/SandboxIR/Context.h      | 15 ++++-
 llvm/lib/SandboxIR/Context.cpp             | 18 ++++++
 llvm/lib/SandboxIR/User.cpp                | 13 +++--
 llvm/lib/SandboxIR/Value.cpp               |  8 ++-
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 66 ++++++++++++++++++++++
 5 files changed, 111 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h
index a88b0003f55bd..714d1ec78f452 100644
--- a/llvm/include/llvm/SandboxIR/Context.h
+++ b/llvm/include/llvm/SandboxIR/Context.h
@@ -26,6 +26,7 @@ class BBIterator;
 class Constant;
 class Module;
 class Value;
+class Use;
 
 class Context {
 public:
@@ -37,6 +38,8 @@ class Context {
   // destination BB and an iterator pointing to the insertion position.
   using MoveInstrCallback =
       std::function<void(Instruction *, const BBIterator &)>;
+  // A SetUseCallback receives the Use that is about to get its source set.
+  using SetUseCallback = std::function<void(const Use &, Value *)>;
 
   /// An ID for a registered callback. Used for deregistration. A dedicated type
   /// is employed so as to keep IDs opaque to the end user; only Context should
@@ -98,6 +101,9 @@ class Context {
   /// Callbacks called when an IR instruction is about to get moved. Keys are
   /// used as IDs for deregistration.
   MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;
+  /// Callbacks called when a Use gets its source set. Keys are used as IDs for
+  /// deregistration.
+  MapVector<CallbackID, SetUseCallback> SetUseCallbacks;
 
   /// A counter used for assigning callback IDs during registration. The same
   /// counter is used for all kinds of callbacks so we can detect mismatched
@@ -129,6 +135,10 @@ class Context {
   void runEraseInstrCallbacks(Instruction *I);
   void runCreateInstrCallbacks(Instruction *I);
   void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
+  void runSetUseCallbacks(const Use &U, Value *NewSrc);
+
+  friend class User;  // For runSetUseCallbacks().
+  friend class Value; // For runSetUseCallbacks().
 
   // Friends for getOrCreateConstant().
 #define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,7 +291,10 @@ class Context {
   CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
   void unregisterMoveInstrCallback(CallbackID ID);
 
-  // TODO: Add callbacks for instructions inserted/removed if needed.
+  /// Register a callback that gets called when a Use gets set.
+  /// \Returns a callback ID for later deregistration.
+  CallbackID registerSetUseCallback(SetUseCallback CB);
+  void unregisterSetUseCallback(CallbackID ID);
 };
 
 } // namespace sandboxir
diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp
index 830f2832853fe..6a397b02d6bde 100644
--- a/llvm/lib/SandboxIR/Context.cpp
+++ b/llvm/lib/SandboxIR/Context.cpp
@@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
     CBEntry.second(I, WhereIt);
 }
 
+void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
+  for (auto &CBEntry : SetUseCallbacks)
+    CBEntry.second(U, NewSrc);
+}
+
 // An arbitrary limit, to check for accidental misuse. We expect a small number
 // of callbacks to be registered at a time, but we can increase this number if
 // we discover we needed more.
@@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) {
          "Callback ID not found in MoveInstrCallbacks during deregistration");
 }
 
+Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
+  assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
+         "SetUseCallbacks size limit exceeded");
+  CallbackID ID{NextCallbackID++};
+  SetUseCallbacks[ID] = CB;
+  return ID;
+}
+void Context::unregisterSetUseCallback(CallbackID ID) {
+  [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
+  assert(Erased &&
+         "Callback ID not found in SetUseCallbacks during deregistration");
+}
+
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/SandboxIR/User.cpp b/llvm/lib/SandboxIR/User.cpp
index d7e4656e6e90e..43fd565e23836 100644
--- a/llvm/lib/SandboxIR/User.cpp
+++ b/llvm/lib/SandboxIR/User.cpp
@@ -90,17 +90,20 @@ bool User::classof(const Value *From) {
 
 void User::setOperand(unsigned OperandIdx, Value *Operand) {
   assert(isa<llvm::User>(Val) && "No operands!");
-  Ctx.getTracker().emplaceIfTracking<UseSet>(getOperandUse(OperandIdx));
+  const auto &U = getOperandUse(OperandIdx);
+  Ctx.getTracker().emplaceIfTracking<UseSet>(U);
+  Ctx.runSetUseCallbacks(U, Operand);
   // We are delegating to llvm::User::setOperand().
   cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
 }
 
 bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
   auto &Tracker = Ctx.getTracker();
-  if (Tracker.isTracking()) {
-    for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
-      auto Use = getOperandUse(OpIdx);
-      if (Use.get() == FromV)
+  for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
+    auto Use = getOperandUse(OpIdx);
+    if (Use.get() == FromV) {
+      Ctx.runSetUseCallbacks(Use, ToV);
+      if (Tracker.isTracking())
         Tracker.emplaceIfTracking<UseSet>(Use);
     }
   }
diff --git a/llvm/lib/SandboxIR/Value.cpp b/llvm/lib/SandboxIR/Value.cpp
index b9d91c7e11f74..e39bbc44bca00 100644
--- a/llvm/lib/SandboxIR/Value.cpp
+++ b/llvm/lib/SandboxIR/Value.cpp
@@ -51,7 +51,7 @@ void Value::replaceUsesWithIf(
   llvm::Value *OtherVal = OtherV->Val;
   // We are delegating RUWIf to LLVM IR's RUWIf.
   Val->replaceUsesWithIf(
-      OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+      OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool {
         User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
         if (DstU == nullptr)
           return false;
@@ -59,6 +59,7 @@ void Value::replaceUsesWithIf(
         if (!ShouldReplace(UseToReplace))
           return false;
         Ctx.getTracker().emplaceIfTracking<UseSet>(UseToReplace);
+        Ctx.runSetUseCallbacks(UseToReplace, OtherV);
         return true;
       });
 }
@@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) {
   assert(getType() == Other->getType() &&
          "Replacing with Value of different type!");
   auto &Tracker = Ctx.getTracker();
-  if (Tracker.isTracking()) {
-    for (auto Use : uses())
+  for (auto Use : uses()) {
+    Ctx.runSetUseCallbacks(Use, Other);
+    if (Tracker.isTracking())
       Tracker.track(std::make_unique<UseSet>(Use));
   }
   // We are delegating RAUW to LLVM IR's RAUW.
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 9eeac9b60372f..2ad33659c609b 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) {
   EXPECT_THAT(Moved, testing::IsEmpty());
 }
 
+// Check callbacks when we set a Use.
+TEST_F(SandboxIRTest, SetUseCallbacks) {
+  parseIR(C, R"IR(
+define void @foo(i8 %v0, i8 %v1) {
+  %add0 = add i8 %v0, %v1
+  %add1 = add i8 %add0, %v1
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *Arg0 = F->getArg(0);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+
+  SmallVector<std::pair<sandboxir::Use, sandboxir::Value *>> UsesSet;
+  auto Id = Ctx.registerSetUseCallback(
+      [&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) {
+        UsesSet.push_back({U, NewSrc});
+      });
+
+  // Now change %add1 operand to not use %add0.
+  Add1->setOperand(0, Arg0);
+  EXPECT_EQ(UsesSet.size(), 1u);
+  EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+  EXPECT_EQ(UsesSet[0].second, Arg0);
+  // Restore to previous state.
+  Add1->setOperand(0, Add0);
+  UsesSet.clear();
+
+  // RAUW
+  Add0->replaceAllUsesWith(Arg0);
+  EXPECT_EQ(UsesSet.size(), 1u);
+  EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+  EXPECT_EQ(UsesSet[0].second, Arg0);
+  // Restore to previous state.
+  Add1->setOperand(0, Add0);
+  UsesSet.clear();
+
+  // RUWIf
+  Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; });
+  EXPECT_EQ(UsesSet.size(), 1u);
+  EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+  EXPECT_EQ(UsesSet[0].second, Arg0);
+  // Restore to previous state.
+  Add1->setOperand(0, Add0);
+  UsesSet.clear();
+
+  // RUOW
+  Add1->replaceUsesOfWith(Add0, Arg0);
+  EXPECT_EQ(UsesSet.size(), 1u);
+  EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get());
+  EXPECT_EQ(UsesSet[0].second, Arg0);
+  // Restore to previous state.
+  Add1->setOperand(0, Add0);
+  UsesSet.clear();
+
+  // Check unregister.
+  Ctx.unregisterSetUseCallback(Id);
+  Add0->replaceAllUsesWith(Arg0);
+  EXPECT_TRUE(UsesSet.empty());
+}
+
 TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) {
   parseIR(C, R"IR(
 define void @foo() {



More information about the llvm-commits mailing list