[llvm] 0daeed6 - [SPIR-V] Improve implementation of the duplicates tracker's storage (#95958)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 19 01:39:18 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-06-19T10:39:14+02:00
New Revision: 0daeed645d22704250bc22aec121c467ffc72e22
URL: https://github.com/llvm/llvm-project/commit/0daeed645d22704250bc22aec121c467ffc72e22
DIFF: https://github.com/llvm/llvm-project/commit/0daeed645d22704250bc22aec121c467ffc72e22.diff
LOG: [SPIR-V] Improve implementation of the duplicates tracker's storage (#95958)
This PR continues https://github.com/llvm/llvm-project/pull/94952,
managing FunctionType in the same way as a pointee types in
https://github.com/llvm/llvm-project/pull/94952 (that is working with
TypedPointers pointee types rather than with original llvm's untyped
pointers).
This PR also fully reworks the base type for the duplicates tracker's
storage to conform with and reuse DenseMapInfo. Previous implementation
didn't store enough info to differ between key values (see isEqual()
implemented as equality of derived from arguments hash values). This, in
turn, led to random crashes in very rare occasions when hash value of an
actual key matched hash values of empty and tombstone instances. In this
PR we use std::tuple instead of a tailor-made class hierarchy, both
reusing DenseMapInfo templates and getting rid of the crash condition.
Added:
Modified:
llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVUtils.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 3c8405fadd44e..a37e65a47eda0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -52,152 +52,87 @@ class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
void addDep(DTSortableEntry *E) { Deps.push_back(E); }
};
-struct SpecialTypeDescriptor {
- enum SpecialTypeKind {
- STK_Empty = 0,
- STK_Image,
- STK_SampledImage,
- STK_Sampler,
- STK_Pipe,
- STK_DeviceEvent,
- STK_Pointer,
- STK_Last = -1
- };
- SpecialTypeKind Kind;
-
- unsigned Hash;
-
- SpecialTypeDescriptor() = delete;
- SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; }
-
- unsigned getHash() const { return Hash; }
-
- virtual ~SpecialTypeDescriptor() {}
-};
-
-struct ImageTypeDescriptor : public SpecialTypeDescriptor {
- union ImageAttrs {
- struct BitFlags {
- unsigned Dim : 3;
- unsigned Depth : 2;
- unsigned Arrayed : 1;
- unsigned MS : 1;
- unsigned Sampled : 2;
- unsigned ImageFormat : 6;
- unsigned AQ : 2;
- } Flags;
- unsigned Val;
- };
-
- ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
- unsigned Arrayed, unsigned MS, unsigned Sampled,
- unsigned ImageFormat, unsigned AQ = 0)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
- ImageAttrs Attrs;
- Attrs.Val = 0;
- Attrs.Flags.Dim = Dim;
- Attrs.Flags.Depth = Depth;
- Attrs.Flags.Arrayed = Arrayed;
- Attrs.Flags.MS = MS;
- Attrs.Flags.Sampled = Sampled;
- Attrs.Flags.ImageFormat = ImageFormat;
- Attrs.Flags.AQ = AQ;
- Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
- ((Attrs.Val << 8) | Kind);
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Image;
- }
-};
-
-struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
- SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
- assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
- ImageTypeDescriptor TD(
- SampledTy, ImageTy->getOperand(2).getImm(),
- ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
- ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
- ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
- Hash = TD.getHash() ^ Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_SampledImage;
- }
-};
-
-struct SamplerTypeDescriptor : public SpecialTypeDescriptor {
- SamplerTypeDescriptor()
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) {
- Hash = Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Sampler;
- }
+enum SpecialTypeKind {
+ STK_Empty = 0,
+ STK_Image,
+ STK_SampledImage,
+ STK_Sampler,
+ STK_Pipe,
+ STK_DeviceEvent,
+ STK_Pointer,
+ STK_Last = -1
};
-struct PipeTypeDescriptor : public SpecialTypeDescriptor {
-
- PipeTypeDescriptor(uint8_t AQ)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) {
- Hash = (AQ << 8) | Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Pipe;
+using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
+
+union ImageAttrs {
+ struct BitFlags {
+ unsigned Dim : 3;
+ unsigned Depth : 2;
+ unsigned Arrayed : 1;
+ unsigned MS : 1;
+ unsigned Sampled : 2;
+ unsigned ImageFormat : 6;
+ unsigned AQ : 2;
+ } Flags;
+ unsigned Val;
+
+ ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+ unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+ Val = 0;
+ Flags.Dim = Dim;
+ Flags.Depth = Depth;
+ Flags.Arrayed = Arrayed;
+ Flags.MS = MS;
+ Flags.Sampled = Sampled;
+ Flags.ImageFormat = ImageFormat;
+ Flags.AQ = AQ;
}
};
-struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
-
- DeviceEventTypeDescriptor()
- : SpecialTypeDescriptor(SpecialTypeKind::STK_DeviceEvent) {
- Hash = Kind;
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_DeviceEvent;
- }
-};
-
-struct PointerTypeDescriptor : public SpecialTypeDescriptor {
- const Type *ElementType;
- unsigned AddressSpace;
-
- PointerTypeDescriptor() = delete;
- PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
- : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
- ElementType(ElementType), AddressSpace(AddressSpace) {
- Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
- ((AddressSpace << 8) | Kind);
- }
-
- static bool classof(const SpecialTypeDescriptor *TD) {
- return TD->Kind == SpecialTypeKind::STK_Pointer;
- }
-};
+inline SpecialTypeDescriptor
+make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
+ unsigned Arrayed, unsigned MS, unsigned Sampled,
+ unsigned ImageFormat, unsigned AQ = 0) {
+ return std::make_tuple(
+ SampledTy,
+ ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
+ SpecialTypeKind::STK_Image);
+}
+
+inline SpecialTypeDescriptor
+make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
+ assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
+ return std::make_tuple(
+ SampledTy,
+ ImageAttrs(
+ ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
+ ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
+ ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
+ ImageTy->getOperand(8).getImm())
+ .Val,
+ SpecialTypeKind::STK_SampledImage);
+}
+
+inline SpecialTypeDescriptor make_descr_sampler() {
+ return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
+}
+
+inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
+ return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
+}
+
+inline SpecialTypeDescriptor make_descr_event() {
+ return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
+}
+
+inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
+ unsigned AddressSpace) {
+ return std::make_tuple(ElementType, AddressSpace,
+ SpecialTypeKind::STK_Pointer);
+}
} // namespace SPIRV
-template <> struct DenseMapInfo<SPIRV::SpecialTypeDescriptor> {
- static inline SPIRV::SpecialTypeDescriptor getEmptyKey() {
- return SPIRV::SpecialTypeDescriptor(
- SPIRV::SpecialTypeDescriptor::STK_Empty);
- }
- static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() {
- return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last);
- }
- static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) {
- return Val.getHash();
- }
- static bool isEqual(SPIRV::SpecialTypeDescriptor LHS,
- SPIRV::SpecialTypeDescriptor RHS) {
- return getHashValue(LHS) == getHashValue(RHS);
- }
-};
-
template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
public:
// NOTE: using MapVector instead of DenseMap helps getting everything ordered
@@ -283,16 +218,13 @@ class SPIRVGeneralDuplicatesTracker {
MachineModuleInfo *MMI);
void add(const Type *Ty, const MachineFunction *MF, Register R) {
- TT.add(Ty, MF, R);
+ TT.add(unifyPtrType(Ty), MF, R);
}
void add(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF, Register R) {
- if (isUntypedPointerTy(PointeeTy))
- PointeeTy =
- TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
- getPointerAddressSpace(PointeeTy));
- ST.add(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF, R);
+ ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
+ R);
}
void add(const Constant *C, const MachineFunction *MF, Register R) {
@@ -321,16 +253,13 @@ class SPIRVGeneralDuplicatesTracker {
}
Register find(const Type *Ty, const MachineFunction *MF) {
- return TT.find(const_cast<Type *>(Ty), MF);
+ return TT.find(unifyPtrType(Ty), MF);
}
Register find(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
- if (isUntypedPointerTy(PointeeTy))
- PointeeTy =
- TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
- getPointerAddressSpace(PointeeTy));
- return ST.find(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF);
+ return ST.find(
+ SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
}
Register find(const Constant *C, const MachineFunction *MF) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d434e0b5efbcc..b22d2a04f75b1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -936,7 +936,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = Ty;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
// will be added later. For special types it is already added to DT.
@@ -1122,9 +1122,9 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
SPIRV::ImageFormat::ImageFormat ImageFormat,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
- SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
- Arrayed, Multisampled, Sampled, ImageFormat,
- AccessQual);
+ auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
+ Depth, Arrayed, Multisampled, Sampled,
+ ImageFormat, AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1143,7 +1143,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
SPIRVType *
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
- SPIRV::SamplerTypeDescriptor TD;
+ auto TD = SPIRV::make_descr_sampler();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1154,7 +1154,7 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
- SPIRV::PipeTypeDescriptor TD(AccessQual);
+ auto TD = SPIRV::make_descr_pipe(AccessQual);
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1166,7 +1166,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
MachineIRBuilder &MIRBuilder) {
- SPIRV::DeviceEventTypeDescriptor TD;
+ auto TD = SPIRV::make_descr_event();
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
return Res;
Register ResVReg = createTypeVReg(MIRBuilder);
@@ -1176,7 +1176,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
- SPIRV::SampledImageTypeDescriptor TD(
+ auto TD = SPIRV::make_descr_sampled_image(
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
ImageType->getOperand(1).getReg())),
ImageType);
@@ -1268,7 +1268,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = LLVMTy;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
return SpirvType;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c131eecb1c137..12725d6bac14a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
: Ty;
}
+inline Type *toTypedFunPointer(FunctionType *FTy) {
+ Type *OrigRetTy = FTy->getReturnType();
+ Type *RetTy = toTypedPointer(OrigRetTy);
+ bool IsUntypedPtr = false;
+ for (Type *PTy : FTy->params()) {
+ if (isUntypedPointerTy(PTy)) {
+ IsUntypedPtr = true;
+ break;
+ }
+ }
+ if (!IsUntypedPtr && RetTy == OrigRetTy)
+ return FTy;
+ SmallVector<Type *> ParamTys;
+ for (Type *PTy : FTy->params())
+ ParamTys.push_back(toTypedPointer(PTy));
+ return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
+}
+
+inline const Type *unifyPtrType(const Type *Ty) {
+ if (auto FTy = dyn_cast<FunctionType>(Ty))
+ return toTypedFunPointer(const_cast<FunctionType *>(FTy));
+ return toTypedPointer(const_cast<Type *>(Ty));
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
More information about the llvm-commits
mailing list