[Mlir-commits] [mlir] 4ba61f5 - [mlirv][spirv] Add KHR Cooperative Matrix type and extension

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 12 18:13:08 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-12T21:11:08-04:00
New Revision: 4ba61f5a30d2f24670c7d35ce0ad3f2572808240

URL: https://github.com/llvm/llvm-project/commit/4ba61f5a30d2f24670c7d35ce0ad3f2572808240
DIFF: https://github.com/llvm/llvm-project/commit/4ba61f5a30d2f24670c7d35ce0ad3f2572808240.diff

LOG: [mlirv][spirv] Add KHR Cooperative Matrix type and extension

Start plumbing through support for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.

Register the extension, add new coop matrix type, and add
`spirv.KHR.CooperativeMatrixLength` op to exercise it.

Make sure that mixing of the KHR and NV coop matrix extensions is not
allowed. Make cast verification more robust.

Reviewed By: antiagainst, qedawkins

Differential Revision: https://reviews.llvm.org/D154877

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
    mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
    mlir/test/Dialect/SPIRV/IR/types.mlir
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2b33327dec1fbc..885fb7138b1410 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -343,6 +343,7 @@ def SPV_KHR_uniform_group_instructions       : I32EnumAttrCase<"SPV_KHR_uniform_
 def SPV_KHR_subgroup_rotate                  : I32EnumAttrCase<"SPV_KHR_subgroup_rotate", 28>;
 def SPV_KHR_non_semantic_info                : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
 def SPV_KHR_terminate_invocation             : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
+def SPV_KHR_cooperative_matrix               : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
 
 def SPV_EXT_demote_to_helper_invocation  : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
 def SPV_EXT_descriptor_indexing          : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -435,6 +436,7 @@ def SPIRV_ExtensionAttr :
       SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
       SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
       SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
+      SPV_KHR_cooperative_matrix,
       SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
       SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
       SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -835,6 +837,12 @@ def SPIRV_C_RayCullMaskKHR                              : I32EnumAttrCase<"RayCu
     Extension<[SPV_KHR_ray_cull_mask]>
   ];
 }
+def SPIRV_C_CooperativeMatrixKHR                        : I32EnumAttrCase<"CooperativeMatrixKHR", 6022> {
+  list<Availability> availability = [
+    Extension<[SPV_KHR_cooperative_matrix]>,
+    MinVersion<SPIRV_V_1_6>
+  ];
+}
 def SPIRV_C_BitInstructions                             : I32EnumAttrCase<"BitInstructions", 6025> {
   list<Availability> availability = [
     Extension<[SPV_KHR_bit_instructions]>
@@ -1457,6 +1465,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
       SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
       SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
+      SPIRV_C_CooperativeMatrixKHR,
       SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
       SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
       SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4017,7 +4026,7 @@ def SPIRV_ArrayedAttr : SPIRV_I32EnumAttr<
 def SPIRV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
 def SPIRV_ISI_MultiSampled  : I32EnumAttrCase<"MultiSampled", 1>;
 
-def SPIRV_SamplingAttr: SPIRV_I32EnumAttr<
+def SPIRV_SamplingAttr : SPIRV_I32EnumAttr<
   "ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
   "image_sampling_info", [SPIRV_ISI_SingleSampled, SPIRV_ISI_MultiSampled]>;
 
@@ -4035,11 +4044,23 @@ def SPIRV_ML_RowMajor    : I32EnumAttrCase<"RowMajor", 1>;
 def SPIRV_ML_PackedA     : I32EnumAttrCase<"PackedA", 2>;
 def SPIRV_ML_PackedB     : I32EnumAttrCase<"PackedB", 3>;
 
-def SPIRV_MatrixLayoutAttr  :
+def SPIRV_MatrixLayoutAttr :
     SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
       SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB
     ]>;
 
+// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CMU_MatrixA   : I32EnumAttrCase<"MatrixA", 0>;
+def SPIRV_KHR_CMU_MatrixB   : I32EnumAttrCase<"MatrixB", 1>;
+def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;
+
+def SPIRV_KHR_CooperativeMatrixUseAttr :
+    SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
+                      "valid SPIR-V Cooperative Matrix Use (KHR)",
+                      "coop_matrix_use_khr", [
+      SPIRV_KHR_CMU_MatrixA, SPIRV_KHR_CMU_MatrixB, SPIRV_KHR_CMU_MatrixAcc
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
 //===----------------------------------------------------------------------===//
@@ -4069,6 +4090,8 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
               !interleave(widths, "/") # "-bit signless/unsigned integer">;
 
 def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
+def SPIRV_IsCooperativeMatrixType :
+  CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
 def SPIRV_IsCooperativeMatrixNVType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
@@ -4100,6 +4123,9 @@ def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
                              "any SPIR-V pointer type">;
 def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
                                "any SPIR-V array type">;
+def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
+                                   SPIRV_IsCooperativeMatrixType,
+                                  "any SPIR-V cooperative matrix type">;
 def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
                                      SPIRV_IsCooperativeMatrixNVType,
                                      "any SPIR-V NV cooperative matrix type">;
@@ -4121,17 +4147,23 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
+               SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
-    SPIRV_AnySampledImage
+    SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
+    SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
 
+class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixType,
+    "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
+    "Cooperative Matrix">;
+
 class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
     "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
@@ -4147,10 +4179,12 @@ class SPIRV_ScalarOrVectorOf<Type type> :
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
-               SPIRV_CoopMatrixNVOfType<[type]>]>;
+               SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>;
+    AnyTypeOf<[SPIRV_AnyMatrix,
+               SPIRV_CoopMatrixOfType<[type]>,
+               SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4400,6 +4434,8 @@ def SPIRV_OC_OpSUDot                      : I32EnumAttrCase<"OpSUDot", 4452>;
 def SPIRV_OC_OpSDotAccSat                 : I32EnumAttrCase<"OpSDotAccSat", 4453>;
 def SPIRV_OC_OpUDotAccSat                 : I32EnumAttrCase<"OpUDotAccSat", 4454>;
 def SPIRV_OC_OpSUDotAccSat                : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
+def SPIRV_OC_OpTypeCooperativeMatrixKHR   : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
+def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
 def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
 def SPIRV_OC_OpCooperativeMatrixStoreNV   : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
@@ -4498,7 +4534,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
       SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
-      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV,
+      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+      SPIRV_OC_OpTypeCooperativeMatrixNV,
       SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
       SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index fada7f64899ad0..e6e2bbb26d094e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -7,12 +7,62 @@
 //===----------------------------------------------------------------------===//
 //
 // This is the op definition spec of cooperative matrix multiply extension ops.
+// We support both cooperative matrix extensions:
+//  - SPV_NV_cooperative_matrix
+//  - SPV_KHR_cooperative_matrix
 //
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
 #define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
 
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix extension ops.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+def SPIRV_KHRCooperativeMatrixLengthOp :
+      SPIRV_KhrVendorOp<"CooperativeMatrixLength", [Pure]> {
+  let summary = "Queries the number of cooperative matrix components";
+
+  let description = [{
+    Number of components of a cooperative matrix type accessible to each
+    invocation when treated as a composite.
+
+    The type attribute must be a cooperative matrix type.
+
+    ``` {.ebnf}
+    cooperative-matrix-length-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLength
+                                    ` : ` cooperative-matrix-type
+    ```
+
+    #### Example:
+
+    ```
+    %0 = spirv.KHR.CooperativeMatrixLength :
+           !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+    ```
+  }];
+
+  let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
+
+  let availability = [
+    MinVersion<SPIRV_V_1_6>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_KHR_cooperative_matrix]>,
+    Capability<[SPIRV_C_CooperativeMatrixKHR]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$cooperative_matrix_type
+  );
+
+  let results = (outs
+    SPIRV_Int32:$result
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // SPV_NV_cooperative_matrix extension ops.
 //===----------------------------------------------------------------------===//
@@ -59,7 +109,6 @@ def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLengt
   let results = (outs
     SPIRV_Int32:$result
   );
-  let hasVerifier = 0;
 }
 
 // -----

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index b5b1f5ad4f52f1..07f2f158ecabb6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/IR/Types.h"
 
+#include <cstdint>
 #include <tuple>
 
 namespace mlir {
@@ -27,6 +28,7 @@ namespace spirv {
 
 namespace detail {
 struct ArrayTypeStorage;
+struct CooperativeMatrixTypeStorage;
 struct CooperativeMatrixNVTypeStorage;
 struct ImageTypeStorage;
 struct JointMatrixTypeStorage;
@@ -398,6 +400,33 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
+// SPIR-V KHR cooperative matrix type
+class CooperativeMatrixType
+    : public Type::TypeBase<CooperativeMatrixType, CompositeType,
+                            detail::CooperativeMatrixTypeStorage> {
+public:
+  using Base::Base;
+
+  static CooperativeMatrixType get(Type elementType, uint32_t rows,
+                                   uint32_t columns, Scope scope,
+                                   CooperativeMatrixUseKHR use);
+  Type getElementType() const;
+
+  /// Returns the scope of the matrix.
+  Scope getScope() const;
+  /// Returns the number of rows of the matrix.
+  uint32_t getRows() const;
+  /// Returns the number of columns of the matrix.
+  uint32_t getColumns() const;
+  /// Returns the use parameter of the cooperative matrix.
+  CooperativeMatrixUseKHR getUse() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     std::optional<StorageClass> storage = std::nullopt);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       std::optional<StorageClass> storage = std::nullopt);
+};
+
 // SPIR-V NV cooperative matrix type
 class CooperativeMatrixNVType
     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a0bf1300b183a3..2e1c7923e24126 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -319,9 +319,45 @@ static Type parseArrayType(SPIRVDialect const &dialect,
 }
 
 // cooperative-matrix-type ::=
-//   `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type ',' scope>
+//   `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
+//                           scope `,` use `>`
 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
                                        DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  SmallVector<int64_t, 2> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+    return {};
+
+  if (dims.size() != 2) {
+    parser.emitError(countLoc, "expected row and column count");
+    return {};
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return {};
+
+  Scope scope;
+  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+    return {};
+
+  CooperativeMatrixUseKHR use;
+  if (parser.parseComma() || parseEnumKeywordAttr(use, parser, "use <id>"))
+    return {};
+
+  if (parser.parseGreater())
+    return {};
+
+  return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
+}
+
+// nv-cooperative-matrix-type ::=
+//   `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>`
+static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
+                                         DialectAsmParser &parser) {
   if (parser.parseLess())
     return Type();
 
@@ -785,8 +821,10 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
 
   if (keyword == "array")
     return parseArrayType(*this, parser);
-  if (keyword == "NV.coopmatrix")
+  if (keyword == "coopmatrix")
     return parseCooperativeMatrixType(*this, parser);
+  if (keyword == "NV.coopmatrix")
+    return parseCooperativeMatrixNVType(*this, parser);
   if (keyword == "jointmatrix")
     return parseJointMatrixType(*this, parser);
   if (keyword == "image")
@@ -889,6 +927,12 @@ static void print(StructType type, DialectAsmPrinter &os) {
     structContext.remove(type.getIdentifier());
 }
 
+static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
+  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
+     << type.getElementType() << ", " << type.getScope() << ", "
+     << type.getUse() << ">";
+}
+
 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
   os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
   os << type.getElementType() << ", " << stringifyScope(type.getScope());
@@ -909,9 +953,10 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
 
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
-      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
-            PointerType, RuntimeArrayType, ImageType, SampledImageType,
-            StructType, MatrixType>([&](auto type) { print(type, os); })
+      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+            JointMatrixINTELType, PointerType, RuntimeArrayType, ImageType,
+            SampledImageType, StructType, MatrixType>(
+          [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6747f75f468a08..5c11fe8d0cffa6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <numeric>
@@ -429,36 +430,38 @@ static LogicalResult verifyCastOp(Operation *op,
   Type operandType = op->getOperand(0).getType();
   Type resultType = op->getResult(0).getType();
 
-  // ODS checks that result type and operand type have the same shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    operandType = vectorType.getElementType();
-    resultType = llvm::cast<VectorType>(resultType).getElementType();
-  }
-
-  if (auto coopMatrixType =
-          llvm::dyn_cast<spirv::CooperativeMatrixNVType>(operandType)) {
-    operandType = coopMatrixType.getElementType();
-    resultType =
-        llvm::cast<spirv::CooperativeMatrixNVType>(resultType).getElementType();
-  }
-
-  if (auto jointMatrixType =
-          llvm::dyn_cast<spirv::JointMatrixINTELType>(operandType)) {
-    operandType = jointMatrixType.getElementType();
-    resultType =
-        llvm::cast<spirv::JointMatrixINTELType>(resultType).getElementType();
-  }
-
-  auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
-  auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
-  auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
+  // ODS checks that result type and operand type have the same shape. Check
+  // that composite types match and extract the element types, if any.
+  using TypePair = std::pair<Type, Type>;
+  auto [operandElemTy, resultElemTy] =
+      TypeSwitch<Type, TypePair>(operandType)
+          .Case<VectorType, spirv::CooperativeMatrixType,
+                spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
+              [resultType](auto concreteOperandTy) -> TypePair {
+                if (auto concreteResultTy =
+                        dyn_cast<decltype(concreteOperandTy)>(resultType)) {
+                  return {concreteOperandTy.getElementType(),
+                          concreteResultTy.getElementType()};
+                }
+                return {};
+              })
+          .Default([resultType](Type operandType) -> TypePair {
+            return {operandType, resultType};
+          });
+
+  if (!operandElemTy || !resultElemTy)
+    return op->emitOpError("incompatible operand and result types");
+
+  unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
+  unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
+  bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
 
   if (requireSameBitWidth) {
     if (!isSameBitWidth) {
       return op->emitOpError(
                  "expected the same bit widths for operand type and result "
                  "type, but provided ")
-             << operandType << " and " << resultType;
+             << operandElemTy << " and " << resultElemTy;
     }
     return success();
   }
@@ -467,7 +470,7 @@ static LogicalResult verifyCastOp(Operation *op,
     return op->emitOpError(
                "expected the 
diff erent bit widths for operand type and result "
                "type, but provided ")
-           << operandType << " and " << resultType;
+           << operandElemTy << " and " << resultElemTy;
   }
   return success();
 }
@@ -4018,6 +4021,34 @@ LogicalResult spirv::VectorShuffleOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
+  if (!isa<spirv::CooperativeMatrixType>(getCooperativeMatrixType())) {
+    return emitOpError(
+               "type attribute must be a '!spirv.coopmatrix' type, found ")
+           << getCooperativeMatrixType() << " instead";
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() {
+  if (!isa<spirv::CooperativeMatrixNVType>(getCooperativeMatrixType())) {
+    return emitOpError(
+               "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
+           << getCooperativeMatrixType() << " instead";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.NV.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
@@ -4053,8 +4084,8 @@ void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
   printer << " : " << getPointer().getType() << " as " << getType();
 }
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
   Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
   if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
       !llvm::isa<VectorType>(pointeeType))
@@ -4074,8 +4105,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+                                          getResult().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -4114,8 +4145,8 @@ void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+                                          getObject().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -4123,7 +4154,7 @@ LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult
-verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
+verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) {
   if (op.getC().getType() != op.getResult().getType())
     return op.emitOpError("result and third operand must have the same type");
   auto typeA = llvm::cast<spirv::CooperativeMatrixNVType>(op.getA().getType());
@@ -4156,9 +4187,13 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
 }
 
 LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
-  return verifyCoopMatrixMulAdd(*this);
+  return verifyCoopMatrixMulAddNV(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixLoad
+//===----------------------------------------------------------------------===//
+
 static LogicalResult
 verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
   Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
@@ -4179,10 +4214,6 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
 LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
   return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
                                          getResult().getType());

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 1599c5bb74ae06..01c694de08a9e8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <cstdint>
 #include <iterator>
 
 using namespace mlir;
@@ -93,9 +94,10 @@ std::optional<int64_t> ArrayType::getSizeInBytes() {
 bool CompositeType::classof(Type type) {
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return isValid(vectorType);
-  return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
-                   spirv::JointMatrixINTELType, spirv::MatrixType,
-                   spirv::RuntimeArrayType, spirv::StructType>(type);
+  return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
+                   spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType,
+                   spirv::MatrixType, spirv::RuntimeArrayType,
+                   spirv::StructType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -114,8 +116,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
-            RuntimeArrayType, VectorType>(
+      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+            JointMatrixINTELType, RuntimeArrayType, VectorType>(
           [](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
@@ -133,9 +135,9 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
-  if (llvm::isa<CooperativeMatrixNVType>(*this)) {
+  if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
     llvm_unreachable(
-        "invalid to query number of elements of spirv::CooperativeMatrix type");
+        "invalid to query number of elements of spirv Cooperative Matrix type");
   }
   if (llvm::isa<JointMatrixINTELType>(*this)) {
     llvm_unreachable(
@@ -149,16 +151,16 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
-              RuntimeArrayType>(*this);
+  return !llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType,
+                    JointMatrixINTELType, RuntimeArrayType>(*this);
 }
 
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
-            MatrixType, RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+            JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
           [&](auto type) { type.getExtensions(extensions, storage); })
       .Case<VectorType>([&](VectorType type) {
         return llvm::cast<ScalarType>(type.getElementType())
@@ -171,8 +173,8 @@ void CompositeType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
-            MatrixType, RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+            JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
           [&](auto type) { type.getCapabilities(capabilities, storage); })
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
@@ -202,6 +204,77 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
   return std::nullopt;
 }
 
+//===----------------------------------------------------------------------===//
+// CooperativeMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
+  using KeyTy =
+      std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+
+  static CooperativeMatrixTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<CooperativeMatrixTypeStorage>())
+        CooperativeMatrixTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, rows, columns, scope, use);
+  }
+
+  CooperativeMatrixTypeStorage(const KeyTy &key)
+      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
+        columns(std::get<2>(key)), scope(std::get<3>(key)),
+        use(std::get<4>(key)) {}
+
+  Type elementType;
+  uint32_t rows;
+  uint32_t columns;
+  Scope scope;
+  CooperativeMatrixUseKHR use;
+};
+
+CooperativeMatrixType CooperativeMatrixType::get(Type elementType,
+                                                 uint32_t rows,
+                                                 uint32_t columns, Scope scope,
+                                                 CooperativeMatrixUseKHR use) {
+  return Base::get(elementType.getContext(), elementType, rows, columns, scope,
+                   use);
+}
+
+Type CooperativeMatrixType::getElementType() const {
+  return getImpl()->elementType;
+}
+
+uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+
+uint32_t CooperativeMatrixType::getColumns() const {
+  return getImpl()->columns;
+}
+
+Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
+
+CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
+  return getImpl()->use;
+}
+
+void CooperativeMatrixType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
+  extensions.push_back(exts);
+}
+
+void CooperativeMatrixType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+  capabilities.push_back(caps);
+}
+
 //===----------------------------------------------------------------------===//
 // CooperativeMatrixNVType
 //===----------------------------------------------------------------------===//
@@ -1247,7 +1320,7 @@ void MatrixType::getCapabilities(
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
-  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
-           MatrixType, PointerType, RuntimeArrayType, SampledImageType,
-           StructType>();
+  addTypes<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType, ImageType,
+           JointMatrixINTELType, MatrixType, PointerType, RuntimeArrayType,
+           SampledImageType, StructType>();
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 1835c6ae1d5f80..4f4a72da7c050a 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -138,6 +138,14 @@ func.func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
 
 // -----
 
+func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>) {
+  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> to !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> to !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.Return
+}
+
+// -----
+
 func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
   // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
@@ -222,7 +230,15 @@ func.func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> {
 
 // -----
 
-func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xf64, Subgroup, MatrixA>
+  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xf64, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
   // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
   %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
   spirv.Return
@@ -238,6 +254,14 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
 
 // -----
 
+func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) {
+  // expected-error @+1 {{incompatible operand and result types}}
+  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.SConvert
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 2e387403964612..65311e1db9bb49 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -1,4 +1,29 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// CooperativeMatrix (KHR)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cooperative_matrix_length
+spirv.func @cooperative_matrix_length() -> i32 "None" {
+  // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
+  %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.ReturnValue %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_load
 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
@@ -7,7 +32,6 @@ spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stri
   spirv.Return
 }
 
-// -----
 // CHECK-LABEL: @cooperative_matrix_load_memaccess
 spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
   // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
@@ -164,3 +188,11 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>,
   %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+  // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
+  %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.ReturnValue %0 : i32
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 06f0ccfd0a3774..e10a6fc77e8566 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -436,11 +436,55 @@ func.func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.str
 // -----
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix
+// CooperativeMatrix (KHR)
 //===----------------------------------------------------------------------===//
 
-// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
-func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
+// CHECK-LABEL: func private @coop_matrix_types
+// CHECK-SAME:    !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+// CHECK-SAME:    !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+// CHECK-SAME:    !spirv.coopmatrix<4x8xf32, Workgroup, MatrixAcc>
+func.func private @coop_matrix_types(!spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+                                     !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
+                                     !spirv.coopmatrix<4x8xf32, Workgroup, MatrixAcc>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected valid keyword}}
+func.func private @missing_scope(!spirv.coopmatrix<8x8xi32, >) -> ()
+
+// -----
+
+// expected-error @+1 {{expected ','}}
+func.func private @missing_use(!spirv.coopmatrix<8x16xi32, Subgroup>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected valid keyword}}
+func.func private @missing_use2(!spirv.coopmatrix<8x8xi32, Subgroup,>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected row and column count}}
+func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup, MatrixA>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected row and column count}}
+func.func private @too_many_dims(!spirv.coopmatrix<8x16x32xi32, Subgroup, MatrixB>) -> ()
+
+// -----
+
+// expected-error @+1 {{invalid use <id> attribute specification: Subgroup}}
+func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup>) -> ()
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
+func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
 
 // -----
 

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 7bf7755a8a52f9..ccf4240f8e5608 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -857,7 +857,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
                   attrList, attrName, words, wordIndex);
-  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+  } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "TypeAttr::get(getType({2}[{3}++]))));\n",
@@ -866,7 +866,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
     PrintFatalError(
         loc, llvm::Twine(
                  "unhandled attribute type in deserialization generation : '") +
-                 attr.getAttrDefName() + llvm::Twine("'"));
+                 attrName + llvm::Twine("'"));
   }
 }
 


        


More information about the Mlir-commits mailing list