[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 07:51:06 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/116609

>From d89bb33f0adb8d98fe668765353f4d79ec64dbd8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 04:47:27 -0800
Subject: [PATCH 1/2] Improve type inference: return values in call lowering

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 18 +-----
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 25 +++++++++
 .../SPIRV/pointers/builtin-ret-reg-type.ll    | 55 +++++++++++++++++++
 .../SPIRV/transcoding/OpGenericCastToPtr.ll   |  2 -
 4 files changed, 83 insertions(+), 17 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..e34f6c3c282750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2517,23 +2517,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
-  // SPIR-V type and return register.
-  Register ReturnRegister = OrigRet;
-  SPIRVType *ReturnType = nullptr;
-  if (OrigRetTy && !OrigRetTy->isVoidTy()) {
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
-      MIRBuilder.getMRI()->setRegClass(ReturnRegister,
-                                       GR->getRegClass(ReturnType));
-  } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
-    ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
-    MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-  }
-
   // Lookup the builtin in the TableGen records.
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+  assert(SpvType && "Inconsistent return register: expected valid type info");
   std::unique_ptr<const IncomingCall> Call =
-      lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+      lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
 
   if (!Call) {
     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..a7b6b0efa99551 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,31 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   if (isFunctionDecl && !DemangledName.empty() &&
       (canUseGLSL || canUseOpenCL)) {
+    if (ResVReg.isValid()) {
+      if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+        const Type *RetTy = OrigRetTy;
+        if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+          const Value *OrigValue = Info.OrigRet.OrigValue;
+          if (!OrigValue)
+            OrigValue = Info.CB;
+          if (OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+              RetTy =
+                  TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+        }
+        SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
+        GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+        if (!MRI->getRegClassOrNull(ResVReg)) {
+          MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
+          MRI->setType(ResVReg, GR->getRegType(SpvType));
+        }
+      }
+    } else {
+      SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
+      ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+      MRI->setType(ResVReg, GR->getRegType(SpvType));
+      GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+    }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
new file mode 100644
index 00000000000000..afa97ccfc0a69c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers which were
+; used as return values in call lowering. Pass criterion is that spirv-val considers output valid.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpPhi %[[#]] %[[#Ptr:]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#Ptr]] = OpPtrAccessChain %[[#]] %[[#]] %[[#]]
+
+
+%t_half = type { half }
+%t_i17 = type { [17 x i32] }
+%t_h17 = type { [17 x %t_half] }
+
+define internal spir_func void @foo(i64 %arrayinit.cur.add_4, half %r1, ptr addrspace(4) noundef align 8 dereferenceable_or_null(72) %this) {
+entry:
+  %r_3 = alloca %t_h17, align 8
+  %p_src = alloca %t_i17, align 4
+  %p_src4 = addrspacecast ptr %p_src to ptr addrspace(4)
+  %call_2 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %p_src4, i32 noundef 7)
+  br label %l_body
+
+l_body:                                           ; preds = %l_body, %entry
+  %l_done = icmp eq i64 %arrayinit.cur.add_4, 34
+  br i1 %l_done, label %exit, label %l_body
+
+exit:                                             ; preds = %l_body
+  %0 = addrspacecast ptr %call_2 to ptr addrspace(4)
+  %call_6 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %0, i32 noundef 7)
+  br label %for.cond_3
+
+for.cond_3:                                       ; preds = %for.body_3, %exit
+  %lsr.iv1 = phi ptr [ %scevgep2, %for.body_3 ], [ %call_6, %exit ]
+  %lsr.iv = phi ptr [ %scevgep, %for.body_3 ], [ %r_3, %exit ]
+  %i.0_3 = phi i64 [ 0, %exit ], [ %inc_3, %for.body_3 ]
+  %cmp_3 = icmp ult i64 %i.0_3, 17
+  br i1 %cmp_3, label %for.body_3, label %exit2
+
+for.body_3:                                       ; preds = %for.cond_3
+  %call2_5 = call spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef %r1, ptr noundef %lsr.iv1)
+  store half %call2_5, ptr %lsr.iv, align 2
+  %inc_3 = add nuw nsw i64 %i.0_3, 1
+  %scevgep = getelementptr i8, ptr %lsr.iv, i64 2
+  %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+  br label %for.cond_3
+
+exit2:                                            ; preds = %for.cond_3
+  ret void
+}
+
+declare dso_local spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef, ptr noundef)
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
index 54b2c786747768..2cba0f6ebd74be 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
@@ -2,9 +2,7 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
-; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
-; CHECK-SPIRV-DAG: %[[#PrivateCharPtr:]] = OpTypePointer Function %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
 
 ; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0

>From 6e23f694f2ab278bcb70c5ec1dd2629a1e0e9a8b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 07:50:54 -0800
Subject: [PATCH 2/2] add and use internal api call to create registers/assign
 types; fix v-reg type/class assignments

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 10 ++--
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 12 +----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  4 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  2 +-
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   |  4 +-
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  5 +-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  9 +---
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 49 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            | 16 ++++++
 9 files changed, 79 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e34f6c3c282750..bed34b83d2e546 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -447,12 +447,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
                               Register DestinationReg = Register(0)) {
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  if (!DestinationReg.isValid()) {
-    DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
-    MRI->setType(DestinationReg, LLT::scalar(64));
-    GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
-  }
+  if (!DestinationReg.isValid())
+    DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
   // TODO: consider using correct address space and alignment (p0 is canonical
   // type for selection though).
   MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2129,7 +2125,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
       MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a7b6b0efa99551..3fdaa6aa3257ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -551,18 +551,10 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
               RetTy =
                   TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
         }
-        SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
-        GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
-        if (!MRI->getRegClassOrNull(ResVReg)) {
-          MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
-          MRI->setType(ResVReg, GR->getRegType(SpvType));
-        }
+        setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
       }
     } else {
-      SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
-      ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
-      MRI->setType(ResVReg, GR->getRegType(SpvType));
-      GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
     }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6f222883ee07de..4e539fcd6c9999 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -69,7 +69,7 @@ SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
 
 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
                                                 Register VReg,
-                                                MachineFunction &MF) {
+                                                const MachineFunction &MF) {
   VRegToTypeMap[&MF][VReg] = SpirvType;
 }
 
@@ -578,7 +578,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
   if (!Res.isValid()) {
     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
-    CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
         .addDef(Res)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 3bb86e8be69500..ff4b0ea8757fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -330,7 +330,7 @@ class SPIRVGlobalRegistry {
   // In cases where the SPIR-V type is already known, this function can be
   // used to map it to the given VReg via an ASSIGN_TYPE instruction.
   void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg,
-                             MachineFunction &MF);
+                             const MachineFunction &MF);
 
   // Either generate a new OpTypeXXX instruction or return an existing one
   // corresponding to the given LLVM IR type.
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 59a1bf50b771b9..b53ea1f7edf4a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -111,8 +111,8 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
                             SPIRVGlobalRegistry &GR, MachineInstr &I,
                             Register OpReg, unsigned OpIdx,
                             SPIRVType *NewPtrType) {
-  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
   MachineIRBuilder MIB(I);
+  Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
   bool Res = MIB.buildInstr(SPIRV::OpBitcast)
                  .addDef(NewReg)
                  .addUse(GR.getSPIRVTypeID(NewPtrType))
@@ -121,8 +121,6 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
                                    *STI.getRegBankInfo());
   if (!Res)
     report_fatal_error("insert validation bitcast: cannot constrain all uses");
-  MRI->setRegClass(NewReg, &SPIRV::iIDRegClass);
-  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
   I.getOperand(OpIdx).setReg(NewReg);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 460f0127d4ffcd..bd04b8c1c0b333 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -357,12 +357,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   verify(*ST.getInstrInfo());
 }
 
-static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
+static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
                                 LegalizerHelper &Helper,
                                 MachineRegisterInfo &MRI,
                                 SPIRVGlobalRegistry *GR) {
   Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
-  GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
+  MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
+  GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
   Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
       .addDef(ConvReg)
       .addUse(Reg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 11b9e4f6f6d17b..3373d8e24dab48 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -102,10 +102,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           if (!ResType) {
             // There was no "assign type" actions, let's fix this now
             ResType = ScalarType;
-            MRI.setRegClass(ResVReg, &SPIRV::iIDRegClass);
-            MRI.setType(ResVReg,
-                        LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
-            GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+            setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
           }
         }
       } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
@@ -124,9 +121,7 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           if (!ResVType)
             continue;
           // Set type & class
-          MRI.setRegClass(ResVReg, GR->getRegClass(ResVType));
-          MRI.setType(ResVReg, GR->getRegType(ResVType));
-          GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+          setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
         }
         // If this is a simple operation that is to be reduced by TableGen
         // definition we must apply some of pre-legalizer rules here
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aeb2c29f7b8618..7e5bb1990626ff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -13,6 +13,7 @@
 #include "SPIRVUtils.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
+#include "SPIRVGlobalRegistry.h"
 #include "SPIRVInstrInfo.h"
 #include "SPIRVSubtarget.h"
 #include "llvm/ADT/StringRef.h"
@@ -677,4 +678,52 @@ bool getVacantFunctionName(Module &M, std::string &Name) {
   return false;
 }
 
+// Assign SPIR-V type to the register. If the register has no valid assigned
+// class, set register LLT type and class according to the SPIR-V type.
+void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                     MachineRegisterInfo *MRI, const MachineFunction &MF,
+                     bool Force) {
+  GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+  if (!MRI->getRegClassOrNull(Reg) || Force) {
+    MRI->setRegClass(Reg, GR->getRegClass(SpvType));
+    MRI->setType(Reg, GR->getRegType(SpvType));
+  }
+}
+
+// Create a SPIR-V type, assign SPIR-V type to the register. If the register has
+// no valid assigned class, set register LLT type and class according to the
+// SPIR-V type.
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+                     MachineIRBuilder &MIRBuilder, bool Force) {
+  setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+                  MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                               MachineRegisterInfo *MRI,
+                               const MachineFunction &MF) {
+  Register Reg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+  MRI->setType(Reg, GR->getRegType(SpvType));
+  GR->assignSPIRVTypeToVReg(SpvType, Reg, MF);
+  return Reg;
+}
+
+// Create a virtual register and assign SPIR-V type to the register. Set
+// register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder) {
+  return createVirtualRegister(SpvType, GR, MIRBuilder.getMRI(),
+                               MIRBuilder.getMF());
+}
+
+// Create a SPIR-V type, virtual register and assign SPIR-V type to the
+// register. Set register LLT type and class according to the SPIR-V type.
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder) {
+  return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
+                               MIRBuilder);
+}
+
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 298b0b93b0e4d2..7a292b52bd1d16 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -34,6 +34,7 @@ class Register;
 class StringRef;
 class SPIRVInstrInfo;
 class SPIRVSubtarget;
+class SPIRVGlobalRegistry;
 
 // This class implements a partial ordering visitor, which visits a cyclic graph
 // in natural topological-like ordering. Topological ordering is not defined for
@@ -355,5 +356,20 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun"
 bool getVacantFunctionName(Module &M, std::string &Name);
 
+void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
+                     MachineIRBuilder &MIRBuilder, bool Force = false);
+void setRegClassType(Register Reg, const MachineInstr *SpvType,
+                     SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
+                     const MachineFunction &MF, bool Force = false);
+Register createVirtualRegister(const MachineInstr *SpvType,
+                               SPIRVGlobalRegistry *GR,
+                               MachineRegisterInfo *MRI,
+                               const MachineFunction &MF);
+Register createVirtualRegister(const MachineInstr *SpvType,
+                               SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
+                               MachineIRBuilder &MIRBuilder);
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H



More information about the llvm-commits mailing list