[llvm] [SandboxIR][NFC] Implement InsertPosition (PR #110730)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 12:56:53 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/110730

This patch implements the InsertPosition class that is used to specify where an instruction should be placed.

It also switches a couple of create() functions from the old API to the new one that uses InsertPositoin.

>From d54ec4dbd769eec2d7fdcd27b927e256b85c3b59 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 1 Oct 2024 10:31:14 -0700
Subject: [PATCH] [SandboxIR][NFC] Implement InsertPosition

This patch implements the InsertPosition class that is used to specify
where an instruction should be placed.

It also switches a couple of create() functions from the old API
to the new one that uses InsertPositoin.
---
 llvm/include/llvm/SandboxIR/Instruction.h  | 46 +++++++++++++------
 llvm/lib/SandboxIR/Instruction.cpp         | 53 ++++------------------
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 10 ++--
 3 files changed, 46 insertions(+), 63 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index a34573a0bc1b01..cfa7d21f1401db 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -18,6 +18,21 @@
 
 namespace llvm::sandboxir {
 
+class InsertPosition {
+  BBIterator InsertAt;
+
+public:
+  InsertPosition(BasicBlock *InsertAtEnd) {
+    assert(InsertAtEnd != nullptr && "Expected non-null!");
+    InsertAt = InsertAtEnd->end();
+  }
+  InsertPosition(BBIterator InsertAt) : InsertAt(InsertAt) {}
+  operator BBIterator() { return InsertAt; }
+  const BBIterator &getIterator() const { return InsertAt; }
+  Instruction &operator*() { return *InsertAt; }
+  BasicBlock *getBasicBlock() const { return InsertAt.getNodeParent(); }
+};
+
 /// A sandboxir::User with operands, opcode and linked with previous/next
 /// instructions in an instruction list.
 class Instruction : public User {
@@ -79,6 +94,20 @@ class Instruction : public User {
   virtual SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const = 0;
   friend class EraseFromParent; // For getLLVMInstrs().
 
+  /// Helper function for create(). It sets the builder's insert position
+  /// according to \p Pos.
+  static IRBuilder<> &setInsertPos(InsertPosition Pos) {
+    auto *WhereBB = Pos.getBasicBlock();
+    auto WhereIt = Pos.getIterator();
+    auto &Ctx = WhereBB->getContext();
+    auto &Builder = Ctx.getLLVMIRBuilder();
+    if (WhereIt != WhereBB->end())
+      Builder.SetInsertPoint((*Pos).getTopmostLLVMInstruction());
+    else
+      Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+    return Builder;
+  }
+
 public:
   static const char *getOpcodeName(Opcode Opc);
   /// This is used by BasicBlock::iterator.
@@ -398,8 +427,8 @@ class FenceInst : public SingleLLVMInstructionImpl<llvm::FenceInst> {
   friend Context; // For constructor;
 
 public:
-  static FenceInst *create(AtomicOrdering Ordering, BBIterator WhereIt,
-                           BasicBlock *WhereBB, Context &Ctx,
+  static FenceInst *create(AtomicOrdering Ordering, InsertPosition Pos,
+                           Context &Ctx,
                            SyncScope::ID SSID = SyncScope::System);
   /// Returns the ordering constraint of this fence instruction.
   AtomicOrdering getOrdering() const {
@@ -425,16 +454,10 @@ class SelectInst : public SingleLLVMInstructionImpl<llvm::SelectInst> {
   SelectInst(llvm::SelectInst *CI, Context &Ctx)
       : SingleLLVMInstructionImpl(ClassID::Select, Opcode::Select, CI, Ctx) {}
   friend Context; // for SelectInst()
-  static Value *createCommon(Value *Cond, Value *True, Value *False,
-                             const Twine &Name, IRBuilder<> &Builder,
-                             Context &Ctx);
 
 public:
   static Value *create(Value *Cond, Value *True, Value *False,
-                       Instruction *InsertBefore, Context &Ctx,
-                       const Twine &Name = "");
-  static Value *create(Value *Cond, Value *True, Value *False,
-                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       InsertPosition Pos, Context &Ctx,
                        const Twine &Name = "");
 
   const Value *getCondition() const { return getOperand(0); }
@@ -471,10 +494,7 @@ class InsertElementInst final
 
 public:
   static Value *create(Value *Vec, Value *NewElt, Value *Idx,
-                       Instruction *InsertBefore, Context &Ctx,
-                       const Twine &Name = "");
-  static Value *create(Value *Vec, Value *NewElt, Value *Idx,
-                       BasicBlock *InsertAtEnd, Context &Ctx,
+                       InsertPosition Pos, Context &Ctx,
                        const Twine &Name = "");
   static bool classof(const Value *From) {
     return From->getSubclassID() == ClassID::InsertElement;
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index 276c4f0872b109..0437ea3d6009f5 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -315,14 +315,9 @@ FreezeInst *FreezeInst::create(Value *V, BBIterator WhereIt,
   return Ctx.createFreezeInst(LLVMI);
 }
 
-FenceInst *FenceInst::create(AtomicOrdering Ordering, BBIterator WhereIt,
-                             BasicBlock *WhereBB, Context &Ctx,
-                             SyncScope::ID SSID) {
-  auto &Builder = Ctx.getLLVMIRBuilder();
-  if (WhereIt != WhereBB->end())
-    Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
-  else
-    Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
+FenceInst *FenceInst::create(AtomicOrdering Ordering, InsertPosition Pos,
+                             Context &Ctx, SyncScope::ID SSID) {
+  auto &Builder = Instruction::setInsertPos(Pos);
   llvm::FenceInst *LLVMI = Builder.CreateFence(Ordering, SSID);
   return Ctx.createFenceInst(LLVMI);
 }
@@ -342,9 +337,9 @@ void FenceInst::setSyncScopeID(SyncScope::ID SSID) {
   cast<llvm::FenceInst>(Val)->setSyncScopeID(SSID);
 }
 
-Value *SelectInst::createCommon(Value *Cond, Value *True, Value *False,
-                                const Twine &Name, IRBuilder<> &Builder,
-                                Context &Ctx) {
+Value *SelectInst::create(Value *Cond, Value *True, Value *False,
+                          InsertPosition Pos, Context &Ctx, const Twine &Name) {
+  auto &Builder = Instruction::setInsertPos(Pos);
   llvm::Value *NewV =
       Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
   if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
@@ -353,24 +348,6 @@ Value *SelectInst::createCommon(Value *Cond, Value *True, Value *False,
   return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
 }
 
-Value *SelectInst::create(Value *Cond, Value *True, Value *False,
-                          Instruction *InsertBefore, Context &Ctx,
-                          const Twine &Name) {
-  llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
-  auto &Builder = Ctx.getLLVMIRBuilder();
-  Builder.SetInsertPoint(BeforeIR);
-  return createCommon(Cond, True, False, Name, Builder, Ctx);
-}
-
-Value *SelectInst::create(Value *Cond, Value *True, Value *False,
-                          BasicBlock *InsertAtEnd, Context &Ctx,
-                          const Twine &Name) {
-  auto *IRInsertAtEnd = cast<llvm::BasicBlock>(InsertAtEnd->Val);
-  auto &Builder = Ctx.getLLVMIRBuilder();
-  Builder.SetInsertPoint(IRInsertAtEnd);
-  return createCommon(Cond, True, False, Name, Builder, Ctx);
-}
-
 void SelectInst::swapValues() {
   Ctx.getTracker().emplaceIfTracking<UseSwap>(getOperandUse(1),
                                               getOperandUse(2));
@@ -1791,23 +1768,9 @@ void PossiblyNonNegInst::setNonNeg(bool B) {
 }
 
 Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
-                                 Instruction *InsertBefore, Context &Ctx,
+                                 InsertPosition Pos, Context &Ctx,
                                  const Twine &Name) {
-  auto &Builder = Ctx.getLLVMIRBuilder();
-  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
-  llvm::Value *NewV =
-      Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
-  if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
-    return Ctx.createInsertElementInst(NewInsert);
-  assert(isa<llvm::Constant>(NewV) && "Expected constant");
-  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
-}
-
-Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
-                                 BasicBlock *InsertAtEnd, Context &Ctx,
-                                 const Twine &Name) {
-  auto &Builder = Ctx.getLLVMIRBuilder();
-  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  auto &Builder = Instruction::setInsertPos(Pos);
   llvm::Value *NewV =
       Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
   if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 7206ee34d36e3a..d1f80b43c4a60a 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -2169,7 +2169,7 @@ define void @foo() {
   // Check create().
   auto *NewFence =
       sandboxir::FenceInst::create(AtomicOrdering::Release, Ret->getIterator(),
-                                   BB, Ctx, SyncScope::SingleThread);
+                                   Ctx, SyncScope::SingleThread);
   EXPECT_EQ(NewFence->getNextNode(), Ret);
   EXPECT_EQ(NewFence->getOrdering(), AtomicOrdering::Release);
   EXPECT_EQ(NewFence->getSyncScopeID(), SyncScope::SingleThread);
@@ -2224,7 +2224,7 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   {
     // Check SelectInst::create() InsertBefore.
     auto *NewSel = cast<sandboxir::SelectInst>(sandboxir::SelectInst::create(
-        Cond0, V0, V1, /*InsertBefore=*/Ret, Ctx));
+        Cond0, V0, V1, /*InsertBefore=*/Ret->getIterator(), Ctx));
     EXPECT_EQ(NewSel->getCondition(), Cond0);
     EXPECT_EQ(NewSel->getTrueValue(), V0);
     EXPECT_EQ(NewSel->getFalseValue(), V1);
@@ -2246,8 +2246,8 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
     auto *FortyTwo =
         sandboxir::ConstantInt::get(sandboxir::Type::getInt1Ty(Ctx), 42,
                                     /*IsSigned=*/false);
-    auto *NewSel =
-        sandboxir::SelectInst::create(False, FortyTwo, FortyTwo, Ret, Ctx);
+    auto *NewSel = sandboxir::SelectInst::create(False, FortyTwo, FortyTwo,
+                                                 Ret->getIterator(), Ctx);
     EXPECT_TRUE(isa<sandboxir::Constant>(NewSel));
     EXPECT_EQ(NewSel, FortyTwo);
   }
@@ -2325,7 +2325,7 @@ define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
   auto *Idx = Ins0->getOperand(2);
   auto *NewI1 =
       cast<sandboxir::InsertElementInst>(sandboxir::InsertElementInst::create(
-          Poison, Arg0, Idx, Ret, Ctx, "NewIns1"));
+          Poison, Arg0, Idx, Ret->getIterator(), Ctx, "NewIns1"));
   EXPECT_EQ(NewI1->getOperand(0), Poison);
   EXPECT_EQ(NewI1->getNextNode(), Ret);
 



More information about the llvm-commits mailing list