[llvm] [SPIR-V] Make 'emit intrinsics' a module pass to resolve function return types over the module (PR #88503)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 12 04:42:01 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

The goal of this PR is to make 'emit intrinsics' a module pass to resolve function return types over the module. This PR is a continuation of https://github.com/llvm/llvm-project/pull/88254 in the part of deduction of function's return type for opaque pointers. The test case is updated (hardened).

---

Patch is 36.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88503.diff


19 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+16-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+158-9) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+6-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+20-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+15-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+11-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+10-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td (+16-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7) 
- (added) llvm/test/CodeGen/SPIRV/const-composite.ll (+26) 
- (added) llvm/test/CodeGen/SPIRV/instructions/ret-type.ll (+82) 
- (added) llvm/test/CodeGen/SPIRV/instructions/select-phi.ll (+58) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/select.ll (+15) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6979107349d968..fb8580cd47c01b 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
 FunctionPass *createSPIRVPostLegalizerPass();
-FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
+ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
                                const SPIRVSubtarget &Subtarget,
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e4ba2191366b3..c107b99cf4cb63 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (F.isDeclaration())
     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
   FunctionType *FTy = getOriginalFunctionType(F);
-  SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+  Type *FRetTy = FTy->getReturnType();
+  if (isUntypedPointerTy(FRetTy)) {
+    if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
+      TypedPointerType *DerivedTy =
+          TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy));
+      GR->addReturnType(&F, DerivedTy);
+      FRetTy = DerivedTy;
+    }
+  }
+  SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
   FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     // TODO: support constexpr casts and indirect calls.
     if (CF == nullptr)
       return false;
-    if (FunctionType *FTy = getOriginalFunctionType(*CF))
+    if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
       OrigRetTy = FTy->getReturnType();
+      if (isUntypedPointerTy(OrigRetTy)) {
+        if (auto *DerivedRetTy = GR->findReturnType(CF))
+          OrigRetTy = DerivedRetTy;
+      }
+    }
   }
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e8ce5a35b457d5..472bc8638c9af1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 
 namespace {
 class SPIRVEmitIntrinsics
-    : public FunctionPass,
+    : public ModulePass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
@@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics
   DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
+  // a registry of created Intrinsic::spv_assign_ptr_type instructions
+  DenseMap<Value *, CallInst *> AssignPtrTypeInstr;
+
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
   Type *deduceElementTypeHelper(Value *I);
@@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
                                std::unordered_set<Value *> &Visited);
 
+  // deduce Types of operands of the Instruction if possible
+  void deduceOperandElementType(Instruction *I);
+
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
 
@@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics
 
 public:
   static char ID;
-  SPIRVEmitIntrinsics() : FunctionPass(ID) {
+  SPIRVEmitIntrinsics() : ModulePass(ID) {
     initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
   }
-  SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) {
+  SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) {
     initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry());
   }
   Instruction *visitInstruction(Instruction &I) { return &I; }
@@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics
   Instruction *visitAllocaInst(AllocaInst &I);
   Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I);
   Instruction *visitUnreachableInst(UnreachableInst &I);
-  bool runOnFunction(Function &F) override;
+
+  StringRef getPassName() const override { return "SPIRV emit intrinsics"; }
+
+  bool runOnModule(Module &M) override;
+  bool runOnFunction(Function &F);
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    ModulePass::getAnalysisUsage(AU);
+  }
 };
 } // namespace
 
@@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (Ty)
         break;
     }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
+      Ty = deduceElementTypeByUsersDeep(Op, Visited);
+      if (Ty)
+        break;
+    }
   }
 
   // remember the found relationship
@@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
   return IntegerType::getInt8Ty(I->getContext());
 }
 
+// If the Instruction has Pointer operands with unresolved types, this function
+// tries to deduce them. If the Instruction has Pointer operands with known
+// types which differ from expected, this function tries to insert a bitcast to
+// resolve the issue.
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+  SmallVector<std::pair<Value *, unsigned>> Ops;
+  Type *KnownElemTy = nullptr;
+  // look for known basic patterns of type inference
+  if (auto *Ref = dyn_cast<PHINode>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Value *Op = Ref->getIncomingValue(i);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
+    if (!isPointerTy(I->getType()) ||
+        !(KnownElemTy = GR->findDeducedElementType(I)))
+      return;
+    for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
+      Value *Op = Ref->getOperand(i);
+      if (isPointerTy(Op->getType()))
+        Ops.push_back(std::make_pair(Op, i));
+    }
+  } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
+    Type *RetTy = F->getReturnType();
+    if (!isPointerTy(RetTy))
+      return;
+    Value *Op = Ref->getReturnValue();
+    if (!Op)
+      return;
+    if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+      if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
+        GR->addDeducedElementType(F, OpElemTy);
+        TypedPointerType *DerivedTy =
+            TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
+        GR->addReturnType(F, DerivedTy);
+      }
+      return;
+    }
+    Ops.push_back(std::make_pair(Op, 0));
+  } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
+    if (!isPointerTy(Ref->getOperand(0)->getType()))
+      return;
+    Value *Op0 = Ref->getOperand(0);
+    Value *Op1 = Ref->getOperand(1);
+    Type *ElemTy0 = GR->findDeducedElementType(Op0);
+    Type *ElemTy1 = GR->findDeducedElementType(Op1);
+    if (ElemTy0) {
+      KnownElemTy = ElemTy0;
+      Ops.push_back(std::make_pair(Op1, 1));
+    } else if (ElemTy1) {
+      KnownElemTy = ElemTy1;
+      Ops.push_back(std::make_pair(Op0, 0));
+    }
+  }
+
+  // There is no enough info to deduce types or all is valid.
+  if (!KnownElemTy || Ops.size() == 0)
+    return;
+
+  LLVMContext &Ctx = F->getContext();
+  IRBuilder<> B(Ctx);
+  for (auto &OpIt : Ops) {
+    Value *Op = OpIt.first;
+    if (Op->use_empty())
+      continue;
+    Type *Ty = GR->findDeducedElementType(Op);
+    if (Ty == KnownElemTy)
+      continue;
+    if (Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()))
+      setInsertPointSkippingPhis(B, User->getNextNode());
+    else
+      B.SetInsertPoint(I);
+    Value *OpTyVal = Constant::getNullValue(KnownElemTy);
+    Type *OpTy = Op->getType();
+    if (!Ty) {
+      GR->addDeducedElementType(Op, KnownElemTy);
+      // check if there is existing Intrinsic::spv_assign_ptr_type instruction
+      auto It = AssignPtrTypeInstr.find(Op);
+      if (It == AssignPtrTypeInstr.end()) {
+        CallInst *CI =
+            buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op,
+                            {B.getInt32(getPointerAddressSpace(OpTy))}, B);
+        AssignPtrTypeInstr[Op] = CI;
+      } else {
+        It->second->setArgOperand(
+            1,
+            MetadataAsValue::get(
+                Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))));
+      }
+    } else {
+      SmallVector<Type *, 2> Types = {OpTy, OpTy};
+      MetadataAsValue *VMD = MetadataAsValue::get(
+          Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)));
+      SmallVector<Value *, 2> Args = {Op, VMD,
+                                      B.getInt32(getPointerAddressSpace(OpTy))};
+      CallInst *PtrCastI =
+          B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+      I->setOperand(OpIt.second, PtrCastI);
+    }
+  }
+}
+
 void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
                                               Instruction *New,
                                               IRBuilder<> &B) {
@@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
     GR->addDeducedElementType(CI, ExpectedElementType);
     GR->addDeducedElementType(Pointer, ExpectedElementType);
+    AssignPtrTypeInstr[Pointer] = CI;
     return;
   }
 
@@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
   GR->addDeducedElementType(CI, ElemTy);
+  AssignPtrTypeInstr[I] = CI;
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
             {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
         GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
         GR->addDeducedElementType(Arg, ElemTy);
+        AssignPtrTypeInstr[Arg] = AssignPtrTyCI;
       }
     }
   }
@@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     insertAssignTypeIntrs(I, B);
     insertPtrCastOrAssignTypeInstr(I, B);
   }
+
+  for (auto &I : instructions(Func))
+    deduceOperandElementType(&I);
+
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     processInstrAfterVisit(I, B);
   }
 
-  // check if function parameter types are set
-  if (!F->isIntrinsic())
-    processParamTypes(F, B);
-
   return true;
 }
 
-FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
+bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
+  bool Changed = false;
+
+  for (auto &F : M) {
+    Changed |= runOnFunction(F);
+  }
+
+  for (auto &F : M) {
+    // check if function parameter types are set
+    if (!F.isDeclaration() && !F.isIntrinsic()) {
+      const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
+      GR = ST.getSPIRVGlobalRegistry();
+      IRBuilder<> B(F.getContext());
+      processParamTypes(&F, B);
+    }
+  }
+
+  return Changed;
+}
+
+ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) {
   return new SPIRVEmitIntrinsics(TM);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 70197e948c6582..05e41e06248e35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -23,7 +23,6 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Type.h"
-#include "llvm/IR/TypedPointerType.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
 
@@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
-
   SPIRVType *SpirvType =
       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
@@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
                                                 bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
-  auto t = VRegToTypeMap.find(CurMF);
+SPIRVType *
+SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
+                                         const MachineFunction *MF) const {
+  auto t = VRegToTypeMap.find(MF ? MF : CurMF);
   if (t != VRegToTypeMap.end()) {
     auto tt = t->second.find(VReg);
     if (tt != t->second.end())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2e3e69456ac260..55979ba403a0ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -21,6 +21,7 @@
 #include "SPIRVInstrInfo.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/IR/Constant.h"
+#include "llvm/IR/TypedPointerType.h"
 
 namespace llvm {
 using SPIRVType = const MachineInstr;
@@ -58,6 +59,9 @@ class SPIRVGlobalRegistry {
   SmallPtrSet<const Type *, 4> TypesInProcessing;
   DenseMap<const Type *, SPIRVType *> ForwardPointerTypes;
 
+  // if a function returns a pointer, this is to map it into TypedPointerType
+  DenseMap<const Function *, TypedPointerType *> FunResPointerTypes;
+
   // Number of bits pointers and size_t integers require.
   const unsigned PointerSize;
 
@@ -134,6 +138,16 @@ class SPIRVGlobalRegistry {
   void setBound(unsigned V) { Bound = V; }
   unsigned getBound() { return Bound; }
 
+  // Add a record to the map of function return pointer types.
+  void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) {
+    FunResPointerTypes[ArgF] = DerivedTy;
+  }
+  // Find a record in the map of function return pointer types.
+  const TypedPointerType *findReturnType(const Function *ArgF) {
+    auto It = FunResPointerTypes.find(ArgF);
+    return It == FunResPointerTypes.end() ? nullptr : It->second;
+  }
+
   // Deduced element types of untyped pointers and composites:
   // - Add a record to the map of deduced element types.
   void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
@@ -276,8 +290,12 @@ class SPIRVGlobalRegistry {
           SPIRV::AccessQualifier::ReadWrite);
 
   // Return the SPIR-V type instruction corresponding to the given VReg, or
-  // nullptr if no such type instruction exists.
-  SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
+  // nullptr if no such type instruction exists. The second argument MF
+  // allows to search for the association in a context of the machine functions
+  // than the current one, without switching between different "current" machine
+  // functions.
+  SPIRVType *getSPIRVTypeForVReg(Register VReg,
+                                 const MachineFunction *MF = nullptr) const;
 
   // Whether the given VReg has a SPIR-V type mapped to it yet.
   bool hasSPIRVTypeForVReg(Register VReg) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 8db54c74f23690..b8296c3f6eeaee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
                              MachineInstr &I, unsigned OpIdx,
                              SPIRVType *ResType, const Type *ResTy = nullptr) {
+  // Get operand type
+  MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+  Register OpTypeReg =
       TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
           ? TypeInst->getOperand(1).getReg()
-          : OpReg);
+          : OpReg;
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
-  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  // Get operand's pointee type
+  Register ElemTypeReg = OpType->getOperand(2).getReg();
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
   if (!ElemType)
     return;
-  bool IsSameMF =
-      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  // Check if we need a bitcast to make a statement valid
+  bool IsSameMF = MF == ResType->getParent()->getParent();
   bool IsEqualTypes = IsSameMF ? ElemType == ResType
                                : GR.getTypeForSPIRVType(ElemType) == ResTy;
   if (IsEqualTypes)
@@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
     SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
     SPIRVType *DefElemType =
         DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
-            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
+                                     DefPtrType->getParent()->getParent())
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
@@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
 // with a processed definition. Return Function pointer if it's a forward
 // call (ahead of definition), and nullptr otherwise.
 const Function *validateFunCall(const SPIRVSubtarget &STI,
-                                MachineRegisterInfo *MRI,
+                                MachineRegisterInfo *CallMRI,
                                 SPIRVGlobalRegistry &GR,
                                 MachineInstr &FunCall) {
   const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
@@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI,
       const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
   if (!FunDef)
     return F;
-  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
+  validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
   return nullptr;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index e3f76419f13137..aacfecc1e313f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
 bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
       MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
-      MI.getOpcode() == SPIRV::GET_vID) {
+      MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
     auto &MRI = MI.getMF()->getRegInfo();
     MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
     MI.eraseFromParent();
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/88503


More information about the llvm-commits mailing list