[llvm] [SPIR-V] Add SPV_INTEL_joint_matrix extension (PR #118578)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 3 18:20:42 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Dmitry Sidorov (MrSidims)

<details>
<summary>Changes</summary>

The spec is available here:
https://github.com/intel/llvm/pull/12497

The PR doesn't add OpCooperativeMatrixApplyFunctionINTEL instruction as it's still experimental and not properly tested E2E.

The PR also fixes few bugs in the related code:
1. CooperativeMatrixMulAddKHR optional operand must be literal, not a constant;
2. Fixed available capabilities table creation for a case, when a single extension adds few capabilities, that occupy not contiguous op codes.

---

Patch is 50.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118578.diff


17 Files Affected:

- (modified) llvm/docs/SPIRVUsage.rst (+2) 
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp (+6-2) 
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h (+10) 
- (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+27) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+40-5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+17) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+104) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+68) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_bf16.ll (+46) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll (+59) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_get_coord.ll (+35) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_packed.ll (+78) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll (+34) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_tf32.ll (+46) 
- (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll (+1-1) 


``````````diff
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 28e919fdf516a0..8f7ac71f8026b3 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Introduces two new storage classes that are subclasses of the CrossWorkgroup storage class that provides additional information that can enable optimization.
    * - ``SPV_INTEL_variable_length_array``
      - Allows to allocate local arrays whose number of elements is unknown at compile time.
+   * - ``SPV_INTEL_joint_matrix``
+     - Adds few matrix capabilities on top of SPV_KHR_cooperative_matrix extension, such as matrix prefetch, get element coordinate and checked load/store/construct instructions, tensor float 32 and bfloat type interpretations for multuply-add instruction.
    * - ``SPV_KHR_bit_instructions``
      - Enables bit instructions to be used by SPIR-V modules without requiring the Shader capability.
    * - ``SPV_KHR_expect_assume``
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
index 0f9a2a69e07390..67bf4596152492 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp
@@ -137,8 +137,12 @@ getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension) {
 
   CapabilityList Capabilities;
   while (Entry &&
-         Entry->Category == SPIRV::OperandCategory::CapabilityOperand &&
-         Entry->ReqExtension == Extension) {
+         Entry->Category == SPIRV::OperandCategory::CapabilityOperand) {
+    // Some capabilities' codes might go not in order.
+    if (Entry->ReqExtension != Extension) {
+      ++Entry;
+      continue;
+    }
     Capabilities.push_back(
         static_cast<SPIRV::Capability::Capability>(Entry->Value));
     ++Entry;
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index 44625793e94138..823c33ecb6bd38 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -207,6 +207,16 @@ namespace Opcode {
 #include "SPIRVGenTables.inc"
 } // namespace Opcode
 
+namespace CooperativeMatrixLayout {
+#define GET_CooperativeMatrixLayout_DECL
+#include "SPIRVGenTables.inc"
+} // namespace Opcode
+
+namespace CooperativeMatrixOperands {
+#define GET_CooperativeMatrixOperands_DECL
+#include "SPIRVGenTables.inc"
+} // namespace Opcode
+
 struct ExtendedBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index ff8759755e5176..d05a0e87ca870e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -211,6 +211,33 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
           // are part of the variable value.
           printOpConstantVarOps(MI, NumFixedOps - 1, OS);
           break;
+        case SPIRV::OpCooperativeMatrixMulAddKHR: {
+          const unsigned NumOps = MI->getNumOperands();
+          if (NumFixedOps == NumOps)
+            break;
+
+          OS << ' ';
+          const unsigned MulAddOp = MI->getOperand(FirstVariableIndex).getImm();
+          if (MulAddOp == 0) {
+            printSymbolicOperand<
+              OperandCategory::CooperativeMatrixOperandsOperand>(
+                  MI, FirstVariableIndex, OS);
+          } else {
+            std::string Buffer;
+            for (unsigned Mask = 0x1;
+                 Mask != SPIRV::CooperativeMatrixOperands::MatrixResultBFloat16ComponentsINTEL;
+                 Mask <<= 1) {
+              if (MulAddOp & Mask) {
+                if (!Buffer.empty())
+                  Buffer += '|';
+                Buffer += getSymbolicOperandMnemonic(
+                    OperandCategory::CooperativeMatrixOperandsOperand, Mask);
+              }
+            }
+            OS << Buffer;
+          }
+          break;
+        }
         default:
           printRemainingVariableOps(MI, NumFixedOps, OS);
           break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 45a49674d4ca21..28fcbda2d1df89 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1969,15 +1969,50 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   unsigned Opcode =
       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
-  bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
+  bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR &&
+               Opcode != SPIRV::OpCooperativeMatrixStoreCheckedINTEL &&
+               Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL;
   unsigned ArgSz = Call->Arguments.size();
   unsigned LiteralIdx = 0;
-  if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
-    LiteralIdx = 3;
-  else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
-    LiteralIdx = 4;
+  switch(Opcode) {
+    // Memory operand is optional and is literal.
+    case SPIRV::OpCooperativeMatrixLoadKHR:
+      LiteralIdx = ArgSz > 3 ? 3 : 0;
+      break;
+    case SPIRV::OpCooperativeMatrixStoreKHR:
+      LiteralIdx = ArgSz > 4 ? 4 : 0;
+      break;
+    case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
+      LiteralIdx = ArgSz > 7 ? 7 : 0;
+      break;
+    case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
+      LiteralIdx = ArgSz > 8 ? 8 : 0;
+      break;
+    // Cooperative Matrix Operands operand is optional and is literal.
+    case SPIRV::OpCooperativeMatrixMulAddKHR:
+      LiteralIdx = ArgSz > 3 ? 3 : 0;
+      break;
+  };
+
   SmallVector<uint32_t, 1> ImmArgs;
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
+    const uint32_t CacheLevel =
+        getConstFromIntrinsic(Call->Arguments[3], MRI);
+    auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL)
+        .addUse(Call->Arguments[0])   // pointer
+        .addUse(Call->Arguments[1])   // rows
+        .addUse(Call->Arguments[2])   // columns
+        .addImm(CacheLevel)           // cache level
+        .addUse(Call->Arguments[4]);  // memory layout
+    if (ArgSz > 5)
+      MIB.addUse(Call->Arguments[5]); // stride
+    if (ArgSz > 6) {
+      const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI);
+      MIB.addImm(MemOp);              // memory operand
+    }
+    return true;
+  }
   if (LiteralIdx > 0)
     ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index dc2da4a3a5647a..e29013d28aafe4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -695,6 +695,13 @@ defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, C
 defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 4, OpCooperativeMatrixMulAddKHR>;
 defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
 
+// Cooperative Matrix Intel builtin records:
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixPrefetchINTEL", OpenCL_std, CoopMatr, 5, 7, OpCooperativeMatrixPrefetchINTEL>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadCheckedINTEL", OpenCL_std, CoopMatr, 6, 8, OpCooperativeMatrixLoadCheckedINTEL>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreCheckedINTEL", OpenCL_std, CoopMatr, 7, 9, OpCooperativeMatrixStoreCheckedINTEL>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixConstructCheckedINTEL", OpenCL_std, CoopMatr, 5, 5, OpCooperativeMatrixConstructCheckedINTEL>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixGetElementCoordINTEL", OpenCL_std, CoopMatr, 2, 2, OpCooperativeMatrixGetElementCoordINTEL>;
+
 //===----------------------------------------------------------------------===//
 // Class defining a work/sub group builtin that should be translated into a
 // SPIR-V instruction using the defined properties.
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e78fc5ce18707b..fb05c1fdbd1e3b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -51,6 +51,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_subgroups},
         {"SPV_INTEL_media_block_io",
          SPIRV::Extension::Extension::SPV_INTEL_media_block_io},
+        {"SPV_INTEL_joint_matrix",
+         SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
         {"SPV_KHR_uniform_group_instructions",
          SPIRV::Extension::Extension::SPV_KHR_uniform_group_instructions},
         {"SPV_KHR_no_integer_wrap_decoration",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 53f1b644a94983..d95803fea56a58 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -895,6 +895,23 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
 def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
                   "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
 
+// SPV_INTEL_joint_matrix
+def OpCooperativeMatrixLoadCheckedINTEL: Op<6193, (outs ID:$res),
+                  (ins TYPE:$resType, ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
+                  "$res = OpCooperativeMatrixLoadCheckedINTEL $resType $pointer $xOffset $yOffset $memory_layout $height $width">;
+def OpCooperativeMatrixStoreCheckedINTEL: Op<6194, (outs),
+                  (ins ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$objectToStore, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
+                  "OpCooperativeMatrixStoreCheckedINTEL $pointer $xOffset $yOffset $objectToStore $memory_layout $height $width">;
+def OpCooperativeMatrixConstructCheckedINTEL: Op<6195, (outs ID:$res),
+                  (ins TYPE:$resType, ID:$xOffset, ID:$yOffset, ID:$height, ID:$width, ID:$value),
+                  "$res = OpCooperativeMatrixConstructCheckedINTEL $resType $xOffset $yOffset $height $width $value">;
+def OpCooperativeMatrixGetElementCoordINTEL: Op<6440, (outs ID:$res),
+                  (ins TYPE:$resType, ID:$matrix, ID:$index),
+                  "$res = OpCooperativeMatrixGetElementCoordINTEL $resType $matrix $index">;
+def OpCooperativeMatrixPrefetchINTEL: Op<6449, (outs),
+                  (ins ID:$pointer, ID:$rows, ID:$columns, i32imm:$cacheLevel, ID:$memory_layout, variable_ops),
+                  "OpCooperativeMatrixPrefetchINTEL $pointer $rows $columns $cacheLevel $memory_layout">;
+
 // SPV_EXT_arithmetic_fence
 def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
                   "$res = OpArithmeticFenceEXT $type $target">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 20540814763157..ee7c08d324bb0c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1437,6 +1437,110 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
     }
     break;
+  case SPIRV::OpCooperativeMatrixMulAddKHR: {
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+      report_fatal_error("Cooperative matrix instructions require the "
+                         "following SPIR-V extension: "
+                         "SPV_KHR_cooperative_matrix", false);
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    constexpr unsigned MulAddMaxSize = 6;
+    if (MI.getNumOperands() != MulAddMaxSize)
+      break;
+    const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
+    if (CoopOperands &
+        SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+      Reqs.addCapability(
+          SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
+    }
+    if (CoopOperands & SPIRV::CooperativeMatrixOperands::
+        MatrixAAndBBFloat16ComponentsINTEL ||
+        CoopOperands & SPIRV::CooperativeMatrixOperands::
+        MatrixCBFloat16ComponentsINTEL ||
+        CoopOperands & SPIRV::CooperativeMatrixOperands::
+        MatrixResultBFloat16ComponentsINTEL) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+      Reqs.addCapability(
+          SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
+    }
+    break;
+  }
+  case SPIRV::OpCooperativeMatrixLoadKHR:
+  case SPIRV::OpCooperativeMatrixStoreKHR:
+  case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
+  case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
+  case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+      report_fatal_error("Cooperative matrix instructions require the "
+                         "following SPIR-V extension: "
+                         "SPV_KHR_cooperative_matrix", false);
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+
+    // Check Layout operand in case if it's not a standart one and add the
+    // appropriate capability.
+    std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
+      {SPIRV::OpCooperativeMatrixLoadKHR, 3},
+      {SPIRV::OpCooperativeMatrixStoreKHR, 2},
+      {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
+      {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
+      {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}
+    };
+
+    const auto OpCode = MI.getOpcode();
+    const unsigned LayoutNum = LayoutToInstMap[OpCode];
+    Register RegLayout = MI.getOperand(LayoutNum).getReg();
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
+    if (MILayout->getOpcode() == SPIRV::OpConstantI) {
+      const unsigned LayoutVal = MILayout->getOperand(2).getImm();
+      if (LayoutVal == static_cast<unsigned>(
+            SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
+        if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
+          report_fatal_error("PackedINTEL layout require the following SPIR-V "
+                             "extension: SPV_INTEL_joint_matrix", false);
+        Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+        Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
+      }
+    }
+
+    // Nothing to do.
+    if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
+        OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
+      break;
+
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
+      report_fatal_error("OpCooperativeMatrix[Load/Store]CheckedINTEL "
+                         "instructions require the following SPIR-V extension: "
+                         "SPV_INTEL_joint_matrix", false);
+    Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+    if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
+      Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
+      break;
+    }
+    Reqs.addCapability(
+        SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
+    break;
+  }
+  case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
+      report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL"
+                         " instructions require the following SPIR-V extension:"
+                         " SPV_INTEL_joint_matrix", false);
+    Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+    Reqs.addCapability(
+        SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
+    break;
+  case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
+      report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
+                         "following SPIR-V extension: SPV_INTEL_joint_matrix",
+                         false);
+    Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
+    Reqs.addCapability(
+        SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
+    break;
   case SPIRV::OpKill: {
     Reqs.addCapability(SPIRV::Capability::Shader);
   } break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index a3a88acdd6c6ae..745d1e1aec67aa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -170,6 +170,8 @@ def GroupOperationOperand : OperandCategory;
 def KernelEnqueueFlagsOperand : OperandCategory;
 def KernelProfilingInfoOperand : OperandCategory;
 def OpcodeOperand : OperandCategory;
+def CooperativeMatrixLayoutOperand : OperandCategory;
+def CooperativeMatrixOperandsOperand : OperandCategory;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Extesions enum values and at the same time
@@ -305,6 +307,7 @@ defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
 defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
 defm SPV_EXT_arithmetic_fence : ExtensionOperand<112>;
 defm SPV_EXT_optnone : ExtensionOperand<113>;
+defm SPV_INTEL_joint_matrix : ExtensionOperand<114>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -492,6 +495,12 @@ defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_control
 defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
 defm ArithmeticFenceEXT : CapabilityOperand<6144, 0, 0, [SPV_EXT_arithmetic_fence], []>;
 defm SplitBarrierINTEL : CapabilityOperand<6141, 0, 0, [SPV_INTEL_split_barrier], []>;
+defm CooperativeMatrixCheckedInstructionsINTEL : CapabilityOperand<6192, 0, 0, [SPV_INTEL_joint_matrix], []>;
+defm CooperativeMatrixPrefetchINTEL : CapabilityOperand<6411, 0, 0, [SPV_INTEL_joint_matrix], []>;
+defm PackedCooperativeMatrixINTEL : CapabilityOperand<6434, 0, 0, [SPV_INTEL_joint_matrix], []>;
+defm CooperativeMatrixInvocationInstructionsINTEL : CapabilityOperand<6435, 0, 0, [SPV_INTEL_joint_matrix], []>;
+defm CooperativeMatrixTF32ComponentTypeINTEL : CapabilityOperand<6436, 0, 0, [SPV_INTEL_joint_matrix], []>;
+defm CooperativeMatrixBFloat16ComponentTypeINTEL : CapabilityOperand<6437, 0, 0, [SPV_INTEL_joint_matrix], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -1649,3 +1658,62 @@ defm GenericCastToPtr : OpcodeOperand<122>;
 defm Bitcast : OpcodeOperand<124>;
 defm ConvertPtrToU : OpcodeOperand<117>;
 defm ConvertUToPtr : OpcodeOperand<120>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define Cooperative Matrix Layout enum values and at the
+// same time SymbolicOperand entries extensions and capabilities.
+//===----------------------------------------------------------------------===//
+
+def CooperativeMatrixLayout : GenericEnum, Operand<i32> {
+  let FilterClass = "CooperativeMatrixLayout";
+  let NameField = "Name";
+  let ValueField = "Value";
+}
+
+class CooperativeMatrixLayout<string name, bits<32> value> {
+  string Name = name;
+  bits<32> Value = value;
+}
+
+multiclass CooperativeMatrixLayoutOperand<bits<32> value, list<Extension> reqExtensions, list<Capability> reqCapabilities> {
+  def : CooperativeMatrixLayout<NAM...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/118578


More information about the llvm-commits mailing list