[Mlir-commits] [mlir] 1d51597 - [mlir][spirv] Add missing NV prefix/suffix for coop matrix

Jakub Kuderski llvmlistbot at llvm.org
Mon Jul 10 07:08:19 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-10T10:06:24-04:00
New Revision: 1d515978709cd97d818738d39e699bf5d88dedab

URL: https://github.com/llvm/llvm-project/commit/1d515978709cd97d818738d39e699bf5d88dedab
DIFF: https://github.com/llvm/llvm-project/commit/1d515978709cd97d818738d39e699bf5d88dedab.diff

LOG: [mlir][spirv] Add missing NV prefix/suffix for coop matrix

This is in preparation for adding the KHR version of the cooperative
matrix extension, `SPV_KHR_cooperative_matrix`, that comes with
equivalent ops and type. These are not cross-extension compatible,
so it's better to add prefixes/suffixes to the Nvidia one,
`SPV_NV_cooperative_matrix`, before adding the KHR counterparts.

In near future, I plan for these two extensions to co-exist in
the SPIR-V dialect, but we may want to remove the NV one at some point.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D154799

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
    mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
    mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
    mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/test/Dialect/SPIRV/IR/types.mlir
    mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
    mlir/test/Target/SPIRV/matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 87d10407739485..2b33327dec1fbc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4069,7 +4069,7 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
               !interleave(widths, "/") # "-bit signless/unsigned integer">;
 
 def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
-def SPIRV_IsCooperativeMatrixType :
+def SPIRV_IsCooperativeMatrixNVType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
 def SPIRV_IsJointMatrixType :
@@ -4100,9 +4100,9 @@ def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
                              "any SPIR-V pointer type">;
 def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
                                "any SPIR-V array type">;
-def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
-                               SPIRV_IsCooperativeMatrixType,
-                               "any SPIR-V cooperative matrix type">;
+def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
+                                     SPIRV_IsCooperativeMatrixNVType,
+                                     "any SPIR-V NV cooperative matrix type">;
 def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
                                 "any SPIR-V image type">;
 def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4121,21 +4121,21 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+               SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
+    SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
     SPIRV_AnySampledImage
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
 
-class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
-  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixType,
+class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
     "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
-    "Cooperative Matrix">;
+    "Cooperative Matrix NV">;
 
 class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
@@ -4147,10 +4147,10 @@ class SPIRV_ScalarOrVectorOf<Type type> :
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
-               SPIRV_CoopMatrixOfType<[type]>]>;
+               SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>;
+    AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 22ee3fb301a2d1..71c4f7e17bf013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -13,6 +13,10 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
 #define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
 
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix extension ops.
+//===----------------------------------------------------------------------===//
+
 // -----
 
 def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
@@ -35,7 +39,7 @@ def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLengt
     For example:
 
     ```
-    %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<Subgroup, i32, 8, 16>
+    %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<Subgroup, i32, 8, 16>
     ```
   }];
 
@@ -111,7 +115,7 @@ def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad",
 
     ```
     %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
-         : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<i32, Workgroup, 16, 8>
+         : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<i32, Workgroup, 16, 8>
     ```
   }];
 
@@ -130,7 +134,7 @@ def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad",
   );
 
   let results = (outs
-    SPIRV_AnyCooperativeMatrix:$result
+    SPIRV_AnyCooperativeMatrixNV:$result
   );
 }
 
@@ -182,7 +186,7 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
 
     ```
     %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2,  :
-      !spirv.coopmatrix<Subgroup, i32, 8, 16>
+      !spirv.NV.coopmatrix<Subgroup, i32, 8, 16>
     ```
   }];
 
@@ -198,13 +202,13 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
   ];
 
   let arguments = (ins
-    SPIRV_AnyCooperativeMatrix:$a,
-    SPIRV_AnyCooperativeMatrix:$b,
-    SPIRV_AnyCooperativeMatrix:$c
+    SPIRV_AnyCooperativeMatrixNV:$a,
+    SPIRV_AnyCooperativeMatrixNV:$b,
+    SPIRV_AnyCooperativeMatrixNV:$c
   );
 
   let results = (outs
-    SPIRV_AnyCooperativeMatrix:$result
+    SPIRV_AnyCooperativeMatrixNV:$result
   );
 }
 
@@ -247,7 +251,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
 
     ```
       spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
-        !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<Workgroup, i32, 16, 8>
+        !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<Workgroup, i32, 16, 8>
     ```
   }];
 
@@ -260,7 +264,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
 
   let arguments = (ins
     SPIRV_AnyPtr:$pointer,
-    SPIRV_AnyCooperativeMatrix:$object,
+    SPIRV_AnyCooperativeMatrixNV:$object,
     SPIRV_Integer:$stride,
     SPIRV_Bool:$columnmajor,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index daa4d61f7103f1..b5b1f5ad4f52f1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -27,7 +27,7 @@ namespace spirv {
 
 namespace detail {
 struct ArrayTypeStorage;
-struct CooperativeMatrixTypeStorage;
+struct CooperativeMatrixNVTypeStorage;
 struct ImageTypeStorage;
 struct JointMatrixTypeStorage;
 struct MatrixTypeStorage;
@@ -398,10 +398,10 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
-// SPIR-V cooperative matrix type
+// SPIR-V NV cooperative matrix type
 class CooperativeMatrixNVType
     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
-                            detail::CooperativeMatrixTypeStorage> {
+                            detail::CooperativeMatrixNVTypeStorage> {
 public:
   using Base::Base;
 
@@ -409,11 +409,11 @@ class CooperativeMatrixNVType
                                      unsigned rows, unsigned columns);
   Type getElementType() const;
 
-  /// Return the scope of the cooperative matrix.
+  /// Returns the scope of the matrix.
   Scope getScope() const;
-  /// return the number of rows of the matrix.
+  /// Returns the number of rows of the matrix.
   unsigned getRows() const;
-  /// return the number of columns of the matrix.
+  /// Returns the number of columns of the matrix.
   unsigned getColumns() const;
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 37e5d77caa49c4..a0bf1300b183a3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -318,9 +318,8 @@ static Type parseArrayType(SPIRVDialect const &dialect,
   return ArrayType::get(elementType, count, stride);
 }
 
-// cooperative-matrix-type ::= `!spirv.coopmatrix` `<` element-type ',' scope
-// ','
-//                                                   rows ',' columns>`
+// cooperative-matrix-type ::=
+//   `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type ',' scope>
 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
                                        DialectAsmParser &parser) {
   if (parser.parseLess())
@@ -786,7 +785,7 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
 
   if (keyword == "array")
     return parseArrayType(*this, parser);
-  if (keyword == "coopmatrix")
+  if (keyword == "NV.coopmatrix")
     return parseCooperativeMatrixType(*this, parser);
   if (keyword == "jointmatrix")
     return parseJointMatrixType(*this, parser);
@@ -891,7 +890,7 @@ static void print(StructType type, DialectAsmPrinter &os) {
 }
 
 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
-  os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+  os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
   os << type.getElementType() << ", " << stringifyScope(type.getScope());
   os << ">";
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 30fc3e1d11bb11..1599c5bb74ae06 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -203,23 +203,23 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
 }
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrixType
+// CooperativeMatrixNVType
 //===----------------------------------------------------------------------===//
 
-struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
+struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage {
   using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
 
-  static CooperativeMatrixTypeStorage *
+  static CooperativeMatrixNVTypeStorage *
   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
-    return new (allocator.allocate<CooperativeMatrixTypeStorage>())
-        CooperativeMatrixTypeStorage(key);
+    return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
+        CooperativeMatrixNVTypeStorage(key);
   }
 
   bool operator==(const KeyTy &key) const {
     return key == KeyTy(elementType, scope, rows, columns);
   }
 
-  CooperativeMatrixTypeStorage(const KeyTy &key)
+  CooperativeMatrixNVTypeStorage(const KeyTy &key)
       : elementType(std::get<0>(key)), rows(std::get<2>(key)),
         columns(std::get<3>(key)), scope(std::get<1>(key)) {}
 

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index 829107f2625bed..12b6a2eb94268c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -11,7 +11,7 @@ module attributes {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
       // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -33,7 +33,7 @@ module attributes {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
       // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -49,13 +49,13 @@ module attributes {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_store_op
     // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
     gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
       // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
-      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
@@ -71,14 +71,14 @@ module attributes {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
     // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+    // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
     // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
     gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
       // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
-      // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
@@ -93,12 +93,12 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
     gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -117,7 +117,7 @@ module attributes {
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       // CHECK: {{%.*}} = spirv.Constant
       %cst = arith.constant 1.0 : f16
-      // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -132,15 +132,15 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
     gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -155,14 +155,14 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
-    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME:    %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
     // CHECK-SAME:    %[[S:.+]]: f16
     gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
       %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
       %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -177,13 +177,13 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
   gpu.module @kernels {
     // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
-    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+    // CHECK-SAME:    %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
     // CHECK-SAME:    %[[S:.+]]: f16
     gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
       %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
       gpu.return
     }

diff  --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index cfbbf9494a00ec..1835c6ae1d5f80 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -138,9 +138,9 @@ func.func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
 
 // -----
 
-func.func @convert_f_to_u_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) {
-  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup>
+func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
@@ -222,9 +222,9 @@ func.func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> {
 
 // -----
 
-func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) {
-  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup>
-  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup>
+func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
+  %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
   spirv.Return
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index e363d14b29ad17..ce7f6bc6118b31 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -29,10 +29,10 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
 
 // -----
 
-func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
-  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
-  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
+  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
 }
 
 // -----
@@ -53,18 +53,18 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
 
 // -----
 
-func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
+func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
-  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
 }
 
 // -----
 
-func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
+func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
-  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
 }
 
 // -----
@@ -121,9 +121,9 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
-func.func @composite_extract_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) -> f32 {
-  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup>
-  %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup>
+func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 {
+  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   return %0 : f32
 }
 
@@ -249,10 +249,10 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3
 
 // -----
 
-func.func @composite_insert_coopmatrix(%arg0: !spirv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.coopmatrix<8x16xi32, Subgroup> {
-  // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup>
-  return %0: !spirv.coopmatrix<8x16xi32, Subgroup>
+func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> {
+  // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index de31458b94771e..2e387403964612 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -2,150 +2,150 @@
 
 // CHECK-LABEL: @cooperative_matrix_load
 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
   spirv.Return
 }
 
 // -----
 // CHECK-LABEL: @cooperative_matrix_load_memaccess
 spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
 spirv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup>
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
+  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_length
 spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.ReturnValue %0 : i32
 }
 
 // CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   spirv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   spirv.Return
 }
 
 // -----
 
 // CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
   %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
   spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<16x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Workgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{matrix A and B non-integer element types must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
   // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
@@ -153,7 +153,7 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>
 
 spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
   // expected-error @+1 {{Pointer must point to a scalar or vector type}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }
 
@@ -161,6 +161,6 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f
 
 spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
   // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.Return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 8cdf2390d7232a..f52666af280e4b 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   }
 
   // CHECK-LABEL: @matrix_times_scalar_2
-  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
-    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+    spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
   }
 
   // CHECK-LABEL: @matrix_transpose_1

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 174485f71d21d1..722e4434aeaf9f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -797,7 +797,7 @@ spirv.module Logical GLSL450 {
 }
 
 //===----------------------------------------------------------------------===//
-// spirv.SpecConstantComposite (spirv.coopmatrix)
+// spirv.SpecConstantComposite (spirv.NV.coopmatrix)
 //===----------------------------------------------------------------------===//
 
 // -----
@@ -805,7 +805,7 @@ spirv.module Logical GLSL450 {
 spirv.module Logical GLSL450 {
   spirv.SpecConstant @sc1 = 1.5 : f32
   // expected-error @+1 {{unsupported composite type}}
-  spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device>
+  spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device>
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7e2833e79646ea..06f0ccfd0a3774 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -439,18 +439,18 @@ func.func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.str
 // CooperativeMatrix
 //===----------------------------------------------------------------------===//
 
-// CHECK: func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>)
-func.func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -> ()
+// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
+func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
 
 // -----
 
 // expected-error @+1 {{expected ','}}
-func.func private @missing_scope(!spirv.coopmatrix<8x16xi32>) -> ()
+func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> ()
 
 // -----
 
 // expected-error @+1 {{expected rows and columns size}}
-func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup>) -> ()
+func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> ()
 
 // -----
 

diff  --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
index 9b060f18d0fc33..2eec99f72691cc 100644
--- a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
@@ -3,100 +3,100 @@
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
   // CHECK-LABEL: @cooperative_matrix_load
   spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
-    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
+    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_load_memaccess
   spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
-    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store
-  spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
-    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup>
-    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup>
+  spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
+    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store_memaccess
-  spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
-    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
+  spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_length
   spirv.func @cooperative_matrix_length() -> i32 "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
-    %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.ReturnValue %0 : i32
   }
 
   // CHECK-LABEL: @cooperative_matrix_muladd
-  spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
-    %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+    %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_add
-  spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+  spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_sub
-  spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+  spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_sdiv
-  spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+  spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_udiv
-  spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+  spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+    %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_fadd
-  spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+  spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+    %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_fsub
-  spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+  spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+    %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_fdiv
-  spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+  spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+    %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
     spirv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_access_chain
-  spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+  spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
     %0 = spirv.Constant 0: i32
-    // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-    %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+    // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+    %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
     spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
   }
 }

diff  --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0b71b3f24b19da..af8f41a30d24fc 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   }
 
   // CHECK-LABEL: @matrix_times_scalar_3
-  spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
-    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+  spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+    spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
   }
 
   // CHECK-LABEL: @matrix_transpose_1


        


More information about the Mlir-commits mailing list