[llvm] legalize long vectors (PR #165444)
Steven Perron via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 28 10:29:23 PDT 2025
https://github.com/s-perron created https://github.com/llvm/llvm-project/pull/165444
- **[SPIRV] Set hasSideEffects flag to false on type and constant opcodes**
- **[SPIRV] Expand spv_bitcast intrinsic during instruction selection**
- **Remove unnecessary pointer checks**
- **[SPIRV] Fix vector bitcast check in LegalizePointerCast**
- **[SPIRV] Use a worklist in the post-legalizer**
- **[SPIRV] Use a worklist in the post-legalizer**
- **Set insertion point in MIB.**
- **[SPIRV] legalize long vectors.**
- **Add tests**
>From aa1e9a0d2c0fdef0c0e8ddcedc115dba06bddb6e Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 7 Oct 2025 13:26:47 -0400
Subject: [PATCH 1/9] [SPIRV] Set hasSideEffects flag to false on type and
constant opcodes
This change sets the hasSideEffects flag to false on type and constant opcodes
so that they can be considered trivially dead if their result is unused. This
means that instruction selection will now be able to remove them.
---
llvm/lib/Target/SPIRV/SPIRVInstrFormats.td | 5 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 179 +++++++++++-------
.../SPIRV/hlsl-intrinsics/AddUint64.ll | 2 +-
.../pointers/resource-vector-load-store.ll | 27 +--
4 files changed, 130 insertions(+), 83 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
index 2fde2b0bc0b1f..f93240dc35993 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
let Pattern = pattern;
}
+class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
+ list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
+ let hasSideEffects = 0;
+}
+
class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Op<0, outs, ins, asmstr, pattern> {
let isPseudo = 1;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a61351eba03f8..799a82c96b0f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari
// 3.42.6 Type-Declaration Instructions
-def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
-def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
-def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
- "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
- "$type = OpTypeFloat $width">;
-def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
- "$type = OpTypeVector $compType $compCount">;
-def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
- "$type = OpTypeMatrix $colType $colCount">;
-def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
- i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
- "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
-def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
-def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
- "$res = OpTypeSampledImage $imageType">;
-def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
- "$type = OpTypeArray $elementType $length">;
-def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
- "$type = OpTypeRuntimeArray $elementType">;
-def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
-def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
- "OpTypeStructContinuedINTEL">;
-def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
- "$res = OpTypeOpaque $name">;
-def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
- "$res = OpTypePointer $storage $type">;
-def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
- "$funcType = OpTypeFunction $returnType">;
-def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
-def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
-def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
-def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
-def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
-def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
- "OpTypeForwardPointer $ptrType $storageClass">;
-def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
-def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
-def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
- "$res = OpTypeAccelerationStructureNV">;
-def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
- "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
-def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
- (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
- "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
+def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
+def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
+def OpTypeInt
+ : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
+ "$type = OpTypeInt $width $signedness">;
+def OpTypeFloat
+ : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
+ "$type = OpTypeFloat $width">;
+def OpTypeVector
+ : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
+ "$type = OpTypeVector $compType $compCount">;
+def OpTypeMatrix
+ : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
+ "$type = OpTypeMatrix $colType $colCount">;
+def OpTypeImage : PureOp<25, (outs TYPE:$res),
+ (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
+ i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
+ ImageFormat:$imFormat, variable_ops),
+ "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
+ "$sampled $imFormat">;
+def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
+def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
+ "$res = OpTypeSampledImage $imageType">;
+def OpTypeArray
+ : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
+ "$type = OpTypeArray $elementType $length">;
+def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
+ "$type = OpTypeRuntimeArray $elementType">;
+def OpTypeStruct
+ : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
+def OpTypeStructContinuedINTEL
+ : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
+def OpTypeOpaque
+ : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
+ "$res = OpTypeOpaque $name">;
+def OpTypePointer
+ : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
+ "$res = OpTypePointer $storage $type">;
+def OpTypeFunction
+ : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
+ "$funcType = OpTypeFunction $returnType">;
+def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
+def OpTypeDeviceEvent
+ : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
+def OpTypeReserveId
+ : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
+def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
+def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
+ "$res = OpTypePipe $a">;
+def OpTypeForwardPointer
+ : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
+ "OpTypeForwardPointer $ptrType $storageClass">;
+def OpTypePipeStorage
+ : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
+def OpTypeNamedBarrier
+ : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
+def OpTypeAccelerationStructureNV
+ : PureOp<5341, (outs TYPE:$res), (ins),
+ "$res = OpTypeAccelerationStructureNV">;
+def OpTypeCooperativeMatrixNV
+ : PureOp<5358, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
+ "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR
+ : PureOp<4456, (outs TYPE:$res),
+ (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+ "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
+ "$use">;
// 3.42.7 Constant-Creation Instructions
@@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;
def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
-def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
-def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
- [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
-
-def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpConstantComposite $type">;
-def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
- "OpConstantCompositeContinuedINTEL">;
-
-def OpConstantSampler: Op<45, (outs ID:$res),
- (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
- "$res = OpConstantSampler $t $s $p $f">;
-def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;
-
-def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
-def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
-def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
- "$res = OpSpecConstant $type $imm">;
-def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
- "$res = OpSpecConstantComposite $type">;
-def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
- "OpSpecConstantCompositeContinuedINTEL">;
-def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
- "$res = OpSpecConstantOp $t $c $o">;
+def OpConstantTrue
+ : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantTrue $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
+def OpConstantFalse
+ : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantFalse $src_ty",
+ [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;
+
+def OpConstantComposite
+ : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpConstantComposite $type">;
+def OpConstantCompositeContinuedINTEL
+ : PureOp<6091, (outs), (ins variable_ops),
+ "OpConstantCompositeContinuedINTEL">;
+
+def OpConstantSampler : PureOp<45, (outs ID:$res),
+ (ins TYPE:$t, SamplerAddressingMode:$s,
+ i32imm:$p, SamplerFilterMode:$f),
+ "$res = OpConstantSampler $t $s $p $f">;
+def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
+ "$dst = OpConstantNull $src_ty">;
+
+def OpSpecConstantTrue
+ : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
+def OpSpecConstantFalse
+ : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
+def OpSpecConstant
+ : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
+ "$res = OpSpecConstant $type $imm">;
+def OpSpecConstantComposite
+ : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
+ "$res = OpSpecConstantComposite $type">;
+def OpSpecConstantCompositeContinuedINTEL
+ : PureOp<6092, (outs), (ins variable_ops),
+ "OpSpecConstantCompositeContinuedINTEL">;
+def OpSpecConstantOp
+ : PureOp<52, (outs ID:$res),
+ (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
+ "$res = OpSpecConstantOp $t $c $o">;
// 3.42.8 Memory Instructions
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
index a97492b8453ea..a15d628cc3614 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll
@@ -63,7 +63,7 @@ entry:
; CHECK: %[[#a_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#a]] %[[#undef_v4i32]] 1 3
; CHECK: %[[#b_low:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 0 2
; CHECK: %[[#b_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 1 3
-; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#vec2_int_32]]
+; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#b_low]]
; CHECK: %[[#lowsum:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 0
; CHECK: %[[#carry:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 1
; CHECK: %[[#carry_ne0:]] = OpINotEqual %[[#vec2_bool]] %[[#carry]] %[[#const_v2i32_0_0]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
index 7548f4757dbe6..6fc03a386d14d 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll
@@ -4,18 +4,23 @@
@.str = private unnamed_addr constant [7 x i8] c"buffer\00", align 1
+; The i64 values in the extracts will be turned
+; into immidiate values. There should be no 64-bit
+; integers in the module.
+; CHECK-NOT: OpTypeInt 64 0
+
define void @main() "hlsl.shader"="pixel" {
-; CHECK: %24 = OpFunction %2 None %3 ; -- Begin function main
-; CHECK-NEXT: %1 = OpLabel
-; CHECK-NEXT: %25 = OpVariable %13 Function %22
-; CHECK-NEXT: %26 = OpLoad %7 %23
-; CHECK-NEXT: %27 = OpImageRead %5 %26 %15
-; CHECK-NEXT: %28 = OpCompositeExtract %4 %27 0
-; CHECK-NEXT: %29 = OpCompositeExtract %4 %27 1
-; CHECK-NEXT: %30 = OpFAdd %4 %29 %28
-; CHECK-NEXT: %31 = OpCompositeInsert %5 %30 %27 0
-; CHECK-NEXT: %32 = OpLoad %7 %23
-; CHECK-NEXT: OpImageWrite %32 %15 %31
+; CHECK: %[[FUNC:[0-9]+]] = OpFunction %[[VOID:[0-9]+]] None %[[FNTYPE:[0-9]+]] ; -- Begin function main
+; CHECK-NEXT: %[[LABEL:[0-9]+]] = OpLabel
+; CHECK-NEXT: %[[VAR:[0-9]+]] = OpVariable %[[PTR_FN:[a-zA-Z0-9_]+]] Function %[[INIT:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[LOAD1:[0-9]+]] = OpLoad %[[IMG_TYPE:[a-zA-Z0-9_]+]] %[[IMG_VAR:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[READ:[0-9]+]] = OpImageRead %[[VEC4:[a-zA-Z0-9_]+]] %[[LOAD1]] %[[COORD:[a-zA-Z0-9_]+]]
+; CHECK-NEXT: %[[EXTRACT1:[0-9]+]] = OpCompositeExtract %[[FLOAT:[a-zA-Z0-9_]+]] %[[READ]] 0
+; CHECK-NEXT: %[[EXTRACT2:[0-9]+]] = OpCompositeExtract %[[FLOAT]] %[[READ]] 1
+; CHECK-NEXT: %[[ADD:[0-9]+]] = OpFAdd %[[FLOAT]] %[[EXTRACT2]] %[[EXTRACT1]]
+; CHECK-NEXT: %[[INSERT:[0-9]+]] = OpCompositeInsert %[[VEC4]] %[[ADD]] %[[READ]] 0
+; CHECK-NEXT: %[[LOAD2:[0-9]+]] = OpLoad %[[IMG_TYPE]] %[[IMG_VAR]]
+; CHECK-NEXT: OpImageWrite %[[LOAD2]] %[[COORD]] %[[INSERT]]
; CHECK-NEXT: OpReturn
; CHECK-NEXT: OpFunctionEnd
entry:
>From dd1d522bee0706efe72638dbfcbcb08b397c534c Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 7 Oct 2025 13:26:47 -0400
Subject: [PATCH 2/9] [SPIRV] Expand spv_bitcast intrinsic during instruction
selection
The spv_bitcast intrinsic is currently replaced by an OpBitcast
during prelegalization. This will cause a problem when we need to
legalize the OpBitcast. The legalizer assumes that instruction
already lowered to a target specific opcode is legal.
We cannot lower it to a G_BITCAST because the bitcasts sometimes
the LLT type will be the same, causing an error in the verifier,
even if the SPIR-V types will be different.
This commit keeps the intrinsic around until instructoin selection.
We can create rules to legalize a G_INTRINISIC* instruction, and
it does not create problem for the verifier.
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 8 +++
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 56 ++++++++++---------
2 files changed, 37 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 021353ab716f7..ccc2c0fc467fb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3119,6 +3119,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectInsertElt(ResVReg, ResType, I);
case Intrinsic::spv_gep:
return selectGEP(ResVReg, ResType, I);
+ case Intrinsic::spv_bitcast: {
+ Register OpReg = I.getOperand(2).getReg();
+ SPIRVType *OpType =
+ OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+ if (!GR.isBitcastCompatible(ResType, OpType))
+ report_fatal_error("incompatible result and operand types in a bitcast");
+ return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast);
+ }
case Intrinsic::spv_unref_global:
case Intrinsic::spv_init_global: {
MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg());
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index db6f2d61e8f29..43ded6a71dd6d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -192,31 +192,38 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
.addUse(OpReg);
}
-// We do instruction selections early instead of calling MIB.buildBitcast()
-// generating the general op code G_BITCAST. When MachineVerifier validates
-// G_BITCAST we see a check of a kind: if Source Type is equal to Destination
-// Type then report error "bitcast must change the type". This doesn't take into
-// account the notion of a typed pointer that is important for SPIR-V where a
-// user may and should use bitcast between pointers with different pointee types
-// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast).
-// It's important for correct lowering in SPIR-V, because interpretation of the
-// data type is not left to instructions that utilize the pointer, but encoded
-// by the pointer declaration, and the SPIRV target can and must handle the
-// declaration and use of pointers that specify the type of data they point to.
-// It's not feasible to improve validation of G_BITCAST using just information
-// provided by low level types of source and destination. Therefore we don't
-// produce G_BITCAST as the general op code with semantics different from
-// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only
-// difference would be that CombinerHelper couldn't transform known patterns
-// around G_BUILD_VECTOR. See discussion
-// in https://github.com/llvm/llvm-project/pull/110270 for even more context.
-static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error.
+// The verifier checks if the source and destination LLTs of a G_BITCAST are
+// different, but this check is too strict for SPIR-V's typed pointers, which
+// may have the same LLT but different SPIRVType (e.g. pointers to different
+// pointee types). By lowering to OpBitcast here, we bypass the verifier's
+// check. See discussion in https://github.com/llvm/llvm-project/pull/110270
+// for more context.
+//
+// We also handle the llvm.spv.bitcast intrinsic here. If the source and
+// destination SPIR-V types are the same, we lower it to a COPY to enable
+// further optimizations like copy propagation.
+static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
SmallVector<MachineInstr *, 16> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
+ if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(2).getReg();
+ SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg);
+ SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg);
+ if (DstType && SrcType && DstType == SrcType) {
+ MIB.setInsertPt(*MI.getParent(), MI);
+ MIB.buildCopy(DstReg, SrcReg);
+ ToErase.push_back(&MI);
+ continue;
+ }
+ }
+
if (MI.getOpcode() != TargetOpcode::G_BITCAST)
continue;
+
MIB.setInsertPt(*MI.getParent(), MI);
buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),
MI.getOperand(1).getReg());
@@ -237,16 +244,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
- if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) &&
- !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
+ if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
continue;
assert(MI.getOperand(2).isReg());
MIB.setInsertPt(*MI.getParent(), MI);
ToErase.push_back(&MI);
- if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
- MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
- continue;
- }
Register Def = MI.getOperand(0).getReg();
Register Source = MI.getOperand(2).getReg();
Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0);
@@ -1089,7 +1091,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
removeImplicitFallthroughs(MF, MIB);
insertSpirvDecorations(MF, GR, MIB);
insertInlineAsm(MF, GR, ST, MIB);
- selectOpBitcasts(MF, GR, MIB);
+ lowerBitcasts(MF, GR, MIB);
return true;
}
>From cd5a1d26985f2a173a490f6445314fabd54f0fa1 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 09:53:11 -0400
Subject: [PATCH 3/9] Remove unnecessary pointer checks
---
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 43ded6a71dd6d..d538009f0ecbe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -212,8 +212,13 @@ static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(2).getReg();
SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg);
+ assert(
+ DstType &&
+ "Expected destination SPIR-V type to have been assigned already.");
SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg);
- if (DstType && SrcType && DstType == SrcType) {
+ assert(SrcType &&
+ "Expected source SPIR-V type to have been assigned already.");
+ if (DstType == SrcType) {
MIB.setInsertPt(*MI.getParent(), MI);
MIB.buildCopy(DstReg, SrcReg);
ToErase.push_back(&MI);
>From 9d342d353cc48ec8b73e84686abdf1be15c90a50 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 10:03:31 -0400
Subject: [PATCH 4/9] [SPIRV] Fix vector bitcast check in LegalizePointerCast
The previous check for vector bitcasts in `loadVectorFromVector` only
compared the number of elements, which is insufficient when the element
types differ. This can lead to incorrect assumptions about the validity
of the cast.
This commit replaces the element count check with a comparison of the
total size of the vectors in bits. This ensures that the bitcast is
only performed between vectors of the same size, preventing potential
miscompilations.
---
.../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 9 +++++++-
.../hlsl-resources/issue-146942-ptr-cast.ll | 4 +---
.../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 22 +++++++++++++++++++
3 files changed, 31 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 28a1690ef0be1..a692c24363310 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Returns the loaded value.
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
FixedVectorType *TargetType, Value *Source) {
- assert(TargetType->getNumElements() <= SourceType->getNumElements());
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
buildAssignType(B, SourceType, NewLoad);
Value *AssignValue = NewLoad;
if (TargetType->getElementType() != SourceType->getElementType()) {
+ const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+ [[maybe_unused]] TypeSize TargetTypeSize =
+ DL.getTypeSizeInBits(TargetType);
+ [[maybe_unused]] TypeSize SourceTypeSize =
+ DL.getTypeSizeInBits(SourceType);
+ assert(TargetTypeSize == SourceTypeSize);
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
{TargetType, SourceType}, {NewLoad});
buildAssignType(B, TargetType, AssignValue);
+ return AssignValue;
}
+ assert(TargetType->getNumElements() < SourceType->getNumElements());
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
index ed67344842b11..4817e7450ac2e 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll
@@ -16,7 +16,6 @@
define void @case1() local_unnamed_addr {
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
- ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
%2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2)
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
@@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr {
define void @case2() local_unnamed_addr {
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
- ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
- ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
+ ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str)
%2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3)
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
index 84913283f6868..a1ec2cd1cfdd2 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -26,3 +26,25 @@ entry:
store <4 x i32> %6, ptr addrspace(11) %7, align 16
ret void
}
+
+; This tests a load from a pointer that has been bitcast between vector types
+; which share the same total bit-width but have different numbers of elements.
+; Tests that legalize-pointer-casts works correctly by moving the bitcast to
+; the element that was loaded.
+
+define void @main2() local_unnamed_addr #0 {
+entry:
+; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}}
+; CHECK: %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]]
+; CHECK: %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]]
+; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}}
+
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+ %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0)
+ %3 = load <4 x i32>, ptr addrspace(11) %2
+ %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1)
+ store <4 x i32> %3, ptr addrspace(11) %4
+ ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
>From c16bc1ef99215d81e79e76772f37f08369b04602 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 13:21:24 -0400
Subject: [PATCH 5/9] [SPIRV] Use a worklist in the post-legalizer
This commit refactors the SPIRV post-legalizer to use a worklist to process
new instructions. Previously, the post-legalizer would iterate through all
instructions and try to assign types. This could fail if a new instruction
depended on another new instruction that had not been processed yet.
The new implementation adds all new instructions that require a SPIR-V type
to a worklist. It then iteratively processes the worklist until it is empty.
This ensures that all dependencies are met before an instruction is
processed.
This change makes the post-legalizer more robust and fixes potential ordering
issues with newly generated instructions.
Existing tests cover existing functionality. More tests will be added as
the legalizer is modified.
Part of #153091
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 388 ++++++++++++++++---
1 file changed, 334 insertions(+), 54 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..d11168b70aea8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
@@ -53,69 +59,344 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_ICMP:
+ case TargetOpcode::G_ANYEXT:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const LLT &Ty = MRI.getType(ResVReg);
+ unsigned BitWidth = Ty.getScalarSizeInBits();
+ return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+}
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ if (DefType->getOpcode() == SPIRV::OpTypeVector) {
+ SPIRVType *ScalarType =
+ GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (!GR->getSPIRVTypeForVReg(DefReg)) {
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType;
+ if (DefLLT.isVector()) {
+ const SPIRVInstrInfo *TII =
+ MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I, *TII);
+ } else {
ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
- }
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register VecReg = I->getOperand(1).getReg();
+ if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+ assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+ return GR->getScalarOrVectorComponentType(VecType);
+ }
+
+ // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+ // If so get the type from there.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+ return GR->getScalarOrVectorComponentType(BuildVecType);
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ // First check if any of the operands have a type.
+ for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
+ MIB, false);
+ }
+ }
+ // If that did not work, then check the uses.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ }
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+ return Type;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
+ return nullptr;
+
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ Register UseResultReg = Use.getOperand(0).getReg();
+ if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+ SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+ const LLT &BitcastLLT = MRI.getType(ResVReg);
+ if (BitcastLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, BitcastLLT.getNumElements(), MIB, false);
+ return ScalarType;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ // The result type of G_ANYEXT cannot be inferred from its operand.
+ // We use the result register's LLT to determine the correct integer type.
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (!ResLLT.isScalar())
+ return nullptr;
+ return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+}
+
+static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
+ !I->getOperand(1).isReg())
+ return nullptr;
+
+ SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
+ if (!OpType)
+ return nullptr;
+ return OpType;
+}
+
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const unsigned Opcode = I->getOpcode();
+ Register ResVReg = I->getOperand(0).getReg();
+ SPIRVType *ResType = nullptr;
+
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT: {
+ ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_UNMERGE_VALUES: {
+ // This one is special as it defines multiple registers.
+ if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
+ return true;
+ break;
+ }
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+ ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_BUILD_VECTOR: {
+ ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_ANYEXT: {
+ ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_IMPLICIT_DEF: {
+ ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
+ ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ default:
+ ResType = deduceTypeForDefault(I, MF, GR);
+ break;
+ }
+
+ if (ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+ if (!MRI.getRegClassOrNull(ResVReg)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo &MRI) {
+ LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+ << I;);
+ if (I.getNumDefs() == 0) {
+ LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+ return false;
+ }
+ if (!mayBeInserted(I.getOpcode())) {
+ LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+ return false;
+ }
+
+ Register ResultRegister = I.defs().begin()->getReg();
+ if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+ if (!MRI.getRegClassOrNull(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+ GR, &MRI, *GR->CurMF, true);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ SmallVector<MachineInstr *, 8> Worklist;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ if (requiresSpirvType(I, GR, MRI)) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+ assert(Worklist.empty() && "Worklist is not empty");
+ }
+}
+
+static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
+ << MF.getName() << "\n");
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (!isTypeFoldingSupported(MI.getOpcode()))
+ continue;
+ if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
+
+ // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
+ bool HasAssignType = false;
+ Register ResultRegister = MI.defs().begin()->getReg();
+ // All uses of Result register
+ for (MachineInstr &UseInstr :
+ MRI.use_nodbg_instructions(ResultRegister)) {
+ if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
+ HasAssignType = true;
+ LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: "
+ << UseInstr);
+ break;
}
}
+
+ if (!HasAssignType) {
+ Register ResultRegister = MI.defs().begin()->getReg();
+ SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
+ LLVM_DEBUG(
+ dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
+ << printReg(ResultRegister, MRI.getTargetRegisterInfo())
+ << " with type: " << *ResultType);
+ insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
+ }
}
}
}
@@ -156,9 +437,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
-
- processNewInstrs(MF, GR, MIB);
-
+ registerSpirvTypeForNewInstructions(MF, GR, MIB);
+ ensureAssignTypeForTypeFolding(MF, GR, MIB);
return true;
}
>From f2f29a52e3c61d52dbcb2f4728318305026016b3 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 13:21:24 -0400
Subject: [PATCH 6/9] [SPIRV] Use a worklist in the post-legalizer
This commit refactors the SPIRV post-legalizer to use a worklist to process
new instructions. Previously, the post-legalizer would iterate through all
instructions and try to assign types. This could fail if a new instruction
depended on another new instruction that had not been processed yet.
The new implementation adds all new instructions that require a SPIR-V type
to a worklist. It then iteratively processes the worklist until it is empty.
This ensures that all dependencies are met before an instruction is
processed.
This change makes the post-legalizer more robust and fixes potential ordering
issues with newly generated instructions.
Existing tests cover existing functionality. More tests will be added as
the legalizer is modified.
Part of #153091
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 412 ++++++++++++++++---
1 file changed, 359 insertions(+), 53 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d17528dd882bf..b6c650c802247 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -17,7 +17,8 @@
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
-#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Support/Debug.h"
#include <stack>
#define DEBUG_TYPE "spirv-postlegalizer"
@@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
static bool mayBeInserted(unsigned Opcode) {
switch (Opcode) {
+ case TargetOpcode::G_CONSTANT:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT:
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMAX:
case TargetOpcode::G_SMIN:
@@ -53,73 +59,372 @@ static bool mayBeInserted(unsigned Opcode) {
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_BUILD_VECTOR:
+ case TargetOpcode::G_ICMP:
+ case TargetOpcode::G_ANYEXT:
return true;
default:
return isTypeFoldingSupported(Opcode);
}
}
-static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const LLT &Ty = MRI.getType(ResVReg);
+ unsigned BitWidth = Ty.getScalarSizeInBits();
+ return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB);
+}
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &I : MBB) {
- const unsigned Opcode = I.getOpcode();
- if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
- unsigned ArgI = I.getNumOperands() - 1;
- Register SrcReg = I.getOperand(ArgI).isReg()
- ? I.getOperand(ArgI).getReg()
- : Register(0);
- SPIRVType *DefType =
- SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
- report_fatal_error(
- "cannot select G_UNMERGE_VALUES with a non-vector argument");
- SPIRVType *ScalarType =
- GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
- for (unsigned i = 0; i < I.getNumDefs(); ++i) {
- Register ResVReg = I.getOperand(i).getReg();
- SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
- if (!ResType) {
- // There was no "assign type" actions, let's fix this now
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+ if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+ if (DefType->getOpcode() == SPIRV::OpTypeVector) {
+ SPIRVType *ScalarType =
+ GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+ Register DefReg = I->getOperand(i).getReg();
+ if (!GR->getSPIRVTypeForVReg(DefReg)) {
+ LLT DefLLT = MRI.getType(DefReg);
+ SPIRVType *ResType;
+ if (DefLLT.isVector()) {
+ const SPIRVInstrInfo *TII =
+ MF.getSubtarget<SPIRVSubtarget>().getInstrInfo();
+ ResType = GR->getOrCreateSPIRVVectorType(
+ ScalarType, DefLLT.getNumElements(), *I, *TII);
+ } else {
ResType = ScalarType;
- setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
}
+ setRegClassType(DefReg, ResType, GR, &MRI, MF);
}
- } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
- I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
- // Legalizer may have added a new instructions and introduced new
- // registers, we must decorate them as if they were introduced in a
- // non-automatic way
- Register ResVReg = I.getOperand(0).getReg();
- // Check if the register defined by the instruction is newly generated
- // or already processed
- // Check if we have type defined for operands of the new instruction
- bool IsKnownReg = MRI.getRegClassOrNull(ResVReg);
- SPIRVType *ResVType = GR->getSPIRVTypeForVReg(
- IsKnownReg ? ResVReg : I.getOperand(1).getReg());
- if (!ResVType)
- continue;
- // Set type & class
- if (!IsKnownReg)
- setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true);
- // If this is a simple operation that is to be reduced by TableGen
- // definition we must apply some of pre-legalizer rules here
- if (isTypeFoldingSupported(Opcode)) {
- processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg));
- if (IsKnownReg && MRI.hasOneUse(ResVReg)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg);
- if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
- continue;
- }
- insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register VecReg = I->getOperand(1).getReg();
+ if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) {
+ assert(VecType->getOpcode() == SPIRV::OpTypeVector);
+ return GR->getScalarOrVectorComponentType(VecType);
+ }
+
+ // If not handled yet, then check if it is used in a G_BUILD_VECTOR.
+ // If so get the type from there.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
+ Register BuildVecResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg))
+ return GR->getScalarOrVectorComponentType(BuildVecType);
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ // First check if any of the operands have a type.
+ for (unsigned i = 1; i < I->getNumOperands(); ++i) {
+ if (SPIRVType *OpType =
+ GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
+ MIB, false);
+ }
+ }
+ // If that did not work, then check the uses.
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ Register ExtractResReg = Use.getOperand(0).getReg();
+ if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
+ const LLT &ResLLT = MRI.getType(ResVReg);
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), MIB, false);
+ }
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I,
+ MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_BUILD_VECTOR ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ // It's possible that the use instruction has not been processed yet.
+ // We should look at the operands of the use to determine the type.
+ for (unsigned i = 1; i < Use.getNumOperands(); ++i) {
+ if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg()))
+ return Type;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast))
+ return nullptr;
+
+ for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
+ const unsigned UseOpc = Use.getOpcode();
+ assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ Register UseResultReg = Use.getOperand(0).getReg();
+ if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
+ SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
+ const LLT &BitcastLLT = MRI.getType(ResVReg);
+ if (BitcastLLT.isVector())
+ return GR->getOrCreateSPIRVVectorType(
+ ScalarType, BitcastLLT.getNumElements(), MIB, false);
+ return ScalarType;
+ }
+ }
+ return nullptr;
+}
+
+static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ Register ResVReg) {
+ // The result type of G_ANYEXT cannot be inferred from its operand.
+ // We use the result register's LLT to determine the correct integer type.
+ const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
+ if (!ResLLT.isScalar())
+ return nullptr;
+ return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+}
+
+static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR) {
+ if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 ||
+ !I->getOperand(1).isReg())
+ return nullptr;
+
+ SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg());
+ if (!OpType)
+ return nullptr;
+ return OpType;
+}
+
+static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB) {
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << *I);
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const unsigned Opcode = I->getOpcode();
+ Register ResVReg = I->getOperand(0).getReg();
+ SPIRVType *ResType = nullptr;
+
+ switch (Opcode) {
+ case TargetOpcode::G_CONSTANT: {
+ ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_UNMERGE_VALUES: {
+ // This one is special as it defines multiple registers.
+ if (deduceAndAssignTypeForGUnmerge(I, MF, GR))
+ return true;
+ break;
+ }
+ case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
+ ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_BUILD_VECTOR: {
+ ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_ANYEXT: {
+ ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_IMPLICIT_DEF: {
+ ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg);
+ break;
+ }
+ case TargetOpcode::G_INTRINSIC:
+ case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: {
+ ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg);
+ break;
+ }
+ default:
+ ResType = deduceTypeForDefault(I, MF, GR);
+ break;
+ }
+
+ if (ResType) {
+ LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n");
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
+
+ if (!MRI.getRegClassOrNull(ResVReg)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
+ }
+ return true;
+ }
+ return false;
+}
+
+static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
+ MachineRegisterInfo &MRI) {
+ LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
+ << I;);
+ if (I.getNumDefs() == 0) {
+ LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
+ return false;
+ }
+ if (!mayBeInserted(I.getOpcode())) {
+ LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n");
+ return false;
+ }
+
+ Register ResultRegister = I.defs().begin()->getReg();
+ if (GR->getSPIRVTypeForVReg(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
+ if (!MRI.getRegClassOrNull(ResultRegister)) {
+ LLVM_DEBUG(dbgs() << "Updating the register class.\n");
+ setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
+ GR, &MRI, *GR->CurMF, true);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ SmallVector<MachineInstr *, 8> Worklist;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ if (requiresSpirvType(I, GR, MRI)) {
+ Worklist.push_back(&I);
+ }
+ }
+ }
+
+ if (Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
+ return;
+ }
+
+ LLVM_DEBUG(dbgs() << "Initial worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+
+ bool Changed = true;
+ while (Changed) {
+ Changed = false;
+ SmallVector<MachineInstr *, 8> NextWorklist;
+
+ for (MachineInstr *I : Worklist) {
+ if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
+ Changed = true;
+ } else {
+ NextWorklist.push_back(I);
+ }
+ }
+ Worklist = NextWorklist;
+ LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
+ }
+
+ if (!Worklist.empty()) {
+ LLVM_DEBUG(dbgs() << "Remaining worklist:\n";
+ for (auto *I : Worklist) { I->dump(); });
+ assert(Worklist.empty() && "Worklist is not empty");
+ }
+}
+
+static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
+ << MF.getName() << "\n");
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (!isTypeFoldingSupported(MI.getOpcode()))
+ continue;
+ if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
+
+ // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
+ bool HasAssignType = false;
+ Register ResultRegister = MI.defs().begin()->getReg();
+ // All uses of Result register
+ for (MachineInstr &UseInstr :
+ MRI.use_nodbg_instructions(ResultRegister)) {
+ if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
+ HasAssignType = true;
+ LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: "
+ << UseInstr);
+ break;
}
}
+
+ if (!HasAssignType) {
+ Register ResultRegister = MI.defs().begin()->getReg();
+ SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
+ LLVM_DEBUG(
+ dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
+ << printReg(ResultRegister, MRI.getTargetRegisterInfo())
+ << " with type: " << *ResultType);
+ insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
+ }
}
}
}
+static void lowerExtractVectorElements(MachineFunction &MF) {
+ SmallVector<MachineInstr *, 8> ExtractInstrs;
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &MI : MBB) {
+ if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
+ ExtractInstrs.push_back(&MI);
+ }
+ }
+ }
+
+ for (MachineInstr *MI : ExtractInstrs) {
+ MachineIRBuilder MIB(*MI);
+ Register Dst = MI->getOperand(0).getReg();
+ Register Vec = MI->getOperand(1).getReg();
+ Register Idx = MI->getOperand(2).getReg();
+
+ auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false);
+ Intr.addUse(Vec);
+ Intr.addUse(Idx);
+
+ MI->eraseFromParent();
+ }
+}
+
// Do a preorder traversal of the CFG starting from the BB |Start|.
// point. Calls |op| on each basic block encountered during the traversal.
void visit(MachineFunction &MF, MachineBasicBlock &Start,
@@ -156,8 +461,9 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
MachineIRBuilder MIB(MF);
-
- processNewInstrs(MF, GR, MIB);
+ registerSpirvTypeForNewInstructions(MF, GR, MIB);
+ ensureAssignTypeForTypeFolding(MF, GR, MIB);
+ lowerExtractVectorElements(MF);
return true;
}
>From c248de25c59edffdfa2a11e05d78610b10488306 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 28 Oct 2025 13:01:07 -0400
Subject: [PATCH 7/9] Set insertion point in MIB.
---
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 22 +++++++++++++-------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index b6c650c802247..69de5a6360c66 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -138,11 +138,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
MachineIRBuilder &MIB,
Register ResVReg) {
MachineRegisterInfo &MRI = MF.getRegInfo();
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n");
// First check if any of the operands have a type.
for (unsigned i = 1; i < I->getNumOperands(); ++i) {
if (SPIRVType *OpType =
GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) {
const LLT &ResLLT = MRI.getType(ResVReg);
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type "
+ << *OpType << ", returning vector type\n");
return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(),
MIB, false);
}
@@ -153,11 +156,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I,
Register ExtractResReg = Use.getOperand(0).getReg();
if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) {
const LLT &ResLLT = MRI.getType(ResVReg);
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type "
+ << *ScalarType << ", returning vector type\n");
return GR->getOrCreateSPIRVVectorType(
ScalarType, ResLLT.getNumElements(), MIB, false);
}
}
}
+ LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n");
return nullptr;
}
@@ -191,7 +197,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF,
for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) {
const unsigned UseOpc = Use.getOpcode();
assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
- UseOpc == TargetOpcode::G_SHUFFLE_VECTOR);
+ UseOpc == TargetOpcode::G_SHUFFLE_VECTOR ||
+ UseOpc == TargetOpcode::G_BUILD_VECTOR);
Register UseResultReg = Use.getOperand(0).getReg();
if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) {
SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType);
@@ -316,8 +323,7 @@ static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR,
}
static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+ SPIRVGlobalRegistry *GR) {
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 8> Worklist;
for (MachineBasicBlock &MBB : MF) {
@@ -342,6 +348,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
SmallVector<MachineInstr *, 8> NextWorklist;
for (MachineInstr *I : Worklist) {
+ MachineIRBuilder MIB(*I);
if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
Changed = true;
} else {
@@ -360,8 +367,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
}
static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
- SPIRVGlobalRegistry *GR,
- MachineIRBuilder MIB) {
+ SPIRVGlobalRegistry *GR) {
LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
<< MF.getName() << "\n");
MachineRegisterInfo &MRI = MF.getRegInfo();
@@ -395,6 +401,7 @@ static void ensureAssignTypeForTypeFolding(MachineFunction &MF,
dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
<< printReg(ResultRegister, MRI.getTargetRegisterInfo())
<< " with type: " << *ResultType);
+ MachineIRBuilder MIB(MI);
insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
}
}
@@ -460,9 +467,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
GR->setCurrentFunc(MF);
- MachineIRBuilder MIB(MF);
- registerSpirvTypeForNewInstructions(MF, GR, MIB);
- ensureAssignTypeForTypeFolding(MF, GR, MIB);
+ registerSpirvTypeForNewInstructions(MF, GR);
+ ensureAssignTypeForTypeFolding(MF, GR);
lowerExtractVectorElements(MF);
return true;
>From f1caae0f9bed9c2c8ab167de81b2e206d0460410 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Fri, 24 Oct 2025 15:16:19 -0400
Subject: [PATCH 8/9] [SPIRV] legalize long vectors.
---
.../llvm/CodeGen/GlobalISel/LegalizerInfo.h | 10 ++
.../CodeGen/GlobalISel/LegalityPredicates.cpp | 20 +++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 50 +++++--
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 123 ++++++++++++++++--
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h | 4 +
5 files changed, 184 insertions(+), 23 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 51318c9c2736d..a8748965eb2e8 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
unsigned Size);
+/// True iff the specified type index is a vector with an element size
+/// that's greater than the given size.
+LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size);
+
+/// True iff the specified type index is a vector with an element size
+/// that's less than or equal to the given size.
+LLVM_ABI LegalityPredicate
+vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);
+
/// True iff the specified type index is a scalar or a vector with an element
/// type that's wider than the given size.
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
index 30c2d089c3121..5e7cd5fd5d9ad 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
@@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
};
}
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
+ unsigned Size) {
+
+ return [=](const LegalityQuery &Query) {
+ const LLT QueryTy = Query.Types[TypeIdx];
+ return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
+ };
+}
+
+LegalityPredicate
+LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
+ unsigned Size) {
+
+ return [=](const LegalityQuery &Query) {
+ const LLT QueryTy = Query.Types[TypeIdx];
+ return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
+ };
+}
+
LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
unsigned Size) {
return [=](const LegalityQuery &Query) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ccc2c0fc467fb..f9e6a224f581b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg =
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
- SPIRVType *DefType =
+ SPIRVType *SrcType =
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
- if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+ if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");
SPIRVType *ScalarType =
- GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
MachineBasicBlock &BB = *I.getParent();
bool Res = false;
+ unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
- // There was no "assign type" actions, let's fix this now
- ResType = ScalarType;
+ LLT ResLLT = MRI->getType(ResVReg);
+ assert(ResLLT.isValid());
+ if (ResLLT.isVector()) {
+ ResType = GR.getOrCreateSPIRVVectorType(
+ ScalarType, ResLLT.getNumElements(), I, TII);
+ } else {
+ ResType = ScalarType;
+ }
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
- MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
}
- auto MIB =
- BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(SrcReg)
- .addImm(static_cast<int64_t>(i));
- Res |= MIB.constrainAllUses(TII, TRI, RBI);
+
+ if (ResType->getOpcode() == SPIRV::OpTypeVector) {
+ Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addUse(UndefReg);
+ unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
+ for (unsigned j = 0; j < NumElements; ++j) {
+ MIB.addImm(CurrentIndex + j);
+ }
+ CurrentIndex += NumElements;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ } else {
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addImm(CurrentIndex);
+ CurrentIndex++;
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ }
}
return Res;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 53074ea3b2597..815150f6dd1f1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -14,11 +14,13 @@
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
using namespace llvm;
using namespace llvm::LegalizeActions;
@@ -101,6 +103,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
v16s8, v16s16, v16s32, v16s64};
+ auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
+ v3s1, v3s8, v3s16, v3s32, v3s64,
+ v4s1, v4s8, v4s16, v4s32, v4s64};
+
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -126,6 +132,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12};
+ auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
+
bool IsExtendedInts =
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
@@ -148,14 +156,63 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
return IsExtendedInts && Ty.isValid();
};
- for (auto Opc : getTypeFoldingSupportedOpcodes())
- getActionDefinitionsBuilder(Opc).custom();
+ uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+
+ for (auto Opc : getTypeFoldingSupportedOpcodes()) {
+ if (Opc != G_EXTRACT_VECTOR_ELT)
+ getActionDefinitionsBuilder(Opc).custom();
+ }
- getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
+ getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
+ .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
+ .moreElementsToNextPow2(0)
+ .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+ .moreElementsToNextPow2(1)
+ .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+ .alwaysLegal();
- // TODO: add proper rules for vectors legalization.
- getActionDefinitionsBuilder(
- {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+ getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize))
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 1, ElementCount::getFixed(MaxVectorSize)))
+ .custom();
+
+ // Illegal G_UNMERGE_VALUES instructions should be handled
+ // during the combine phase.
+ getActionDefinitionsBuilder(G_BUILD_VECTOR)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(MaxVectorSize)));
+
+ // When entering the legalizer, there should be no G_BITCAST instructions.
+ // They should all be calls to the `spv_bitcast` intrinsic. The call to
+ // the intrinsic will be converted to a G_BITCAST during legalization if
+ // the vectors are not legal. After using the rules to legalize a G_BITCAST,
+ // we turn it back into a call to the intrinsic with a custom ruel to avoid
+ // potential machines verifier failures.
+ getActionDefinitionsBuilder(G_BITCAST)
+ .moreElementsToNextPow2(0)
+ .moreElementsToNextPow2(1)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementCountTo(
+ 0, ElementCount::getFixed(MaxVectorSize)))
+ .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
+ .custom();
+
+ getActionDefinitionsBuilder(G_CONCAT_VECTORS)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .moreElementsToNextPow2(0)
+ .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
+ .alwaysLegal();
+
+ getActionDefinitionsBuilder(G_SPLAT_VECTOR)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+ .moreElementsToNextPow2(0)
+ .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+ LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
.alwaysLegal();
// Vector Reduction Operations
@@ -164,7 +221,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
- .legalFor(allVectors)
+ .legalFor(allowedVectorTypes)
.scalarize(1)
.lower();
@@ -172,9 +229,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.scalarize(2)
.lower();
- // Merge/Unmerge
- // TODO: add proper legalization rules.
- getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+ // Illegal G_UNMERGE_VALUES instructions should be handled
+ // during the combine phase.
+ getActionDefinitionsBuilder(G_UNMERGE_VALUES)
+ .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize));
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
.legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
@@ -287,6 +345,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// Pointer-handling.
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
+ getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
+
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
@@ -374,6 +434,11 @@ bool SPIRVLegalizerInfo::legalizeCustom(
default:
// TODO: implement legalization for other opcodes.
return true;
+ case TargetOpcode::G_BITCAST:
+ return legalizeBitcast(Helper, MI);
+ case TargetOpcode::G_INTRINSIC:
+ return legalizeIntrinsic(Helper, MI);
+
case TargetOpcode::G_IS_FPCLASS:
return legalizeIsFPClass(Helper, MI, LocObserver);
case TargetOpcode::G_ICMP: {
@@ -400,6 +465,44 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
}
+bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
+ MachineInstr &MI) const {
+ MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+ MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+ const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+
+ auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
+ if (IntrinsicID == Intrinsic::spv_bitcast) {
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(2).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT SrcTy = MRI.getType(SrcReg);
+
+ int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+ bool isLongVector =
+ (DstTy.isVector() && DstTy.getNumElements() > MaxVectorSize) ||
+ (SrcTy.isVector() && SrcTy.getNumElements() > MaxVectorSize);
+
+ if (isLongVector) {
+ MIRBuilder.buildBitcast(DstReg, SrcReg);
+ MI.eraseFromParent();
+ }
+ return true;
+ }
+ return true;
+}
+
+bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
+ MachineInstr &MI) const {
+ MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(1).getReg();
+ SmallVector<Register, 1> DstRegs = {DstReg};
+ MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
+ MI.eraseFromParent();
+ return true;
+}
+
// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
// to ensure that all instructions created during the lowering have SPIR-V types
// assigned to them.
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
index eeefa4239c778..86e7e711caa60 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
@@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
public:
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const override;
+ bool legalizeIntrinsic(LegalizerHelper &Helper,
+ MachineInstr &MI) const override;
+
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
private:
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const;
+ bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const;
};
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
>From f50b420bcdade4fdb51fd3c9cf359a0cc1d42a13 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 28 Oct 2025 12:54:52 -0400
Subject: [PATCH 9/9] Add tests
---
.../SPIRV/legalization/load-store-global.ll | 84 +++++++++++++++++++
1 file changed, 84 insertions(+)
create mode 100644 llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
new file mode 100644
index 0000000000000..fbfec1b3ee7cf
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -0,0 +1,84 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+; CHECK-DAG: %[[#v2i32:]] = OpTypeVector %[[#int]] 2
+; CHECK-DAG: %[[#ptr_private_v4i32:]] = OpTypePointer Private %[[#v4i32]]
+; CHECK-DAG: %[[#ptr_private_v4f64:]] = OpTypePointer Private %[[#v4f64]]
+; CHECK-DAG: %[[#global_double:]] = OpVariable %[[#ptr_private_v4f64]] Private
+
+ at G_16 = internal addrspace(10) global [16 x i32] zeroinitializer
+ at G_4_double = internal addrspace(10) global <4 x double> zeroinitializer
+ at G_4_int = internal addrspace(10) global <4 x i32> zeroinitializer
+
+
+; This is the way matrices will be represented in HLSL. The memory type will be
+; an array, but it will be loaded as a vector.
+; TODO: Legalization for loads and stores of long vectors is not implemented yet. │
+;define spir_func void @test_load_store_global() { │
+;entry: │
+; %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64 │
+; store <16 x i32> %0, ptr addrspace(10) @G_16, align 64 │
+; ret void │
+;}
+
+; This is the code pattern that can be generated from the `asuint` and `asdouble`
+; HLSL intrinsics.
+
+; TODO: This cods not the best because instruction selection is not folding an
+; extract from other intstruction. That needs to be handled.
+define spir_func void @test_int32_double_conversion() {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+entry:
+ ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]]
+ ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1
+ ; CHECK: %[[#VEC_SHUF2:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 2 3
+ ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF1]]
+ ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF2]]
+ ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 0
+ ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 2
+ ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 0
+ ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 2
+ ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %[[#EXTRACT4]]
+ ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 1
+ ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 3
+ ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 1
+ ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 3
+ ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT5]] %[[#EXTRACT6]] %[[#EXTRACT7]] %[[#EXTRACT8]]
+ ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 0
+ ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 0
+ ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 1
+ ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 1
+ ; CHECK: %[[#EXTRACT13:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 2
+ ; CHECK: %[[#EXTRACT14:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 2
+ ; CHECK: %[[#EXTRACT15:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 3
+ ; CHECK: %[[#EXTRACT16:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 3
+ ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT9]] %[[#EXTRACT10]]
+ ; CHECK: %[[#CONSTRUCT4:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT11]] %[[#EXTRACT12]]
+ ; CHECK: %[[#CONSTRUCT5:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT13]] %[[#EXTRACT14]]
+ ; CHECK: %[[#CONSTRUCT6:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT15]] %[[#EXTRACT16]]
+ ; CHECK: %[[#BITCAST3:]] = OpBitcast %[[#double]] %[[#CONSTRUCT3]]
+ ; CHECK: %[[#BITCAST4:]] = OpBitcast %[[#double]] %[[#CONSTRUCT4]]
+ ; CHECK: %[[#BITCAST5:]] = OpBitcast %[[#double]] %[[#CONSTRUCT5]]
+ ; CHECK: %[[#BITCAST6:]] = OpBitcast %[[#double]] %[[#CONSTRUCT6]]
+ ; CHECK: %[[#CONSTRUCT7:]] = OpCompositeConstruct %[[#v4f64]] %[[#BITCAST3]] %[[#BITCAST4]] %[[#BITCAST5]] %[[#BITCAST6]]
+ ; CHECK: OpStore %[[#global_double]] %[[#CONSTRUCT7]] Aligned 32
+
+ %0 = load <8 x i32>, ptr addrspace(10) @G_4_double
+ %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+ store <8 x i32> %3, ptr addrspace(10) @G_4_double
+ ret void
+}
+
+; Add a main function to make it a valid module for spirv-val
+define void @main() #1 {
+ ret void
+}
+
+attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
More information about the llvm-commits
mailing list