[llvm] [SPIRV] Add vector reduction instructions (PR #82786)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 23 12:33:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR is to add vector reduction instructions according to https://llvm.org/docs/GlobalISel/GenericOpcode.html#vector-reduction-operations and widen in such a way a range of successful supported conversions, covering new cases of vector reduction instructions which IRTranslator is unable to resolve.

By legalizing vector reduction instructions we introduce a new instruction patterns that should be addressed, including patterns that are delegated to pre-legalize step. To address this problem, a new pass is added that is to bring newly generated instructions after legalization to an aspect required by instruction selection.

I mark this PR draft for now until I add tests to cover newly supported vector reduction instructions.

---

Patch is 147.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82786.diff


22 Files Affected:

- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+45-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+23) 
- (added) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+170) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+49-49) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+1) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/add.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/and.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fadd.ll (+189) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmax.ll (+177) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmaximum.ll (+177) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmin.ll (+176) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fminimum.ll (+177) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/fmul.ll (+189) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/mul.ll (+232) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/or.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/smax.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/smin.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/umax.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/umin.ll (+233) 
- (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/llvm-vector-reduce/xor.ll (+233) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index d1ada45d17a5bc..afc26dda4c68bd 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -29,6 +29,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
   SPIRVPreLegalizer.cpp
+  SPIRVPostLegalizer.cpp
   SPIRVPrepareFunctions.cpp
   SPIRVRegisterBankInfo.cpp
   SPIRVRegisterInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 9460b0808cae89..6979107349d968 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
+FunctionPass *createSPIRVPostLegalizerPass();
 FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -32,6 +33,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
 void initializeSPIRVPreLegalizerPass(PassRegistry &);
+void initializeSPIRVPostLegalizerPass(PassRegistry &);
 void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7258d3b4d88ed3..6987d54e2b176d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -183,6 +183,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectUnmergeValues(MachineInstr &I) const;
+
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
                             const SPIRVType *ResType = nullptr) const;
 
@@ -235,7 +237,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more.
       auto *Def = MRI->getVRegDef(I.getOperand(1).getReg());
       if (isTypeFoldingSupported(Def->getOpcode())) {
-        auto Res = selectImpl(I, *CoverageInfo);
+        bool Res = selectImpl(I, *CoverageInfo);
         assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
         if (Res)
           return Res;
@@ -263,7 +265,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
   if (spvSelect(ResVReg, ResType, I)) {
     if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs).
-      MRI->setType(ResVReg, LLT::scalar(32));
+      for (unsigned i = 0; i < I.getNumDefs(); ++i)
+        MRI->setType(I.getOperand(i).getReg(), LLT::scalar(32));
     I.removeFromParent();
     return true;
   }
@@ -273,9 +276,9 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
 bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
                                          const SPIRVType *ResType,
                                          MachineInstr &I) const {
-  assert(!isTypeFoldingSupported(I.getOpcode()) ||
-         I.getOpcode() == TargetOpcode::G_CONSTANT);
   const unsigned Opcode = I.getOpcode();
+  if (isTypeFoldingSupported(Opcode) && Opcode != TargetOpcode::G_CONSTANT)
+    return selectImpl(I, *CoverageInfo);
   switch (Opcode) {
   case TargetOpcode::G_CONSTANT:
     return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(),
@@ -504,6 +507,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FENCE:
     return selectFence(I);
 
+  case TargetOpcode::G_UNMERGE_VALUES:
+    return selectUnmergeValues(I);
+
   default:
     return false;
   }
@@ -733,6 +739,41 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
+  unsigned ArgI = I.getNumOperands() - 1;
+  Register SrcReg =
+      I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
+  SPIRVType *DefType =
+      SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
+  if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+    report_fatal_error(
+        "cannot select G_UNMERGE_VALUES with a non-vector argument");
+
+  SPIRVType *ScalarType =
+      GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+  MachineBasicBlock &BB = *I.getParent();
+  bool Res = false;
+  for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+    Register ResVReg = I.getOperand(i).getReg();
+    SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
+    if (!ResType) {
+      // There was no "assign type" actions, let's fix this now
+      ResType = ScalarType;
+      MRI->setRegClass(ResVReg, &SPIRV::IDRegClass);
+      MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
+      GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
+    }
+    auto MIB =
+        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+            .addDef(ResVReg)
+            .addUse(GR.getSPIRVTypeID(ResType))
+            .addUse(SrcReg)
+            .addImm(static_cast<int64_t>(i));
+    Res |= MIB.constrainAllUses(TII, TRI, RBI);
+  }
+  return Res;
+}
+
 bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
   AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm());
   uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 4f2e7a240fc2cc..c3f75463dfd23e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -113,6 +113,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
       v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
+  auto allVectors = {v2s1,  v2s8,   v2s16,  v2s32, v2s64, v3s1,  v3s8,
+                     v3s16, v3s32,  v3s64,  v4s1,  v4s8,  v4s16, v4s32,
+                     v4s64, v8s1,   v8s8,   v8s16, v8s32, v8s64, v16s1,
+                     v16s8, v16s16, v16s32, v16s64};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -146,6 +151,24 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // TODO: add proper rules for vectors legalization.
   getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
 
+  // Vector Reduction Operations
+  getActionDefinitionsBuilder(
+      {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
+       G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
+       G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
+       G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
+      .legalFor(allVectors)
+      .scalarize(1)
+      .lower();
+
+  getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
+      .scalarize(2)
+      .lower();
+
+  // Merge/Unmerge
+  // TODO: add proper legalization rules.
+  getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+
   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
       .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
new file mode 100644
index 00000000000000..da24c779ffe066
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -0,0 +1,170 @@
+//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
+//
+// which may appear after the legalizer pass
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The pass partially apply pre-legalization logic to new instructions inserted
+// as a result of legalization:
+// - assigns SPIR-V types to registers for new instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Target/TargetIntrinsicInfo.h"
+
+#define DEBUG_TYPE "spirv-postlegalizer"
+
+using namespace llvm;
+
+namespace {
+class SPIRVPostLegalizer : public MachineFunctionPass {
+public:
+  static char ID;
+  SPIRVPostLegalizer() : MachineFunctionPass(ID) {
+    initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
+  }
+  bool runOnMachineFunction(MachineFunction &MF) override;
+};
+} // namespace
+
+// Defined in SPIRVLegalizerInfo.cpp.
+extern bool isTypeFoldingSupported(unsigned Opcode);
+
+namespace llvm {
+//  Defined in SPIRVPreLegalizer.cpp.
+extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
+                                  SPIRVGlobalRegistry *GR,
+                                  MachineIRBuilder &MIB,
+                                  MachineRegisterInfo &MRI);
+extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+                         MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
+} // namespace llvm
+
+static bool isMetaInstrGET(unsigned Opcode) {
+  return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
+         Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
+         Opcode == SPIRV::GET_vfID;
+}
+
+static bool mayBeInserted(unsigned Opcode) {
+  switch (Opcode) {
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMAX:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMAXIMUM:
+      return true;
+    default:
+      return isTypeFoldingSupported(Opcode);
+  }
+}
+
+static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+                             MachineIRBuilder MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &I : MBB) {
+      const unsigned Opcode = I.getOpcode();
+      if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
+        unsigned ArgI = I.getNumOperands() - 1;
+        Register SrcReg = I.getOperand(ArgI).isReg()
+                              ? I.getOperand(ArgI).getReg()
+                              : Register(0);
+        SPIRVType *DefType =
+            SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
+        if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+          report_fatal_error(
+              "cannot select G_UNMERGE_VALUES with a non-vector argument");
+        SPIRVType *ScalarType =
+            GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+        for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+          Register ResVReg = I.getOperand(i).getReg();
+          SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
+          if (!ResType) {
+            // There was no "assign type" actions, let's fix this now
+            ResType = ScalarType;
+            MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+            MRI.setType(ResVReg,
+                        LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
+            GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+          }
+        }
+      } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
+                 I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
+        // Legalizer may have added a new instructions and introduced new
+        // registers, we must decorate them as if they were introduced in a
+        // non-automatic way
+        Register ResVReg = I.getOperand(0).getReg();
+        SPIRVType *ResVType = GR->getSPIRVTypeForVReg(ResVReg);
+        // Check if the register defined by the instruction is newly generated
+        // or already processed
+        if (!ResVType) {
+          // Set type of the defined register
+          ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
+          // Check if we have type defined for operands of the new instruction
+          if (!ResVType)
+            continue;
+          // Set type & class
+          MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+          MRI.setType(ResVReg,
+                      LLT::scalar(GR->getScalarOrVectorBitWidth(ResVType)));
+          GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+        }
+        // If this is a simple operation that is to be reduced by TableGen
+        // definition we must apply some of pre-legalizer rules here
+        if (isTypeFoldingSupported(Opcode)) {
+          // Check if the instruction newly generated or already processed
+          MachineInstr *NextMI = I.getNextNode();
+          if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
+            continue;
+          // Restore usual instructions pattern for the newly inserted
+          // instruction
+          MRI.setRegClass(ResVReg, MRI.getType(ResVReg).isVector()
+                                       ? &SPIRV::IDRegClass
+                                       : &SPIRV::ANYIDRegClass);
+          MRI.setType(ResVReg, LLT::scalar(32));
+          insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+          processInstr(I, MIB, MRI, GR);
+        }
+      }
+    }
+  }
+}
+
+bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
+  // Initialize the type registry.
+  const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
+  SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+  GR->setCurrentFunc(MF);
+  MachineIRBuilder MIB(MF);
+
+  processNewInstrs(MF, GR, MIB);
+
+  return true;
+}
+
+INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
+                false)
+
+char SPIRVPostLegalizer::ID = 0;
+
+FunctionPass *llvm::createSPIRVPostLegalizerPass() {
+  return new SPIRVPostLegalizer();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 144216896eb68c..1e92e5ce264f04 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -212,6 +212,34 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
   return SpirvTy;
 }
 
+static std::pair<Register, unsigned>
+createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
+               const SPIRVGlobalRegistry &GR) {
+  LLT NewT = LLT::scalar(32);
+  SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
+  assert(SpvType && "VReg is expected to have SPIRV type");
+  bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
+  bool IsVectorFloat =
+      SpvType->getOpcode() == SPIRV::OpTypeVector &&
+      GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
+          SPIRV::OpTypeFloat;
+  IsFloat |= IsVectorFloat;
+  auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
+  auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
+  if (MRI.getType(ValReg).isPointer()) {
+    NewT = LLT::pointer(0, 32);
+    GetIdOp = SPIRV::GET_pID;
+    DstClass = &SPIRV::pIDRegClass;
+  } else if (MRI.getType(ValReg).isVector()) {
+    NewT = LLT::fixed_vector(2, NewT);
+    GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
+    DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
+  }
+  Register IdReg = MRI.createGenericVirtualRegister(NewT);
+  MRI.setRegClass(IdReg, DstClass);
+  return {IdReg, GetIdOp};
+}
+
 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
 // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -249,6 +277,27 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
   Def->getOperand(0).setReg(NewReg);
   return NewReg;
 }
+
+void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+                  MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
+  unsigned Opc = MI.getOpcode();
+  assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
+  MachineInstr &AssignTypeInst =
+      *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
+  auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
+  AssignTypeInst.getOperand(1).setReg(NewReg);
+  MI.getOperand(0).setReg(NewReg);
+  MIB.setInsertPt(*MI.getParent(),
+                  (MI.getNextNode() ? MI.getNextNode()->getIterator()
+                                    : MI.getParent()->end()));
+  for (auto &Op : MI.operands()) {
+    if (!Op.isReg() || Op.isDef())
+      continue;
+    auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
+    MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
+    Op.setReg(IdOpInfo.first);
+  }
+}
 } // namespace llvm
 
 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
@@ -345,55 +394,6 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
     MI->eraseFromParent();
 }
 
-static std::pair<Register, unsigned>
-createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
-               const SPIRVGlobalRegistry &GR) {
-  LLT NewT = LLT::scalar(32);
-  SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
-  assert(SpvType && "VReg is expected to have SPIRV type");
-  bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
-  bool IsVectorFloat =
-      SpvType->getOpcode() == SPIRV::OpTypeVector &&
-      GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
-          SPIRV::OpTypeFloat;
-  IsFloat |= IsVectorFloat;
-  auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
-  auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
-  if (MRI.getType(ValReg).isPointer()) {
-    NewT = LLT::pointer(0, 32);
-    GetIdOp = SPIRV::GET_pID;
-    DstClass = &SPIRV::pIDRegClass;
-  } else if (MRI.getType(ValReg).isVector()) {
-    NewT = LLT::fixed_vector(2, NewT);
-    GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
-    DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
-  }
-  Register IdReg = MRI.createGenericVirtualRegister(NewT);
-  MRI.setRegClass(IdReg, DstClass);
-  return {IdReg, GetIdOp};
-}
-
-static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
-                         MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
-  unsigned Opc = MI.getOpcode();
-  assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
-  MachineInstr &AssignTypeInst =
-      *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
-  auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
-  AssignTypeInst.getOperand(1).setReg(NewReg);
-  MI.getOperand(0).setReg(NewReg);
-  MIB.setInsertPt(*MI.getParent(),
-                  (MI.getNextNode() ? MI.getNextNode()->getIterator()
-                                    : MI.getParent()->end()));
-  for (auto &Op : MI.operands()) {
-    if (!Op.isReg() || Op.isDef())
-      continue;
-    auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
-    MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
-    Op.setReg(IdOpInfo.first);
-  }
-}
-
 // Defined in SPIRVLegalizerInfo.cpp.
 extern bool isTypeFoldingSupported(unsigned Opcode);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index e1b7bdd3140dbe..fbf64f2b1dfb13 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -189,6 +189,7 @@ void SPIRVPassCon...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82786


More information about the llvm-commits mailing list