[llvm] f768083 - [SPIR-V] Update type inference and instruction selection (#88254)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 15 00:59:50 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-04-15T09:59:47+02:00
New Revision: f76808351677f4e361f2acd9c4b3f385901d37ef
URL: https://github.com/llvm/llvm-project/commit/f76808351677f4e361f2acd9c4b3f385901d37ef
DIFF: https://github.com/llvm/llvm-project/commit/f76808351677f4e361f2acd9c4b3f385901d37ef.diff
LOG: [SPIR-V] Update type inference and instruction selection (#88254)
This PR contains a series of fixes which are to improve type inference
and instruction selection.
Namely, it includes:
* fix OpSelect to support operands of a pointer type, according to the
SPIR-V specification (previously only integer/float/vectors of integer
or float were supported) -- a new test case is added and existing test
case is updated;
* fix TableGen typo's in definition of register classes and introduce a
new reg class that is a vector of pointers;
* fix usage of a machine function context when there is a need to switch
between different machine functions to infer/validate correct types;
* add usage of TypedPointerType instead of PointerType so that later
stages of type inference are able to distinguish pointer types by their
element types, effectively supporting hierarchy of pointer/pointee types
and avoiding more complicated recursive type matching on level of
machine instructions in favor of direct pointer comparison using LLVM's
`Type *` values;
* extracting detailed information about operand types using known type
rules for some llvm instructions (for instance, by deducing PHI's
operand pointee types if PHI's results type was deducted on previous
stages of type inference), and adding correspondent
`Intrinsic::spv_assign_ptr_type` to keep type info along consequent
passes,
* ensure that OpConstantComposite reuses a constant when it's already
created and available in the same machine function -- otherwise there is
a crash while building a dependency graph, the corresponding test case
is attached,
* implement deduction of function's return type for opaque pointers, a
new test case is attached,
* make 'emit intrinsics' a module pass to resolve function return types
over the module -- first types for all functions of the module must be
calculated, and only after that it's feasible to deduct function return
types on this earlier stage of translation.
Added:
llvm/test/CodeGen/SPIRV/const-composite.ll
llvm/test/CodeGen/SPIRV/instructions/ret-type.ll
llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
Modified:
llvm/lib/Target/SPIRV/SPIRV.h
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
llvm/lib/Target/SPIRV/SPIRVUtils.cpp
llvm/lib/Target/SPIRV/SPIRVUtils.h
llvm/test/CodeGen/SPIRV/instructions/select.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6979107349d968..fb8580cd47c01b 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVPostLegalizerPass();
-FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
+ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
const SPIRVSubtarget &Subtarget,
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e4ba2191366b3..c107b99cf4cb63 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
FunctionType *FTy = getOriginalFunctionType(F);
- SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+ Type *FRetTy = FTy->getReturnType();
+ if (isUntypedPointerTy(FRetTy)) {
+ if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
+ TypedPointerType *DerivedTy =
+ TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
+ GR->addReturnType(&F, DerivedTy);
+ FRetTy = DerivedTy;
+ }
+ }
+ SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
- if (FunctionType *FTy = getOriginalFunctionType(*CF))
+ if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
OrigRetTy = FTy->getReturnType();
+ if (isUntypedPointerTy(OrigRetTy)) {
+ if (auto *DerivedRetTy = GR->findReturnType(CF))
+ OrigRetTy = DerivedRetTy;
+ }
+ }
}
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..472bc8638c9af1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
namespace {
class SPIRVEmitIntrinsics
- : public FunctionPass,
+ : public ModulePass,
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
SPIRVTargetMachine *TM = nullptr;
SPIRVGlobalRegistry *GR = nullptr;
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
+ // a registry of created Intrinsic::spv_assign_ptr_type instructions
+ DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
+
// deduce element type of untyped pointers
Type *deduceElementType(Value *I);
Type *deduceElementTypeHelper(Value *I);
@@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
Type *deduceNestedTypeHelper(User *U, Type *Ty,
std::unordered_set<Value *> &Visited);
+ // deduce Types of operands of the Instruction if possible
+ void deduceOperandElementType(Instruction *I);
+
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics
public:
static char ID;
- SPIRVEmitIntrinsics() : FunctionPass(ID) {
+ SPIRVEmitIntrinsics() : ModulePass(ID) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
- SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
+ SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
Instruction *visitInstruction(Instruction &I) { return &I; }
@@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
Instruction *visitAllocaInst(AllocaInst &I);
Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
Instruction *visitUnreachableInst(UnreachableInst &I);
- bool runOnFunction(Function &F) override;
+
+ StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
+
+ bool runOnModule(Module &M) override;
+ bool runOnFunction(Function &F);
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ ModulePass::getAnalysisUsage(AU);
+ }
};
} // namespace
@@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (Ty)
break;
}
+ } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+ for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
+ Ty = deduceElementTypeByUsersDeep(Op, Visited);
+ if (Ty)
+ break;
+ }
}
// remember the found relationship
@@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
return IntegerType::getInt8Ty(I->getContext());
}
+// If the Instruction has Pointer operands with unresolved types, this function
+// tries to deduce them. If the Instruction has Pointer operands with known
+// types which
diff er from expected, this function tries to insert a bitcast to
+// resolve the issue.
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+ SmallVector<std::pair<Value *, unsigned>> Ops;
+ Type *KnownElemTy = nullptr;
+ // look for known basic patterns of type inference
+ if (auto *Ref = dyn_cast<PHINode>(I)) {
+ if (!isPointerTy(I->getType()) ||
+ !(KnownElemTy = GR->findDeducedElementType(I)))
+ return;
+ for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+ Value *Op = Ref->getIncomingValue(i);
+ if (isPointerTy(Op->getType()))
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+ if (!isPointerTy(I->getType()) ||
+ !(KnownElemTy = GR->findDeducedElementType(I)))
+ return;
+ for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
+ Value *Op = Ref->getOperand(i);
+ if (isPointerTy(Op->getType()))
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
+ Type *RetTy = F->getReturnType();
+ if (!isPointerTy(RetTy))
+ return;
+ Value *Op = Ref->getReturnValue();
+ if (!Op)
+ return;
+ if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+ if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+ GR->addDeducedElementType(F, OpElemTy);
+ TypedPointerType *DerivedTy =
+ TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
+ GR->addReturnType(F, DerivedTy);
+ }
+ return;
+ }
+ Ops.push_back(std::make_pair(Op, 0));
+ } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
+ if (!isPointerTy(Ref->getOperand(0)->getType()))
+ return;
+ Value *Op0 = Ref->getOperand(0);
+ Value *Op1 = Ref->getOperand(1);
+ Type *ElemTy0 = GR->findDeducedElementType(Op0);
+ Type *ElemTy1 = GR->findDeducedElementType(Op1);
+ if (ElemTy0) {
+ KnownElemTy = ElemTy0;
+ Ops.push_back(std::make_pair(Op1, 1));
+ } else if (ElemTy1) {
+ KnownElemTy = ElemTy1;
+ Ops.push_back(std::make_pair(Op0, 0));
+ }
+ }
+
+ // There is no enough info to deduce types or all is valid.
+ if (!KnownElemTy || Ops.size() == 0)
+ return;
+
+ LLVMContext &Ctx = F->getContext();
+ IRBuilder<> B(Ctx);
+ for (auto &OpIt : Ops) {
+ Value *Op = OpIt.first;
+ if (Op->use_empty())
+ continue;
+ Type *Ty = GR->findDeducedElementType(Op);
+ if (Ty == KnownElemTy)
+ continue;
+ if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ else
+ B.SetInsertPoint(I);
+ Value *OpTyVal = Constant::getNullValue(KnownElemTy);
+ Type *OpTy = Op->getType();
+ if (!Ty) {
+ GR->addDeducedElementType(Op, KnownElemTy);
+ // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+ auto It = AssignPtrTypeInstr.find(Op);
+ if (It == AssignPtrTypeInstr.end()) {
+ CallInst *CI =
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+ AssignPtrTypeInstr[Op] = CI;
+ } else {
+ It->second->setArgOperand(
+ 1,
+ MetadataAsValue::get(
+ Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+ }
+ } else {
+ SmallVector<Type *, 2> Types = {OpTy, OpTy};
+ MetadataAsValue *VMD = MetadataAsValue::get(
+ Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
+ SmallVector<Value *, 2> Args = {Op, VMD,
+ B.getInt32(getPointerAddressSpace(OpTy))};
+ CallInst *PtrCastI =
+ B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+ I->setOperand(OpIt.second, PtrCastI);
+ }
+ }
+}
+
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
@@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ExpectedElementType);
GR->addDeducedElementType(Pointer, ExpectedElementType);
+ AssignPtrTypeInstr[Pointer] = CI;
return;
}
@@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ElemTy);
+ AssignPtrTypeInstr[I] = CI;
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
+ AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
}
}
}
@@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
}
+
+ for (auto &I : instructions(Func))
+ deduceOperandElementType(&I);
+
for (auto *I : Worklist) {
TrackConstants = true;
if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
processInstrAfterVisit(I, B);
}
- // check if function parameter types are set
- if (!F->isIntrinsic())
- processParamTypes(F, B);
-
return true;
}
-FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
+bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
+ bool Changed = false;
+
+ for (auto &F : M) {
+ Changed |= runOnFunction(F);
+ }
+
+ for (auto &F : M) {
+ // check if function parameter types are set
+ if (!F.isDeclaration() && !F.isIntrinsic()) {
+ const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
+ GR = ST.getSPIRVGlobalRegistry();
+ IRBuilder<> B(F.getContext());
+ processParamTypes(&F, B);
+ }
+ }
+
+ return Changed;
+}
+
+ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
return new SPIRVEmitIntrinsics(TM);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 70197e948c6582..05e41e06248e35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -23,7 +23,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
-#include "llvm/IR/TypedPointerType.h"
#include "llvm/Support/Casting.h"
#include <cassert>
@@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-
SPIRVType *SpirvType =
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
@@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return SpirvType;
}
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
- auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF) const {
+ auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2e3e69456ac260..55979ba403a0ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -21,6 +21,7 @@
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/Constant.h"
+#include "llvm/IR/TypedPointerType.h"
namespace llvm {
using SPIRVType = const MachineInstr;
@@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
SmallPtrSet<const Type *, 4> TypesInProcessing;
DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
+ // if a function returns a pointer, this is to map it into TypedPointerType
+ DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
+
// Number of bits pointers and size_t integers require.
const unsigned PointerSize;
@@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }
+ // Add a record to the map of function return pointer types.
+ void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
+ FunResPointerTypes[ArgF] = DerivedTy;
+ }
+ // Find a record in the map of function return pointer types.
+ const TypedPointerType *findReturnType(const Function *ArgF) {
+ auto It = FunResPointerTypes.find(ArgF);
+ return It == FunResPointerTypes.end() ? nullptr : It->second;
+ }
+
// Deduced element types of untyped pointers and composites:
// - Add a record to the map of deduced element types.
void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
@@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
SPIRV::AccessQualifier::ReadWrite);
// Return the SPIR-V type instruction corresponding to the given VReg, or
- // nullptr if no such type instruction exists.
- SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+ // nullptr if no such type instruction exists. The second argument MF
+ // allows to search for the association in a context of the machine functions
+ // than the current one, without switching between
diff erent "current" machine
+ // functions.
+ SPIRVType *getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF = nullptr) const;
// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
MachineInstr &I, unsigned OpIdx,
SPIRVType *ResType, const Type *ResTy = nullptr) {
+ // Get operand type
+ MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+ Register OpTypeReg =
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
? TypeInst->getOperand(1).getReg()
- : OpReg);
+ : OpReg;
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
- SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ // Get operand's pointee type
+ Register ElemTypeReg = OpType->getOperand(2).getReg();
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
if (!ElemType)
return;
- bool IsSameMF =
- ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+ // Check if we need a bitcast to make a statement valid
+ bool IsSameMF = MF == ResType->getParent()->getParent();
bool IsEqualTypes = IsSameMF ? ElemType == ResType
: GR.getTypeForSPIRVType(ElemType) == ResTy;
if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
SPIRVType *DefElemType =
DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
- ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+ ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+ DefPtrType->getParent()->getParent())
: nullptr;
if (DefElemType) {
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
// with a processed definition. Return Function pointer if it's a forward
// call (ahead of definition), and nullptr otherwise.
const Function *validateFunCall(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
+ MachineRegisterInfo *CallMRI,
SPIRVGlobalRegistry &GR,
MachineInstr &FunCall) {
const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
if (!FunDef)
return F;
- validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+ MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+ validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
return nullptr;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index e3f76419f13137..aacfecc1e313f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
- MI.getOpcode() == SPIRV::GET_vID) {
+ MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
auto &MRI = MI.getMF()->getRegInfo();
MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
MI.eraseFromParent();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 99c57dac4141d8..a3f981457c8daa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -22,6 +22,7 @@ let isCodeGenOnly=1 in {
def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>;
def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>;
def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>;
+ def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins ANYID:$src)>;
}
def SPVTypeBin : SDTypeProfile<1, 2, []>;
@@ -55,7 +56,7 @@ multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genF = 0
}
}
-multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> {
+multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP = 1, bit genI = 1, bit genF = 0, bit genV = 0> {
if genF then {
def SFSCond: TernOpTyped<name, opCode, ID, fID, node>;
def SFVCond: TernOpTyped<name, opCode, vID, fID, node>;
@@ -64,6 +65,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI =
def SISCond: TernOpTyped<name, opCode, ID, ID, node>;
def SIVCond: TernOpTyped<name, opCode, vID, ID, node>;
}
+ if genP then {
+ def SPSCond: TernOpTyped<name, opCode, ID, pID, node>;
+ def SPVCond: TernOpTyped<name, opCode, vID, pID, node>;
+ }
if genV then {
if genF then {
def VFSCond: TernOpTyped<name, opCode, ID, vfID, node>;
@@ -73,6 +78,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genI =
def VISCond: TernOpTyped<name, opCode, ID, vID, node>;
def VIVCond: TernOpTyped<name, opCode, vID, vID, node>;
}
+ if genP then {
+ def VPSCond: TernOpTyped<name, opCode, ID, vpID, node>;
+ def VPVCond: TernOpTyped<name, opCode, vID, vpID, node>;
+ }
}
}
@@ -552,7 +561,7 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>;
def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>;
def OpLogicalNot: UnOp<"OpLogicalNot", 168>;
-defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>;
+defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1, 1>;
def OpIEqual: BinOp<"OpIEqual", 170>;
def OpINotEqual: BinOp<"OpINotEqual", 171>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index b9d66de9555b11..f069a92ac68683 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -56,7 +56,7 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool isMetaInstrGET(unsigned Opcode) {
return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
- Opcode == SPIRV::GET_vfID;
+ Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
}
static bool mayBeInserted(unsigned Opcode) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7e155a36aadbc4..2c964595fc39e8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg());
assert(BuildVec &&
BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
- for (unsigned i = 0; i < ConstVec->getNumElements(); ++i)
- GR->add(ConstVec->getElementAsConstant(i), &MF,
- BuildVec->getOperand(1 + i).getReg());
+ for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
+ // Ensure that OpConstantComposite reuses a constant when it's
+ // already created and available in the same machine function.
+ Constant *ElemConst = ConstVec->getElementAsConstant(i);
+ Register ElemReg = GR->find(ElemConst, &MF);
+ if (!ElemReg.isValid())
+ GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
+ else
+ BuildVec->getOperand(1 + i).setReg(ElemReg);
+ }
}
GR->add(Const, &MF, MI.getOperand(2).getReg());
} else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
index 9bf9d7fe5b39e8..5983c9229cb3c2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
@@ -39,6 +39,8 @@ SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
return SPIRV::vIDRegBank;
case SPIRV::vfIDRegClassID:
return SPIRV::vfIDRegBank;
+ case SPIRV::vpIDRegClassID:
+ return SPIRV::vpIDRegBank;
case SPIRV::ANYIDRegClassID:
case SPIRV::ANYRegClassID:
return SPIRV::IDRegBank;
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
index 90c7f3a6e67265..c7f1e172f3d4f1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
@@ -12,4 +12,5 @@ def IDRegBank : RegisterBank<"IDBank", [ID]>;
def fIDRegBank : RegisterBank<"fIDBank", [fID]>;
def vIDRegBank : RegisterBank<"vIDBank", [vID]>;
def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>;
+def vpIDRegBank : RegisterBank<"vpIDBank", [vpID]>;
def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
index d0b64b6895d035..6d2bfb91a97f12 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
@@ -12,6 +12,17 @@
let Namespace = "SPIRV" in {
def p0 : PtrValueType <i32, 0>;
+
+ class P0Vec<ValueType scalar>
+ : PtrValueType <scalar, 0> {
+ let nElem = 2;
+ let ElementType = p0;
+ let isInteger = false;
+ let isFP = false;
+ let isVector = true;
+ }
+
+ def v2p0 : P0Vec<i32>;
// All registers are for 32-bit identifiers, so have a single dummy register
// Class for registers that are the result of OpTypeXXX instructions
@@ -21,14 +32,16 @@ let Namespace = "SPIRV" in {
// Class for every other non-type ID
def ID0 : Register<"ID0">;
def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>;
- def fID0 : Register<"FID0">;
+ def fID0 : Register<"fID0">;
def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>;
def pID0 : Register<"pID0">;
def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>;
- def vID0 : Register<"pID0">;
+ def vID0 : Register<"vID0">;
def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>;
- def vfID0 : Register<"pID0">;
+ def vfID0 : Register<"vfID0">;
def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>;
+ def vpID0 : Register<"vpID0">;
+ def vpID : RegisterClass<"SPIRV", [v2p0], 32, (add vpID0)>;
def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 299a4341193bfd..2e44c208ed8e04 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
}
Type *getMDOperandAsType(const MDNode *N, unsigned I) {
- return cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType();
+ return toTypedPointer(ElementTy, N->getContext());
}
// The set of names is borrowed from the SPIR-V translator.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c2c3475e1a936f..cd1a2af09147e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) {
return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
}
+inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) {
+ return isUntypedPointerTy(Ty)
+ ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx),
+ getPointerAddressSpace(Ty))
+ : Ty;
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll
new file mode 100644
index 00000000000000..4e304bb9516702
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/const-composite.ll
@@ -0,0 +1,26 @@
+; This test is to ensure that OpConstantComposite reuses a constant when it's
+; already created and available in the same machine function. In this test case
+; it's `1` that is passed implicitly as a part of the `foo` function argument
+; and also takes part in a composite constant creation.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0
+; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1
+; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]]
+; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0
+; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]]
+
+%struct = type { [1 x i64] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+entry:
+ call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+ ret void
+}
+
+define spir_func void @bar(<2 x i32> noundef) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll b/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll
new file mode 100644
index 00000000000000..bf71eb5628e217
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll
@@ -0,0 +1,82 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[Test1:.*]] "test1"
+; CHECK-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-DAG: OpName %[[Test2:.*]] "test2"
+
+; CHECK-DAG: %[[Long:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]]
+; CHECK-DAG: %[[Struct1:.*]] = OpTypeStruct %[[Array]]
+; CHECK-DAG: %[[Struct2:.*]] = OpTypeStruct %[[Struct1]]
+; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct2]]
+; CHECK-DAG: %[[Bool:.*]] = OpTypeBool
+; CHECK-DAG: %[[FooType:.*]] = OpTypeFunction %[[StructPtr:.*]] %[[StructPtr]] %[[StructPtr]] %[[Bool]]
+; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[CharPtr:.*]] = OpTypePointer Function %[[Char]]
+
+; CHECK: %[[Test1]] = OpFunction
+; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Foo]]
+; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Bar]]
+; CHECK: OpFunctionEnd
+
+; CHECK: %[[Foo]] = OpFunction %[[StructPtr:.*]] None %[[FooType]]
+; CHECK: %[[Arg1:.*]] = OpFunctionParameter %[[StructPtr]]
+; CHECK: %[[Arg2:.*]] = OpFunctionParameter
+; CHECK: %[[Sw:.*]] = OpFunctionParameter
+; CHECK: %[[Res:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[Arg1]] %[[#]]
+; CHECK: OpReturnValue %[[Res]]
+; CHECK: OpReturnValue %[[Arg2]]
+
+; CHECK: %[[Bar]] = OpFunction %[[StructPtr:.*]] None %[[#]]
+; CHECK: %[[BarArg:.*]] = OpFunctionParameter
+; CHECK: %[[BarRes:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[BarArg]] %[[#]]
+; CHECK: %[[BarResCasted:.*]] = OpBitcast %[[StructPtr]] %[[BarRes]]
+; CHECK: %[[BarResStruct:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[#]] %[[#]]
+; CHECK: OpReturnValue %[[BarResStruct]]
+; CHECK: OpReturnValue %[[BarResCasted]]
+
+; CHECK: %[[Test2]] = OpFunction
+; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Foo]]
+; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Bar]]
+; CHECK: OpFunctionEnd
+
+%struct = type { %array }
+%array = type { [1 x i64] }
+
+define spir_func void @test1(ptr %arg1, ptr %arg2, i1 %sw) {
+entry:
+ %r1 = call ptr @foo(ptr %arg1, ptr %arg2, i1 %sw)
+ %r2 = call ptr @bar(ptr %arg1, i1 %sw)
+ ret void
+}
+
+define spir_func ptr @foo(ptr %arg1, ptr %arg2, i1 %sw) {
+entry:
+ br i1 %sw, label %exit, label %sw1
+sw1:
+ %result = getelementptr inbounds %struct, ptr %arg1, i64 100
+ ret ptr %result
+exit:
+ ret ptr %arg2
+}
+
+define spir_func ptr @bar(ptr %arg1, i1 %sw) {
+entry:
+ %charptr = getelementptr inbounds i8, ptr %arg1, i64 0
+ br i1 %sw, label %exit, label %sw1
+sw1:
+ %result = getelementptr inbounds %struct, ptr %arg1, i64 100
+ ret ptr %result
+exit:
+ ret ptr %charptr
+}
+
+define spir_func void @test2(ptr %arg1, ptr %arg2, i1 %sw) {
+entry:
+ %r1 = call ptr @foo(ptr %arg1, ptr %arg2, i1 %sw)
+ %r2 = call ptr @bar(ptr %arg1, i1 %sw)
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
new file mode 100644
index 00000000000000..afc75c616f023b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
@@ -0,0 +1,58 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]]
+; CHECK-DAG: %[[Struct:.*]] = OpTypeStruct %[[Array]]
+; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct]]
+; CHECK-DAG: %[[CharPtr:.*]] = OpTypePointer Function %[[Char]]
+
+; CHECK: %[[Branch1:.*]] = OpLabel
+; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
+; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
+; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]]
+; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]]
+; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
+; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]]
+; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]]
+; CHECK: OpLabel
+; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]]
+
+%struct = type { %array }
+%array = type { [1 x i64] }
+%array3 = type { [3 x i32] }
+
+define spir_kernel void @foo(ptr addrspace(1) noundef align 1 %arg1, ptr noundef byval(%struct) align 8 %arg2, i1 noundef zeroext %expected) {
+entry:
+ %agg = alloca %array3, align 8
+ %r0 = load i64, ptr %arg2, align 8
+ %add.ptr = getelementptr inbounds i8, ptr %agg, i64 12
+ %r1 = load i32, ptr %agg, align 4
+ %tobool0 = icmp slt i32 %r1, 0
+ br i1 %tobool0, label %exit, label %sw1
+
+sw1: ; preds = %entry
+ %incdec1 = getelementptr inbounds i8, ptr %agg, i64 4
+ %r2 = load i32, ptr %incdec1, align 4
+ %tobool1 = icmp slt i32 %r2, 0
+ br i1 %tobool1, label %exit, label %sw2
+
+sw2: ; preds = %sw1
+ %incdec2 = getelementptr inbounds i8, ptr %agg, i64 8
+ %r3 = load i32, ptr %incdec2, align 4
+ %tobool2 = icmp slt i32 %r3, 0
+ %spec.select = select i1 %tobool2, ptr %incdec2, ptr %add.ptr
+ br label %exit
+
+exit: ; preds = %sw2, %sw1, %entry
+ %retval.0 = phi ptr [ %agg, %entry ], [ %incdec1, %sw1 ], [ %spec.select, %sw2 ]
+ %add.ptr.i = getelementptr inbounds i8, ptr addrspace(1) %arg1, i64 %r0
+ %r4 = icmp eq ptr %retval.0, %add.ptr
+ %cmp = xor i1 %r4, %expected
+ %frombool6.i = zext i1 %cmp to i8
+ store i8 %frombool6.i, ptr addrspace(1) %add.ptr.i, align 1
+ %r5 = icmp eq ptr %add.ptr, %retval.0
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select.ll b/llvm/test/CodeGen/SPIRV/instructions/select.ll
index f54ef21f208596..c4176b17abb449 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select.ll
@@ -1,6 +1,8 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-DAG: OpName [[SCALARi32:%.+]] "select_i32"
+; CHECK-DAG: OpName [[SCALARPTR:%.+]] "select_ptr"
; CHECK-DAG: OpName [[VEC2i32:%.+]] "select_i32v2"
; CHECK-DAG: OpName [[VEC2i32v2:%.+]] "select_v2i32v2"
@@ -17,6 +19,19 @@ define i32 @select_i32(i1 %c, i32 %t, i32 %f) {
ret i32 %r
}
+; CHECK: [[SCALARPTR]] = OpFunction
+; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter
+; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter
+; CHECK-NEXT: [[F:%.+]] = OpFunctionParameter
+; CHECK: OpLabel
+; CHECK: [[R:%.+]] = OpSelect {{%.+}} [[C]] [[T]] [[F]]
+; CHECK: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define ptr @select_ptr(i1 %c, ptr %t, ptr %f) {
+ %r = select i1 %c, ptr %t, ptr %f
+ ret ptr %r
+}
+
; CHECK: [[VEC2i32]] = OpFunction
; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter
; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter
More information about the llvm-commits
mailing list