[llvm] [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW (PR #98410)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 10 16:27:28 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/98410

>From 044c30df4ea5f85af76a99c96be3dec16886d8b4 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 10 Jul 2024 15:41:18 -0700
Subject: [PATCH] [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW

This patch adds the following member functions:
- User::setOperand()
- User::replaceUsesOfWith()
- Value::replaceAllUsesWith()
- Value::replaceUsesWithIf()
---
 llvm/include/llvm/SandboxIR/SandboxIR.h    |  22 +++++
 llvm/lib/SandboxIR/SandboxIR.cpp           |  44 ++++++++-
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 104 +++++++++++++++++++++
 3 files changed, 168 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 8e87470ee1e5c..3631dd76cfdaf 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -200,6 +200,7 @@ class Value {
   llvm::Value *Val = nullptr;
 
   friend class Context; // For getting `Val`.
+  friend class ValueAttorney;
 
   /// All values point to the context.
   Context &Ctx;
@@ -284,6 +285,11 @@ class Value {
   Type *getType() const { return Val->getType(); }
 
   Context &getContext() const { return Ctx; }
+
+  void replaceUsesWithIf(Value *OtherV,
+                         llvm::function_ref<bool(const Use &)> ShouldReplace);
+  void replaceAllUsesWith(Value *Other);
+
 #ifndef NDEBUG
   /// Should crash if there is something wrong with the instruction.
   virtual void verify() const = 0;
@@ -303,6 +309,13 @@ class Value {
 #endif
 };
 
+/// Helper Attorney-Client class that gives access to the underlying IR.
+class ValueAttorney {
+private:
+  static llvm::Value *getValue(const Value *SBV) { return SBV->Val; }
+  friend class User;
+};
+
 /// Argument of a sandboxir::Function.
 class Argument : public sandboxir::Value {
   Argument(llvm::Argument *Arg, sandboxir::Context &Ctx)
@@ -349,6 +362,10 @@ class User : public Value {
   virtual unsigned getUseOperandNo(const Use &Use) const = 0;
   friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
 
+#ifndef NDEBUG
+  void verifyUserOfLLVMUse(const llvm::Use &Use) const;
+#endif // NDEBUG
+
 public:
   /// For isa/dyn_cast.
   static bool classof(const Value *From);
@@ -387,6 +404,11 @@ class User : public Value {
     return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
   }
 
+  virtual void setOperand(unsigned OperandIdx, Value *Operand);
+  /// Replaces any operands that match \p FromV with \p ToV. Returns whether any
+  /// operands were replaced.
+  bool replaceUsesOfWith(Value *FromV, Value *ToV);
+
 #ifndef NDEBUG
   void verify() const override {
     assert(isa<llvm::User>(Val) && "Expected User!");
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 160d807738a3c..1824b42fe930e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {
 
 unsigned Value::getNumUses() const { return range_size(Val->users()); }
 
+void Value::replaceUsesWithIf(
+    Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
+  assert(getType() == OtherV->getType() && "Can't replace with different type");
+  llvm::Value *OtherVal = OtherV->Val;
+  Val->replaceUsesWithIf(
+      OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+        User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
+        if (DstU == nullptr)
+          return false;
+        return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
+      });
+}
+
+void Value::replaceAllUsesWith(Value *Other) {
+  assert(getType() == Other->getType() &&
+         "Replacing with Value of different type!");
+  Val->replaceAllUsesWith(Other->Val);
+}
+
 #ifndef NDEBUG
 std::string Value::getName() const {
   std::stringstream SS;
@@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
   return Use(LLVMUse, const_cast<User *>(this), Ctx);
 }
 
+#ifndef NDEBUG
+void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
+  assert(Ctx.getValue(Use.getUser()) == this &&
+         "Use not found in this SBUser's operands!");
+}
+#endif
+
 bool User::classof(const Value *From) {
   switch (From->getSubclassID()) {
 #define DEF_VALUE(ID, CLASS)
@@ -180,6 +206,19 @@ bool User::classof(const Value *From) {
   }
 }
 
+void User::setOperand(unsigned OperandIdx, Value *Operand) {
+  if (!isa<llvm::User>(Val))
+    llvm_unreachable("No operands!");
+  cast<llvm::User>(Val)->setOperand(OperandIdx,
+                                    ValueAttorney::getValue(Operand));
+}
+
+bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
+  llvm::Value *FromLLVM = ValueAttorney::getValue(FromV);
+  llvm::Value *ToLLVM = ValueAttorney::getValue(ToV);
+  return cast<llvm::User>(Val)->replaceUsesOfWith(FromLLVM, ToLLVM);
+}
+
 #ifndef NDEBUG
 void User::dumpCommonHeader(raw_ostream &OS) const {
   Value::dumpCommonHeader(OS);
@@ -325,10 +364,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     return It->second.get();
 
   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
+    It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+    auto *NewC = It->second.get();
     for (llvm::Value *COp : C->operands())
       getOrCreateValueInternal(COp, C);
-    It->second = std::unique_ptr<Constant>(new Constant(C, *this));
-    return It->second.get();
+    return NewC;
   }
   if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
     It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 72e81bf640350..16e537efba5de 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
   BasicBlock *LLVMBB = &*LLVMF.begin();
   auto LLVMBBIt = LLVMBB->begin();
   Instruction *LLVMI0 = &*LLVMBBIt++;
+  Instruction *LLVMRet = &*LLVMBBIt++;
+  Argument *LLVMArg0 = LLVMF.getArg(0);
+  Argument *LLVMArg1 = LLVMF.getArg(1);
 
   auto &F = *Ctx.createFunction(&LLVMF);
   auto &BB = *F.begin();
@@ -203,6 +206,107 @@ OperandNo: 0
   EXPECT_FALSE(I0->hasNUses(0u));
   EXPECT_TRUE(I0->hasNUses(1u));
   EXPECT_FALSE(I0->hasNUses(2u));
+
+  // Check User.setOperand().
+  Ret->setOperand(0, Arg0);
+  EXPECT_EQ(Ret->getOperand(0), Arg0);
+  EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
+  EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);
+
+  Ret->setOperand(0, Arg1);
+  EXPECT_EQ(Ret->getOperand(0), Arg1);
+  EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
+  EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
+}
+
+TEST_F(SandboxIRTest, RUOW) {
+  parseIR(C, R"IR(
+declare void @bar0()
+declare void @bar1()
+
+ at glob0 = global ptr @bar0
+ at glob1 = global ptr @bar1
+
+define i32 @foo(i32 %v0, i32 %v1) {
+  %add0 = add i32 %v0, %v1
+  %gep1 = getelementptr i8, ptr @glob0, i32 1
+  %gep2 = getelementptr i8, ptr @glob1, i32 1
+  ret i32 %add0
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto *Arg0 = F.getArg(0);
+  auto *Arg1 = F.getArg(1);
+  auto It = BB.begin();
+  auto *I0 = &*It++;
+  auto *I1 = &*It++;
+  auto *I2 = &*It++;
+  auto *Ret = &*It++;
+
+  bool Replaced;
+  // Try to replace an operand that doesn't match.
+  Replaced = I0->replaceUsesOfWith(Ret, Arg1);
+  EXPECT_FALSE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg0);
+  EXPECT_EQ(I0->getOperand(1), Arg1);
+
+  // Replace I0 operands when operands differ.
+  Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg1);
+  EXPECT_EQ(I0->getOperand(1), Arg1);
+
+  // Replace I0 operands when operands are the same.
+  Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg0);
+  EXPECT_EQ(I0->getOperand(1), Arg0);
+
+  // Replace Ret operand.
+  Replaced = Ret->replaceUsesOfWith(I0, Arg0);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(Ret->getOperand(0), Arg0);
+
+  // Check RAUW on constant.
+  auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
+  auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
+  auto *Glob0Op = Glob0->getOperand(0);
+  Glob0->replaceUsesOfWith(Glob0Op, Glob1);
+  EXPECT_EQ(Glob0->getOperand(0), Glob1);
+}
+
+TEST_F(SandboxIRTest, RAW_RUWIf) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %ld0 = load float, ptr %ptr
+  %ld1 = load float, ptr %ptr
+  store float %ld0, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin();
+
+  Ctx.createFunction(&LLVMF);
+  auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
+  auto It = BB0->begin();
+  sandboxir::Instruction *Ld0 = &*It++;
+  sandboxir::Instruction *Ld1 = &*It++;
+  sandboxir::Instruction *St0 = &*It++;
+  // Check RUWIf when the lambda returns false.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  // Check RUWIf when the lambda returns true.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  // Check RAUW.
+  Ld1->replaceAllUsesWith(Ld0);
+  EXPECT_EQ(St0->getOperand(0), Ld0);
 }
 
 // Check that the operands/users are counted correctly.



More information about the llvm-commits mailing list