[llvm] [SPIR-V] Improve implementation of the duplicates tracker's storage (PR #95958)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 18 10:20:14 PDT 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/95958

This PR continues https://github.com/llvm/llvm-project/pull/94952, managing FunctionType in the same way as TypedPointers pointee types in https://github.com/llvm/llvm-project/pull/94952. This PR also changes hashing to conform with existing DenseMapInfo approach. 

>From 3867328f0800f085e03b3a17ae624976a88f405d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 18 Jun 2024 10:15:28 -0700
Subject: [PATCH] Improve DuplicatesTracker

---
 .../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 95 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  4 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.h            | 24 +++++
 3 files changed, 76 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 3c8405fadd44e..08b5c3b92a1ae 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -63,8 +63,7 @@ struct SpecialTypeDescriptor {
     STK_Pointer,
     STK_Last = -1
   };
-  SpecialTypeKind Kind;
-
+  unsigned Kind;
   unsigned Hash;
 
   SpecialTypeDescriptor() = delete;
@@ -75,35 +74,41 @@ struct SpecialTypeDescriptor {
   virtual ~SpecialTypeDescriptor() {}
 };
 
+union ImageAttrs {
+  struct BitFlags {
+    unsigned Dim : 3;
+    unsigned Depth : 2;
+    unsigned Arrayed : 1;
+    unsigned MS : 1;
+    unsigned Sampled : 2;
+    unsigned ImageFormat : 6;
+    unsigned AQ : 2;
+  } Flags;
+  unsigned Val;
+
+  ImageAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+             unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+    Val = 0;
+    Flags.Dim = Dim;
+    Flags.Depth = Depth;
+    Flags.Arrayed = Arrayed;
+    Flags.MS = MS;
+    Flags.Sampled = Sampled;
+    Flags.ImageFormat = ImageFormat;
+    Flags.AQ = AQ;
+  }
+};
+
 struct ImageTypeDescriptor : public SpecialTypeDescriptor {
-  union ImageAttrs {
-    struct BitFlags {
-      unsigned Dim : 3;
-      unsigned Depth : 2;
-      unsigned Arrayed : 1;
-      unsigned MS : 1;
-      unsigned Sampled : 2;
-      unsigned ImageFormat : 6;
-      unsigned AQ : 2;
-    } Flags;
-    unsigned Val;
-  };
+  using Tuple = std::tuple<const Type *, unsigned, unsigned>;
 
   ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth,
                       unsigned Arrayed, unsigned MS, unsigned Sampled,
                       unsigned ImageFormat, unsigned AQ = 0)
       : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) {
-    ImageAttrs Attrs;
-    Attrs.Val = 0;
-    Attrs.Flags.Dim = Dim;
-    Attrs.Flags.Depth = Depth;
-    Attrs.Flags.Arrayed = Arrayed;
-    Attrs.Flags.MS = MS;
-    Attrs.Flags.Sampled = Sampled;
-    Attrs.Flags.ImageFormat = ImageFormat;
-    Attrs.Flags.AQ = AQ;
-    Hash = (DenseMapInfo<Type *>().getHashValue(SampledTy) & 0xffff) ^
-           ((Attrs.Val << 8) | Kind);
+    ImageAttrs Attrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ);
+    Hash = DenseMapInfo<Tuple>::getHashValue(
+        std::make_tuple(SampledTy, Attrs.Val, Kind));
   }
 
   static bool classof(const SpecialTypeDescriptor *TD) {
@@ -112,15 +117,18 @@ struct ImageTypeDescriptor : public SpecialTypeDescriptor {
 };
 
 struct SampledImageTypeDescriptor : public SpecialTypeDescriptor {
+  using Tuple = std::tuple<const Type *, unsigned, unsigned>;
+
   SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy)
       : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) {
     assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
-    ImageTypeDescriptor TD(
-        SampledTy, ImageTy->getOperand(2).getImm(),
-        ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(),
-        ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(),
-        ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm());
-    Hash = TD.getHash() ^ Kind;
+    ImageAttrs Attrs(
+        ImageTy->getOperand(2).getImm(), ImageTy->getOperand(3).getImm(),
+        ImageTy->getOperand(4).getImm(), ImageTy->getOperand(5).getImm(),
+        ImageTy->getOperand(6).getImm(), ImageTy->getOperand(7).getImm(),
+        ImageTy->getOperand(8).getImm());
+    Hash = DenseMapInfo<Tuple>::getHashValue(
+        std::make_tuple(SampledTy, Attrs.Val, Kind));
   }
 
   static bool classof(const SpecialTypeDescriptor *TD) {
@@ -164,6 +172,8 @@ struct DeviceEventTypeDescriptor : public SpecialTypeDescriptor {
 };
 
 struct PointerTypeDescriptor : public SpecialTypeDescriptor {
+  using Pair = std::pair<const Type *, unsigned>;
+
   const Type *ElementType;
   unsigned AddressSpace;
 
@@ -171,8 +181,8 @@ struct PointerTypeDescriptor : public SpecialTypeDescriptor {
   PointerTypeDescriptor(const Type *ElementType, unsigned AddressSpace)
       : SpecialTypeDescriptor(SpecialTypeKind::STK_Pointer),
         ElementType(ElementType), AddressSpace(AddressSpace) {
-    Hash = (DenseMapInfo<Type *>().getHashValue(ElementType) & 0xffff) ^
-           ((AddressSpace << 8) | Kind);
+    Hash = DenseMapInfo<Pair>::getHashValue(
+        std::make_pair(ElementType, (AddressSpace << 8) | Kind));
   }
 
   static bool classof(const SpecialTypeDescriptor *TD) {
@@ -283,16 +293,13 @@ class SPIRVGeneralDuplicatesTracker {
                       MachineModuleInfo *MMI);
 
   void add(const Type *Ty, const MachineFunction *MF, Register R) {
-    TT.add(Ty, MF, R);
+    TT.add(unifyPtrType(Ty), MF, R);
   }
 
   void add(const Type *PointeeTy, unsigned AddressSpace,
            const MachineFunction *MF, Register R) {
-    if (isUntypedPointerTy(PointeeTy))
-      PointeeTy =
-          TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
-                                getPointerAddressSpace(PointeeTy));
-    ST.add(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF, R);
+    ST.add(SPIRV::PointerTypeDescriptor(unifyPtrType(PointeeTy), AddressSpace),
+           MF, R);
   }
 
   void add(const Constant *C, const MachineFunction *MF, Register R) {
@@ -321,16 +328,14 @@ class SPIRVGeneralDuplicatesTracker {
   }
 
   Register find(const Type *Ty, const MachineFunction *MF) {
-    return TT.find(const_cast<Type *>(Ty), MF);
+    return TT.find(unifyPtrType(Ty), MF);
   }
 
   Register find(const Type *PointeeTy, unsigned AddressSpace,
                 const MachineFunction *MF) {
-    if (isUntypedPointerTy(PointeeTy))
-      PointeeTy =
-          TypedPointerType::get(IntegerType::getInt8Ty(PointeeTy->getContext()),
-                                getPointerAddressSpace(PointeeTy));
-    return ST.find(SPIRV::PointerTypeDescriptor(PointeeTy, AddressSpace), MF);
+    return ST.find(
+        SPIRV::PointerTypeDescriptor(unifyPtrType(PointeeTy), AddressSpace),
+        MF);
   }
 
   Register find(const Constant *C, const MachineFunction *MF) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d434e0b5efbcc..78f4fb142eb0f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -936,7 +936,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
   TypesInProcessing.erase(Ty);
   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
-  SPIRVToLLVMType[SpirvType] = Ty;
+  SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
   // will be added later. For special types it is already added to DT.
@@ -1268,7 +1268,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
                                                         SPIRVType *SpirvType) {
   assert(CurMF == SpirvType->getMF());
   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
-  SPIRVToLLVMType[SpirvType] = LLVMTy;
+  SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
   return SpirvType;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index c131eecb1c137..12725d6bac14a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -160,5 +160,29 @@ inline Type *toTypedPointer(Type *Ty) {
              : Ty;
 }
 
+inline Type *toTypedFunPointer(FunctionType *FTy) {
+  Type *OrigRetTy = FTy->getReturnType();
+  Type *RetTy = toTypedPointer(OrigRetTy);
+  bool IsUntypedPtr = false;
+  for (Type *PTy : FTy->params()) {
+    if (isUntypedPointerTy(PTy)) {
+      IsUntypedPtr = true;
+      break;
+    }
+  }
+  if (!IsUntypedPtr && RetTy == OrigRetTy)
+    return FTy;
+  SmallVector<Type *> ParamTys;
+  for (Type *PTy : FTy->params())
+    ParamTys.push_back(toTypedPointer(PTy));
+  return FunctionType::get(RetTy, ParamTys, FTy->isVarArg());
+}
+
+inline const Type *unifyPtrType(const Type *Ty) {
+  if (auto FTy = dyn_cast<FunctionType>(Ty))
+    return toTypedFunPointer(const_cast<FunctionType *>(FTy));
+  return toTypedPointer(const_cast<Type *>(Ty));
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H



More information about the llvm-commits mailing list