[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 04:57:45 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/116609

>From d89bb33f0adb8d98fe668765353f4d79ec64dbd8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 04:47:27 -0800
Subject: [PATCH 1/5] Improve type inference: return values in call lowering

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 18 +-----
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 25 +++++++++
 .../SPIRV/pointers/builtin-ret-reg-type.ll    | 55 +++++++++++++++++++
 .../SPIRV/transcoding/OpGenericCastToPtr.ll   |  2 -
 4 files changed, 83 insertions(+), 17 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..e34f6c3c282750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2517,23 +2517,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
-  // SPIR-V type and return register.
-  Register ReturnRegister = OrigRet;
-  SPIRVType *ReturnType = nullptr;
-  if (OrigRetTy && !OrigRetTy->isVoidTy()) {
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
-      MIRBuilder.getMRI()->setRegClass(ReturnRegister,
-                                       GR->getRegClass(ReturnType));
-  } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
-    ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
-    MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-  }
-
   // Lookup the builtin in the TableGen records.
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+  assert(SpvType && "Inconsistent return register: expected valid type info");
   std::unique_ptr<const IncomingCall> Call =
-      lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+      lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
 
   if (!Call) {
     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..a7b6b0efa99551 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,31 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   if (isFunctionDecl && !DemangledName.empty() &&
       (canUseGLSL || canUseOpenCL)) {
+    if (ResVReg.isValid()) {
+      if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+        const Type *RetTy = OrigRetTy;
+        if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+          const Value *OrigValue = Info.OrigRet.OrigValue;
+          if (!OrigValue)
+            OrigValue = Info.CB;
+          if (OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+              RetTy =
+                  TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+        }
+        SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
+        GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+        if (!MRI->getRegClassOrNull(ResVReg)) {
+          MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
+          MRI->setType(ResVReg, GR->getRegType(SpvType));
+        }
+      }
+    } else {
+      SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
+      ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+      MRI->setType(ResVReg, GR->getRegType(SpvType));
+      GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+    }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
new file mode 100644
index 00000000000000..afa97ccfc0a69c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers which were
+; used as return values in call lowering. Pass criterion is that spirv-val considers output valid.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpPhi %[[#]] %[[#Ptr:]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#Ptr]] = OpPtrAccessChain %[[#]] %[[#]] %[[#]]
+
+
+%t_half = type { half }
+%t_i17 = type { [17 x i32] }
+%t_h17 = type { [17 x %t_half] }
+
+define internal spir_func void @foo(i64 %arrayinit.cur.add_4, half %r1, ptr addrspace(4) noundef align 8 dereferenceable_or_null(72) %this) {
+entry:
+  %r_3 = alloca %t_h17, align 8
+  %p_src = alloca %t_i17, align 4
+  %p_src4 = addrspacecast ptr %p_src to ptr addrspace(4)
+  %call_2 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %p_src4, i32 noundef 7)
+  br label %l_body
+
+l_body:                                           ; preds = %l_body, %entry
+  %l_done = icmp eq i64 %arrayinit.cur.add_4, 34
+  br i1 %l_done, label %exit, label %l_body
+
+exit:                                             ; preds = %l_body
+  %0 = addrspacecast ptr %call_2 to ptr addrspace(4)
+  %call_6 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %0, i32 noundef 7)
+  br label %for.cond_3
+
+for.cond_3:                                       ; preds = %for.body_3, %exit
+  %lsr.iv1 = phi ptr [ %scevgep2, %for.body_3 ], [ %call_6, %exit ]
+  %lsr.iv = phi ptr [ %scevgep, %for.body_3 ], [ %r_3, %exit ]
+  %i.0_3 = phi i64 [ 0, %exit ], [ %inc_3, %for.body_3 ]
+  %cmp_3 = icmp ult i64 %i.0_3, 17
+  br i1 %cmp_3, label %for.body_3, label %exit2
+
+for.body_3:                                       ; preds = %for.cond_3
+  %call2_5 = call spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef %r1, ptr noundef %lsr.iv1)
+  store half %call2_5, ptr %lsr.iv, align 2
+  %inc_3 = add nuw nsw i64 %i.0_3, 1
+  %scevgep = getelementptr i8, ptr %lsr.iv, i64 2
+  %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+  br label %for.cond_3
+
+exit2:                                            ; preds = %for.cond_3
+  ret void
+}
+
+declare dso_local spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef, ptr noundef)
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
index 54b2c786747768..2cba0f6ebd74be 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
@@ -2,9 +2,7 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
-; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
-; CHECK-SPIRV-DAG: %[[#PrivateCharPtr:]] = OpTypePointer Function %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
 
 ; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0

>From 6e23f694f2ab278bcb70c5ec1dd2629a1e0e9a8b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 07:50:54 -0800
Subject: [PATCH 2/5] add and use internal api call to create registers/assign
 types; fix v-reg type/class assignments

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 10 ++--
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 12 +----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  4 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   |  4 +-
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  5 +-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  9 +---
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 49 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            | 16 ++++++
 9 files changed, 79 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e34f6c3c282750..bed34b83d2e546 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -447,12 +447,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
                               Register DestinationReg = Register(0)) {
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  if (!DestinationReg.isValid()) {
-    DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
-    MRI->setType(DestinationReg, LLT::scalar(64));
-    GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
-  }
+  if (!DestinationReg.isValid())
+    DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
   // TODO: consider using correct address space and alignment (p0 is canonical
   // type for selection though).
   MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2129,7 +2125,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
       MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a7b6b0efa99551..3fdaa6aa3257ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -551,18 +551,10 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
               RetTy =
                   TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
         }
-        SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
-        GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
-        if (!MRI->getRegClassOrNull(ResVReg)) {
-          MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
-          MRI->setType(ResVReg, GR->getRegType(SpvType));
-        }
+        setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
       }
     } else {
-      SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
-      ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
-      MRI->setType(ResVReg, GR->getRegType(SpvType));
-      GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
     }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6f222883ee07de..4e539fcd6c9999 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -69,7 +69,7 @@ SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
 
 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
                                                 Register VReg,
-                                                MachineFunction &MF) {
+                                                const MachineFunction &MF) {
   VRegToTypeMap[&MF][VReg] = SpirvType;
 }
 
@@ -578,7 +578,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
   if (!Res.isValid()) {
     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
-    CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
         .addDef(Res)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 3bb86e8be69500..ff4b0ea8757fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -330,7 +330,7 @@ class SPIRVGlobalRegistry {
   // In cases where the SPIR-V type is already known, this function can be
   // used to map it to the given VReg via an ASSIGN_TYPE instruction.
   void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg,
-                             MachineFunction &MF);
+                             const MachineFunction &MF);
 
   // Either generate a new OpTypeXXX instruction or return an existing one
   // corresponding to the given LLVM IR type.
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 59a1bf50b771b9..b53ea1f7edf4a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -111,8 +111,8 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
                             SPIRVGlobalRegistry &GR, MachineInstr &I,
                             Register OpReg, unsigned OpIdx,
                             SPIRVType *NewPtrType) {
-  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MachineIRBuilder MIB(I);
+  Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
                  .addDef(NewReg)
                  .addUse(GR.getSPIRVTypeID(NewPtrType))
@@ -121,8 +121,6 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
                                    *STI.getRegBankInfo());
   if (!Res)
     report_fatal_error("insert validation bitcast: cannot constrain all uses");
-  MRI->setRegClass(NewReg, &SPIRV::iIDRegClass);
-  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
   I.getOperand(OpIdx).setReg(NewReg);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 460f0127d4ffcd..bd04b8c1c0b333 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -357,12 +357,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   verify(*ST.getInstrInfo());
 }
 
-static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
+static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
                                 LegalizerHelper &Helper,
                                 MachineRegisterInfo &MRI,
                                 SPIRVGlobalRegistry *GR) {
   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
-  GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
+  MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
+  GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
       .addDef(ConvReg)
       .addUse(Reg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 11b9e4f6f6d17b..3373d8e24dab48 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -102,10 +102,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           if (!ResType) {
             // There was no "assign type" actions, let's fix this now
             ResType = ScalarType;
-            MRI.setRegClass(ResVReg, &SPIRV::iIDRegClass);
-            MRI.setType(ResVReg,
-                        LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
-            GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+            setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
           }
         }
       } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
@@ -124,9 +121,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           if (!ResVType)
             continue;
           // Set type & class
-          MRI.setRegClass(ResVReg, GR->getRegClass(ResVType));
-          MRI.setType(ResVReg, GR->getRegType(ResVType));
-          GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+          setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
         }
         // If this is a simple operation that is to be reduced by TableGen
         // definition we must apply some of pre-legalizer rules here
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aeb2c29f7b8618..7e5bb1990626ff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -13,6 +13,7 @@
 #include "SPIRVUtils.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
+#include "SPIRVGlobalRegistry.h"
 #include "SPIRVInstrInfo.h"
 #include "SPIRVSubtarget.h"
 #include "llvm/ADT/StringRef.h"
@@ -677,4 +678,52 @@ bool getVacantFunctionName(Module &M, std::string &Name) {
   return false;
 }
 
+// Assign SPIR-V type to the register. If the register has no valid assigned
+// class, set register LLT type and class according to the SPIR-V type.
+void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                     MachineRegisterInfo *MRI, const MachineFunction &MF,
+                     bool Force) {
+  GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+  if (!MRI->getRegClassOrNull(Reg) || Force) {
+    MRI->setRegClass(Reg, GR->getRegClass(SpvType));
+    MRI->setType(Reg, GR->getRegType(SpvType));
+  }
+}
+
+// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
+// no valid assigned class, set register LLT type and class according to the
+// SPIR-V type.
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+                     MachineIRBuilder &MIRBuilder, bool Force) {
+  setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+                  MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                               MachineRegisterInfo *MRI,
+                               const MachineFunction &MF) {
+  Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+  MRI->setType(Reg, GR->getRegType(SpvType));
+  GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+  return Reg;
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder) {
+  return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
+                               MIRBuilder.getMF());
+}
+
+// Create a SPIR-V type, virtual register and assign SPIR-V type to the
+// register. Set register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder) {
+  return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+                               MIRBuilder);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 298b0b93b0e4d2..7a292b52bd1d16 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -34,6 +34,7 @@ class Register;
 class StringRef;
 class SPIRVInstrInfo;
 class SPIRVSubtarget;
+class SPIRVGlobalRegistry;
 
 // This class implements a partial ordering visitor, which visits a cyclic graph
 // in natural topological-like ordering. Topological ordering is not defined for
@@ -355,5 +356,20 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
 bool getVacantFunctionName(Module &M, std::string &Name);
 
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+                     MachineIRBuilder &MIRBuilder, bool Force = false);
+void setRegClassType(Register Reg, const MachineInstr *SpvType,
+                     SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
+                     const MachineFunction &MF, bool Force = false);
+Register createVirtualRegister(const MachineInstr *SpvType,
+                               SPIRVGlobalRegistry *GR,
+                               MachineRegisterInfo *MRI,
+                               const MachineFunction &MF);
+Register createVirtualRegister(const MachineInstr *SpvType,
+                               SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder);
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

>From c87bad4db5511940a3d061bb9f94b7b9b8fb1535 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 19 Nov 2024 12:55:15 -0800
Subject: [PATCH 3/5] improve type inference: change processing order,
 calculate uncomplete types, speed up postprocessing of types

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 164 ++++++++++++------
 .../fp_two_calls.ll                           |  12 +-
 .../CodeGen/SPIRV/pointers/phi-chain-types.ll |  82 +++++++++
 3 files changed, 200 insertions(+), 58 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e6ef40e010dc20..c98c22641273ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -78,6 +78,11 @@ class SPIRVEmitIntrinsics
   // a register of Instructions that don't have a complete type definition
   DenseMap<Value *, unsigned> UncompleteTypeInfo;
   SmallVector<Value *> PostprocessWorklist;
+  void addToUncompleteTypeInfo(Value *Op) {
+    auto It = UncompleteTypeInfo.try_emplace(Op, PostprocessWorklist.size());
+    if (It.second)
+      PostprocessWorklist.push_back(Op);
+  }
 
   // well known result types of builtins
   enum WellKnownTypes { Event };
@@ -105,8 +110,9 @@ class SPIRVEmitIntrinsics
                                bool UnknownElemTypeI8);
 
   // deduce Types of operands of the Instruction if possible
-  void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0,
-                                Type *AskTy = 0, CallInst *AssignCI = 0);
+  void deduceOperandElementType(Instruction *I,
+                                const SmallPtrSet<Value *, 4> *AskOps = nullptr,
+                                SmallPtrSet<Value *, 16> *Completed = nullptr);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -145,12 +151,20 @@ class SPIRVEmitIntrinsics
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
+
+  bool deduceOperandElementTypeCalledFunction(
+      SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
+      SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy);
+  void deduceOperandElementTypeFunctionPointer(
+      CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+      Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed);
+
   void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
                             CallInst *AssignCI);
   void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
 
   bool runOnFunction(Function &F);
-  bool postprocessTypes();
+  bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
 
 public:
@@ -286,11 +300,11 @@ void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
   if (DeleteOld) {
     unsigned Pos = It->second;
     UncompleteTypeInfo.erase(Src);
-    UncompleteTypeInfo[Dest] = Pos;
-    PostprocessWorklist[Pos] = Dest;
+    auto It = UncompleteTypeInfo.try_emplace(Dest, Pos);
+    if (It.second)
+      PostprocessWorklist[Pos] = Dest;
   } else {
-    UncompleteTypeInfo[Dest] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Dest);
+    addToUncompleteTypeInfo(Dest);
   }
 }
 
@@ -455,10 +469,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   if (isUntypedPointerTy(RefTy)) {
     if (!UnknownElemTypeI8)
       return;
-    if (auto *I = dyn_cast<Instruction>(Op)) {
-      UncompleteTypeInfo[I] = PostprocessWorklist.size();
-      PostprocessWorklist.push_back(I);
-    }
+    addToUncompleteTypeInfo(Op);
   }
   Ty = RefTy;
 }
@@ -661,10 +672,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
     return Ty;
   if (!UnknownElemTypeI8)
     return nullptr;
-  if (auto *Instr = dyn_cast<Instruction>(I)) {
-    UncompleteTypeInfo[Instr] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Instr);
-  }
+  addToUncompleteTypeInfo(I);
   return IntegerType::getInt8Ty(I->getContext());
 }
 
@@ -683,8 +691,7 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
 
 // Try to deduce element type for a call base. Returns false if this is an
 // indirect function invocation, and true otherwise.
-static bool deduceOperandElementTypeCalledFunction(
-    SPIRVGlobalRegistry *GR, Instruction *I,
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
     SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
     SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
   Function *CalledF = CI->getCalledFunction();
@@ -726,7 +733,7 @@ static bool deduceOperandElementTypeCalledFunction(
       case SPIRV::OpAtomicUMax:
       case SPIRV::OpAtomicSMin:
       case SPIRV::OpAtomicSMax: {
-        KnownElemTy = getAtomicElemTy(GR, I, Op);
+        KnownElemTy = getAtomicElemTy(GR, CI, Op);
         if (!KnownElemTy)
           return true;
         Ops.push_back(std::make_pair(Op, 0));
@@ -738,32 +745,44 @@ static bool deduceOperandElementTypeCalledFunction(
 }
 
 // Try to deduce element type for a function pointer.
-static void deduceOperandElementTypeFunctionPointer(
-    SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
-    SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
+    CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+    Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed) {
   Value *Op = CI->getCalledOperand();
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
   FunctionType *FTy = CI->getFunctionType();
-  bool IsNewFTy = false;
+  bool IsNewFTy = false, IsUncomplete = false;
   SmallVector<Type *, 4> ArgTys;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
-    if (ArgTy->isPointerTy())
+    if (ArgTy->isPointerTy()) {
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
         IsNewFTy = true;
         ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+        if (UncompleteTypeInfo.contains(Arg))
+          IsUncomplete = true;
+      } else {
+        IsUncomplete = true;
       }
+    }
     ArgTys.push_back(ArgTy);
   }
   Type *RetTy = FTy->getReturnType();
-  if (I->getType()->isPointerTy())
-    if (Type *ElemTy = GR->findDeducedElementType(I)) {
+  if (CI->getType()->isPointerTy()) {
+    if (Type *ElemTy = GR->findDeducedElementType(CI)) {
       IsNewFTy = true;
       RetTy =
-          TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
+          TypedPointerType::get(ElemTy, getPointerAddressSpace(CI->getType()));
+      if (UncompleteTypeInfo.contains(CI))
+        IsUncomplete = true;
+    } else {
+      IsUncomplete = true;
     }
+  }
+  if (!Completed && IsUncomplete)
+    addToUncompleteTypeInfo(Op);
   KnownElemTy =
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
@@ -772,10 +791,9 @@ static void deduceOperandElementTypeFunctionPointer(
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
 // resolve the issue.
-void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
-                                                   Instruction *AskOp,
-                                                   Type *AskTy,
-                                                   CallInst *AskCI) {
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Instruction *I, const SmallPtrSet<Value *, 4> *AskOps,
+    SmallPtrSet<Value *, 16> *Completed) {
   SmallVector<std::pair<Value *, unsigned>> Ops;
   Type *KnownElemTy = nullptr;
   // look for known basic patterns of type inference
@@ -875,10 +893,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     }
   } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
     if (!CI->isIndirectCall())
-      deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
-                                             KnownElemTy);
+      deduceOperandElementTypeCalledFunction(InstrSet, CI, Ops, KnownElemTy);
     else if (HaveFunPtrs)
-      deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
+      deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy, Completed);
   }
 
   // There is no enough info to deduce types or all is valid.
@@ -889,9 +906,19 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
   IRBuilder<> B(Ctx);
   for (auto &OpIt : Ops) {
     Value *Op = OpIt.first;
-    if (Op->use_empty() || (AskOp && Op != AskOp))
+    if (Op->use_empty())
       continue;
-    Type *Ty = AskOp ? AskTy : GR->findDeducedElementType(Op);
+    Type *AskTy = nullptr;
+    CallInst *AskCI = nullptr;
+    if (AskOps) {
+      auto It = AskOps->find(Op);
+      if (It == AskOps->end())
+        continue;
+      AskTy = GR->findDeducedElementType(Op);
+      AskCI = GR->findAssignPtrTypeInstr(Op);
+      assert(AskTy && AskCI);
+    }
+    Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
     Value *OpTyVal = PoisonValue::get(KnownElemTy);
@@ -899,6 +926,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     if (!Ty || AskTy || isUntypedPointerTy(Ty) ||
         UncompleteTypeInfo.contains(Op)) {
       GR->addDeducedElementType(Op, KnownElemTy);
+      // check if KnownElemTy is complete
+      if (!Completed && UncompleteTypeInfo.contains(I))
+        addToUncompleteTypeInfo(Op);
       // check if there is existing Intrinsic::spv_assign_ptr_type instruction
       CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op);
       if (AssignCI == nullptr) {
@@ -910,6 +940,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
         GR->addAssignPtrTypeInstr(Op, CI);
       } else {
         updateAssignType(AssignCI, Op, OpTyVal);
+        if (Completed)
+          Completed->insert(Op);
       }
     } else {
       if (auto *OpI = dyn_cast<Instruction>(Op)) {
@@ -1878,6 +1910,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   for (auto &I : instructions(Func))
     Worklist.push_back(&I);
 
+  // Pass forward: use operand to deduce instructions result.
   for (auto &I : Worklist) {
     // Don't emit intrinsincs for convergence intrinsics.
     if (isConvergenceIntrinsic(I))
@@ -1894,9 +1927,17 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
       insertAssignPtrTypeIntrs(I, B, true);
   }
 
-  for (auto &I : instructions(Func))
+  // Pass backward: use instructions results to specify/update/cast operands
+  // where needed.
+  for (auto &I : llvm::reverse(instructions(Func)))
     deduceOperandElementType(&I);
 
+  // Pass forward for PHIs only, their operands are not preceed the instruction
+  // in meaning of `instructions(Func)`.
+  for (BasicBlock &BB : Func)
+    for (PHINode &Phi : BB.phis())
+      deduceOperandElementType(&Phi);
+
   for (auto *I : Worklist) {
     TrackConstants = true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
@@ -1938,16 +1979,19 @@ void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
 }
 
 // Try to deduce a better type for pointers to untyped ptr.
-bool SPIRVEmitIntrinsics::postprocessTypes() {
-  bool Changed = false;
-  if (!GR)
-    return Changed;
+bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
+  if (!GR || UncompleteTypeInfo.size() == 0)
+    return false;
+
+  DenseMap<Value *, SmallPtrSet<Value *, 4>> ToProcess;
+  SmallPtrSet<Value *, 16> Completed;
   for (auto IB = PostprocessWorklist.rbegin(), IE = PostprocessWorklist.rend();
        IB != IE; ++IB) {
     CallInst *AssignCI = GR->findAssignPtrTypeInstr(*IB);
     Type *KnownTy = GR->findDeducedElementType(*IB);
-    if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
+    if (!KnownTy || !AssignCI)
       continue;
+    assert(AssignCI->getArgOperand(0) == *IB);
     // Try to improve the type deduced after all Functions are processed.
     if (auto *CI = dyn_cast<CallInst>(*IB)) {
       if (Function *CalledF = CI->getCalledFunction()) {
@@ -1955,24 +1999,37 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
         // Fix inconsistency between known type and function's return type.
         if (RetElemTy && RetElemTy != KnownTy) {
           replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
-          Changed = true;
+          Completed.insert(CI);
           continue;
         }
       }
     }
-    Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
-    for (User *U : I->users()) {
+    Value *Op = AssignCI->getArgOperand(0);
+    for (User *U : Op->users()) {
       Instruction *Inst = dyn_cast<Instruction>(U);
-      if (!Inst || isa<IntrinsicInst>(Inst))
+      if (Inst && !isa<IntrinsicInst>(Inst))
+        ToProcess[Inst].insert(Op);
+    }
+  }
+  if (Completed.size() >= UncompleteTypeInfo.size())
+    return true;
+
+  for (auto &F : M) {
+    for (auto &I : llvm::reverse(instructions(F))) {
+      auto It = ToProcess.find(&I);
+      if (It == ToProcess.end())
         continue;
-      deduceOperandElementType(Inst, I, KnownTy, AssignCI);
-      if (KnownTy != GR->findDeducedElementType(I)) {
-        Changed = true;
-        break;
-      }
+      It->second.remove_if(
+          [&Completed](Value *V) { return Completed.contains(V); });
+      if (It->second.size() == 0)
+        continue;
+      deduceOperandElementType(&I, &It->second, &Completed);
+      if (Completed.size() >= UncompleteTypeInfo.size())
+        return true;
     }
   }
-  return Changed;
+
+  return Completed.size() > 0;
 }
 
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
@@ -1983,17 +2040,16 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   for (auto &F : M)
     Changed |= runOnFunction(F);
 
+  // Specify function parameters after all functions were processed.
   for (auto &F : M) {
     // check if function parameter types are set
     if (!F.isDeclaration() && !F.isIntrinsic()) {
-      const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
-      GR = ST.getSPIRVGlobalRegistry();
       IRBuilder<> B(F.getContext());
       processParamTypes(&F, B);
     }
   }
 
-  Changed |= postprocessTypes();
+  Changed |= postprocessTypes(M);
   if (HaveFunPtrs)
     Changed |= processFunctionPointers(M);
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
index eb7b1dffaee501..621d06aa4aadee 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -1,4 +1,4 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
 ; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: OpCapability Int8
@@ -15,10 +15,14 @@
 ; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
 ; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
-; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
-; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrFp]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyUncompleteBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrUncompleteBar:.*]] = OpTypePointer Function %[[TyUncompleteBar]]
+; CHECK-DAG: %[[TyUncompleteFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrUncompleteBar]]
+; CHECK-DAG: %[[TyPtrUncompleteFp:.*]] = OpTypePointer Function %[[TyUncompleteFp]]
+; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrUncompleteFp]] %[[TyPtrInt8]]
 ; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]]
+; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrBar]]
+; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
 ; CHECK-DAG: %[[TyTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFp]] %[[TyPtrInt8]] %[[TyPtrBar]]
 ; CHECK: %[[test]] = OpFunction %[[TyVoid]] None %[[TyTest]]
 ; CHECK: %[[fp]] = OpFunctionParameter %[[TyPtrFp]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll
new file mode 100644
index 00000000000000..a9e79df259c4fb
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll
@@ -0,0 +1,82 @@
+; The goal of the test case is to ensure that correct types are applied to PHI's as arguments of other PHI's.
+; Pass criterion is that spirv-val considers output valid.
+
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK-DAG: OpName %[[#Foo:]] "foo"
+; CHECK-DAG: OpName %[[#FooVal1:]] "val1"
+; CHECK-DAG: OpName %[[#FooVal2:]] "val2"
+; CHECK-DAG: OpName %[[#FooVal3:]] "val3"
+; CHECK-DAG: OpName %[[#Bar:]] "bar"
+; CHECK-DAG: OpName %[[#BarVal1:]] "val1"
+; CHECK-DAG: OpName %[[#BarVal2:]] "val2"
+; CHECK-DAG: OpName %[[#BarVal3:]] "val3"
+
+; CHECK-DAG: %[[#Short:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#ShortGenPtr:]] = OpTypePointer Generic %[[#Short]]
+; CHECK-DAG: %[[#ShortWrkPtr:]] = OpTypePointer Workgroup %[[#Short]]
+; CHECK-DAG: %[[#G1:]] = OpVariable %[[#ShortWrkPtr]] Workgroup
+
+; CHECK: %[[#Foo:]] = OpFunction %[[#]] None %[[#]]
+; CHECK: %[[#FooArgP:]] = OpFunctionParameter %[[#ShortGenPtr]]
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: %[[#FooG1:]] = OpPtrCastToGeneric %[[#ShortGenPtr]] %[[#G1]]
+; CHECK: %[[#FooVal2]] = OpPhi %[[#ShortGenPtr]] %[[#FooArgP]] %[[#]] %[[#FooVal3]] %[[#]]
+; CHECK: %[[#FooVal1]] = OpPhi %[[#ShortGenPtr]] %[[#FooG1]] %[[#]] %[[#FooVal2]] %[[#]]
+; CHECK: %[[#FooVal3]] = OpLoad %[[#ShortGenPtr]] %[[#]]
+
+; CHECK: %[[#Bar:]] = OpFunction %[[#]] None %[[#]]
+; CHECK: %[[#BarArgP:]] = OpFunctionParameter %[[#ShortGenPtr]]
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: OpFunctionParameter
+; CHECK: %[[#BarVal3]] = OpLoad %[[#ShortGenPtr]] %[[#]]
+; CHECK: %[[#BarG1:]] = OpPtrCastToGeneric %[[#ShortGenPtr]] %[[#G1]]
+; CHECK: %[[#BarVal1]] = OpPhi %[[#ShortGenPtr]] %[[#BarG1]] %[[#]] %[[#BarVal2]] %[[#]]
+; CHECK: %[[#BarVal2]] = OpPhi %[[#ShortGenPtr]] %[[#BarArgP]] %[[#]] %[[#BarVal3]] %[[#]]
+
+ at G1 = internal addrspace(3) global i16 undef, align 8
+ at G2 = internal unnamed_addr addrspace(3) global ptr addrspace(4) undef, align 8
+
+define spir_kernel void @foo(ptr addrspace(4) %p, i1 %f1, i1 %f2, i1 %f3) {
+entry:
+  br label %l1
+
+l1:
+  br i1 %f1, label %l2, label %exit
+
+l2:
+  %val2 = phi ptr addrspace(4) [ %p, %l1 ], [ %val3, %l3 ]
+  %val1 = phi ptr addrspace(4) [ addrspacecast (ptr addrspace(3) @G1 to ptr addrspace(4)), %l1 ], [ %val2, %l3 ]
+  br i1 %f2, label %l3, label %exit
+
+l3:
+  %val3 = load ptr addrspace(4), ptr addrspace(3) @G2, align 8
+  br i1 %f3, label %l2, label %exit
+
+exit:
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(4) %p, i1 %f1, i1 %f2, i1 %f3) {
+entry:
+  %val3 = load ptr addrspace(4), ptr addrspace(3) @G2, align 8
+  br label %l1
+
+l3:
+  br i1 %f3, label %l2, label %exit
+
+l1:
+  br i1 %f1, label %l2, label %exit
+
+l2:
+  %val1 = phi ptr addrspace(4) [ addrspacecast (ptr addrspace(3) @G1 to ptr addrspace(4)), %l1 ], [ %val2, %l3 ]
+  %val2 = phi ptr addrspace(4) [ %p, %l1 ], [ %val3, %l3 ]
+  br i1 %f2, label %l3, label %exit
+
+exit:
+  ret void
+}

>From ebe4f1212d38149bcedd8a80290de8b7ccdfc3c9 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 20 Nov 2024 03:48:44 -0800
Subject: [PATCH 4/5] rework uncomplete types

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 207 +++++++++++-------
 1 file changed, 129 insertions(+), 78 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c98c22641273ec..0625d4c1469ed2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,7 +67,7 @@ class SPIRVEmitIntrinsics
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
-  Function *F = nullptr;
+  Function *CurrF = nullptr;
   bool TrackConstants = true;
   bool HaveFunPtrs = false;
   DenseMap<Instruction *, Constant *> AggrConsts;
@@ -76,12 +76,27 @@ class SPIRVEmitIntrinsics
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // a register of Instructions that don't have a complete type definition
-  DenseMap<Value *, unsigned> UncompleteTypeInfo;
-  SmallVector<Value *> PostprocessWorklist;
-  void addToUncompleteTypeInfo(Value *Op) {
-    auto It = UncompleteTypeInfo.try_emplace(Op, PostprocessWorklist.size());
-    if (It.second)
-      PostprocessWorklist.push_back(Op);
+  bool CanTodoType = true;
+  bool CanUpdateType = true;
+  unsigned TodoTypeSz = 0;
+  DenseMap<Value *, bool> TodoType;
+  void insertTodoType(Value *Op) {
+    if (CanTodoType) {
+      auto It = TodoType.try_emplace(Op, true);
+      if (It.second)
+        ++TodoTypeSz;
+    }
+  }
+  void eraseTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    if (It != TodoType.end() && It->second) {
+      TodoType[Op] = false;
+      --TodoTypeSz;
+    }
+  }
+  bool isTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    return It != TodoType.end() && It->second;
   }
 
   // well known result types of builtins
@@ -112,7 +127,7 @@ class SPIRVEmitIntrinsics
   // deduce Types of operands of the Instruction if possible
   void deduceOperandElementType(Instruction *I,
                                 const SmallPtrSet<Value *, 4> *AskOps = nullptr,
-                                SmallPtrSet<Value *, 16> *Completed = nullptr);
+                                bool IsPostprocessing = false);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -157,7 +172,7 @@ class SPIRVEmitIntrinsics
       SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy);
   void deduceOperandElementTypeFunctionPointer(
       CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
-      Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed);
+      Type *&KnownElemTy, bool IsPostprocessing);
 
   void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
                             CallInst *AssignCI);
@@ -294,17 +309,10 @@ void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
   GR->updateIfExistDeducedElementType(Src, Dest, DeleteOld);
   GR->updateIfExistAssignPtrTypeInstr(Src, Dest, DeleteOld);
   // Update uncomplete type records if any
-  auto It = UncompleteTypeInfo.find(Src);
-  if (It == UncompleteTypeInfo.end())
-    return;
-  if (DeleteOld) {
-    unsigned Pos = It->second;
-    UncompleteTypeInfo.erase(Src);
-    auto It = UncompleteTypeInfo.try_emplace(Dest, Pos);
-    if (It.second)
-      PostprocessWorklist[Pos] = Dest;
-  } else {
-    addToUncompleteTypeInfo(Dest);
+  if (isTodoType(Src)) {
+    if (DeleteOld)
+      eraseTodoType(Src);
+    insertTodoType(Dest);
   }
 }
 
@@ -368,7 +376,7 @@ void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
-      AssignPtrTyCI->getParent()->getParent() != F) {
+      AssignPtrTyCI->getParent()->getParent() != CurrF) {
     AssignPtrTyCI = buildIntrWithMD(
         Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
         {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
@@ -469,7 +477,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   if (isUntypedPointerTy(RefTy)) {
     if (!UnknownElemTypeI8)
       return;
-    addToUncompleteTypeInfo(Op);
+    insertTodoType(Op);
   }
   Ty = RefTy;
 }
@@ -672,7 +680,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
     return Ty;
   if (!UnknownElemTypeI8)
     return nullptr;
-  addToUncompleteTypeInfo(I);
+  insertTodoType(I);
   return IntegerType::getInt8Ty(I->getContext());
 }
 
@@ -747,7 +755,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
 // Try to deduce element type for a function pointer.
 void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
     CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
-    Type *&KnownElemTy, SmallPtrSet<Value *, 16> *Completed) {
+    Type *&KnownElemTy, bool IsPostprocessing) {
   Value *Op = CI->getCalledOperand();
   if (!Op || !isPointerTy(Op->getType()))
     return;
@@ -761,7 +769,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
         IsNewFTy = true;
         ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
-        if (UncompleteTypeInfo.contains(Arg))
+        if (isTodoType(Arg))
           IsUncomplete = true;
       } else {
         IsUncomplete = true;
@@ -775,14 +783,14 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
       IsNewFTy = true;
       RetTy =
           TypedPointerType::get(ElemTy, getPointerAddressSpace(CI->getType()));
-      if (UncompleteTypeInfo.contains(CI))
+      if (isTodoType(CI))
         IsUncomplete = true;
     } else {
       IsUncomplete = true;
     }
   }
-  if (!Completed && IsUncomplete)
-    addToUncompleteTypeInfo(Op);
+  if (!IsPostprocessing && IsUncomplete)
+    insertTodoType(Op);
   KnownElemTy =
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
@@ -793,14 +801,16 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
 // resolve the issue.
 void SPIRVEmitIntrinsics::deduceOperandElementType(
     Instruction *I, const SmallPtrSet<Value *, 4> *AskOps,
-    SmallPtrSet<Value *, 16> *Completed) {
+    bool IsPostprocessing) {
   SmallVector<std::pair<Value *, unsigned>> Ops;
   Type *KnownElemTy = nullptr;
+  bool Uncomplete = false;
   // look for known basic patterns of type inference
   if (auto *Ref = dyn_cast<PHINode>(I)) {
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
       Value *Op = Ref->getIncomingValue(i);
       if (isPointerTy(Op->getType()))
@@ -810,6 +820,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     KnownElemTy = GR->findDeducedElementType(I);
     if (!KnownElemTy)
       return;
+    Uncomplete = isTodoType(I);
     Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0));
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     KnownElemTy = Ref->getSourceElementType();
@@ -855,27 +866,29 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
       Value *Op = Ref->getOperand(i);
       if (isPointerTy(Op->getType()))
         Ops.push_back(std::make_pair(Op, i));
     }
   } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
-    Type *RetTy = F->getReturnType();
+    Type *RetTy = CurrF->getReturnType();
     if (!isPointerTy(RetTy))
       return;
     Value *Op = Ref->getReturnValue();
     if (!Op)
       return;
-    if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+    if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
       if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
-        GR->addDeducedElementType(F, OpElemTy);
+        GR->addDeducedElementType(CurrF, OpElemTy);
         TypedPointerType *DerivedTy =
             TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
-        GR->addReturnType(F, DerivedTy);
+        GR->addReturnType(CurrF, DerivedTy);
       }
       return;
     }
+    Uncomplete = isTodoType(CurrF);
     Ops.push_back(std::make_pair(Op, 0));
   } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
     if (!isPointerTy(Ref->getOperand(0)->getType()))
@@ -886,34 +899,36 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
     Type *ElemTy1 = GR->findDeducedElementType(Op1);
     if (ElemTy0) {
       KnownElemTy = ElemTy0;
+      Uncomplete = isTodoType(Op0);
       Ops.push_back(std::make_pair(Op1, 1));
     } else if (ElemTy1) {
       KnownElemTy = ElemTy1;
+      Uncomplete = isTodoType(Op1);
       Ops.push_back(std::make_pair(Op0, 0));
     }
   } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
     if (!CI->isIndirectCall())
       deduceOperandElementTypeCalledFunction(InstrSet, CI, Ops, KnownElemTy);
     else if (HaveFunPtrs)
-      deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy, Completed);
+      deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy,
+                                              IsPostprocessing);
   }
 
   // There is no enough info to deduce types or all is valid.
   if (!KnownElemTy || Ops.size() == 0)
     return;
 
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
   IRBuilder<> B(Ctx);
   for (auto &OpIt : Ops) {
     Value *Op = OpIt.first;
     if (Op->use_empty())
       continue;
+    if (AskOps && !AskOps->contains(Op))
+      continue;
     Type *AskTy = nullptr;
     CallInst *AskCI = nullptr;
-    if (AskOps) {
-      auto It = AskOps->find(Op);
-      if (It == AskOps->end())
-        continue;
+    if (IsPostprocessing && AskOps) {
       AskTy = GR->findDeducedElementType(Op);
       AskCI = GR->findAssignPtrTypeInstr(Op);
       assert(AskTy && AskCI);
@@ -923,12 +938,14 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
       continue;
     Value *OpTyVal = PoisonValue::get(KnownElemTy);
     Type *OpTy = Op->getType();
-    if (!Ty || AskTy || isUntypedPointerTy(Ty) ||
-        UncompleteTypeInfo.contains(Op)) {
+    if (!Ty || (CanUpdateType &&
+                (AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)))) {
       GR->addDeducedElementType(Op, KnownElemTy);
       // check if KnownElemTy is complete
-      if (!Completed && UncompleteTypeInfo.contains(I))
-        addToUncompleteTypeInfo(Op);
+      if (!Uncomplete)
+        eraseTodoType(Op);
+      else if (!IsPostprocessing)
+        insertTodoType(Op);
       // check if there is existing Intrinsic::spv_assign_ptr_type instruction
       CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op);
       if (AssignCI == nullptr) {
@@ -940,10 +957,9 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
         GR->addAssignPtrTypeInstr(Op, CI);
       } else {
         updateAssignType(AssignCI, Op, OpTyVal);
-        if (Completed)
-          Completed->insert(Op);
       }
     } else {
+      eraseTodoType(Op);
       if (auto *OpI = dyn_cast<Instruction>(Op)) {
         // spv_ptrcast's argument Op denotes an instruction that generates
         // a value, and we may use getInsertionPointAfterDef()
@@ -953,7 +969,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
         B.SetInsertPointPastAllocas(OpA->getParent());
         B.SetCurrentDebugLocation(DebugLoc());
       } else {
-        B.SetInsertPoint(F->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
+        B.SetInsertPoint(CurrF->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
       }
       SmallVector<Type *, 2> Types = {OpTy, OpTy};
       SmallVector<Value *, 2> Args = {Op, buildMD(OpTyVal),
@@ -993,7 +1009,7 @@ void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
 
 void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -1021,7 +1037,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
 
 void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -1080,7 +1096,7 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
     return &Call;
 
   const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand());
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
 
   Constant *TyC = UndefValue::get(IA->getFunctionType());
   MDString *ConstraintString = MDString::get(Ctx, IA->getConstraintString());
@@ -1281,10 +1297,10 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
                                                          IRBuilder<> &B) {
   // Handle basic instructions:
   StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (IsKernelArgInt8(F, SI)) {
+  if (IsKernelArgInt8(CurrF, SI)) {
     return replacePointerOperandWithPtrCast(
-        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
-        B);
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(CurrF->getContext()),
+        0, B);
   } else if (SI) {
     Value *Op = SI->getValueOperand();
     Type *OpTy = Op->getType();
@@ -1451,7 +1467,7 @@ Instruction *SPIRVEmitIntrinsics::visitLoadInst(LoadInst &I) {
   TrackConstants = false;
   const auto *TLI = TM->getSubtargetImpl()->getTargetLowering();
   MachineMemOperand::Flags Flags =
-      TLI->getLoadMemOperandFlags(I, F->getDataLayout());
+      TLI->getLoadMemOperandFlags(I, CurrF->getDataLayout());
   auto *NewI =
       B.CreateIntrinsic(Intrinsic::spv_load, {I.getOperand(0)->getType()},
                         {I.getPointerOperand(), B.getInt16(Flags),
@@ -1468,7 +1484,7 @@ Instruction *SPIRVEmitIntrinsics::visitStoreInst(StoreInst &I) {
   TrackConstants = false;
   const auto *TLI = TM->getSubtargetImpl()->getTargetLowering();
   MachineMemOperand::Flags Flags =
-      TLI->getStoreMemOperandFlags(I, F->getDataLayout());
+      TLI->getStoreMemOperandFlags(I, CurrF->getDataLayout());
   auto *PtrOp = I.getPointerOperand();
   auto *NewI = B.CreateIntrinsic(
       Intrinsic::spv_store, {I.getValueOperand()->getType(), PtrOp->getType()},
@@ -1774,9 +1790,28 @@ void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
     if (!isUntypedPointerTy(Arg->getType()))
       continue;
     Type *ElemTy = GR->findDeducedElementType(Arg);
-    if (!ElemTy && hasPointeeTypeAttr(Arg) &&
-        (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr)
+    if (ElemTy)
+      continue;
+    if (hasPointeeTypeAttr(Arg) &&
+        (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
       buildAssignPtr(B, ElemTy, Arg);
+      continue;
+    }
+    if (HaveFunPtrs) {
+      for (User *U : Arg->users()) {
+        CallInst *CI = dyn_cast<CallInst>(U);
+        if (CI && !isa<IntrinsicInst>(CI) && CI->isIndirectCall() &&
+            CI->getCalledOperand() == Arg &&
+            CI->getParent()->getParent() == CurrF) {
+          SmallVector<std::pair<Value *, unsigned>> Ops;
+          deduceOperandElementTypeFunctionPointer(CI, Ops, ElemTy, false);
+          if (ElemTy) {
+            buildAssignPtr(B, ElemTy, Arg);
+            break;
+          }
+        }
+      }
+    }
   }
 }
 
@@ -1877,17 +1912,17 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
                               : SPIRV::InstructionSet::GLSL_std_450;
 
-  if (!F)
+  if (!CurrF)
     HaveFunPtrs =
         ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
 
-  F = &Func;
+  CurrF = &Func;
   IRBuilder<> B(Func.getContext());
   AggrConsts.clear();
   AggrConstTypes.clear();
   AggrStores.clear();
 
-  processParamTypesByFunHeader(F, B);
+  processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,
   // so we need to store it in the set. Also store already transformed types.
@@ -1936,7 +1971,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   // in meaning of `instructions(Func)`.
   for (BasicBlock &BB : Func)
     for (PHINode &Phi : BB.phis())
-      deduceOperandElementType(&Phi);
+      if (isPointerTy(Phi.getType()))
+        deduceOperandElementType(&Phi);
 
   for (auto *I : Worklist) {
     TrackConstants = true;
@@ -1980,63 +2016,65 @@ void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
 
 // Try to deduce a better type for pointers to untyped ptr.
 bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
-  if (!GR || UncompleteTypeInfo.size() == 0)
+  if (!GR || TodoTypeSz == 0)
     return false;
 
+  unsigned SzTodo = TodoTypeSz;
   DenseMap<Value *, SmallPtrSet<Value *, 4>> ToProcess;
-  SmallPtrSet<Value *, 16> Completed;
-  for (auto IB = PostprocessWorklist.rbegin(), IE = PostprocessWorklist.rend();
-       IB != IE; ++IB) {
-    CallInst *AssignCI = GR->findAssignPtrTypeInstr(*IB);
-    Type *KnownTy = GR->findDeducedElementType(*IB);
+  for (auto [Op, Enabled] : TodoType) {
+    if (!Enabled)
+      continue;
+    CallInst *AssignCI = GR->findAssignPtrTypeInstr(Op);
+    Type *KnownTy = GR->findDeducedElementType(Op);
     if (!KnownTy || !AssignCI)
       continue;
-    assert(AssignCI->getArgOperand(0) == *IB);
+    assert(Op == AssignCI->getArgOperand(0));
     // Try to improve the type deduced after all Functions are processed.
-    if (auto *CI = dyn_cast<CallInst>(*IB)) {
+    if (auto *CI = dyn_cast<CallInst>(Op)) {
+      // TODO: deduceElementTypeHelper() & replaceWithPtrcasted() if
+      // isa<Instruction>(Op)
+      CurrF = CI->getParent()->getParent();
       if (Function *CalledF = CI->getCalledFunction()) {
         Type *RetElemTy = GR->findDeducedElementType(CalledF);
         // Fix inconsistency between known type and function's return type.
         if (RetElemTy && RetElemTy != KnownTy) {
           replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
-          Completed.insert(CI);
+          eraseTodoType(Op);
           continue;
         }
       }
     }
-    Value *Op = AssignCI->getArgOperand(0);
     for (User *U : Op->users()) {
       Instruction *Inst = dyn_cast<Instruction>(U);
       if (Inst && !isa<IntrinsicInst>(Inst))
         ToProcess[Inst].insert(Op);
     }
   }
-  if (Completed.size() >= UncompleteTypeInfo.size())
+  if (TodoTypeSz == 0)
     return true;
 
   for (auto &F : M) {
+    CurrF = &F;
     for (auto &I : llvm::reverse(instructions(F))) {
       auto It = ToProcess.find(&I);
       if (It == ToProcess.end())
         continue;
-      It->second.remove_if(
-          [&Completed](Value *V) { return Completed.contains(V); });
+      It->second.remove_if([this](Value *V) { return !isTodoType(V); });
       if (It->second.size() == 0)
         continue;
-      deduceOperandElementType(&I, &It->second, &Completed);
-      if (Completed.size() >= UncompleteTypeInfo.size())
+      deduceOperandElementType(&I, &It->second, true);
+      if (TodoTypeSz == 0)
         return true;
     }
   }
 
-  return Completed.size() > 0;
+  return SzTodo > TodoTypeSz;
 }
 
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   bool Changed = false;
 
-  UncompleteTypeInfo.clear();
-  PostprocessWorklist.clear();
+  TodoType.clear();
   for (auto &F : M)
     Changed |= runOnFunction(F);
 
@@ -2049,7 +2087,20 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
     }
   }
 
+  CanTodoType = false;
   Changed |= postprocessTypes(M);
+
+  // Validation pass.
+  CanUpdateType = false;
+  TodoType.clear();
+  for (auto &F : M) {
+    CurrF = &F;
+    for (BasicBlock &BB : F)
+      for (PHINode &Phi : BB.phis())
+        if (isPointerTy(Phi.getType()))
+          deduceOperandElementType(&Phi, nullptr, true);
+  }
+
   if (HaveFunPtrs)
     Changed |= processFunctionPointers(M);
 

>From 23ed87d14587ca2f3fb4d2c33c699adafeaa23b5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 20 Nov 2024 04:57:31 -0800
Subject: [PATCH 5/5] fix function pointers and dealing with uncomplete types

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp     | 15 +--------------
 .../SPV_INTEL_function_pointers/fp_two_calls.ll   | 14 +++++++-------
 2 files changed, 8 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 0625d4c1469ed2..7460e0a71aae51 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -77,7 +77,6 @@ class SPIRVEmitIntrinsics
 
   // a register of Instructions that don't have a complete type definition
   bool CanTodoType = true;
-  bool CanUpdateType = true;
   unsigned TodoTypeSz = 0;
   DenseMap<Value *, bool> TodoType;
   void insertTodoType(Value *Op) {
@@ -938,8 +937,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
       continue;
     Value *OpTyVal = PoisonValue::get(KnownElemTy);
     Type *OpTy = Op->getType();
-    if (!Ty || (CanUpdateType &&
-                (AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)))) {
+    if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       GR->addDeducedElementType(Op, KnownElemTy);
       // check if KnownElemTy is complete
       if (!Uncomplete)
@@ -2090,17 +2088,6 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   CanTodoType = false;
   Changed |= postprocessTypes(M);
 
-  // Validation pass.
-  CanUpdateType = false;
-  TodoType.clear();
-  for (auto &F : M) {
-    CurrF = &F;
-    for (BasicBlock &BB : F)
-      for (PHINode &Phi : BB.phis())
-        if (isPointerTy(Phi.getType()))
-          deduceOperandElementType(&Phi, nullptr, true);
-  }
-
   if (HaveFunPtrs)
     Changed |= processFunctionPointers(M);
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
index 621d06aa4aadee..1b217c3bb92f16 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll
@@ -12,17 +12,17 @@
 ; CHECK-DAG: OpName %[[test:.*]] "test"
 ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
 ; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32
-; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
-; CHECK-DAG: %[[TyUncompleteBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyPtrUncompleteBar:.*]] = OpTypePointer Function %[[TyUncompleteBar]]
-; CHECK-DAG: %[[TyUncompleteFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrUncompleteBar]]
+; CHECK-DAG: %[[TyUncompleteFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]]
 ; CHECK-DAG: %[[TyPtrUncompleteFp:.*]] = OpTypePointer Function %[[TyUncompleteFp]]
-; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrUncompleteFp]] %[[TyPtrInt8]]
-; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]]
-; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrBar]]
+; CHECK-DAG: %[[TyUncompleteBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrUncompleteFp]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrUncompleteBar:.*]] = OpTypePointer Function %[[TyUncompleteBar]]
+; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrUncompleteBar]]
 ; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]]
+; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrFp]] %[[TyPtrInt8]]
+; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]]
 ; CHECK-DAG: %[[TyTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFp]] %[[TyPtrInt8]] %[[TyPtrBar]]
 ; CHECK: %[[test]] = OpFunction %[[TyVoid]] None %[[TyTest]]
 ; CHECK: %[[fp]] = OpFunctionParameter %[[TyPtrFp]]



More information about the llvm-commits mailing list