[Mlir-commits] [mlir] 15389cd - [mlir][spirv] Add remaining cooperative matrix instructions

Thomas Raoux llvmlistbot at llvm.org
Thu May 21 11:55:57 PDT 2020


Author: Thomas Raoux
Date: 2020-05-21T11:55:33-07:00
New Revision: 15389cdc5b721fc1e10dc5818390fa3fa939a92e

URL: https://github.com/llvm/llvm-project/commit/15389cdc5b721fc1e10dc5818390fa3fa939a92e
DIFF: https://github.com/llvm/llvm-project/commit/15389cdc5b721fc1e10dc5818390fa3fa939a92e.diff

LOG: [mlir][spirv] Add remaining cooperative matrix instructions

Adds support for cooperative matrix support for arithmetic and cast
instructions. It also adds cooperative matrix store, muladd and matrixlength
instructions which are part of the extension.

Differential Revision: https://reviews.llvm.org/D80181

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/ops.mlir
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index 11d0cdf23d79..350e3659a28d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -22,7 +22,18 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
       // Operands type same as result type.
       SPV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [NoSideEffect, SameOperandsAndResultType])>;
+                               [NoSideEffect, SameOperandsAndResultType])> {
+  // In addition to normal types arithmetic instructions can support cooperative
+  // matrix.
+  let arguments = (ins
+    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
+    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
+  );
+
+  let results = (outs
+    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+  );
+}
 
 class SPV_ArithmeticUnaryOp<string mnemonic, Type type,
                             list<OpTrait> traits = []> :

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index b958a10c5952..a3a2c2bec43b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3004,6 +3004,7 @@ def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
 def SPV_Void : TypeAlias<NoneType, "void">;
 def SPV_Bool : TypeAlias<I1, "bool">;
 def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPV_Int32 : TypeAlias<I32, "Int32">;
 def SPV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
 def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
@@ -3034,9 +3035,18 @@ def SPV_Type : AnyTypeOf<[
 
 def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
 
+class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, SPV_IsCooperativeMatrixType,
+    "$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()",
+    "Cooperative Matrix">;
+
 class SPV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
 
+class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>,
+               SPV_CoopMatrixOfType<[type]>]>;
+
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
@@ -3227,6 +3237,9 @@ def SPV_OC_OpGroupNonUniformFMax       : I32EnumAttrCase<"OpGroupNonUniformFMax"
 def SPV_OC_OpSubgroupBallotKHR         : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPV_OC_OpTypeCooperativeMatrixNV   : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPV_OC_OpCooperativeMatrixLoadNV   : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
+def SPV_OC_OpCooperativeMatrixStoreNV  : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
+def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
+def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
 
 def SPV_OpcodeAttr :
     SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3279,7 +3292,9 @@ def SPV_OpcodeAttr :
       SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
       SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
       SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
-      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV
+      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
+      SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
+      SPV_OC_OpCooperativeMatrixLengthNV
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
index 8a64e48d56c0..c67c8d5e4542 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
@@ -23,11 +23,11 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
              !listconcat(traits,
                          [NoSideEffect, SameOperandsAndResultShape])> {
   let arguments = (ins
-    SPV_ScalarOrVectorOf<operandType>:$operand
+    SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand
   );
 
   let results = (outs
-    SPV_ScalarOrVectorOf<resultType>:$result
+    SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
   );
 
   let parser = [{ return mlir::impl::parseCastOp(parser, result); }];

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
index 931f56f58755..4645765b66ba 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -15,6 +15,49 @@
 
 // -----
 
+def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
+  [NoSideEffect]> {
+  let summary = "See extension SPV_NV_cooperative_matrix";
+
+  let description = [{
+    Number of components of a cooperative matrix type accessible to each
+    invocation when treated as a composite.
+
+    Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
+
+    Type is a cooperative matrix type.
+
+    ``` {.ebnf}
+    cooperative-matrix-length-op ::= ssa-id `=` `spv.CooperativeMatrixLengthNV
+                                    ` : ` cooperative-matrix-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<Subgroup, i32, 8, 16>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_NV_cooperative_matrix]>,
+    Capability<[SPV_C_CooperativeMatrixNV]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type
+  );
+
+  let results = (outs
+    SPV_Int32:$result
+  );
+  let verifier = [{ return success(); }];
+}
+
+// -----
+
 def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
@@ -55,9 +98,10 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
     ### Custom assembly form
 
     ``` {.ebnf}
-    cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
-                              storage-class ssa-use (`[` memory-access `]`)? `
-                              : ` cooperative-matrix-type
+    cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
+                              storage-class ssa-use `,` ssa-use `,` ssa-use
+                              (`[` memory-access `]`)? ` : `
+                              cooperative-matrix-type
     ```
 
     For example:
@@ -86,7 +130,147 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
     SPV_AnyCooperativeMatrix:$result
   );
 
-  let verifier = [{ return success(); }];
+  let verifier = [{
+    return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
+                                          result().getType());
+  }];
+}
+
+// -----
+
+def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
+  [NoSideEffect]> {
+  let summary = "See extension SPV_NV_cooperative_matrix";
+
+  let description = [{
+    Linear-algebraic matrix multiply of A by B and then component-wise add C.
+    The order of the operations is implementation-dependent. The internal
+    precision of floating-point operations is defined by the client API.
+    Integer operations are performed at the precision of the Result Type and are
+    exact unless there is overflow or underflow, in which case the result is
+    undefined.
+
+    Result Type must be a cooperative matrix type with M rows and N columns.
+
+    A is a cooperative matrix with M rows and K columns.
+
+    B is a cooperative matrix with K rows and N columns.
+
+    C is a cooperative matrix with M rows and N columns.
+
+    The values of M, N, and K must be consistent across the result and operands.
+    This is referred to as an MxNxK matrix multiply.
+
+    A, B, C, and Result Type must have the same scope, and this defines the
+    scope of the operation. A, B, C, and Result Type need not necessarily have
+    the same component type, this is defined by the client API.
+
+    If the Component Type of any matrix operand is an integer type, then its
+    components are treated as signed if its Component Type has Signedness of 1
+    and are treated as unsigned otherwise.
+
+    For a given dynamic instance of this instruction, all invocations in a given
+    scope instance must be active or all must be inactive (where the scope is
+    the scope of the operation).
+
+    ``` {.ebnf}
+    cooperative-matrixmuladd-op ::= ssa-id `=` `spv.CooperativeMatrixMulAddNV`
+                              ssa-use `,` ssa-use `,` ssa-use ` : `
+                              a-cooperative-matrix-type,
+                              b-cooperative-matrix-type ->
+                              result-cooperative-matrix-type
+    ```
+    For example:
+
+    ```
+    %0 = spv.CooperativeMatrixMulAddNV %arg0, %arg1, %arg2,  :
+      !spv.coopmatrix<Subgroup, i32, 8, 16>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_NV_cooperative_matrix]>,
+    Capability<[SPV_C_CooperativeMatrixNV]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyCooperativeMatrix:$a,
+    SPV_AnyCooperativeMatrix:$b,
+    SPV_AnyCooperativeMatrix:$c
+  );
+
+  let results = (outs
+    SPV_AnyCooperativeMatrix:$result
+  );
+
+  let verifier = [{ return verifyCoopMatrixMulAdd(*this); }];
+}
+
+// -----
+
+def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
+  let summary = "See extension SPV_NV_cooperative_matrix";
+
+  let description = [{
+    Store a cooperative matrix through a pointer.
+
+    Pointer is a pointer into an array. Its type must be an OpTypePointer whose
+    Type operand is a scalar or vector type. The storage class of Pointer must
+    be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
+    supported) PhysicalStorageBufferEXT.
+
+    Object is the object to store. Its type must be an
+    OpTypeCooperativeMatrixNV.
+
+    Stride is the number of elements in the array in memory between the first
+    component of consecutive rows (or columns) in the result. It must be a
+    scalar integer type.
+
+    ColumnMajor indicates whether the values stored to memory are arranged in
+    column-major or row-major order. It must be a boolean constant instruction,
+    with false indicating row major and true indicating column major.
+
+    Memory Access must be a Memory Access literal. If not present, it is the
+    same as specifying None.
+
+    ``` {.ebnf}
+    coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV `
+                              storage-class ssa-use `, ` ssa-use `, `
+                              ssa-use `, ` ssa-use `, `
+                  (`[` memory-access `]`)? `:` spirv-element-type
+    ```
+
+    For example:
+
+    ```
+      spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 :
+        !spv.coopmatrix<Workgroup, i32, 16, 8>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_NV_cooperative_matrix]>,
+    Capability<[SPV_C_CooperativeMatrixNV]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyPtr:$pointer,
+    SPV_AnyCooperativeMatrix:$object,
+    SPV_Integer:$stride,
+    SPV_Bool:$columnmajor,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access
+  );
+
+  let results = (outs);
+
+  let verifier = [{
+    return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
+                                          object().getType());
+  }];
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index eed597b1d21c..630f09842ccd 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -213,6 +213,13 @@ static LogicalResult verifyCastOp(Operation *op,
     resultType = resultType.cast<VectorType>().getElementType();
   }
 
+  if (auto coopMatrixType =
+          operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+    operandType = coopMatrixType.getElementType();
+    resultType =
+        resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
+  }
+
   auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
   auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
   auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
@@ -2662,6 +2669,138 @@ static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
   printer << " : " << M.getType();
 }
 
+static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
+                                                    Type coopMatrix) {
+  if (pointer.cast<spirv::PointerType>().getPointeeType() !=
+      coopMatrix.cast<spirv::CooperativeMatrixNVType>().getElementType())
+    return op->emitError(
+               "expected the same type for pointer and the cooperative matrix"
+               "element, bu provided ")
+           << pointer << " and " << coopMatrix;
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CooperativeMatrixStoreNV
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
+                                                   OperationState &state) {
+  spirv::StorageClass storageClass;
+  SmallVector<OpAsmParser::OperandType, 4> operandInfo;
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) ||
+      parser.parseOperandList(operandInfo, 4) ||
+      parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(
+      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
+      storageClass);
+  SmallVector<Type, 4> OperandType = {ptrType, elementType, strideType,
+                                      columnMajorType};
+  if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
+                             state.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
+                  OpAsmPrinter &printer) {
+  StringRef sc = stringifyStorageClass(coopMatrix.pointer()
+                                           .getType()
+                                           .cast<spirv::PointerType>()
+                                           .getStorageClass());
+  printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \""
+          << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object()
+          << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
+  // Print optional memory access attribute.
+  if (auto memAccess = coopMatrix.memory_access())
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+  printer << " : " << coopMatrix.getOperand(1).getType();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CooperativeMatrixLengthNV
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
+                                                    OperationState &state) {
+  OpAsmParser::OperandType operandInfo;
+  Type dstType = parser.getBuilder().getIntegerType(32);
+  Type type;
+  if (parser.parseColonType(type)) {
+    return failure();
+  }
+  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  state.addTypes(dstType);
+  return success();
+}
+
+static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
+                  OpAsmPrinter &printer) {
+  printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CooperativeMatrixMulAddNV
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
+                                                    OperationState &state) {
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<Type, 3> types(3);
+  if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
+      parser.parseType(types[0]) || parser.parseComma() ||
+      parser.parseType(types[1]) || parser.parseArrow() ||
+      parser.parseType(types[2]) ||
+      parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
+    return failure();
+  }
+  state.addTypes(types[2]);
+  return success();
+}
+
+static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
+                  OpAsmPrinter &printer) {
+  printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
+          << ", " << coopMatrix.getOperand(1) << ", "
+          << coopMatrix.getOperand(2) << ", "
+          << " : " << coopMatrix.getOperand(0).getType() << ", "
+          << coopMatrix.getOperand(1).getType() << " -> "
+          << coopMatrix.getOperand(2).getType();
+}
+
+static LogicalResult
+verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
+  if (op.c().getType() != op.result().getType())
+    return op.emitOpError(
+        "result and third operand must have the same type");
+  auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
+  auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
+  auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
+  auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
+  if (typeA.getRows() != typeR.getRows() ||
+      typeA.getColumns() != typeB.getRows() ||
+      typeB.getColumns() != typeR.getColumns())
+    return op.emitOpError("matrix size must match");
+  if (typeR.getScope() != typeA.getScope() ||
+      typeR.getScope() != typeB.getScope() ||
+      typeR.getScope() != typeC.getScope())
+    return op.emitOpError("matrix scope must match");
+  if (typeR.getElementType() != typeA.getElementType() ||
+      typeR.getElementType() != typeB.getElementType() ||
+      typeR.getElementType() != typeC.getElementType())
+    return op.emitOpError("matrix element type must match");
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index e90996ee24b7..6fb58d859d1f 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -14,4 +14,81 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
     %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
     spv.Return
   }
+
+  // CHECK-LABEL: @cooperative_matrix_store
+  spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
+    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
+    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_store_memaccess
+  spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_length
+  spv.func @cooperative_matrix_length() -> i32 "None" {
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
+    %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.ReturnValue %0 : i32
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_muladd
+  spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+    %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_add
+  spv.func @cooperative_matrix_add(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+    %r = spv.IAdd %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_sub
+  spv.func @cooperative_matrix_sub(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+    %r = spv.ISub %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_sdiv
+  spv.func @cooperative_matrix_sdiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+    %r = spv.SDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_udiv
+  spv.func @cooperative_matrix_udiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+    %r = spv.UDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_fadd
+  spv.func @cooperative_matrix_fadd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+    %r = spv.FAdd %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_fsub
+  spv.func @cooperative_matrix_fsub(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+    %r = spv.FSub %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_fdiv
+  spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+    %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+    spv.Return
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index c121943acf82..0b05d8a587e5 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -14,3 +14,113 @@ spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>,
   %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
   spv.Return
 }
+
+// CHECK-LABEL: @cooperative_matrix_store
+spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
+  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup>
+  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_memaccess
+spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_length
+spv.func @cooperative_matrix_length() -> i32 "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: @cooperative_matrix_muladd
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_add
+spv.func @cooperative_matrix_add(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+  %r = spv.IAdd %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sub
+spv.func @cooperative_matrix_sub(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+  %r = spv.ISub %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sdiv
+spv.func @cooperative_matrix_sdiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+  %r = spv.SDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_udiv
+spv.func @cooperative_matrix_udiv(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>
+  %r = spv.UDiv %a, %b : !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fadd
+spv.func @cooperative_matrix_fadd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+  %r = spv.FAdd %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fsub
+spv.func @cooperative_matrix_fsub(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+  %r = spv.FSub %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fdiv
+spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup>
+  %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<16x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}}
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<16x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}}
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Workgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix scope must match}}
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Workgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{matrix element type must match}}
+  %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xf32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  spv.Return
+}
+

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 5cf91c0b09b9..14e1fa10735e 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -328,6 +328,14 @@ func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
 
 // -----
 
+func @convert_f_to_u_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) {
+  // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.ConvertFToU %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
 func @convert_f_to_u_scalar_invalid(%arg0 : f16) -> i32 {
   // expected-error @+1 {{expected the same bit widths for operand type and result type, but provided 'f16' and 'i32'}}
   %0 = spv.ConvertFToU %arg0 : f16 to i32
@@ -380,6 +388,14 @@ func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> {
 
 // -----
 
+func @f_convert_coop_matrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) {
+  // CHECK: {{%.*}} = spv.FConvert {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xf64, Subgroup>
+  %0 = spv.FConvert %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> to !spv.coopmatrix<8x16xf64, Subgroup>
+  spv.Return
+}
+
+// -----
+
 func @f_convert_vector(%arg0 : f32) -> f32 {
   // expected-error @+1 {{expected the 
diff erent bit widths for operand type and result type, but provided 'f32' and 'f32'}}
   %0 = spv.FConvert %arg0 : f32 to f32

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 73418b831721..e860dbc3fb58 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -508,6 +508,11 @@ static void emitAttributeSerialization(const Attribute &attr,
        << formatv("  {0}.push_back(static_cast<uint32_t>("
                   "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
                   operandList);
+  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+    os << tabs
+       << formatv("  {0}.push_back(static_cast<uint32_t>("
+                  "getTypeID(attr.cast<TypeAttr>().getValue())));\n",
+                  operandList);
   } else {
     PrintFatalError(
         loc,
@@ -766,6 +771,11 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
                   attrList, attrName, words, wordIndex);
+  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+    os << tabs
+       << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+                  "TypeAttr::get(getType({2}[{3}++]))));\n",
+                  attrList, attrName, words, wordIndex);
   } else {
     PrintFatalError(
         loc, llvm::Twine(


        


More information about the Mlir-commits mailing list