[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 05:09:41 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

Goals of the PR are:
* to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's;
* improve type inference by speeding up postprocessing of types: by limiting iterations by checking what remains to process, and processing each instruction just once for any number of operands with uncomplete types;
* improve type inference by more accurate work with uncomplete types (pass uncomplete property to dependent operands, ensure consistency of uncomplete-types data structure);
* change processing order and add traversing of PHI nodes when type inference apply instructions results to specify/update/cast operands type (fixes an issue with OpPhi's result type mismatch with operand types).

---

Patch is 46.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116609.diff


14 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+6-22) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+17) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+187-93) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+2-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+1-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+2-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+49) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+16) 
- (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll (+7-3) 
- (added) llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll (+55) 
- (added) llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll (+82) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll (-2) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..bed34b83d2e546 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -447,12 +447,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
                               Register DestinationReg = Register(0)) {
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  if (!DestinationReg.isValid()) {
-    DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
-    MRI->setType(DestinationReg, LLT::scalar(64));
-    GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
-  }
+  if (!DestinationReg.isValid())
+    DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
   // TODO: consider using correct address space and alignment (p0 is canonical
   // type for selection though).
   MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2129,7 +2125,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
       MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(
@@ -2517,23 +2513,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
-  // SPIR-V type and return register.
-  Register ReturnRegister = OrigRet;
-  SPIRVType *ReturnType = nullptr;
-  if (OrigRetTy && !OrigRetTy->isVoidTy()) {
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
-      MIRBuilder.getMRI()->setRegClass(ReturnRegister,
-                                       GR->getRegClass(ReturnType));
-  } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
-    ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
-    MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-  }
-
   // Lookup the builtin in the TableGen records.
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+  assert(SpvType && "Inconsistent return register: expected valid type info");
   std::unique_ptr<const IncomingCall> Call =
-      lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+      lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
 
   if (!Call) {
     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..3fdaa6aa3257ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,23 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   if (isFunctionDecl && !DemangledName.empty() &&
       (canUseGLSL || canUseOpenCL)) {
+    if (ResVReg.isValid()) {
+      if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+        const Type *RetTy = OrigRetTy;
+        if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+          const Value *OrigValue = Info.OrigRet.OrigValue;
+          if (!OrigValue)
+            OrigValue = Info.CB;
+          if (OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+              RetTy =
+                  TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+        }
+        setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
+      }
+    } else {
+      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
+    }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e6ef40e010dc20..7460e0a71aae51 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,7 +67,7 @@ class SPIRVEmitIntrinsics
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
-  Function *F = nullptr;
+  Function *CurrF = nullptr;
   bool TrackConstants = true;
   bool HaveFunPtrs = false;
   DenseMap<Instruction *, Constant *> AggrConsts;
@@ -76,8 +76,27 @@ class SPIRVEmitIntrinsics
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // a register of Instructions that don't have a complete type definition
-  DenseMap<Value *, unsigned> UncompleteTypeInfo;
-  SmallVector<Value *> PostprocessWorklist;
+  bool CanTodoType = true;
+  unsigned TodoTypeSz = 0;
+  DenseMap<Value *, bool> TodoType;
+  void insertTodoType(Value *Op) {
+    if (CanTodoType) {
+      auto It = TodoType.try_emplace(Op, true);
+      if (It.second)
+        ++TodoTypeSz;
+    }
+  }
+  void eraseTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    if (It != TodoType.end() && It->second) {
+      TodoType[Op] = false;
+      --TodoTypeSz;
+    }
+  }
+  bool isTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    return It != TodoType.end() && It->second;
+  }
 
   // well known result types of builtins
   enum WellKnownTypes { Event };
@@ -105,8 +124,9 @@ class SPIRVEmitIntrinsics
                                bool UnknownElemTypeI8);
 
   // deduce Types of operands of the Instruction if possible
-  void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0,
-                                Type *AskTy = 0, CallInst *AssignCI = 0);
+  void deduceOperandElementType(Instruction *I,
+                                const SmallPtrSet<Value *, 4> *AskOps = nullptr,
+                                bool IsPostprocessing = false);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -145,12 +165,20 @@ class SPIRVEmitIntrinsics
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
+
+  bool deduceOperandElementTypeCalledFunction(
+      SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
+      SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy);
+  void deduceOperandElementTypeFunctionPointer(
+      CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+      Type *&KnownElemTy, bool IsPostprocessing);
+
   void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
                             CallInst *AssignCI);
   void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
 
   bool runOnFunction(Function &F);
-  bool postprocessTypes();
+  bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
 
 public:
@@ -280,17 +308,10 @@ void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
   GR->updateIfExistDeducedElementType(Src, Dest, DeleteOld);
   GR->updateIfExistAssignPtrTypeInstr(Src, Dest, DeleteOld);
   // Update uncomplete type records if any
-  auto It = UncompleteTypeInfo.find(Src);
-  if (It == UncompleteTypeInfo.end())
-    return;
-  if (DeleteOld) {
-    unsigned Pos = It->second;
-    UncompleteTypeInfo.erase(Src);
-    UncompleteTypeInfo[Dest] = Pos;
-    PostprocessWorklist[Pos] = Dest;
-  } else {
-    UncompleteTypeInfo[Dest] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Dest);
+  if (isTodoType(Src)) {
+    if (DeleteOld)
+      eraseTodoType(Src);
+    insertTodoType(Dest);
   }
 }
 
@@ -354,7 +375,7 @@ void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
-      AssignPtrTyCI->getParent()->getParent() != F) {
+      AssignPtrTyCI->getParent()->getParent() != CurrF) {
     AssignPtrTyCI = buildIntrWithMD(
         Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
         {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
@@ -455,10 +476,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   if (isUntypedPointerTy(RefTy)) {
     if (!UnknownElemTypeI8)
       return;
-    if (auto *I = dyn_cast<Instruction>(Op)) {
-      UncompleteTypeInfo[I] = PostprocessWorklist.size();
-      PostprocessWorklist.push_back(I);
-    }
+    insertTodoType(Op);
   }
   Ty = RefTy;
 }
@@ -661,10 +679,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
     return Ty;
   if (!UnknownElemTypeI8)
     return nullptr;
-  if (auto *Instr = dyn_cast<Instruction>(I)) {
-    UncompleteTypeInfo[Instr] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Instr);
-  }
+  insertTodoType(I);
   return IntegerType::getInt8Ty(I->getContext());
 }
 
@@ -683,8 +698,7 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
 
 // Try to deduce element type for a call base. Returns false if this is an
 // indirect function invocation, and true otherwise.
-static bool deduceOperandElementTypeCalledFunction(
-    SPIRVGlobalRegistry *GR, Instruction *I,
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
     SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
     SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
   Function *CalledF = CI->getCalledFunction();
@@ -726,7 +740,7 @@ static bool deduceOperandElementTypeCalledFunction(
       case SPIRV::OpAtomicUMax:
       case SPIRV::OpAtomicSMin:
       case SPIRV::OpAtomicSMax: {
-        KnownElemTy = getAtomicElemTy(GR, I, Op);
+        KnownElemTy = getAtomicElemTy(GR, CI, Op);
         if (!KnownElemTy)
           return true;
         Ops.push_back(std::make_pair(Op, 0));
@@ -738,32 +752,44 @@ static bool deduceOperandElementTypeCalledFunction(
 }
 
 // Try to deduce element type for a function pointer.
-static void deduceOperandElementTypeFunctionPointer(
-    SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
-    SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
+    CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+    Type *&KnownElemTy, bool IsPostprocessing) {
   Value *Op = CI->getCalledOperand();
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
   FunctionType *FTy = CI->getFunctionType();
-  bool IsNewFTy = false;
+  bool IsNewFTy = false, IsUncomplete = false;
   SmallVector<Type *, 4> ArgTys;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
-    if (ArgTy->isPointerTy())
+    if (ArgTy->isPointerTy()) {
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
         IsNewFTy = true;
         ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+        if (isTodoType(Arg))
+          IsUncomplete = true;
+      } else {
+        IsUncomplete = true;
       }
+    }
     ArgTys.push_back(ArgTy);
   }
   Type *RetTy = FTy->getReturnType();
-  if (I->getType()->isPointerTy())
-    if (Type *ElemTy = GR->findDeducedElementType(I)) {
+  if (CI->getType()->isPointerTy()) {
+    if (Type *ElemTy = GR->findDeducedElementType(CI)) {
       IsNewFTy = true;
       RetTy =
-          TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
+          TypedPointerType::get(ElemTy, getPointerAddressSpace(CI->getType()));
+      if (isTodoType(CI))
+        IsUncomplete = true;
+    } else {
+      IsUncomplete = true;
     }
+  }
+  if (!IsPostprocessing && IsUncomplete)
+    insertTodoType(Op);
   KnownElemTy =
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
@@ -772,17 +798,18 @@ static void deduceOperandElementTypeFunctionPointer(
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
 // resolve the issue.
-void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
-                                                   Instruction *AskOp,
-                                                   Type *AskTy,
-                                                   CallInst *AskCI) {
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Instruction *I, const SmallPtrSet<Value *, 4> *AskOps,
+    bool IsPostprocessing) {
   SmallVector<std::pair<Value *, unsigned>> Ops;
   Type *KnownElemTy = nullptr;
+  bool Uncomplete = false;
   // look for known basic patterns of type inference
   if (auto *Ref = dyn_cast<PHINode>(I)) {
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
       Value *Op = Ref->getIncomingValue(i);
       if (isPointerTy(Op->getType()))
@@ -792,6 +819,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     KnownElemTy = GR->findDeducedElementType(I);
     if (!KnownElemTy)
       return;
+    Uncomplete = isTodoType(I);
     Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0));
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     KnownElemTy = Ref->getSourceElementType();
@@ -837,27 +865,29 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
       Value *Op = Ref->getOperand(i);
       if (isPointerTy(Op->getType()))
         Ops.push_back(std::make_pair(Op, i));
     }
   } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
-    Type *RetTy = F->getReturnType();
+    Type *RetTy = CurrF->getReturnType();
     if (!isPointerTy(RetTy))
       return;
     Value *Op = Ref->getReturnValue();
     if (!Op)
       return;
-    if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+    if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
       if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
-        GR->addDeducedElementType(F, OpElemTy);
+        GR->addDeducedElementType(CurrF, OpElemTy);
         TypedPointerType *DerivedTy =
             TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
-        GR->addReturnType(F, DerivedTy);
+        GR->addReturnType(CurrF, DerivedTy);
       }
       return;
     }
+    Uncomplete = isTodoType(CurrF);
     Ops.push_back(std::make_pair(Op, 0));
   } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
     if (!isPointerTy(Ref->getOperand(0)->getType()))
@@ -868,37 +898,52 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     Type *ElemTy1 = GR->findDeducedElementType(Op1);
     if (ElemTy0) {
       KnownElemTy = ElemTy0;
+      Uncomplete = isTodoType(Op0);
       Ops.push_back(std::make_pair(Op1, 1));
     } else if (ElemTy1) {
       KnownElemTy = ElemTy1;
+      Uncomplete = isTodoType(Op1);
       Ops.push_back(std::make_pair(Op0, 0));
     }
   } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
     if (!CI->isIndirectCall())
-      deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
-                                             KnownElemTy);
+      deduceOperandElementTypeCalledFunction(InstrSet, CI, Ops, KnownElemTy);
     else if (HaveFunPtrs)
-      deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
+      deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy,
+                                              IsPostprocessing);
   }
 
   // There is no enough info to deduce types or all is valid.
   if (!KnownElemTy || Ops.size() == 0)
     return;
 
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
   IRBuilder<> B(Ctx);
   for (auto &OpIt : Ops) {
     Value *Op = OpIt.first;
-    if (Op->use_empty() || (AskOp && Op != AskOp))
+    if (Op->use_empty())
       continue;
-    Type *Ty = AskOp ? AskTy : GR->findDeducedElementType(Op);
+    if (AskOps && !AskOps->contains(Op))
+      continue;
+    Type *AskTy = nullptr;
+    CallInst *AskCI = nullptr;
+    if (IsPostprocessing && AskOps) {
+      AskTy = GR->findDeducedElementType(Op);
+      AskCI = GR->findAssignPtrTypeInstr(Op);
+      assert(AskTy && AskCI);
+    }
+    Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
     Value *OpTyVal = PoisonValue::get(KnownElemTy);
     Type *OpTy = Op->getType();
-    if (!Ty || AskTy || isUntypedPointerTy(Ty) ||
-        UncompleteTypeInfo.contains(Op)) {
+    if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       GR->addDeducedElementType(Op, KnownElemTy);
+      // check if KnownElemTy is complete
+      if (!Uncomplete)
+        eraseTodoType(Op);
+      else if (!IsPostprocessing)
+        insertTodoType(Op);
       // check if there is existing Intrinsic::spv_assign_ptr_type instruction
       CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op);
       if (AssignCI == nullptr) {
@@ -912,6 +957,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
         updateAssignType(AssignCI, Op, OpTyVal);
       }
     } else {
+      eraseTodoType(Op);
       if (auto *OpI = dyn_cast<Instruction>(Op)) {
         // spv_ptrcast's argument Op denotes an instruction that generates
         // a value, and we may use getInsertionPointAfterDef()
@@ -921,7 +967,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
         B.SetInsertPointPastAllocas(OpA->getParent());
         B.SetCurrentDebugLocation(DebugLoc());
       } else {
-        B.SetInsertPoint(F->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
+        B.SetInsertPoint(CurrF->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
       }
       SmallVector<Type *, 2> Types = {OpTy, OpTy};
       SmallVector<Value *, 2> Args = {Op, buildMD(OpTyVal),
@@ -961,7 +1007,7 @@ void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
 
 void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -989,7 +1035,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
 
 void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -1048,7 +1094,7 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
     return &Call;
 
   const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand());
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
 
   Constant *TyC = UndefValue::get(IA->getFunctionType());
   MDString *ConstraintString = MDString::get(Ctx, IA->getConstraintString());
@@ -1249,10 +1295,10 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
                                                          IRBuilder<> &B) {
   // Handle basic instructions:
   StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (IsKernelArgInt8(F, SI)) {
+  if (IsKernelArgInt8(CurrF, SI)) {
     return replacePointerOperandWithPtrCast(
-        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
-        B);
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(CurrF->getContext()),
+        0, B);
   } else if (SI) {
     Value *Op = SI->getValueOperand();
     Type *OpTy = Op->getType();
@@ -1419,7 +1465,7 @@ Instruction *SPIRVEmitIntrinsics::visitLoadInst(LoadInst &I) {
   TrackConstants = false;
   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116609


More information about the llvm-commits mailing list