[llvm] ecd2bf7 - [SandboxIR] Add setOperand() and RAUW, RUWIf, RUOW (#98410)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 11 21:46:09 PDT 2024


Author: vporpo
Date: 2024-07-11T21:46:05-07:00
New Revision: ecd2bf73cb212452951b3010bbf06e4d96330a92

URL: https://github.com/llvm/llvm-project/commit/ecd2bf73cb212452951b3010bbf06e4d96330a92
DIFF: https://github.com/llvm/llvm-project/commit/ecd2bf73cb212452951b3010bbf06e4d96330a92.diff

LOG: [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW (#98410)

This patch adds the following member functions:
- User::setOperand()
- User::replaceUsesOfWith()
- Value::replaceAllUsesWith()
- Value::replaceUsesWithIf()

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 8e87470ee1e5c..317884fe07681 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -200,6 +200,7 @@ class Value {
   llvm::Value *Val = nullptr;
 
   friend class Context; // For getting `Val`.
+  friend class User;    // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -284,6 +285,11 @@ class Value {
   Type *getType() const { return Val->getType(); }
 
   Context &getContext() const { return Ctx; }
+
+  void replaceUsesWithIf(Value *OtherV,
+                         llvm::function_ref<bool(const Use &)> ShouldReplace);
+  void replaceAllUsesWith(Value *Other);
+
 #ifndef NDEBUG
   /// Should crash if there is something wrong with the instruction.
   virtual void verify() const = 0;
@@ -349,6 +355,10 @@ class User : public Value {
   virtual unsigned getUseOperandNo(const Use &Use) const = 0;
   friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
 
+#ifndef NDEBUG
+  void verifyUserOfLLVMUse(const llvm::Use &Use) const;
+#endif // NDEBUG
+
 public:
   /// For isa/dyn_cast.
   static bool classof(const Value *From);
@@ -387,6 +397,11 @@ class User : public Value {
     return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
   }
 
+  virtual void setOperand(unsigned OperandIdx, Value *Operand);
+  /// Replaces any operands that match \p FromV with \p ToV. Returns whether any
+  /// operands were replaced.
+  bool replaceUsesOfWith(Value *FromV, Value *ToV);
+
 #ifndef NDEBUG
   void verify() const override {
     assert(isa<llvm::User>(Val) && "Expected User!");

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index e4a902fb93166..41b66c07bfd43 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {
 
 unsigned Value::getNumUses() const { return range_size(Val->users()); }
 
+void Value::replaceUsesWithIf(
+    Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
+  assert(getType() == OtherV->getType() && "Can't replace with 
diff erent type");
+  llvm::Value *OtherVal = OtherV->Val;
+  Val->replaceUsesWithIf(
+      OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+        User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
+        if (DstU == nullptr)
+          return false;
+        return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
+      });
+}
+
+void Value::replaceAllUsesWith(Value *Other) {
+  assert(getType() == Other->getType() &&
+         "Replacing with Value of 
diff erent type!");
+  Val->replaceAllUsesWith(Other->Val);
+}
+
 #ifndef NDEBUG
 std::string Value::getName() const {
   std::stringstream SS;
@@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
   return Use(LLVMUse, const_cast<User *>(this), Ctx);
 }
 
+#ifndef NDEBUG
+void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
+  assert(Ctx.getValue(Use.getUser()) == this &&
+         "Use not found in this SBUser's operands!");
+}
+#endif
+
 bool User::classof(const Value *From) {
   switch (From->getSubclassID()) {
 #define DEF_VALUE(ID, CLASS)
@@ -180,6 +206,15 @@ bool User::classof(const Value *From) {
   }
 }
 
+void User::setOperand(unsigned OperandIdx, Value *Operand) {
+  assert(isa<llvm::User>(Val) && "No operands!");
+  cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
+}
+
+bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
+  return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
+}
+
 #ifndef NDEBUG
 void User::dumpCommonHeader(raw_ostream &OS) const {
   Value::dumpCommonHeader(OS);
@@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     return It->second.get();
 
   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
+    It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+    auto *NewC = It->second.get();
     for (llvm::Value *COp : C->operands())
       getOrCreateValueInternal(COp, C);
-    It->second = std::unique_ptr<Constant>(new Constant(C, *this));
-    return It->second.get();
+    return NewC;
   }
   if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
     It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 72e81bf640350..98c0052d878d8 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
   BasicBlock *LLVMBB = &*LLVMF.begin();
   auto LLVMBBIt = LLVMBB->begin();
   Instruction *LLVMI0 = &*LLVMBBIt++;
+  Instruction *LLVMRet = &*LLVMBBIt++;
+  Argument *LLVMArg0 = LLVMF.getArg(0);
+  Argument *LLVMArg1 = LLVMF.getArg(1);
 
   auto &F = *Ctx.createFunction(&LLVMF);
   auto &BB = *F.begin();
@@ -203,6 +206,126 @@ OperandNo: 0
   EXPECT_FALSE(I0->hasNUses(0u));
   EXPECT_TRUE(I0->hasNUses(1u));
   EXPECT_FALSE(I0->hasNUses(2u));
+
+  // Check User.setOperand().
+  Ret->setOperand(0, Arg0);
+  EXPECT_EQ(Ret->getOperand(0), Arg0);
+  EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
+  EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);
+
+  Ret->setOperand(0, Arg1);
+  EXPECT_EQ(Ret->getOperand(0), Arg1);
+  EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
+  EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
+}
+
+TEST_F(SandboxIRTest, RUOW) {
+  parseIR(C, R"IR(
+declare void @bar0()
+declare void @bar1()
+
+ at glob0 = global ptr @bar0
+ at glob1 = global ptr @bar1
+
+define i32 @foo(i32 %arg0, i32 %arg1) {
+  %add0 = add i32 %arg0, %arg1
+  %gep1 = getelementptr i8, ptr @glob0, i32 1
+  %gep2 = getelementptr i8, ptr @glob1, i32 1
+  ret i32 %add0
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto *Arg0 = F.getArg(0);
+  auto *Arg1 = F.getArg(1);
+  auto It = BB.begin();
+  auto *I0 = &*It++;
+  auto *I1 = &*It++;
+  auto *I2 = &*It++;
+  auto *Ret = &*It++;
+
+  bool Replaced;
+  // Try to replace an operand that doesn't match.
+  Replaced = I0->replaceUsesOfWith(Ret, Arg1);
+  EXPECT_FALSE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg0);
+  EXPECT_EQ(I0->getOperand(1), Arg1);
+
+  // Replace I0 operands when operands 
diff er.
+  Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg1);
+  EXPECT_EQ(I0->getOperand(1), Arg1);
+
+  // Replace I0 operands when operands are the same.
+  Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(I0->getOperand(0), Arg0);
+  EXPECT_EQ(I0->getOperand(1), Arg0);
+
+  // Replace Ret operand.
+  Replaced = Ret->replaceUsesOfWith(I0, Arg0);
+  EXPECT_TRUE(Replaced);
+  EXPECT_EQ(Ret->getOperand(0), Arg0);
+
+  // Check RAUW on constant.
+  auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
+  auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
+  auto *Glob0Op = Glob0->getOperand(0);
+  Glob0->replaceUsesOfWith(Glob0Op, Glob1);
+  EXPECT_EQ(Glob0->getOperand(0), Glob1);
+}
+
+TEST_F(SandboxIRTest, RAUW_RUWIf) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %ld0 = load float, ptr %ptr
+  %ld1 = load float, ptr %ptr
+  store float %ld0, ptr %ptr
+  store float %ld0, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+
+  Ctx.createFunction(&LLVMF);
+  auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+  auto It = BB->begin();
+  sandboxir::Instruction *Ld0 = &*It++;
+  sandboxir::Instruction *Ld1 = &*It++;
+  sandboxir::Instruction *St0 = &*It++;
+  sandboxir::Instruction *St1 = &*It++;
+  // Check RUWIf when the lambda returns false.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+  // Check RUWIf when the lambda returns true.
+  Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  EXPECT_EQ(St1->getOperand(0), Ld1);
+  St0->setOperand(0, Ld0);
+  St1->setOperand(0, Ld0);
+  // Check RUWIf user == St0.
+  Ld0->replaceUsesWithIf(
+      Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; });
+  EXPECT_EQ(St0->getOperand(0), Ld1);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
+  St0->setOperand(0, Ld0);
+  // Check RUWIf user == St1.
+  Ld0->replaceUsesWithIf(
+      Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; });
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld1);
+  St1->setOperand(0, Ld0);
+  // Check RAUW.
+  Ld1->replaceAllUsesWith(Ld0);
+  EXPECT_EQ(St0->getOperand(0), Ld0);
+  EXPECT_EQ(St1->getOperand(0), Ld0);
 }
 
 // Check that the operands/users are counted correctly.


        


More information about the llvm-commits mailing list