[llvm] [SandboxIR] Implement LoadInst (PR #99597)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 19 13:41:30 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/99597

>From 093d6ee722655ff593fac3c07460cbd4a36d8543 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 18 Jul 2024 12:48:22 -0700
Subject: [PATCH] [SandboxIR] Implement LoadInst

This patch implements a `LoadInst` instruction in SandboxIR.
It mirrors `llvm::LoadInst`.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 66 ++++++++++++++++--
 .../llvm/SandboxIR/SandboxIRValues.def        |  1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 69 +++++++++++++++++--
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 31 +++++++++
 4 files changed, 156 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index a9f0177eb9338..f168fdf8b1056 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -59,6 +59,7 @@
 #define LLVM_SANDBOXIR_SANDBOXIR_H
 
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/User.h"
 #include "llvm/IR/Value.h"
 #include "llvm/SandboxIR/Tracker.h"
@@ -74,6 +75,7 @@ class BasicBlock;
 class Context;
 class Function;
 class Instruction;
+class LoadInst;
 class User;
 class Value;
 
@@ -170,9 +172,10 @@ class Value {
   /// order.
   llvm::Value *Val = nullptr;
 
-  friend class Context; // For getting `Val`.
-  friend class User;    // For getting `Val`.
-  friend class Use;     // For getting `Val`.
+  friend class Context;  // For getting `Val`.
+  friend class User;     // For getting `Val`.
+  friend class Use;      // For getting `Val`.
+  friend class LoadInst; // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -262,11 +265,14 @@ class Value {
                          llvm::function_ref<bool(const Use &)> ShouldReplace);
   void replaceAllUsesWith(Value *Other);
 
+  /// \Returns the LLVM IR name of the bottom-most LLVM value.
+  StringRef getName() const { return Val->getName(); }
+
 #ifndef NDEBUG
   /// Should crash if there is something wrong with the instruction.
   virtual void verify() const = 0;
-  /// Returns the name in the form 'SB<number>.' like 'SB1.'
-  std::string getName() const;
+  /// Returns the unique id in the form 'SB<number>.' like 'SB1.'
+  std::string getUid() const;
   virtual void dumpCommonHeader(raw_ostream &OS) const;
   void dumpCommonFooter(raw_ostream &OS) const;
   void dumpCommonPrefix(raw_ostream &OS) const;
@@ -489,6 +495,7 @@ class Instruction : public sandboxir::User {
   /// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
   /// returns its topmost LLVM IR instruction.
   llvm::Instruction *getTopmostLLVMInstruction() const;
+  friend class LoadInst; // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -553,6 +560,45 @@ class Instruction : public sandboxir::User {
 #endif
 };
 
+class LoadInst final : public Instruction {
+  /// Use LoadInst::create() instead of calling the constructor.
+  LoadInst(llvm::LoadInst *LI, Context &Ctx)
+      : Instruction(ClassID::Load, Opcode::Load, LI, Ctx) {}
+  friend Context; // for LoadInst()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
+                          Instruction *InsertBefore, Context &Ctx,
+                          const Twine &Name = "");
+  static LoadInst *create(Type *Ty, Value *Ptr, MaybeAlign Align,
+                          BasicBlock *InsertAtEnd, Context &Ctx,
+                          const Twine &Name = "");
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From);
+  Value *getPointerOperand() const;
+  Align getAlign() const { return cast<llvm::LoadInst>(Val)->getAlign(); }
+  bool isUnordered() const { return cast<llvm::LoadInst>(Val)->isUnordered(); }
+  bool isSimple() const { return cast<llvm::LoadInst>(Val)->isSimple(); }
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<llvm::LoadInst>(Val) && "Expected LoadInst!");
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -683,8 +729,16 @@ class Context {
 
   friend class BasicBlock; // For getOrCreateValue().
 
+  IRBuilder<ConstantFolder> LLVMIRBuilder;
+  auto &getLLVMIRBuilder() { return LLVMIRBuilder; }
+
+  LoadInst *createLoadInst(llvm::LoadInst *LI);
+  friend LoadInst; // For createLoadInst()
+
 public:
-  Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx), IRTracker(*this) {}
+  Context(LLVMContext &LLVMCtx)
+      : LLVMCtx(LLVMCtx), IRTracker(*this),
+        LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
 
   Tracker &getTracker() { return IRTracker; }
   /// Convenience function for `getTracker().save()`
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index b090ade3ea0ca..e1ed3cdac6bba 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -25,6 +25,7 @@ DEF_USER(Constant, Constant)
 #endif
 //       ClassID, Opcode(s),  Class
 DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
+DEF_INSTR(Load, OP(Load), LoadInst)
 
 #ifdef DEF_VALUE
 #undef DEF_VALUE
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 87134995a1538..f392704a6d27e 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -140,14 +140,14 @@ void Value::replaceAllUsesWith(Value *Other) {
 }
 
 #ifndef NDEBUG
-std::string Value::getName() const {
+std::string Value::getUid() const {
   std::stringstream SS;
   SS << "SB" << UID << ".";
   return SS.str();
 }
 
 void Value::dumpCommonHeader(raw_ostream &OS) const {
-  OS << getName() << " " << getSubclassIDStr(SubclassID) << " ";
+  OS << getUid() << " " << getSubclassIDStr(SubclassID) << " ";
 }
 
 void Value::dumpCommonFooter(raw_ostream &OS) const {
@@ -167,7 +167,7 @@ void Value::dumpCommonPrefix(raw_ostream &OS) const {
 }
 
 void Value::dumpCommonSuffix(raw_ostream &OS) const {
-  OS << " ; " << getName() << " (" << getSubclassIDStr(SubclassID) << ")";
+  OS << " ; " << getUid() << " (" << getSubclassIDStr(SubclassID) << ")";
 }
 
 void Value::printAsOperandCommon(raw_ostream &OS) const {
@@ -453,6 +453,49 @@ void Instruction::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
+                           Instruction *InsertBefore, Context &Ctx,
+                           const Twine &Name) {
+  llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(BeforeIR);
+  auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
+                                          /*isVolatile=*/false, Name);
+  auto *NewSBI = Ctx.createLoadInst(NewLI);
+  return NewSBI;
+}
+
+LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
+                           BasicBlock *InsertAtEnd, Context &Ctx,
+                           const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
+                                          /*isVolatile=*/false, Name);
+  auto *NewSBI = Ctx.createLoadInst(NewLI);
+  return NewSBI;
+}
+
+bool LoadInst::classof(const Value *From) {
+  return From->getSubclassID() == ClassID::Load;
+}
+
+Value *LoadInst::getPointerOperand() const {
+  return Ctx.getValue(cast<llvm::LoadInst>(Val)->getPointerOperand());
+}
+
+#ifndef NDEBUG
+void LoadInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void LoadInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -538,8 +581,8 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
   assert(VPtr->getSubclassID() != Value::ClassID::User &&
          "Can't register a user!");
   Value *V = VPtr.get();
-  llvm::Value *Key = V->Val;
-  LLVMValueToValueMap[Key] = std::move(VPtr);
+  auto Pair = LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
+  assert(Pair.second && "Already exists!");
   return V;
 }
 
@@ -568,6 +611,17 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     return nullptr;
   }
   assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
+
+  switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
+  case llvm::Instruction::Load: {
+    auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
+    It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
+    return It->second.get();
+  }
+  default:
+    break;
+  }
+
   It->second = std::unique_ptr<OpaqueInst>(
       new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
   return It->second.get();
@@ -582,6 +636,11 @@ BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
   return BB;
 }
 
+LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
+  auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
+  return cast<LoadInst>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ec68ed0afeb2f..04beb429502bc 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -560,3 +560,34 @@ define void @foo(i8 %v1) {
   EXPECT_EQ(I0->getNumUses(), 0u);
   EXPECT_EQ(I0->getNextNode(), Ret);
 }
+
+TEST_F(SandboxIRTest, LoadInst) {
+  parseIR(C, R"IR(
+define void @foo(ptr %arg0, ptr %arg1) {
+  %ld = load i8, ptr %arg0, align 64
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  auto *Arg0 = F->getArg(0);
+  auto *Arg1 = F->getArg(1);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+  auto *Ret = &*It++;
+
+  // Check getPointerOperand()
+  EXPECT_EQ(Ld->getPointerOperand(), Arg0);
+  // Check getAlign()
+  EXPECT_EQ(Ld->getAlign(), 64);
+  // Check create(InsertBefore)
+  sandboxir::LoadInst *NewLd =
+      sandboxir::LoadInst::create(Ld->getType(), Arg1, Align(8),
+                                  /*InsertBefore=*/Ret, Ctx, "NewLd");
+  EXPECT_EQ(NewLd->getType(), Ld->getType());
+  EXPECT_EQ(NewLd->getPointerOperand(), Arg1);
+  EXPECT_EQ(NewLd->getAlign(), 8);
+  EXPECT_EQ(NewLd->getName(), "NewLd");
+}



More information about the llvm-commits mailing list