[llvm] [SandboxIR] Implement LoadInst (PR #99597)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 19 13:41:30 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99597
>From 093d6ee722655ff593fac3c07460cbd4a36d8543 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 18 Jul 2024 12:48:22 -0700
Subject: [PATCH] [SandboxIR] Implement LoadInst
This patch implements a `LoadInst` instruction in SandboxIR.
It mirrors `llvm::LoadInst`.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 66 ++++++++++++++++--
.../llvm/SandboxIR/SandboxIRValues.def | 1 +
llvm/lib/SandboxIR/SandboxIR.cpp | 69 +++++++++++++++++--
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 31 +++++++++
4 files changed, 156 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a9f0177eb9338..f168fdf8b1056 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -59,6 +59,7 @@
#define LLVM_SANDBOXIR_SANDBOXIR_H
#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
#include "llvm/SandboxIR/Tracker.h"
@@ -74,6 +75,7 @@ class BasicBlock;
class Context;
class Function;
class Instruction;
+class LoadInst;
class User;
class Value;
@@ -170,9 +172,10 @@ class Value {
/// order.
llvm::Value *Val = nullptr;
- friend class Context; // For getting `Val`.
- friend class User; // For getting `Val`.
- friend class Use; // For getting `Val`.
+ friend class Context; // For getting `Val`.
+ friend class User; // For getting `Val`.
+ friend class Use; // For getting `Val`.
+ friend class LoadInst; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -262,11 +265,14 @@ class Value {
llvm::function_ref<bool(const Use &)> ShouldReplace);
void replaceAllUsesWith(Value *Other);
+ /// \Returns the LLVM IR name of the bottom-most LLVM value.
+ StringRef getName() const { return Val->getName(); }
+
#ifndef NDEBUG
/// Should crash if there is something wrong with the instruction.
virtual void verify() const = 0;
- /// Returns the name in the form 'SB<number>.' like 'SB1.'
- std::string getName() const;
+ /// Returns the unique id in the form 'SB<number>.' like 'SB1.'
+ std::string getUid() const;
virtual void dumpCommonHeader(raw_ostream &OS) const;
void dumpCommonFooter(raw_ostream &OS) const;
void dumpCommonPrefix(raw_ostream &OS) const;
@@ -489,6 +495,7 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
+ friend class LoadInst; // For getTopmostLLVMInstruction().
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
@@ -553,6 +560,45 @@ class Instruction : public sandboxir::User {
#endif
};
+class LoadInst final : public Instruction {
+ /// Use LoadInst::create() instead of calling the constructor.
+ LoadInst(llvm::LoadInst *LI, Context &Ctx)
+ : Instruction(ClassID::Load, Opcode::Load, LI, Ctx) {}
+ friend Context; // for LoadInst()
+ Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+ return getOperandUseDefault(OpIdx, Verify);
+ }
+ SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+ return {cast<llvm::Instruction>(Val)};
+ }
+
+public:
+ unsigned getUseOperandNo(const Use &Use) const final {
+ return getUseOperandNoDefault(Use);
+ }
+
+ unsigned getNumOfIRInstrs() const final { return 1u; }
+ static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name = "");
+ static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name = "");
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From);
+ Value *getPointerOperand() const;
+ Align getAlign() const { return cast<llvm::LoadInst>(Val)->getAlign(); }
+ bool isUnordered() const { return cast<llvm::LoadInst>(Val)->isUnordered(); }
+ bool isSimple() const { return cast<llvm::LoadInst>(Val)->isSimple(); }
+#ifndef NDEBUG
+ void verify() const final {
+ assert(isa<llvm::LoadInst>(Val) && "Expected LoadInst!");
+ }
+ void dump(raw_ostream &OS) const override;
+ LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
@@ -683,8 +729,16 @@ class Context {
friend class BasicBlock; // For getOrCreateValue().
+ IRBuilder<ConstantFolder> LLVMIRBuilder;
+ auto &getLLVMIRBuilder() { return LLVMIRBuilder; }
+
+ LoadInst *createLoadInst(llvm::LoadInst *LI);
+ friend LoadInst; // For createLoadInst()
+
public:
- Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}
+ Context(LLVMContext &LLVMCtx)
+ : LLVMCtx(LLVMCtx), IRTracker(*this),
+ LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
Tracker &getTracker() { return IRTracker; }
/// Convenience function for `getTracker().save()`
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index b090ade3ea0ca..e1ed3cdac6bba 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
#endif
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
+DEF_INSTR(Load, OP(Load), LoadInst)
#ifdef DEF_VALUE
#undef DEF_VALUE
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 87134995a1538..f392704a6d27e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -140,14 +140,14 @@ void Value::replaceAllUsesWith(Value *Other) {
}
#ifndef NDEBUG
-std::string Value::getName() const {
+std::string Value::getUid() const {
std::stringstream SS;
SS << "SB" << UID << ".";
return SS.str();
}
void Value::dumpCommonHeader(raw_ostream &OS) const {
- OS << getName() << " " << getSubclassIDStr(SubclassID) << " ";
+ OS << getUid() << " " << getSubclassIDStr(SubclassID) << " ";
}
void Value::dumpCommonFooter(raw_ostream &OS) const {
@@ -167,7 +167,7 @@ void Value::dumpCommonPrefix(raw_ostream &OS) const {
}
void Value::dumpCommonSuffix(raw_ostream &OS) const {
- OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ")";
+ OS << " ; " << getUid() << " (" << getSubclassIDStr(SubclassID) << ")";
}
void Value::printAsOperandCommon(raw_ostream &OS) const {
@@ -453,6 +453,49 @@ void Instruction::dump() const {
dump(dbgs());
dbgs() << "\n";
}
+#endif // NDEBUG
+
+LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
+ llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(BeforeIR);
+ auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
+ /*isVolatile=*/false, Name);
+ auto *NewSBI = Ctx.createLoadInst(NewLI);
+ return NewSBI;
+}
+
+LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+ auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
+ /*isVolatile=*/false, Name);
+ auto *NewSBI = Ctx.createLoadInst(NewLI);
+ return NewSBI;
+}
+
+bool LoadInst::classof(const Value *From) {
+ return From->getSubclassID() == ClassID::Load;
+}
+
+Value *LoadInst::getPointerOperand() const {
+ return Ctx.getValue(cast<llvm::LoadInst>(Val)->getPointerOperand());
+}
+
+#ifndef NDEBUG
+void LoadInst::dump(raw_ostream &OS) const {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+}
+
+void LoadInst::dump() const {
+ dump(dbgs());
+ dbgs() << "\n";
+}
void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
@@ -538,8 +581,8 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");
Value *V = VPtr.get();
- llvm::Value *Key = V->Val;
- LLVMValueToValueMap[Key] = std::move(VPtr);
+ auto Pair = LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
+ assert(Pair.second && "Already exists!");
return V;
}
@@ -568,6 +611,17 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return nullptr;
}
assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
+
+ switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
+ case llvm::Instruction::Load: {
+ auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
+ It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
+ return It->second.get();
+ }
+ default:
+ break;
+ }
+
It->second = std::unique_ptr<OpaqueInst>(
new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
return It->second.get();
@@ -582,6 +636,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
return BB;
}
+LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
+ auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
+ return cast<LoadInst>(registerValue(std::move(NewPtr)));
+}
+
Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ec68ed0afeb2f..04beb429502bc 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -560,3 +560,34 @@ define void @foo(i8 %v1) {
EXPECT_EQ(I0->getNumUses(), 0u);
EXPECT_EQ(I0->getNextNode(), Ret);
}
+
+TEST_F(SandboxIRTest, LoadInst) {
+ parseIR(C, R"IR(
+define void @foo(ptr %arg0, ptr %arg1) {
+ %ld = load i8, ptr %arg0, align 64
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ sandboxir::Function *F = Ctx.createFunction(LLVMF);
+ auto *Arg0 = F->getArg(0);
+ auto *Arg1 = F->getArg(1);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ auto *Ret = &*It++;
+
+ // Check getPointerOperand()
+ EXPECT_EQ(Ld->getPointerOperand(), Arg0);
+ // Check getAlign()
+ EXPECT_EQ(Ld->getAlign(), 64);
+ // Check create(InsertBefore)
+ sandboxir::LoadInst *NewLd =
+ sandboxir::LoadInst::create(Ld->getType(), Arg1, Align(8),
+ /*InsertBefore=*/Ret, Ctx, "NewLd");
+ EXPECT_EQ(NewLd->getType(), Ld->getType());
+ EXPECT_EQ(NewLd->getPointerOperand(), Arg1);
+ EXPECT_EQ(NewLd->getAlign(), 8);
+ EXPECT_EQ(NewLd->getName(), "NewLd");
+}
More information about the llvm-commits
mailing list