[llvm] [SandboxIR] Add the ExtractElementInst class (PR #102706)

Jorge Gorbe Moya via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 9 18:14:49 PDT 2024


https://github.com/slackito updated https://github.com/llvm/llvm-project/pull/102706

>From 68f7500fdcbf0f0eb52697bc706ca53994f941e5 Mon Sep 17 00:00:00 2001
From: Jorge Gorbe Moya <jgorbe at google.com>
Date: Fri, 9 Aug 2024 17:09:44 -0700
Subject: [PATCH] [SandboxIR] Add the ExtractElementInst class

---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 121 +++++++++++++-----
 .../llvm/SandboxIR/SandboxIRValues.def        |  61 ++++-----
 llvm/lib/SandboxIR/SandboxIR.cpp              |  49 +++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    |  42 ++++++
 4 files changed, 211 insertions(+), 62 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 4c452ce0b4a61b..8a2e3aa630c62e 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -111,6 +111,7 @@ class Context;
 class Function;
 class Instruction;
 class SelectInst;
+class ExtractElementInst;
 class InsertElementInst;
 class BranchInst;
 class UnaryInstruction;
@@ -232,24 +233,25 @@ class Value {
   /// order.
   llvm::Value *Val = nullptr;
 
-  friend class Context;           // For getting `Val`.
-  friend class User;              // For getting `Val`.
-  friend class Use;               // For getting `Val`.
-  friend class SelectInst;        // For getting `Val`.
-  friend class InsertElementInst; // For getting `Val`.
-  friend class BranchInst;        // For getting `Val`.
-  friend class LoadInst;          // For getting `Val`.
-  friend class StoreInst;         // For getting `Val`.
-  friend class ReturnInst;        // For getting `Val`.
-  friend class CallBase;          // For getting `Val`.
-  friend class CallInst;          // For getting `Val`.
-  friend class InvokeInst;        // For getting `Val`.
-  friend class CallBrInst;        // For getting `Val`.
-  friend class GetElementPtrInst; // For getting `Val`.
-  friend class AllocaInst;        // For getting `Val`.
-  friend class CastInst;          // For getting `Val`.
-  friend class PHINode;           // For getting `Val`.
-  friend class UnreachableInst;   // For getting `Val`.
+  friend class Context;            // For getting `Val`.
+  friend class User;               // For getting `Val`.
+  friend class Use;                // For getting `Val`.
+  friend class SelectInst;         // For getting `Val`.
+  friend class ExtractElementInst; // For getting `Val`.
+  friend class InsertElementInst;  // For getting `Val`.
+  friend class BranchInst;         // For getting `Val`.
+  friend class LoadInst;           // For getting `Val`.
+  friend class StoreInst;          // For getting `Val`.
+  friend class ReturnInst;         // For getting `Val`.
+  friend class CallBase;           // For getting `Val`.
+  friend class CallInst;           // For getting `Val`.
+  friend class InvokeInst;         // For getting `Val`.
+  friend class CallBrInst;         // For getting `Val`.
+  friend class GetElementPtrInst;  // For getting `Val`.
+  friend class AllocaInst;         // For getting `Val`.
+  friend class CastInst;           // For getting `Val`.
+  friend class PHINode;            // For getting `Val`.
+  friend class UnreachableInst;    // For getting `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -635,20 +637,21 @@ class Instruction : public sandboxir::User {
   /// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
   /// returns its topmost LLVM IR instruction.
   llvm::Instruction *getTopmostLLVMInstruction() const;
-  friend class SelectInst;        // For getTopmostLLVMInstruction().
-  friend class InsertElementInst; // For getTopmostLLVMInstruction().
-  friend class BranchInst;        // For getTopmostLLVMInstruction().
-  friend class LoadInst;          // For getTopmostLLVMInstruction().
-  friend class StoreInst;         // For getTopmostLLVMInstruction().
-  friend class ReturnInst;        // For getTopmostLLVMInstruction().
-  friend class CallInst;          // For getTopmostLLVMInstruction().
-  friend class InvokeInst;        // For getTopmostLLVMInstruction().
-  friend class CallBrInst;        // For getTopmostLLVMInstruction().
-  friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
-  friend class AllocaInst;        // For getTopmostLLVMInstruction().
-  friend class CastInst;          // For getTopmostLLVMInstruction().
-  friend class PHINode;           // For getTopmostLLVMInstruction().
-  friend class UnreachableInst;   // For getTopmostLLVMInstruction().
+  friend class SelectInst;         // For getTopmostLLVMInstruction().
+  friend class ExtractElementInst; // For getTopmostLLVMInstruction().
+  friend class InsertElementInst;  // For getTopmostLLVMInstruction().
+  friend class BranchInst;         // For getTopmostLLVMInstruction().
+  friend class LoadInst;           // For getTopmostLLVMInstruction().
+  friend class StoreInst;          // For getTopmostLLVMInstruction().
+  friend class ReturnInst;         // For getTopmostLLVMInstruction().
+  friend class CallInst;           // For getTopmostLLVMInstruction().
+  friend class InvokeInst;         // For getTopmostLLVMInstruction().
+  friend class CallBrInst;         // For getTopmostLLVMInstruction().
+  friend class GetElementPtrInst;  // For getTopmostLLVMInstruction().
+  friend class AllocaInst;         // For getTopmostLLVMInstruction().
+  friend class CastInst;           // For getTopmostLLVMInstruction().
+  friend class PHINode;            // For getTopmostLLVMInstruction().
+  friend class UnreachableInst;    // For getTopmostLLVMInstruction().
 
   /// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
   /// order.
@@ -805,6 +808,58 @@ class InsertElementInst final : public Instruction {
 #endif
 };
 
+class ExtractElementInst final : public Instruction {
+  /// Use Context::createExtractElementInst() instead.
+  ExtractElementInst(llvm::Instruction *I, Context &Ctx)
+      : Instruction(ClassID::ExtractElement, Opcode::ExtractElement, I, Ctx) {}
+  friend class Context; // For accessing the constructor in
+                        // create*()
+  Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
+    return getOperandUseDefault(OpIdx, Verify);
+  }
+  SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
+    return {cast<llvm::Instruction>(Val)};
+  }
+
+public:
+  static Value *create(Value *Vec, Value *Idx, Instruction *InsertBefore,
+                       Context &Ctx, const Twine &Name = "");
+  static Value *create(Value *Vec, Value *Idx, BasicBlock *InsertAtEnd,
+                       Context &Ctx, const Twine &Name = "");
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ExtractElement;
+  }
+
+  static bool isValidOperands(const Value *Vec, const Value *Idx) {
+    return llvm::ExtractElementInst::isValidOperands(Vec->Val, Idx->Val);
+  }
+  Value *getVectorOperand() { return getOperand(0); }
+  Value *getIndexOperand() { return getOperand(1); }
+  const Value *getVectorOperand() const { return getOperand(0); }
+  const Value *getIndexOperand() const { return getOperand(1); }
+
+  VectorType *getVectorOperandType() const {
+    return cast<VectorType>(getVectorOperand()->getType());
+  }
+
+  unsigned getUseOperandNo(const Use &Use) const final {
+    return getUseOperandNoDefault(Use);
+  }
+  unsigned getNumOfIRInstrs() const final { return 1u; }
+#ifndef NDEBUG
+  void verify() const final {
+    assert(isa<llvm::ExtractElementInst>(Val) && "Expected ExtractElementInst");
+  }
+  friend raw_ostream &operator<<(raw_ostream &OS,
+                                 const ExtractElementInst &EEI) {
+    EEI.dump(OS);
+    return OS;
+  }
+  void dump(raw_ostream &OS) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+#endif
+};
+
 class BranchInst : public Instruction {
   /// Use Context::createBranchInst(). Don't call the constructor directly.
   BranchInst(llvm::BranchInst *BI, Context &Ctx)
@@ -1899,6 +1954,8 @@ class Context {
   friend SelectInst; // For createSelectInst()
   InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
   friend InsertElementInst; // For createInsertElementInst()
+  ExtractElementInst *createExtractElementInst(llvm::ExtractElementInst *EEI);
+  friend ExtractElementInst; // For createExtractElementInst()
   BranchInst *createBranchInst(llvm::BranchInst *I);
   friend BranchInst; // For createBranchInst()
   LoadInst *createLoadInst(llvm::LoadInst *LI);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 269aea784dcec1..11f4f2e74712f4 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -32,36 +32,37 @@ DEF_USER(Constant, Constant)
 #define OPCODES(...)
 #endif
 // clang-format off
-//       ClassID,        Opcode(s),         Class
-DEF_INSTR(Opaque,        OP(Opaque),        OpaqueInst)
-DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
-DEF_INSTR(Select,        OP(Select),        SelectInst)
-DEF_INSTR(Br,            OP(Br),            BranchInst)
-DEF_INSTR(Load,          OP(Load),          LoadInst)
-DEF_INSTR(Store,         OP(Store),         StoreInst)
-DEF_INSTR(Ret,           OP(Ret),           ReturnInst)
-DEF_INSTR(Call,          OP(Call),          CallInst)
-DEF_INSTR(Invoke,        OP(Invoke),        InvokeInst)
-DEF_INSTR(CallBr,        OP(CallBr),        CallBrInst)
-DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(Alloca,        OP(Alloca),        AllocaInst)
-DEF_INSTR(Cast,  OPCODES(\
-                         OP(ZExt)          \
-                         OP(SExt)          \
-                         OP(FPToUI)        \
-                         OP(FPToSI)        \
-                         OP(FPExt)         \
-                         OP(PtrToInt)      \
-                         OP(IntToPtr)      \
-                         OP(SIToFP)        \
-                         OP(UIToFP)        \
-                         OP(Trunc)         \
-                         OP(FPTrunc)       \
-                         OP(BitCast)       \
-                         OP(AddrSpaceCast) \
-                         ),                 CastInst)
-DEF_INSTR(PHI,           OP(PHI),           PHINode)
-DEF_INSTR(Unreachable,   OP(Unreachable),   UnreachableInst)
+//        ClassID,        Opcode(s),          Class
+DEF_INSTR(Opaque,         OP(Opaque),         OpaqueInst)
+DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
+DEF_INSTR(InsertElement,  OP(InsertElement),  InsertElementInst)
+DEF_INSTR(Select,         OP(Select),         SelectInst)
+DEF_INSTR(Br,             OP(Br),             BranchInst)
+DEF_INSTR(Load,           OP(Load),           LoadInst)
+DEF_INSTR(Store,          OP(Store),          StoreInst)
+DEF_INSTR(Ret,            OP(Ret),            ReturnInst)
+DEF_INSTR(Call,           OP(Call),           CallInst)
+DEF_INSTR(Invoke,         OP(Invoke),         InvokeInst)
+DEF_INSTR(CallBr,         OP(CallBr),         CallBrInst)
+DEF_INSTR(GetElementPtr,  OP(GetElementPtr),  GetElementPtrInst)
+DEF_INSTR(Alloca,         OP(Alloca),         AllocaInst)
+DEF_INSTR(Cast,   OPCODES(\
+                          OP(ZExt)          \
+                          OP(SExt)          \
+                          OP(FPToUI)        \
+                          OP(FPToSI)        \
+                          OP(FPExt)         \
+                          OP(PtrToInt)      \
+                          OP(IntToPtr)      \
+                          OP(SIToFP)        \
+                          OP(UIToFP)        \
+                          OP(Trunc)         \
+                          OP(FPTrunc)       \
+                          OP(BitCast)       \
+                          OP(AddrSpaceCast) \
+                          ),                  CastInst)
+DEF_INSTR(PHI,            OP(PHI),            PHINode)
+DEF_INSTR(Unreachable,    OP(Unreachable),    UnreachableInst)
 
 // clang-format on
 #ifdef DEF_VALUE
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index fd74e7f5eaf62e..2b59d7844cd548 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1410,6 +1410,42 @@ void InsertElementInst::dump() const {
 }
 #endif // NDEBUG
 
+Value *ExtractElementInst::create(Value *Vec, Value *Idx,
+                                  Instruction *InsertBefore, Context &Ctx,
+                                  const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+  llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
+  if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
+    return Ctx.createExtractElementInst(NewExtract);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *ExtractElementInst::create(Value *Vec, Value *Idx,
+                                  BasicBlock *InsertAtEnd, Context &Ctx,
+                                  const Twine &Name) {
+  auto &Builder = Ctx.getLLVMIRBuilder();
+  Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+  llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
+  if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
+    return Ctx.createExtractElementInst(NewExtract);
+  assert(isa<llvm::Constant>(NewV) && "Expected constant");
+  return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+#ifndef NDEBUG
+void ExtractElementInst::dump(raw_ostream &OS) const {
+  dumpCommonPrefix(OS);
+  dumpCommonSuffix(OS);
+}
+
+void ExtractElementInst::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
                               bool IsSigned) {
   llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
@@ -1540,6 +1576,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
     return It->second.get();
   }
+  case llvm::Instruction::ExtractElement: {
+    auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
+    It->second = std::unique_ptr<ExtractElementInst>(
+        new ExtractElementInst(LLVMIns, *this));
+    return It->second.get();
+  }
   case llvm::Instruction::InsertElement: {
     auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
     It->second = std::unique_ptr<InsertElementInst>(
@@ -1643,6 +1685,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
   return cast<SelectInst>(registerValue(std::move(NewPtr)));
 }
 
+ExtractElementInst *
+Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
+  auto NewPtr =
+      std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
+  return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
+}
+
 InsertElementInst *
 Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
   auto NewPtr =
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 3e52b05ad2e94c..7af1c0b060a873 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -631,6 +631,48 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
   }
 }
 
+TEST_F(SandboxIRTest, ExtractElementInst) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %vec, i32 %idx) {
+  %ins0 = extractelement <2 x i8> %vec, i32 %idx
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *ArgVec = F.getArg(0);
+  auto *ArgIdx = F.getArg(1);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *Ins0 = cast<sandboxir::ExtractElementInst>(&*It++);
+  auto *Ret = &*It++;
+
+  EXPECT_EQ(Ins0->getOpcode(), sandboxir::Instruction::Opcode::ExtractElement);
+  EXPECT_EQ(Ins0->getOperand(0), ArgVec);
+  EXPECT_EQ(Ins0->getOperand(1), ArgIdx);
+  EXPECT_EQ(Ins0->getVectorOperand(), ArgVec);
+  EXPECT_EQ(Ins0->getIndexOperand(), ArgIdx);
+  auto *NewI1 =
+      cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
+          ArgVec, ArgIdx, Ret, Ctx, "NewInsBeforeRet"));
+  EXPECT_EQ(NewI1->getOperand(0), ArgVec);
+  EXPECT_EQ(NewI1->getOperand(1), ArgIdx);
+  EXPECT_EQ(NewI1->getNextNode(), Ret);
+
+  auto *NewI2 =
+      cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
+          ArgVec, ArgIdx, BB, Ctx, "NewInsAtEndOfBB"));
+  EXPECT_EQ(NewI2->getPrevNode(), Ret);
+
+  auto *LLVMArgVec = LLVMF.getArg(0);
+  auto *LLVMArgIdx = LLVMF.getArg(1);
+  EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgVec, ArgIdx),
+            llvm::ExtractElementInst::isValidOperands(LLVMArgVec, LLVMArgIdx));
+  EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgIdx, ArgVec),
+            llvm::ExtractElementInst::isValidOperands(LLVMArgIdx, LLVMArgVec));
+}
+
 TEST_F(SandboxIRTest, InsertElementInst) {
   parseIR(C, R"IR(
 define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {



More information about the llvm-commits mailing list