[llvm] [SPIRV][NFC] Refactor pointer creation in GlobalRegistery (PR #134429)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 4 11:26:40 PDT 2025
https://github.com/s-perron created https://github.com/llvm/llvm-project/pull/134429
This PR adds new interfaces to create pointer type, and adds
some requirements to the old interfaces. This is the first step in
https://github.com/llvm/llvm-project/issues/134119.
>From 87e342c8cdf8e94a86cfd5442013270a58ee4a58 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 25 Mar 2025 13:09:23 -0400
Subject: [PATCH] [SPIRV][NFC] Refactor pointer creation in GlobalRegistery
This PR adds new interfaces to create pointer type, and adds
some requirements to the old interfaces. This is the first step in
https://github.com/llvm/llvm-project/issues/134119.
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 14 +---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 84 ++++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 31 ++++++-
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 6 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 52 +++++-------
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 13 +--
6 files changed, 135 insertions(+), 65 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index d55631e0146cf..5ec8c22dbf473 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -215,11 +215,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
Argument *Arg = F.getArg(ArgIdx);
Type *ArgType = Arg->getType();
if (isTypedPointerTy(ArgType)) {
- SPIRVType *ElementType = GR->getOrCreateSPIRVType(
- cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
- SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
- ElementType, MIRBuilder,
+ cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}
@@ -232,11 +229,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
if (hasPointeeTypeAttr(Arg)) {
- SPIRVType *ElementType =
- GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
- SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
- ElementType, MIRBuilder,
+ getPointeeTypeByAttr(Arg), MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}
@@ -259,10 +253,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
Type *ElementTy =
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
- SPIRVType *ElementType = GR->getOrCreateSPIRVType(
- ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
- ElementType, MIRBuilder,
+ ElementTy, MIRBuilder,
addressSpaceToStorageClass(
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 60ec1c9f15a0c..5c0744ae128d6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -54,6 +54,40 @@ static unsigned typeToAddressSpace(const Type *Ty) {
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
}
+static bool
+storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
+ switch (SC) {
+ case SPIRV::StorageClass::Uniform:
+ case SPIRV::StorageClass::PushConstant:
+ case SPIRV::StorageClass::StorageBuffer:
+ // case SPIRV::StorageClass::PhysicalStorageBuffer:
+ return true;
+ case SPIRV::StorageClass::UniformConstant:
+ case SPIRV::StorageClass::Input:
+ case SPIRV::StorageClass::Output:
+ case SPIRV::StorageClass::Workgroup:
+ case SPIRV::StorageClass::CrossWorkgroup:
+ case SPIRV::StorageClass::Private:
+ case SPIRV::StorageClass::Function:
+ case SPIRV::StorageClass::Generic:
+ case SPIRV::StorageClass::AtomicCounter:
+ case SPIRV::StorageClass::Image:
+ case SPIRV::StorageClass::CallableDataNV:
+ case SPIRV::StorageClass::IncomingCallableDataNV:
+ case SPIRV::StorageClass::RayPayloadNV:
+ case SPIRV::StorageClass::HitAttributeNV:
+ case SPIRV::StorageClass::IncomingRayPayloadNV:
+ case SPIRV::StorageClass::ShaderRecordBufferNV:
+ case SPIRV::StorageClass::CodeSectionINTEL:
+ case SPIRV::StorageClass::DeviceOnlyINTEL:
+ case SPIRV::StorageClass::HostOnlyINTEL:
+ return false;
+ default:
+ llvm_unreachable("Unknown storage class");
+ return false;
+ }
+}
+
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize), Bound(0) {}
@@ -1080,6 +1114,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
+ // TODO: How can be this null?
if (SpvElementType == nullptr) {
auto [It, Inserted] = ForwardPointerTypes.try_emplace(Ty);
if (Inserted)
@@ -1342,7 +1377,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
SPIRV::Decoration::NonWritable, 0, {});
}
- SPIRVType *R = getOrCreateSPIRVPointerType(BlockType, MIRBuilder, SC);
+ SPIRVType *R = getOrCreateSPIRVPointerTypeInternal(BlockType, MIRBuilder, SC);
add(Key, R);
return R;
}
@@ -1524,7 +1559,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
- SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
+ SpirvTy = getOrCreateSPIRVPointerType(Ty, MIRBuilder, SC);
TypeStr = TypeStr.substr(strlen("*"));
}
@@ -1693,6 +1728,43 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+ Type *BaseType, MachineInstr &I, SPIRV::StorageClass::StorageClass SC) {
+ MachineIRBuilder MIRBuilder(I);
+ return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+ Type *BaseType, MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SC) {
+ SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
+ BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+ return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::changePointerStorageClass(
+ SPIRVType *PtrType, SPIRV::StorageClass::StorageClass SC, MachineInstr &I) {
+ SPIRV::StorageClass::StorageClass OldSC = getPointerStorageClass(PtrType);
+ assert(storageClassRequiresExplictLayout(OldSC) ==
+ storageClassRequiresExplictLayout(SC));
+
+ SPIRVType *PointeeType = getPointeeType(PtrType);
+ MachineIRBuilder MIRBuilder(I);
+ return getOrCreateSPIRVPointerTypeInternal(PointeeType, MIRBuilder, SC);
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
+ SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SC) {
+ Type *LLVMType = const_cast<Type *>(getTypeForSPIRVType(BaseType));
+ assert(!storageClassRequiresExplictLayout(SC));
+ SPIRVType *R = getOrCreateSPIRVPointerType(LLVMType, MIRBuilder, SC);
+ assert(
+ getPointeeType(R) == BaseType &&
+ "The base type was not correctly laid out for the given storage class.");
+ return R;
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeInternal(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
@@ -1714,14 +1786,6 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
return finishCreatingSPIRVType(Ty, NewMI);
}
-SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
- SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
- SPIRV::StorageClass::StorageClass SC) {
- MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
- MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
- return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
-}
-
Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index c18f17d1f3d23..11fe7eaf8df69 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -466,6 +466,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Constant *CA, unsigned BitWidth,
unsigned ElemCnt);
+ // Returns a pointer to a SPIR-V pointer type with the given base type and
+ // storage class. It is the responsibility of the caller to make sure the
+ // decorations on the base type are valid for the given storage class. For
+ // example, it has the correct offset and stride decorations.
+ SPIRVType *getOrCreateSPIRVPointerTypeInternal(
+ SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
+
public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
@@ -540,13 +548,32 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);
+ // Returns a pointer to a SPIR-V pointer type with the given base type and
+ // storage class. The base type will be translated to a SPIR-V type, and the
+ // appropriate layout decorations will be added to the base type.
SPIRVType *getOrCreateSPIRVPointerType(
- SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+ Type *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
SPIRVType *getOrCreateSPIRVPointerType(
- SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
+ Type *BaseType, MachineInstr &I,
SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
+ // Returns a pointer to a SPIR-V pointer type with the given base type and
+ // storage class. It is the responsibility of the caller to make sure the
+ // decorations on the base type are valid for the given storage class. For
+ // example, it has the correct offset and stride decorations.
+ SPIRVType *getOrCreateSPIRVPointerType(
+ SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function);
+
+ // Returns a pointer to a SPIR-V pointer type that is the same as `PtrType`
+ // except the stroage class has been changed to `SC`. It is the responsibility
+ // of the caller to be sure that the original and new storage class have the
+ // same layout requirements.
+ SPIRVType *changePointerStorageClass(SPIRVType *PtrType,
+ SPIRV::StorageClass::StorageClass SC,
+ MachineInstr &I);
+
SPIRVType *getOrCreateVulkanBufferType(MachineIRBuilder &MIRBuilder,
Type *ElemType,
SPIRV::StorageClass::StorageClass SC,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index c347dde89256f..d274839af82eb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -214,10 +214,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
PtrType->getOperand(1).getImm());
MachineIRBuilder MIB(I);
LLVMContext &Context = MF->getFunction().getContext();
- SPIRVType *ElemType =
- GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
- SPIRV::AccessQualifier::ReadWrite, false);
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
+ SPIRVType *NewPtrType =
+ GR.getOrCreateSPIRVPointerType(IntegerType::getInt8Ty(Context), MIB, SC);
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 946a295c2df25..c41387559982c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1259,14 +1259,18 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
Register SrcReg = I.getOperand(1).getReg();
bool Result = true;
if (I.getOpcode() == TargetOpcode::G_MEMSET) {
+ MachineIRBuilder MIRBuilder(I);
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
- SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
- SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
- Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
+ Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
+ Type *ArrTy = ArrayType::get(ValTy, Num);
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
- ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
+ ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);
+
+ SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
+ ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
+ Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
Function &CurFunction = GR.CurMF->getFunction();
Type *LLVMArrTy =
@@ -1289,7 +1293,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
- ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
+ ValTy, I, SPIRV::StorageClass::UniformConstant);
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
}
@@ -1590,7 +1594,7 @@ static bool isASCastInGVar(MachineRegisterInfo *MRI, Register ResVReg) {
Register SPIRVInstructionSelector::getUcharPtrTypeReg(
MachineInstr &I, SPIRV::StorageClass::StorageClass SC) const {
return GR.getSPIRVTypeID(GR.getOrCreateSPIRVPointerType(
- GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, SC));
+ Type::getInt8Ty(I.getMF()->getFunction().getContext()), I, SC));
}
MachineInstrBuilder
@@ -1608,8 +1612,8 @@ SPIRVInstructionSelector::buildSpecConstantOp(MachineInstr &I, Register Dest,
MachineInstrBuilder
SPIRVInstructionSelector::buildConstGenericPtr(MachineInstr &I, Register SrcPtr,
SPIRVType *SrcPtrTy) const {
- SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
- GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
+ SPIRVType *GenericPtrTy =
+ GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
MRI->setType(Tmp, LLT::pointer(storageClassToAddressSpace(
SPIRV::StorageClass::Generic),
@@ -1694,8 +1698,8 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
// Casting between 2 eligible pointers using Generic as an intermediary.
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
- SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
- GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
+ SPIRVType *GenericPtrTy =
+ GR.changePointerStorageClass(SrcPtrTy, SPIRV::StorageClass::Generic, I);
Register Tmp = createVirtualRegister(GenericPtrTy, &GR, MRI, MRI->getMF());
bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
.addDef(Tmp)
@@ -3366,18 +3370,20 @@ bool SPIRVInstructionSelector::selectImageWriteIntrinsic(
}
Register SPIRVInstructionSelector::buildPointerToResource(
- const SPIRVType *ResType, SPIRV::StorageClass::StorageClass SC,
+ const SPIRVType *SpirvResType, SPIRV::StorageClass::StorageClass SC,
uint32_t Set, uint32_t Binding, uint32_t ArraySize, Register IndexReg,
bool IsNonUniform, MachineIRBuilder MIRBuilder) const {
+ Type *ResType = const_cast<Type *>(GR.getTypeForSPIRVType(SpirvResType));
if (ArraySize == 1) {
- SPIRVType *PtrType =
- GR.getOrCreateSPIRVPointerType(ResType, MIRBuilder, SC);
+ SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
+ const_cast<Type *>(ResType), MIRBuilder, SC);
+ assert(GR.getPointeeType(PtrType) == SpirvResType &&
+ "SpirvResType did not have an explicit layout.");
return GR.getOrCreateGlobalVariableWithBinding(PtrType, Set, Binding,
MIRBuilder);
}
- const SPIRVType *VarType = GR.getOrCreateSPIRVArrayType(
- ResType, ArraySize, *MIRBuilder.getInsertPt(), TII);
+ Type *VarType = ArrayType::get(ResType, ArraySize);
SPIRVType *VarPointerType =
GR.getOrCreateSPIRVPointerType(VarType, MIRBuilder, SC);
Register VarReg = GR.getOrCreateGlobalVariableWithBinding(
@@ -3807,17 +3813,6 @@ bool SPIRVInstructionSelector::selectGlobalValue(
MachineIRBuilder MIRBuilder(I);
const GlobalValue *GV = I.getOperand(1).getGlobal();
Type *GVType = toTypedPointer(GR.getDeducedGlobalValueType(GV));
- SPIRVType *PointerBaseType;
- if (GVType->isArrayTy()) {
- SPIRVType *ArrayElementType =
- GR.getOrCreateSPIRVType(GVType->getArrayElementType(), MIRBuilder,
- SPIRV::AccessQualifier::ReadWrite, false);
- PointerBaseType = GR.getOrCreateSPIRVArrayType(
- ArrayElementType, GVType->getArrayNumElements(), I, TII);
- } else {
- PointerBaseType = GR.getOrCreateSPIRVType(
- GVType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false);
- }
std::string GlobalIdent;
if (!GV->hasName()) {
@@ -3850,7 +3845,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
? dyn_cast<Function>(GV)
: nullptr;
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
- PointerBaseType, I, TII,
+ GVType, I,
GVFun ? SPIRV::StorageClass::CodeSectionINTEL
: addressSpaceToStorageClass(GV->getAddressSpace(), STI));
if (GVFun) {
@@ -3908,8 +3903,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
const unsigned AddrSpace = GV->getAddressSpace();
SPIRV::StorageClass::StorageClass StorageClass =
addressSpaceToStorageClass(AddrSpace, STI);
- SPIRVType *ResType =
- GR.getOrCreateSPIRVPointerType(PointerBaseType, I, TII, StorageClass);
+ SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass);
Register Reg = GR.buildGlobalVariable(
ResVReg, ResType, GlobalIdent, GV, StorageClass, Init,
GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index e4cc03eff1035..3fcff3dd8f553 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -251,10 +251,8 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
- SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
- ElemTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
- BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+ ElemTy, MI,
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
// If the ptrcast would be redundant, replace all uses with the source
@@ -366,9 +364,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
RegType.getAddressSpace()) {
const SPIRVSubtarget &ST =
MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
- SpvType = GR->getOrCreateSPIRVPointerType(
- GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
- addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
+ auto TSC = addressSpaceToStorageClass(RegType.getAddressSpace(), ST);
+ SpvType = GR->changePointerStorageClass(SpvType, TSC, *MI);
}
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
}
@@ -518,10 +515,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Reg = MI.getOperand(1).getReg();
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
- SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
- ElementTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
- BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+ ElementTy, MI,
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
MachineInstr *Def = MRI.getVRegDef(Reg);
assert(Def && "Expecting an instruction that defines the register");
More information about the llvm-commits
mailing list