[llvm] [SPIR-V] Add SPV_INTEL_joint_matrix extension (PR #118578)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 3 18:23:18 PST 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff f6f16b5f541773bb074dd042746456deff169de2 0556328ee5adec96f38c2105536e33d852ab216f --extensions cpp,h -- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index 823c33ecb6..2437fbb820 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -210,12 +210,12 @@ namespace Opcode {
 namespace CooperativeMatrixLayout {
 #define GET_CooperativeMatrixLayout_DECL
 #include "SPIRVGenTables.inc"
-} // namespace Opcode
+} // namespace CooperativeMatrixLayout
 
 namespace CooperativeMatrixOperands {
 #define GET_CooperativeMatrixOperands_DECL
 #include "SPIRVGenTables.inc"
-} // namespace Opcode
+} // namespace CooperativeMatrixOperands
 
 struct ExtendedBuiltin {
   StringRef Name;
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index d05a0e87ca..2ee0c79b8f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -220,12 +220,13 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
           const unsigned MulAddOp = MI->getOperand(FirstVariableIndex).getImm();
           if (MulAddOp == 0) {
             printSymbolicOperand<
-              OperandCategory::CooperativeMatrixOperandsOperand>(
-                  MI, FirstVariableIndex, OS);
+                OperandCategory::CooperativeMatrixOperandsOperand>(
+                MI, FirstVariableIndex, OS);
           } else {
             std::string Buffer;
             for (unsigned Mask = 0x1;
-                 Mask != SPIRV::CooperativeMatrixOperands::MatrixResultBFloat16ComponentsINTEL;
+                 Mask != SPIRV::CooperativeMatrixOperands::
+                             MatrixResultBFloat16ComponentsINTEL;
                  Mask <<= 1) {
               if (MulAddOp & Mask) {
                 if (!Buffer.empty())
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 28fcbda2d1..9b6c2a849e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1974,42 +1974,41 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
                Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL;
   unsigned ArgSz = Call->Arguments.size();
   unsigned LiteralIdx = 0;
-  switch(Opcode) {
-    // Memory operand is optional and is literal.
-    case SPIRV::OpCooperativeMatrixLoadKHR:
-      LiteralIdx = ArgSz > 3 ? 3 : 0;
-      break;
-    case SPIRV::OpCooperativeMatrixStoreKHR:
-      LiteralIdx = ArgSz > 4 ? 4 : 0;
-      break;
-    case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
-      LiteralIdx = ArgSz > 7 ? 7 : 0;
-      break;
-    case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
-      LiteralIdx = ArgSz > 8 ? 8 : 0;
-      break;
-    // Cooperative Matrix Operands operand is optional and is literal.
-    case SPIRV::OpCooperativeMatrixMulAddKHR:
-      LiteralIdx = ArgSz > 3 ? 3 : 0;
-      break;
+  switch (Opcode) {
+  // Memory operand is optional and is literal.
+  case SPIRV::OpCooperativeMatrixLoadKHR:
+    LiteralIdx = ArgSz > 3 ? 3 : 0;
+    break;
+  case SPIRV::OpCooperativeMatrixStoreKHR:
+    LiteralIdx = ArgSz > 4 ? 4 : 0;
+    break;
+  case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
+    LiteralIdx = ArgSz > 7 ? 7 : 0;
+    break;
+  case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
+    LiteralIdx = ArgSz > 8 ? 8 : 0;
+    break;
+  // Cooperative Matrix Operands operand is optional and is literal.
+  case SPIRV::OpCooperativeMatrixMulAddKHR:
+    LiteralIdx = ArgSz > 3 ? 3 : 0;
+    break;
   };
 
   SmallVector<uint32_t, 1> ImmArgs;
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
-    const uint32_t CacheLevel =
-        getConstFromIntrinsic(Call->Arguments[3], MRI);
+    const uint32_t CacheLevel = getConstFromIntrinsic(Call->Arguments[3], MRI);
     auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL)
-        .addUse(Call->Arguments[0])   // pointer
-        .addUse(Call->Arguments[1])   // rows
-        .addUse(Call->Arguments[2])   // columns
-        .addImm(CacheLevel)           // cache level
-        .addUse(Call->Arguments[4]);  // memory layout
+                   .addUse(Call->Arguments[0])  // pointer
+                   .addUse(Call->Arguments[1])  // rows
+                   .addUse(Call->Arguments[2])  // columns
+                   .addImm(CacheLevel)          // cache level
+                   .addUse(Call->Arguments[4]); // memory layout
     if (ArgSz > 5)
       MIB.addUse(Call->Arguments[5]); // stride
     if (ArgSz > 6) {
       const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI);
-      MIB.addImm(MemOp);              // memory operand
+      MIB.addImm(MemOp); // memory operand
     }
     return true;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ee7c08d324..616e9d348e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1441,7 +1441,8 @@ void addInstrRequirements(const MachineInstr &MI,
     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
       report_fatal_error("Cooperative matrix instructions require the "
                          "following SPIR-V extension: "
-                         "SPV_KHR_cooperative_matrix", false);
+                         "SPV_KHR_cooperative_matrix",
+                         false);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
     constexpr unsigned MulAddMaxSize = 6;
@@ -1455,11 +1456,11 @@ void addInstrRequirements(const MachineInstr &MI,
           SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
     }
     if (CoopOperands & SPIRV::CooperativeMatrixOperands::
-        MatrixAAndBBFloat16ComponentsINTEL ||
-        CoopOperands & SPIRV::CooperativeMatrixOperands::
-        MatrixCBFloat16ComponentsINTEL ||
+                           MatrixAAndBBFloat16ComponentsINTEL ||
+        CoopOperands &
+            SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
         CoopOperands & SPIRV::CooperativeMatrixOperands::
-        MatrixResultBFloat16ComponentsINTEL) {
+                           MatrixResultBFloat16ComponentsINTEL) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
       Reqs.addCapability(
           SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
@@ -1474,19 +1475,19 @@ void addInstrRequirements(const MachineInstr &MI,
     if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
       report_fatal_error("Cooperative matrix instructions require the "
                          "following SPIR-V extension: "
-                         "SPV_KHR_cooperative_matrix", false);
+                         "SPV_KHR_cooperative_matrix",
+                         false);
     Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
     Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
 
     // Check Layout operand in case if it's not a standart one and add the
     // appropriate capability.
     std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
-      {SPIRV::OpCooperativeMatrixLoadKHR, 3},
-      {SPIRV::OpCooperativeMatrixStoreKHR, 2},
-      {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
-      {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
-      {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}
-    };
+        {SPIRV::OpCooperativeMatrixLoadKHR, 3},
+        {SPIRV::OpCooperativeMatrixStoreKHR, 2},
+        {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
+        {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
+        {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
 
     const auto OpCode = MI.getOpcode();
     const unsigned LayoutNum = LayoutToInstMap[OpCode];
@@ -1495,11 +1496,12 @@ void addInstrRequirements(const MachineInstr &MI,
     MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
     if (MILayout->getOpcode() == SPIRV::OpConstantI) {
       const unsigned LayoutVal = MILayout->getOperand(2).getImm();
-      if (LayoutVal == static_cast<unsigned>(
-            SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
+      if (LayoutVal ==
+          static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
         if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
           report_fatal_error("PackedINTEL layout require the following SPIR-V "
-                             "extension: SPV_INTEL_joint_matrix", false);
+                             "extension: SPV_INTEL_joint_matrix",
+                             false);
         Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
         Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
       }
@@ -1513,7 +1515,8 @@ void addInstrRequirements(const MachineInstr &MI,
     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
       report_fatal_error("OpCooperativeMatrix[Load/Store]CheckedINTEL "
                          "instructions require the following SPIR-V extension: "
-                         "SPV_INTEL_joint_matrix", false);
+                         "SPV_INTEL_joint_matrix",
+                         false);
     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
     if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
       Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
@@ -1527,7 +1530,8 @@ void addInstrRequirements(const MachineInstr &MI,
     if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
       report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL"
                          " instructions require the following SPIR-V extension:"
-                         " SPV_INTEL_joint_matrix", false);
+                         " SPV_INTEL_joint_matrix",
+                         false);
     Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
     Reqs.addCapability(
         SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);

``````````

</details>


https://github.com/llvm/llvm-project/pull/118578


More information about the llvm-commits mailing list