[llvm] [SPIR-V] Fix generation of gMIR vs. SPIR-V code from utility methods (PR #128159)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 24 02:36:29 PST 2025
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/128159
>From d579bb2192a9412693f19492d6d692d3083232bb Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 20 Feb 2025 13:42:20 -0800
Subject: [PATCH 1/3] fix generation of gMIR vs. SPIR-V
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 16 ++++++----
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 3 +-
.../Target/SPIRV/SPIRVEmitNonSemanticDI.cpp | 6 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 32 +++++++++++--------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 22 ++++++-------
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 5 +--
6 files changed, 46 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 7b897f7e34c6f..c65166902550c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -476,8 +476,9 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
- TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
- FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
+ TrueConst =
+ GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
+ FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
} else {
TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
@@ -1457,7 +1458,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
ToTruncate = DefaultReg;
}
auto NewRegister =
- GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+ GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
MIRBuilder.buildCopy(DefaultReg, NewRegister);
} else { // If it could be in range, we need to load from the given builtin.
auto Vec3Ty =
@@ -1492,13 +1493,14 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
// Use G_ICMP to check if idxVReg < 3.
- MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
- GR->buildConstantInt(3, MIRBuilder, IndexType));
+ MIRBuilder.buildICmp(
+ CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
+ GR->buildConstantInt(3, MIRBuilder, IndexType, true));
// Get constant for the default value (0 or 1 depending on which
// function).
Register DefaultRegister =
- GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+ GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
// Get a register for the selection result (possibly a new temporary one).
Register SelectionResult = Call->ReturnRegister;
@@ -2277,7 +2279,7 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
SpvFieldTy, *ST.getInstrInfo());
} else {
- Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+ Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
}
if (!LocalWorkSize.isValid())
LocalWorkSize = Const;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 78f6b188c45c1..e47dfddd55975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -669,7 +669,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
- SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
+ SPIRVType *RetType = GR->assignTypeToVReg(
+ OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
// Emit the call instruction and its args.
auto MIB = MIRBuilder.buildInstr(CallOp)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
index b98cef0a4f07f..ee98af5cffe4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
@@ -193,7 +193,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};
const SPIRVType *VoidTy =
- GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder);
+ GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, false);
const auto EmitDIInstruction =
[&](SPIRV::NonSemanticExtInst::NonSemanticExtInst Inst,
@@ -217,7 +218,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};
const SPIRVType *I32Ty =
- GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder);
+ GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, false);
const Register DwarfVersionReg =
GR->buildConstantInt(DwarfVersion, MIRBuilder, I32Ty, false);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e2f1b211caa5c..a09474a21534e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -244,7 +244,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
- assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
+ assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
else
assignIntTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
@@ -271,7 +272,8 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
- assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+ assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
else
assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
@@ -878,12 +880,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
});
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
- MachineIRBuilder &MIRBuilder,
- bool EmitIR) {
+SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
+ const StructType *Ty, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -1017,26 +1020,27 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
- SPIRVType *El =
- findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
+ SPIRVType *El = findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
+ MIRBuilder, AccQual, EmitIR);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
- SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
+ SPIRVType *El =
+ findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR);
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
}
if (auto SType = dyn_cast<StructType>(Ty)) {
if (SType->isOpaque())
return getOpTypeOpaque(SType, MIRBuilder);
- return getOpTypeStruct(SType, MIRBuilder, EmitIR);
+ return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR);
}
if (auto FType = dyn_cast<FunctionType>(Ty)) {
- SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
+ SPIRVType *RetTy =
+ findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR);
SmallVector<SPIRVType *, 4> ParamTypes;
- for (const auto &t : FType->params()) {
- ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
- }
+ for (const auto &ParamTy : FType->params())
+ ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 0c94ec4df97f5..5ad705f197d5a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,13 +94,11 @@ class SPIRVGlobalRegistry {
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AQ =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR);
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier accessQual =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier accessQual,
+ bool EmitIR);
SPIRVType *
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual,
@@ -321,9 +319,8 @@ class SPIRVGlobalRegistry {
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AQ =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR);
SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I, const SPIRVInstrInfo &TII);
SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
@@ -470,13 +467,14 @@ class SPIRVGlobalRegistry {
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
- MachineIRBuilder &MIRBuilder, bool EmitIR = true);
+ MachineIRBuilder &MIRBuilder, bool EmitIR);
SPIRVType *getOpTypeOpaque(const StructType *Ty,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AccQual,
+ bool EmitIR);
SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -539,7 +537,7 @@ class SPIRVGlobalRegistry {
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType, bool EmitIR = true);
+ SPIRVType *SpvType, bool EmitIR);
Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType);
Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index d5b81bf46c804..f622be893919f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -192,7 +192,7 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
// Insert a bitcast before the instruction to keep SPIR-V code valid.
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *NewPtrType =
- createNewPtrType(GR, I, OpType, false, true, nullptr,
+ createNewPtrType(GR, I, OpType, false, false, nullptr,
TargetExtType::get(Context, "spirv.Event"));
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
@@ -216,7 +216,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
MachineIRBuilder MIB(I);
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *ElemType =
- GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
+ GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
+ SPIRV::AccessQualifier::ReadWrite, false);
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}
>From cfe50c14555d4ebdb24e7f07ebc934a8f2279870 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 21 Feb 2025 02:39:27 -0800
Subject: [PATCH 2/3] add explicit emitir flag in utils
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 6 ++++--
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 19 ++++++++++++-------
llvm/lib/Target/SPIRV/SPIRVUtils.h | 9 ++++++---
3 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e47dfddd55975..e4f144eeffbbc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -557,10 +557,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
RetTy =
TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
}
- setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
+ setRegClassType(ResVReg, RetTy, GR, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
} else {
- ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
+ ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index ddc66f98829a9..c55b735314228 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -738,9 +738,12 @@ void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
// no valid assigned class, set register LLT type and class according to the
// SPIR-V type.
void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder, bool Force) {
- setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
- MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool EmitIR, bool Force) {
+ setRegClassType(Reg,
+ GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR),
+ GR, MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
}
// Create a virtual register and assign SPIR-V type to the register. Set
@@ -764,10 +767,12 @@ Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
// Create a SPIR-V type, virtual register and assign SPIR-V type to the
// register. Set register LLT type and class according to the SPIR-V type.
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder) {
- return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
- MIRBuilder);
+Register createVirtualRegister(
+ const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
+ return createVirtualRegister(
+ GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR), GR,
+ MIRBuilder);
}
// Return true if there is an opaque pointer type nested in the argument.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 552adf2df7d17..870649879218a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -411,7 +411,9 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
bool getVacantFunctionName(Module &M, std::string &Name);
void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder, bool Force = false);
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool EmitIR, bool Force = false);
void setRegClassType(Register Reg, const MachineInstr *SpvType,
SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
const MachineFunction &MF, bool Force = false);
@@ -422,8 +424,9 @@ Register createVirtualRegister(const MachineInstr *SpvType,
Register createVirtualRegister(const MachineInstr *SpvType,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIRBuilder);
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(
+ const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR);
// Return true if there is an opaque pointer type nested in the argument.
bool isNestedPointer(const Type *Ty);
>From f15671d718a779f54a54cd53a9c4d56609564ac7 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 24 Feb 2025 02:36:16 -0800
Subject: [PATCH 3/3] remove EmitIR default call chains
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 51 ++++++++++---------
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 29 +++++++----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 40 ++++++++-------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 18 ++++---
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 21 ++++----
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 19 +++----
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 34 ++++++++-----
8 files changed, 123 insertions(+), 92 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c65166902550c..907031388e3e7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -445,12 +445,12 @@ static std::tuple<Register, SPIRVType *>
buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
SPIRVGlobalRegistry *GR) {
LLT Type;
- SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
+ SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
if (ResultType->getOpcode() == SPIRV::OpTypeVector) {
unsigned VectorElements = ResultType->getOperand(2).getImm();
- BoolType =
- GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder);
+ BoolType = GR->getOrCreateSPIRVVectorType(BoolType, VectorElements,
+ MIRBuilder, true);
const FixedVectorType *LLVMVectorType =
cast<FixedVectorType>(GR->getTypeForSPIRVType(BoolType));
Type = LLT::vector(LLVMVectorType->getElementCount(), 1);
@@ -480,8 +480,8 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
} else {
- TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
- FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
+ TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType, true);
+ FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType, true);
}
return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst,
@@ -581,8 +581,8 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
static Register buildConstantIntReg32(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
- return GR->buildConstantInt(Val, MIRBuilder,
- GR->getOrCreateSPIRVIntegerType(32, MIRBuilder));
+ return GR->buildConstantInt(
+ Val, MIRBuilder, GR->getOrCreateSPIRVIntegerType(32, MIRBuilder), true);
}
static Register buildScopeReg(Register CLScopeRegister,
@@ -1153,7 +1153,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
Register Arg0;
if (GroupBuiltin->HasBoolArg) {
- SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
+ SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
Register BoolReg = Call->Arguments[0];
SPIRVType *BoolRegType = GR->getSPIRVTypeForVReg(BoolReg);
if (!BoolRegType)
@@ -1162,14 +1162,15 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
if (ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT) {
if (BoolRegType->getOpcode() != SPIRV::OpTypeBool)
Arg0 = GR->buildConstantInt(getIConstVal(BoolReg, MRI), MIRBuilder,
- BoolType);
+ BoolType, true);
} else {
if (BoolRegType->getOpcode() == SPIRV::OpTypeInt) {
Arg0 = MRI->createGenericVirtualRegister(LLT::scalar(1));
MRI->setRegClass(Arg0, &SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(BoolType, Arg0, MIRBuilder.getMF());
- MIRBuilder.buildICmp(CmpInst::ICMP_NE, Arg0, BoolReg,
- GR->buildConstantInt(0, MIRBuilder, BoolRegType));
+ MIRBuilder.buildICmp(
+ CmpInst::ICMP_NE, Arg0, BoolReg,
+ GR->buildConstantInt(0, MIRBuilder, BoolRegType, true));
insertAssignInstr(Arg0, nullptr, BoolType, GR, MIRBuilder,
MIRBuilder.getMF().getRegInfo());
} else if (BoolRegType->getOpcode() != SPIRV::OpTypeBool) {
@@ -1214,7 +1215,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
SPIRVType *VecType =
- GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder);
+ GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder, true);
GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
auto MIB =
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
@@ -1462,7 +1463,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
MIRBuilder.buildCopy(DefaultReg, NewRegister);
} else { // If it could be in range, we need to load from the given builtin.
auto Vec3Ty =
- GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
+ GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder, true);
Register LoadedVector =
buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue,
LLT::fixed_vector(3, PointerSize));
@@ -1485,7 +1486,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
*MRI);
auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister);
- auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
+ auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder, true);
Register CompareRegister =
MRI->createGenericVirtualRegister(LLT::scalar(1));
@@ -1814,7 +1815,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::vIDRegClass);
SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
QueryResultType = GR->getOrCreateSPIRVVectorType(
- IntTy, NumActualRetComponents, MIRBuilder);
+ IntTy, NumActualRetComponents, MIRBuilder, true);
GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF());
}
bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
@@ -1971,7 +1972,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
if (Call->ReturnType->getOpcode() != SPIRV::OpTypeVector) {
SPIRVType *TempType =
- GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder);
+ GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder, true);
Register TempRegister =
MRI->createGenericVirtualRegister(GR->getRegType(TempType));
MRI->setRegClass(TempRegister, GR->getRegClass(TempType));
@@ -2069,7 +2070,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
SPIRVType *Type =
Call->ReturnType
? Call->ReturnType
- : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
+ : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder, true);
if (!Type) {
std::string DiagMsg =
"Unable to recognize SPIRV type name: " + ReturnType;
@@ -2267,7 +2268,8 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
Type *FieldTy = ArrayType::get(BaseTy, Size);
- SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
+ SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(
+ FieldTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
MIRBuilder.buildInstr(SPIRV::OpLoad)
@@ -2305,7 +2307,8 @@ getOrCreateSPIRVDeviceEventPointer(MachineIRBuilder &MIRBuilder,
LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext();
unsigned SC1 = storageClassToAddressSpace(SPIRV::StorageClass::Generic);
Type *PtrType = PointerType::get(Context, SC1);
- return GR->getOrCreateSPIRVType(PtrType, MIRBuilder);
+ return GR->getOrCreateSPIRVType(PtrType, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
@@ -2454,7 +2457,7 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
SPIRVType *NewType =
Call->ReturnType->getOpcode() == SPIRV::OpTypeEvent
? nullptr
- : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder);
+ : GR->getOrCreateSPIRVTypeByName("spirv.Event", MIRBuilder, true);
Register TypeReg = GR->getSPIRVTypeID(NewType ? NewType : Call->ReturnType);
unsigned NumArgs = Call->Arguments.size();
Register EventReg = Call->Arguments[NumArgs - 1];
@@ -2955,12 +2958,13 @@ static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V coop matrices builtin type must have a type parameter!");
const SPIRVType *ElemType =
- GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+ GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
// Create or get an existing type from GlobalRegistry.
return GR->getOrCreateOpTypeCoopMatr(
MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
- ExtensionType->getIntParameter(3));
+ ExtensionType->getIntParameter(3), true);
}
static SPIRVType *
@@ -2970,7 +2974,8 @@ getImageType(const TargetExtType *ExtensionType,
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V image builtin type must have sampled type parameter!");
const SPIRVType *SampledType =
- GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+ GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
assert((ExtensionType->getNumIntParameters() == 7 ||
ExtensionType->getNumIntParameters() == 6) &&
"Invalid number of parameters for SPIR-V image builtin!");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e4f144eeffbbc..b5a3b1953d5a6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,13 +209,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
if (!isPointerTy(OriginalArgType))
- return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
+ return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual,
+ true);
Argument *Arg = F.getArg(ArgIdx);
Type *ArgType = Arg->getType();
if (isTypedPointerTy(ArgType)) {
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
- cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
+ cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
@@ -231,7 +233,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// type.
if (hasPointeeTypeAttr(Arg)) {
SPIRVType *ElementType =
- GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
+ GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
@@ -245,7 +248,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
Type *BuiltinType =
cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
- return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
+ return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual,
+ true);
}
// Check if this is spv_assign_ptr_type assigning pointer element type.
@@ -255,7 +259,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
Type *ElementTy =
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
- SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
+ SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+ ElementTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(
@@ -265,7 +270,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// Replace PointerType with TypedPointerType to be able to map SPIR-V types to
// LLVM types in a consistent manner
return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
- ArgAccessQual);
+ ArgAccessQual, true);
}
static SPIRV::ExecutionModel::ExecutionModel
@@ -405,7 +410,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
FRetTy = DerivedTy;
}
}
- SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
+ SPIRVType *RetTy = GR->getOrCreateSPIRVType(
+ FRetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
@@ -486,10 +492,12 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
// Create indirect call data types if any
MachineFunction &MF = MIRBuilder.getMF();
for (auto const &IC : IndirectCalls) {
- SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
+ SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(
+ IC.RetTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
SmallVector<SPIRVType *, 4> SpirvArgTypes;
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
- SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
+ SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(
+ IC.ArgTys[i], MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
SpirvArgTypes.push_back(SPIRVTy);
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
@@ -586,7 +594,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
ArgTy = Arg.Ty;
}
if (ArgTy) {
- SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder);
+ SpvType = GR->getOrCreateSPIRVType(
+ ArgTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a09474a21534e..df08c61d60c39 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -412,7 +412,8 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
auto &Ctx = MF.getFunction().getContext();
if (!SpvType) {
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
- SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
+ SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
// Find a constant in DT or build a new one.
const auto ConstFP = ConstantFP::get(Ctx, Val);
@@ -653,9 +654,10 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
SPIRVType *SampTy;
if (SpvType)
- SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
- else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
- MIRBuilder)) == nullptr)
+ SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
+ else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder,
+ false)) == nullptr)
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
auto Sampler =
@@ -1381,7 +1383,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
- uint32_t Use) {
+ uint32_t Use, bool EmitIR) {
Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
@@ -1391,10 +1393,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
- .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
- .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
- .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
- .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
+ .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, EmitIR))
+ .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, EmitIR))
+ .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, EmitIR))
+ .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, EmitIR));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
@@ -1421,7 +1423,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
// Returns nullptr if unable to recognize SPIRV type name
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
- StringRef TypeStr, MachineIRBuilder &MIRBuilder,
+ StringRef TypeStr, MachineIRBuilder &MIRBuilder, bool EmitIR,
SPIRV::StorageClass::StorageClass SC,
SPIRV::AccessQualifier::AccessQualifier AQ) {
unsigned VecElts = 0;
@@ -1431,7 +1433,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
if (hasBuiltinTypePrefix(TypeStr))
return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
TypeStr.str(), MIRBuilder.getContext()),
- MIRBuilder, AQ);
+ MIRBuilder, AQ, true);
// Parse type name in either "typeN" or "type vector[N]" format, where
// N is the number of elements of the vector.
@@ -1442,7 +1444,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
// Unable to recognize SPIRV type name
return nullptr;
- auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
+ auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ, true);
// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
@@ -1458,7 +1460,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
}
TypeStr.getAsInteger(10, VecElts);
if (VecElts > 0)
- SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
+ SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder, EmitIR);
if (IsPtrToVec)
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
@@ -1471,7 +1473,7 @@ SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
- MIRBuilder);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
}
SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
@@ -1531,10 +1533,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
}
SPIRVType *
-SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
+SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
+ bool EmitIR) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
- MIRBuilder);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
}
SPIRVType *
@@ -1552,11 +1555,12 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
- SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
+ SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder,
+ bool EmitIR) {
return getOrCreateSPIRVType(
FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
NumElements),
- MIRBuilder);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 5ad705f197d5a..2c24ba79ea8e6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -341,9 +341,8 @@ class SPIRVGlobalRegistry {
// want to emit extra IR instructions there.
SPIRVType *getOrCreateSPIRVType(const Type *Type,
MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AQ =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR);
const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
auto Res = SPIRVToLLVMType.find(Ty);
@@ -360,7 +359,7 @@ class SPIRVGlobalRegistry {
// corresponding to the given string containing the name of the builtin type.
// Return nullptr if unable to recognize SPIRV type name from `TypeStr`.
SPIRVType *getOrCreateSPIRVTypeByName(
- StringRef TypeStr, MachineIRBuilder &MIRBuilder,
+ StringRef TypeStr, MachineIRBuilder &MIRBuilder, bool EmitIR,
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
SPIRV::AccessQualifier::AccessQualifier AQ =
SPIRV::AccessQualifier::ReadWrite);
@@ -516,7 +515,7 @@ class SPIRVGlobalRegistry {
public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType, bool EmitIR = true,
+ SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
@@ -568,12 +567,14 @@ class SPIRVGlobalRegistry {
unsigned SPIRVOPcode, Type *LLVMTy);
SPIRVType *getOrCreateSPIRVFloatType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII);
- SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
+ SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
+ bool EmitIR);
SPIRVType *getOrCreateSPIRVBoolType(MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements,
- MachineIRBuilder &MIRBuilder);
+ MachineIRBuilder &MIRBuilder,
+ bool EmitIR);
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);
@@ -603,7 +604,8 @@ class SPIRVGlobalRegistry {
const TargetExtType *ExtensionType,
const SPIRVType *ElemType,
uint32_t Scope, uint32_t Rows,
- uint32_t Columns, uint32_t Use);
+ uint32_t Columns, uint32_t Use,
+ bool EmitIR);
SPIRVType *
getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index f622be893919f..c347dde89256f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -126,8 +126,7 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
SPIRVType *OpType, bool ReuseType,
- bool EmitIR, SPIRVType *ResType,
- const Type *ResTy) {
+ SPIRVType *ResType, const Type *ResTy) {
SPIRV::StorageClass::StorageClass SC =
static_cast<SPIRV::StorageClass::StorageClass>(
OpType->getOperand(1).getImm());
@@ -135,7 +134,7 @@ static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
SPIRVType *NewBaseType =
ReuseType ? ResType
: GR.getOrCreateSPIRVType(
- ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
}
@@ -166,7 +165,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
SPIRVType *NewPtrType =
- createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
+ createNewPtrType(GR, I, OpType, IsSameMF, ResType, ResTy);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
@@ -192,7 +191,7 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
// Insert a bitcast before the instruction to keep SPIR-V code valid.
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *NewPtrType =
- createNewPtrType(GR, I, OpType, false, false, nullptr,
+ createNewPtrType(GR, I, OpType, false, nullptr,
TargetExtType::get(Context, "spirv.Event"));
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
@@ -493,12 +492,12 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
assert(RetType && "Expected return type");
- validatePtrTypes(
- STI, MRI, GR, MI, MI.getNumOperands() - 1,
- RetType->getOpcode() != SPIRV::OpTypeVector
- ? Int32Type
- : GR.getOrCreateSPIRVVectorType(
- Int32Type, RetType->getOperand(2).getImm(), MIB));
+ validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ RetType->getOpcode() != SPIRV::OpTypeVector
+ ? Int32Type
+ : GR.getOrCreateSPIRVVectorType(
+ Int32Type, RetType->getOperand(2).getImm(),
+ MIB, false));
} break;
case SPIRV::OpenCLExtInst::fract:
case SPIRV::OpenCLExtInst::modf:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e7d8fe5bd8015..888cc9ec05277 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3391,9 +3391,10 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
MachineIRBuilder MIRBuilder(I);
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
SPIRVType *I64Type = GR.getOrCreateSPIRVIntegerType(64, MIRBuilder);
- SPIRVType *I64x2Type = GR.getOrCreateSPIRVVectorType(I64Type, 2, MIRBuilder);
+ SPIRVType *I64x2Type =
+ GR.getOrCreateSPIRVVectorType(I64Type, 2, MIRBuilder, false);
SPIRVType *Vec2ResType =
- GR.getOrCreateSPIRVVectorType(BaseType, 2, MIRBuilder);
+ GR.getOrCreateSPIRVVectorType(BaseType, 2, MIRBuilder, false);
std::vector<Register> PartialRegs;
@@ -3476,8 +3477,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
// 1. Split int64 into 2 pieces using a bitcast
MachineIRBuilder MIRBuilder(I);
- SPIRVType *PostCastType =
- GR.getOrCreateSPIRVVectorType(BaseType, 2 * ComponentCount, MIRBuilder);
+ SPIRVType *PostCastType = GR.getOrCreateSPIRVVectorType(
+ BaseType, 2 * ComponentCount, MIRBuilder, false);
Register BitcastReg =
MRI->createVirtualRegister(GR.getRegClass(PostCastType));
@@ -3554,8 +3555,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SelectOp = SPIRV::OpSelectSISCond;
AddOp = SPIRV::OpIAddS;
} else {
- BoolType =
- GR.getOrCreateSPIRVVectorType(BoolType, ComponentCount, MIRBuilder);
+ BoolType = GR.getOrCreateSPIRVVectorType(BoolType, ComponentCount,
+ MIRBuilder, false);
NegOneReg =
GR.getOrCreateConstVector((unsigned)-1, I, ResType, TII, ZeroAsNull);
Reg0 = GR.getOrCreateConstVector(0, I, ResType, TII, ZeroAsNull);
@@ -3922,7 +3923,7 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
MachineIRBuilder MIRBuilder(I);
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
const SPIRVType *Vec3Ty =
- GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
+ GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder, false);
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
@@ -3971,7 +3972,7 @@ SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
MachineInstr &I) const {
MachineIRBuilder MIRBuilder(I);
if (Type->getOpcode() != SPIRV::OpTypeVector)
- return GR.getOrCreateSPIRVVectorType(Type, 4, MIRBuilder);
+ return GR.getOrCreateSPIRVVectorType(Type, 4, MIRBuilder, false);
uint64_t VectorSize = Type->getOperand(2).getImm();
if (VectorSize == 4)
@@ -3979,7 +3980,7 @@ SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type,
Register ScalarTypeReg = Type->getOperand(1).getReg();
const SPIRVType *ScalarType = GR.getSPIRVTypeForVReg(ScalarTypeReg);
- return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder);
+ return GR.getOrCreateSPIRVVectorType(ScalarType, 4, MIRBuilder, false);
}
bool SPIRVInstructionSelector::loadHandleBeforePosition(
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index fa5e0a80576d0..daa8ea52ffe03 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -405,7 +405,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
LLT ConvT = LLT::scalar(ST->getPointerSize());
Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
ST->getPointerSize());
- SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
+ SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(
+ LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 32f6af3d1440f..5d70b9c2a4a59 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -98,8 +98,9 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
TargetExtConstTypes[SrcMI] = Const->getType();
if (Const->isNullValue()) {
MachineIRBuilder MIB(MF);
- SPIRVType *ExtType =
- GR->getOrCreateSPIRVType(Const->getType(), MIB);
+ SPIRVType *ExtType = GR->getOrCreateSPIRVType(
+ Const->getType(), MIB, SPIRV::AccessQualifier::ReadWrite,
+ true);
SrcMI->setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
SrcMI->addOperand(MachineOperand::CreateReg(
GR->getSPIRVTypeID(ExtType), false));
@@ -248,7 +249,8 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
- SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElemTy, MIB);
+ SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
+ ElemTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
@@ -299,7 +301,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
case TargetOpcode::G_CONSTANT: {
MIB.setInsertPt(*MI->getParent(), MI);
Type *Ty = MI->getOperand(1).getCImm()->getType();
- SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
+ SpvType = GR->getOrCreateSPIRVType(
+ Ty, MIB, SPIRV::AccessQualifier::ReadWrite, true);
break;
}
case TargetOpcode::G_GLOBAL_VALUE: {
@@ -308,7 +311,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global));
auto *Ty = TypedPointerType::get(ElementTy,
Global->getType()->getAddressSpace());
- SpvType = GR->getOrCreateSPIRVType(Ty, MIB);
+ SpvType = GR->getOrCreateSPIRVType(
+ Ty, MIB, SPIRV::AccessQualifier::ReadWrite, true);
break;
}
case TargetOpcode::G_ANYEXT:
@@ -324,8 +328,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
unsigned NumElements = GR->getScalarOrVectorComponentCount(Def);
SpvType = GR->getOrCreateSPIRVIntegerType(ExpectedBW, MIB);
if (NumElements > 1)
- SpvType =
- GR->getOrCreateSPIRVVectorType(SpvType, NumElements, MIB);
+ SpvType = GR->getOrCreateSPIRVVectorType(SpvType, NumElements,
+ MIB, true);
}
}
}
@@ -431,7 +435,9 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
MachineInstr *Def = MRI.getVRegDef(Reg);
setInsertPtAfterDef(MIB, Def);
- SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
+ SpvType = SpvType ? SpvType
+ : GR->getOrCreateSPIRVType(
+ Ty, MIB, SPIRV::AccessQualifier::ReadWrite, true);
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
if (auto *RC = MRI.getRegClassOrNull(Reg)) {
MRI.setRegClass(NewReg, RC);
@@ -518,7 +524,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Reg = MI.getOperand(1).getReg();
MIB.setInsertPt(*MI.getParent(), MI.getIterator());
Type *ElementTy = getMDOperandAsType(MI.getOperand(2).getMetadata(), 0);
- SPIRVType *BaseTy = GR->getOrCreateSPIRVType(ElementTy, MIB);
+ SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
+ ElementTy, MIB, SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
@@ -737,9 +744,11 @@ insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
FunctionType *FTy = cast<FunctionType>(getMDOperandAsType(IAMD, 0));
SmallVector<SPIRVType *, 4> ArgTypes;
for (const auto &ArgTy : FTy->params())
- ArgTypes.push_back(GR->getOrCreateSPIRVType(ArgTy, MIRBuilder));
+ ArgTypes.push_back(GR->getOrCreateSPIRVType(
+ ArgTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true));
SPIRVType *RetType =
- GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+ GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
SPIRVType *FuncType = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetType, ArgTypes, MIRBuilder);
@@ -772,7 +781,8 @@ insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
DefReg = MRI.createGenericVirtualRegister(LLT::scalar(32));
MRI.setRegClass(DefReg, &SPIRV::iIDRegClass);
SPIRVType *VoidType = GR->getOrCreateSPIRVType(
- Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder);
+ Type::getVoidTy(MF.getFunction().getContext()), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
GR->assignSPIRVTypeToVReg(VoidType, DefReg, MF);
}
More information about the llvm-commits
mailing list