[llvm] ecd2bf7 - [SandboxIR] Add setOperand() and RAUW, RUWIf, RUOW (#98410)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 11 21:46:09 PDT 2024
Author: vporpo
Date: 2024-07-11T21:46:05-07:00
New Revision: ecd2bf73cb212452951b3010bbf06e4d96330a92
URL: https://github.com/llvm/llvm-project/commit/ecd2bf73cb212452951b3010bbf06e4d96330a92
DIFF: https://github.com/llvm/llvm-project/commit/ecd2bf73cb212452951b3010bbf06e4d96330a92.diff
LOG: [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW (#98410)
This patch adds the following member functions:
- User::setOperand()
- User::replaceUsesOfWith()
- Value::replaceAllUsesWith()
- Value::replaceUsesWithIf()
Added:
Modified:
llvm/include/llvm/SandboxIR/SandboxIR.h
llvm/lib/SandboxIR/SandboxIR.cpp
llvm/unittests/SandboxIR/SandboxIRTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 8e87470ee1e5c..317884fe07681 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -200,6 +200,7 @@ class Value {
llvm::Value *Val = nullptr;
friend class Context; // For getting `Val`.
+ friend class User; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -284,6 +285,11 @@ class Value {
Type *getType() const { return Val->getType(); }
Context &getContext() const { return Ctx; }
+
+ void replaceUsesWithIf(Value *OtherV,
+ llvm::function_ref<bool(const Use &)> ShouldReplace);
+ void replaceAllUsesWith(Value *Other);
+
#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
@@ -349,6 +355,10 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
+#ifndef NDEBUG
+ void verifyUserOfLLVMUse(const llvm::Use &Use) const;
+#endif // NDEBUG
+
public:
/// For isa/dyn_cast.
static bool classof(const Value *From);
@@ -387,6 +397,11 @@ class User : public Value {
return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
}
+ virtual void setOperand(unsigned OperandIdx, Value *Operand);
+ /// Replaces any operands that match \p FromV with \p ToV. Returns whether any
+ /// operands were replaced.
+ bool replaceUsesOfWith(Value *FromV, Value *ToV);
+
#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::User>(Val) && "Expected User!");
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index e4a902fb93166..41b66c07bfd43 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {
unsigned Value::getNumUses() const { return range_size(Val->users()); }
+void Value::replaceUsesWithIf(
+ Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
+ assert(getType() == OtherV->getType() && "Can't replace with
diff erent type");
+ llvm::Value *OtherVal = OtherV->Val;
+ Val->replaceUsesWithIf(
+ OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+ User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
+ if (DstU == nullptr)
+ return false;
+ return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
+ });
+}
+
+void Value::replaceAllUsesWith(Value *Other) {
+ assert(getType() == Other->getType() &&
+ "Replacing with Value of
diff erent type!");
+ Val->replaceAllUsesWith(Other->Val);
+}
+
#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
@@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}
+#ifndef NDEBUG
+void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
+ assert(Ctx.getValue(Use.getUser()) == this &&
+ "Use not found in this SBUser's operands!");
+}
+#endif
+
bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
@@ -180,6 +206,15 @@ bool User::classof(const Value *From) {
}
}
+void User::setOperand(unsigned OperandIdx, Value *Operand) {
+ assert(isa<llvm::User>(Val) && "No operands!");
+ cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
+}
+
+bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
+ return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
+}
+
#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
@@ -325,10 +360,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();
if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
+ It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+ auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
- It->second = std::unique_ptr<Constant>(new Constant(C, *this));
- return It->second.get();
+ return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 72e81bf640350..98c0052d878d8 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
+ Instruction *LLVMRet = &*LLVMBBIt++;
+ Argument *LLVMArg0 = LLVMF.getArg(0);
+ Argument *LLVMArg1 = LLVMF.getArg(1);
auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
@@ -203,6 +206,126 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));
+
+ // Check User.setOperand().
+ Ret->setOperand(0, Arg0);
+ EXPECT_EQ(Ret->getOperand(0), Arg0);
+ EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
+ EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);
+
+ Ret->setOperand(0, Arg1);
+ EXPECT_EQ(Ret->getOperand(0), Arg1);
+ EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
+ EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
+}
+
+TEST_F(SandboxIRTest, RUOW) {
+ parseIR(C, R"IR(
+declare void @bar0()
+declare void @bar1()
+
+ at glob0 = global ptr @bar0
+ at glob1 = global ptr @bar1
+
+define i32 @foo(i32 %arg0, i32 %arg1) {
+ %add0 = add i32 %arg0, %arg1
+ %gep1 = getelementptr i8, ptr @glob0, i32 1
+ %gep2 = getelementptr i8, ptr @glob1, i32 1
+ ret i32 %add0
+}
+)IR");
+ llvm::Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto *Arg0 = F.getArg(0);
+ auto *Arg1 = F.getArg(1);
+ auto It = BB.begin();
+ auto *I0 = &*It++;
+ auto *I1 = &*It++;
+ auto *I2 = &*It++;
+ auto *Ret = &*It++;
+
+ bool Replaced;
+ // Try to replace an operand that doesn't match.
+ Replaced = I0->replaceUsesOfWith(Ret, Arg1);
+ EXPECT_FALSE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg0);
+ EXPECT_EQ(I0->getOperand(1), Arg1);
+
+ // Replace I0 operands when operands
diff er.
+ Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg1);
+ EXPECT_EQ(I0->getOperand(1), Arg1);
+
+ // Replace I0 operands when operands are the same.
+ Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg0);
+ EXPECT_EQ(I0->getOperand(1), Arg0);
+
+ // Replace Ret operand.
+ Replaced = Ret->replaceUsesOfWith(I0, Arg0);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(Ret->getOperand(0), Arg0);
+
+ // Check RAUW on constant.
+ auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
+ auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
+ auto *Glob0Op = Glob0->getOperand(0);
+ Glob0->replaceUsesOfWith(Glob0Op, Glob1);
+ EXPECT_EQ(Glob0->getOperand(0), Glob1);
+}
+
+TEST_F(SandboxIRTest, RAUW_RUWIf) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+ %ld0 = load float, ptr %ptr
+ %ld1 = load float, ptr %ptr
+ store float %ld0, ptr %ptr
+ store float %ld0, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
+
+ Ctx.createFunction(&LLVMF);
+ auto *BB = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB));
+ auto It = BB->begin();
+ sandboxir::Instruction *Ld0 = &*It++;
+ sandboxir::Instruction *Ld1 = &*It++;
+ sandboxir::Instruction *St0 = &*It++;
+ sandboxir::Instruction *St1 = &*It++;
+ // Check RUWIf when the lambda returns false.
+ Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
+ EXPECT_EQ(St0->getOperand(0), Ld0);
+ EXPECT_EQ(St1->getOperand(0), Ld0);
+ // Check RUWIf when the lambda returns true.
+ Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
+ EXPECT_EQ(St0->getOperand(0), Ld1);
+ EXPECT_EQ(St1->getOperand(0), Ld1);
+ St0->setOperand(0, Ld0);
+ St1->setOperand(0, Ld0);
+ // Check RUWIf user == St0.
+ Ld0->replaceUsesWithIf(
+ Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; });
+ EXPECT_EQ(St0->getOperand(0), Ld1);
+ EXPECT_EQ(St1->getOperand(0), Ld0);
+ St0->setOperand(0, Ld0);
+ // Check RUWIf user == St1.
+ Ld0->replaceUsesWithIf(
+ Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; });
+ EXPECT_EQ(St0->getOperand(0), Ld0);
+ EXPECT_EQ(St1->getOperand(0), Ld1);
+ St1->setOperand(0, Ld0);
+ // Check RAUW.
+ Ld1->replaceAllUsesWith(Ld0);
+ EXPECT_EQ(St0->getOperand(0), Ld0);
+ EXPECT_EQ(St1->getOperand(0), Ld0);
}
// Check that the operands/users are counted correctly.
More information about the llvm-commits
mailing list