[llvm] [SandboxIR] IR Tracker (PR #99238)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 16 14:31:33 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99238

>From 5e8c1c420f6aabfeed6262814680eb045e53cc2e Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 12 Jul 2024 10:24:55 -0700
Subject: [PATCH] [SandboxIR] IR Tracker

This is the first patch in a series of patches for the IR change tracking
component of SandboxIR.
The tracker collects changes in a vector of `IRChangeBase` objects and
provides a `save()`/`accept()`/`revert()` API.

Each type of IR changing event is captured by a dedicated subclass of
`IRChangeBase`. This patch implements only one of them, that for updating
a `sandboxir::Use` source value, named `UseSet`.
---
 llvm/docs/SandboxIR.md                        |  11 ++
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  12 ++
 .../include/llvm/SandboxIR/SandboxIRTracker.h | 181 ++++++++++++++++++
 llvm/include/llvm/SandboxIR/Use.h             |   1 +
 llvm/lib/SandboxIR/CMakeLists.txt             |   1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              |  26 ++-
 llvm/lib/SandboxIR/SandboxIRTracker.cpp       |  84 ++++++++
 llvm/unittests/SandboxIR/CMakeLists.txt       |   1 +
 .../SandboxIR/SandboxIRTrackerTest.cpp        | 154 +++++++++++++++
 9 files changed, 470 insertions(+), 1 deletion(-)
 create mode 100644 llvm/include/llvm/SandboxIR/SandboxIRTracker.h
 create mode 100644 llvm/lib/SandboxIR/SandboxIRTracker.cpp
 create mode 100644 llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp

diff --git a/llvm/docs/SandboxIR.md b/llvm/docs/SandboxIR.md
index 8f8752f102c76..29f5e5ea9346f 100644
--- a/llvm/docs/SandboxIR.md
+++ b/llvm/docs/SandboxIR.md
@@ -51,3 +51,14 @@ For example, for `sandboxir::User::setOperand(OpIdx, sandboxir::Value *Op)`:
 - We get the corresponding LLVM User: `llvm::User *LLVMU = cast<llvm::User>(Val)`
 - Next we get the corresponding LLVM Operand: `llvm::Value *LLVMOp = Op->Val`
 - Finally we modify `LLVMU`'s operand: `LLVMU->setOperand(OpIdx, LLVMOp)
+
+## IR Change Tracking
+Sandbox IR's state can be saved and restored.
+This is done with the help of the tracker component that is tightly coupled to the public Sandbox IR API functions.
+
+To save the state and enable tracking the user needs to call `sandboxir::Context::save()`.
+From this point on any change made to the Sandbox IR state will automatically create a change object and register it with the tracker, without any intervention from the user.
+The changes are accumulated in a vector within the tracker.
+
+To rollback to the saved state the user needs to call `sandboxir::Context::revert()`.
+Reverting back to the saved state is a matter of going over all the accumulated states in reverse and undoing each individual change.
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index fcb581211736e..2e2d5668f2a2c 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -61,6 +61,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/User.h"
 #include "llvm/IR/Value.h"
+#include "llvm/SandboxIR/SandboxIRTracker.h"
 #include "llvm/SandboxIR/Use.h"
 #include "llvm/Support/raw_ostream.h"
 #include <iterator>
@@ -167,6 +168,7 @@ class Value {
 
   friend class Context; // For getting `Val`.
   friend class User;    // For getting `Val`.
+  friend class Use;     // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -630,6 +632,8 @@ class BasicBlock : public Value {
 class Context {
 protected:
   LLVMContext &LLVMCtx;
+  SandboxIRTracker IRTracker;
+
   /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
   /// SandboxIR objects.
   DenseMap<llvm::Value *, std::unique_ptr<sandboxir::Value>>
@@ -667,6 +671,14 @@ class Context {
 public:
   Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {}
 
+  SandboxIRTracker &getTracker() { return IRTracker; }
+  /// Convenience function for `getTracker().save()`
+  void save() { IRTracker.save(); }
+  /// Convenience function for `getTracker().revert()`
+  void revert() { IRTracker.revert(); }
+  /// Convenience function for `getTracker().accept()`
+  void accept() { IRTracker.accept(); }
+
   sandboxir::Value *getValue(llvm::Value *V) const;
   const sandboxir::Value *getValue(const llvm::Value *V) const {
     return getValue(const_cast<llvm::Value *>(V));
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRTracker.h b/llvm/include/llvm/SandboxIR/SandboxIRTracker.h
new file mode 100644
index 0000000000000..8a819c578a156
--- /dev/null
+++ b/llvm/include/llvm/SandboxIR/SandboxIRTracker.h
@@ -0,0 +1,181 @@
+//===- SandboxIRTracker.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file is the component of SandboxIR that tracks all changes made to its
+// state, such that we can revert the state when needed.
+//
+// Tracking changes
+// ----------------
+// The user needs to call `SandboxIRTracker::save()` to enable tracking changes
+// made to SandboxIR. From that point on, any change made to SandboxIR, will
+// automatically create a change tracking object and register it with the
+// tracker. IR-change objects are subclasses of `IRChangeBase` and get
+// registered with the `SandboxIRTracker::track()` function. The change objects
+// are saved in the order they are registered with the tracker and are stored in
+// the `SandboxIRTracker::Changes` vector. All of this is done transparently to
+// the user.
+//
+// Reverting changes
+// -----------------
+// Calling `SandboxIRTracker::revert()` will restore the state saved when
+// `SandboxIRTracker::save()` was called. Internally this goes through the
+// change objects in `SandboxIRTracker::Changes` in reverse order, calling their
+// `IRChangeBase::revert()` function one by one.
+//
+// Accepting changes
+// -----------------
+// The user needs to either revert or accept changes before the tracker object
+// is destroyed, or else the tracker destructor will cause a crash.
+// This is the job of `SandboxIRTracker::accept()`. Internally this will go
+// through the change objects in `SandboxIRTracker::Changes` in order, calling
+// `IRChangeBase::accept()`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SANDBOXIR_SANDBOXIRTRACKER_H
+#define LLVM_SANDBOXIR_SANDBOXIRTRACKER_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/SandboxIR/Use.h"
+#include "llvm/Support/Debug.h"
+#include <memory>
+#include <regex>
+
+namespace llvm::sandboxir {
+
+class BasicBlock;
+
+/// Each IR change type has an ID.
+enum class TrackID {
+  UseSet,
+};
+
+#ifndef NDEBUG
+static const char *trackIDToStr(TrackID ID) {
+  switch (ID) {
+  case TrackID::UseSet:
+    return "UseSet";
+  }
+  llvm_unreachable("Unimplemented ID");
+}
+#endif // NDEBUG
+
+class SandboxIRTracker;
+
+/// The base class for IR Change classes.
+class IRChangeBase {
+protected:
+#ifndef NDEBUG
+  unsigned Idx = 0;
+#endif
+  const TrackID ID;
+  SandboxIRTracker &Parent;
+
+public:
+  IRChangeBase(TrackID ID, SandboxIRTracker &Parent);
+  TrackID getTrackID() const { return ID; }
+  /// This runs when changes get reverted.
+  virtual void revert() = 0;
+  /// This runs when changes get accepted.
+  virtual void accept() = 0;
+  virtual ~IRChangeBase() = default;
+#ifndef NDEBUG
+  void dumpCommon(raw_ostream &OS) const {
+    OS << Idx << ". " << trackIDToStr(ID);
+  }
+  virtual void dump(raw_ostream &OS) const = 0;
+  LLVM_DUMP_METHOD virtual void dump() const = 0;
+#endif
+};
+
+/// Change the source Value of a sandboxir::Use.
+class UseSet : public IRChangeBase {
+  Use U;
+  Value *OrigV = nullptr;
+
+public:
+  UseSet(const Use &U, SandboxIRTracker &Tracker)
+      : IRChangeBase(TrackID::UseSet, Tracker), U(U), OrigV(U.get()) {}
+  // For isa<> etc.
+  static bool classof(const IRChangeBase *Other) {
+    return Other->getTrackID() == TrackID::UseSet;
+  }
+  void revert() final { U.set(OrigV); }
+  void accept() final {}
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const final { dumpCommon(OS); }
+  LLVM_DUMP_METHOD void dump() const final;
+  friend raw_ostream &operator<<(raw_ostream &OS, const UseSet &C) {
+    C.dump(OS);
+    return OS;
+  }
+#endif
+};
+
+/// The tracker collects all the change objects and implements the main API for
+/// saving / reverting / accepting.
+class SandboxIRTracker {
+public:
+  enum class TrackerState {
+    Disabled, ///> Tracking is disabled
+    Record,   ///> Tracking changes
+    Revert,   ///> Undoing changes
+    Accept,   ///> Accepting changes
+  };
+
+private:
+  /// The list of changes that are being tracked.
+  SmallVector<std::unique_ptr<IRChangeBase>> Changes;
+  /// The current state of the tracker.
+  TrackerState State = TrackerState::Disabled;
+
+public:
+#ifndef NDEBUG
+  /// Helps catch bugs where we are creating new change objects while in the
+  /// middle of creating other change objects.
+  bool InMiddleOfCreatingChange = false;
+#endif // NDEBUG
+
+  SandboxIRTracker() = default;
+  ~SandboxIRTracker();
+  /// Record \p Change and take ownership. This is the main function used to
+  /// track Sandbox IR changes.
+  void track(std::unique_ptr<IRChangeBase> &&Change);
+  /// \Returns true if the tracker is recording changes.
+  bool tracking() const { return State == TrackerState::Record; }
+  /// \Returns the current state of the tracker.
+  TrackerState getState() const { return State; }
+  /// Turns on IR tracking.
+  void save();
+  /// Stops tracking and accept changes.
+  void accept();
+  /// Stops tracking and reverts to saved state.
+  void revert();
+  /// \Returns the number of change entries recorded so far.
+  unsigned size() const { return Changes.size(); }
+  /// \Returns true if there are no change entries recorded so far.
+  bool empty() const { return Changes.empty(); }
+
+#ifndef NDEBUG
+  /// \Returns the \p Idx'th change. This is used for testing.
+  IRChangeBase *getChange(unsigned Idx) const { return Changes[Idx].get(); }
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+  friend raw_ostream &operator<<(raw_ostream &OS, const SandboxIRTracker &C) {
+    C.dump(OS);
+    return OS;
+  }
+#endif // NDEBUG
+};
+
+} // namespace llvm::sandboxir
+
+#endif // LLVM_SANDBOXIR_SANDBOXIRTRACKER_H
diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h
index 33afb54c1ff29..d77b4568d0fab 100644
--- a/llvm/include/llvm/SandboxIR/Use.h
+++ b/llvm/include/llvm/SandboxIR/Use.h
@@ -44,6 +44,7 @@ class Use {
 public:
   operator Value *() const { return get(); }
   Value *get() const;
+  void set(Value *V);
   class User *getUser() const { return Usr; }
   unsigned getOperandNo() const;
   Context *getContext() const { return Ctx; }
diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt
index 225eca0cadd1a..74b31fe869aed 100644
--- a/llvm/lib/SandboxIR/CMakeLists.txt
+++ b/llvm/lib/SandboxIR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMSandboxIR
   SandboxIR.cpp
+  SandboxIRTracker.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index a3f350e9ca8b0..a9f564c6591b6 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -16,6 +16,8 @@ using namespace llvm::sandboxir;
 
 Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }
 
+void Use::set(Value *V) { LLVMUse->set(V->Val); }
+
 unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
 
 #ifndef NDEBUG
@@ -112,13 +114,24 @@ void Value::replaceUsesWithIf(
         User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
         if (DstU == nullptr)
           return false;
-        return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
+        Use UseToReplace(&LLVMUse, DstU, Ctx);
+        if (!ShouldReplace(UseToReplace))
+          return false;
+        auto &Tracker = Ctx.getTracker();
+        if (Tracker.tracking())
+          Tracker.track(std::make_unique<UseSet>(UseToReplace, Tracker));
+        return true;
       });
 }
 
 void Value::replaceAllUsesWith(Value *Other) {
   assert(getType() == Other->getType() &&
          "Replacing with Value of different type!");
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.tracking()) {
+    for (auto Use : uses())
+      Tracker.track(std::make_unique<UseSet>(Use, Tracker));
+  }
   Val->replaceAllUsesWith(Other->Val);
 }
 
@@ -208,10 +221,21 @@ bool User::classof(const Value *From) {
 
 void User::setOperand(unsigned OperandIdx, Value *Operand) {
   assert(isa<llvm::User>(Val) && "No operands!");
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.tracking())
+    Tracker.track(std::make_unique<UseSet>(getOperandUse(OperandIdx), Tracker));
   cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
 }
 
 bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
+  auto &Tracker = Ctx.getTracker();
+  if (Tracker.tracking()) {
+    for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
+      auto Use = getOperandUse(OpIdx);
+      if (Use.get() == FromV)
+        Tracker.track(std::make_unique<UseSet>(Use, Tracker));
+    }
+  }
   return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
 }
 
diff --git a/llvm/lib/SandboxIR/SandboxIRTracker.cpp b/llvm/lib/SandboxIR/SandboxIRTracker.cpp
new file mode 100644
index 0000000000000..0b62df46c020c
--- /dev/null
+++ b/llvm/lib/SandboxIR/SandboxIRTracker.cpp
@@ -0,0 +1,84 @@
+//===- SandboxIRTracker.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/SandboxIR/SandboxIRTracker.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include <sstream>
+
+using namespace llvm::sandboxir;
+
+IRChangeBase::IRChangeBase(TrackID ID, SandboxIRTracker &Parent)
+    : ID(ID), Parent(Parent) {
+#ifndef NDEBUG
+  Idx = Parent.size();
+
+  assert(!Parent.InMiddleOfCreatingChange &&
+         "We are in the middle of creating another change!");
+  if (Parent.tracking())
+    Parent.InMiddleOfCreatingChange = true;
+#endif // NDEBUG
+}
+
+#ifndef NDEBUG
+void UseSet::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
+SandboxIRTracker::~SandboxIRTracker() {
+  assert(Changes.empty() && "You must accept or revert changes!");
+}
+
+void SandboxIRTracker::track(std::unique_ptr<IRChangeBase> &&Change) {
+#ifndef NDEBUG
+  assert(State != TrackerState::Revert &&
+         "No changes should be tracked during revert()!");
+#endif // NDEBUG
+  Changes.push_back(std::move(Change));
+
+#ifndef NDEBUG
+  InMiddleOfCreatingChange = false;
+#endif
+}
+
+void SandboxIRTracker::save() { State = TrackerState::Record; }
+
+void SandboxIRTracker::revert() {
+  auto SavedState = State;
+  State = TrackerState::Revert;
+  for (auto &Change : reverse(Changes))
+    Change->revert();
+  Changes.clear();
+  State = SavedState;
+}
+
+void SandboxIRTracker::accept() {
+  auto SavedState = State;
+  State = TrackerState::Accept;
+  for (auto &Change : Changes)
+    Change->accept();
+  Changes.clear();
+  State = SavedState;
+}
+
+#ifndef NDEBUG
+void SandboxIRTracker::dump(raw_ostream &OS) const {
+  for (const auto &ChangePtr : Changes) {
+    ChangePtr->dump(OS);
+    OS << "\n";
+  }
+}
+void SandboxIRTracker::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt
index 362653bfff965..1bb1a6efbef30 100644
--- a/llvm/unittests/SandboxIR/CMakeLists.txt
+++ b/llvm/unittests/SandboxIR/CMakeLists.txt
@@ -6,4 +6,5 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(SandboxIRTests
   SandboxIRTest.cpp
+  SandboxIRTrackerTest.cpp
   )
diff --git a/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp
new file mode 100644
index 0000000000000..380d5d9ac1fd8
--- /dev/null
+++ b/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp
@@ -0,0 +1,154 @@
+//===- SandboxIRTrackerTest.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Module.h"
+#include "llvm/SandboxIR/SandboxIR.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+struct SandboxIRTrackerTest : public testing::Test {
+  LLVMContext C;
+  std::unique_ptr<Module> M;
+
+  void parseIR(LLVMContext &C, const char *IR) {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("SandboxIRTrackerTest", errs());
+  }
+  BasicBlock *getBasicBlockByName(Function &F, StringRef Name) {
+    for (BasicBlock &BB : F)
+      if (BB.getName() == Name)
+        return &BB;
+    llvm_unreachable("Expected to find basic block!");
+  }
+};
+
+TEST_F(SandboxIRTrackerTest, SetOperand) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr float, ptr %ptr, i32 0
+  %gep1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %gep0
+  store float undef, ptr %gep0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(&LLVMF);
+  auto *BB = &*F->begin();
+  auto &Tracker = Ctx.getTracker();
+  Tracker.save();
+  auto It = BB->begin();
+  auto *Gep0 = &*It++;
+  auto *Gep1 = &*It++;
+  auto *Ld = &*It++;
+  auto *St = &*It++;
+  St->setOperand(0, Ld);
+  EXPECT_EQ(Tracker.size(), 1u);
+  St->setOperand(1, Gep1);
+  EXPECT_EQ(Tracker.size(), 2u);
+  Ld->setOperand(0, Gep1);
+  EXPECT_EQ(Tracker.size(), 3u);
+  EXPECT_EQ(St->getOperand(0), Ld);
+  EXPECT_EQ(St->getOperand(1), Gep1);
+  EXPECT_EQ(Ld->getOperand(0), Gep1);
+
+  Ctx.getTracker().revert();
+  EXPECT_NE(St->getOperand(0), Ld);
+  EXPECT_EQ(St->getOperand(1), Gep0);
+  EXPECT_EQ(Ld->getOperand(0), Gep0);
+}
+
+TEST_F(SandboxIRTrackerTest, RUWIf_RAUW_RUOW) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %ld0 = load float, ptr %ptr
+  %ld1 = load float, ptr %ptr
+  store float %ld0, ptr %ptr
+  store float %ld0, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+  auto &Tracker = Ctx.getTracker();
+  Ctx.createFunction(&LLVMF);
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  sandboxir::Instruction *Ld0 = &*It++;
+  sandboxir::Instruction *Ld1 = &*It++;
+  sandboxir::Instruction *St0 = &*It++;
+  sandboxir::Instruction *St1 = &*It++;
+  Ctx.save();
+  // Check RUWIf when the lambda returns false.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
+  EXPECT_TRUE(Tracker.empty());
+
+  // Check RUWIf when the lambda returns true.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
+  EXPECT_EQ(Tracker.size(), 2u);
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  EXPECT_EQ(St1->getOperand(0), Ld1);
+  Ctx.revert();
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+
+  // Check RUWIf user == St0.
+  Ctx.save();
+  Ld0->replaceUsesWithIf(
+      Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; });
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+  Ctx.revert();
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+
+  // Check RUWIf user == St1.
+  Ctx.save();
+  Ld0->replaceUsesWithIf(
+      Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; });
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld1);
+  Ctx.revert();
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+
+  // Check RAUW.
+  Ctx.save();
+  Ld1->replaceAllUsesWith(Ld0);
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+  Ctx.revert();
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+
+  // Check RUOW.
+  Ctx.save();
+  St0->replaceUsesOfWith(Ld0, Ld1);
+  EXPECT_EQ(Tracker.size(), 1u);
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  Ctx.revert();
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+
+  // Check accept().
+  St0->replaceUsesOfWith(Ld0, Ld1);
+  EXPECT_EQ(Tracker.size(), 1u);
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  Ctx.accept();
+  EXPECT_TRUE(Tracker.empty());
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+}



More information about the llvm-commits mailing list