[llvm] [SPIR-V] Improve implementation of the duplicates tracker's storage (PR #95958)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 18 10:20:14 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/95958
This PR continues https://github.com/llvm/llvm-project/pull/94952, managing FunctionType in the same way as TypedPointers pointee types in https://github.com/llvm/llvm-project/pull/94952. This PR also changes hashing to conform with existing DenseMapInfo approach.
>From 3867328f0800f085e03b3a17ae624976a88f405d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 18 Jun 2024 10:15:28 -0700
Subject: [PATCH] Improve DuplicatesTracker
---
.../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 95 ++++++++++---------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVUtils.h | 24 +++++
3 files changed, 76 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 3c8405fadd44e..08b5c3b92a1ae 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -63,8 +63,7 @@ struct SpecialTypeDescriptor {
STK_Pointer,
STK_Last = -1
};
- SpecialTypeKind Kind;
-
+ unsigned Kind;
unsigned Hash;
SpecialTypeDescriptor() = delete;
@@ -75,35 +74,41 @@ struct SpecialTypeDescriptor {
virtual ~SpecialTypeDescriptor() {}
};
+union ImageAttrs {
+ struct BitFlags {
+ unsigned Dim : 3;
+ unsigned Depth : 2;
+ unsigned Arrayed : 1;
+ unsigned MS : 1;
+ unsigned Sampled : 2;
+ unsigned ImageFormat : 6;
+ unsigned AQ : 2;
+ } Flags;
+ unsigned Val;
+
+ ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+ unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+ Val = 0;
+ Flags.Dim = Dim;
+ Flags.Depth = Depth;
+ Flags.Arrayed = Arrayed;
+ Flags.MS = MS;
+ Flags.Sampled = Sampled;
+ Flags.ImageFormat = ImageFormat;
+ Flags.AQ = AQ;
+ }
+};
+
struct ImageTypeDescriptor : public SpecialTypeDescriptor {
- union ImageAttrs {
- struct BitFlags {
- unsigned Dim : 3;
- unsigned Depth : 2;
- unsigned Arrayed : 1;
- unsigned MS : 1;
- unsigned Sampled : 2;
- unsigned ImageFormat : 6;
- unsigned AQ : 2;
- } Flags;
- unsigned Val;
- };
+ using Tuple = std::tuple<const Type *, unsigned, unsigned>;
ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
unsigned Arrayed, unsigned MS, unsigned Sampled,
unsigned ImageFormat, unsigned AQ = 0)
: SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
- ImageAttrs Attrs;
- Attrs.Val = 0;
- Attrs.Flags.Dim = Dim;
- Attrs.Flags.Depth = Depth;
- Attrs.Flags.Arrayed = Arrayed;
- Attrs.Flags.MS = MS;
- Attrs.Flags.Sampled = Sampled;
- Attrs.Flags.ImageFormat = ImageFormat;
- Attrs.Flags.AQ = AQ;
- Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
- ((Attrs.Val << 8) | Kind);
+ ImageAttrs Attrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ);
+ Hash = DenseMapInfo<Tuple>::getHashValue(
+ std::make_tuple(SampledTy, Attrs.Val, Kind));
}
static bool classof(const SpecialTypeDescriptor *TD) {
@@ -112,15 +117,18 @@ struct ImageTypeDescriptor : public SpecialTypeDescriptor {
};
struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
+ using Tuple = std::tuple<const Type *, unsigned, unsigned>;
+
SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
: SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
- ImageTypeDescriptor TD(
- SampledTy, ImageTy->getOperand(2).getImm(),
- ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
- ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
- ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
- Hash = TD.getHash() ^ Kind;
+ ImageAttrs Attrs(
+ ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
+ ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
+ ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
+ ImageTy->getOperand(8).getImm());
+ Hash = DenseMapInfo<Tuple>::getHashValue(
+ std::make_tuple(SampledTy, Attrs.Val, Kind));
}
static bool classof(const SpecialTypeDescriptor *TD) {
@@ -164,6 +172,8 @@ struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
};
struct PointerTypeDescriptor : public SpecialTypeDescriptor {
+ using Pair = std::pair<const Type *, unsigned>;
+
const Type *ElementType;
unsigned AddressSpace;
@@ -171,8 +181,8 @@ struct PointerTypeDescriptor : public SpecialTypeDescriptor {
PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
: SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
ElementType(ElementType), AddressSpace(AddressSpace) {
- Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
- ((AddressSpace << 8) | Kind);
+ Hash = DenseMapInfo<Pair>::getHashValue(
+ std::make_pair(ElementType, (AddressSpace << 8) | Kind));
}
static bool classof(const SpecialTypeDescriptor *TD) {
@@ -283,16 +293,13 @@ class SPIRVGeneralDuplicatesTracker {
MachineModuleInfo *MMI);
void add(const Type *Ty, const MachineFunction *MF, Register R) {
- TT.add(Ty, MF, R);
+ TT.add(unifyPtrType(Ty), MF, R);
}
void add(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF, Register R) {
- if (isUntypedPointerTy(PointeeTy))
- PointeeTy =
- TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
- getPointerAddressSpace(PointeeTy));
- ST.add(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF, R);
+ ST.add(SPIRV::PointerTypeDescriptor(unifyPtrType(PointeeTy), AddressSpace),
+ MF, R);
}
void add(const Constant *C, const MachineFunction *MF, Register R) {
@@ -321,16 +328,14 @@ class SPIRVGeneralDuplicatesTracker {
}
Register find(const Type *Ty, const MachineFunction *MF) {
- return TT.find(const_cast<Type *>(Ty), MF);
+ return TT.find(unifyPtrType(Ty), MF);
}
Register find(const Type *PointeeTy, unsigned AddressSpace,
const MachineFunction *MF) {
- if (isUntypedPointerTy(PointeeTy))
- PointeeTy =
- TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
- getPointerAddressSpace(PointeeTy));
- return ST.find(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF);
+ return ST.find(
+ SPIRV::PointerTypeDescriptor(unifyPtrType(PointeeTy), AddressSpace),
+ MF);
}
Register find(const Constant *C, const MachineFunction *MF) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d434e0b5efbcc..78f4fb142eb0f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -936,7 +936,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = Ty;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
// will be added later. For special types it is already added to DT.
@@ -1268,7 +1268,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
- SPIRVToLLVMType[SpirvType] = LLVMTy;
+ SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
return SpirvType;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c131eecb1c137..12725d6bac14a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
: Ty;
}
+inline Type *toTypedFunPointer(FunctionType *FTy) {
+ Type *OrigRetTy = FTy->getReturnType();
+ Type *RetTy = toTypedPointer(OrigRetTy);
+ bool IsUntypedPtr = false;
+ for (Type *PTy : FTy->params()) {
+ if (isUntypedPointerTy(PTy)) {
+ IsUntypedPtr = true;
+ break;
+ }
+ }
+ if (!IsUntypedPtr && RetTy == OrigRetTy)
+ return FTy;
+ SmallVector<Type *> ParamTys;
+ for (Type *PTy : FTy->params())
+ ParamTys.push_back(toTypedPointer(PTy));
+ return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
+}
+
+inline const Type *unifyPtrType(const Type *Ty) {
+ if (auto FTy = dyn_cast<FunctionType>(Ty))
+ return toTypedFunPointer(const_cast<FunctionType *>(FTy));
+ return toTypedPointer(const_cast<Type *>(Ty));
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
More information about the llvm-commits
mailing list