[llvm] [SPIRV] Add legalization for long vectors (PR #165444)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 06:35:33 PST 2025
https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/165444
>From 65c01f567b70555eb48f0dc9058c838fe89802c8 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 13:21:24 -0400
Subject: [PATCH 1/9] [SPIRV] Use a worklist in the post-legalizer
This commit refactors the SPIRV post-legalizer to use a worklist to process
new instructions. Previously, the post-legalizer would iterate through all
instructions and try to assign types. This could fail if a new instruction
depended on another new instruction that had not been processed yet.
The new implementation adds all new instructions that require a SPIR-V type
to a worklist. It then iteratively processes the worklist until it is empty.
This ensures that all dependencies are met before an instruction is
processed.
This change makes the post-legalizer more robust and fixes potential ordering
issues with newly generated instructions.
Existing tests cover existing functionality. More tests will be added as
the legalizer is modified.
Part of #153091
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 412 ++++++++++++++++---
1 file changed, 359 insertions(+), 53 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..b6c650c802247 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
@@ -53,73 +59,372 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_ICMP:
+ case TargetOpcode::G_ANYEXT:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const LLT &Ty = MRI.getType(ResVReg);
+ unsigned BitWidth = Ty.getScalarSizeInBits();
+ return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+}
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ if (DefType->getOpcode() == SPIRV::OpTypeVector) {
+ SPIRVType *ScalarType =
+ GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (!GR->getSPIRVTypeForVReg(DefReg)) {
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType;
+ if (DefLLT.isVector()) {
+ const SPIRVInstrInfo *TII =
+ MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I, *TII);
+ } else {
ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
- }
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register VecReg = I->getOperand(1).getReg();
+ if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+ assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+ return GR->getScalarOrVectorComponentType(VecType);
+ }
+
+ // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+ // If so get the type from there.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+ return GR->getScalarOrVectorComponentType(BuildVecType);
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ // First check if any of the operands have a type.
+ for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
+ MIB, false);
+ }
+ }
+ // If that did not work, then check the uses.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ }
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+ return Type;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
+ return nullptr;
+
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ Register UseResultReg = Use.getOperand(0).getReg();
+ if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+ SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+ const LLT &BitcastLLT = MRI.getType(ResVReg);
+ if (BitcastLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, BitcastLLT.getNumElements(), MIB, false);
+ return ScalarType;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ // The result type of G_ANYEXT cannot be inferred from its operand.
+ // We use the result register's LLT to determine the correct integer type.
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (!ResLLT.isScalar())
+ return nullptr;
+ return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+}
+
+static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
+ !I->getOperand(1).isReg())
+ return nullptr;
+
+ SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
+ if (!OpType)
+ return nullptr;
+ return OpType;
+}
+
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const unsigned Opcode = I->getOpcode();
+ Register ResVReg = I->getOperand(0).getReg();
+ SPIRVType *ResType = nullptr;
+
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT: {
+ ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_UNMERGE_VALUES: {
+ // This one is special as it defines multiple registers.
+ if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
+ return true;
+ break;
+ }
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+ ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_BUILD_VECTOR: {
+ ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_ANYEXT: {
+ ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_IMPLICIT_DEF: {
+ ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
+ ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ default:
+ ResType = deduceTypeForDefault(I, MF, GR);
+ break;
+ }
+
+ if (ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+ if (!MRI.getRegClassOrNull(ResVReg)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo &MRI) {
+ LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+ << I;);
+ if (I.getNumDefs() == 0) {
+ LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+ return false;
+ }
+ if (!mayBeInserted(I.getOpcode())) {
+ LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+ return false;
+ }
+
+ Register ResultRegister = I.defs().begin()->getReg();
+ if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+ if (!MRI.getRegClassOrNull(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+ GR, &MRI, *GR->CurMF, true);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ SmallVector<MachineInstr *, 8> Worklist;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ if (requiresSpirvType(I, GR, MRI)) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+ assert(Worklist.empty() && "Worklist is not empty");
+ }
+}
+
+static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
+ << MF.getName() << "\n");
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (!isTypeFoldingSupported(MI.getOpcode()))
+ continue;
+ if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
+
+ // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
+ bool HasAssignType = false;
+ Register ResultRegister = MI.defs().begin()->getReg();
+ // All uses of Result register
+ for (MachineInstr &UseInstr :
+ MRI.use_nodbg_instructions(ResultRegister)) {
+ if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
+ HasAssignType = true;
+ LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: "
+ << UseInstr);
+ break;
}
}
+
+ if (!HasAssignType) {
+ Register ResultRegister = MI.defs().begin()->getReg();
+ SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
+ LLVM_DEBUG(
+ dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
+ << printReg(ResultRegister, MRI.getTargetRegisterInfo())
+ << " with type: " << *ResultType);
+ insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
+ }
}
}
}
+static void lowerExtractVectorElements(MachineFunction &MF) {
+ SmallVector<MachineInstr *, 8> ExtractInstrs;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ ExtractInstrs.push_back(&MI);
+ }
+ }
+ }
+
+ for (MachineInstr *MI : ExtractInstrs) {
+ MachineIRBuilder MIB(*MI);
+ Register Dst = MI->getOperand(0).getReg();
+ Register Vec = MI->getOperand(1).getReg();
+ Register Idx = MI->getOperand(2).getReg();
+
+ auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false);
+ Intr.addUse(Vec);
+ Intr.addUse(Idx);
+
+ MI->eraseFromParent();
+ }
+}
+
// Do a preorder traversal of the CFG starting from the BB |Start|.
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, MachineBasicBlock &Start,
@@ -156,8 +461,9 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
-
- processNewInstrs(MF, GR, MIB);
+ registerSpirvTypeForNewInstructions(MF, GR, MIB);
+ ensureAssignTypeForTypeFolding(MF, GR, MIB);
+ lowerExtractVectorElements(MF);
return true;
}
>From a1f4714914e985dfa2dd4584dffb9d9e4c19310d Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 28 Oct 2025 13:01:07 -0400
Subject: [PATCH 2/9] Set insertion point in MIB.
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 22 +++++++++++++-------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index b6c650c802247..69de5a6360c66 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -138,11 +138,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
MachineIRBuilder &MIB,
Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n");
// First check if any of the operands have a type.
for (unsigned i = 1; i < I->getNumOperands(); ++i) {
if (SPIRVType *OpType =
GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
const LLT &ResLLT = MRI.getType(ResVReg);
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type "
+ << *OpType << ", returning vector type\n");
return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
MIB, false);
}
@@ -153,11 +156,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
Register ExtractResReg = Use.getOperand(0).getReg();
if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
const LLT &ResLLT = MRI.getType(ResVReg);
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type "
+ << *ScalarType << ", returning vector type\n");
return GR->getOrCreateSPIRVVectorType(
ScalarType, ResLLT.getNumElements(), MIB, false);
}
}
}
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n");
return nullptr;
}
@@ -191,7 +197,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
const unsigned UseOpc = Use.getOpcode();
assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
- UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
+ UseOpc == TargetOpcode::G_BUILD_VECTOR);
Register UseResultReg = Use.getOperand(0).getReg();
if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
@@ -316,8 +323,7 @@ static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
}
static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+ SPIRVGlobalRegistry *GR) {
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 8> Worklist;
for (MachineBasicBlock &MBB : MF) {
@@ -342,6 +348,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
SmallVector<MachineInstr *, 8> NextWorklist;
for (MachineInstr *I : Worklist) {
+ MachineIRBuilder MIB(*I);
if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
Changed = true;
} else {
@@ -360,8 +367,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
}
static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+ SPIRVGlobalRegistry *GR) {
LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
<< MF.getName() << "\n");
MachineRegisterInfo &MRI = MF.getRegInfo();
@@ -395,6 +401,7 @@ static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
<< printReg(ResultRegister, MRI.getTargetRegisterInfo())
<< " with type: " << *ResultType);
+ MachineIRBuilder MIB(MI);
insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
}
}
@@ -460,9 +467,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
- MachineIRBuilder MIB(MF);
- registerSpirvTypeForNewInstructions(MF, GR, MIB);
- ensureAssignTypeForTypeFolding(MF, GR, MIB);
+ registerSpirvTypeForNewInstructions(MF, GR);
+ ensureAssignTypeForTypeFolding(MF, GR);
lowerExtractVectorElements(MF);
return true;
>From 79f3b0e897bb2b8c9692210c8cf778617a753072 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 29 Oct 2025 09:36:48 -0400
Subject: [PATCH 3/9] Handle vector shuffle.
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 138 +++++++++++++++----
1 file changed, 108 insertions(+), 30 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 69de5a6360c66..644e010d8cf94 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -62,6 +62,7 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_IMPLICIT_DEF:
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_ICMP:
+ case TargetOpcode::G_SHUFFLE_VECTOR:
case TargetOpcode::G_ANYEXT:
return true;
default:
@@ -83,30 +84,47 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
SPIRVGlobalRegistry *GR) {
MachineRegisterInfo &MRI = MF.getRegInfo();
Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ SPIRVType *ScalarType = nullptr;
if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
- if (DefType->getOpcode() == SPIRV::OpTypeVector) {
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I->getNumDefs(); ++i) {
- Register DefReg = I->getOperand(i).getReg();
- if (!GR->getSPIRVTypeForVReg(DefReg)) {
- LLT DefLLT = MRI.getType(DefReg);
- SPIRVType *ResType;
- if (DefLLT.isVector()) {
- const SPIRVInstrInfo *TII =
- MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
- ResType = GR->getOrCreateSPIRVVectorType(
- ScalarType, DefLLT.getNumElements(), *I, *TII);
- } else {
- ResType = ScalarType;
- }
- setRegClassType(DefReg, ResType, GR, &MRI, MF);
+ assert(DefType->getOpcode() == SPIRV::OpTypeVector);
+ ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ }
+
+ if (!ScalarType) {
+ // If we could not deduce the type from the source, try to deduce it from
+ // the uses of the results.
+ for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
+ for (const auto &Use :
+ MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
+ assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
+ "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
+ if (auto *VecType =
+ GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
+ ScalarType = GR->getScalarOrVectorComponentType(VecType);
+ break;
}
}
- return true;
}
}
- return false;
+
+ if (!ScalarType)
+ return false;
+
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (GR->getSPIRVTypeForVReg(DefReg))
+ continue;
+
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType =
+ DefLLT.isVector()
+ ? GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I,
+ *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
+ : ScalarType;
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
+ }
+ return true;
}
static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
@@ -167,20 +185,61 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
return nullptr;
}
+static SPIRVType *deduceTypeForGShuffleVector(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ assert(ResLLT.isVector() && "G_SHUFFLE_VECTOR result must be a vector");
+
+ // The result element type should be the same as the input vector element
+ // types.
+ for (unsigned i = 1; i <= 2; ++i) {
+ Register VReg = I->getOperand(i).getReg();
+ if (auto *VType = GR->getSPIRVTypeForVReg(VReg)) {
+ if (auto *ScalarType = GR->getScalarOrVectorComponentType(VType))
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ }
+ }
+ return nullptr;
+}
+
static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
MachineFunction &MF,
SPIRVGlobalRegistry *GR,
Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
- for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
- const unsigned UseOpc = Use.getOpcode();
- assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
- UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
- // It's possible that the use instruction has not been processed yet.
- // We should look at the operands of the use to determine the type.
- for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
- if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
- return Type;
+ for (const MachineInstr &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ SPIRVType *ScalarType = nullptr;
+ switch (Use.getOpcode()) {
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+ ScalarType = GR->getScalarOrVectorComponentType(OpType);
+ }
+ break;
+ case TargetOpcode::G_SHUFFLE_VECTOR:
+ // For G_SHUFFLE_VECTOR, only look at the vector input operands.
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(1).getReg()))
+ ScalarType = GR->getScalarOrVectorComponentType(Type);
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(2).getReg()))
+ ScalarType = GR->getScalarOrVectorComponentType(Type);
+ break;
+ }
+ if (ScalarType) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ if (!ResLLT.isVector())
+ return ScalarType;
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), *I,
+ *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo());
}
}
return nullptr;
@@ -198,7 +257,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
const unsigned UseOpc = Use.getOpcode();
assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
- UseOpc == TargetOpcode::G_BUILD_VECTOR);
+ UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_UNMERGE_VALUES);
Register UseResultReg = Use.getOperand(0).getReg();
if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
@@ -264,6 +324,10 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
break;
}
+ case TargetOpcode::G_SHUFFLE_VECTOR: {
+ ResType = deduceTypeForGShuffleVector(I, MF, GR, MIB, ResVReg);
+ break;
+ }
case TargetOpcode::G_ANYEXT: {
ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
break;
@@ -362,7 +426,21 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
if (!Worklist.empty()) {
LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
for (auto *I : Worklist) { I->dump(); });
- assert(Worklist.empty() && "Worklist is not empty");
+ for (auto *I : Worklist) {
+ MachineIRBuilder MIB(*I);
+ Register ResVReg = I->getOperand(0).getReg();
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ SPIRVType *ResType = nullptr;
+ if (ResLLT.isVector()) {
+ SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
+ ResLLT.getElementType().getSizeInBits(), MIB);
+ ResType = GR->getOrCreateSPIRVVectorType(
+ CompType, ResLLT.getNumElements(), MIB, false);
+ } else {
+ ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+ }
+ setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
+ }
}
}
>From 2220ffcfac603c2453b57443a3f70bda1c2d546f Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 4 Nov 2025 14:59:07 -0500
Subject: [PATCH 4/9] t
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 294 +++++--------------
1 file changed, 73 insertions(+), 221 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 644e010d8cf94..ae8f63793f4fc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -44,40 +44,10 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
SPIRVType *KnownResType);
} // namespace llvm
-static bool mayBeInserted(unsigned Opcode) {
- switch (Opcode) {
- case TargetOpcode::G_CONSTANT:
- case TargetOpcode::G_UNMERGE_VALUES:
- case TargetOpcode::G_EXTRACT_VECTOR_ELT:
- case TargetOpcode::G_INTRINSIC:
- case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
- case TargetOpcode::G_SMAX:
- case TargetOpcode::G_UMAX:
- case TargetOpcode::G_SMIN:
- case TargetOpcode::G_UMIN:
- case TargetOpcode::G_FMINNUM:
- case TargetOpcode::G_FMINIMUM:
- case TargetOpcode::G_FMAXNUM:
- case TargetOpcode::G_FMAXIMUM:
- case TargetOpcode::G_IMPLICIT_DEF:
- case TargetOpcode::G_BUILD_VECTOR:
- case TargetOpcode::G_ICMP:
- case TargetOpcode::G_SHUFFLE_VECTOR:
- case TargetOpcode::G_ANYEXT:
- return true;
- default:
- return isTypeFoldingSupported(Opcode);
- }
-}
-
-static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB,
- Register ResVReg) {
- MachineRegisterInfo &MRI = MF.getRegInfo();
- const LLT &Ty = MRI.getType(ResVReg);
- unsigned BitWidth = Ty.getScalarSizeInBits();
- return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+static SPIRVType *deduceIntTypeFromRes(Register ResVReg, MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR) {
+ const LLT &Ty = MIB.getMRI()->getType(ResVReg);
+ return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
}
static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
@@ -127,173 +97,87 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
return true;
}
-static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
- MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- Register ResVReg) {
- MachineRegisterInfo &MRI = MF.getRegInfo();
- Register VecReg = I->getOperand(1).getReg();
- if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
- assert(VecType->getOpcode() == SPIRV::OpTypeVector);
- return GR->getScalarOrVectorComponentType(VecType);
- }
-
- // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
- // If so get the type from there.
- for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
- if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
- Register BuildVecResReg = Use.getOperand(0).getReg();
- if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
- return GR->getScalarOrVectorComponentType(BuildVecType);
+static SPIRVType *deduceTypeFromOperand(MachineInstr *I, MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned OpIdx) {
+ Register OpReg = I->getOperand(OpIdx).getReg();
+ if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
+ if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+ Register ResVReg = I->getOperand(0).getReg();
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (ResLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
+ MIB, false);
+ return CompType;
}
}
return nullptr;
}
-static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
- MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB,
- Register ResVReg) {
- MachineRegisterInfo &MRI = MF.getRegInfo();
- LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n");
- // First check if any of the operands have a type.
- for (unsigned i = 1; i < I->getNumOperands(); ++i) {
- if (SPIRVType *OpType =
- GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
- const LLT &ResLLT = MRI.getType(ResVReg);
- LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type "
- << *OpType << ", returning vector type\n");
- return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
- MIB, false);
- }
- }
- // If that did not work, then check the uses.
- for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
- if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
- Register ExtractResReg = Use.getOperand(0).getReg();
- if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
- const LLT &ResLLT = MRI.getType(ResVReg);
- LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type "
- << *ScalarType << ", returning vector type\n");
- return GR->getOrCreateSPIRVVectorType(
- ScalarType, ResLLT.getNumElements(), MIB, false);
- }
- }
+static SPIRVType *deduceTypeFromOperands(MachineInstr *I, MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned StartOp, unsigned EndOp) {
+ for (unsigned i = StartOp; i < EndOp; ++i) {
+ if (SPIRVType *Type = deduceTypeFromOperand(I, MIB, GR, i))
+ return Type;
}
- LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n");
return nullptr;
}
-static SPIRVType *deduceTypeForGShuffleVector(MachineInstr *I,
- MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB,
- Register ResVReg) {
- MachineRegisterInfo &MRI = MF.getRegInfo();
- const LLT &ResLLT = MRI.getType(ResVReg);
- assert(ResLLT.isVector() && "G_SHUFFLE_VECTOR result must be a vector");
-
- // The result element type should be the same as the input vector element
- // types.
- for (unsigned i = 1; i <= 2; ++i) {
- Register VReg = I->getOperand(i).getReg();
- if (auto *VType = GR->getSPIRVTypeForVReg(VReg)) {
- if (auto *ScalarType = GR->getScalarOrVectorComponentType(VType))
- return GR->getOrCreateSPIRVVectorType(
- ScalarType, ResLLT.getNumElements(), MIB, false);
- }
+static SPIRVType *deduceTypeForGBuildVectorFromUses(const MachineInstr &Use,
+ SPIRVGlobalRegistry *GR) {
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+ return GR->getScalarOrVectorComponentType(BuildVecType);
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorEltFromUses(
+ const MachineInstr &Use, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MIB.getMRI()->getType(Use.getOperand(0).getReg());
+ return GR->getOrCreateSPIRVVectorType(ScalarType, ResLLT.getNumElements(),
+ MIB, false);
}
return nullptr;
}
-static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
- MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- Register ResVReg) {
+static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
- for (const MachineInstr &Use : MRI.use_nodbg_instructions(ResVReg)) {
- SPIRVType *ScalarType = nullptr;
+ for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
+ SPIRVType *ResType = nullptr;
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
- case TargetOpcode::G_UNMERGE_VALUES:
- // It's possible that the use instruction has not been processed yet.
- // We should look at the operands of the use to determine the type.
- for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
- if (SPIRVType *OpType =
- GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
- ScalarType = GR->getScalarOrVectorComponentType(OpType);
- }
+ ResType = deduceTypeForGBuildVectorFromUses(Use, GR);
break;
- case TargetOpcode::G_SHUFFLE_VECTOR:
- // For G_SHUFFLE_VECTOR, only look at the vector input operands.
- if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(1).getReg()))
- ScalarType = GR->getScalarOrVectorComponentType(Type);
- if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(2).getReg()))
- ScalarType = GR->getScalarOrVectorComponentType(Type);
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ ResType = deduceTypeForGExtractVectorEltFromUses(Use, GR, MIB);
break;
}
- if (ScalarType) {
- const LLT &ResLLT = MRI.getType(ResVReg);
- if (!ResLLT.isVector())
- return ScalarType;
- return GR->getOrCreateSPIRVVectorType(
- ScalarType, ResLLT.getNumElements(), *I,
- *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo());
- }
+ if (ResType)
+ return ResType;
}
return nullptr;
}
-static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB,
- Register ResVReg) {
- MachineRegisterInfo &MRI = MF.getRegInfo();
- if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
- return nullptr;
-
- for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
- const unsigned UseOpc = Use.getOpcode();
- assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
- UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
- UseOpc == TargetOpcode::G_BUILD_VECTOR ||
- UseOpc == TargetOpcode::G_UNMERGE_VALUES);
- Register UseResultReg = Use.getOperand(0).getReg();
- if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
- SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
- const LLT &BitcastLLT = MRI.getType(ResVReg);
- if (BitcastLLT.isVector())
- return GR->getOrCreateSPIRVVectorType(
- ScalarType, BitcastLLT.getNumElements(), MIB, false);
- return ScalarType;
- }
+static SPIRVType *deduceTypeFromOperands(MachineInstr *I,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ Register ResVReg = I->getOperand(0).getReg();
+ switch (I->getOpcode()) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_ANYEXT:
+ return deduceIntTypeFromRes(ResVReg, MIB, GR);
+ case TargetOpcode::G_BUILD_VECTOR:
+ return deduceTypeFromOperands(I, MIB, GR, 1, I->getNumOperands());
+ case TargetOpcode::G_SHUFFLE_VECTOR:
+ return deduceTypeFromOperands(I, MIB, GR, 1, 3);
+ default:
+ return deduceTypeFromOperand(I, MIB, GR, 1);
}
- return nullptr;
-}
-
-static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB,
- Register ResVReg) {
- // The result type of G_ANYEXT cannot be inferred from its operand.
- // We use the result register's LLT to determine the correct integer type.
- const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
- if (!ResLLT.isScalar())
- return nullptr;
- return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
-}
-
-static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
- SPIRVGlobalRegistry *GR) {
- if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
- !I->getOperand(1).isReg())
- return nullptr;
-
- SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
- if (!OpType)
- return nullptr;
- return OpType;
}
static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
@@ -301,50 +185,17 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
MachineIRBuilder &MIB) {
LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
MachineRegisterInfo &MRI = MF.getRegInfo();
- const unsigned Opcode = I->getOpcode();
Register ResVReg = I->getOperand(0).getReg();
- SPIRVType *ResType = nullptr;
- switch (Opcode) {
- case TargetOpcode::G_CONSTANT: {
- ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
- break;
- }
- case TargetOpcode::G_UNMERGE_VALUES: {
- // This one is special as it defines multiple registers.
- if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
- return true;
- break;
- }
- case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
- ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
- break;
- }
- case TargetOpcode::G_BUILD_VECTOR: {
- ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
- break;
- }
- case TargetOpcode::G_SHUFFLE_VECTOR: {
- ResType = deduceTypeForGShuffleVector(I, MF, GR, MIB, ResVReg);
- break;
- }
- case TargetOpcode::G_ANYEXT: {
- ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
- break;
- }
- case TargetOpcode::G_IMPLICIT_DEF: {
- ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
- break;
- }
- case TargetOpcode::G_INTRINSIC:
- case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
- ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
- break;
- }
- default:
- ResType = deduceTypeForDefault(I, MF, GR);
- break;
- }
+ // G_UNMERGE_VALUES is handled separately because it has multiple definitions,
+ // unlike the other instructions which have a single result register. The main
+ // deduction logic is designed for the single-definition case.
+ if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
+ return deduceAndAssignTypeForGUnmerge(I, MF, GR);
+
+ SPIRVType *ResType = deduceTypeFromOperands(I, GR, MIB);
+ if (!ResType)
+ ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
if (ResType) {
LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
@@ -367,8 +218,9 @@ static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
return false;
}
- if (!mayBeInserted(I.getOpcode())) {
- LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+
+ if (!I.isPreISelOpcode()) {
+ LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
return false;
}
>From 33fa8df194011e3607ee7e9b75593e8a6f2e9c93 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 4 Nov 2025 15:26:53 -0500
Subject: [PATCH 5/9] t
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 64 ++++++++------------
1 file changed, 26 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index ae8f63793f4fc..7bb2af3096ab2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -44,8 +44,9 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
SPIRVType *KnownResType);
} // namespace llvm
-static SPIRVType *deduceIntTypeFromRes(Register ResVReg, MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR) {
+static SPIRVType *inferIntTypeFromResult(Register ResVReg,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR) {
const LLT &Ty = MIB.getMRI()->getType(ResVReg);
return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
}
@@ -97,9 +98,10 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
return true;
}
-static SPIRVType *deduceTypeFromOperand(MachineInstr *I, MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR,
- unsigned OpIdx) {
+static SPIRVType *inferTypeFromSingleOperand(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned OpIdx) {
Register OpReg = I->getOperand(OpIdx).getReg();
if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
@@ -114,47 +116,33 @@ static SPIRVType *deduceTypeFromOperand(MachineInstr *I, MachineIRBuilder &MIB,
return nullptr;
}
-static SPIRVType *deduceTypeFromOperands(MachineInstr *I, MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR,
- unsigned StartOp, unsigned EndOp) {
+static SPIRVType *inferTypeFromOperandRange(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned StartOp, unsigned EndOp) {
for (unsigned i = StartOp; i < EndOp; ++i) {
- if (SPIRVType *Type = deduceTypeFromOperand(I, MIB, GR, i))
+ if (SPIRVType *Type = inferTypeFromSingleOperand(I, MIB, GR, i))
return Type;
}
return nullptr;
}
-static SPIRVType *deduceTypeForGBuildVectorFromUses(const MachineInstr &Use,
- SPIRVGlobalRegistry *GR) {
- Register BuildVecResReg = Use.getOperand(0).getReg();
- if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
- return GR->getScalarOrVectorComponentType(BuildVecType);
- return nullptr;
-}
-
-static SPIRVType *deduceTypeForGExtractVectorEltFromUses(
- const MachineInstr &Use, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) {
- Register ExtractResReg = Use.getOperand(0).getReg();
- if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
- const LLT &ResLLT = MIB.getMRI()->getType(Use.getOperand(0).getReg());
- return GR->getOrCreateSPIRVVectorType(ScalarType, ResLLT.getNumElements(),
- MIB, false);
- }
- return nullptr;
+static SPIRVType *inferTypeForResultRegister(MachineInstr *Use,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ return inferTypeFromSingleOperand(Use, MIB, GR, 0);
}
static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
- for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
+ for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
SPIRVType *ResType = nullptr;
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
- ResType = deduceTypeForGBuildVectorFromUses(Use, GR);
- break;
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
- ResType = deduceTypeForGExtractVectorEltFromUses(Use, GR, MIB);
+ ResType = inferTypeForResultRegister(&Use, GR, MIB);
break;
}
if (ResType)
@@ -163,20 +151,20 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
return nullptr;
}
-static SPIRVType *deduceTypeFromOperands(MachineInstr *I,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB) {
+static SPIRVType *inferResultTypeFromOperands(MachineInstr *I,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
Register ResVReg = I->getOperand(0).getReg();
switch (I->getOpcode()) {
case TargetOpcode::G_CONSTANT:
case TargetOpcode::G_ANYEXT:
- return deduceIntTypeFromRes(ResVReg, MIB, GR);
+ return inferIntTypeFromResult(ResVReg, MIB, GR);
case TargetOpcode::G_BUILD_VECTOR:
- return deduceTypeFromOperands(I, MIB, GR, 1, I->getNumOperands());
+ return inferTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
case TargetOpcode::G_SHUFFLE_VECTOR:
- return deduceTypeFromOperands(I, MIB, GR, 1, 3);
+ return inferTypeFromOperandRange(I, MIB, GR, 1, 3);
default:
- return deduceTypeFromOperand(I, MIB, GR, 1);
+ return inferTypeFromSingleOperand(I, MIB, GR, 1);
}
}
@@ -193,7 +181,7 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
return deduceAndAssignTypeForGUnmerge(I, MF, GR);
- SPIRVType *ResType = deduceTypeFromOperands(I, GR, MIB);
+ SPIRVType *ResType = inferResultTypeFromOperands(I, GR, MIB);
if (!ResType)
ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
>From 0541e99ff6a89c0963434ecc58116e46151f8a1b Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 7 Oct 2025 13:26:47 -0400
Subject: [PATCH 6/9] [SPIRV] Set hasSideEffects flag to false on type and
constant opcodes
This change sets the hasSideEffects flag to false on type and constant opcodes
so that they can be considered trivially dead if their result is unused. This
means that instruction selection will now be able to remove them.
---
llvm/lib/Target/SPIRV/SPIRVInstrFormats.td | 5 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 179 +++++++++++-------
.../SPIRV/hlsl-intrinsics/AddUint64.ll | 2 +-
.../pointers/resource-vector-load-store.ll | 27 +--
4 files changed, 130 insertions(+), 83 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
index 2fde2b0bc0b1f..f93240dc35993 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
let Pattern = pattern;
}
+class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
+ list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
+ let hasSideEffects = 0;
+}
+
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Op<0, outs, ins, asmstr, pattern> {
let isPseudo = 1;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a61351eba03f8..799a82c96b0f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari
// 3.42.6 Type-Declaration Instructions
-def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
-def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
-def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
- "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
- "$type = OpTypeFloat $width">;
-def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
- "$type = OpTypeVector $compType $compCount">;
-def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
- "$type = OpTypeMatrix $colType $colCount">;
-def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
- i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
- "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
-def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
-def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
- "$res = OpTypeSampledImage $imageType">;
-def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
- "$type = OpTypeArray $elementType $length">;
-def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
- "$type = OpTypeRuntimeArray $elementType">;
-def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
-def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
- "OpTypeStructContinuedINTEL">;
-def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
- "$res = OpTypeOpaque $name">;
-def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
- "$res = OpTypePointer $storage $type">;
-def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
- "$funcType = OpTypeFunction $returnType">;
-def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
-def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
-def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
-def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
-def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
-def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
- "OpTypeForwardPointer $ptrType $storageClass">;
-def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
-def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
-def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
- "$res = OpTypeAccelerationStructureNV">;
-def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
- "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
-def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
- "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
+def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
+def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
+def OpTypeInt
+ : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
+ "$type = OpTypeInt $width $signedness">;
+def OpTypeFloat
+ : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
+ "$type = OpTypeFloat $width">;
+def OpTypeVector
+ : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
+ "$type = OpTypeVector $compType $compCount">;
+def OpTypeMatrix
+ : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
+ "$type = OpTypeMatrix $colType $colCount">;
+def OpTypeImage : PureOp<25, (outs TYPE:$res),
+ (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
+ i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
+ ImageFormat:$imFormat, variable_ops),
+ "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
+ "$sampled $imFormat">;
+def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
+def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
+ "$res = OpTypeSampledImage $imageType">;
+def OpTypeArray
+ : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
+ "$type = OpTypeArray $elementType $length">;
+def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
+ "$type = OpTypeRuntimeArray $elementType">;
+def OpTypeStruct
+ : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
+def OpTypeStructContinuedINTEL
+ : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
+def OpTypeOpaque
+ : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
+ "$res = OpTypeOpaque $name">;
+def OpTypePointer
+ : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
+ "$res = OpTypePointer $storage $type">;
+def OpTypeFunction
+ : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
+ "$funcType = OpTypeFunction $returnType">;
+def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
+def OpTypeDeviceEvent
+ : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
+def OpTypeReserveId
+ : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
+def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
+def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
+ "$res = OpTypePipe $a">;
+def OpTypeForwardPointer
+ : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
+ "OpTypeForwardPointer $ptrType $storageClass">;
+def OpTypePipeStorage
+ : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
+def OpTypeNamedBarrier
+ : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
+def OpTypeAccelerationStructureNV
+ : PureOp<5341, (outs TYPE:$res), (ins),
+ "$res = OpTypeAccelerationStructureNV">;
+def OpTypeCooperativeMatrixNV
+ : PureOp<5358, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
+ "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR
+ : PureOp<4456, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+ "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
+ "$use">;
// 3.42.7 Constant-Creation Instructions
@@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;
def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
-def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
-def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
-
-def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpConstantComposite $type">;
-def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
- "OpConstantCompositeContinuedINTEL">;
-
-def OpConstantSampler: Op<45, (outs ID:$res),
- (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
- "$res = OpConstantSampler $t $s $p $f">;
-def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;
-
-def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
-def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
-def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
- "$res = OpSpecConstant $type $imm">;
-def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpSpecConstantComposite $type">;
-def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
- "OpSpecConstantCompositeContinuedINTEL">;
-def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
- "$res = OpSpecConstantOp $t $c $o">;
+def OpConstantTrue
+ : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantTrue $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
+def OpConstantFalse
+ : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantFalse $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
+
+def OpConstantComposite
+ : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpConstantComposite $type">;
+def OpConstantCompositeContinuedINTEL
+ : PureOp<6091, (outs), (ins variable_ops),
+ "OpConstantCompositeContinuedINTEL">;
+
+def OpConstantSampler : PureOp<45, (outs ID:$res),
+ (ins TYPE:$t, SamplerAddressingMode:$s,
+ i32imm:$p, SamplerFilterMode:$f),
+ "$res = OpConstantSampler $t $s $p $f">;
+def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantNull $src_ty">;
+
+def OpSpecConstantTrue
+ : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
+def OpSpecConstantFalse
+ : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
+def OpSpecConstant
+ : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
+ "$res = OpSpecConstant $type $imm">;
+def OpSpecConstantComposite
+ : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpSpecConstantComposite $type">;
+def OpSpecConstantCompositeContinuedINTEL
+ : PureOp<6092, (outs), (ins variable_ops),
+ "OpSpecConstantCompositeContinuedINTEL">;
+def OpSpecConstantOp
+ : PureOp<52, (outs ID:$res),
+ (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
+ "$res = OpSpecConstantOp $t $c $o">;
// 3.42.8 Memory Instructions
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
index a97492b8453ea..a15d628cc3614 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
@@ -63,7 +63,7 @@ entry:
; CHECK: %[[#a_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#a]] %[[#undef_v4i32]] 1 3
; CHECK: %[[#b_low:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 0 2
; CHECK: %[[#b_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 1 3
-; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#vec2_int_32]]
+; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#b_low]]
; CHECK: %[[#lowsum:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 0
; CHECK: %[[#carry:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 1
; CHECK: %[[#carry_ne0:]] = OpINotEqual %[[#vec2_bool]] %[[#carry]] %[[#const_v2i32_0_0]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
index 7548f4757dbe6..6fc03a386d14d 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
@@ -4,18 +4,23 @@
@.str = private unnamed_addr constant [7 x i8] c"buffer\00", align 1
+; The i64 values in the extracts will be turned
+; into immidiate values. There should be no 64-bit
+; integers in the module.
+; CHECK-NOT: OpTypeInt 64 0
+
define void @main() "hlsl.shader"="pixel" {
-; CHECK: %24 = OpFunction %2 None %3 ; -- Begin function main
-; CHECK-NEXT: %1 = OpLabel
-; CHECK-NEXT: %25 = OpVariable %13 Function %22
-; CHECK-NEXT: %26 = OpLoad %7 %23
-; CHECK-NEXT: %27 = OpImageRead %5 %26 %15
-; CHECK-NEXT: %28 = OpCompositeExtract %4 %27 0
-; CHECK-NEXT: %29 = OpCompositeExtract %4 %27 1
-; CHECK-NEXT: %30 = OpFAdd %4 %29 %28
-; CHECK-NEXT: %31 = OpCompositeInsert %5 %30 %27 0
-; CHECK-NEXT: %32 = OpLoad %7 %23
-; CHECK-NEXT: OpImageWrite %32 %15 %31
+; CHECK: %[[FUNC:[0-9]+]] = OpFunction %[[VOID:[0-9]+]] None %[[FNTYPE:[0-9]+]] ; -- Begin function main
+; CHECK-NEXT: %[[LABEL:[0-9]+]] = OpLabel
+; CHECK-NEXT: %[[VAR:[0-9]+]] = OpVariable %[[PTR_FN:[a-zA-Z0-9_]+]] Function %[[INIT:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[LOAD1:[0-9]+]] = OpLoad %[[IMG_TYPE:[a-zA-Z0-9_]+]] %[[IMG_VAR:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[READ:[0-9]+]] = OpImageRead %[[VEC4:[a-zA-Z0-9_]+]] %[[LOAD1]] %[[COORD:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[EXTRACT1:[0-9]+]] = OpCompositeExtract %[[FLOAT:[a-zA-Z0-9_]+]] %[[READ]] 0
+; CHECK-NEXT: %[[EXTRACT2:[0-9]+]] = OpCompositeExtract %[[FLOAT]] %[[READ]] 1
+; CHECK-NEXT: %[[ADD:[0-9]+]] = OpFAdd %[[FLOAT]] %[[EXTRACT2]] %[[EXTRACT1]]
+; CHECK-NEXT: %[[INSERT:[0-9]+]] = OpCompositeInsert %[[VEC4]] %[[ADD]] %[[READ]] 0
+; CHECK-NEXT: %[[LOAD2:[0-9]+]] = OpLoad %[[IMG_TYPE]] %[[IMG_VAR]]
+; CHECK-NEXT: OpImageWrite %[[LOAD2]] %[[COORD]] %[[INSERT]]
; CHECK-NEXT: OpReturn
; CHECK-NEXT: OpFunctionEnd
entry:
>From 9f7ea63dfdcdd33fd26fa5c7b8881b22bcf3dd16 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 5 Nov 2025 09:30:19 -0500
Subject: [PATCH 7/9] t
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 34 ++++++++++++++++----
1 file changed, 28 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 7bb2af3096ab2..8e478899ec42a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -128,9 +128,23 @@ static SPIRVType *inferTypeFromOperandRange(MachineInstr *I,
}
static SPIRVType *inferTypeForResultRegister(MachineInstr *Use,
+ Register UseRegister,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
- return inferTypeFromSingleOperand(Use, MIB, GR, 0);
+ for (const MachineOperand &MO : Use->defs()) {
+ if (!MO.isReg())
+ continue;
+ if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) {
+ if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
+ const LLT &ResLLT = MIB.getMRI()->getType(UseRegister);
+ if (ResLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ CompType, ResLLT.getNumElements(), MIB, false);
+ return CompType;
+ }
+ }
+ }
+ return nullptr;
}
static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
@@ -142,7 +156,9 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
switch (Use.getOpcode()) {
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
- ResType = inferTypeForResultRegister(&Use, GR, MIB);
+ case TargetOpcode::G_UNMERGE_VALUES:
+ LLVM_DEBUG(dbgs() << "Looking at use " << Use);
+ ResType = inferTypeForResultRegister(&Use, Reg, GR, MIB);
break;
}
if (ResType)
@@ -164,14 +180,17 @@ static SPIRVType *inferResultTypeFromOperands(MachineInstr *I,
case TargetOpcode::G_SHUFFLE_VECTOR:
return inferTypeFromOperandRange(I, MIB, GR, 1, 3);
default:
- return inferTypeFromSingleOperand(I, MIB, GR, 1);
+ if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
+ I->getOperand(1).isReg())
+ return inferTypeFromSingleOperand(I, MIB, GR, 1);
+ return nullptr;
}
}
static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIB) {
- LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+ LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
MachineRegisterInfo &MRI = MF.getRegInfo();
Register ResVReg = I->getOperand(0).getReg();
@@ -181,12 +200,15 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
return deduceAndAssignTypeForGUnmerge(I, MF, GR);
+ LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
SPIRVType *ResType = inferResultTypeFromOperands(I, GR, MIB);
- if (!ResType)
+ if (!ResType) {
+ LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
+ }
if (ResType) {
- LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
if (!MRI.getRegClassOrNull(ResVReg)) {
>From 3ec8f3954a20ea53c6c1a90c870cc66295d740aa Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 5 Nov 2025 09:34:45 -0500
Subject: [PATCH 8/9] t
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 56 ++++++++++----------
1 file changed, 28 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 8e478899ec42a..7bd68dfd2777f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -44,9 +44,9 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
SPIRVType *KnownResType);
} // namespace llvm
-static SPIRVType *inferIntTypeFromResult(Register ResVReg,
- MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR) {
+static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR) {
const LLT &Ty = MIB.getMRI()->getType(ResVReg);
return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
}
@@ -98,39 +98,39 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
return true;
}
-static SPIRVType *inferTypeFromSingleOperand(MachineInstr *I,
- MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR,
- unsigned OpIdx) {
+static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned OpIdx) {
Register OpReg = I->getOperand(OpIdx).getReg();
if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
Register ResVReg = I->getOperand(0).getReg();
const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
if (ResLLT.isVector())
- return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
- MIB, false);
+ return GR->getOrCreateSPIRVVectorType(
+ CompType, ResLLT.getNumElements(), MIB, false);
return CompType;
}
}
return nullptr;
}
-static SPIRVType *inferTypeFromOperandRange(MachineInstr *I,
- MachineIRBuilder &MIB,
- SPIRVGlobalRegistry *GR,
- unsigned StartOp, unsigned EndOp) {
+static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
+ MachineIRBuilder &MIB,
+ SPIRVGlobalRegistry *GR,
+ unsigned StartOp, unsigned EndOp) {
for (unsigned i = StartOp; i < EndOp; ++i) {
- if (SPIRVType *Type = inferTypeFromSingleOperand(I, MIB, GR, i))
+ if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i))
return Type;
}
return nullptr;
}
-static SPIRVType *inferTypeForResultRegister(MachineInstr *Use,
- Register UseRegister,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB) {
+static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
+ Register UseRegister,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
for (const MachineOperand &MO : Use->defs()) {
if (!MO.isReg())
continue;
@@ -157,8 +157,8 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
case TargetOpcode::G_BUILD_VECTOR:
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
case TargetOpcode::G_UNMERGE_VALUES:
- LLVM_DEBUG(dbgs() << "Looking at use " << Use);
- ResType = inferTypeForResultRegister(&Use, Reg, GR, MIB);
+ LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
+ ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
break;
}
if (ResType)
@@ -167,22 +167,22 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
return nullptr;
}
-static SPIRVType *inferResultTypeFromOperands(MachineInstr *I,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIB) {
+static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
Register ResVReg = I->getOperand(0).getReg();
switch (I->getOpcode()) {
case TargetOpcode::G_CONSTANT:
case TargetOpcode::G_ANYEXT:
- return inferIntTypeFromResult(ResVReg, MIB, GR);
+ return deduceIntTypeFromResult(ResVReg, MIB, GR);
case TargetOpcode::G_BUILD_VECTOR:
- return inferTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
+ return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
case TargetOpcode::G_SHUFFLE_VECTOR:
- return inferTypeFromOperandRange(I, MIB, GR, 1, 3);
+ return deduceTypeFromOperandRange(I, MIB, GR, 1, 3);
default:
if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
I->getOperand(1).isReg())
- return inferTypeFromSingleOperand(I, MIB, GR, 1);
+ return deduceTypeFromSingleOperand(I, MIB, GR, 1);
return nullptr;
}
}
@@ -201,7 +201,7 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
return deduceAndAssignTypeForGUnmerge(I, MF, GR);
LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
- SPIRVType *ResType = inferResultTypeFromOperands(I, GR, MIB);
+ SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
if (!ResType) {
LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
>From 2bdaa1c7bd28e27ac7f35414eba746d43b559f8c Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 15:16:19 -0400
Subject: [PATCH 9/9] [SPIRV] Add legalization for long vectors
This patch introduces the necessary infrastructure to legalize vector
operations on vectors that are longer than what the SPIR-V target
supports. For instance, shaders only support vectors up to 4 elements.
The legalization is done by splitting the long vectors into smaller
vectors of a legal size.
Specifically, this patch does the following:
- Introduces `vectorElementCountIsGreaterThan` and
`vectorElementCountIsLessThanOrEqualTo` legality predicates.
- Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`,
`G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and
`G_UNMERGE_VALUES`.
- Handles `G_BITCAST` of long vectors by converting them to
`@llvm.spv.bitcast` intrinsics which are then legalized.
- Updates `selectUnmergeValues` to handle extraction of both scalars
and vectors from a larger vector, using `OpCompositeExtract` and
`OpVectorShuffle` respectively.
- Adds a test case to verify the legalization of a bitcast between
a `<8 x i32>` and `<4 x f64>`, which is a pattern generated by
HLSL's `asuint` and `asdouble` intrinsics.
Fixes: https://github.com/llvm/llvm-project/pull/165444
---
.../llvm/CodeGen/GlobalISel/LegalizerInfo.h | 10 ++
.../CodeGen/GlobalISel/LegalityPredicates.cpp | 20 +++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 50 ++++--
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 165 ++++++++++++++++--
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h | 4 +
.../SPIRV/legalization/load-store-global.ll | 84 +++++++++
.../vector-legalization-kernel.ll | 69 ++++++++
7 files changed, 378 insertions(+), 24 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..a8748965eb2e8 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
unsigned Size);
+/// True iff the specified type index is a vector with an element size
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size);
+
+/// True iff the specified type index is a vector with an element size
+/// that's less than or equal to the given size.
+LLVM_ABI LegalityPredicate
+vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
+
/// True iff the specified type index is a scalar or a vector with an element
/// type that's wider than the given size.
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..5e7cd5fd5d9ad 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
};
}
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size) {
+
+ return [=](const LegalityQuery &Query) {
+ const LLT QueryTy = Query.Types[TypeIdx];
+ return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+ };
+}
+
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
+ unsigned Size) {
+
+ return [=](const LegalityQuery &Query) {
+ const LLT QueryTy = Query.Types[TypeIdx];
+ return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
+ };
+}
+
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
unsigned Size) {
return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3f0424f436c72..23f4ee2cd0f7e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1529,33 +1529,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg =
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
- SPIRVType *DefType =
+ SPIRVType *SrcType =
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+ if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");
SPIRVType *ScalarType =
- GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
MachineBasicBlock &BB = *I.getParent();
bool Res = false;
+ unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
- // There was no "assign type" actions, let's fix this now
- ResType = ScalarType;
+ LLT ResLLT = MRI->getType(ResVReg);
+ assert(ResLLT.isValid());
+ if (ResLLT.isVector()) {
+ ResType = GR.getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), I, TII);
+ } else {
+ ResType = ScalarType;
+ }
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
- MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
}
- auto MIB =
- BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(SrcReg)
- .addImm(static_cast<int64_t>(i));
- Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+ if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+ Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addUse(UndefReg);
+ unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+ for (unsigned j = 0; j < NumElements; ++j) {
+ MIB.addImm(CurrentIndex + j);
+ }
+ CurrentIndex += NumElements;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ } else {
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addImm(CurrentIndex);
+ CurrentIndex++;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ }
}
return Res;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..a7d6bde3c5f1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -14,16 +14,22 @@
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
using namespace llvm;
using namespace llvm::LegalizeActions;
using namespace llvm::LegalityPredicates;
+#define DEBUG_TYPE "spirv-legalizer"
+
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
const LLT Ty = Query.Types[TypeIdx];
@@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
v16s8, v16s16, v16s32, v16s64};
+ auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+ v3s1, v3s8, v3s16, v3s32, v3s64,
+ v4s1, v4s8, v4s16, v4s32, v4s64};
+
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
+ auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
+
bool IsExtendedInts =
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
return IsExtendedInts && Ty.isValid();
};
- for (auto Opc : getTypeFoldingSupportedOpcodes())
- getActionDefinitionsBuilder(Opc).custom();
+ uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
- getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+ for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+ if (Opc != G_EXTRACT_VECTOR_ELT)
+ getActionDefinitionsBuilder(Opc).custom();
+ }
- // TODO: add proper rules for vectors legalization.
- getActionDefinitionsBuilder(
- {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+ getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
+
+ getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+ .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
+ .moreElementsToNextPow2(0)
+ .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+ .moreElementsToNextPow2(1)
+ .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+ .alwaysLegal();
+
+ getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 1, ElementCount::getFixed(MaxVectorSize)))
+ .custom();
+
+ // Illegal G_UNMERGE_VALUES instructions should be handled
+ // during the combine phase.
+ getActionDefinitionsBuilder(G_BUILD_VECTOR)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(MaxVectorSize)));
+
+ // When entering the legalizer, there should be no G_BITCAST instructions.
+ // They should all be calls to the `spv_bitcast` intrinsic. The call to
+ // the intrinsic will be converted to a G_BITCAST during legalization if
+ // the vectors are not legal. After using the rules to legalize a G_BITCAST,
+ // we turn it back into a call to the intrinsic with a custom ruel to avoid
+ // potential machines verifier failures.
+ getActionDefinitionsBuilder(G_BITCAST)
+ .moreElementsToNextPow2(0)
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(MaxVectorSize)))
+ .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+ .custom();
+
+ getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .moreElementsToNextPow2(0)
+ .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+ .alwaysLegal();
+
+ getActionDefinitionsBuilder(G_SPLAT_VECTOR)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .moreElementsToNextPow2(0)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
.alwaysLegal();
// Vector Reduction Operations
@@ -164,7 +227,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
- .legalFor(allVectors)
+ .legalFor(allowedVectorTypes)
.scalarize(1)
.lower();
@@ -172,9 +235,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.scalarize(2)
.lower();
- // Merge/Unmerge
- // TODO: add proper legalization rules.
- getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+ // Illegal G_UNMERGE_VALUES instructions should be handled
+ // during the combine phase.
+ getActionDefinitionsBuilder(G_UNMERGE_VALUES)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
all(typeInSet(0, allPtrsScalarsAndVectors),
typeInSet(1, allPtrsScalarsAndVectors)));
- getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
+ getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
+ .legalFor({s1})
+ .legalFor(allFloatAndIntScalarsAndPtrs)
+ .legalFor(allowedVectorTypes)
+ .moreElementsToNextPow2(0)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(MaxVectorSize)));
getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
@@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Pointer-handling.
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
+ getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
@@ -374,6 +447,12 @@ bool SPIRVLegalizerInfo::legalizeCustom(
default:
// TODO: implement legalization for other opcodes.
return true;
+ case TargetOpcode::G_BITCAST:
+ return legalizeBitcast(Helper, MI);
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
+ return legalizeIntrinsic(Helper, MI);
+
case TargetOpcode::G_IS_FPCLASS:
return legalizeIsFPClass(Helper, MI, LocObserver);
case TargetOpcode::G_ICMP: {
@@ -400,6 +479,70 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}
+bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
+ MachineInstr &MI) const {
+ LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
+
+ MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+ MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+ const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+
+ auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
+ if (IntrinsicID == Intrinsic::spv_bitcast) {
+ LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(2).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT SrcTy = MRI.getType(SrcReg);
+
+ int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+
+ bool DstNeedsLegalization = false;
+ bool SrcNeedsLegalization = false;
+
+ if (DstTy.isVector()) {
+ if (DstTy.getNumElements() > 4 &&
+ !isPowerOf2_32(DstTy.getNumElements())) {
+ DstNeedsLegalization = true;
+ }
+
+ if (DstTy.getNumElements() > MaxVectorSize) {
+ DstNeedsLegalization = true;
+ }
+ }
+
+ if (SrcTy.isVector()) {
+ if (SrcTy.getNumElements() > 4 &&
+ !isPowerOf2_32(SrcTy.getNumElements())) {
+ SrcNeedsLegalization = true;
+ }
+
+ if (SrcTy.getNumElements() > MaxVectorSize) {
+ SrcNeedsLegalization = true;
+ }
+ }
+
+ if (DstNeedsLegalization || SrcNeedsLegalization) {
+ LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
+ MIRBuilder.buildBitcast(DstReg, SrcReg);
+ MI.eraseFromParent();
+ }
+ return true;
+ }
+ return true;
+}
+
+bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
+ MachineInstr &MI) const {
+ MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(1).getReg();
+ SmallVector<Register, 1> DstRegs = {DstReg};
+ MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
+ MI.eraseFromParent();
+ return true;
+}
+
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
// to ensure that all instructions created during the lowering have SPIR-V types
// assigned to them.
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
index eeefa4239c778..86e7e711caa60 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
public:
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const override;
+ bool legalizeIntrinsic(LegalizerHelper &Helper,
+ MachineInstr &MI) const override;
+
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
private:
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const;
+ bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
};
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
new file mode 100644
index 0000000000000..fbfec1b3ee7cf
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -0,0 +1,84 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+; CHECK-DAG: %[[#v2i32:]] = OpTypeVector %[[#int]] 2
+; CHECK-DAG: %[[#ptr_private_v4i32:]] = OpTypePointer Private %[[#v4i32]]
+; CHECK-DAG: %[[#ptr_private_v4f64:]] = OpTypePointer Private %[[#v4f64]]
+; CHECK-DAG: %[[#global_double:]] = OpVariable %[[#ptr_private_v4f64]] Private
+
+ at G_16 = internal addrspace(10) global [16 x i32] zeroinitializer
+ at G_4_double = internal addrspace(10) global <4 x double> zeroinitializer
+ at G_4_int = internal addrspace(10) global <4 x i32> zeroinitializer
+
+
+; This is the way matrices will be represented in HLSL. The memory type will be
+; an array, but it will be loaded as a vector.
+; TODO: Legalization for loads and stores of long vectors is not implemented yet. │
+;define spir_func void @test_load_store_global() { │
+;entry: │
+; %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64 │
+; store <16 x i32> %0, ptr addrspace(10) @G_16, align 64 │
+; ret void │
+;}
+
+; This is the code pattern that can be generated from the `asuint` and `asdouble`
+; HLSL intrinsics.
+
+; TODO: This cods not the best because instruction selection is not folding an
+; extract from other intstruction. That needs to be handled.
+define spir_func void @test_int32_double_conversion() {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+entry:
+ ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]]
+ ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1
+ ; CHECK: %[[#VEC_SHUF2:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 2 3
+ ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF1]]
+ ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF2]]
+ ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 0
+ ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 2
+ ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 0
+ ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 2
+ ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %[[#EXTRACT4]]
+ ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 1
+ ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 3
+ ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 1
+ ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 3
+ ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT5]] %[[#EXTRACT6]] %[[#EXTRACT7]] %[[#EXTRACT8]]
+ ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 0
+ ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 0
+ ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 1
+ ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 1
+ ; CHECK: %[[#EXTRACT13:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 2
+ ; CHECK: %[[#EXTRACT14:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 2
+ ; CHECK: %[[#EXTRACT15:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 3
+ ; CHECK: %[[#EXTRACT16:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 3
+ ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT9]] %[[#EXTRACT10]]
+ ; CHECK: %[[#CONSTRUCT4:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT11]] %[[#EXTRACT12]]
+ ; CHECK: %[[#CONSTRUCT5:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT13]] %[[#EXTRACT14]]
+ ; CHECK: %[[#CONSTRUCT6:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT15]] %[[#EXTRACT16]]
+ ; CHECK: %[[#BITCAST3:]] = OpBitcast %[[#double]] %[[#CONSTRUCT3]]
+ ; CHECK: %[[#BITCAST4:]] = OpBitcast %[[#double]] %[[#CONSTRUCT4]]
+ ; CHECK: %[[#BITCAST5:]] = OpBitcast %[[#double]] %[[#CONSTRUCT5]]
+ ; CHECK: %[[#BITCAST6:]] = OpBitcast %[[#double]] %[[#CONSTRUCT6]]
+ ; CHECK: %[[#CONSTRUCT7:]] = OpCompositeConstruct %[[#v4f64]] %[[#BITCAST3]] %[[#BITCAST4]] %[[#BITCAST5]] %[[#BITCAST6]]
+ ; CHECK: OpStore %[[#global_double]] %[[#CONSTRUCT7]] Aligned 32
+
+ %0 = load <8 x i32>, ptr addrspace(10) @G_4_double
+ %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+ store <8 x i32> %3, ptr addrspace(10) @G_4_double
+ ret void
+}
+
+; Add a main function to make it a valid module for spirv-val
+define void @main() #1 {
+ ret void
+}
+
+attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
new file mode 100644
index 0000000000000..4fe6f217dd40f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll
@@ -0,0 +1,69 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]]
+
+; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion"
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]]
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+
+define spir_kernel void @test_int32_double_conversion(ptr %G_vec) {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]]
+entry:
+ ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]]
+ ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6
+ ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7
+ ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7
+ ; CHECK: OpStore %[[#param]] %[[#SHUF3]]
+
+ %0 = load <8 x i32>, ptr %G_vec
+ %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+ store <8 x i32> %3, ptr %G_vec
+ ret void
+}
+
+define spir_kernel void @test_v3f64_conversion(ptr %G_vec) {
+; CHECK: %[[#test_v3f64_conversion:]] = OpFunction
+; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]]
+entry:
+ ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]]
+ %0 = load <3 x double>, ptr %G_vec
+
+ ; The 6-element vector is not legal. It get expanded to 8.
+ ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0
+ ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1
+ ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2
+ ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}}
+ ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]]
+ %1 = bitcast <3 x double> %0 to <6 x i32>
+
+ ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF
+ %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+
+ ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 1 3 5 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF
+ %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+
+ ; CHECK: %[[#SHUFFLE3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUFFLE1]] %[[#SHUFFLE2]] 0 8 1 9 2 10 0xFFFFFFFF 0xFFFFFFFF
+ %4 = shufflevector <3 x i32> %2, <3 x i32> %3, <6 x i32> <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+
+ ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4f64]] %[[#SHUFFLE3]]
+ ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 0
+ ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 1
+ ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 2
+ ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v3f64]] %[[#EXTRACT10]] %[[#EXTRACT11]] %[[#EXTRACT12]]
+ %5 = bitcast <6 x i32> %4 to <3 x double>
+
+ ; CHECK: OpStore %[[#param_v3f64]] %[[#CONSTRUCT3]]
+ store <3 x double> %5, ptr %G_vec
+ ret void
+}
+
More information about the llvm-commits
mailing list