[llvm] [SPIR-V] Support `SPV_INTEL_int4` extension (PR #141031)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 22 04:21:44 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Viktoria Maximova (vmaksimo)
<details>
<summary>Changes</summary>
Adds support for native 4-bit type.
Spec:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc
---
Full diff: https://github.com/llvm/llvm-project/pull/141031.diff
8 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+17-4)
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll (+20)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll (+29)
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll (+25)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e6cb8cee66a60..fbaca4e0e4d81 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_ternary_bitwise_function",
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
{"SPV_INTEL_2d_block_io",
- SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
+ SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
+ {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ac397fc486e19..d9fcb5623b220 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
+ ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
return Width;
if (Width <= 8)
Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- if ((!isPowerOf2_32(Width) || Width < 8) &&
- ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+ if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+ MIRBuilder.buildInstr(SPIRV::OpExtension)
+ .addImm(SPIRV::Extension::SPV_INTEL_int4);
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::Int4TypeINTEL);
+ } else if ((!isPowerOf2_32(Width) || Width < 8) &&
+ ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
+ const Type *ET = getTypeForSPIRVType(ElemType);
+ if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+ .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
+ }
return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 578e82881f6e8..29ec90d2ae8df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
bool IsExtendedInts =
ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
- ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+ ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+ ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
auto extendedScalarsAndVectors =
[IsExtendedInts](const LegalityQuery &Query) {
const LLT Ty = Query.Types[0];
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b6a2da6e2045d..2d6d0e37c8a1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
bool IsExtendedInts =
ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
- ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+ ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+ ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
for (MachineBasicBlock *MBB : post_order(&MF)) {
if (MBB->empty())
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 851495bda4979..51a59e441b5b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
+defm SPV_INTEL_int4 : ExtensionOperand<123>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
+defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
+defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
new file mode 100644
index 0000000000000..02f023276bf5d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: Capability Int4TypeINTEL
+; CHECK-DAG: Capability CooperativeMatrixKHR
+; CHECK-DAG: Extension "SPV_INTEL_int4"
+; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
+; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
+
+; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
+; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
+; CHECK: CompositeConstruct %[[#CoopMatTy]]
+
+define spir_kernel void @foo() {
+entry:
+ %call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
+ ret void
+}
+
+declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
new file mode 100644
index 0000000000000..17202ab243e8f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
+; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made
+; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
+
+; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
+; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
+; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
+; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+ %0 = alloca i4
+ store i4 1, ptr %0
+ %1 = load i4, ptr %0
+ call spir_func void @boo(i4 %1)
+ ret void
+}
+
+declare spir_func void @boo(i4)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
new file mode 100644
index 0000000000000..f1bee0b963613
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
@@ -0,0 +1,25 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: Capability Int4TypeINTEL
+; CHECK: Extension "SPV_INTEL_int4"
+; CHECK: %[[#Int4:]] = OpTypeInt 4 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+ %0 = alloca i4
+ store i4 1, ptr %0
+ %1 = load i4, ptr %0
+ call spir_func void @boo(i4 %1)
+ ret void
+}
+
+declare spir_func void @boo(i4)
``````````
</details>
https://github.com/llvm/llvm-project/pull/141031
More information about the llvm-commits
mailing list