[llvm] [SPIR-V] Support `SPV_INTEL_int4` extension (PR #141031)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 22 04:21:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Viktoria Maximova (vmaksimo)

<details>
<summary>Changes</summary>

Adds support for native 4-bit type.

Spec:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc

---
Full diff: https://github.com/llvm/llvm-project/pull/141031.diff


8 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+17-4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll (+20) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll (+29) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll (+25) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e6cb8cee66a60..fbaca4e0e4d81 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_ternary_bitwise_function",
          SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
         {"SPV_INTEL_2d_block_io",
-         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
+         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
+        {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ac397fc486e19..d9fcb5623b220 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
     return Width;
   if (Width <= 8)
     Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
   const SPIRVSubtarget &ST =
       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    if ((!isPowerOf2_32(Width) || Width < 8) &&
-        ST.canUseExtension(
-            SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+      MIRBuilder.buildInstr(SPIRV::OpExtension)
+          .addImm(SPIRV::Extension::SPV_INTEL_int4);
+      MIRBuilder.buildInstr(SPIRV::OpCapability)
+          .addImm(SPIRV::Capability::Int4TypeINTEL);
+    } else if ((!isPowerOf2_32(Width) || Width < 8) &&
+               ST.canUseExtension(
+                   SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
       MIRBuilder.buildInstr(SPIRV::OpExtension)
           .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
       MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
   const MachineInstr *NewMI =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
+        const Type *ET = getTypeForSPIRVType(ElemType);
+        if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
+            cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+                .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+          MIRBuilder.buildInstr(SPIRV::OpCapability)
+              .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
+        }
         return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
             .addDef(createTypeVReg(MIRBuilder))
             .addUse(getSPIRVTypeID(ElemType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 578e82881f6e8..29ec90d2ae8df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   bool IsExtendedInts =
       ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
   auto extendedScalarsAndVectors =
       [IsExtendedInts](const LegalityQuery &Query) {
         const LLT Ty = Query.Types[0];
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b6a2da6e2045d..2d6d0e37c8a1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
   bool IsExtendedInts =
       ST->canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
 
   for (MachineBasicBlock *MBB : post_order(&MF)) {
     if (MBB->empty())
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 851495bda4979..51a59e441b5b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
 defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
 defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
 defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
+defm SPV_INTEL_int4 : ExtensionOperand<123>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
 defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
 defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
 defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
+defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
+defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
new file mode 100644
index 0000000000000..02f023276bf5d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: Capability Int4TypeINTEL
+; CHECK-DAG: Capability CooperativeMatrixKHR
+; CHECK-DAG: Extension "SPV_INTEL_int4"
+; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
+; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
+
+; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
+; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
+; CHECK: CompositeConstruct %[[#CoopMatTy]]
+
+define spir_kernel void @foo() {
+entry:
+  %call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
+  ret void
+}
+
+declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
new file mode 100644
index 0000000000000..17202ab243e8f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
+; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made 
+; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
+
+; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
+; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
+; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
+; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
new file mode 100644
index 0000000000000..f1bee0b963613
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
@@ -0,0 +1,25 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: Capability Int4TypeINTEL
+; CHECK: Extension "SPV_INTEL_int4"
+; CHECK: %[[#Int4:]] = OpTypeInt  4 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)

``````````

</details>


https://github.com/llvm/llvm-project/pull/141031


More information about the llvm-commits mailing list