[Mlir-commits] [mlir] 1fa9e15 - [mlir][spirv] Add cooperative matrix load op

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 19 08:00:36 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-19T10:55:27-04:00
New Revision: 1fa9e150b43dad347bd488bac584b3c2ed77d9f8

URL: https://github.com/llvm/llvm-project/commit/1fa9e150b43dad347bd488bac584b3c2ed77d9f8
DIFF: https://github.com/llvm/llvm-project/commit/1fa9e150b43dad347bd488bac584b3c2ed77d9f8.diff

LOG: [mlir][spirv] Add cooperative matrix load op

Implement cooperative matrix load for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.

Also some minor fixes in common code for custom parsing.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155616

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 885fb7138b1410..7888d6e3aa7f0a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4061,6 +4061,17 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
       SPIRV_KHR_CMU_MatrixA, SPIRV_KHR_CMU_MatrixB, SPIRV_KHR_CMU_MatrixAcc
     ]>;
 
+// Cooperative Matrix Layout for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CML_RowMajor    : I32EnumAttrCase<"RowMajor", 0>;
+def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
+
+def SPIRV_KHR_CooperativeMatrixLayoutAttr :
+    SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
+                      "valid SPIR-V Cooperative Matrix Layout (KHR)",
+                      "coop_matrix_layout_khr", [
+      SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
 //===----------------------------------------------------------------------===//
@@ -4435,6 +4446,7 @@ def SPIRV_OC_OpSDotAccSat                 : I32EnumAttrCase<"OpSDotAccSat", 4453
 def SPIRV_OC_OpUDotAccSat                 : I32EnumAttrCase<"OpUDotAccSat", 4454>;
 def SPIRV_OC_OpSUDotAccSat                : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
 def SPIRV_OC_OpTypeCooperativeMatrixKHR   : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
+def SPIRV_OC_OpCooperativeMatrixLoadKHR   : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
 def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4534,7 +4546,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
       SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
-      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
+      SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
       SPIRV_OC_OpTypeCooperativeMatrixNV,
       SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
       SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index e6e2bbb26d094e..c3c1e2cd042800 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -63,6 +63,77 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
   );
 }
 
+// -----
+
+def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
+  let summary = "Loads a cooperative matrix through a pointer";
+
+  let description = [{
+    Load a cooperative matrix through a pointer.
+
+    Result Type is the type of the loaded object. It must be a cooperative
+    matrix type.
+
+    Pointer is a pointer. Its type must be an OpTypePointer whose Type operand is
+    a scalar or vector type. If the Shader capability was declared, Pointer must
+    point into an array and any ArrayStride decoration on Pointer is ignored.
+
+    MemoryLayout specifies how matrix elements are laid out in memory. It must
+    come from a 32-bit integer constant instruction whose value corresponds to a
+    Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
+    description of the layouts and detailed layout-specific rules.
+
+    Stride further qualifies how matrix elements are laid out in memory. It must
+    be a scalar integer type and its exact semantics depend on MemoryLayout.
+
+    Memory Operand must be a Memory Operand literal. If not present, it is the
+    same as specifying None.
+
+    NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
+    as 'Memory Access'.
+
+    For a given dynamic instance of this instruction, all operands of this
+    instruction must be the same for all invocations in a given scope instance
+    (where the scope is the scope the cooperative matrix type was created with).
+    All invocations in a given scope instance must be active or all must be
+    inactive.
+
+    ``` {.ebnf}
+    cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad`
+                              ssa-use `,` ssa-use `,`
+                              cooperative-matrix-layout `,`
+                              (`[` memory-operand `]`)? ` : `
+                                pointer-type `as` cooperative-matrix-type
+    ```
+
+    #### Example:
+
+    ```
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor
+         : !spirv.ptr<i32, StorageBuffer>
+             as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_6>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_KHR_cooperative_matrix]>,
+    Capability<[SPIRV_C_CooperativeMatrixKHR]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyPtr:$pointer,
+    SPIRV_Integer:$stride,
+    SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+  );
+
+  let results = (outs
+    SPIRV_AnyCooperativeMatrix:$result
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // SPV_NV_cooperative_matrix extension ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 81e74f38307774..0eeb8bb1ec8da5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1445,13 +1445,13 @@ class OpAsmParser : public AsmParser {
   std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
   resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
                   SmallVectorImpl<Value> &result) {
-    size_t operandSize = std::distance(operands.begin(), operands.end());
-    size_t typeSize = std::distance(types.begin(), types.end());
+    size_t operandSize = llvm::range_size(operands);
+    size_t typeSize = llvm::range_size(types);
     if (operandSize != typeSize)
       return emitError(loc)
              << operandSize << " operands present, but expected " << typeSize;
 
-    for (auto [operand, type] : llvm::zip(operands, types))
+    for (auto [operand, type] : llvm::zip_equal(operands, types))
       if (resolveOperand(operand, type, result))
         return failure();
     return success();

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5c11fe8d0cffa6..61b084e6a56412 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -36,6 +36,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 
@@ -53,7 +54,9 @@ constexpr char kGroupOperationAttrName[] = "group_operation";
 constexpr char kIndicesAttrName[] = "indices";
 constexpr char kInitializerAttrName[] = "initializer";
 constexpr char kInterfaceAttrName[] = "interface";
+constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
 constexpr char kMemoryAccessAttrName[] = "memory_access";
+constexpr char kMemoryOperandAttrName[] = "memory_operand";
 constexpr char kMemoryScopeAttrName[] = "memory_scope";
 constexpr char kPackedVectorFormatAttrName[] = "format";
 constexpr char kSemanticsAttrName[] = "semantics";
@@ -176,6 +179,7 @@ template <typename EnumClass>
 static ParseResult
 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
                  StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
   Attribute attrVal;
   NamedAttrList attr;
   auto loc = parser.getCurrentLocation();
@@ -202,6 +206,7 @@ template <typename EnumAttrClass,
 static ParseResult
 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
                  StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
   if (parseEnumStrAttr(value, parser))
     return failure();
   state.addAttribute(attrName,
@@ -218,6 +223,7 @@ static ParseResult
 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
                      OperationState &state,
                      StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
   if (parseEnumKeywordAttr(value, parser))
     return failure();
   state.addAttribute(attrName,
@@ -246,14 +252,15 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
   return success();
 }
 
-/// Parses optional memory access attributes attached to a memory access
-/// operand/pointer. Specifically, parses the following syntax:
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
 ///     (`[` memory-access `]`)?
 /// where:
 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
 ///         integer-literal | `"NonTemporal"`
-static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
-                                               OperationState &state) {
+static ParseResult
+parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state,
+                            StringRef attrName = kMemoryAccessAttrName) {
   // Parse an optional list of attributes staring with '['
   if (parser.parseOptionalLSquare()) {
     // Nothing to do
@@ -262,7 +269,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
 
   spirv::MemoryAccess memoryAccessAttr;
   if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
-                                                kMemoryAccessAttrName))
+                                                attrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -4035,6 +4042,75 @@ LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+                                                     OperationState &result) {
+  std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
+  if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
+    return failure();
+  if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
+    return failure();
+
+  spirv::CooperativeMatrixLayoutKHR layout;
+  if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
+          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+    return failure();
+  }
+
+  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+    return failure();
+
+  Type ptrType;
+  Type elementType;
+  if (parser.parseColon() || parser.parseType(ptrType) ||
+      parser.parseKeywordType("as", elementType)) {
+    return failure();
+  }
+  result.addTypes(elementType);
+
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  if (parser.resolveOperands(operandInfo, {ptrType, strideType},
+                             parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getStride() << ", "
+          << getMatrixLayout();
+  // Print optional memory operand attribute.
+  if (auto memOperand = getMemoryOperand())
+    printer << " [\"" << memOperand << "\"]";
+  printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
+                                                    Type coopMatrix) {
+  auto pointerType = cast<spirv::PointerType>(pointer);
+  Type pointeeType = pointerType.getPointeeType();
+  if (!isa<spirv::ScalarType, VectorType>(pointeeType)) {
+    return op->emitError(
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  }
+
+  // TODO: Verify the memory object behind the pointer:
+  // > If the Shader capability was declared, Pointer must point into an array
+  // > and any ArrayStride decoration on Pointer is ignored.
+
+  return success();
+}
+
+LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
+  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+                                        getResult().getType());
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.NV.CooperativeMatrixLength
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 65311e1db9bb49..7e38161aebae80 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // CooperativeMatrix (KHR)
@@ -21,6 +21,80 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
 
 // -----
 
+// CHECK-LABEL: @cooperative_matrix_load
+spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_memoperand
+spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
+  // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
+spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] :
+  // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] :
+    !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_function
+spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
+  // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
+  // CHECK-SAME:   !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
+    !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
+    !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{expected ','}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{expected valid keyword}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
+    !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // NV.CooperativeMatrix
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list