[llvm] [SPIRV] Use a worklist in the post-legalizer (PR #165027)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 7 09:14:30 PST 2025


================
@@ -43,83 +44,334 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
                          SPIRVType *KnownResType);
 } // namespace llvm
 
-static bool mayBeInserted(unsigned Opcode) {
-  switch (Opcode) {
-  case TargetOpcode::G_SMAX:
-  case TargetOpcode::G_UMAX:
-  case TargetOpcode::G_SMIN:
-  case TargetOpcode::G_UMIN:
-  case TargetOpcode::G_FMINNUM:
-  case TargetOpcode::G_FMINIMUM:
-  case TargetOpcode::G_FMAXNUM:
-  case TargetOpcode::G_FMAXIMUM:
-    return true;
+static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
+                                          MachineIRBuilder &MIB,
+                                          SPIRVGlobalRegistry *GR) {
+  const LLT &Ty = MIB.getMRI()->getType(ResVReg);
+  return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
+}
+
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+                                           SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+  SPIRVType *ScalarType = nullptr;
+  if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+    assert(DefType->getOpcode() == SPIRV::OpTypeVector);
+    ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+  }
+
+  if (!ScalarType) {
+    // If we could not deduce the type from the source, try to deduce it from
+    // the uses of the results.
+    for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
+      for (const auto &Use :
+           MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
+        assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
+               "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
+        if (auto *VecType =
+                GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
+          ScalarType = GR->getScalarOrVectorComponentType(VecType);
+          break;
+        }
+      }
+    }
+  }
+
+  if (!ScalarType)
+    return false;
+
+  for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+    Register DefReg = I->getOperand(i).getReg();
+    if (GR->getSPIRVTypeForVReg(DefReg))
+      continue;
+
+    LLT DefLLT = MRI.getType(DefReg);
+    SPIRVType *ResType =
+        DefLLT.isVector()
+            ? GR->getOrCreateSPIRVVectorType(
+                  ScalarType, DefLLT.getNumElements(), *I,
+                  *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
+            : ScalarType;
+    setRegClassType(DefReg, ResType, GR, &MRI, MF);
+  }
+  return true;
+}
+
+static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
+                                              MachineIRBuilder &MIB,
+                                              SPIRVGlobalRegistry *GR,
+                                              unsigned OpIdx) {
+  Register OpReg = I->getOperand(OpIdx).getReg();
+  if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
+    if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+      Register ResVReg = I->getOperand(0).getReg();
+      const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+      if (ResLLT.isVector())
+        return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
+                                              MIB, false);
+      return CompType;
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
+                                             MachineIRBuilder &MIB,
+                                             SPIRVGlobalRegistry *GR,
+                                             unsigned StartOp, unsigned EndOp) {
+  for (unsigned i = StartOp; i < EndOp; ++i) {
+    if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i))
+      return Type;
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
+                                              Register UseRegister,
+                                              SPIRVGlobalRegistry *GR,
+                                              MachineIRBuilder &MIB) {
+  for (const MachineOperand &MO : Use->defs()) {
+    if (!MO.isReg())
+      continue;
+    if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) {
+      if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+        const LLT &ResLLT = MIB.getMRI()->getType(UseRegister);
+        if (ResLLT.isVector())
+          return GR->getOrCreateSPIRVVectorType(
+              CompType, ResLLT.getNumElements(), MIB, false);
+        return CompType;
+      }
+    }
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
+                                     SPIRVGlobalRegistry *GR,
+                                     MachineIRBuilder &MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
+    SPIRVType *ResType = nullptr;
+    switch (Use.getOpcode()) {
+    case TargetOpcode::G_BUILD_VECTOR:
+    case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+    case TargetOpcode::G_UNMERGE_VALUES:
+      LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
+      ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
+      break;
+    }
+    if (ResType)
+      return ResType;
+  }
+  return nullptr;
+}
+
+static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
+                                               SPIRVGlobalRegistry *GR,
+                                               MachineIRBuilder &MIB) {
+  Register ResVReg = I->getOperand(0).getReg();
+  switch (I->getOpcode()) {
+  case TargetOpcode::G_CONSTANT:
+  case TargetOpcode::G_ANYEXT:
+    return deduceIntTypeFromResult(ResVReg, MIB, GR);
+  case TargetOpcode::G_BUILD_VECTOR:
+    return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
+  case TargetOpcode::G_SHUFFLE_VECTOR:
+    return deduceTypeFromOperandRange(I, MIB, GR, 1, 3);
   default:
-    return isTypeFoldingSupported(Opcode);
+    if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
+        I->getOperand(1).isReg())
+      return deduceTypeFromSingleOperand(I, MIB, GR, 1);
----------------
s-perron wrote:

The first operands is not a register for a G_INTRINSIC. It is the id for which intrinsic is being called.

Function calls to intrinsic do not get added to the work list. It is not a PreISelOpcode.

https://github.com/llvm/llvm-project/pull/165027


More information about the llvm-commits mailing list