[llvm] [SPIRV] Use a worklist in the post-legalizer (PR #165027)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 28 10:01:27 PDT 2025


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/165027

>From f2f29a52e3c61d52dbcb2f4728318305026016b3 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 13:21:24 -0400
Subject: [PATCH 1/2] [SPIRV] Use a worklist in the post-legalizer

This commit refactors the SPIRV post-legalizer to use a worklist to process
new instructions. Previously, the post-legalizer would iterate through all
instructions and try to assign types. This could fail if a new instruction
depended on another new instruction that had not been processed yet.

The new implementation adds all new instructions that require a SPIR-V type
to a worklist. It then iteratively processes the worklist until it is empty.
This ensures that all dependencies are met before an instruction is
processed.

This change makes the post-legalizer more robust and fixes potential ordering
issues with newly generated instructions.

Existing tests cover existing functionality. More tests will be added as
the legalizer is modified.

Part of #153091
---
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 412 ++++++++++++++++---
 1 file changed, 359 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..b6c650c802247 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
 #include "SPIRV.h"
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
 #include <stack>
 
 #define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
 
 static bool mayBeInserted(unsigned Opcode) {
   switch (Opcode) {
+  case TargetOpcode::G_CONSTANT:
+  case TargetOpcode::G_UNMERGE_VALUES:
+  case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+  case TargetOpcode::G_INTRINSIC:
+  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
   case TargetOpcode::G_SMAX:
   case TargetOpcode::G_UMAX:
   case TargetOpcode::G_SMIN:
@@ -53,73 +59,372 @@ static bool mayBeInserted(unsigned Opcode) {
   case TargetOpcode::G_FMINIMUM:
   case TargetOpcode::G_FMAXNUM:
   case TargetOpcode::G_FMAXIMUM:
+  case TargetOpcode::G_IMPLICIT_DEF:
+  case TargetOpcode::G_BUILD_VECTOR:
+  case TargetOpcode::G_ICMP:
+  case TargetOpcode::G_ANYEXT:
     return true;
   default:
     return isTypeFoldingSupported(Opcode);
   }
 }
 
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
-                             MachineIRBuilder MIB) {
+static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
+                                         SPIRVGlobalRegistry *GR,
+                                         MachineIRBuilder &MIB,
+                                         Register ResVReg) {
   MachineRegisterInfo &MRI = MF.getRegInfo();
+  const LLT &Ty = MRI.getType(ResVReg);
+  unsigned BitWidth = Ty.getScalarSizeInBits();
+  return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+}
 
-  for (MachineBasicBlock &MBB : MF) {
-    for (MachineInstr &I : MBB) {
-      const unsigned Opcode = I.getOpcode();
-      if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
-        unsigned ArgI = I.getNumOperands() - 1;
-        Register SrcReg = I.getOperand(ArgI).isReg()
-                              ? I.getOperand(ArgI).getReg()
-                              : Register(0);
-        SPIRVType *DefType =
-            SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
-        if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
-          report_fatal_error(
-              "cannot select G_UNMERGE_VALUES with a non-vector argument");
-        SPIRVType *ScalarType =
-            GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
-        for (unsigned i = 0; i < I.getNumDefs(); ++i) {
-          Register ResVReg = I.getOperand(i).getReg();
-          SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
-          if (!ResType) {
-            // There was no "assign type" actions, let's fix this now
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+                                           SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+  if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+    if (DefType->getOpcode() == SPIRV::OpTypeVector) {
+      SPIRVType *ScalarType =
+          GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+      for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+        Register DefReg = I->getOperand(i).getReg();
+        if (!GR->getSPIRVTypeForVReg(DefReg)) {
+          LLT DefLLT = MRI.getType(DefReg);
+          SPIRVType *ResType;
+          if (DefLLT.isVector()) {
+            const SPIRVInstrInfo *TII =
+                MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+            ResType = GR->getOrCreateSPIRVVectorType(
+                ScalarType, DefLLT.getNumElements(), *I, *TII);
+          } else {
             ResType = ScalarType;
-            setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
           }
+          setRegClassType(DefReg, ResType, GR, &MRI, MF);
         }
-      } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
-                 I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
-        // Legalizer may have added a new instructions and introduced new
-        // registers, we must decorate them as if they were introduced in a
-        // non-automatic way
-        Register ResVReg = I.getOperand(0).getReg();
-        // Check if the register defined by the instruction is newly generated
-        // or already processed
-        // Check if we have type defined for operands of the new instruction
-        bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
-        SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
-            IsKnownReg ? ResVReg : I.getOperand(1).getReg());
-        if (!ResVType)
-          continue;
-        // Set type & class
-        if (!IsKnownReg)
-          setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
-        // If this is a simple operation that is to be reduced by TableGen
-        // definition we must apply some of pre-legalizer rules here
-        if (isTypeFoldingSupported(Opcode)) {
-          processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
-          if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
-            MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
-            if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
-              continue;
-          }
-          insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+      }
+      return true;
+    }
+  }
+  return false;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
+                                                 MachineFunction &MF,
+                                                 SPIRVGlobalRegistry *GR,
+                                                 Register ResVReg) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  Register VecReg = I->getOperand(1).getReg();
+  if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+    assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+    return GR->getScalarOrVectorComponentType(VecType);
+  }
+
+  // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+  // If so get the type from there.
+  for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+    if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+      Register BuildVecResReg = Use.getOperand(0).getReg();
+      if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+        return GR->getScalarOrVectorComponentType(BuildVecType);
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
+                                            MachineFunction &MF,
+                                            SPIRVGlobalRegistry *GR,
+                                            MachineIRBuilder &MIB,
+                                            Register ResVReg) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  // First check if any of the operands have a type.
+  for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+    if (SPIRVType *OpType =
+            GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+      const LLT &ResLLT = MRI.getType(ResVReg);
+      return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
+                                            MIB, false);
+    }
+  }
+  // If that did not work, then check the uses.
+  for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+    if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+      Register ExtractResReg = Use.getOperand(0).getReg();
+      if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+        const LLT &ResLLT = MRI.getType(ResVReg);
+        return GR->getOrCreateSPIRVVectorType(
+            ScalarType, ResLLT.getNumElements(), MIB, false);
+      }
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
+                                            MachineFunction &MF,
+                                            SPIRVGlobalRegistry *GR,
+                                            Register ResVReg) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+    const unsigned UseOpc = Use.getOpcode();
+    assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+           UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+    // It's possible that the use instruction has not been processed yet.
+    // We should look at the operands of the use to determine the type.
+    for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+      if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+        return Type;
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
+                                          SPIRVGlobalRegistry *GR,
+                                          MachineIRBuilder &MIB,
+                                          Register ResVReg) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
+    return nullptr;
+
+  for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+    const unsigned UseOpc = Use.getOpcode();
+    assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+           UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+    Register UseResultReg = Use.getOperand(0).getReg();
+    if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+      SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+      const LLT &BitcastLLT = MRI.getType(ResVReg);
+      if (BitcastLLT.isVector())
+        return GR->getOrCreateSPIRVVectorType(
+            ScalarType, BitcastLLT.getNumElements(), MIB, false);
+      return ScalarType;
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
+                                       SPIRVGlobalRegistry *GR,
+                                       MachineIRBuilder &MIB,
+                                       Register ResVReg) {
+  // The result type of G_ANYEXT cannot be inferred from its operand.
+  // We use the result register's LLT to determine the correct integer type.
+  const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+  if (!ResLLT.isScalar())
+    return nullptr;
+  return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+}
+
+static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
+                                       SPIRVGlobalRegistry *GR) {
+  if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
+      !I->getOperand(1).isReg())
+    return nullptr;
+
+  SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
+  if (!OpType)
+    return nullptr;
+  return OpType;
+}
+
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+                                     SPIRVGlobalRegistry *GR,
+                                     MachineIRBuilder &MIB) {
+  LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const unsigned Opcode = I->getOpcode();
+  Register ResVReg = I->getOperand(0).getReg();
+  SPIRVType *ResType = nullptr;
+
+  switch (Opcode) {
+  case TargetOpcode::G_CONSTANT: {
+    ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
+    break;
+  }
+  case TargetOpcode::G_UNMERGE_VALUES: {
+    // This one is special as it defines multiple registers.
+    if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
+      return true;
+    break;
+  }
+  case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+    ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
+    break;
+  }
+  case TargetOpcode::G_BUILD_VECTOR: {
+    ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
+    break;
+  }
+  case TargetOpcode::G_ANYEXT: {
+    ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
+    break;
+  }
+  case TargetOpcode::G_IMPLICIT_DEF: {
+    ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
+    break;
+  }
+  case TargetOpcode::G_INTRINSIC:
+  case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
+    ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
+    break;
+  }
+  default:
+    ResType = deduceTypeForDefault(I, MF, GR);
+    break;
+  }
+
+  if (ResType) {
+    LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+    GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+    if (!MRI.getRegClassOrNull(ResVReg)) {
+      LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+      setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+    }
+    return true;
+  }
+  return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+                              MachineRegisterInfo &MRI) {
+  LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+                    << I;);
+  if (I.getNumDefs() == 0) {
+    LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+    return false;
+  }
+  if (!mayBeInserted(I.getOpcode())) {
+    LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+    return false;
+  }
+
+  Register ResultRegister = I.defs().begin()->getReg();
+  if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+    LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+    if (!MRI.getRegClassOrNull(ResultRegister)) {
+      LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+      setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+                      GR, &MRI, *GR->CurMF, true);
+    }
+    return false;
+  }
+
+  return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+                                                SPIRVGlobalRegistry *GR,
+                                                MachineIRBuilder MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  SmallVector<MachineInstr *, 8> Worklist;
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &I : MBB) {
+      if (requiresSpirvType(I, GR, MRI)) {
+        Worklist.push_back(&I);
+      }
+    }
+  }
+
+  if (Worklist.empty()) {
+    LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+    return;
+  }
+
+  LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+             for (auto *I : Worklist) { I->dump(); });
+
+  bool Changed = true;
+  while (Changed) {
+    Changed = false;
+    SmallVector<MachineInstr *, 8> NextWorklist;
+
+    for (MachineInstr *I : Worklist) {
+      if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+        Changed = true;
+      } else {
+        NextWorklist.push_back(I);
+      }
+    }
+    Worklist = NextWorklist;
+    LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+  }
+
+  if (!Worklist.empty()) {
+    LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+               for (auto *I : Worklist) { I->dump(); });
+    assert(Worklist.empty() && "Worklist is not empty");
+  }
+}
+
+static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
+                                           SPIRVGlobalRegistry *GR,
+                                           MachineIRBuilder MIB) {
+  LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
+                    << MF.getName() << "\n");
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (!isTypeFoldingSupported(MI.getOpcode()))
+        continue;
+      if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
+        continue;
+
+      LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
+
+      // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
+      bool HasAssignType = false;
+      Register ResultRegister = MI.defs().begin()->getReg();
+      // All uses of Result register
+      for (MachineInstr &UseInstr :
+           MRI.use_nodbg_instructions(ResultRegister)) {
+        if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
+          HasAssignType = true;
+          LLVM_DEBUG(dbgs() << "  Instruction already has an ASSIGN_TYPE use: "
+                            << UseInstr);
+          break;
         }
       }
+
+      if (!HasAssignType) {
+        Register ResultRegister = MI.defs().begin()->getReg();
+        SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
+        LLVM_DEBUG(
+            dbgs() << "  Adding ASSIGN_TYPE for ResultRegister: "
+                   << printReg(ResultRegister, MRI.getTargetRegisterInfo())
+                   << " with type: " << *ResultType);
+        insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
+      }
     }
   }
 }
 
+static void lowerExtractVectorElements(MachineFunction &MF) {
+  SmallVector<MachineInstr *, 8> ExtractInstrs;
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+        ExtractInstrs.push_back(&MI);
+      }
+    }
+  }
+
+  for (MachineInstr *MI : ExtractInstrs) {
+    MachineIRBuilder MIB(*MI);
+    Register Dst = MI->getOperand(0).getReg();
+    Register Vec = MI->getOperand(1).getReg();
+    Register Idx = MI->getOperand(2).getReg();
+
+    auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false);
+    Intr.addUse(Vec);
+    Intr.addUse(Idx);
+
+    MI->eraseFromParent();
+  }
+}
+
 // Do a preorder traversal of the CFG starting from the BB |Start|.
 // point. Calls |op| on each basic block encountered during the traversal.
 void visit(MachineFunction &MF, MachineBasicBlock &Start,
@@ -156,8 +461,9 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
   GR->setCurrentFunc(MF);
   MachineIRBuilder MIB(MF);
-
-  processNewInstrs(MF, GR, MIB);
+  registerSpirvTypeForNewInstructions(MF, GR, MIB);
+  ensureAssignTypeForTypeFolding(MF, GR, MIB);
+  lowerExtractVectorElements(MF);
 
   return true;
 }

>From c248de25c59edffdfa2a11e05d78610b10488306 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 28 Oct 2025 13:01:07 -0400
Subject: [PATCH 2/2] Set insertion point in MIB.

---
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 22 +++++++++++++-------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index b6c650c802247..69de5a6360c66 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -138,11 +138,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
                                             MachineIRBuilder &MIB,
                                             Register ResVReg) {
   MachineRegisterInfo &MRI = MF.getRegInfo();
+  LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n");
   // First check if any of the operands have a type.
   for (unsigned i = 1; i < I->getNumOperands(); ++i) {
     if (SPIRVType *OpType =
             GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
       const LLT &ResLLT = MRI.getType(ResVReg);
+      LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type "
+                        << *OpType << ", returning vector type\n");
       return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
                                             MIB, false);
     }
@@ -153,11 +156,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
       Register ExtractResReg = Use.getOperand(0).getReg();
       if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
         const LLT &ResLLT = MRI.getType(ResVReg);
+        LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type "
+                          << *ScalarType << ", returning vector type\n");
         return GR->getOrCreateSPIRVVectorType(
             ScalarType, ResLLT.getNumElements(), MIB, false);
       }
     }
   }
+  LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n");
   return nullptr;
 }
 
@@ -191,7 +197,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
   for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
     const unsigned UseOpc = Use.getOpcode();
     assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
-           UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+           UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
+           UseOpc == TargetOpcode::G_BUILD_VECTOR);
     Register UseResultReg = Use.getOperand(0).getReg();
     if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
       SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
@@ -316,8 +323,7 @@ static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
 }
 
 static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
-                                                SPIRVGlobalRegistry *GR,
-                                                MachineIRBuilder MIB) {
+                                                SPIRVGlobalRegistry *GR) {
   MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<MachineInstr *, 8> Worklist;
   for (MachineBasicBlock &MBB : MF) {
@@ -342,6 +348,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
     SmallVector<MachineInstr *, 8> NextWorklist;
 
     for (MachineInstr *I : Worklist) {
+      MachineIRBuilder MIB(*I);
       if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
         Changed = true;
       } else {
@@ -360,8 +367,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
 }
 
 static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
-                                           SPIRVGlobalRegistry *GR,
-                                           MachineIRBuilder MIB) {
+                                           SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
                     << MF.getName() << "\n");
   MachineRegisterInfo &MRI = MF.getRegInfo();
@@ -395,6 +401,7 @@ static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
             dbgs() << "  Adding ASSIGN_TYPE for ResultRegister: "
                    << printReg(ResultRegister, MRI.getTargetRegisterInfo())
                    << " with type: " << *ResultType);
+        MachineIRBuilder MIB(MI);
         insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
       }
     }
@@ -460,9 +467,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
   const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
   SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
   GR->setCurrentFunc(MF);
-  MachineIRBuilder MIB(MF);
-  registerSpirvTypeForNewInstructions(MF, GR, MIB);
-  ensureAssignTypeForTypeFolding(MF, GR, MIB);
+  registerSpirvTypeForNewInstructions(MF, GR);
+  ensureAssignTypeForTypeFolding(MF, GR);
   lowerExtractVectorElements(MF);
 
   return true;



More information about the llvm-commits mailing list