[llvm] [SPIRV] Use a worklist in the post-legalizer (PR #165027)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 7 09:27:49 PST 2025
================
@@ -43,83 +44,334 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
SPIRVType *KnownResType);
} // namespace llvm
-static bool mayBeInserted(unsigned Opcode) {
- switch (Opcode) {
- case TargetOpcode::G_SMAX:
- case TargetOpcode::G_UMAX:
- case TargetOpcode::G_SMIN:
- case TargetOpcode::G_UMIN:
- case TargetOpcode::G_FMINNUM:
- case TargetOpcode::G_FMINIMUM:
- case TargetOpcode::G_FMAXNUM:
- case TargetOpcode::G_FMAXIMUM:
- return true;
+static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR) {
+ const LLT &Ty = MIB.getMRI()->getType(ResVReg);
+ return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
+}
+
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ SPIRVType *ScalarType = nullptr;
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ assert(DefType->getOpcode() == SPIRV::OpTypeVector);
+ ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ }
+
+ if (!ScalarType) {
+ // If we could not deduce the type from the source, try to deduce it from
+ // the uses of the results.
+ for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
+ for (const auto &Use :
+ MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
+ assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
+ "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
+ if (auto *VecType =
+ GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
+ ScalarType = GR->getScalarOrVectorComponentType(VecType);
+ break;
+ }
+ }
+ }
+ }
+
+ if (!ScalarType)
+ return false;
+
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (GR->getSPIRVTypeForVReg(DefReg))
+ continue;
+
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType =
+ DefLLT.isVector()
+ ? GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I,
+ *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
+ : ScalarType;
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
+ }
+ return true;
+}
+
+static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned OpIdx) {
+ Register OpReg = I->getOperand(OpIdx).getReg();
+ if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
+ if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+ Register ResVReg = I->getOperand(0).getReg();
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (ResLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
+ MIB, false);
+ return CompType;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned StartOp, unsigned EndOp) {
+ for (unsigned i = StartOp; i < EndOp; ++i) {
+ if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i))
+ return Type;
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
+ Register UseRegister,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ for (const MachineOperand &MO : Use->defs()) {
+ if (!MO.isReg())
+ continue;
+ if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) {
+ if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+ const LLT &ResLLT = MIB.getMRI()->getType(UseRegister);
+ if (ResLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ CompType, ResLLT.getNumElements(), MIB, false);
+ return CompType;
+ }
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
+ SPIRVType *ResType = nullptr;
+ switch (Use.getOpcode()) {
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
+ ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
+ break;
+ }
+ if (ResType)
+ return ResType;
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ Register ResVReg = I->getOperand(0).getReg();
+ switch (I->getOpcode()) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_ANYEXT:
+ return deduceIntTypeFromResult(ResVReg, MIB, GR);
+ case TargetOpcode::G_BUILD_VECTOR:
+ return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
+ case TargetOpcode::G_SHUFFLE_VECTOR:
+ return deduceTypeFromOperandRange(I, MIB, GR, 1, 3);
default:
- return isTypeFoldingSupported(Opcode);
+ if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
+ I->getOperand(1).isReg())
+ return deduceTypeFromSingleOperand(I, MIB, GR, 1);
+ return nullptr;
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register ResVReg = I->getOperand(0).getReg();
+
+ // G_UNMERGE_VALUES is handled separately because it has multiple definitions,
+ // unlike the other instructions which have a single result register. The main
+ // deduction logic is designed for the single-definition case.
+ if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
+ return deduceAndAssignTypeForGUnmerge(I, MF, GR);
+ LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
+ SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
+ if (!ResType) {
+ LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
+ ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
+ }
+
+ if (ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+ if (!MRI.getRegClassOrNull(ResVReg)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo &MRI) {
+ LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+ << I;);
+ if (I.getNumDefs() == 0) {
+ LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+ return false;
+ }
+
+ if (!I.isPreISelOpcode()) {
+ LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
+ return false;
+ }
+
+ Register ResultRegister = I.defs().begin()->getReg();
+ if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+ if (!MRI.getRegClassOrNull(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+ GR, &MRI, *GR->CurMF, true);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ SmallVector<MachineInstr *, 8> Worklist;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
- ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
- }
- }
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
- }
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ if (requiresSpirvType(I, GR, MRI)) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ MachineIRBuilder MIB(*I);
+ if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
----------------
s-perron wrote:
done
https://github.com/llvm/llvm-project/pull/165027
More information about the llvm-commits
mailing list