[llvm] 74c6671 - [SPIRV] fix several issues in builds with expensive checks

Ilia Diachkov via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 16 13:19:30 PDT 2023


Author: Ilia Diachkov
Date: 2023-03-17T00:08:23+03:00
New Revision: 74c66710a79edb3a0d380079c1c5c82fa441a8e8

URL: https://github.com/llvm/llvm-project/commit/74c66710a79edb3a0d380079c1c5c82fa441a8e8
DIFF: https://github.com/llvm/llvm-project/commit/74c66710a79edb3a0d380079c1c5c82fa441a8e8.diff

LOG: [SPIRV] fix several issues in builds with expensive checks

The patch fixes "Virtual register does not match instruction constraint"
and partly "Illegal virtual register for instruction" fails in the SPIRV
backend builds with LLVM_ENABLE_EXPENSIVE_CHECKS enabled. As a result,
the number of passed LIT tests with enabled checks is doubled.

Also, support for ndrange_*D builtins is placed in a separate function.

Differential Revision: https://reviews.llvm.org/D144897

Added: 
    

Modified: 
    llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
    llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c11b36a088545..40b652057e87f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -291,6 +291,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType,
 
   Register ResultRegister =
       MIRBuilder.getMRI()->createGenericVirtualRegister(Type);
+  MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass);
   GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF());
   return std::make_tuple(ResultRegister, BoolType);
 }
@@ -417,33 +418,41 @@ static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
 }
 
 static Register buildScopeReg(Register CLScopeRegister,
+                              SPIRV::Scope::Scope Scope,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR,
-                              const MachineRegisterInfo *MRI) {
-  auto CLScope =
-      static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
-  SPIRV::Scope::Scope Scope = getSPIRVScope(CLScope);
-
-  if (CLScope == static_cast<unsigned>(Scope))
-    return CLScopeRegister;
-
+                              MachineRegisterInfo *MRI) {
+  if (CLScopeRegister.isValid()) {
+    auto CLScope =
+        static_cast<SPIRV::CLMemoryScope>(getIConstVal(CLScopeRegister, MRI));
+    Scope = getSPIRVScope(CLScope);
+
+    if (CLScope == static_cast<unsigned>(Scope)) {
+      MRI->setRegClass(CLScopeRegister, &SPIRV::IDRegClass);
+      return CLScopeRegister;
+    }
+  }
   return buildConstantIntReg(Scope, MIRBuilder, GR);
 }
 
 static Register buildMemSemanticsReg(Register SemanticsRegister,
-                                     Register PtrRegister,
-                                     const MachineRegisterInfo *MRI,
+                                     Register PtrRegister, unsigned &Semantics,
+                                     MachineIRBuilder &MIRBuilder,
                                      SPIRVGlobalRegistry *GR) {
-  std::memory_order Order =
-      static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
-  unsigned Semantics =
-      getSPIRVMemSemantics(Order) |
-      getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
-
-  if (Order == Semantics)
-    return SemanticsRegister;
+  if (SemanticsRegister.isValid()) {
+    MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+    std::memory_order Order =
+        static_cast<std::memory_order>(getIConstVal(SemanticsRegister, MRI));
+    Semantics =
+        getSPIRVMemSemantics(Order) |
+        getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
 
-  return Register();
+    if (Order == Semantics) {
+      MRI->setRegClass(SemanticsRegister, &SPIRV::IDRegClass);
+      return SemanticsRegister;
+    }
+  }
+  return buildConstantIntReg(Semantics, MIRBuilder, GR);
 }
 
 /// Helper function for translating atomic init to OpStore.
@@ -451,7 +460,8 @@ static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder) {
   assert(Call->Arguments.size() == 2 &&
          "Need 2 arguments for atomic init translation");
-
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpStore)
       .addUse(Call->Arguments[0])
       .addUse(Call->Arguments[1]);
@@ -463,19 +473,22 @@ static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
   Register PtrRegister = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
   // TODO: if true insert call to __translate_ocl_memory_sccope before
   // OpAtomicLoad and the function implementation. We can use Translator's
   // output for transcoding/atomic_explicit_arguments.cl as an example.
   Register ScopeRegister;
-  if (Call->Arguments.size() > 1)
+  if (Call->Arguments.size() > 1) {
     ScopeRegister = Call->Arguments[1];
-  else
+    MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::IDRegClass);
+  } else
     ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
 
   Register MemSemanticsReg;
   if (Call->Arguments.size() > 2) {
     // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad.
     MemSemanticsReg = Call->Arguments[2];
+    MIRBuilder.getMRI()->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
   } else {
     int Semantics =
         SPIRV::MemorySemantics::SequentiallyConsistent |
@@ -499,11 +512,12 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
   Register ScopeRegister =
       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
   Register PtrRegister = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass);
   int Semantics =
       SPIRV::MemorySemantics::SequentiallyConsistent |
       getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
   Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
-
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
       .addUse(PtrRegister)
       .addUse(ScopeRegister)
@@ -525,6 +539,9 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
   Register ObjectPtr = Call->Arguments[0];   // Pointer (volatile A *object.)
   Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected).
   Register Desired = Call->Arguments[2];     // Value (C Desired).
+  MRI->setRegClass(ObjectPtr, &SPIRV::IDRegClass);
+  MRI->setRegClass(ExpectedArg, &SPIRV::IDRegClass);
+  MRI->setRegClass(Desired, &SPIRV::IDRegClass);
   SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired);
   LLT DesiredLLT = MRI->getType(Desired);
 
@@ -564,6 +581,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
       MemSemEqualReg = Call->Arguments[3];
     if (MemOrdNeq == MemSemEqual)
       MemSemUnequalReg = Call->Arguments[4];
+    MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[4], &SPIRV::IDRegClass);
   }
   if (!MemSemEqualReg.isValid())
     MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
@@ -580,6 +599,7 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
     Scope = getSPIRVScope(ClScope);
     if (ClScope == static_cast<unsigned>(Scope))
       ScopeReg = Call->Arguments[5];
+    MRI->setRegClass(Call->Arguments[5], &SPIRV::IDRegClass);
   }
   if (!ScopeReg.isValid())
     ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
@@ -591,6 +611,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
   MRI->setType(Expected, DesiredLLT);
   Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
                             : Call->ReturnRegister;
+  if (!MRI->getRegClassOrNull(Tmp))
+    MRI->setRegClass(Tmp, &SPIRV::IDRegClass);
   GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF());
 
   SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
@@ -614,30 +636,23 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
 static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                                MachineIRBuilder &MIRBuilder,
                                SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
-  Register ScopeRegister;
-
-  if (Call->Arguments.size() >= 4) {
-    assert(Call->Arguments.size() == 4 &&
-           "Too many args for explicit atomic RMW");
-    ScopeRegister = buildScopeReg(Call->Arguments[3], MIRBuilder, GR, MRI);
-  }
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  Register ScopeRegister =
+      Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register();
 
-  if (!ScopeRegister.isValid())
-    ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+  assert(Call->Arguments.size() <= 4 &&
+         "Too many args for explicit atomic RMW");
+  ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup,
+                                MIRBuilder, GR, MRI);
 
   Register PtrRegister = Call->Arguments[0];
   unsigned Semantics = SPIRV::MemorySemantics::None;
-  Register MemSemanticsReg;
-
-  if (Call->Arguments.size() >= 3)
-    MemSemanticsReg =
-        buildMemSemanticsReg(Call->Arguments[2], PtrRegister, MRI, GR);
-
-  if (!MemSemanticsReg.isValid())
-    MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
-
+  MRI->setRegClass(PtrRegister, &SPIRV::IDRegClass);
+  Register MemSemanticsReg =
+      Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
+  MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
+                                         Semantics, MIRBuilder, GR);
+  MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(Opcode)
       .addDef(Call->ReturnRegister)
       .addUse(GR->getSPIRVTypeID(Call->ReturnType))
@@ -653,32 +668,23 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
 static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register PtrRegister = Call->Arguments[0];
   unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent;
-  Register MemSemanticsReg;
-
-  if (Call->Arguments.size() >= 2)
-    MemSemanticsReg =
-        buildMemSemanticsReg(Call->Arguments[1], PtrRegister, MRI, GR);
-
-  if (!MemSemanticsReg.isValid())
-    MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
+  Register MemSemanticsReg =
+      Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register();
+  MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
+                                         Semantics, MIRBuilder, GR);
 
   assert((Opcode != SPIRV::OpAtomicFlagClear ||
           (Semantics != SPIRV::MemorySemantics::Acquire &&
            Semantics != SPIRV::MemorySemantics::AcquireRelease)) &&
          "Invalid memory order argument!");
 
-  SPIRV::Scope::Scope Scope = SPIRV::Scope::Device;
-  Register ScopeRegister;
-
-  if (Call->Arguments.size() >= 3)
-    ScopeRegister = buildScopeReg(Call->Arguments[2], MIRBuilder, GR, MRI);
-
-  if (!ScopeRegister.isValid())
-    ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+  Register ScopeRegister =
+      Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register();
+  ScopeRegister =
+      buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI);
 
   auto MIB = MIRBuilder.buildInstr(Opcode);
   if (Opcode == SPIRV::OpAtomicFlagTestAndSet)
@@ -694,7 +700,7 @@ static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
 static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                              MachineIRBuilder &MIRBuilder,
                              SPIRVGlobalRegistry *GR) {
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
   unsigned MemSemantics = SPIRV::MemorySemantics::None;
 
@@ -716,9 +722,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
   }
 
   Register MemSemanticsReg;
-  if (MemFlags == MemSemantics)
+  if (MemFlags == MemSemantics) {
     MemSemanticsReg = Call->Arguments[0];
-  else
+    MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
+  } else
     MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
 
   Register ScopeReg;
@@ -738,8 +745,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
         (Opcode == SPIRV::OpMemoryBarrier))
       Scope = MemScope;
 
-    if (CLScope == static_cast<unsigned>(Scope))
+    if (CLScope == static_cast<unsigned>(Scope)) {
       ScopeReg = Call->Arguments[1];
+      MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
+    }
   }
 
   if (!ScopeReg.isValid())
@@ -834,7 +843,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   const SPIRV::GroupBuiltin *GroupBuiltin =
       SPIRV::lookupGroupBuiltin(Builtin->Name);
-  const MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register Arg0;
   if (GroupBuiltin->HasBoolArg) {
     Register ConstRegister = Call->Arguments[0];
@@ -876,8 +885,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
     MIB.addImm(GroupBuiltin->GroupOperation);
   if (Call->Arguments.size() > 0) {
     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
-    for (unsigned i = 1; i < Call->Arguments.size(); i++)
+    MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    for (unsigned i = 1; i < Call->Arguments.size(); i++) {
       MIB.addUse(Call->Arguments[i]);
+      MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
+    }
   }
 
   // Build select instruction.
@@ -936,16 +948,17 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
   // If it's out of range (max dimension is 3), we can just return the constant
   // default value (0 or 1 depending on which query function).
   if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) {
-    Register defaultReg = Call->ReturnRegister;
+    Register DefaultReg = Call->ReturnRegister;
     if (PointerSize != ResultWidth) {
-      defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
-      GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg,
+      DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+      MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass);
+      GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg,
                                 MIRBuilder.getMF());
-      ToTruncate = defaultReg;
+      ToTruncate = DefaultReg;
     }
     auto NewRegister =
         GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
-    MIRBuilder.buildCopy(defaultReg, NewRegister);
+    MIRBuilder.buildCopy(DefaultReg, NewRegister);
   } else { // If it could be in range, we need to load from the given builtin.
     auto Vec3Ty =
         GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder);
@@ -956,6 +969,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
     Register Extracted = Call->ReturnRegister;
     if (!IsConstantIndex || PointerSize != ResultWidth) {
       Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+      MRI->setRegClass(Extracted, &SPIRV::IDRegClass);
       GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF());
     }
     // Use Intrinsic::spv_extractelt so dynamic vs static extraction is
@@ -974,6 +988,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
 
       Register CompareRegister =
           MRI->createGenericVirtualRegister(LLT::scalar(1));
+      MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass);
       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
 
       // Use G_ICMP to check if idxVReg < 3.
@@ -990,6 +1005,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       if (PointerSize != ResultWidth) {
         SelectionResult =
             MRI->createGenericVirtualRegister(LLT::scalar(PointerSize));
+        MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass);
         GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult,
                                   MIRBuilder.getMF());
       }
@@ -1125,6 +1141,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
   if (NumExpectedRetComponents != NumActualRetComponents) {
     QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister(
         LLT::fixed_vector(NumActualRetComponents, 32));
+    MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass);
     SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder);
     QueryResultType = GR->getOrCreateSPIRVVectorType(
         IntTy, NumActualRetComponents, MIRBuilder);
@@ -1133,6 +1150,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
   bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer;
   unsigned Opcode =
       IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod;
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
   auto MIB = MIRBuilder.buildInstr(Opcode)
                  .addDef(QueryResult)
                  .addUse(GR->getSPIRVTypeID(QueryResultType))
@@ -1177,6 +1195,7 @@ static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
 
   Register Image = Call->Arguments[0];
+  MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass);
   SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
       GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
 
@@ -1239,8 +1258,13 @@ static bool generateReadImageInst(const StringRef DemangledCall,
                                   SPIRVGlobalRegistry *GR) {
   Register Image = Call->Arguments[0];
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-
-  if (DemangledCall.contains_insensitive("ocl_sampler")) {
+  MRI->setRegClass(Image, &SPIRV::IDRegClass);
+  MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler");
+  bool HasMsaa = DemangledCall.contains_insensitive("msaa");
+  if (HasOclSampler || HasMsaa)
+    MRI->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
+  if (HasOclSampler) {
     Register Sampler = Call->Arguments[1];
 
     if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) &&
@@ -1274,6 +1298,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
     }
     LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType));
     Register TempRegister = MRI->createGenericVirtualRegister(LLType);
+    MRI->setRegClass(TempRegister, &SPIRV::IDRegClass);
     GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF());
 
     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
@@ -1290,7 +1315,7 @@ static bool generateReadImageInst(const StringRef DemangledCall,
           .addUse(GR->getSPIRVTypeID(Call->ReturnType))
           .addUse(TempRegister)
           .addImm(0);
-  } else if (DemangledCall.contains_insensitive("msaa")) {
+  } else if (HasMsaa) {
     MIRBuilder.buildInstr(SPIRV::OpImageRead)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
@@ -1311,6 +1336,9 @@ static bool generateReadImageInst(const StringRef DemangledCall,
 static bool generateWriteImageInst(const SPIRV::IncomingCall *Call,
                                    MachineIRBuilder &MIRBuilder,
                                    SPIRVGlobalRegistry *GR) {
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
   MIRBuilder.buildInstr(SPIRV::OpImageWrite)
       .addUse(Call->Arguments[0])  // Image.
       .addUse(Call->Arguments[1])  // Coordinate.
@@ -1322,10 +1350,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
                                     const SPIRV::IncomingCall *Call,
                                     MachineIRBuilder &MIRBuilder,
                                     SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   if (Call->Builtin->Name.contains_insensitive(
           "__translate_sampler_initializer")) {
     // Build sampler literal.
-    uint64_t Bitmask = getIConstVal(Call->Arguments[0], MIRBuilder.getMRI());
+    uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI);
     Register Sampler = GR->buildConstantSampler(
         Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask),
         getSamplerParamFromBitmask(Bitmask),
@@ -1340,7 +1369,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
     Register SampledImage =
         Call->ReturnRegister.isValid()
             ? Call->ReturnRegister
-            : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
+            : MRI->createVirtualRegister(&SPIRV::IDRegClass);
     MIRBuilder.buildInstr(SPIRV::OpSampledImage)
         .addDef(SampledImage)
         .addUse(GR->getSPIRVTypeID(SampledImageType))
@@ -1356,6 +1385,10 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
       ReturnType = ReturnType.substr(0, ReturnType.find('('));
     }
     SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
+    MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+    MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
+
     MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Type))
@@ -1431,6 +1464,75 @@ static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
   }
 }
 
+static bool buildNDRange(const SPIRV::IncomingCall *Call,
+                         MachineIRBuilder &MIRBuilder,
+                         SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+  SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+  assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
+         PtrType->getOperand(2).isReg());
+  Register TypeReg = PtrType->getOperand(2).getReg();
+  SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
+  MachineFunction &MF = MIRBuilder.getMF();
+  Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF);
+  // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
+  // three other arguments, so pass zero constant on absence.
+  unsigned NumArgs = Call->Arguments.size();
+  assert(NumArgs >= 2);
+  Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
+  MRI->setRegClass(GlobalWorkSize, &SPIRV::IDRegClass);
+  Register LocalWorkSize =
+      NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
+  if (LocalWorkSize.isValid())
+    MRI->setRegClass(LocalWorkSize, &SPIRV::IDRegClass);
+  Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
+  if (GlobalWorkOffset.isValid())
+    MRI->setRegClass(GlobalWorkOffset, &SPIRV::IDRegClass);
+  if (NumArgs < 4) {
+    Register Const;
+    SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
+    if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
+      MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
+      assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
+             DefInstr->getOperand(3).isReg());
+      Register GWSPtr = DefInstr->getOperand(3).getReg();
+      if (!MRI->getRegClassOrNull(GWSPtr))
+        MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass);
+      // TODO: Maybe simplify generation of the type of the fields.
+      unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2;
+      unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
+      Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth);
+      Type *FieldTy = ArrayType::get(BaseTy, Size);
+      SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
+      GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+      GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF);
+      MIRBuilder.buildInstr(SPIRV::OpLoad)
+          .addDef(GlobalWorkSize)
+          .addUse(GR->getSPIRVTypeID(SpvFieldTy))
+          .addUse(GWSPtr);
+      Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
+    } else {
+      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+    }
+    if (!LocalWorkSize.isValid())
+      LocalWorkSize = Const;
+    if (!GlobalWorkOffset.isValid())
+      GlobalWorkOffset = Const;
+  }
+  assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid());
+  MIRBuilder.buildInstr(SPIRV::OpBuildNDRange)
+      .addDef(TmpReg)
+      .addUse(TypeReg)
+      .addUse(GlobalWorkSize)
+      .addUse(LocalWorkSize)
+      .addUse(GlobalWorkOffset);
+  return MIRBuilder.buildInstr(SPIRV::OpStore)
+      .addUse(Call->Arguments[0])
+      .addUse(TmpReg);
+}
+
 static MachineInstr *getBlockStructInstr(Register ParamReg,
                                          MachineRegisterInfo *MRI) {
   // We expect the following sequence of instructions:
@@ -1538,9 +1640,8 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg =
-          MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
-      MIRBuilder.getMRI()->setType(Reg, LLType);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+      MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(Intrinsic::spv_gep,
                                                ArrayRef<Register>{Reg}, true);
@@ -1605,6 +1706,7 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
   switch (Opcode) {
   case SPIRV::OpRetainEvent:
   case SPIRV::OpReleaseEvent:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]);
   case SPIRV::OpCreateUserEvent:
   case SPIRV::OpGetDefaultQueue:
@@ -1612,77 +1714,27 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType));
   case SPIRV::OpIsValidEvent:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addDef(Call->ReturnRegister)
         .addUse(GR->getSPIRVTypeID(Call->ReturnType))
         .addUse(Call->Arguments[0]);
   case SPIRV::OpSetUserEventStatus:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addUse(Call->Arguments[0])
         .addUse(Call->Arguments[1]);
   case SPIRV::OpCaptureEventProfilingInfo:
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+    MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass);
     return MIRBuilder.buildInstr(Opcode)
         .addUse(Call->Arguments[0])
         .addUse(Call->Arguments[1])
         .addUse(Call->Arguments[2]);
-  case SPIRV::OpBuildNDRange: {
-    MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-    SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
-    assert(PtrType->getOpcode() == SPIRV::OpTypePointer &&
-           PtrType->getOperand(2).isReg());
-    Register TypeReg = PtrType->getOperand(2).getReg();
-    SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg);
-    Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF());
-    // Skip the first arg, it's the destination pointer. OpBuildNDRange takes
-    // three other arguments, so pass zero constant on absence.
-    unsigned NumArgs = Call->Arguments.size();
-    assert(NumArgs >= 2);
-    Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2];
-    Register LocalWorkSize =
-        NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3];
-    Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1];
-    if (NumArgs < 4) {
-      Register Const;
-      SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize);
-      if (SpvTy->getOpcode() == SPIRV::OpTypePointer) {
-        MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize);
-        assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) &&
-               DefInstr->getOperand(3).isReg());
-        Register GWSPtr = DefInstr->getOperand(3).getReg();
-        // TODO: Maybe simplify generation of the type of the fields.
-        unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2;
-        unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32;
-        Type *BaseTy = IntegerType::get(
-            MIRBuilder.getMF().getFunction().getContext(), BitWidth);
-        Type *FieldTy = ArrayType::get(BaseTy, Size);
-        SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder);
-        GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-        GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize,
-                                  MIRBuilder.getMF());
-        MIRBuilder.buildInstr(SPIRV::OpLoad)
-            .addDef(GlobalWorkSize)
-            .addUse(GR->getSPIRVTypeID(SpvFieldTy))
-            .addUse(GWSPtr);
-        Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
-      } else {
-        Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
-      }
-      if (!LocalWorkSize.isValid())
-        LocalWorkSize = Const;
-      if (!GlobalWorkOffset.isValid())
-        GlobalWorkOffset = Const;
-    }
-    MIRBuilder.buildInstr(Opcode)
-        .addDef(TmpReg)
-        .addUse(TypeReg)
-        .addUse(GlobalWorkSize)
-        .addUse(LocalWorkSize)
-        .addUse(GlobalWorkOffset);
-    return MIRBuilder.buildInstr(SPIRV::OpStore)
-        .addUse(Call->Arguments[0])
-        .addUse(TmpReg);
-  }
+  case SPIRV::OpBuildNDRange:
+    return buildNDRange(Call, MIRBuilder, GR);
   case SPIRV::OpEnqueueKernel:
     return buildEnqueueKernel(Call, MIRBuilder, GR);
   default:
@@ -1817,16 +1869,23 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
   }
   // Add a pointer to the value to load/store.
   MIB.addUse(Call->Arguments[0]);
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
   // Add a value to store.
-  if (!IsLoad)
+  if (!IsLoad) {
     MIB.addUse(Call->Arguments[1]);
+    MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
+  }
   // Add optional memory attributes and an alignment.
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned NumArgs = Call->Arguments.size();
-  if ((IsLoad && NumArgs >= 2) || NumArgs >= 3)
+  if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) {
     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI));
-  if ((IsLoad && NumArgs >= 3) || NumArgs >= 4)
+    MRI->setRegClass(Call->Arguments[IsLoad ? 1 : 2], &SPIRV::IDRegClass);
+  }
+  if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) {
     MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI));
+    MRI->setRegClass(Call->Arguments[IsLoad ? 2 : 3], &SPIRV::IDRegClass);
+  }
   return true;
 }
 
@@ -1846,6 +1905,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   SPIRVType *ReturnType = nullptr;
   if (OrigRetTy && !OrigRetTy->isVoidTy()) {
     ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder);
+    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
+      MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass);
   } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
     ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
     MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));

diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 8b618686ee7da..47b25a1f83515 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -374,6 +374,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     FTy = getOriginalFunctionType(*CF);
   }
 
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   Register ResVReg =
       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
   std::string FuncName = Info.Callee.getGlobal()->getName().str();
@@ -410,8 +411,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
     for (const Argument &Arg : CF->args()) {
       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
         continue; // Don't handle zero sized types.
-      ToInsert.push_back(
-          {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
+      Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+      MRI->setRegClass(Reg, &SPIRV::IDRegClass);
+      ToInsert.push_back({Reg});
       VRegArgs.push_back(ToInsert.back());
     }
     // TODO: Reuse FunctionLoweringInfo

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 062188abbf5e8..c77a7f860eda2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -143,6 +143,7 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     LLT LLTy = LLT::scalar(32);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     if (MIRBuilder)
       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
     else
@@ -202,6 +203,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
+    MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
@@ -247,6 +249,7 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
   if (!Res.isValid()) {
     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
+    MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
     DT.add(ConstFP, &MF, Res);
     MIRBuilder.buildFConstant(Res, *ConstFP);
@@ -272,6 +275,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     MachineInstrBuilder MIB;
@@ -343,6 +347,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
     Register SpvVecConst =
         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
     DT.add(CA, CurMF, SpvVecConst);
     if (EmitIR) {
@@ -411,6 +416,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
   if (!Res.isValid()) {
     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
         .addDef(Res)
@@ -1090,6 +1096,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
     return Res;
   LLT LLTy = LLT::scalar(32);
   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+  CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
   DT.add(UV, CurMF, Res);
 

diff  --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 27d0e8a976f0d..2818329ece3cb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -85,6 +85,9 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) {
     Register Reg = MI->getOperand(2).getReg();
     if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end())
       Reg = RegsAlreadyAddedToDT[MI];
+    auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg());
+    if (!MRI.getRegClassOrNull(Reg) && RC)
+      MRI.setRegClass(Reg, RC);
     MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg);
     MI->eraseFromParent();
   }
@@ -201,8 +204,12 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
                                       : Def->getParent()->end()));
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
-  if (auto *RC = MRI.getRegClassOrNull(Reg))
+  if (auto *RC = MRI.getRegClassOrNull(Reg)) {
     MRI.setRegClass(NewReg, RC);
+  } else {
+    MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
+    MRI.setRegClass(Reg, &SPIRV::IDRegClass);
+  }
   SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
   // This is to make it convenient for Legalizer to get the SPIRVType
@@ -217,7 +224,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
       .addUse(GR->getSPIRVTypeID(SpirvTy))
       .setMIFlags(Flags);
   Def->getOperand(0).setReg(NewReg);
-  MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass);
   return NewReg;
 }
 } // namespace llvm


        


More information about the llvm-commits mailing list