[llvm] 492ad84 - [SPIRV] Add explicit layout (#135789)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 5 15:49:15 PDT 2025
Author: Steven Perron
Date: 2025-05-05T18:49:12-04:00
New Revision: 492ad848b1c319ad9641208aaadb41bc575a9c3f
URL: https://github.com/llvm/llvm-project/commit/492ad848b1c319ad9641208aaadb41bc575a9c3f
DIFF: https://github.com/llvm/llvm-project/commit/492ad848b1c319ad9641208aaadb41bc575a9c3f.diff
LOG: [SPIRV] Add explicit layout (#135789)
Adds code to add offset decorations when needed. This could cause a
type mismatch for memory instructions. We add code to fix up OpLoad
instructions, so that we could get some tests. Other memory operations
will be handled in another PR.
Part of https://github.com/llvm/llvm-project/issues/134119.
Added:
llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/lib/Target/SPIRV/SPIRVIRMapping.h
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
llvm/lib/Target/SPIRV/SPIRVISelLowering.h
llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 88b1e44d15af0..35ddb906c366a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -55,7 +55,6 @@ static unsigned typeToAddressSpace(const Type *Ty) {
reportFatalInternalError("Unable to convert LLVM type to SPIRVType");
}
-#ifndef NDEBUG
static bool
storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
switch (SC) {
@@ -87,7 +86,6 @@ storageClassRequiresExplictLayout(SPIRV::StorageClass::StorageClass SC) {
}
llvm_unreachable("Unknown SPIRV::StorageClass enum");
}
-#endif
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize), Bound(0) {}
@@ -837,13 +835,31 @@ static std::string buildSpirvTypeName(const SPIRVType *Type,
}
case SPIRV::OpTypeStruct: {
std::string TypeName = "{";
- for (uint32_t I = 2; I < Type->getNumOperands(); ++I) {
+ for (uint32_t I = 1; I < Type->getNumOperands(); ++I) {
SPIRVType *MemberType =
GR.getSPIRVTypeForVReg(Type->getOperand(I).getReg());
- TypeName = '_' + buildSpirvTypeName(MemberType, MIRBuilder, GR);
+ TypeName += '_' + buildSpirvTypeName(MemberType, MIRBuilder, GR);
}
return TypeName + "}";
}
+ case SPIRV::OpTypeVector: {
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ Register ElementTypeReg = Type->getOperand(1).getReg();
+ auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
+ uint32_t VectorSize = GR.getScalarOrVectorComponentCount(Type);
+ return (buildSpirvTypeName(ElementType, MIRBuilder, GR) + Twine("[") +
+ Twine(VectorSize) + Twine("]"))
+ .str();
+ }
+ case SPIRV::OpTypeRuntimeArray: {
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ Register ElementTypeReg = Type->getOperand(1).getReg();
+ auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg);
+ uint32_t ArraySize = 0;
+ return (buildSpirvTypeName(ElementType, MIRBuilder, GR) + Twine("[") +
+ Twine(ArraySize) + Twine("]"))
+ .str();
+ }
default:
llvm_unreachable("Trying to the the name of an unknown type.");
}
@@ -885,30 +901,41 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
return VarReg;
}
+// TODO: Double check the calls to getOpTypeArray to make sure that `ElemType`
+// is explicitly laid out when required.
SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder,
+ bool ExplicitLayoutRequired,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
-
+ SPIRVType *ArrayType = nullptr;
if (NumElems != 0) {
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
- return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addUse(NumElementsVReg);
});
+ } else {
+ ArrayType = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray)
+ .addDef(createTypeVReg(MIRBuilder))
+ .addUse(getSPIRVTypeID(ElemType));
+ });
}
- return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return MIRBuilder.buildInstr(SPIRV::OpTypeRuntimeArray)
- .addDef(createTypeVReg(MIRBuilder))
- .addUse(getSPIRVTypeID(ElemType));
- });
+ if (ExplicitLayoutRequired && !isResourceType(ElemType)) {
+ Type *ET = const_cast<Type *>(getTypeForSPIRVType(ElemType));
+ addArrayStrideDecorations(ArrayType->defs().begin()->getReg(), ET,
+ MIRBuilder);
+ }
+
+ return ArrayType;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
@@ -926,7 +953,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
const StructType *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ SPIRV::AccessQualifier::AccessQualifier AccQual,
+ bool ExplicitLayoutRequired, bool EmitIR) {
SmallVector<Register, 4> FieldTypes;
constexpr unsigned MaxWordCount = UINT16_MAX;
const size_t NumElements = Ty->getNumElements();
@@ -940,8 +968,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
}
for (const auto &Elem : Ty->elements()) {
- SPIRVType *ElemTy =
- findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR);
+ SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual,
+ ExplicitLayoutRequired, EmitIR);
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
"Invalid struct element type");
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -952,18 +980,27 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
if (Ty->isPacked())
buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
- return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- auto MIBStruct = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
- for (size_t I = 0; I < SPIRVStructNumElements; ++I)
- MIBStruct.addUse(FieldTypes[I]);
- for (size_t I = SPIRVStructNumElements; I < NumElements;
- I += MaxNumElements) {
- auto MIBCont = MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL);
- for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J)
- MIBCont.addUse(FieldTypes[I]);
- }
- return MIBStruct;
- });
+ SPIRVType *SPVType =
+ createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+ auto MIBStruct =
+ MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
+ for (size_t I = 0; I < SPIRVStructNumElements; ++I)
+ MIBStruct.addUse(FieldTypes[I]);
+ for (size_t I = SPIRVStructNumElements; I < NumElements;
+ I += MaxNumElements) {
+ auto MIBCont =
+ MIRBuilder.buildInstr(SPIRV::OpTypeStructContinuedINTEL);
+ for (size_t J = I; J < std::min(I + MaxNumElements, NumElements); ++J)
+ MIBCont.addUse(FieldTypes[I]);
+ }
+ return MIBStruct;
+ });
+
+ if (ExplicitLayoutRequired)
+ addStructOffsetDecorations(SPVType->defs().begin()->getReg(),
+ const_cast<StructType *>(Ty), MIRBuilder);
+
+ return SPVType;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
@@ -1013,22 +1050,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
const Type *Ty, SPIRVType *RetType,
const SmallVectorImpl<SPIRVType *> &ArgTypes,
MachineIRBuilder &MIRBuilder) {
- if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF()))
+ if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
return MI;
const MachineInstr *NewMI = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ SPIRV::AccessQualifier::AccessQualifier AccQual,
+ bool ExplicitLayoutRequired, bool EmitIR) {
Ty = adjustIntTypeByWidth(Ty);
- if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF()))
+ // TODO: findMI needs to know if a layout is required.
+ if (const MachineInstr *MI =
+ findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
return MI;
if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end())
return It->second;
- return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
+ return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, ExplicitLayoutRequired,
+ EmitIR);
}
Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
@@ -1062,11 +1103,13 @@ const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ SPIRV::AccessQualifier::AccessQualifier AccQual,
+ bool ExplicitLayoutRequired, bool EmitIR) {
if (isSpecialOpaqueType(Ty))
return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
- if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF()))
+ if (const MachineInstr *MI =
+ findMI(Ty, ExplicitLayoutRequired, &MIRBuilder.getMF()))
return MI;
if (auto IType = dyn_cast<IntegerType>(Ty)) {
@@ -1079,27 +1122,31 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
- SPIRVType *El = findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
- MIRBuilder, AccQual, EmitIR);
+ SPIRVType *El =
+ findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder,
+ AccQual, ExplicitLayoutRequired, EmitIR);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
- SPIRVType *El =
- findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR);
- return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
+ SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder,
+ AccQual, ExplicitLayoutRequired, EmitIR);
+ return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder,
+ ExplicitLayoutRequired, EmitIR);
}
if (auto SType = dyn_cast<StructType>(Ty)) {
if (SType->isOpaque())
return getOpTypeOpaque(SType, MIRBuilder);
- return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR);
+ return getOpTypeStruct(SType, MIRBuilder, AccQual, ExplicitLayoutRequired,
+ EmitIR);
}
if (auto FType = dyn_cast<FunctionType>(Ty)) {
- SPIRVType *RetTy =
- findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR);
+ SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder,
+ AccQual, ExplicitLayoutRequired, EmitIR);
SmallVector<SPIRVType *, 4> ParamTypes;
for (const auto &ParamTy : FType->params())
- ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
+ ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual,
+ ExplicitLayoutRequired, EmitIR));
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
@@ -1114,44 +1161,50 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
- // Null pointer means we have a loop in type definitions, make and
- // return corresponding OpTypeForwardPointer.
- if (SpvElementType == nullptr) {
- auto [It, Inserted] = ForwardPointerTypes.try_emplace(Ty);
- if (Inserted)
- It->second = getOpTypeForwardPointer(SC, MIRBuilder);
- return It->second;
+
+ Type *ElemTy = ::getPointeeType(Ty);
+ if (!ElemTy) {
+ ElemTy = Type::getInt8Ty(MIRBuilder.getContext());
}
+
// If we have forward pointer associated with this type, use its register
// operand to create OpTypePointer.
if (auto It = ForwardPointerTypes.find(Ty); It != ForwardPointerTypes.end()) {
Register Reg = getSPIRVTypeID(It->second);
+ // TODO: what does getOpTypePointer do?
return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
}
- return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
+ return getOrCreateSPIRVPointerType(ElemTy, MIRBuilder, SC);
}
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool ExplicitLayoutRequired, bool EmitIR) {
+ // TODO: Could this create a problem if one requires an explicit layout, and
+ // the next time it does not?
if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty))
return nullptr;
TypesInProcessing.insert(Ty);
- SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
+ SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
+ ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+
+ // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
+ // Is that a problem?
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer ||
- findMI(Ty, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty))
+ findMI(Ty, false, &MIRBuilder.getMF()) || isSpecialOpaqueType(Ty))
return SpirvType;
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
ExtTy && isTypedPointerWrapper(ExtTy))
add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), SpirvType);
else if (!isPointerTy(Ty))
- add(Ty, SpirvType);
+ add(Ty, ExplicitLayoutRequired, SpirvType);
else if (isTypedPointerTy(Ty))
add(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), SpirvType);
@@ -1183,14 +1236,15 @@ SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
- SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
+ SPIRV::AccessQualifier::AccessQualifier AccessQual,
+ bool ExplicitLayoutRequired, bool EmitIR) {
const MachineFunction *MF = &MIRBuilder.getMF();
Register Reg;
if (auto *ExtTy = dyn_cast<TargetExtType>(Ty);
ExtTy && isTypedPointerWrapper(ExtTy))
Reg = find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), MF);
else if (!isPointerTy(Ty))
- Reg = find(Ty = adjustIntTypeByWidth(Ty), MF);
+ Reg = find(Ty = adjustIntTypeByWidth(Ty), ExplicitLayoutRequired, MF);
else if (isTypedPointerTy(Ty))
Reg = find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), MF);
@@ -1201,15 +1255,20 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
return getSPIRVTypeForVReg(Reg);
TypesInProcessing.clear();
- SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
+ SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual,
+ ExplicitLayoutRequired, EmitIR);
// Create normal pointer types for the corresponding OpTypeForwardPointers.
for (auto &CU : ForwardPointerTypes) {
+ // Pointer type themselves do not require an explicit layout. The types
+ // they pointer to might, but that is taken care of when creating the type.
+ bool PtrNeedsLayout = false;
const Type *Ty2 = CU.first;
SPIRVType *STy2 = CU.second;
- if ((Reg = find(Ty2, MF)).isValid())
+ if ((Reg = find(Ty2, PtrNeedsLayout, MF)).isValid())
STy2 = getSPIRVTypeForVReg(Reg);
else
- STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
+ STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, PtrNeedsLayout,
+ EmitIR);
if (Ty == Ty2)
STy = STy2;
}
@@ -1238,6 +1297,19 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
return false;
}
+bool SPIRVGlobalRegistry::isResourceType(SPIRVType *Type) const {
+ switch (Type->getOpcode()) {
+ case SPIRV::OpTypeImage:
+ case SPIRV::OpTypeSampler:
+ case SPIRV::OpTypeSampledImage:
+ return true;
+ case SPIRV::OpTypeStruct:
+ return hasBlockDecoration(Type);
+ default:
+ return false;
+ }
+ return false;
+}
unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
@@ -1362,16 +1434,16 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateVulkanBufferType(
if (const MachineInstr *MI = findMI(Key, &MIRBuilder.getMF()))
return MI;
- // TODO(134119): The SPIRVType for `ElemType` will not have an explicit
- // layout. This generates invalid SPIR-V.
+ bool ExplicitLayoutRequired = storageClassRequiresExplictLayout(SC);
+ // We need to get the SPIR-V type for the element here, so we can add the
+ // decoration to it.
auto *T = StructType::create(ElemType);
auto *BlockType =
- getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None, EmitIr);
+ getOrCreateSPIRVType(T, MIRBuilder, SPIRV::AccessQualifier::None,
+ ExplicitLayoutRequired, EmitIr);
buildOpDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
SPIRV::Decoration::Block, {});
- buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
- SPIRV::Decoration::Offset, 0, {0});
if (!IsWritable) {
buildOpMemberDecorate(BlockType->defs().begin()->getReg(), MIRBuilder,
@@ -1480,7 +1552,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
uint32_t Use, bool EmitIR) {
- if (const MachineInstr *MI = findMI(ExtensionType, &MIRBuilder.getMF()))
+ if (const MachineInstr *MI =
+ findMI(ExtensionType, false, &MIRBuilder.getMF()))
return MI;
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
@@ -1493,26 +1566,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
.addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, EmitIR))
.addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, EmitIR));
});
- add(ExtensionType, NewMI);
+ add(ExtensionType, false, NewMI);
return NewMI;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
- if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF()))
+ if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
return MI;
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(Opcode).addDef(createTypeVReg(MIRBuilder));
});
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return NewMI;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
const ArrayRef<MCOperand> Operands) {
- if (const MachineInstr *MI = findMI(Ty, &MIRBuilder.getMF()))
+ if (const MachineInstr *MI = findMI(Ty, false, &MIRBuilder.getMF()))
return MI;
Register ResVReg = createTypeVReg(MIRBuilder);
const MachineInstr *NewMI =
@@ -1529,7 +1602,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
}
return MIB;
});
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return NewMI;
}
@@ -1545,7 +1618,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
if (hasBuiltinTypePrefix(TypeStr))
return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
TypeStr.str(), MIRBuilder.getContext()),
- MIRBuilder, AQ, true);
+ MIRBuilder, AQ, false, true);
// Parse type name in either "typeN" or "type vector[N]" format, where
// N is the number of elements of the vector.
@@ -1556,7 +1629,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
// Unable to recognize SPIRV type name
return nullptr;
- auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ, true);
+ auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ, false, true);
// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
@@ -1585,7 +1658,7 @@ SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
- MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, true);
}
SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
@@ -1601,7 +1674,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
const SPIRVInstrInfo &TII,
unsigned SPIRVOPcode,
Type *Ty) {
- if (const MachineInstr *MI = findMI(Ty, CurMF))
+ if (const MachineInstr *MI = findMI(Ty, false, CurMF))
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
@@ -1613,7 +1686,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
.addImm(BitWidth)
.addImm(0);
});
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
@@ -1654,14 +1727,14 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder,
bool EmitIR) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
- MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
const SPIRVInstrInfo &TII) {
Type *Ty = IntegerType::get(CurMF->getFunction().getContext(), 1);
- if (const MachineInstr *MI = findMI(Ty, CurMF))
+ if (const MachineInstr *MI = findMI(Ty, false, CurMF))
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
@@ -1671,7 +1744,7 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
MIRBuilder.getDL(), TII.get(SPIRV::OpTypeBool))
.addDef(createTypeVReg(CurMF->getRegInfo()));
});
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
@@ -1681,7 +1754,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
return getOrCreateSPIRVType(
FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
NumElements),
- MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false, EmitIR);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
@@ -1689,7 +1762,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
const SPIRVInstrInfo &TII) {
Type *Ty = FixedVectorType::get(
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
- if (const MachineInstr *MI = findMI(Ty, CurMF))
+ if (const MachineInstr *MI = findMI(Ty, false, CurMF))
return MI;
MachineInstr *DepMI = const_cast<MachineInstr *>(BaseType);
MachineIRBuilder MIRBuilder(*DepMI->getParent(), DepMI->getIterator());
@@ -1701,30 +1774,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
.addUse(getSPIRVTypeID(BaseType))
.addImm(NumElements);
});
- add(Ty, NewMI);
- return finishCreatingSPIRVType(Ty, NewMI);
-}
-
-SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
- SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
- const SPIRVInstrInfo &TII) {
- Type *Ty = ArrayType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
- NumElements);
- if (const MachineInstr *MI = findMI(Ty, CurMF))
- return MI;
- SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
- Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
- MachineBasicBlock &DepMBB = I.getMF()->front();
- MachineIRBuilder MIRBuilder(DepMBB, getInsertPtValidEnd(&DepMBB));
- const MachineInstr *NewMI =
- createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRV::OpTypeArray))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addUse(getSPIRVTypeID(BaseType))
- .addUse(Len);
- });
- add(Ty, NewMI);
+ add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
@@ -1738,8 +1788,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
const Type *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC) {
+ // TODO: Need to check if EmitIr should always be true.
SPIRVType *SpirvBaseType = getOrCreateSPIRVType(
- BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+ BaseType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
+ storageClassRequiresExplictLayout(SC), true);
+ assert(SpirvBaseType);
return getOrCreateSPIRVPointerTypeInternal(SpirvBaseType, MIRBuilder, SC);
}
@@ -2006,3 +2059,33 @@ void SPIRVGlobalRegistry::updateAssignType(CallInst *AssignCI, Value *Arg,
addDeducedElementType(AssignCI, ElemTy);
addDeducedElementType(Arg, ElemTy);
}
+
+void SPIRVGlobalRegistry::addStructOffsetDecorations(
+ Register Reg, StructType *Ty, MachineIRBuilder &MIRBuilder) {
+ ArrayRef<TypeSize> Offsets =
+ DataLayout().getStructLayout(Ty)->getMemberOffsets();
+ for (uint32_t I = 0; I < Ty->getNumElements(); ++I) {
+ buildOpMemberDecorate(Reg, MIRBuilder, SPIRV::Decoration::Offset, I,
+ {static_cast<uint32_t>(Offsets[I])});
+ }
+}
+
+void SPIRVGlobalRegistry::addArrayStrideDecorations(
+ Register Reg, Type *ElementType, MachineIRBuilder &MIRBuilder) {
+ uint32_t SizeInBytes = DataLayout().getTypeSizeInBits(ElementType) / 8;
+ buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::ArrayStride,
+ {SizeInBytes});
+}
+
+bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
+ Register Def = getSPIRVTypeID(Type);
+ for (const MachineInstr &Use :
+ Type->getMF()->getRegInfo().use_instructions(Def)) {
+ if (Use.getOpcode() != SPIRV::OpDecorate)
+ continue;
+
+ if (Use.getOperand(1).getImm() == SPIRV::Decoration::Block)
+ return true;
+ }
+ return false;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index b05896fb7174c..7338e805956d6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -90,14 +90,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
- bool EmitIR);
+ bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier accessQual,
- bool EmitIR);
+ bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual,
- bool EmitIR);
+ bool ExplicitLayoutRequired, bool EmitIR);
// Internal function creating the an OpType at the correct position in the
// function by tweaking the passed "MIRBuilder" insertion point and restoring
@@ -298,10 +298,19 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
// because this method may be called from InstructionSelector and we don't
// want to emit extra IR instructions there.
+ SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineInstr &I,
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool EmitIR) {
+ MachineIRBuilder MIRBuilder(I);
+ return getOrCreateSPIRVType(Type, MIRBuilder, AQ, EmitIR);
+ }
+
SPIRVType *getOrCreateSPIRVType(const Type *Type,
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
- bool EmitIR);
+ bool EmitIR) {
+ return getOrCreateSPIRVType(Type, MIRBuilder, AQ, false, EmitIR);
+ }
const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
auto Res = SPIRVToLLVMType.find(Ty);
@@ -364,6 +373,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
+ // Returns true if `Type` is a resource type. This could be an image type
+ // or a struct for a buffer decorated with the block decoration.
+ bool isResourceType(SPIRVType *Type) const;
+
// Return number of elements in a vector if the argument is associated with
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
unsigned getScalarOrVectorComponentCount(Register VReg) const;
@@ -414,6 +427,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
const Type *adjustIntTypeByWidth(const Type *Ty) const;
unsigned adjustOpTypeIntWidth(unsigned Width) const;
+ SPIRVType *getOrCreateSPIRVType(const Type *Type,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier::AccessQualifier AQ,
+ bool ExplicitLayoutRequired, bool EmitIR);
+
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
@@ -425,14 +443,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
- MachineIRBuilder &MIRBuilder, bool EmitIR);
+ MachineIRBuilder &MIRBuilder,
+ bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *getOpTypeOpaque(const StructType *Ty,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual,
- bool EmitIR);
+ bool ExplicitLayoutRequired, bool EmitIR);
SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -475,6 +494,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC);
+ void addStructOffsetDecorations(Register Reg, StructType *Ty,
+ MachineIRBuilder &MIRBuilder);
+ void addArrayStrideDecorations(Register Reg, Type *ElementType,
+ MachineIRBuilder &MIRBuilder);
+ bool hasBlockDecoration(SPIRVType *Type) const;
+
public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
@@ -545,9 +570,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);
- SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType,
- unsigned NumElements, MachineInstr &I,
- const SPIRVInstrInfo &TII);
// Returns a pointer to a SPIR-V pointer type with the given base type and
// storage class. The base type will be translated to a SPIR-V type, and the
diff --git a/llvm/lib/Target/SPIRV/SPIRVIRMapping.h b/llvm/lib/Target/SPIRV/SPIRVIRMapping.h
index 9c9c099bc5fc4..a329fd5ed9d29 100644
--- a/llvm/lib/Target/SPIRV/SPIRVIRMapping.h
+++ b/llvm/lib/Target/SPIRV/SPIRVIRMapping.h
@@ -66,6 +66,7 @@ enum SpecialTypeKind {
STK_Value,
STK_MachineInstr,
STK_VkBuffer,
+ STK_ExplictLayoutType,
STK_Last = -1
};
@@ -150,6 +151,11 @@ inline IRHandle irhandle_vkbuffer(const Type *ElementType,
SpecialTypeKind::STK_VkBuffer);
}
+inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
+ const Type *WrpTy = unifyPtrType(Ty);
+ return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType);
+}
+
inline IRHandle handle(const Type *Ty) {
const Type *WrpTy = unifyPtrType(Ty);
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
@@ -163,6 +169,10 @@ inline IRHandle handle(const MachineInstr *KeyMI) {
return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
}
+inline bool type_has_layout_decoration(const Type *T) {
+ return (isa<StructType>(T) || isa<ArrayType>(T));
+}
+
} // namespace SPIRV
// Bi-directional mappings between LLVM entities and (v-reg, machine function)
@@ -238,14 +248,49 @@ class SPIRVIRMapping {
return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
}
- template <typename T> bool add(const T *Obj, const MachineInstr *MI) {
+ bool add(const Value *V, const MachineInstr *MI) {
+ return add(SPIRV::handle(V), MI);
+ }
+
+ bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
+ if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
+ return add(SPIRV::irhandle_explict_layout_type(T), MI);
+ }
+ return add(SPIRV::handle(T), MI);
+ }
+
+ bool add(const MachineInstr *Obj, const MachineInstr *MI) {
return add(SPIRV::handle(Obj), MI);
}
- template <typename T> Register find(const T *Obj, const MachineFunction *MF) {
- return find(SPIRV::handle(Obj), MF);
+
+ Register find(const Value *V, const MachineFunction *MF) {
+ return find(SPIRV::handle(V), MF);
+ }
+
+ Register find(const Type *T, bool RequiresExplicitLayout,
+ const MachineFunction *MF) {
+ if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
+ return find(SPIRV::irhandle_explict_layout_type(T), MF);
+ return find(SPIRV::handle(T), MF);
+ }
+
+ Register find(const MachineInstr *MI, const MachineFunction *MF) {
+ return find(SPIRV::handle(MI), MF);
+ }
+
+ const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
+ return findMI(SPIRV::handle(Obj), MF);
+ }
+
+ const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
+ const MachineFunction *MF) {
+ if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
+ return findMI(SPIRV::irhandle_explict_layout_type(T), MF);
+ return findMI(SPIRV::handle(T), MF);
}
- template <typename T>
- const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) {
+
+ const MachineInstr *findMI(const MachineInstr *Obj,
+ const MachineFunction *MF) {
return findMI(SPIRV::handle(Obj), MF);
}
};
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 216c3e26be1bf..8a873426e78d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -25,6 +25,42 @@
using namespace llvm;
+// Returns true of the types logically match, as defined in
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
+static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
+ SPIRVGlobalRegistry &GR) {
+ if (Ty1->getOpcode() != Ty2->getOpcode())
+ return false;
+
+ if (Ty1->getNumOperands() != Ty2->getNumOperands())
+ return false;
+
+ if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
+ // Array must have the same size.
+ if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
+ return false;
+
+ SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
+ SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
+ return ElemType1 == ElemType2 ||
+ typesLogicallyMatch(ElemType1, ElemType2, GR);
+ }
+
+ if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
+ for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
+ SPIRVType *ElemType1 =
+ GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg());
+ SPIRVType *ElemType2 =
+ GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg());
+ if (ElemType1 != ElemType2 &&
+ !typesLogicallyMatch(ElemType1, ElemType2, GR))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
// This code avoids CallLowering fail inside getVectorTypeBreakdown
@@ -374,6 +410,9 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
// implies that %Op is a pointer to <ResType>
case SPIRV::OpLoad:
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
+ if (enforcePtrTypeCompatibility(MI, 2, 0))
+ break;
+
validatePtrTypes(STI, MRI, GR, MI, 2,
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
break;
@@ -531,3 +570,58 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
ProcessedMF.insert(&MF);
TargetLowering::finalizeLowering(MF);
}
+
+// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
+// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
+// match or if the instruction was modified to make them match.
+bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
+ MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
+ SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
+ SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
+ SPIRVType *PointeeType = GR.getPointeeType(PtrType);
+ SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
+
+ if (PointeeType == OpType)
+ return true;
+
+ if (typesLogicallyMatch(PointeeType, OpType, GR)) {
+ // Apply OpCopyLogical to OpIdx.
+ if (I.getOperand(OpIdx).isDef() &&
+ insertLogicalCopyOnResult(I, PointeeType)) {
+ return true;
+ }
+
+ llvm_unreachable("Unable to add OpCopyLogical yet.");
+ return false;
+ }
+
+ return false;
+}
+
+bool SPIRVTargetLowering::insertLogicalCopyOnResult(
+ MachineInstr &I, SPIRVType *NewResultType) const {
+ MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
+ SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
+
+ Register NewResultReg =
+ createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
+ Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
+
+ assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
+ "Expected only one def");
+ MachineOperand &OldResult = *I.defs().begin();
+ Register OldResultReg = OldResult.getReg();
+ MachineOperand &OldType = *I.uses().begin();
+ Register OldTypeReg = OldType.getReg();
+
+ OldResult.setReg(NewResultReg);
+ OldType.setReg(NewTypeReg);
+
+ MachineIRBuilder MIB(*I.getNextNode());
+ return MIB.buildInstr(SPIRV::OpCopyLogical)
+ .addDef(OldResultReg)
+ .addUse(OldTypeReg)
+ .addUse(NewResultReg)
+ .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index eb78299b72f04..9025e6eb0842e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -71,6 +71,11 @@ class SPIRVTargetLowering : public TargetLowering {
EVT ConditionVT) const override {
return ConditionVT.getSimpleVT();
}
+
+ bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx,
+ unsigned OpIdx) const;
+ bool insertLogicalCopyOnResult(MachineInstr &I,
+ SPIRVType *NewResultType) const;
};
} // namespace llvm
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
index fc8faa7300534..f539fdefa3fa2 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll
@@ -11,17 +11,18 @@ declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handle
; CHECK: OpDecorate [[BufferVar:%.+]] DescriptorSet 0
; CHECK: OpDecorate [[BufferVar]] Binding 0
-; CHECK: OpDecorate [[BufferType:%.+]] Block
-; CHECK: OpMemberDecorate [[BufferType]] 0 Offset 0
+; CHECK: OpMemberDecorate [[BufferType:%.+]] 0 Offset 0
+; CHECK: OpDecorate [[BufferType]] Block
; CHECK: OpMemberDecorate [[BufferType]] 0 NonWritable
; CHECK: OpDecorate [[RWBufferVar:%.+]] DescriptorSet 0
; CHECK: OpDecorate [[RWBufferVar]] Binding 1
-; CHECK: OpDecorate [[RWBufferType:%.+]] Block
-; CHECK: OpMemberDecorate [[RWBufferType]] 0 Offset 0
+; CHECK: OpDecorate [[ArrayType:%.+]] ArrayStride 4
+; CHECK: OpMemberDecorate [[RWBufferType:%.+]] 0 Offset 0
+; CHECK: OpDecorate [[RWBufferType]] Block
; CHECK: [[int:%[0-9]+]] = OpTypeInt 32 0
-; CHECK: [[ArrayType:%.+]] = OpTypeRuntimeArray
+; CHECK: [[ArrayType]] = OpTypeRuntimeArray
; CHECK: [[RWBufferType]] = OpTypeStruct [[ArrayType]]
; CHECK: [[RWBufferPtrType:%.+]] = OpTypePointer StorageBuffer [[RWBufferType]]
; CHECK: [[BufferType]] = OpTypeStruct [[ArrayType]]
diff --git a/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll b/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll
new file mode 100644
index 0000000000000..7303471c9929c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/spirv-explicit-layout.ll
@@ -0,0 +1,149 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val %}
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
+
+; CHECK-DAG: OpName [[ScalarBlock_var:%[0-9]+]] "__resource_p_12_{_u32[0]}_0_0"
+; CHECK-DAG: OpName [[buffer_var:%[0-9]+]] "__resource_p_12_{_{_{_u32_f32[3]}[10]}[0]}_0_0"
+; CHECK-DAG: OpName [[array_buffer_var:%[0-9]+]] "__resource_p_12_{_{_{_u32_f32[3]}[10]}[0]}[10]_0_0"
+
+; CHECK-DAG: OpMemberDecorate [[ScalarBlock:%[0-9]+]] 0 Offset 0
+; CHECK-DAG: OpDecorate [[ScalarBlock]] Block
+; CHECK-DAG: OpMemberDecorate [[ScalarBlock]] 0 NonWritable
+; CHECK-DAG: OpMemberDecorate [[T_explicit:%[0-9]+]] 0 Offset 0
+; CHECK-DAG: OpMemberDecorate [[T_explicit]] 1 Offset 16
+; CHECK-DAG: OpDecorate [[T_array_explicit:%[0-9]+]] ArrayStride 32
+; CHECK-DAG: OpMemberDecorate [[S_explicit:%[0-9]+]] 0 Offset 0
+; CHECK-DAG: OpDecorate [[S_array_explicit:%[0-9]+]] ArrayStride 320
+; CHECK-DAG: OpMemberDecorate [[block:%[0-9]+]] 0 Offset 0
+; CHECK-DAG: OpDecorate [[block]] Block
+; CHECK-DAG: OpMemberDecorate [[block]] 0 NonWritable
+
+; CHECK-DAG: [[float:%[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: [[v3f:%[0-9]+]] = OpTypeVector [[float]] 3
+; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: [[T:%[0-9]+]] = OpTypeStruct [[uint]] [[v3f]]
+; CHECK-DAG: [[T_explicit]] = OpTypeStruct [[uint]] [[v3f]]
+%struct.T = type { i32, <3 x float> }
+
+; CHECK-DAG: [[zero:%[0-9]+]] = OpConstant [[uint]] 0{{$}}
+; CHECK-DAG: [[one:%[0-9]+]] = OpConstant [[uint]] 1{{$}}
+; CHECK-DAG: [[ten:%[0-9]+]] = OpConstant [[uint]] 10
+; CHECK-DAG: [[T_array:%[0-9]+]] = OpTypeArray [[T]] [[ten]]
+; CHECK-DAG: [[S:%[0-9]+]] = OpTypeStruct [[T_array]]
+; CHECK-DAG: [[T_array_explicit]] = OpTypeArray [[T_explicit]] [[ten]]
+; CHECK-DAG: [[S_explicit]] = OpTypeStruct [[T_array_explicit]]
+%struct.S = type { [10 x %struct.T] }
+
+; CHECK-DAG: [[private_S_ptr:%[0-9]+]] = OpTypePointer Private [[S]]
+; CHECK-DAG: [[private_var:%[0-9]+]] = OpVariable [[private_S_ptr]] Private
+ at private = internal addrspace(10) global %struct.S poison
+
+; CHECK-DAG: [[storagebuffer_S_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[S_explicit]]
+; CHECK-DAG: [[storage_buffer:%[0-9]+]] = OpVariable [[storagebuffer_S_ptr]] StorageBuffer
+ at storage_buffer = internal addrspace(11) global %struct.S poison
+
+; CHECK-DAG: [[storagebuffer_int_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[uint]]
+; CHECK-DAG: [[ScalarBlock_array:%[0-9]+]] = OpTypeRuntimeArray [[uint]]
+; CHECK-DAG: [[ScalarBlock]] = OpTypeStruct [[ScalarBlock_array]]
+; CHECK-DAG: [[ScalarBlock_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[ScalarBlock]]
+; CHECK-DAG: [[ScalarBlock_var]] = OpVariable [[ScalarBlock_ptr]] StorageBuffer
+
+
+; CHECK-DAG: [[S_array_explicit]] = OpTypeRuntimeArray [[S_explicit]]
+; CHECK-DAG: [[block]] = OpTypeStruct [[S_array_explicit]]
+; CHECK-DAG: [[buffer_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[block]]
+; CHECK-DAG: [[buffer_var]] = OpVariable [[buffer_ptr]] StorageBuffer
+
+; CHECK-DAG: [[array_buffer:%[0-9]+]] = OpTypeArray [[block]] [[ten]]
+; CHECK-DAG: [[array_buffer_ptr:%[0-9]+]] = OpTypePointer StorageBuffer [[array_buffer]]
+; CHECK-DAG: [[array_buffer_var]] = OpVariable [[array_buffer_ptr]] StorageBuffer
+
+; CHECK: OpFunction [[uint]] None
+define external i32 @scalar_vulkan_buffer_load() {
+; CHECK-NEXT: OpLabel
+entry:
+; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[ScalarBlock_ptr]] [[ScalarBlock_var]]
+ %handle = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+
+; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_int_ptr]] [[handle]] [[zero]] [[one]]
+ %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x i32], 12, 0) %handle, i32 1)
+
+; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[uint]] [[ptr]] Aligned 4
+ %1 = load i32, ptr addrspace(11) %0, align 4
+
+; CHECK-NEXT: OpReturnValue [[ld]]
+ ret i32 %1
+
+; CHECK-NEXT: OpFunctionEnd
+}
+
+; CHECK: OpFunction [[S]] None
+define external %struct.S @private_load() {
+; CHECK-NEXT: OpLabel
+entry:
+
+; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S]] [[private_var]] Aligned 4
+ %1 = load %struct.S, ptr addrspace(10) @private, align 4
+
+; CHECK-NEXT: OpReturnValue [[ld]]
+ ret %struct.S %1
+
+; CHECK-NEXT: OpFunctionEnd
+}
+
+; CHECK: OpFunction [[S]] None
+define external %struct.S @storage_buffer_load() {
+; CHECK-NEXT: OpLabel
+entry:
+
+; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[storage_buffer]] Aligned 4
+; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]]
+ %1 = load %struct.S, ptr addrspace(11) @storage_buffer, align 4
+
+; CHECK-NEXT: OpReturnValue [[copy]]
+ ret %struct.S %1
+
+; CHECK-NEXT: OpFunctionEnd
+}
+
+; CHECK: OpFunction [[S]] None
+define external %struct.S @vulkan_buffer_load() {
+; CHECK-NEXT: OpLabel
+entry:
+; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[buffer_ptr]] [[buffer_var]]
+ %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+
+; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_S_ptr]] [[handle]] [[zero]] [[one]]
+ %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) %handle, i32 1)
+
+; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[ptr]] Aligned 4
+; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]]
+ %1 = load %struct.S, ptr addrspace(11) %0, align 4
+
+; CHECK-NEXT: OpReturnValue [[copy]]
+ ret %struct.S %1
+
+; CHECK-NEXT: OpFunctionEnd
+}
+
+; CHECK: OpFunction [[S]] None
+define external %struct.S @array_of_vulkan_buffers_load() {
+; CHECK-NEXT: OpLabel
+entry:
+; CHECK-NEXT: [[h:%[0-9]+]] = OpAccessChain [[buffer_ptr]] [[array_buffer_var]] [[one]]
+; CHECK-NEXT: [[handle:%[0-9]+]] = OpCopyObject [[buffer_ptr]] [[h]]
+ %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) @llvm.spv.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 1, i1 false)
+
+; CHECK-NEXT: [[ptr:%[0-9]+]] = OpAccessChain [[storagebuffer_S_ptr]] [[handle]] [[zero]] [[one]]
+ %0 = tail call noundef nonnull align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 0) %handle, i32 1)
+
+; CHECK-NEXT: [[ld:%[0-9]+]] = OpLoad [[S_explicit]] [[ptr]] Aligned 4
+; CHECK-NEXT: [[copy:%[0-9]+]] = OpCopyLogical [[S]] [[ld]]
+ %1 = load %struct.S, ptr addrspace(11) %0, align 4
+
+; CHECK-NEXT: OpReturnValue [[copy]]
+ ret %struct.S %1
+
+; CHECK-NEXT: OpFunctionEnd
+}
More information about the llvm-commits
mailing list