[llvm] 15aa4ef - [SandboxIR] Add the ExtractElementInst class (#102706)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 12 12:53:12 PDT 2024
Author: Jorge Gorbe Moya
Date: 2024-08-12T12:53:07-07:00
New Revision: 15aa4ef057438df5bae8aaf7ff07b31dfcc1ef77
URL: https://github.com/llvm/llvm-project/commit/15aa4ef057438df5bae8aaf7ff07b31dfcc1ef77
DIFF: https://github.com/llvm/llvm-project/commit/15aa4ef057438df5bae8aaf7ff07b31dfcc1ef77.diff
LOG: [SandboxIR] Add the ExtractElementInst class (#102706)
Added:
Modified:
llvm/include/llvm/SandboxIR/SandboxIR.h
llvm/include/llvm/SandboxIR/SandboxIRValues.def
llvm/lib/SandboxIR/SandboxIR.cpp
llvm/unittests/SandboxIR/SandboxIRTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index cf1246951ecc18..c160520788d873 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -111,6 +111,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
+class ExtractElementInst;
class InsertElementInst;
class BranchInst;
class UnaryInstruction;
@@ -232,24 +233,25 @@ class Value {
/// order.
llvm::Value *Val = nullptr;
- friend class Context; // For getting `Val`.
- friend class User; // For getting `Val`.
- friend class Use; // For getting `Val`.
- friend class SelectInst; // For getting `Val`.
- friend class InsertElementInst; // For getting `Val`.
- friend class BranchInst; // For getting `Val`.
- friend class LoadInst; // For getting `Val`.
- friend class StoreInst; // For getting `Val`.
- friend class ReturnInst; // For getting `Val`.
- friend class CallBase; // For getting `Val`.
- friend class CallInst; // For getting `Val`.
- friend class InvokeInst; // For getting `Val`.
- friend class CallBrInst; // For getting `Val`.
- friend class GetElementPtrInst; // For getting `Val`.
- friend class AllocaInst; // For getting `Val`.
- friend class CastInst; // For getting `Val`.
- friend class PHINode; // For getting `Val`.
- friend class UnreachableInst; // For getting `Val`.
+ friend class Context; // For getting `Val`.
+ friend class User; // For getting `Val`.
+ friend class Use; // For getting `Val`.
+ friend class SelectInst; // For getting `Val`.
+ friend class ExtractElementInst; // For getting `Val`.
+ friend class InsertElementInst; // For getting `Val`.
+ friend class BranchInst; // For getting `Val`.
+ friend class LoadInst; // For getting `Val`.
+ friend class StoreInst; // For getting `Val`.
+ friend class ReturnInst; // For getting `Val`.
+ friend class CallBase; // For getting `Val`.
+ friend class CallInst; // For getting `Val`.
+ friend class InvokeInst; // For getting `Val`.
+ friend class CallBrInst; // For getting `Val`.
+ friend class GetElementPtrInst; // For getting `Val`.
+ friend class AllocaInst; // For getting `Val`.
+ friend class CastInst; // For getting `Val`.
+ friend class PHINode; // For getting `Val`.
+ friend class UnreachableInst; // For getting `Val`.
/// All values point to the context.
Context &Ctx;
@@ -615,20 +617,21 @@ class Instruction : public sandboxir::User {
/// A SandboxIR Instruction may map to multiple LLVM IR Instruction. This
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
- friend class SelectInst; // For getTopmostLLVMInstruction().
- friend class InsertElementInst; // For getTopmostLLVMInstruction().
- friend class BranchInst; // For getTopmostLLVMInstruction().
- friend class LoadInst; // For getTopmostLLVMInstruction().
- friend class StoreInst; // For getTopmostLLVMInstruction().
- friend class ReturnInst; // For getTopmostLLVMInstruction().
- friend class CallInst; // For getTopmostLLVMInstruction().
- friend class InvokeInst; // For getTopmostLLVMInstruction().
- friend class CallBrInst; // For getTopmostLLVMInstruction().
- friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
- friend class AllocaInst; // For getTopmostLLVMInstruction().
- friend class CastInst; // For getTopmostLLVMInstruction().
- friend class PHINode; // For getTopmostLLVMInstruction().
- friend class UnreachableInst; // For getTopmostLLVMInstruction().
+ friend class SelectInst; // For getTopmostLLVMInstruction().
+ friend class ExtractElementInst; // For getTopmostLLVMInstruction().
+ friend class InsertElementInst; // For getTopmostLLVMInstruction().
+ friend class BranchInst; // For getTopmostLLVMInstruction().
+ friend class LoadInst; // For getTopmostLLVMInstruction().
+ friend class StoreInst; // For getTopmostLLVMInstruction().
+ friend class ReturnInst; // For getTopmostLLVMInstruction().
+ friend class CallInst; // For getTopmostLLVMInstruction().
+ friend class InvokeInst; // For getTopmostLLVMInstruction().
+ friend class CallBrInst; // For getTopmostLLVMInstruction().
+ friend class GetElementPtrInst; // For getTopmostLLVMInstruction().
+ friend class AllocaInst; // For getTopmostLLVMInstruction().
+ friend class CastInst; // For getTopmostLLVMInstruction().
+ friend class PHINode; // For getTopmostLLVMInstruction().
+ friend class UnreachableInst; // For getTopmostLLVMInstruction().
/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
@@ -768,6 +771,37 @@ class InsertElementInst final
}
};
+class ExtractElementInst final
+ : public SingleLLVMInstructionImpl<llvm::ExtractElementInst> {
+ /// Use Context::createExtractElementInst() instead.
+ ExtractElementInst(llvm::Instruction *I, Context &Ctx)
+ : SingleLLVMInstructionImpl(ClassID::ExtractElement,
+ Opcode::ExtractElement, I, Ctx) {}
+ friend class Context; // For accessing the constructor in
+ // create*()
+
+public:
+ static Value *create(Value *Vec, Value *Idx, Instruction *InsertBefore,
+ Context &Ctx, const Twine &Name = "");
+ static Value *create(Value *Vec, Value *Idx, BasicBlock *InsertAtEnd,
+ Context &Ctx, const Twine &Name = "");
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::ExtractElement;
+ }
+
+ static bool isValidOperands(const Value *Vec, const Value *Idx) {
+ return llvm::ExtractElementInst::isValidOperands(Vec->Val, Idx->Val);
+ }
+ Value *getVectorOperand() { return getOperand(0); }
+ Value *getIndexOperand() { return getOperand(1); }
+ const Value *getVectorOperand() const { return getOperand(0); }
+ const Value *getIndexOperand() const { return getOperand(1); }
+
+ VectorType *getVectorOperandType() const {
+ return cast<VectorType>(getVectorOperand()->getType());
+ }
+};
+
class BranchInst : public SingleLLVMInstructionImpl<llvm::BranchInst> {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
@@ -1644,6 +1678,8 @@ class Context {
friend SelectInst; // For createSelectInst()
InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
friend InsertElementInst; // For createInsertElementInst()
+ ExtractElementInst *createExtractElementInst(llvm::ExtractElementInst *EEI);
+ friend ExtractElementInst; // For createExtractElementInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 269aea784dcec1..11f4f2e74712f4 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -32,36 +32,37 @@ DEF_USER(Constant, Constant)
#define OPCODES(...)
#endif
// clang-format off
-// ClassID, Opcode(s), Class
-DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
-DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
-DEF_INSTR(Select, OP(Select), SelectInst)
-DEF_INSTR(Br, OP(Br), BranchInst)
-DEF_INSTR(Load, OP(Load), LoadInst)
-DEF_INSTR(Store, OP(Store), StoreInst)
-DEF_INSTR(Ret, OP(Ret), ReturnInst)
-DEF_INSTR(Call, OP(Call), CallInst)
-DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
-DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
-DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
-DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
-DEF_INSTR(Cast, OPCODES(\
- OP(ZExt) \
- OP(SExt) \
- OP(FPToUI) \
- OP(FPToSI) \
- OP(FPExt) \
- OP(PtrToInt) \
- OP(IntToPtr) \
- OP(SIToFP) \
- OP(UIToFP) \
- OP(Trunc) \
- OP(FPTrunc) \
- OP(BitCast) \
- OP(AddrSpaceCast) \
- ), CastInst)
-DEF_INSTR(PHI, OP(PHI), PHINode)
-DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
+// ClassID, Opcode(s), Class
+DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
+DEF_INSTR(ExtractElement, OP(ExtractElement), ExtractElementInst)
+DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
+DEF_INSTR(Select, OP(Select), SelectInst)
+DEF_INSTR(Br, OP(Br), BranchInst)
+DEF_INSTR(Load, OP(Load), LoadInst)
+DEF_INSTR(Store, OP(Store), StoreInst)
+DEF_INSTR(Ret, OP(Ret), ReturnInst)
+DEF_INSTR(Call, OP(Call), CallInst)
+DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
+DEF_INSTR(CallBr, OP(CallBr), CallBrInst)
+DEF_INSTR(GetElementPtr, OP(GetElementPtr), GetElementPtrInst)
+DEF_INSTR(Alloca, OP(Alloca), AllocaInst)
+DEF_INSTR(Cast, OPCODES(\
+ OP(ZExt) \
+ OP(SExt) \
+ OP(FPToUI) \
+ OP(FPToSI) \
+ OP(FPExt) \
+ OP(PtrToInt) \
+ OP(IntToPtr) \
+ OP(SIToFP) \
+ OP(UIToFP) \
+ OP(Trunc) \
+ OP(FPTrunc) \
+ OP(BitCast) \
+ OP(AddrSpaceCast) \
+ ), CastInst)
+DEF_INSTR(PHI, OP(PHI), PHINode)
+DEF_INSTR(Unreachable, OP(Unreachable), UnreachableInst)
// clang-format on
#ifdef DEF_VALUE
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 192c53d1a13f19..445f56b14e83b5 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -1235,6 +1235,30 @@ Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}
+Value *ExtractElementInst::create(Value *Vec, Value *Idx,
+ Instruction *InsertBefore, Context &Ctx,
+ const Twine &Name) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
+ llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
+ if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
+ return Ctx.createExtractElementInst(NewExtract);
+ assert(isa<llvm::Constant>(NewV) && "Expected constant");
+ return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
+Value *ExtractElementInst::create(Value *Vec, Value *Idx,
+ BasicBlock *InsertAtEnd, Context &Ctx,
+ const Twine &Name) {
+ auto &Builder = Ctx.getLLVMIRBuilder();
+ Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
+ llvm::Value *NewV = Builder.CreateExtractElement(Vec->Val, Idx->Val, Name);
+ if (auto *NewExtract = dyn_cast<llvm::ExtractElementInst>(NewV))
+ return Ctx.createExtractElementInst(NewExtract);
+ assert(isa<llvm::Constant>(NewV) && "Expected constant");
+ return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
+}
+
Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned) {
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
@@ -1356,6 +1380,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
+ case llvm::Instruction::ExtractElement: {
+ auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
+ It->second = std::unique_ptr<ExtractElementInst>(
+ new ExtractElementInst(LLVMIns, *this));
+ return It->second.get();
+ }
case llvm::Instruction::InsertElement: {
auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
It->second = std::unique_ptr<InsertElementInst>(
@@ -1459,6 +1489,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}
+ExtractElementInst *
+Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
+ auto NewPtr =
+ std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
+ return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
+}
+
InsertElementInst *
Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
auto NewPtr =
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index f98a60b49ecab3..1cd1ca6a418c61 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -631,6 +631,50 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
}
}
+TEST_F(SandboxIRTest, ExtractElementInst) {
+ parseIR(C, R"IR(
+define void @foo(<2 x i8> %vec, i32 %idx) {
+ %ins0 = extractelement <2 x i8> %vec, i32 %idx
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *ArgVec = F.getArg(0);
+ auto *ArgIdx = F.getArg(1);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *EI = cast<sandboxir::ExtractElementInst>(&*It++);
+ auto *Ret = &*It++;
+
+ EXPECT_EQ(EI->getOpcode(), sandboxir::Instruction::Opcode::ExtractElement);
+ EXPECT_EQ(EI->getOperand(0), ArgVec);
+ EXPECT_EQ(EI->getOperand(1), ArgIdx);
+ EXPECT_EQ(EI->getVectorOperand(), ArgVec);
+ EXPECT_EQ(EI->getIndexOperand(), ArgIdx);
+ EXPECT_EQ(EI->getVectorOperandType(), ArgVec->getType());
+
+ auto *NewI1 =
+ cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
+ ArgVec, ArgIdx, Ret, Ctx, "NewExtrBeforeRet"));
+ EXPECT_EQ(NewI1->getOperand(0), ArgVec);
+ EXPECT_EQ(NewI1->getOperand(1), ArgIdx);
+ EXPECT_EQ(NewI1->getNextNode(), Ret);
+
+ auto *NewI2 =
+ cast<sandboxir::ExtractElementInst>(sandboxir::ExtractElementInst::create(
+ ArgVec, ArgIdx, BB, Ctx, "NewExtrAtEndOfBB"));
+ EXPECT_EQ(NewI2->getPrevNode(), Ret);
+
+ auto *LLVMArgVec = LLVMF.getArg(0);
+ auto *LLVMArgIdx = LLVMF.getArg(1);
+ EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgVec, ArgIdx),
+ llvm::ExtractElementInst::isValidOperands(LLVMArgVec, LLVMArgIdx));
+ EXPECT_EQ(sandboxir::ExtractElementInst::isValidOperands(ArgIdx, ArgVec),
+ llvm::ExtractElementInst::isValidOperands(LLVMArgIdx, LLVMArgVec));
+}
+
TEST_F(SandboxIRTest, InsertElementInst) {
parseIR(C, R"IR(
define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
More information about the llvm-commits
mailing list