[llvm] Reland [SPIR-V] Support `SPV_INTEL_int4` extension (PR #141279)

Viktoria Maximova via llvm-commits llvm-commits at lists.llvm.org
Fri May 23 11:53:46 PDT 2025


https://github.com/vmaksimo created https://github.com/llvm/llvm-project/pull/141279

This relands #141031 

This change ensures generated SPIR-V is valid and passes machine verification:
```
*** Bad machine code: inconsistent constant size ***
- function:    foo
- basic block: %bb.1 entry (0x9ec9298)
- instruction: %12:iid(s8) = G_CONSTANT i4 1
```
That is done by promoting `G_CONSTANT` instructions with small integer types (e.g., `i4`) to `i8` if no extensions for "special" integer types are enabled.


>From 4c243ab88ad057dd58135a20e5e526bb789e57d8 Mon Sep 17 00:00:00 2001
From: "Maksimova, Viktoria" <viktoria.maksimova at intel.com>
Date: Tue, 20 May 2025 11:45:05 -0700
Subject: [PATCH 1/3] [SPIR-V] Support `SPV_INTEL_int4` extension

Adds support for native 4-bit type.

Spec:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc
---
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |  3 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 21 +++++++++++---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  3 +-
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   |  3 +-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  3 ++
 .../SPV_INTEL_int4/cooperative_matrix.ll      | 20 +++++++++++++
 .../extensions/SPV_INTEL_int4/negative.ll     | 29 +++++++++++++++++++
 .../extensions/SPV_INTEL_int4/trivial.ll      | 25 ++++++++++++++++
 8 files changed, 100 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e6cb8cee66a60..fbaca4e0e4d81 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_ternary_bitwise_function",
          SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
         {"SPV_INTEL_2d_block_io",
-         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
+         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
+        {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ac397fc486e19..d9fcb5623b220 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
     return Width;
   if (Width <= 8)
     Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
   const SPIRVSubtarget &ST =
       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    if ((!isPowerOf2_32(Width) || Width < 8) &&
-        ST.canUseExtension(
-            SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+      MIRBuilder.buildInstr(SPIRV::OpExtension)
+          .addImm(SPIRV::Extension::SPV_INTEL_int4);
+      MIRBuilder.buildInstr(SPIRV::OpCapability)
+          .addImm(SPIRV::Capability::Int4TypeINTEL);
+    } else if ((!isPowerOf2_32(Width) || Width < 8) &&
+               ST.canUseExtension(
+                   SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
       MIRBuilder.buildInstr(SPIRV::OpExtension)
           .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
       MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
   const MachineInstr *NewMI =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
+        const Type *ET = getTypeForSPIRVType(ElemType);
+        if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
+            cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+                .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+          MIRBuilder.buildInstr(SPIRV::OpCapability)
+              .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
+        }
         return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
             .addDef(createTypeVReg(MIRBuilder))
             .addUse(getSPIRVTypeID(ElemType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 578e82881f6e8..29ec90d2ae8df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   bool IsExtendedInts =
       ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
   auto extendedScalarsAndVectors =
       [IsExtendedInts](const LegalityQuery &Query) {
         const LLT Ty = Query.Types[0];
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b6a2da6e2045d..2d6d0e37c8a1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
   bool IsExtendedInts =
       ST->canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
 
   for (MachineBasicBlock *MBB : post_order(&MF)) {
     if (MBB->empty())
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 851495bda4979..51a59e441b5b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
 defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
 defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
 defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
+defm SPV_INTEL_int4 : ExtensionOperand<123>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
 defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
 defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
 defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
+defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
+defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
new file mode 100644
index 0000000000000..02f023276bf5d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: Capability Int4TypeINTEL
+; CHECK-DAG: Capability CooperativeMatrixKHR
+; CHECK-DAG: Extension "SPV_INTEL_int4"
+; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
+; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
+
+; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
+; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
+; CHECK: CompositeConstruct %[[#CoopMatTy]]
+
+define spir_kernel void @foo() {
+entry:
+  %call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
+  ret void
+}
+
+declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
new file mode 100644
index 0000000000000..17202ab243e8f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
+; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made 
+; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
+
+; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
+; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
+; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
+; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
new file mode 100644
index 0000000000000..f1bee0b963613
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
@@ -0,0 +1,25 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: Capability Int4TypeINTEL
+; CHECK: Extension "SPV_INTEL_int4"
+; CHECK: %[[#Int4:]] = OpTypeInt  4 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)

>From 5f0b40a7a5177e2689df61f0836d69f0c1862a43 Mon Sep 17 00:00:00 2001
From: "Maksimova, Viktoria" <viktoria.maksimova at intel.com>
Date: Thu, 22 May 2025 07:53:22 -0700
Subject: [PATCH 2/3] Update llvm/docs/SPIRVUsage.rst

---
 llvm/docs/SPIRVUsage.rst | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 4d8b7996dec01..373dd1e3856d8 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -215,6 +215,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform. 
    * - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
      - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix. 
+   * - ``SPV_INTEL_int4``
+     - Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
 
 To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
 

>From 8af2ecc5dedd835200a0c2c044408ae739be0426 Mon Sep 17 00:00:00 2001
From: "Maksimova, Viktoria" <viktoria.maksimova at intel.com>
Date: Fri, 23 May 2025 11:22:54 -0700
Subject: [PATCH 3/3] Fix MachineVerifier failure

This fixes the following error:

```
*** Bad machine code: inconsistent constant size ***
- function:    foo
- basic block: %bb.1 entry (0x9ec9298)
- instruction: %12:iid(s8) = G_CONSTANT i4 1
```
---
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 39 ++++++++++++++-----
 .../extensions/SPV_INTEL_int4/negative.ll     |  2 +-
 2 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 2d6d0e37c8a1c..ee6d32eb9aa2c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -380,16 +380,32 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
 // To support current approach and limitations wrt. bit width here we widen a
 // scalar register with a bit width greater than 1 to valid sizes and cap it to
 // 64 width.
-static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
+static unsigned widenBitWidthToNextPow2(unsigned BitWidth) {
+  if (BitWidth == 1)
+    return 1; // No need to widen 1-bit values
+  return std::min(std::max(1u << Log2_32_Ceil(BitWidth), 8u), 64u);
+}
+
+static void widenScalarType(Register Reg, MachineRegisterInfo &MRI) {
   LLT RegType = MRI.getType(Reg);
   if (!RegType.isScalar())
     return;
-  unsigned Sz = RegType.getScalarSizeInBits();
-  if (Sz == 1)
-    return;
-  unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
-  if (NewSz != Sz)
-    MRI.setType(Reg, LLT::scalar(NewSz));
+  unsigned CurrentWidth = RegType.getScalarSizeInBits();
+  unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
+  if (NewWidth != CurrentWidth)
+    MRI.setType(Reg, LLT::scalar(NewWidth));
+}
+
+static void widenCImmType(MachineOperand &MOP) {
+  const ConstantInt *CImmVal = MOP.getCImm();
+  unsigned CurrentWidth = CImmVal->getBitWidth();
+  unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
+  if (NewWidth != CurrentWidth) {
+    // Replace the immediate value with the widened version
+    MOP.setCImm(ConstantInt::get(
+        CImmVal->getType()->getContext(),
+        CImmVal->getValue().zextOrTrunc(NewWidth)));
+  }
 }
 
 static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
@@ -506,10 +522,13 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       unsigned MIOp = MI.getOpcode();
 
       if (!IsExtendedInts) {
-        // validate bit width of scalar registers
-        for (const auto &MOP : MI.operands())
+        // validate bit width of scalar registers and immediates
+        for (auto &MOP : MI.operands()) {
           if (MOP.isReg())
-            widenScalarLLTNextPow2(MOP.getReg(), MRI);
+            widenScalarType(MOP.getReg(), MRI);
+          else if (MOP.isCImm())
+            widenCImmType(MOP);
+        }
       }
 
       if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
index 17202ab243e8f..4d5fa52a166f2 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
@@ -1,6 +1,6 @@
 ; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
 
-; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
 ; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made 
 ; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
 



More information about the llvm-commits mailing list