[llvm] [SPIR-V] Remove spv_track_constant() internal intrinsics (PR #130605)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 11 04:44:20 PDT 2025


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/130605

>From 33221a17e7e0d4a12249a248ccbbff0a7df02ae4 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 10 Mar 2025 06:46:43 -0700
Subject: [PATCH 1/3] remove spv_track_constant() intrnal intrinsics

---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  9 +-
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 82 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 33 +++-----
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 19 ++---
 5 files changed, 65 insertions(+), 80 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7012ef3534c68..89fb92be0e1eb 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -13,7 +13,7 @@
 let TargetPrefix = "spv" in {
   def int_spv_assign_type : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
   def int_spv_assign_ptr_type : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
-  def int_spv_assign_name : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
+  def int_spv_assign_name : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
   def int_spv_assign_decoration : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
   def int_spv_value_md : Intrinsic<[], [llvm_metadata_ty]>;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 579e37f68d5d8..f5c31ea737839 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -391,12 +391,9 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
 // TODO: maybe unify with prelegalizer pass.
 static unsigned getConstFromIntrinsic(Register Reg, MachineRegisterInfo *MRI) {
   MachineInstr *DefMI = MRI->getUniqueVRegDef(Reg);
-  assert(isSpvIntrinsic(*DefMI, Intrinsic::spv_track_constant) &&
-         DefMI->getOperand(2).isReg());
-  MachineInstr *DefMI2 = MRI->getUniqueVRegDef(DefMI->getOperand(2).getReg());
-  assert(DefMI2->getOpcode() == TargetOpcode::G_CONSTANT &&
-         DefMI2->getOperand(1).isCImm());
-  return DefMI2->getOperand(1).getCImm()->getValue().getZExtValue();
+  assert(DefMI->getOpcode() == TargetOpcode::G_CONSTANT &&
+         DefMI->getOperand(1).isCImm());
+  return DefMI->getOperand(1).getCImm()->getValue().getZExtValue();
 }
 
 // Return type of the instruction result from spv_assign_type intrinsic.
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 751ea5ab2dc47..356f4f6dab75c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -63,7 +63,7 @@ class SPIRVEmitIntrinsics
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
   Function *CurrF = nullptr;
-  bool TrackConstants = true;
+  bool TrackConstants = false;//true;
   bool HaveFunPtrs = false;
   DenseMap<Instruction *, Constant *> AggrConsts;
   DenseMap<Instruction *, Type *> AggrConstTypes;
@@ -316,8 +316,10 @@ static void emitAssignName(Instruction *I, IRBuilder<> &B) {
     return;
   reportFatalOnTokenType(I);
   setInsertPointAfterDef(B, I);
-  std::vector<Value *> Args = {I};
-  addStringImm(I->getName(), B, Args);
+  LLVMContext &Ctx = I->getContext();
+  std::vector<Value *> Args = {
+      I, MetadataAsValue::get(
+             Ctx, MDNode::get(Ctx, MDString::get(Ctx, I->getName())))};
   B.CreateIntrinsic(Intrinsic::spv_assign_name, {I->getType()}, Args);
 }
 
@@ -2023,7 +2025,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   auto *II = dyn_cast<IntrinsicInst>(I);
   bool IsConstComposite =
       II && II->getIntrinsicID() == Intrinsic::spv_const_composite;
-  if (IsConstComposite && TrackConstants) {
+  if (IsConstComposite) {
     setInsertPointAfterDef(B, I);
     auto t = AggrConsts.find(I);
     assert(t != AggrConsts.end());
@@ -2035,41 +2037,43 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
   bool IsPhi = isa<PHINode>(I), BPrepared = false;
   for (const auto &Op : I->operands()) {
-    if (isa<PHINode>(I) || isa<SwitchInst>(I))
-      TrackConstants = false;
-    if ((isa<ConstantData>(Op) || isa<ConstantExpr>(Op)) && TrackConstants) {
-      unsigned OpNo = Op.getOperandNo();
-      if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) ||
-                 (II->paramHasAttr(OpNo, Attribute::ImmArg))))
-        continue;
-      if (!BPrepared) {
-        IsPhi ? B.SetInsertPointPastAllocas(I->getParent()->getParent())
-              : B.SetInsertPoint(I);
-        BPrepared = true;
-      }
-      Type *OpTy = Op->getType();
-      Value *OpTyVal = Op;
-      if (OpTy->isTargetExtTy())
-        OpTyVal = getNormalizedPoisonValue(OpTy);
-      CallInst *NewOp =
-          buildIntrWithMD(Intrinsic::spv_track_constant,
-                          {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
-      Type *OpElemTy = nullptr;
-      if (!IsConstComposite && isPointerTy(OpTy) &&
-          (OpElemTy = GR->findDeducedElementType(Op)) != nullptr &&
-          OpElemTy != IntegerType::getInt8Ty(I->getContext())) {
-        GR->buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
-        SmallVector<Type *, 2> Types = {OpTy, OpTy};
-        SmallVector<Value *, 2> Args = {
-            NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
-            B.getInt32(getPointerAddressSpace(OpTy))};
-        CallInst *PtrCasted =
-            B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-        GR->buildAssignPtr(B, OpElemTy, PtrCasted);
-        NewOp = PtrCasted;
-      }
-      I->setOperand(OpNo, NewOp);
+    if (isa<PHINode>(I) || isa<SwitchInst>(I) ||
+        !(isa<ConstantData>(Op) || isa<ConstantExpr>(Op)))
+      continue;
+    unsigned OpNo = Op.getOperandNo();
+    if (II && ((II->getIntrinsicID() == Intrinsic::spv_gep && OpNo == 0) ||
+                (II->paramHasAttr(OpNo, Attribute::ImmArg))))
+      continue;
+
+    if (!BPrepared) {
+      IsPhi ? B.SetInsertPointPastAllocas(I->getParent()->getParent())
+            : B.SetInsertPoint(I);
+      BPrepared = true;
     }
+    Type *OpTy = Op->getType();
+    Value *OpTyVal = Op;
+    if (OpTy->isTargetExtTy())
+      OpTyVal = getNormalizedPoisonValue(OpTy);
+    Value *NewOp = Op;
+    if (OpTy->isTargetExtTy())
+      NewOp = buildIntrWithMD(Intrinsic::spv_track_constant,
+                              {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B);
+    Type *OpElemTy = nullptr;
+    if (!IsConstComposite && isPointerTy(OpTy) &&
+        (OpElemTy = GR->findDeducedElementType(Op)) != nullptr &&
+        OpElemTy != IntegerType::getInt8Ty(I->getContext())) {
+      GR->buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp);
+      SmallVector<Type *, 2> Types = {OpTy, OpTy};
+      SmallVector<Value *, 2> Args = {
+          NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)),
+          B.getInt32(getPointerAddressSpace(OpTy))};
+      CallInst *PtrCasted =
+          B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+      GR->buildAssignPtr(B, OpElemTy, PtrCasted);
+      NewOp = PtrCasted;
+    }
+    if (NewOp != Op)
+      I->setOperand(OpNo, NewOp);
   }
   emitAssignName(I, B);
 }
@@ -2417,7 +2421,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
         deduceOperandElementType(&Phi, nullptr);
 
   for (auto *I : Worklist) {
-    TrackConstants = true;
+    TrackConstants = false;//true;
     if (!I->getType()->isVoidTy() || isa<StoreInst>(I))
       setInsertPointAfterDef(B, I);
     // Visitors return either the original/newly created instruction for further
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 3779a4b6ccd34..edf215f0ce00f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -133,32 +133,25 @@ addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR,
     MI->eraseFromParent();
 }
 
-static void
-foldConstantsIntoIntrinsics(MachineFunction &MF,
-                            const SmallSet<Register, 4> &TrackedConstRegs) {
-  SmallVector<MachineInstr *, 10> ToErase;
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-  const unsigned AssignNameOperandShift = 2;
+static void foldConstantsIntoIntrinsics(MachineFunction &MF,
+                                        MachineIRBuilder MIB) {
+  SmallVector<MachineInstr *, 64> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
       if (!isSpvIntrinsic(MI, Intrinsic::spv_assign_name))
         continue;
-      unsigned NumOp = MI.getNumExplicitDefs() + AssignNameOperandShift;
-      while (MI.getOperand(NumOp).isReg()) {
-        MachineOperand &MOp = MI.getOperand(NumOp);
-        MachineInstr *ConstMI = MRI.getVRegDef(MOp.getReg());
-        assert(ConstMI->getOpcode() == TargetOpcode::G_CONSTANT);
-        MI.removeOperand(NumOp);
-        MI.addOperand(MachineOperand::CreateImm(
-            ConstMI->getOperand(1).getCImm()->getZExtValue()));
-        Register DefReg = ConstMI->getOperand(0).getReg();
-        if (MRI.use_empty(DefReg) && !TrackedConstRegs.contains(DefReg))
-          ToErase.push_back(ConstMI);
+      const MDNode *MD = MI.getOperand(2).getMetadata();
+      StringRef ValueName = cast<MDString>(MD->getOperand(0))->getString();
+      if (ValueName.size() > 0) {
+        MIB.setInsertPt(*MI.getParent(), MI);
+        buildOpName(MI.getOperand(1).getReg(), ValueName, MIB);
       }
+      ToErase.push_back(&MI);
     }
+    for (MachineInstr *MI : ToErase)
+      MI->eraseFromParent();
+    ToErase.clear();
   }
-  for (MachineInstr *MI : ToErase)
-    MI->eraseFromParent();
 }
 
 static MachineInstr *findAssignTypeInstr(Register Reg,
@@ -1043,7 +1036,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
   // to keep record of tracked constants
   SmallSet<Register, 4> TrackedConstRegs;
   addConstantsToTrack(MF, GR, ST, TargetExtConstTypes, TrackedConstRegs);
-  foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
+  foldConstantsIntoIntrinsics(MF, MIB);
   insertBitcasts(MF, GR, MIB);
   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index ce4f6d6c9288f..05bebb5a0e9c1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -307,20 +307,11 @@ SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
                                        const MachineRegisterInfo *MRI) {
   MachineInstr *MI = MRI->getVRegDef(ConstReg);
-  MachineInstr *ConstInstr =
-      MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
-          ? MRI->getVRegDef(MI->getOperand(1).getReg())
-          : MI;
-  if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
-    if (GI->is(Intrinsic::spv_track_constant)) {
-      ConstReg = ConstInstr->getOperand(2).getReg();
-      return MRI->getVRegDef(ConstReg);
-    }
-  } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
-    ConstReg = ConstInstr->getOperand(1).getReg();
-    return MRI->getVRegDef(ConstReg);
-  }
-  return MRI->getVRegDef(ConstReg);
+  if (MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT)
+    return getDefInstrMaybeConstant(ConstReg = MI->getOperand(1).getReg(), MRI);
+  if (MI->getOpcode() == SPIRV::ASSIGN_TYPE)
+    return getDefInstrMaybeConstant(ConstReg = MI->getOperand(1).getReg(), MRI);
+  return MI;
 }
 
 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {

>From c2fcf6357b015fc1272d3658ca45db604fb9b19b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 10 Mar 2025 11:39:29 -0700
Subject: [PATCH 2/3] fixes

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  5 ++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 29 ++++++++---------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  3 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 11 +++++++
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 32 +++++++++++++++++--
 .../test/CodeGen/SPIRV/SampledImageRetType.ll |  6 ++--
 6 files changed, 60 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f5c31ea737839..e6c0a526a9cec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1968,8 +1968,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
       Sampler = GR->buildConstantSampler(
           Register(), getSamplerAddressingModeFromBitmask(SamplerMask),
           getSamplerParamFromBitmask(SamplerMask),
-          getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder,
-          GR->getSPIRVTypeForVReg(Sampler));
+          getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder);
     }
     SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image);
     SPIRVType *SampledImageType =
@@ -2056,7 +2055,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
     Register Sampler = GR->buildConstantSampler(
         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
         getSamplerParamFromBitmask(Bitmask),
-        getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder, Call->ReturnType);
+        getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder);
     return Sampler.isValid();
   } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) {
     // Create OpSampledImage.
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cbec1c95eadc3..fb718d9ddd0b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -398,7 +398,9 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
     if (EmitIR) {
-      MIRBuilder.buildConstant(Res, *ConstInt);
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return MIRBuilder.buildConstant(Res, *ConstInt);
+      });
     } else {
       Register SpvTypeReg = getSPIRVTypeID(SpvType);
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
@@ -605,7 +607,9 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     if (EmitIR) {
-      MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst);
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        return MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst);
+      });
     } else {
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         if (Val) {
@@ -668,17 +672,10 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
   return Res;
 }
 
-Register SPIRVGlobalRegistry::buildConstantSampler(
-    Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
-    MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
-  SPIRVType *SampTy;
-  if (SpvType)
-    SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder,
-                                  SPIRV::AccessQualifier::ReadWrite, true);
-  else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder,
-                                                false)) == nullptr)
-    report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
-
+Register
+SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode,
+                                          unsigned Param, unsigned FilerMode,
+                                          MachineIRBuilder &MIRBuilder) {
   auto Sampler =
       ResReg.isValid()
           ? ResReg
@@ -686,7 +683,7 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
   auto Res = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
         .addDef(Sampler)
-        .addUse(getSPIRVTypeID(SampTy))
+        .addUse(getSPIRVTypeID(getOrCreateOpTypeSampler(MIRBuilder)))
         .addImm(AddrMode)
         .addImm(Param)
         .addImm(FilerMode);
@@ -1383,7 +1380,9 @@ SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
     return Res;
   Register ResVReg = createTypeVReg(MIRBuilder);
   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
-  return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
+  });
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 89599f17ef737..467d9b73e2e39 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -544,8 +544,7 @@ class SPIRVGlobalRegistry {
                                    SPIRVType *SpvType);
   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
                                 unsigned FilerMode,
-                                MachineIRBuilder &MIRBuilder,
-                                SPIRVType *SpvType);
+                                MachineIRBuilder &MIRBuilder);
   Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType,
                             const SPIRVInstrInfo &TII);
   Register buildGlobalVariable(Register Reg, SPIRVType *BaseType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b188f36ca9a9e..ce9ce3ef3135c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2326,6 +2326,17 @@ static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef,
         return false;
     }
     return true;
+  case SPIRV::OpConstantTrue:
+  case SPIRV::OpConstantFalse:
+  case SPIRV::OpConstantI:
+  case SPIRV::OpConstantF:
+  case SPIRV::OpConstantComposite:
+  case SPIRV::OpConstantCompositeContinuedINTEL:
+  case SPIRV::OpConstantSampler:
+  case SPIRV::OpConstantNull:
+  case SPIRV::OpUndef:
+  case SPIRV::OpConstantFunctionPointerINTEL:
+    return true;
   }
   }
   return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 05bebb5a0e9c1..8380cd579004d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -303,16 +303,42 @@ SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
     return SPIRV::Scope::Device;
   return SPIRV::Scope::CrossDevice;
 }
-
+/*
 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
                                        const MachineRegisterInfo *MRI) {
   MachineInstr *MI = MRI->getVRegDef(ConstReg);
-  if (MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT)
-    return getDefInstrMaybeConstant(ConstReg = MI->getOperand(1).getReg(), MRI);
+  //if (MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT)
+  //  return getDefInstrMaybeConstant(ConstReg = MI->getOperand(1).getReg(), MRI);
   if (MI->getOpcode() == SPIRV::ASSIGN_TYPE)
     return getDefInstrMaybeConstant(ConstReg = MI->getOperand(1).getReg(), MRI);
+  if (auto *GI = dyn_cast<GIntrinsic>(MI))
+    if (GI->is(Intrinsic::spv_track_constant))
+      return MRI->getVRegDef(ConstReg = MI->getOperand(2).getReg());
   return MI;
 }
+*/
+MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
+                                       const MachineRegisterInfo *MRI) {
+  MachineInstr *MI = MRI->getVRegDef(ConstReg);
+  MachineInstr *ConstInstr =
+      MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
+          ? MRI->getVRegDef(MI->getOperand(1).getReg())
+          : MI;
+  if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
+    if (GI->is(Intrinsic::spv_track_constant)) {
+      ConstReg = ConstInstr->getOperand(2).getReg();
+      return MRI->getVRegDef(ConstReg);
+    }
+  } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) {
+    ConstReg = ConstInstr->getOperand(1).getReg();
+    return MRI->getVRegDef(ConstReg);
+  } else if (ConstInstr->getOpcode() == TargetOpcode::G_CONSTANT ||
+             ConstInstr->getOpcode() == TargetOpcode::G_FCONSTANT) {
+    ConstReg = ConstInstr->getOperand(0).getReg();
+    return ConstInstr;
+  }
+  return MRI->getVRegDef(ConstReg);
+}
 
 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
   const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
diff --git a/llvm/test/CodeGen/SPIRV/SampledImageRetType.ll b/llvm/test/CodeGen/SPIRV/SampledImageRetType.ll
index f034f293dc6a9..91f83e09c94f0 100644
--- a/llvm/test/CodeGen/SPIRV/SampledImageRetType.ll
+++ b/llvm/test/CodeGen/SPIRV/SampledImageRetType.ll
@@ -4,9 +4,9 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: %[[#image1d_t:]] = OpTypeImage
-; CHECK: %[[#sampler_t:]] = OpTypeSampler
-; CHECK: %[[#sampled_image_t:]] = OpTypeSampledImage
+; CHECK-DAG: %[[#image1d_t:]] = OpTypeImage
+; CHECK-DAG: %[[#sampler_t:]] = OpTypeSampler
+; CHECK-DAG: %[[#sampled_image_t:]] = OpTypeSampledImage
 
 declare dso_local spir_func ptr addrspace(4) @_Z20__spirv_SampledImageI14ocl_image1d_roPvET0_T_11ocl_sampler(target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %0, target("spirv.Sampler") %1) local_unnamed_addr
 

>From 663108c10f9fbd46075d1710dd0beb21a1918331 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 11 Mar 2025 04:44:05 -0700
Subject: [PATCH 3/3] remove duplicate tracker

---
 .../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 228 ++++-----------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 275 +++++++-----------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  41 +--
 3 files changed, 165 insertions(+), 379 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 441e32c1eb695..21e1069adb2d3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -26,35 +26,8 @@
 
 namespace llvm {
 namespace SPIRV {
-class SPIRVInstrInfo;
-// NOTE: using MapVector instead of DenseMap because it helps getting
-// everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize
-// memory and expensive removals which do not happen anyway.
-class DTSortableEntry : public MapVector<const MachineFunction *, Register> {
-  SmallVector<DTSortableEntry *, 2> Deps;
 
-  struct FlagsTy {
-    unsigned IsFunc : 1;
-    unsigned IsGV : 1;
-    unsigned IsConst : 1;
-    // NOTE: bit-field default init is a C++20 feature.
-    FlagsTy() : IsFunc(0), IsGV(0), IsConst(0) {}
-  };
-  FlagsTy Flags;
-
-public:
-  // Common hoisting utility doesn't support function, because their hoisting
-  // require hoisting of params as well.
-  bool getIsFunc() const { return Flags.IsFunc; }
-  bool getIsGV() const { return Flags.IsGV; }
-  bool getIsConst() const { return Flags.IsConst; }
-  void setIsFunc(bool V) { Flags.IsFunc = V; }
-  void setIsGV(bool V) { Flags.IsGV = V; }
-  void setIsConst(bool V) { Flags.IsConst = V; }
-
-  const SmallVector<DTSortableEntry *, 2> &getDeps() const { return Deps; }
-  void addDep(DTSortableEntry *E) { Deps.push_back(E); }
-};
+using IRHandle = std::tuple<const void *, unsigned, unsigned>;
 
 enum SpecialTypeKind {
   STK_Empty = 0,
@@ -63,12 +36,11 @@ enum SpecialTypeKind {
   STK_Sampler,
   STK_Pipe,
   STK_DeviceEvent,
+  STK_ElementPointer,
   STK_Pointer,
   STK_Last = -1
 };
 
-using SpecialTypeDescriptor = std::tuple<const Type *, unsigned, unsigned>;
-
 union ImageAttrs {
   struct BitFlags {
     unsigned Dim : 3;
@@ -94,18 +66,18 @@ union ImageAttrs {
   }
 };
 
-inline SpecialTypeDescriptor
-make_descr_image(const Type *SampledTy, unsigned Dim, unsigned Depth,
-                 unsigned Arrayed, unsigned MS, unsigned Sampled,
-                 unsigned ImageFormat, unsigned AQ = 0) {
+inline IRHandle make_descr_image(const Type *SampledTy, unsigned Dim,
+                                 unsigned Depth, unsigned Arrayed, unsigned MS,
+                                 unsigned Sampled, unsigned ImageFormat,
+                                 unsigned AQ = 0) {
   return std::make_tuple(
       SampledTy,
       ImageAttrs(Dim, Depth, Arrayed, MS, Sampled, ImageFormat, AQ).Val,
       SpecialTypeKind::STK_Image);
 }
 
-inline SpecialTypeDescriptor
-make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
+inline IRHandle make_descr_sampled_image(const Type *SampledTy,
+                                         const MachineInstr *ImageTy) {
   assert(ImageTy->getOpcode() == SPIRV::OpTypeImage);
   unsigned AC = AccessQualifier::AccessQualifier::None;
   if (ImageTy->getNumOperands() > 8)
@@ -120,170 +92,84 @@ make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
       SpecialTypeKind::STK_SampledImage);
 }
 
-inline SpecialTypeDescriptor make_descr_sampler() {
+inline IRHandle make_descr_sampler() {
   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
 }
 
-inline SpecialTypeDescriptor make_descr_pipe(uint8_t AQ) {
+inline IRHandle make_descr_pipe(uint8_t AQ) {
   return std::make_tuple(nullptr, AQ, SpecialTypeKind::STK_Pipe);
 }
 
-inline SpecialTypeDescriptor make_descr_event() {
+inline IRHandle make_descr_event() {
   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_DeviceEvent);
 }
 
-inline SpecialTypeDescriptor make_descr_pointee(const Type *ElementType,
-                                                unsigned AddressSpace) {
+inline IRHandle make_descr_pointee(const Type *ElementType,
+                                   unsigned AddressSpace) {
   return std::make_tuple(ElementType, AddressSpace,
-                         SpecialTypeKind::STK_Pointer);
+                         SpecialTypeKind::STK_ElementPointer);
 }
-} // namespace SPIRV
 
-template <typename KeyTy> class SPIRVDuplicatesTrackerBase {
-public:
-  // NOTE: using MapVector instead of DenseMap helps getting everything ordered
-  // in a stable manner for a price of extra (NumKeys)*PtrSize memory and
-  // expensive removals which don't happen anyway.
-  using StorageTy = MapVector<KeyTy, SPIRV::DTSortableEntry>;
+inline IRHandle make_descr_ptr(const void *Ptr) {
+  return std::make_tuple(Ptr, 0U, SpecialTypeKind::STK_Pointer);
+}
+} // namespace SPIRV
 
-private:
-  StorageTy Storage;
+// Bi-directional mappings between LLVM entities and (v-reg, machine function)
+// pairs support management of unique SPIR-V definitions per machine function
+// per an LLVM/GlobalISel entity (e.g., Type, Constant, Machine Instruction).
+class SPIRVIRMap {
+  DenseMap < std::pair<IRHandle, const MachineFunction *>, Register >> Vregs;
+  DenseMap<const MachineInstr *, IRHandle> Defs;
 
 public:
-  void add(KeyTy V, const MachineFunction *MF, Register R) {
-    if (find(V, MF).isValid())
-      return;
-
-    auto &S = Storage[V];
-    S[MF] = R;
-    if (std::is_same<Function,
-                     typename std::remove_const<
-                         typename std::remove_pointer<KeyTy>::type>::type>() ||
-        std::is_same<Argument,
-                     typename std::remove_const<
-                         typename std::remove_pointer<KeyTy>::type>::type>())
-      S.setIsFunc(true);
-    if (std::is_same<GlobalVariable,
-                     typename std::remove_const<
-                         typename std::remove_pointer<KeyTy>::type>::type>())
-      S.setIsGV(true);
-    if (std::is_same<Constant,
-                     typename std::remove_const<
-                         typename std::remove_pointer<KeyTy>::type>::type>())
-      S.setIsConst(true);
-  }
-
-  Register find(KeyTy V, const MachineFunction *MF) const {
-    auto iter = Storage.find(V);
-    if (iter != Storage.end()) {
-      auto Map = iter->second;
-      auto iter2 = Map.find(MF);
-      if (iter2 != Map.end())
-        return iter2->second;
+  bool add(IRHandle Handle, const MachineInstr *MI) {
+    auto [It, Inserted] =
+        Vregs.try_emplace(std::make_pair(Handle, MI->getMF()));
+    if (Inserted) {
+      It->second = MI->getOperand(0).getReg();
+      auto [_, IsConsistent] = Defs.insert_or_assign(MI, Handle);
+      assert(IsConsistent);
     }
-    return Register();
-  }
-
-  const StorageTy &getAllUses() const { return Storage; }
-
-private:
-  StorageTy &getAllUses() { return Storage; }
-
-  // The friend class needs to have access to the internal storage
-  // to be able to build dependency graph, can't declare only one
-  // function a 'friend' due to the incomplete declaration at this point
-  // and mutual dependency problems.
-  friend class SPIRVGeneralDuplicatesTracker;
-};
-
-template <typename T>
-class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase<const T *> {};
-
-template <>
-class SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor>
-    : public SPIRVDuplicatesTrackerBase<SPIRV::SpecialTypeDescriptor> {};
-
-class SPIRVGeneralDuplicatesTracker {
-  SPIRVDuplicatesTracker<Type> TT;
-  SPIRVDuplicatesTracker<Constant> CT;
-  SPIRVDuplicatesTracker<GlobalVariable> GT;
-  SPIRVDuplicatesTracker<Function> FT;
-  SPIRVDuplicatesTracker<Argument> AT;
-  SPIRVDuplicatesTracker<MachineInstr> MT;
-  SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
-
-public:
-  void add(const Type *Ty, const MachineFunction *MF, Register R) {
-    TT.add(unifyPtrType(Ty), MF, R);
-  }
-
-  void add(const Type *PointeeTy, unsigned AddressSpace,
-           const MachineFunction *MF, Register R) {
-    ST.add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF,
-           R);
+    return Inserted1;
   }
-
-  void add(const Constant *C, const MachineFunction *MF, Register R) {
-    CT.add(C, MF, R);
+  bool erase(const MachineInstr *MI) {
+    bool Res = false;
+    if (auto It = Defs.find(MI); It != Defs.end()) {
+      Res = Vregs.erase(std::make_pair(It->second, MI->getMF()));
+      Defs.erase(It);
+    }
+    return Res;
   }
-
-  void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) {
-    GT.add(GV, MF, R);
+  Register find(IRHandle Handle, const MachineFunction *MF) {
+    if (auto It = Vregs.find(std::make_pair(Handle, MF)); It != Vregs.end())
+      return It->second;
+    return Register();
   }
 
-  void add(const Function *F, const MachineFunction *MF, Register R) {
-    FT.add(F, MF, R);
+  // helpers
+  bool add(const Type *Ty, const MachineInstr *MI) {
+    return add(SPIRV::make_descr_ptr(unifyPtrType(Ty)), MI);
   }
-
-  void add(const Argument *Arg, const MachineFunction *MF, Register R) {
-    AT.add(Arg, MF, R);
+  void add(const void *Key, const MachineInstr *MI) {
+    return add(SPIRV::make_descr_ptr(Key), MI);
   }
-
-  void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
-    MT.add(MI, MF, R);
+  bool add(const Type *PointeeTy, unsigned AddressSpace,
+           const MachineInstr *MI) {
+    return add(SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace),
+               MI);
   }
-
-  void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
-           Register R) {
-    ST.add(TD, MF, R);
-  }
-
   Register find(const Type *Ty, const MachineFunction *MF) {
-    return TT.find(unifyPtrType(Ty), MF);
+    return find(SPIRV::make_descr_ptr(unifyPtrType(Ty)), MF);
+  }
+  Register find(const void *Key, const MachineFunction *MF) {
+    return find(SPIRV::make_descr_ptr(Key), MF);
   }
-
   Register find(const Type *PointeeTy, unsigned AddressSpace,
                 const MachineFunction *MF) {
-    return ST.find(
+    return find(
         SPIRV::make_descr_pointee(unifyPtrType(PointeeTy), AddressSpace), MF);
   }
-
-  Register find(const Constant *C, const MachineFunction *MF) {
-    return CT.find(const_cast<Constant *>(C), MF);
-  }
-
-  Register find(const GlobalVariable *GV, const MachineFunction *MF) {
-    return GT.find(const_cast<GlobalVariable *>(GV), MF);
-  }
-
-  Register find(const Function *F, const MachineFunction *MF) {
-    return FT.find(const_cast<Function *>(F), MF);
-  }
-
-  Register find(const Argument *Arg, const MachineFunction *MF) {
-    return AT.find(const_cast<Argument *>(Arg), MF);
-  }
-
-  Register find(const MachineInstr *MI, const MachineFunction *MF) {
-    return MT.find(const_cast<MachineInstr *>(MI), MF);
-  }
-
-  Register find(const SPIRV::SpecialTypeDescriptor &TD,
-                const MachineFunction *MF) {
-    return ST.find(TD, MF);
-  }
-
-  const SPIRVDuplicatesTracker<Type> *getTypes() { return &TT; }
 };
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index fb718d9ddd0b4..c897026dc2426 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -181,7 +181,7 @@ void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
   // - take into account duplicate tracker case which is a known issue,
   // - review other data structure wrt. possible issues related to removal
   //   of a machine instruction during instruction selection.
-  const MachineFunction *MF = MI->getParent()->getParent();
+  const MachineFunction *MF = MI->getMF();
   auto It = LastInsertedTypeMap.find(MF);
   if (It == LastInsertedTypeMap.end())
     return;
@@ -245,98 +245,46 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
   });
 }
 
-std::tuple<Register, ConstantInt *, bool, unsigned>
-SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
-                                            MachineIRBuilder *MIRBuilder,
-                                            MachineInstr *I,
-                                            const SPIRVInstrInfo *TII) {
-  assert(SpvType);
-  const IntegerType *LLVMIntTy =
-      cast<IntegerType>(getTypeForSPIRVType(SpvType));
-  unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
-  bool NewInstr = false;
-  // Find a constant in DT or build a new one.
-  ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
-  Register Res = DT.find(CI, CurMF);
-  if (!Res.isValid()) {
-    Res =
-        CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
-    CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
-    if (MIRBuilder)
-      assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder,
-                       SPIRV::AccessQualifier::ReadWrite, true);
-    else
-      assignIntTypeToVReg(BitWidth, Res, *I, *TII);
-    DT.add(CI, CurMF, Res);
-    NewInstr = true;
-  }
-  return std::make_tuple(Res, CI, NewInstr, BitWidth);
-}
-
-std::tuple<Register, ConstantFP *, bool, unsigned>
-SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
-                                              MachineIRBuilder *MIRBuilder,
-                                              MachineInstr *I,
-                                              const SPIRVInstrInfo *TII) {
-  assert(SpvType);
-  LLVMContext &Ctx = CurMF->getFunction().getContext();
-  const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
-  unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
-  bool NewInstr = false;
-  // Find a constant in DT or build a new one.
-  auto *const CI = ConstantFP::get(Ctx, Val);
-  Register Res = DT.find(CI, CurMF);
-  if (!Res.isValid()) {
-    Res =
-        CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
-    CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
-    if (MIRBuilder)
-      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder,
-                       SPIRV::AccessQualifier::ReadWrite, true);
-    else
-      assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
-    DT.add(CI, CurMF, Res);
-    NewInstr = true;
-  }
-  return std::make_tuple(Res, CI, NewInstr, BitWidth);
-}
-
 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
                                                  SPIRVType *SpvType,
                                                  const SPIRVInstrInfo &TII,
                                                  bool ZeroAsNull) {
-  assert(SpvType);
-  ConstantFP *CI;
-  Register Res;
-  bool New;
-  unsigned BitWidth;
-  std::tie(Res, CI, New, BitWidth) =
-      getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
-  // If we have found Res register which is defined by the passed G_CONSTANT
-  // machine instruction, a new constant instruction should be created.
-  if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
+  unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
+  LLVMContext &Ctx = CurMF->getFunction().getContext();
+  auto *const CF = ConstantFP::get(Ctx, Val);
+  Register Res = find(CF, CurMF);
+  if (Res.isValid())
     return Res;
+
+  LLT LLTy = LLT::scalar(BitWidth);
+  Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+  CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
+  assignFloatTypeToVReg(BitWidth, Res, I, TII);
+
   MachineIRBuilder MIRBuilder(I);
-  createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    MachineInstrBuilder MIB;
-    // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
-    if (Val.isPosZero() && ZeroAsNull) {
-      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
-                .addDef(Res)
-                .addUse(getSPIRVTypeID(SpvType));
-    } else {
-      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
-                .addDef(Res)
-                .addUse(getSPIRVTypeID(SpvType));
-      addNumImm(
-          APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
-          MIB);
-    }
-    const auto &ST = CurMF->getSubtarget();
-    constrainSelectedInstRegOperands(
-        *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
-    return MIB;
-  });
+  SPIRVType *NewType =
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        MachineInstrBuilder MIB;
+        // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
+        if (Val.isPosZero() && ZeroAsNull) {
+          MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
+                    .addDef(Res)
+                    .addUse(getSPIRVTypeID(SpvType));
+        } else {
+          MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
+                    .addDef(Res)
+                    .addUse(getSPIRVTypeID(SpvType));
+          addNumImm(APInt(BitWidth,
+                          CI->getValueAPF().bitcastToAPInt().getZExtValue()),
+                    MIB);
+        }
+        const auto &ST = CurMF->getSubtarget();
+        constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+                                         *ST.getRegisterInfo(),
+                                         *ST.getRegBankInfo());
+        return MIB;
+      });
+  add(CI, NewType);
   return Res;
 }
 
@@ -344,36 +292,37 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
                                                   SPIRVType *SpvType,
                                                   const SPIRVInstrInfo &TII,
                                                   bool ZeroAsNull) {
-  assert(SpvType);
-  ConstantInt *CI;
-  Register Res;
-  bool New;
-  unsigned BitWidth;
-  std::tie(Res, CI, New, BitWidth) =
-      getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
-  // If we have found Res register which is defined by the passed G_CONSTANT
-  // machine instruction, a new constant instruction should be created.
-  if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
+  const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
+  unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
+  auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
+  Register Res = find(CI, CurMF);
+  if (Res.isValid())
     return Res;
-
+  LLT LLTy = LLT::scalar(BitWidth);
+  Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+  CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
+  assignIntTypeToVReg(BitWidth, Res, I, TII);
   MachineIRBuilder MIRBuilder(I);
-  createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    MachineInstrBuilder MIB;
-    if (Val || !ZeroAsNull) {
-      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
-                .addDef(Res)
-                .addUse(getSPIRVTypeID(SpvType));
-      addNumImm(APInt(BitWidth, Val), MIB);
-    } else {
-      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
-                .addDef(Res)
-                .addUse(getSPIRVTypeID(SpvType));
-    }
-    const auto &ST = CurMF->getSubtarget();
-    constrainSelectedInstRegOperands(
-        *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
-    return MIB;
-  });
+  SPIRVType *NewType =
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        MachineInstrBuilder MIB;
+        if (Val || !ZeroAsNull) {
+          MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
+                    .addDef(Res)
+                    .addUse(getSPIRVTypeID(SpvType));
+          addNumImm(APInt(BitWidth, Val), MIB);
+        } else {
+          MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
+                    .addDef(Res)
+                    .addUse(getSPIRVTypeID(SpvType));
+        }
+        const auto &ST = CurMF->getSubtarget();
+        constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+                                         *ST.getRegisterInfo(),
+                                         *ST.getRegBankInfo());
+        return MIB;
+      });
+  add(CI, NewType);
   return Res;
 }
 
@@ -383,27 +332,24 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
                                                bool ZeroAsNull) {
   assert(SpvType);
   auto &MF = MIRBuilder.getMF();
-  const IntegerType *LLVMIntTy =
-      cast<IntegerType>(getTypeForSPIRVType(SpvType));
-  // Find a constant in DT or build a new one.
-  const auto ConstInt =
-      ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
-  Register Res = DT.find(ConstInt, &MF);
-  if (!Res.isValid()) {
-    unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
-    LLT LLTy = LLT::scalar(BitWidth);
-    Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
-    MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
-    assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
-                     SPIRV::AccessQualifier::ReadWrite, EmitIR);
-    DT.add(ConstInt, &MIRBuilder.getMF(), Res);
-    if (EmitIR) {
-      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return MIRBuilder.buildConstant(Res, *ConstInt);
-      });
-    } else {
-      Register SpvTypeReg = getSPIRVTypeID(SpvType);
+  const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
+  auto *const CI = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
+  Register Res = find(CI, &MF);
+  if (Res.isValid())
+    return Res;
+
+  unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
+  LLT LLTy = LLT::scalar(BitWidth);
+  Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
+  MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
+  assignTypeToVReg(Ty, Res, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
+                   EmitIR);
+
+  SPIRVType *NewType =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        if (EmitIR)
+          return MIRBuilder.buildConstant(Res, *CI);
+        Register SpvTypeReg = getSPIRVTypeID(SpvType);
         MachineInstrBuilder MIB;
         if (Val || !ZeroAsNull) {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
@@ -421,8 +367,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
                                          *Subtarget.getRegBankInfo());
         return MIB;
       });
-    }
-  }
+  add(ConstInt, NewType);
   return Res;
 }
 
@@ -430,31 +375,30 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
                                               MachineIRBuilder &MIRBuilder,
                                               SPIRVType *SpvType) {
   auto &MF = MIRBuilder.getMF();
-  auto &Ctx = MF.getFunction().getContext();
-  if (!SpvType) {
-    const Type *LLVMFPTy = Type::getFloatTy(Ctx);
-    SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder,
-                                   SPIRV::AccessQualifier::ReadWrite, true);
-  }
-  // Find a constant in DT or build a new one.
-  const auto ConstFP = ConstantFP::get(Ctx, Val);
-  Register Res = DT.find(ConstFP, &MF);
-  if (!Res.isValid()) {
-    Res = MF.getRegInfo().createGenericVirtualRegister(
-        LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
-    MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
-    assignSPIRVTypeToVReg(SpvType, Res, MF);
-    DT.add(ConstFP, &MF, Res);
-    createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-      MachineInstrBuilder MIB;
-      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
-                .addDef(Res)
-                .addUse(getSPIRVTypeID(SpvType));
-      addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
-      return MIB;
-    });
-  }
+  if (!SpvType)
+    SpvType = getOrCreateSPIRVType(
+        Type::getFloatTy(MF.getFunction().getContext()), MIRBuilder,
+        SPIRV::AccessQualifier::ReadWrite, true);
+  auto *const CF = ConstantFP::get(Ctx, Val);
+  Register Res = find(CF, &MF);
+  if (Res.isValid())
+    return Res;
 
+  LLT LLTy = LLT::scalar(getScalarOrVectorBitWidth(SpvType));
+  Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
+  MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
+  assignSPIRVTypeToVReg(SpvType, Res, MF);
+
+  SPIRVType *NewType =
+      createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+        MachineInstrBuilder MIB;
+        MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
+                  .addDef(Res)
+                  .addUse(getSPIRVTypeID(SpvType));
+        addNumImm(CF->getValueAPF().bitcastToAPInt(), MIB);
+        return MIB;
+      });
+  add(CF, NewType);
   return Res;
 }
 
@@ -1044,13 +988,8 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
   if (isSpecialOpaqueType(Ty))
     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
-  auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
-  auto t = TypeToSPIRVTypeMap.find(Ty);
-  if (t != TypeToSPIRVTypeMap.end()) {
-    auto tt = t->second.find(&MIRBuilder.getMF());
-    if (tt != t->second.end())
-      return getSPIRVTypeForVReg(tt->second);
-  }
+  if (Register TyReg = DT.find(Ty, &MIRBuilder.getMF()); TyReg.isValid())
+    return getSPIRVTypeForVReg(TyReg);
 
   if (auto IType = dyn_cast<IntegerType>(Ty)) {
     const unsigned Width = IType->getBitWidth();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 467d9b73e2e39..8840401fe996d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -27,7 +27,7 @@ namespace llvm {
 class SPIRVSubtarget;
 using SPIRVType = const MachineInstr;
 
-class SPIRVGlobalRegistry {
+class SPIRVGlobalRegistry : public SPIRVIRMap {
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -37,9 +37,6 @@ class SPIRVGlobalRegistry {
   DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
       VRegToTypeMap;
 
-  // Map LLVM Type* to <MF, Reg>
-  SPIRVGeneralDuplicatesTracker DT;
-
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
 
   // map a Function to its definition (as a machine instruction operand)
@@ -119,42 +116,6 @@ class SPIRVGlobalRegistry {
 
   MachineFunction *CurMF;
 
-  void add(const Constant *C, MachineFunction *MF, Register R) {
-    DT.add(C, MF, R);
-  }
-
-  void add(const GlobalVariable *GV, MachineFunction *MF, Register R) {
-    DT.add(GV, MF, R);
-  }
-
-  void add(const Function *F, MachineFunction *MF, Register R) {
-    DT.add(F, MF, R);
-  }
-
-  void add(const Argument *Arg, MachineFunction *MF, Register R) {
-    DT.add(Arg, MF, R);
-  }
-
-  void add(const MachineInstr *MI, MachineFunction *MF, Register R) {
-    DT.add(MI, MF, R);
-  }
-
-  Register find(const MachineInstr *MI, MachineFunction *MF) {
-    return DT.find(MI, MF);
-  }
-
-  Register find(const Constant *C, MachineFunction *MF) {
-    return DT.find(C, MF);
-  }
-
-  Register find(const GlobalVariable *GV, MachineFunction *MF) {
-    return DT.find(GV, MF);
-  }
-
-  Register find(const Function *F, MachineFunction *MF) {
-    return DT.find(F, MF);
-  }
-
   void setBound(unsigned V) { Bound = V; }
   unsigned getBound() { return Bound; }
 



More information about the llvm-commits mailing list