[llvm] 7467f41 - [SandboxIR] Implement ReturnInst (#99784)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 22 15:48:45 PDT 2024


Author: vporpo
Date: 2024-07-22T15:48:42-07:00
New Revision: 7467f41a7d4bc2e305fb368c591790936ad5ef33

URL: https://github.com/llvm/llvm-project/commit/7467f41a7d4bc2e305fb368c591790936ad5ef33
DIFF: https://github.com/llvm/llvm-project/commit/7467f41a7d4bc2e305fb368c591790936ad5ef33.diff

LOG: [SandboxIR] Implement ReturnInst (#99784)

This patch adds the implementation of the SandboxIR ReturnInst which
mirrors llvm::ReturnInst.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/include/llvm/SandboxIR/SandboxIRValues.def
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index dfffe5c96f1cf..cd77897ccbb94 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -40,7 +40,7 @@
 //                                      |
 //                                      +- PHINode
 //                                      |
-//                                      +- RetInst
+//                                      +- ReturnInst
 //                                      |
 //                                      +- SelectInst
 //                                      |
@@ -76,6 +76,7 @@ class Context;
 class Function;
 class Instruction;
 class LoadInst;
+class ReturnInst;
 class StoreInst;
 class User;
 class Value;
@@ -173,11 +174,12 @@ class Value {
   /// order.
   llvm::Value *Val = nullptr;
 
-  friend class Context;   // For getting `Val`.
-  friend class User;      // For getting `Val`.
-  friend class Use;       // For getting `Val`.
-  friend class LoadInst;  // For getting `Val`.
-  friend class StoreInst; // For getting `Val`.
+  friend class Context;    // For getting `Val`.
+  friend class User;       // For getting `Val`.
+  friend class Use;        // For getting `Val`.
+  friend class LoadInst;   // For getting `Val`.
+  friend class StoreInst;  // For getting `Val`.
+  friend class ReturnInst; // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -497,8 +499,9 @@ class Instruction : public sandboxir::User {
   /// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
   /// returns its topmost LLVM IR instruction.
   llvm::Instruction *getTopmostLLVMInstruction() const;
-  friend class LoadInst;  // For getTopmostLLVMInstruction().
-  friend class StoreInst; // For getTopmostLLVMInstruction().
+  friend class LoadInst;   // For getTopmostLLVMInstruction().
+  friend class StoreInst;  // For getTopmostLLVMInstruction().
+  friend class ReturnInst; // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -639,6 +642,43 @@ class StoreInst final : public Instruction {
 #endif
 };
 
+class ReturnInst final : public Instruction {
+  /// Use ReturnInst::create() instead of calling the constructor.
+  ReturnInst(llvm::Instruction *I, Context &Ctx)
+      : Instruction(ClassID::Ret, Opcode::Ret, I, Ctx) {}
+  ReturnInst(ClassID SubclassID, llvm::Instruction *I, Context &Ctx)
+      : Instruction(SubclassID, Opcode::Ret, I, Ctx) {}
+  friend class Context; // For accessing the constructor in create*()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+  static ReturnInst *createCommon(Value *RetVal, IRBuilder<> &Builder,
+                                  Context &Ctx);
+
+public:
+  static ReturnInst *create(Value *RetVal, Instruction *InsertBefore,
+                            Context &Ctx);
+  static ReturnInst *create(Value *RetVal, BasicBlock *InsertAtEnd,
+                            Context &Ctx);
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::Ret;
+  }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+  /// \Returns null if there is no return value.
+  Value *getReturnValue() const;
+#ifndef NDEBUG
+  void verify() const final {}
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 /// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
 /// an OpaqueInstr.
 class OpaqueInst : public sandboxir::Instruction {
@@ -776,6 +816,8 @@ class Context {
   friend LoadInst; // For createLoadInst()
   StoreInst *createStoreInst(llvm::StoreInst *SI);
   friend StoreInst; // For createStoreInst()
+  ReturnInst *createReturnInst(llvm::ReturnInst *I);
+  friend ReturnInst; // For createReturnInst()
 
 public:
   Context(LLVMContext &LLVMCtx)

diff  --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 90365ca7a1c45..b2f88741af8d9 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -27,6 +27,7 @@ DEF_USER(Constant, Constant)
 DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
 DEF_INSTR(Load, OP(Load), LoadInst)
 DEF_INSTR(Store, OP(Store), StoreInst)
+DEF_INSTR(Ret, OP(Ret), ReturnInst)
 
 #ifdef DEF_VALUE
 #undef DEF_VALUE

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 209b677bafbb5..4cf45fa87693a 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -540,6 +540,48 @@ void StoreInst::dump() const {
   dump(dbgs());
   dbgs() << "\n";
 }
+#endif // NDEBUG
+
+ReturnInst *ReturnInst::createCommon(Value *RetVal, IRBuilder<> &Builder,
+                                     Context &Ctx) {
+  llvm::ReturnInst *NewRI;
+  if (RetVal != nullptr)
+    NewRI = Builder.CreateRet(RetVal->Val);
+  else
+    NewRI = Builder.CreateRetVoid();
+  return Ctx.createReturnInst(NewRI);
+}
+
+ReturnInst *ReturnInst::create(Value *RetVal, Instruction *InsertBefore,
+                               Context &Ctx) {
+  llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(BeforeIR);
+  return createCommon(RetVal, Builder, Ctx);
+}
+
+ReturnInst *ReturnInst::create(Value *RetVal, BasicBlock *InsertAtEnd,
+                               Context &Ctx) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  return createCommon(RetVal, Builder, Ctx);
+}
+
+Value *ReturnInst::getReturnValue() const {
+  auto *LLVMRetVal = cast<llvm::ReturnInst>(Val)->getReturnValue();
+  return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr;
+}
+
+#ifndef NDEBUG
+void ReturnInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void ReturnInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
 
 void OpaqueInst::dump(raw_ostream &OS) const {
   dumpCommonPrefix(OS);
@@ -626,7 +668,7 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
          "Can't register a user!");
   Value *V = VPtr.get();
   [[maybe_unused]] auto Pair =
-         LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
+      LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
   assert(Pair.second && "Already exists!");
   return V;
 }
@@ -668,6 +710,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
     return It->second.get();
   }
+  case llvm::Instruction::Ret: {
+    auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
+    It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
+    return It->second.get();
+  }
   default:
     break;
   }
@@ -696,6 +743,11 @@ StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
   return cast<StoreInst>(registerValue(std::move(NewPtr)));
 }
 
+ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
+  auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
+  return cast<ReturnInst>(registerValue(std::move(NewPtr)));
+}
+
 Value *Context::getValue(llvm::Value *V) const {
   auto It = LLVMValueToValueMap.find(V);
   if (It != LLVMValueToValueMap.end())

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 054a81e9cf308..b0d6ae85950d7 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -132,7 +132,7 @@ define i32 @foo(i32 %v0, i32 %v1) {
   auto *Arg1 = F.getArg(1);
   auto It = BB.begin();
   auto *I0 = &*It++;
-  auto *Ret = &*It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   SmallVector<sandboxir::Argument *> Args{Arg0, Arg1};
   unsigned OpIdx = 0;
@@ -245,7 +245,7 @@ define i32 @foo(i32 %arg0, i32 %arg1) {
   auto *I0 = &*It++;
   auto *I1 = &*It++;
   auto *I2 = &*It++;
-  auto *Ret = &*It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   bool Replaced;
   // Try to replace an operand that doesn't match.
@@ -401,7 +401,7 @@ void @foo(i32 %arg0, i32 %arg1) {
   br label %bb1 ; SB3. (Opaque)
 
 bb1:
-  ret void ; SB5. (Opaque)
+  ret void ; SB5. (Ret)
 }
 )IR");
   }
@@ -488,7 +488,7 @@ define void @foo(i8 %v1) {
   auto It = BB->begin();
   auto *I0 = &*It++;
   auto *I1 = &*It++;
-  auto *Ret = &*It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   // Check getPrevNode().
   EXPECT_EQ(Ret->getPrevNode(), I1);
@@ -508,7 +508,7 @@ define void @foo(i8 %v1) {
   // Check getOpcode().
   EXPECT_EQ(I0->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
   EXPECT_EQ(I1->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
-  EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Opaque);
+  EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
 
   // Check moveBefore(I).
   I1->moveBefore(I0);
@@ -576,7 +576,7 @@ define void @foo(ptr %arg0, ptr %arg1) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
   auto *Ld = cast<sandboxir::LoadInst>(&*It++);
-  auto *Ret = &*It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   // Check getPointerOperand()
   EXPECT_EQ(Ld->getPointerOperand(), Arg0);
@@ -607,7 +607,7 @@ define void @foo(i8 %val, ptr %ptr) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
   auto *St = cast<sandboxir::StoreInst>(&*It++);
-  auto *Ret = &*It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   // Check that the StoreInst has been created correctly.
   // Check getPointerOperand()
@@ -624,3 +624,42 @@ define void @foo(i8 %val, ptr %ptr) {
   EXPECT_EQ(NewSt->getPointerOperand(), Ptr);
   EXPECT_EQ(NewSt->getAlign(), 8);
 }
+
+TEST_F(SandboxIRTest, ReturnInst) {
+  parseIR(C, R"IR(
+define i8 @foo(i8 %val) {
+  %add = add i8 %val, 42
+  ret i8 %val
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  auto *Val = F->getArg(0);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  It++;
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  // Check that the ReturnInst has been created correctly.
+  // Check getReturnValue().
+  EXPECT_EQ(Ret->getReturnValue(), Val);
+
+  // Check create(InsertBefore) a void ReturnInst.
+  auto *NewRet1 = cast<sandboxir::ReturnInst>(
+      sandboxir::ReturnInst::create(nullptr, /*InsertBefore=*/Ret, Ctx));
+  EXPECT_EQ(NewRet1->getReturnValue(), nullptr);
+  // Check create(InsertBefore) a non-void ReturnInst.
+  auto *NewRet2 = cast<sandboxir::ReturnInst>(
+      sandboxir::ReturnInst::create(Val, /*InsertBefore=*/Ret, Ctx));
+  EXPECT_EQ(NewRet2->getReturnValue(), Val);
+
+  // Check create(InsertAtEnd) a void ReturnInst.
+  auto *NewRet3 = cast<sandboxir::ReturnInst>(
+      sandboxir::ReturnInst::create(nullptr, /*InsertAtEnd=*/BB, Ctx));
+  EXPECT_EQ(NewRet3->getReturnValue(), nullptr);
+  // Check create(InsertAtEnd) a non-void ReturnInst.
+  auto *NewRet4 = cast<sandboxir::ReturnInst>(
+      sandboxir::ReturnInst::create(Val, /*InsertAtEnd=*/BB, Ctx));
+  EXPECT_EQ(NewRet4->getReturnValue(), Val);
+}


        


More information about the llvm-commits mailing list