[llvm] [SPIRV][NFC] Refactor pointer creation in GlobalRegistery (PR #134429)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 11:26:40 PDT 2025


https://github.com/s-perron created https://github.com/llvm/llvm-project/pull/134429

This PR adds new interfaces to create pointer type, and adds
some requirements to the old interfaces. This is the first step in
https://github.com/llvm/llvm-project/issues/134119.


>From 87e342c8cdf8e94a86cfd5442013270a58ee4a58 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 25 Mar 2025 13:09:23 -0400
Subject: [PATCH] [SPIRV][NFC] Refactor pointer creation in GlobalRegistery

This PR adds new interfaces to create pointer type, and adds
some requirements to the old interfaces. This is the first step in
https://github.com/llvm/llvm-project/issues/134119.
---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 14 +---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 84 ++++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 31 ++++++-
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   |  6 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 52 +++++-------
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 13 +--
 6 files changed, 135 insertions(+), 65 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index d55631e0146cf..5ec8c22dbf473 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -215,11 +215,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   Argument *Arg = F.getArg(ArgIdx);
   Type *ArgType = Arg->getType();
   if (isTypedPointerTy(ArgType)) {
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
-        cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
-        SPIRV::AccessQualifier::ReadWrite, true);
     return GR->getOrCreateSPIRVPointerType(
-        ElementType, MIRBuilder,
+        cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
   }
 
@@ -232,11 +229,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
   if (hasPointeeTypeAttr(Arg)) {
-    SPIRVType *ElementType =
-        GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
-                                 SPIRV::AccessQualifier::ReadWrite, true);
     return GR->getOrCreateSPIRVPointerType(
-        ElementType, MIRBuilder,
+        getPointeeTypeByAttr(Arg), MIRBuilder,
         addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
   }
 
@@ -259,10 +253,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
     MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
     Type *ElementTy =
         toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
-        ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
     return GR->getOrCreateSPIRVPointerType(
-        ElementType, MIRBuilder,
+        ElementTy, MIRBuilder,
         addressSpaceToStorageClass(
             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 60ec1c9f15a0c..5c0744ae128d6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -54,6 +54,40 @@ static unsigned typeToAddressSpace(const Type *Ty) {
   report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
 }
 
+static bool
+storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
+  switch (SC) {
+  case SPIRV::StorageClass::Uniform:
+  case SPIRV::StorageClass::PushConstant:
+  case SPIRV::StorageClass::StorageBuffer:
+    // case SPIRV::StorageClass::PhysicalStorageBuffer:
+    return true;
+  case SPIRV::StorageClass::UniformConstant:
+  case SPIRV::StorageClass::Input:
+  case SPIRV::StorageClass::Output:
+  case SPIRV::StorageClass::Workgroup:
+  case SPIRV::StorageClass::CrossWorkgroup:
+  case SPIRV::StorageClass::Private:
+  case SPIRV::StorageClass::Function:
+  case SPIRV::StorageClass::Generic:
+  case SPIRV::StorageClass::AtomicCounter:
+  case SPIRV::StorageClass::Image:
+  case SPIRV::StorageClass::CallableDataNV:
+  case SPIRV::StorageClass::IncomingCallableDataNV:
+  case SPIRV::StorageClass::RayPayloadNV:
+  case SPIRV::StorageClass::HitAttributeNV:
+  case SPIRV::StorageClass::IncomingRayPayloadNV:
+  case SPIRV::StorageClass::ShaderRecordBufferNV:
+  case SPIRV::StorageClass::CodeSectionINTEL:
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return false;
+  default:
+    llvm_unreachable("Unknown storage class");
+    return false;
+  }
+}
+
 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
     : PointerSize(PointerSize), Bound(0) {}
 
@@ -1080,6 +1114,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
   auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
   // Null pointer means we have a loop in type definitions, make and
   // return corresponding OpTypeForwardPointer.
+  // TODO: How can be this null?
   if (SpvElementType == nullptr) {
     auto [It, Inserted] = ForwardPointerTypes.try_emplace(Ty);
     if (Inserted)
@@ -1342,7 +1377,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
                           SPIRV::Decoration::NonWritable, 0, {});
   }
 
-  SPIRVType *R = getOrCreateSPIRVPointerType(BlockType, MIRBuilder, SC);
+  SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
   add(Key, R);
   return R;
 }
@@ -1524,7 +1559,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
 
   // Handle "type*" or  "type* vector[N]".
   if (TypeStr.starts_with("*")) {
-    SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
+    SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
     TypeStr = TypeStr.substr(strlen("*"));
   }
 
@@ -1693,6 +1728,43 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+    Type *BaseType, MachineInstr &I, SPIRV::StorageClass::StorageClass SC) {
+  MachineIRBuilder MIRBuilder(I);
+  return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+    Type *BaseType, MachineIRBuilder &MIRBuilder,
+    SPIRV::StorageClass::StorageClass SC) {
+  SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
+      BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+  return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
+    SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
+  SPIRV::StorageClass::StorageClass OldSC = getPointerStorageClass(PtrType);
+  assert(storageClassRequiresExplictLayout(OldSC) ==
+         storageClassRequiresExplictLayout(SC));
+
+  SPIRVType *PointeeType = getPointeeType(PtrType);
+  MachineIRBuilder MIRBuilder(I);
+  return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+    SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+    SPIRV::StorageClass::StorageClass SC) {
+  Type *LLVMType = const_cast<Type *>(getTypeForSPIRVType(BaseType));
+  assert(!storageClassRequiresExplictLayout(SC));
+  SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
+  assert(
+      getPointeeType(R) == BaseType &&
+      "The base type was not correctly laid out for the given storage class.");
+  return R;
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
     SPIRV::StorageClass::StorageClass SC) {
   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
@@ -1714,14 +1786,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
   return finishCreatingSPIRVType(Ty, NewMI);
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
-    SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
-    SPIRV::StorageClass::StorageClass SC) {
-  MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
-  MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
-  return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
-}
-
 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
                                                SPIRVType *SpvType,
                                                const SPIRVInstrInfo &TII) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index c18f17d1f3d23..11fe7eaf8df69 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -466,6 +466,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
                                          Constant *CA, unsigned BitWidth,
                                          unsigned ElemCnt);
 
+  // Returns a pointer to a SPIR-V pointer type with the given base type and
+  // storage class. It is the responsibility of the caller to make sure the
+  // decorations on the base type are valid for the given storage class. For
+  // example, it has the correct offset and stride decorations.
+  SPIRVType *getOrCreateSPIRVPointerTypeInternal(
+      SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+      SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
+
 public:
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
@@ -540,13 +548,32 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
                                        unsigned NumElements, MachineInstr &I,
                                        const SPIRVInstrInfo &TII);
 
+  // Returns a pointer to a SPIR-V pointer type with the given base type and
+  // storage class. The base type will be translated to a SPIR-V type, and the
+  // appropriate layout decorations will be added to the base type.
   SPIRVType *getOrCreateSPIRVPointerType(
-      SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+      Type *BaseType, MachineIRBuilder &MIRBuilder,
       SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
   SPIRVType *getOrCreateSPIRVPointerType(
-      SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
+      Type *BaseType, MachineInstr &I,
       SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
 
+  // Returns a pointer to a SPIR-V pointer type with the given base type and
+  // storage class. It is the responsibility of the caller to make sure the
+  // decorations on the base type are valid for the given storage class. For
+  // example, it has the correct offset and stride decorations.
+  SPIRVType *getOrCreateSPIRVPointerType(
+      SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+      SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
+
+  // Returns a pointer to a SPIR-V pointer type that is the same as `PtrType`
+  // except the stroage class has been changed to `SC`. It is the responsibility
+  // of the caller to be sure that the original and new storage class have the
+  // same layout requirements.
+  SPIRVType *changePointerStorageClass(SPIRVType *PtrType,
+                                       SPIRV::StorageClass::StorageClass SC,
+                                       MachineInstr &I);
+
   SPIRVType *getOrCreateVulkanBufferType(MachineIRBuilder &MIRBuilder,
                                          Type *ElemType,
                                          SPIRV::StorageClass::StorageClass SC,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index c347dde89256f..d274839af82eb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -214,10 +214,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
           PtrType->getOperand(1).getImm());
   MachineIRBuilder MIB(I);
   LLVMContext &Context = MF->getFunction().getContext();
-  SPIRVType *ElemType =
-      GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
-                              SPIRV::AccessQualifier::ReadWrite, false);
-  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
+  SPIRVType *NewPtrType =
+      GR.getOrCreateSPIRVPointerType(IntegerType::getInt8Ty(Context), MIB, SC);
   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 946a295c2df25..c41387559982c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1259,14 +1259,18 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
   Register SrcReg = I.getOperand(1).getReg();
   bool Result = true;
   if (I.getOpcode() == TargetOpcode::G_MEMSET) {
+    MachineIRBuilder MIRBuilder(I);
     assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
     unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
     unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
-    SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
-    SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
-    Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
+    Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
+    Type *ArrTy = ArrayType::get(ValTy, Num);
     SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
-        ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
+        ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);
+
+    SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
+        ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
+    Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
     // TODO: check if we have such GV, add init, use buildGlobalVariable.
     Function &CurFunction = GR.CurMF->getFunction();
     Type *LLVMArrTy =
@@ -1289,7 +1293,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
 
     buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
     SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
-        ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
+        ValTy, I, SPIRV::StorageClass::UniformConstant);
     SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
     selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
   }
@@ -1590,7 +1594,7 @@ static bool isASCastInGVar(MachineRegisterInfo *MRI, Register ResVReg) {
 Register SPIRVInstructionSelector::getUcharPtrTypeReg(
     MachineInstr &I, SPIRV::StorageClass::StorageClass SC) const {
   return GR.getSPIRVTypeID(GR.getOrCreateSPIRVPointerType(
-      GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, SC));
+      Type::getInt8Ty(I.getMF()->getFunction().getContext()), I, SC));
 }
 
 MachineInstrBuilder
@@ -1608,8 +1612,8 @@ SPIRVInstructionSelector::buildSpecConstantOp(MachineInstr &I, Register Dest,
 MachineInstrBuilder
 SPIRVInstructionSelector::buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
                                                SPIRVType *SrcPtrTy) const {
-  SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
-      GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
+  SPIRVType *GenericPtrTy =
+      GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
   Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
   MRI->setType(Tmp, LLT::pointer(storageClassToAddressSpace(
                                      SPIRV::StorageClass::Generic),
@@ -1694,8 +1698,8 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
   // Casting between 2 eligible pointers using Generic as an intermediary.
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
-    SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
-        GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
+    SPIRVType *GenericPtrTy =
+        GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
     Register Tmp = createVirtualRegister(GenericPtrTy, &GR, MRI, MRI->getMF());
     bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
                       .addDef(Tmp)
@@ -3366,18 +3370,20 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic(
 }
 
 Register SPIRVInstructionSelector::buildPointerToResource(
-    const SPIRVType *ResType, SPIRV::StorageClass::StorageClass SC,
+    const SPIRVType *SpirvResType, SPIRV::StorageClass::StorageClass SC,
     uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg,
     bool IsNonUniform, MachineIRBuilder MIRBuilder) const {
+  Type *ResType = const_cast<Type *>(GR.getTypeForSPIRVType(SpirvResType));
   if (ArraySize == 1) {
-    SPIRVType *PtrType =
-        GR.getOrCreateSPIRVPointerType(ResType, MIRBuilder, SC);
+    SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
+        const_cast<Type *>(ResType), MIRBuilder, SC);
+    assert(GR.getPointeeType(PtrType) == SpirvResType &&
+           "SpirvResType did not have an explicit layout.");
     return GR.getOrCreateGlobalVariableWithBinding(PtrType, Set, Binding,
                                                    MIRBuilder);
   }
 
-  const SPIRVType *VarType = GR.getOrCreateSPIRVArrayType(
-      ResType, ArraySize, *MIRBuilder.getInsertPt(), TII);
+  Type *VarType = ArrayType::get(ResType, ArraySize);
   SPIRVType *VarPointerType =
       GR.getOrCreateSPIRVPointerType(VarType, MIRBuilder, SC);
   Register VarReg = GR.getOrCreateGlobalVariableWithBinding(
@@ -3807,17 +3813,6 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   MachineIRBuilder MIRBuilder(I);
   const GlobalValue *GV = I.getOperand(1).getGlobal();
   Type *GVType = toTypedPointer(GR.getDeducedGlobalValueType(GV));
-  SPIRVType *PointerBaseType;
-  if (GVType->isArrayTy()) {
-    SPIRVType *ArrayElementType =
-        GR.getOrCreateSPIRVType(GVType->getArrayElementType(), MIRBuilder,
-                                SPIRV::AccessQualifier::ReadWrite, false);
-    PointerBaseType = GR.getOrCreateSPIRVArrayType(
-        ArrayElementType, GVType->getArrayNumElements(), I, TII);
-  } else {
-    PointerBaseType = GR.getOrCreateSPIRVType(
-        GVType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false);
-  }
 
   std::string GlobalIdent;
   if (!GV->hasName()) {
@@ -3850,7 +3845,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
               ? dyn_cast<Function>(GV)
               : nullptr;
       SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
-          PointerBaseType, I, TII,
+          GVType, I,
           GVFun ? SPIRV::StorageClass::CodeSectionINTEL
                 : addressSpaceToStorageClass(GV->getAddressSpace(), STI));
       if (GVFun) {
@@ -3908,8 +3903,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   const unsigned AddrSpace = GV->getAddressSpace();
   SPIRV::StorageClass::StorageClass StorageClass =
       addressSpaceToStorageClass(AddrSpace, STI);
-  SPIRVType *ResType =
-      GR.getOrCreateSPIRVPointerType(PointerBaseType, I, TII, StorageClass);
+  SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass);
   Register Reg = GR.buildGlobalVariable(
       ResVReg, ResType, GlobalIdent, GV, StorageClass, Init,
       GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index e4cc03eff1035..3fcff3dd8f553 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -251,10 +251,8 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Def = MI.getOperand(0).getReg();
       Register Source = MI.getOperand(2).getReg();
       Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
-      SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
-          ElemTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-          BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+          ElemTy, MI,
           addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
 
       // If the ptrcast would be redundant, replace all uses with the source
@@ -366,9 +364,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
                 RegType.getAddressSpace()) {
           const SPIRVSubtarget &ST =
               MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
-          SpvType = GR->getOrCreateSPIRVPointerType(
-              GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
-              addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
+          auto TSC = addressSpaceToStorageClass(RegType.getAddressSpace(), ST);
+          SpvType = GR->changePointerStorageClass(SpvType, TSC, *MI);
         }
         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
       }
@@ -518,10 +515,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
         Register Reg = MI.getOperand(1).getReg();
         MIB.setInsertPt(*MI.getParent(), MI.getIterator());
         Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
-        SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
-            ElementTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
-            BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+            ElementTy, MI,
             addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");



More information about the llvm-commits mailing list