[llvm] [SPIR-V] Fix generation of gMIR vs. SPIR-V code from utility methods (PR #128159)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 21 02:39:37 PST 2025
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/128159
>From d579bb2192a9412693f19492d6d692d3083232bb Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 20 Feb 2025 13:42:20 -0800
Subject: [PATCH 1/2] fix generation of gMIR vs. SPIR-V
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 16 ++++++----
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 3 +-
.../Target/SPIRV/SPIRVEmitNonSemanticDI.cpp | 6 ++--
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 32 +++++++++++--------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 22 ++++++-------
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 5 +--
6 files changed, 46 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 7b897f7e34c6f..c65166902550c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -476,8 +476,9 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
- TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
- FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
+ TrueConst =
+ GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
+ FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
} else {
TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
@@ -1457,7 +1458,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
ToTruncate = DefaultReg;
}
auto NewRegister =
- GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+ GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
MIRBuilder.buildCopy(DefaultReg, NewRegister);
} else { // If it could be in range, we need to load from the given builtin.
auto Vec3Ty =
@@ -1492,13 +1493,14 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
// Use G_ICMP to check if idxVReg < 3.
- MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
- GR->buildConstantInt(3, MIRBuilder, IndexType));
+ MIRBuilder.buildICmp(
+ CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
+ GR->buildConstantInt(3, MIRBuilder, IndexType, true));
// Get constant for the default value (0 or 1 depending on which
// function).
Register DefaultRegister =
- GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+ GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
// Get a register for the selection result (possibly a new temporary one).
Register SelectionResult = Call->ReturnRegister;
@@ -2277,7 +2279,7 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
SpvFieldTy, *ST.getInstrInfo());
} else {
- Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+ Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
}
if (!LocalWorkSize.isValid())
LocalWorkSize = Const;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 78f6b188c45c1..e47dfddd55975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -669,7 +669,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
- SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
+ SPIRVType *RetType = GR->assignTypeToVReg(
+ OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
// Emit the call instruction and its args.
auto MIB = MIRBuilder.buildInstr(CallOp)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
index b98cef0a4f07f..ee98af5cffe4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
@@ -193,7 +193,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};
const SPIRVType *VoidTy =
- GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder);
+ GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, false);
const auto EmitDIInstruction =
[&](SPIRV::NonSemanticExtInst::NonSemanticExtInst Inst,
@@ -217,7 +218,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
};
const SPIRVType *I32Ty =
- GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder);
+ GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, false);
const Register DwarfVersionReg =
GR->buildConstantInt(DwarfVersion, MIRBuilder, I32Ty, false);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e2f1b211caa5c..a09474a21534e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -244,7 +244,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
- assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
+ assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
else
assignIntTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
@@ -271,7 +272,8 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
- assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+ assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
else
assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
DT.add(CI, CurMF, Res);
@@ -878,12 +880,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
});
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
- MachineIRBuilder &MIRBuilder,
- bool EmitIR) {
+SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
+ const StructType *Ty, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
+ SPIRVType *ElemTy =
+ findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -1017,26 +1020,27 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
- SPIRVType *El =
- findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
+ SPIRVType *El = findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
+ MIRBuilder, AccQual, EmitIR);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
- SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
+ SPIRVType *El =
+ findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR);
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
}
if (auto SType = dyn_cast<StructType>(Ty)) {
if (SType->isOpaque())
return getOpTypeOpaque(SType, MIRBuilder);
- return getOpTypeStruct(SType, MIRBuilder, EmitIR);
+ return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR);
}
if (auto FType = dyn_cast<FunctionType>(Ty)) {
- SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
+ SPIRVType *RetTy =
+ findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR);
SmallVector<SPIRVType *, 4> ParamTypes;
- for (const auto &t : FType->params()) {
- ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
- }
+ for (const auto &ParamTy : FType->params())
+ ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 0c94ec4df97f5..5ad705f197d5a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,13 +94,11 @@ class SPIRVGlobalRegistry {
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AQ =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR);
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier accessQual =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier accessQual,
+ bool EmitIR);
SPIRVType *
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual,
@@ -321,9 +319,8 @@ class SPIRVGlobalRegistry {
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AQ =
- SPIRV::AccessQualifier::ReadWrite,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR);
SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I, const SPIRVInstrInfo &TII);
SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
@@ -470,13 +467,14 @@ class SPIRVGlobalRegistry {
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
- MachineIRBuilder &MIRBuilder, bool EmitIR = true);
+ MachineIRBuilder &MIRBuilder, bool EmitIR);
SPIRVType *getOpTypeOpaque(const StructType *Ty,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
- bool EmitIR = true);
+ SPIRV::AccessQualifier::AccessQualifier AccQual,
+ bool EmitIR);
SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -539,7 +537,7 @@ class SPIRVGlobalRegistry {
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType, bool EmitIR = true);
+ SPIRVType *SpvType, bool EmitIR);
Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType);
Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index d5b81bf46c804..f622be893919f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -192,7 +192,7 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
// Insert a bitcast before the instruction to keep SPIR-V code valid.
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *NewPtrType =
- createNewPtrType(GR, I, OpType, false, true, nullptr,
+ createNewPtrType(GR, I, OpType, false, false, nullptr,
TargetExtType::get(Context, "spirv.Event"));
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
@@ -216,7 +216,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
MachineIRBuilder MIB(I);
LLVMContext &Context = MF->getFunction().getContext();
SPIRVType *ElemType =
- GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
+ GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
+ SPIRV::AccessQualifier::ReadWrite, false);
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}
>From cfe50c14555d4ebdb24e7f07ebc934a8f2279870 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 21 Feb 2025 02:39:27 -0800
Subject: [PATCH 2/2] add explicit emitir flag in utils
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 6 ++++--
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 19 ++++++++++++-------
llvm/lib/Target/SPIRV/SPIRVUtils.h | 9 ++++++---
3 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e47dfddd55975..e4f144eeffbbc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -557,10 +557,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
RetTy =
TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
}
- setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
+ setRegClassType(ResVReg, RetTy, GR, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
} else {
- ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
+ ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, true);
}
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index ddc66f98829a9..c55b735314228 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -738,9 +738,12 @@ void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
// no valid assigned class, set register LLT type and class according to the
// SPIR-V type.
void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder, bool Force) {
- setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
- MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool EmitIR, bool Force) {
+ setRegClassType(Reg,
+ GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR),
+ GR, MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
}
// Create a virtual register and assign SPIR-V type to the register. Set
@@ -764,10 +767,12 @@ Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
// Create a SPIR-V type, virtual register and assign SPIR-V type to the
// register. Set register LLT type and class according to the SPIR-V type.
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder) {
- return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
- MIRBuilder);
+Register createVirtualRegister(
+ const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
+ return createVirtualRegister(
+ GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR), GR,
+ MIRBuilder);
}
// Return true if there is an opaque pointer type nested in the argument.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 552adf2df7d17..870649879218a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -411,7 +411,9 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
bool getVacantFunctionName(Module &M, std::string &Name);
void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder, bool Force = false);
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool EmitIR, bool Force = false);
void setRegClassType(Register Reg, const MachineInstr *SpvType,
SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
const MachineFunction &MF, bool Force = false);
@@ -422,8 +424,9 @@ Register createVirtualRegister(const MachineInstr *SpvType,
Register createVirtualRegister(const MachineInstr *SpvType,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIRBuilder);
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(
+ const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR);
// Return true if there is an opaque pointer type nested in the argument.
bool isNestedPointer(const Type *Ty);
More information about the llvm-commits
mailing list