[llvm] [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW (PR #98410)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 10 16:27:28 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/98410
>From 044c30df4ea5f85af76a99c96be3dec16886d8b4 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 10 Jul 2024 15:41:18 -0700
Subject: [PATCH] [SandboxIR] Add setOperand() and RAUW,RUWIf,RUOW
This patch adds the following member functions:
- User::setOperand()
- User::replaceUsesOfWith()
- Value::replaceAllUsesWith()
- Value::replaceUsesWithIf()
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 22 +++++
llvm/lib/SandboxIR/SandboxIR.cpp | 44 ++++++++-
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 104 +++++++++++++++++++++
3 files changed, 168 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 8e87470ee1e5c..3631dd76cfdaf 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -200,6 +200,7 @@ class Value {
llvm::Value *Val = nullptr;
friend class Context; // For getting `Val`.
+ friend class ValueAttorney;
/// All values point to the context.
Context &Ctx;
@@ -284,6 +285,11 @@ class Value {
Type *getType() const { return Val->getType(); }
Context &getContext() const { return Ctx; }
+
+ void replaceUsesWithIf(Value *OtherV,
+ llvm::function_ref<bool(const Use &)> ShouldReplace);
+ void replaceAllUsesWith(Value *Other);
+
#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
@@ -303,6 +309,13 @@ class Value {
#endif
};
+/// Helper Attorney-Client class that gives access to the underlying IR.
+class ValueAttorney {
+private:
+ static llvm::Value *getValue(const Value *SBV) { return SBV->Val; }
+ friend class User;
+};
+
/// Argument of a sandboxir::Function.
class Argument : public sandboxir::Value {
Argument(llvm::Argument *Arg, sandboxir::Context &Ctx)
@@ -349,6 +362,10 @@ class User : public Value {
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
+#ifndef NDEBUG
+ void verifyUserOfLLVMUse(const llvm::Use &Use) const;
+#endif // NDEBUG
+
public:
/// For isa/dyn_cast.
static bool classof(const Value *From);
@@ -387,6 +404,11 @@ class User : public Value {
return isa<llvm::User>(Val) ? cast<llvm::User>(Val)->getNumOperands() : 0;
}
+ virtual void setOperand(unsigned OperandIdx, Value *Operand);
+ /// Replaces any operands that match \p FromV with \p ToV. Returns whether any
+ /// operands were replaced.
+ bool replaceUsesOfWith(Value *FromV, Value *ToV);
+
#ifndef NDEBUG
void verify() const override {
assert(isa<llvm::User>(Val) && "Expected User!");
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 160d807738a3c..1824b42fe930e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -103,6 +103,25 @@ Value::user_iterator Value::user_begin() {
unsigned Value::getNumUses() const { return range_size(Val->users()); }
+void Value::replaceUsesWithIf(
+ Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
+ assert(getType() == OtherV->getType() && "Can't replace with different type");
+ llvm::Value *OtherVal = OtherV->Val;
+ Val->replaceUsesWithIf(
+ OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
+ User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
+ if (DstU == nullptr)
+ return false;
+ return ShouldReplace(Use(&LLVMUse, DstU, Ctx));
+ });
+}
+
+void Value::replaceAllUsesWith(Value *Other) {
+ assert(getType() == Other->getType() &&
+ "Replacing with Value of different type!");
+ Val->replaceAllUsesWith(Other->Val);
+}
+
#ifndef NDEBUG
std::string Value::getName() const {
std::stringstream SS;
@@ -165,6 +184,13 @@ Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
return Use(LLVMUse, const_cast<User *>(this), Ctx);
}
+#ifndef NDEBUG
+void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
+ assert(Ctx.getValue(Use.getUser()) == this &&
+ "Use not found in this SBUser's operands!");
+}
+#endif
+
bool User::classof(const Value *From) {
switch (From->getSubclassID()) {
#define DEF_VALUE(ID, CLASS)
@@ -180,6 +206,19 @@ bool User::classof(const Value *From) {
}
}
+void User::setOperand(unsigned OperandIdx, Value *Operand) {
+ if (!isa<llvm::User>(Val))
+ llvm_unreachable("No operands!");
+ cast<llvm::User>(Val)->setOperand(OperandIdx,
+ ValueAttorney::getValue(Operand));
+}
+
+bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
+ llvm::Value *FromLLVM = ValueAttorney::getValue(FromV);
+ llvm::Value *ToLLVM = ValueAttorney::getValue(ToV);
+ return cast<llvm::User>(Val)->replaceUsesOfWith(FromLLVM, ToLLVM);
+}
+
#ifndef NDEBUG
void User::dumpCommonHeader(raw_ostream &OS) const {
Value::dumpCommonHeader(OS);
@@ -325,10 +364,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();
if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
+ It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+ auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
- It->second = std::unique_ptr<Constant>(new Constant(C, *this));
- return It->second.get();
+ return NewC;
}
if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 72e81bf640350..16e537efba5de 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -122,6 +122,9 @@ define i32 @foo(i32 %v0, i32 %v1) {
BasicBlock *LLVMBB = &*LLVMF.begin();
auto LLVMBBIt = LLVMBB->begin();
Instruction *LLVMI0 = &*LLVMBBIt++;
+ Instruction *LLVMRet = &*LLVMBBIt++;
+ Argument *LLVMArg0 = LLVMF.getArg(0);
+ Argument *LLVMArg1 = LLVMF.getArg(1);
auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
@@ -203,6 +206,107 @@ OperandNo: 0
EXPECT_FALSE(I0->hasNUses(0u));
EXPECT_TRUE(I0->hasNUses(1u));
EXPECT_FALSE(I0->hasNUses(2u));
+
+ // Check User.setOperand().
+ Ret->setOperand(0, Arg0);
+ EXPECT_EQ(Ret->getOperand(0), Arg0);
+ EXPECT_EQ(Ret->getOperandUse(0).get(), Arg0);
+ EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg0);
+
+ Ret->setOperand(0, Arg1);
+ EXPECT_EQ(Ret->getOperand(0), Arg1);
+ EXPECT_EQ(Ret->getOperandUse(0).get(), Arg1);
+ EXPECT_EQ(LLVMRet->getOperand(0), LLVMArg1);
+}
+
+TEST_F(SandboxIRTest, RUOW) {
+ parseIR(C, R"IR(
+declare void @bar0()
+declare void @bar1()
+
+ at glob0 = global ptr @bar0
+ at glob1 = global ptr @bar1
+
+define i32 @foo(i32 %v0, i32 %v1) {
+ %add0 = add i32 %v0, %v1
+ %gep1 = getelementptr i8, ptr @glob0, i32 1
+ %gep2 = getelementptr i8, ptr @glob1, i32 1
+ ret i32 %add0
+}
+)IR");
+ llvm::Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto *Arg0 = F.getArg(0);
+ auto *Arg1 = F.getArg(1);
+ auto It = BB.begin();
+ auto *I0 = &*It++;
+ auto *I1 = &*It++;
+ auto *I2 = &*It++;
+ auto *Ret = &*It++;
+
+ bool Replaced;
+ // Try to replace an operand that doesn't match.
+ Replaced = I0->replaceUsesOfWith(Ret, Arg1);
+ EXPECT_FALSE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg0);
+ EXPECT_EQ(I0->getOperand(1), Arg1);
+
+ // Replace I0 operands when operands differ.
+ Replaced = I0->replaceUsesOfWith(Arg0, Arg1);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg1);
+ EXPECT_EQ(I0->getOperand(1), Arg1);
+
+ // Replace I0 operands when operands are the same.
+ Replaced = I0->replaceUsesOfWith(Arg1, Arg0);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(I0->getOperand(0), Arg0);
+ EXPECT_EQ(I0->getOperand(1), Arg0);
+
+ // Replace Ret operand.
+ Replaced = Ret->replaceUsesOfWith(I0, Arg0);
+ EXPECT_TRUE(Replaced);
+ EXPECT_EQ(Ret->getOperand(0), Arg0);
+
+ // Check RAUW on constant.
+ auto *Glob0 = cast<sandboxir::Constant>(I1->getOperand(0));
+ auto *Glob1 = cast<sandboxir::Constant>(I2->getOperand(0));
+ auto *Glob0Op = Glob0->getOperand(0);
+ Glob0->replaceUsesOfWith(Glob0Op, Glob1);
+ EXPECT_EQ(Glob0->getOperand(0), Glob1);
+}
+
+TEST_F(SandboxIRTest, RAW_RUWIf) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+ %ld0 = load float, ptr %ptr
+ %ld1 = load float, ptr %ptr
+ store float %ld0, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ llvm::BasicBlock *LLVMBB0 = &*LLVMF.begin();
+
+ Ctx.createFunction(&LLVMF);
+ auto *BB0 = cast<sandboxir::BasicBlock>(Ctx.getValue(LLVMBB0));
+ auto It = BB0->begin();
+ sandboxir::Instruction *Ld0 = &*It++;
+ sandboxir::Instruction *Ld1 = &*It++;
+ sandboxir::Instruction *St0 = &*It++;
+ // Check RUWIf when the lambda returns false.
+ Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; });
+ EXPECT_EQ(St0->getOperand(0), Ld0);
+ // Check RUWIf when the lambda returns true.
+ Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; });
+ EXPECT_EQ(St0->getOperand(0), Ld1);
+ // Check RAUW.
+ Ld1->replaceAllUsesWith(Ld0);
+ EXPECT_EQ(St0->getOperand(0), Ld0);
}
// Check that the operands/users are counted correctly.
More information about the llvm-commits
mailing list