[llvm] [SPIRV] Legalize long vectors in GlobalISel (PR #164634)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 22 07:53:51 PDT 2025
https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/164634
>From 50f19ff74bd0ec15e0fc7e489a5bb380d9fd81fc Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 7 Oct 2025 13:26:47 -0400
Subject: [PATCH] [SPIRV] Legalize long vectors in GlobalISel
This commit introduces support for legalizing long vectors (vectors with more
than 4 elements) in the SPIR-V backend using GlobalISel. This is primarily
for shader compilation where the GLSL_std_450 instruction set is available.
The main changes include:
- Adding legalization rules for vector operations (G_BUILD_VECTOR,
G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT, G_BITCAST, G_CONCAT_VECTORS)
to split vectors with more than 4 elements into smaller vectors.
- Enhancing the SPIRVPostLegalizer with a worklist-based approach to
correctly process instructions and types generated during legalization.
- Lowering G_EXTRACT_VECTOR_ELT to a spv_extractelt intrinsic.
- Refining the handling of G_BITCAST to legalize non-pointer bitcasts.
- Marking many SPIR-V operations as pure to aid optimization.
---
.../llvm/CodeGen/GlobalISel/LegalizerInfo.h | 5 +
.../CodeGen/GlobalISel/LegalityPredicates.cpp | 10 +
llvm/lib/Target/SPIRV/SPIRVInstrFormats.td | 5 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 179 +++++++----
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 50 ++-
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 5 +-
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 55 +++-
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 303 +++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 26 +-
9 files changed, 493 insertions(+), 145 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..7cce0ae5359b6 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,11 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
unsigned Size);
+/// True iff the specified type index is a vector with an element size
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size);
+
/// True iff the specified type index is a scalar or a vector with an element
/// type that's wider than the given size.
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..757a1fdba7fbe 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,16 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
};
}
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size) {
+
+ return [=](const LegalityQuery &Query) {
+ const LLT QueryTy = Query.Types[TypeIdx];
+ return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+ };
+}
+
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
unsigned Size) {
return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
index 2fde2b0bc0b1f..f93240dc35993 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
let Pattern = pattern;
}
+class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
+ list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
+ let hasSideEffects = 0;
+}
+
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Op<0, outs, ins, asmstr, pattern> {
let isPseudo = 1;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a61351eba03f8..799a82c96b0f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari
// 3.42.6 Type-Declaration Instructions
-def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
-def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
-def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
- "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
- "$type = OpTypeFloat $width">;
-def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
- "$type = OpTypeVector $compType $compCount">;
-def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
- "$type = OpTypeMatrix $colType $colCount">;
-def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
- i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
- "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
-def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
-def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
- "$res = OpTypeSampledImage $imageType">;
-def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
- "$type = OpTypeArray $elementType $length">;
-def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
- "$type = OpTypeRuntimeArray $elementType">;
-def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
-def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
- "OpTypeStructContinuedINTEL">;
-def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
- "$res = OpTypeOpaque $name">;
-def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
- "$res = OpTypePointer $storage $type">;
-def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
- "$funcType = OpTypeFunction $returnType">;
-def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
-def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
-def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
-def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
-def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
-def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
- "OpTypeForwardPointer $ptrType $storageClass">;
-def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
-def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
-def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
- "$res = OpTypeAccelerationStructureNV">;
-def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
- "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
-def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
- "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
+def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
+def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
+def OpTypeInt
+ : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
+ "$type = OpTypeInt $width $signedness">;
+def OpTypeFloat
+ : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
+ "$type = OpTypeFloat $width">;
+def OpTypeVector
+ : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
+ "$type = OpTypeVector $compType $compCount">;
+def OpTypeMatrix
+ : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
+ "$type = OpTypeMatrix $colType $colCount">;
+def OpTypeImage : PureOp<25, (outs TYPE:$res),
+ (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
+ i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
+ ImageFormat:$imFormat, variable_ops),
+ "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
+ "$sampled $imFormat">;
+def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
+def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
+ "$res = OpTypeSampledImage $imageType">;
+def OpTypeArray
+ : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
+ "$type = OpTypeArray $elementType $length">;
+def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
+ "$type = OpTypeRuntimeArray $elementType">;
+def OpTypeStruct
+ : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
+def OpTypeStructContinuedINTEL
+ : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
+def OpTypeOpaque
+ : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
+ "$res = OpTypeOpaque $name">;
+def OpTypePointer
+ : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
+ "$res = OpTypePointer $storage $type">;
+def OpTypeFunction
+ : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
+ "$funcType = OpTypeFunction $returnType">;
+def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
+def OpTypeDeviceEvent
+ : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
+def OpTypeReserveId
+ : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
+def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
+def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
+ "$res = OpTypePipe $a">;
+def OpTypeForwardPointer
+ : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
+ "OpTypeForwardPointer $ptrType $storageClass">;
+def OpTypePipeStorage
+ : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
+def OpTypeNamedBarrier
+ : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
+def OpTypeAccelerationStructureNV
+ : PureOp<5341, (outs TYPE:$res), (ins),
+ "$res = OpTypeAccelerationStructureNV">;
+def OpTypeCooperativeMatrixNV
+ : PureOp<5358, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
+ "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR
+ : PureOp<4456, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+ "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
+ "$use">;
// 3.42.7 Constant-Creation Instructions
@@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;
def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
-def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
-def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
-
-def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpConstantComposite $type">;
-def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
- "OpConstantCompositeContinuedINTEL">;
-
-def OpConstantSampler: Op<45, (outs ID:$res),
- (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
- "$res = OpConstantSampler $t $s $p $f">;
-def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;
-
-def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
-def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
-def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
- "$res = OpSpecConstant $type $imm">;
-def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpSpecConstantComposite $type">;
-def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
- "OpSpecConstantCompositeContinuedINTEL">;
-def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
- "$res = OpSpecConstantOp $t $c $o">;
+def OpConstantTrue
+ : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantTrue $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
+def OpConstantFalse
+ : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantFalse $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
+
+def OpConstantComposite
+ : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpConstantComposite $type">;
+def OpConstantCompositeContinuedINTEL
+ : PureOp<6091, (outs), (ins variable_ops),
+ "OpConstantCompositeContinuedINTEL">;
+
+def OpConstantSampler : PureOp<45, (outs ID:$res),
+ (ins TYPE:$t, SamplerAddressingMode:$s,
+ i32imm:$p, SamplerFilterMode:$f),
+ "$res = OpConstantSampler $t $s $p $f">;
+def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantNull $src_ty">;
+
+def OpSpecConstantTrue
+ : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
+def OpSpecConstantFalse
+ : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
+def OpSpecConstant
+ : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
+ "$res = OpSpecConstant $type $imm">;
+def OpSpecConstantComposite
+ : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpSpecConstantComposite $type">;
+def OpSpecConstantCompositeContinuedINTEL
+ : PureOp<6092, (outs), (ins variable_ops),
+ "OpSpecConstantCompositeContinuedINTEL">;
+def OpSpecConstantOp
+ : PureOp<52, (outs ID:$res),
+ (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
+ "$res = OpSpecConstantOp $t $c $o">;
// 3.42.8 Memory Instructions
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 021353ab716f7..23fe0b5a15041 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg =
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
- SPIRVType *DefType =
+ SPIRVType *SrcType =
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+ if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");
SPIRVType *ScalarType =
- GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
MachineBasicBlock &BB = *I.getParent();
bool Res = false;
+ unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
- // There was no "assign type" actions, let's fix this now
- ResType = ScalarType;
+ LLT ResLLT = MRI->getType(ResVReg);
+ assert(ResLLT.isValid());
+ if (ResLLT.isVector()) {
+ ResType = GR.getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), I, TII);
+ } else {
+ ResType = ScalarType;
+ }
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
- MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
}
- auto MIB =
- BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(SrcReg)
- .addImm(static_cast<int64_t>(i));
- Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+ if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+ Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addUse(UndefReg);
+ unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+ for (unsigned j = 0; j < NumElements; ++j) {
+ MIB.addImm(CurrentIndex + j);
+ }
+ CurrentIndex += NumElements;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ } else {
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addImm(CurrentIndex);
+ CurrentIndex++;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ }
}
return Res;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 28a1690ef0be1..61de82afad389 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,7 +73,10 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Returns the loaded value.
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
FixedVectorType *TargetType, Value *Source) {
- assert(TargetType->getNumElements() <= SourceType->getNumElements());
+ const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+ [[maybe_unused]] TypeSize TargetTypeSize = DL.getTypeSizeInBits(TargetType);
+ [[maybe_unused]] TypeSize SourceTypeSize = DL.getTypeSizeInBits(SourceType);
+ assert(TargetTypeSize <= SourceTypeSize);
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
buildAssignType(B, SourceType, NewLoad);
Value *AssignValue = NewLoad;
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..bedb42752d241 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -19,6 +19,7 @@
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
using namespace llvm;
using namespace llvm::LegalizeActions;
@@ -101,6 +102,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
v16s8, v16s16, v16s32, v16s64};
+ auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+ v3s1, v3s8, v3s16, v3s32, v3s64,
+ v4s1, v4s8, v4s16, v4s32, v4s64};
+
+ auto allNonShaderVectors = {v8s1, v8s8, v8s16, v8s32, v8s64,
+ v16s1, v16s8, v16s16, v16s32, v16s64};
+
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -148,15 +156,46 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
return IsExtendedInts && Ty.isValid();
};
- for (auto Opc : getTypeFoldingSupportedOpcodes())
- getActionDefinitionsBuilder(Opc).custom();
+ // TODO: So far we only legalize vectors for Shaders.
+ // We need to legalize for kernels as well. For Kernels
+ // vector sizes of 8 and 16 are allowed as well.
- getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+ for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+ if (Opc != G_EXTRACT_VECTOR_ELT)
+ getActionDefinitionsBuilder(Opc).custom();
+ }
- // TODO: add proper rules for vectors legalization.
- getActionDefinitionsBuilder(
- {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
- .alwaysLegal();
+ if (ST.canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450)) {
+ getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+ .lowerIf(vectorElementCountIsGreaterThan(0, 4))
+ .lowerIf(vectorElementCountIsGreaterThan(1, 4))
+ .legalFor(allShaderVectors);
+ getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(1, 4),
+ LegalizeMutations::changeElementCountTo(
+ 1, ElementCount::getFixed(4)))
+ .custom();
+ getActionDefinitionsBuilder(G_BUILD_VECTOR)
+ .legalFor(allShaderVectors)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(4)));
+ getActionDefinitionsBuilder(G_BITCAST)
+ .moreElementsToNextPow2(0)
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, 4),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(4)))
+ .lowerIf(vectorElementCountIsGreaterThan(1, 4))
+ .alwaysLegal();
+ getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+ .legalFor(allShaderVectors)
+ .lower();
+ } else
+ getActionDefinitionsBuilder(
+ {G_SHUFFLE_VECTOR, G_BUILD_VECTOR, G_SPLAT_VECTOR})
+ .alwaysLegal();
// Vector Reduction Operations
getActionDefinitionsBuilder(
@@ -287,6 +326,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Pointer-handling.
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
+ getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..94d7417868750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,10 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_BITCAST:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
@@ -53,70 +58,230 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_BUILD_VECTOR:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static bool processInstr(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const unsigned Opcode = I->getOpcode();
+ Register ResVReg = I->getOperand(0).getReg();
+ SPIRVType *ResType = nullptr;
+ bool Handled = false;
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT: {
+ const LLT &Ty = MRI.getType(ResVReg);
+ unsigned BitWidth = Ty.getScalarSizeInBits();
+ ResType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+ Handled = true;
+ break;
+ }
+ case TargetOpcode::G_UNMERGE_VALUES: {
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ if (DefType->getOpcode() == SPIRV::OpTypeVector) {
SPIRVType *ScalarType =
GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
- ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (!GR->getSPIRVTypeForVReg(DefReg)) {
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType;
+ if (DefLLT.isVector()) {
+ const SPIRVInstrInfo *TII =
+ MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I, *TII);
+ } else {
+ ResType = ScalarType;
+ }
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
}
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
+ Handled = true;
+ }
+ }
+ break;
+ }
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+ LLVM_DEBUG(dbgs() << "Processing G_EXTRACT_VECTOR_ELT: " << *I);
+ Register VecReg = I->getOperand(1).getReg();
+ if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+ LLVM_DEBUG(dbgs() << " Found vector type: " << *VecType << "\n");
+ if (VecType->getOpcode() != SPIRV::OpTypeVector) {
+ VecType->dump();
+ }
+ assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+ ResType = GR->getScalarOrVectorComponentType(VecType);
+ Handled = true;
+ } else {
+ LLVM_DEBUG(dbgs() << " Vector operand " << VecReg
+ << " has no type. Looking at uses of " << ResVReg
+ << ".\n");
+ // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+ // If so get the type from there.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ LLVM_DEBUG(dbgs() << " Use: " << Use);
+ if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+ LLVM_DEBUG(dbgs() << " Use is G_BUILD_VECTOR.\n");
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType =
+ GR->getSPIRVTypeForVReg(BuildVecResReg)) {
+ LLVM_DEBUG(dbgs() << " Found G_BUILD_VECTOR result type: "
+ << *BuildVecType << "\n");
+ ResType = GR->getScalarOrVectorComponentType(BuildVecType);
+ Handled = true;
+ break;
+ } else {
+ LLVM_DEBUG(dbgs() << " G_BUILD_VECTOR result " << BuildVecResReg
+ << " has no type yet.\n");
}
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
}
}
}
+ if (!Handled) {
+ LLVM_DEBUG(
+ dbgs() << " Could not determine type for G_EXTRACT_VECTOR_ELT.\n");
+ }
+ break;
+ }
+ case TargetOpcode::G_BUILD_VECTOR: {
+ // First check if any of the operands have a type.
+ for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ ResType = GR->getOrCreateSPIRVVectorType(
+ OpType, ResLLT.getNumElements(), MIB, false);
+ Handled = true;
+ break;
+ }
+ }
+ if (Handled) {
+ break;
+ }
+ // If that did not work, then check the uses.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ Handled = true;
+ break;
+ }
+ }
+ }
+ break;
+ }
+ case TargetOpcode::G_IMPLICIT_DEF: {
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) {
+ ResType = Type;
+ Handled = true;
+ break;
+ }
+ }
+ if (Handled) {
+ break;
+ }
+ }
+ break;
+ }
+ case TargetOpcode::G_BITCAST: {
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ Register UseResultReg = Use.getOperand(0).getReg();
+ if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+ SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+ const LLT &BitcastLLT = MRI.getType(ResVReg);
+ if (BitcastLLT.isVector()) {
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, BitcastLLT.getNumElements(), MIB, false);
+ } else {
+ ResType = ScalarType;
+ }
+ Handled = true;
+ break;
+ }
+ }
+ break;
+ }
+ default:
+ if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
+ I->getOperand(1).isReg()) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(1).getReg())) {
+ ResType = OpType;
+ Handled = true;
+ }
+ }
+ break;
+ }
+
+ if (Handled && ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+ }
+ return Handled;
+}
+
+static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ SmallVector<MachineInstr *, 8> Worklist;
+
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ if (I.getNumDefs() > 0 &&
+ !GR->getSPIRVTypeForVReg(I.getOperand(0).getReg()) &&
+ mayBeInserted(I.getOpcode())) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ if (processInstr(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+ assert(Worklist.empty() && "Worklist is not empty");
}
}
@@ -159,6 +324,46 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
processNewInstrs(MF, GR, MIB);
+ // TODO: Move this into is own function.
+ SmallVector<MachineInstr *, 8> ExtractInstrs;
+ SmallVector<MachineInstr *, 8> BitcastInstrs;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ ExtractInstrs.push_back(&MI);
+ } else if (MI.getOpcode() == TargetOpcode::G_BITCAST) {
+ BitcastInstrs.push_back(&MI);
+ }
+ }
+ }
+
+ for (MachineInstr *MI : ExtractInstrs) {
+ MachineIRBuilder MIB(*MI);
+ Register Dst = MI->getOperand(0).getReg();
+ Register Vec = MI->getOperand(1).getReg();
+ Register Idx = MI->getOperand(2).getReg();
+
+ auto Intr =
+ MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, false, false);
+ Intr.addUse(Vec);
+ Intr.addUse(Idx);
+
+ MI->eraseFromParent();
+ }
+
+ for (MachineInstr *MI : BitcastInstrs) {
+ MachineIRBuilder MIB(*MI);
+ Register Dst = MI->getOperand(0).getReg();
+ Register Src = MI->getOperand(1).getReg();
+ SPIRVType *DstType = GR->getSPIRVTypeForVReg(Dst);
+ assert(DstType && "Destination of G_BITCAST must have a type");
+ MIB.buildInstr(SPIRV::OpBitcast)
+ .addDef(Dst)
+ .addUse(GR->getSPIRVTypeID(DstType))
+ .addUse(Src);
+ MI->eraseFromParent();
+ }
+
return true;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index db6f2d61e8f29..b4b9771b0b50a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -192,6 +192,10 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
.addUse(OpReg);
}
+// TODO: See if the comment needs to be more precise. This is a problem for more
+// than just pointers. A bitcast between an two type that map to the same LLT
+// will cause a problem. For example a bitcast from a float to an int.
+
// We do instruction selections early instead of calling MIB.buildBitcast()
// generating the general op code G_BITCAST. When MachineVerifier validates
// G_BITCAST we see a check of a kind: if Source Type is equal to Destination
@@ -212,15 +216,29 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
// in https://github.com/llvm/llvm-project/pull/110270 for even more context.
static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
+ // TODO: This change will still cause failures for the machine verifier.
+ // This is done so that the G_Bitcast legalization rules can be used.
+ // We will have to add rules to legalize the intrinsic. Note that we cannot
+ // try to lower only bitcast the verify will complain about. You could cast
+ // a long vector of float to a long vector of ints. That has to be legalized
+ // but is also an invalid G_Bitcast.
+ //
+ // Could we use G_COPY? for cases where the LLT are the same? Then lowering
+ // the G_COPY could be either an OpBitcast or OpCopyObject denpending on the
+ // source and result type.
+
+ MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 16> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
if (MI.getOpcode() != TargetOpcode::G_BITCAST)
continue;
- MIB.setInsertPt(*MI.getParent(), MI);
- buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
- MI.getOperand(1).getReg());
- ToErase.push_back(&MI);
+ Register DstReg = MI.getOperand(0).getReg();
+ if (MRI.getType(DstReg).isPointer()) {
+ MIB.setInsertPt(*MI.getParent(), MI);
+ buildOpBitcast(GR, MIB, DstReg, MI.getOperand(1).getReg());
+ ToErase.push_back(&MI);
+ }
}
}
for (MachineInstr *MI : ToErase) {
More information about the llvm-commits
mailing list