[Mlir-commits] [mlir] b359bba - [mlir][spirv] First step to support spirv cooperative matrix extension.

Thomas Raoux llvmlistbot at llvm.org
Tue May 19 19:30:05 PDT 2020


Author: Thomas Raoux
Date: 2020-05-19T19:29:41-07:00
New Revision: b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1

URL: https://github.com/llvm/llvm-project/commit/b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1
DIFF: https://github.com/llvm/llvm-project/commit/b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1.diff

LOG: [mlir][spirv] First step to support spirv cooperative matrix extension.

Add a new type to SPIRV dialect for cooperative matrix and add new op for
cooperative matrix load. This is missing most instructions to support
cooperative matrix extension but this is a stop-gap patch to avoid creating big
review.

Differential Revision: https://reviews.llvm.org/D80043

Added: 
    mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
    mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
new file mode 100644
index 000000000000..f368aec45efb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
@@ -0,0 +1,41 @@
+//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities used for parsing types and ops for SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_SPIRV_PARSERUTILS_H_
+#define MLIR_DIALECT_SPIRV_PARSERUTILS_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+/// Parses the next keyword in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass, typename ParserType>
+static ParseResult
+parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
+                     StringRef attrName = spirv::attributeName<EnumClass>()) {
+  StringRef keyword;
+  SmallVector<NamedAttribute, 1> attr;
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseKeyword(&keyword))
+    return failure();
+  if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
+    value = attr.getValue();
+    return success();
+  }
+  return parser.emitError(loc, "invalid ")
+         << attrName << " attribute specification: " << keyword;
+}
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_PARSERUTILS_H_

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 64063cb77d01..b958a10c5952 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -2991,8 +2991,10 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
     AnyTypeOf<!foreach(w, widths, IOrUI<w>),
               StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
 
-def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
 def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsCooperativeMatrixType :
+  CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
 def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
 def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
 
@@ -3012,6 +3014,9 @@ def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
                              "any SPIR-V pointer type">;
 def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
                                "any SPIR-V array type">;
+def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
+                               SPV_IsCooperativeMatrixType,
+                               "any SPIR-V cooperative matrix type">;
 def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
                                  "any SPIR-V runtime array type">;
 def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
@@ -3220,6 +3225,8 @@ def SPV_OC_OpGroupNonUniformSMax       : I32EnumAttrCase<"OpGroupNonUniformSMax"
 def SPV_OC_OpGroupNonUniformUMax       : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
 def SPV_OC_OpGroupNonUniformFMax       : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
 def SPV_OC_OpSubgroupBallotKHR         : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
+def SPV_OC_OpTypeCooperativeMatrixNV   : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
+def SPV_OC_OpCooperativeMatrixLoadNV   : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
 
 def SPV_OpcodeAttr :
     SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3271,7 +3278,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
       SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
       SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
-      SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR
+      SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
+      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
new file mode 100644
index 000000000000..931f56f58755
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -0,0 +1,94 @@
+//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of cooperative matrix multiply extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_COOPERATIVE_MATRIX_OPS
+#define SPIRV_COOPERATIVE_MATRIX_OPS
+
+// -----
+
+def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
+  let summary = "See extension SPV_NV_cooperative_matrix";
+
+  let description = [{
+    Load a cooperative matrix through a pointer.
+
+    Result Type is the type of the loaded object. It must be a cooperative
+    matrix type.
+
+    Pointer is a pointer into an array. Its type must be an OpTypePointer whose
+    Type operand is a scalar or vector type. The storage class of Pointer must
+    be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
+    supported) PhysicalStorageBufferEXT.
+
+    Stride is the number of elements in the array in memory between the first
+    component of consecutive rows (or columns) in the result. It must be a
+    scalar integer type.
+
+    ColumnMajor indicates whether the values loaded from memory are arranged in
+    column-major or row-major order. It must be a boolean constant instruction,
+    with false indicating row major and true indicating column major.
+
+    Memory Access must be a Memory Access literal. If not present, it is the
+    same as specifying None.
+
+    If ColumnMajor is false, then elements (row,*) of the result are taken in
+    order from contiguous locations starting at Pointer[row*Stride]. If
+    ColumnMajor is true, then elements (*,col) of the result are taken in order
+    from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
+    decoration on Pointer is ignored.
+
+    For a given dynamic instance of this instruction, all operands of this
+    instruction must be the same for all invocations in a given scope instance
+    (where the scope is the scope the cooperative matrix type was created with).
+    All invocations in a given scope instance must be active or all must be
+    inactive.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
+                              storage-class ssa-use (`[` memory-access `]`)? `
+                              : ` cooperative-matrix-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
+         : !spv.coopmatrix<i32, Workgroup, 16, 8>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_NV_cooperative_matrix]>,
+    Capability<[SPV_C_CooperativeMatrixNV]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyPtr:$pointer,
+    SPV_Integer:$stride,
+    SPV_Bool:$columnmajor,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access
+  );
+
+  let results = (outs
+    SPV_AnyCooperativeMatrix:$result
+  );
+
+  let verifier = [{ return success(); }];
+}
+
+// -----
+
+#endif // SPIRV_COOPERATIVE_MATRIX_OPS

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 518dca69873d..520ed14c9624 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
 include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
 include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td"
 include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
+include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 3b5a82d239b9..078fb5a67225 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -54,6 +54,7 @@ SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
 
 namespace detail {
 struct ArrayTypeStorage;
+struct CooperativeMatrixTypeStorage;
 struct ImageTypeStorage;
 struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
@@ -63,6 +64,7 @@ struct StructTypeStorage;
 namespace TypeKind {
 enum Kind {
   Array = Type::FIRST_SPIRV_TYPE,
+  CooperativeMatrix,
   Image,
   Pointer,
   RuntimeArray,
@@ -330,6 +332,34 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
                        Optional<spirv::StorageClass> storage = llvm::None);
 };
 
+// SPIR-V cooperative matrix type
+class CooperativeMatrixNVType
+    : public Type::TypeBase<CooperativeMatrixNVType, SPIRVType,
+                            detail::CooperativeMatrixTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) {
+    return kind == TypeKind::CooperativeMatrix;
+  }
+
+  static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope,
+                                     unsigned rows, unsigned columns);
+  Type getElementType() const;
+
+  /// Return the scope of the cooperative matrix.
+  spirv::Scope getScope() const;
+  /// return the number of rows of the matrix.
+  unsigned getRows() const;
+  /// return the number of columns of the matrix.
+  unsigned getColumns() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
+};
+
 } // end namespace spirv
 } // end namespace mlir
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index c74698a93bfc..8c4d0ebe99a7 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/ParserUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
@@ -115,7 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
 
 SPIRVDialect::SPIRVDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
+  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
+           RuntimeArrayType, StructType>();
 
   addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
 
@@ -264,6 +266,36 @@ static Type parseArrayType(SPIRVDialect const &dialect,
   return ArrayType::get(elementType, count, stride);
 }
 
+// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
+//                                                   rows ',' coloumns>`
+static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
+                                       DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  SmallVector<int64_t, 2> dims;
+  llvm::SMLoc countLoc = parser.getCurrentLocation();
+  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+    return Type();
+
+  if (dims.size() != 2) {
+    parser.emitError(countLoc, "expected rows and columns size.");
+    return Type();
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return Type();
+
+  Scope scope;
+  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+    return Type();
+
+  if (parser.parseGreater())
+    return Type();
+  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
+}
+
 // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -525,6 +557,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
 
   if (keyword == "array")
     return parseArrayType(*this, parser);
+  if (keyword == "coopmatrix")
+    return parseCooperativeMatrixType(*this, parser);
   if (keyword == "image")
     return parseImageType(*this, parser);
   if (keyword == "ptr")
@@ -595,11 +629,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
+  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+  os << type.getElementType() << ", " << stringifyScope(type.getScope());
+  os << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   switch (type.getKind()) {
   case TypeKind::Array:
     print(type.cast<ArrayType>(), os);
     return;
+  case TypeKind::CooperativeMatrix:
+    print(type.cast<CooperativeMatrixNVType>(), os);
+    return;
   case TypeKind::Pointer:
     print(type.cast<PointerType>(), os);
     return;

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index e7bdfe902804..eed597b1d21c 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 
+#include "mlir/Dialect/SPIRV/ParserUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
@@ -140,25 +141,6 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
   return success();
 }
 
-/// Parses the next keyword in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
-                     StringRef attrName = spirv::attributeName<EnumClass>()) {
-  StringRef keyword;
-  SmallVector<NamedAttribute, 1> attr;
-  auto loc = parser.getCurrentLocation();
-  if (parser.parseKeyword(&keyword))
-    return failure();
-  if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
-    value = attr.getValue();
-    return success();
-  }
-  return parser.emitError(loc, "invalid ")
-         << attrName << " attribute specification: " << keyword;
-}
-
 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
 /// the enum class's name as attribute name.
@@ -2637,6 +2619,49 @@ static LogicalResult verify(spirv::VariableOp varOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.CooperativeMatrixLoadNV
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
+                                                  OperationState &state) {
+  spirv::StorageClass storageClass;
+  SmallVector<OpAsmParser::OperandType, 3> operandInfo;
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) ||
+      parser.parseOperandList(operandInfo, 3) ||
+      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(
+      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
+      storageClass);
+  SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
+  if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
+                             state.operands)) {
+    return failure();
+  }
+
+  state.addTypes(elementType);
+  return success();
+}
+
+static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
+  StringRef sc = stringifyStorageClass(
+      M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
+  printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
+          << "\" " << M.pointer() << ", " << M.stride() << ", "
+          << M.columnmajor();
+  // Print optional memory access attribute.
+  if (auto memAccess = M.memory_access())
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+  printer << " : " << M.getType();
+}
+
 namespace mlir {
 namespace spirv {
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 71ca0c3d2bc7..ce5a6c0c4fd9 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -158,6 +158,7 @@ void ArrayType::getCapabilities(
 bool CompositeType::classof(Type type) {
   switch (type.getKind()) {
   case TypeKind::Array:
+  case TypeKind::CooperativeMatrix:
   case TypeKind::RuntimeArray:
   case TypeKind::Struct:
     return true;
@@ -177,6 +178,8 @@ Type CompositeType::getElementType(unsigned index) const {
   switch (getKind()) {
   case spirv::TypeKind::Array:
     return cast<ArrayType>().getElementType();
+  case spirv::TypeKind::CooperativeMatrix:
+    return cast<CooperativeMatrixNVType>().getElementType();
   case spirv::TypeKind::RuntimeArray:
     return cast<RuntimeArrayType>().getElementType();
   case spirv::TypeKind::Struct:
@@ -192,6 +195,9 @@ unsigned CompositeType::getNumElements() const {
   switch (getKind()) {
   case spirv::TypeKind::Array:
     return cast<ArrayType>().getNumElements();
+  case spirv::TypeKind::CooperativeMatrix:
+    return cast<CooperativeMatrixNVType>().getRows() *
+           cast<CooperativeMatrixNVType>().getColumns();
   case spirv::TypeKind::RuntimeArray:
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -211,6 +217,9 @@ void CompositeType::getExtensions(
   case spirv::TypeKind::Array:
     cast<ArrayType>().getExtensions(extensions, storage);
     break;
+  case spirv::TypeKind::CooperativeMatrix:
+    cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
+    break;
   case spirv::TypeKind::RuntimeArray:
     cast<RuntimeArrayType>().getExtensions(extensions, storage);
     break;
@@ -233,6 +242,9 @@ void CompositeType::getCapabilities(
   case spirv::TypeKind::Array:
     cast<ArrayType>().getCapabilities(capabilities, storage);
     break;
+  case spirv::TypeKind::CooperativeMatrix:
+    cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
+    break;
   case spirv::TypeKind::RuntimeArray:
     cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
     break;
@@ -248,6 +260,70 @@ void CompositeType::getCapabilities(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// CooperativeMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
+  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
+
+  static CooperativeMatrixTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<CooperativeMatrixTypeStorage>())
+        CooperativeMatrixTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, getScope(), rows, columns);
+  }
+
+  CooperativeMatrixTypeStorage(const KeyTy &key)
+      : TypeStorage(static_cast<unsigned>(std::get<1>(key))),
+        elementType(std::get<0>(key)), rows(std::get<2>(key)),
+        columns(std::get<3>(key)) {}
+
+  Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
+
+  Type elementType;
+  unsigned rows;
+  unsigned columns;
+};
+
+CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
+                                                     Scope scope, unsigned rows,
+                                                     unsigned columns) {
+  return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
+                   elementType, scope, rows, columns);
+}
+
+Type CooperativeMatrixNVType::getElementType() const {
+  return getImpl()->elementType;
+}
+
+Scope CooperativeMatrixNVType::getScope() const {
+  return getImpl()->getScope();
+}
+
+unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
+
+unsigned CooperativeMatrixNVType::getColumns() const {
+  return getImpl()->columns;
+}
+
+void CooperativeMatrixNVType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+  extensions.push_back(Extension::SPV_NV_cooperative_matrix);
+}
+
+void CooperativeMatrixNVType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+  capabilities.push_back(Capability::CooperativeMatrixNV);
+}
+
 //===----------------------------------------------------------------------===//
 // ImageType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index a45780ba63a0..87f233580b75 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -217,6 +217,8 @@ class Deserializer {
 
   LogicalResult processArrayType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
+
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
   LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
@@ -1160,6 +1162,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
   } break;
   case spirv::Opcode::OpTypeArray:
     return processArrayType(operands);
+  case spirv::Opcode::OpTypeCooperativeMatrixNV:
+    return processCooperativeMatrixType(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
   case spirv::Opcode::OpTypeRuntimeArray:
@@ -1229,6 +1233,35 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 5) {
+    return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
+                                 "type and row x column parameters");
+  }
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc,
+                     "OpTypeCooperativeMatrix references undefined <id> ")
+           << operands[1];
+  }
+
+  auto scope = spirv::symbolizeScope(operands[2]);
+  if (!scope) {
+    return emitError(unknownLoc,
+                     "OpTypeCooperativeMatrix references undefined scope <id> ")
+           << operands[2];
+  }
+
+  unsigned rows = operands[3];
+  unsigned columns = operands[4];
+
+  typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
+      elementTy, scope.getValue(), rows, columns);
+  return success();
+}
+
 LogicalResult
 Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
@@ -2210,6 +2243,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
+  case spirv::Opcode::OpTypeCooperativeMatrixNV:
     return processType(opcode, operands);
   case spirv::Opcode::OpConstant:
     return processConstant(operands, /*isSpec=*/false);

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 2b500ddbf985..8ea0c4f4711b 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1096,6 +1096,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     return success();
   }
 
+  if (auto cooperativeMatrixType =
+          type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+    uint32_t elementTypeID = 0;
+    if (failed(processType(loc, cooperativeMatrixType.getElementType(),
+                           elementTypeID))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
+    operands.push_back(elementTypeID);
+    operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
+    operands.push_back(cooperativeMatrixType.getRows());
+    operands.push_back(cooperativeMatrixType.getColumns());
+    return success();
+  }
+
   // TODO(ravishankarm) : Handle other types.
   return emitError(loc, "unhandled type in serialization: ") << type;
 }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
new file mode 100644
index 000000000000..e90996ee24b7
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
+  // CHECK-LABEL: @cooperative_matrix_load
+  spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
+    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_load_memaccess
+  spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+}

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
new file mode 100644
index 000000000000..c121943acf82
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @cooperative_matrix_load
+spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
+  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+  spv.Return
+}
+
+// -----
+// CHECK-LABEL: @cooperative_matrix_load_memaccess
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}

diff  --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 4c1adafce4a8..697177b0b98e 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -327,3 +327,23 @@ func @struct_type_missing_comma(!spv.struct<f32 [0 NonWritable], i32 [4]>)
 
 // expected-error @+1 {{expected ']'}}
 func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i32 [4]>)
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>)
+func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected ','}}
+func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected rows and columns size}}
+func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
+


        


More information about the Mlir-commits mailing list