[llvm] [SPIR-V] Make 'emit intrinsics' a module pass to resolve function return types over the module (PR #88503)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 12 04:42:01 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
The goal of this PR is to make 'emit intrinsics' a module pass to resolve function return types over the module. This PR is a continuation of https://github.com/llvm/llvm-project/pull/88254 in the part of deduction of function's return type for opaque pointers. The test case is updated (hardened).
---
Patch is 36.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88503.diff
19 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+158-9)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-5)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+20-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+15-8)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+11-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+10-3)
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td (+1)
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td (+16-3)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7)
- (added) llvm/test/CodeGen/SPIRV/const-composite.ll (+26)
- (added) llvm/test/CodeGen/SPIRV/instructions/ret-type.ll (+82)
- (added) llvm/test/CodeGen/SPIRV/instructions/select-phi.ll (+58)
- (modified) llvm/test/CodeGen/SPIRV/instructions/select.ll (+15)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6979107349d968..fb8580cd47c01b 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
FunctionPass *createSPIRVPostLegalizerPass();
-FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
+ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
const SPIRVSubtarget &Subtarget,
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e4ba2191366b3..c107b99cf4cb63 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
FunctionType *FTy = getOriginalFunctionType(F);
- SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+ Type *FRetTy = FTy->getReturnType();
+ if (isUntypedPointerTy(FRetTy)) {
+ if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
+ TypedPointerType *DerivedTy =
+ TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
+ GR->addReturnType(&F, DerivedTy);
+ FRetTy = DerivedTy;
+ }
+ }
+ SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
- if (FunctionType *FTy = getOriginalFunctionType(*CF))
+ if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
OrigRetTy = FTy->getReturnType();
+ if (isUntypedPointerTy(OrigRetTy)) {
+ if (auto *DerivedRetTy = GR->findReturnType(CF))
+ OrigRetTy = DerivedRetTy;
+ }
+ }
}
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..472bc8638c9af1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
namespace {
class SPIRVEmitIntrinsics
- : public FunctionPass,
+ : public ModulePass,
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
SPIRVTargetMachine *TM = nullptr;
SPIRVGlobalRegistry *GR = nullptr;
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
+ // a registry of created Intrinsic::spv_assign_ptr_type instructions
+ DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
+
// deduce element type of untyped pointers
Type *deduceElementType(Value *I);
Type *deduceElementTypeHelper(Value *I);
@@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
Type *deduceNestedTypeHelper(User *U, Type *Ty,
std::unordered_set<Value *> &Visited);
+ // deduce Types of operands of the Instruction if possible
+ void deduceOperandElementType(Instruction *I);
+
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics
public:
static char ID;
- SPIRVEmitIntrinsics() : FunctionPass(ID) {
+ SPIRVEmitIntrinsics() : ModulePass(ID) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
- SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
+ SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
}
Instruction *visitInstruction(Instruction &I) { return &I; }
@@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
Instruction *visitAllocaInst(AllocaInst &I);
Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
Instruction *visitUnreachableInst(UnreachableInst &I);
- bool runOnFunction(Function &F) override;
+
+ StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
+
+ bool runOnModule(Module &M) override;
+ bool runOnFunction(Function &F);
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ ModulePass::getAnalysisUsage(AU);
+ }
};
} // namespace
@@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (Ty)
break;
}
+ } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+ for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
+ Ty = deduceElementTypeByUsersDeep(Op, Visited);
+ if (Ty)
+ break;
+ }
}
// remember the found relationship
@@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
return IntegerType::getInt8Ty(I->getContext());
}
+// If the Instruction has Pointer operands with unresolved types, this function
+// tries to deduce them. If the Instruction has Pointer operands with known
+// types which differ from expected, this function tries to insert a bitcast to
+// resolve the issue.
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+ SmallVector<std::pair<Value *, unsigned>> Ops;
+ Type *KnownElemTy = nullptr;
+ // look for known basic patterns of type inference
+ if (auto *Ref = dyn_cast<PHINode>(I)) {
+ if (!isPointerTy(I->getType()) ||
+ !(KnownElemTy = GR->findDeducedElementType(I)))
+ return;
+ for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+ Value *Op = Ref->getIncomingValue(i);
+ if (isPointerTy(Op->getType()))
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+ if (!isPointerTy(I->getType()) ||
+ !(KnownElemTy = GR->findDeducedElementType(I)))
+ return;
+ for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
+ Value *Op = Ref->getOperand(i);
+ if (isPointerTy(Op->getType()))
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
+ Type *RetTy = F->getReturnType();
+ if (!isPointerTy(RetTy))
+ return;
+ Value *Op = Ref->getReturnValue();
+ if (!Op)
+ return;
+ if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+ if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+ GR->addDeducedElementType(F, OpElemTy);
+ TypedPointerType *DerivedTy =
+ TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
+ GR->addReturnType(F, DerivedTy);
+ }
+ return;
+ }
+ Ops.push_back(std::make_pair(Op, 0));
+ } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
+ if (!isPointerTy(Ref->getOperand(0)->getType()))
+ return;
+ Value *Op0 = Ref->getOperand(0);
+ Value *Op1 = Ref->getOperand(1);
+ Type *ElemTy0 = GR->findDeducedElementType(Op0);
+ Type *ElemTy1 = GR->findDeducedElementType(Op1);
+ if (ElemTy0) {
+ KnownElemTy = ElemTy0;
+ Ops.push_back(std::make_pair(Op1, 1));
+ } else if (ElemTy1) {
+ KnownElemTy = ElemTy1;
+ Ops.push_back(std::make_pair(Op0, 0));
+ }
+ }
+
+ // There is no enough info to deduce types or all is valid.
+ if (!KnownElemTy || Ops.size() == 0)
+ return;
+
+ LLVMContext &Ctx = F->getContext();
+ IRBuilder<> B(Ctx);
+ for (auto &OpIt : Ops) {
+ Value *Op = OpIt.first;
+ if (Op->use_empty())
+ continue;
+ Type *Ty = GR->findDeducedElementType(Op);
+ if (Ty == KnownElemTy)
+ continue;
+ if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
+ setInsertPointSkippingPhis(B, User->getNextNode());
+ else
+ B.SetInsertPoint(I);
+ Value *OpTyVal = Constant::getNullValue(KnownElemTy);
+ Type *OpTy = Op->getType();
+ if (!Ty) {
+ GR->addDeducedElementType(Op, KnownElemTy);
+ // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+ auto It = AssignPtrTypeInstr.find(Op);
+ if (It == AssignPtrTypeInstr.end()) {
+ CallInst *CI =
+ buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
+ {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+ AssignPtrTypeInstr[Op] = CI;
+ } else {
+ It->second->setArgOperand(
+ 1,
+ MetadataAsValue::get(
+ Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+ }
+ } else {
+ SmallVector<Type *, 2> Types = {OpTy, OpTy};
+ MetadataAsValue *VMD = MetadataAsValue::get(
+ Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
+ SmallVector<Value *, 2> Args = {Op, VMD,
+ B.getInt32(getPointerAddressSpace(OpTy))};
+ CallInst *PtrCastI =
+ B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+ I->setOperand(OpIt.second, PtrCastI);
+ }
+ }
+}
+
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
@@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ExpectedElementType);
GR->addDeducedElementType(Pointer, ExpectedElementType);
+ AssignPtrTypeInstr[Pointer] = CI;
return;
}
@@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
GR->addDeducedElementType(CI, ElemTy);
+ AssignPtrTypeInstr[I] = CI;
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
GR->addDeducedElementType(Arg, ElemTy);
+ AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
}
}
}
@@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
}
+
+ for (auto &I : instructions(Func))
+ deduceOperandElementType(&I);
+
for (auto *I : Worklist) {
TrackConstants = true;
if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
processInstrAfterVisit(I, B);
}
- // check if function parameter types are set
- if (!F->isIntrinsic())
- processParamTypes(F, B);
-
return true;
}
-FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
+bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
+ bool Changed = false;
+
+ for (auto &F : M) {
+ Changed |= runOnFunction(F);
+ }
+
+ for (auto &F : M) {
+ // check if function parameter types are set
+ if (!F.isDeclaration() && !F.isIntrinsic()) {
+ const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
+ GR = ST.getSPIRVGlobalRegistry();
+ IRBuilder<> B(F.getContext());
+ processParamTypes(&F, B);
+ }
+ }
+
+ return Changed;
+}
+
+ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
return new SPIRVEmitIntrinsics(TM);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 70197e948c6582..05e41e06248e35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -23,7 +23,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Type.h"
-#include "llvm/IR/TypedPointerType.h"
#include "llvm/Support/Casting.h"
#include <cassert>
@@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-
SPIRVType *SpirvType =
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
@@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
return SpirvType;
}
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
- auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF) const {
+ auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2e3e69456ac260..55979ba403a0ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -21,6 +21,7 @@
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/Constant.h"
+#include "llvm/IR/TypedPointerType.h"
namespace llvm {
using SPIRVType = const MachineInstr;
@@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
SmallPtrSet<const Type *, 4> TypesInProcessing;
DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
+ // if a function returns a pointer, this is to map it into TypedPointerType
+ DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
+
// Number of bits pointers and size_t integers require.
const unsigned PointerSize;
@@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }
+ // Add a record to the map of function return pointer types.
+ void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
+ FunResPointerTypes[ArgF] = DerivedTy;
+ }
+ // Find a record in the map of function return pointer types.
+ const TypedPointerType *findReturnType(const Function *ArgF) {
+ auto It = FunResPointerTypes.find(ArgF);
+ return It == FunResPointerTypes.end() ? nullptr : It->second;
+ }
+
// Deduced element types of untyped pointers and composites:
// - Add a record to the map of deduced element types.
void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
@@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
SPIRV::AccessQualifier::ReadWrite);
// Return the SPIR-V type instruction corresponding to the given VReg, or
- // nullptr if no such type instruction exists.
- SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+ // nullptr if no such type instruction exists. The second argument MF
+ // allows to search for the association in a context of the machine functions
+ // than the current one, without switching between different "current" machine
+ // functions.
+ SPIRVType *getSPIRVTypeForVReg(Register VReg,
+ const MachineFunction *MF = nullptr) const;
// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
MachineInstr &I, unsigned OpIdx,
SPIRVType *ResType, const Type *ResTy = nullptr) {
+ // Get operand type
+ MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+ Register OpTypeReg =
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
? TypeInst->getOperand(1).getReg()
- : OpReg);
+ : OpReg;
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
- SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ // Get operand's pointee type
+ Register ElemTypeReg = OpType->getOperand(2).getReg();
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
if (!ElemType)
return;
- bool IsSameMF =
- ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+ // Check if we need a bitcast to make a statement valid
+ bool IsSameMF = MF == ResType->getParent()->getParent();
bool IsEqualTypes = IsSameMF ? ElemType == ResType
: GR.getTypeForSPIRVType(ElemType) == ResTy;
if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
SPIRVType *DefElemType =
DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
- ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+ ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+ DefPtrType->getParent()->getParent())
: nullptr;
if (DefElemType) {
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
// with a processed definition. Return Function pointer if it's a forward
// call (ahead of definition), and nullptr otherwise.
const Function *validateFunCall(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
+ MachineRegisterInfo *CallMRI,
SPIRVGlobalRegistry &GR,
MachineInstr &FunCall) {
const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
if (!FunDef)
return F;
- validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+ MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+ validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
return nullptr;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index e3f76419f13137..aacfecc1e313f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
- MI.getOpcode() == SPIRV::GET_vID) {
+ MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
auto &MRI = MI.getMF()->getRegInfo();
MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
MI.eraseFromParent();
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/88503
More information about the llvm-commits
mailing list