[Mlir-commits] [mlir] [mlir][spirv] Drop support for SPV_NV_cooperative_matrix (PR #76782)

Jakub Kuderski llvmlistbot at llvm.org
Tue Jan 2 22:07:26 PST 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/76782

This extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel.

Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication.

>From 01987b32f359c871af8f9dba39a6c9cb14b8f745 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 3 Jan 2024 01:02:49 -0500
Subject: [PATCH] [mlir][spirv] Drop support for SPV_NV_cooperative_matrix

This extension has been superseeded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD,
and Intel.

Given that the KHR version has been supported for nearly half a year,
drop the NV-specific extension to reduce the maintanance burden and code
duplication.
---
 .../mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h   |  12 +-
 mlir/include/mlir/Conversion/Passes.td        |   4 -
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  38 +--
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     | 247 ------------------
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  27 --
 .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp  |  12 +-
 .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp  | 133 +---------
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         |   2 +-
 .../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 152 -----------
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  46 +---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |   8 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  93 +------
 .../SPIRV/Deserialization/DeserializeOps.cpp  |   1 -
 .../SPIRV/Deserialization/Deserializer.cpp    |  33 ---
 .../Target/SPIRV/Serialization/Serializer.cpp |  20 --
 .../wmma-ops-to-spirv-khr-coop-matrix.mlir    |   2 +-
 .../wmma-ops-to-spirv-nv-coop-matrix.mlir     | 194 --------------
 mlir/test/Dialect/SPIRV/IR/cast-ops.mlir      |  24 --
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir |  39 ---
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  |  26 --
 mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir    |   8 +-
 .../SPIRV/IR/nv-cooperative-matrix-ops.mlir   | 177 -------------
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir |   4 +-
 mlir/test/Dialect/SPIRV/IR/types.mlir         |  19 --
 mlir/test/Target/SPIRV/matrix.mlir            |   8 +-
 .../SPIRV/nv-cooperative-matrix-ops.mlir      | 102 --------
 26 files changed, 49 insertions(+), 1382 deletions(-)
 delete mode 100644 mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
 delete mode 100644 mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
 delete mode 100644 mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index cd650345f1daa2..d34549432161db 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -31,16 +31,10 @@ void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
 
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
-/// using the NV Cooperative Matrix extension.
-void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
-    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-
-/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
-/// conversion to the type converter. Defaults to KHR cooperative matrix types.
-/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
+/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
+/// conversion to the type converter.
 void populateMMAToSPIRVCoopMatrixTypeConversion(
-    SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
+    SPIRVTypeConverter &typeConverter);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..71be8841ca7c03 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -564,10 +564,6 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
     Option<"use64bitIndex", "use-64bit-index",
            "bool", /*default=*/"false",
            "Use 64-bit integers to convert index types">,
-    Option<"useCoopMatrixNV", "use-coop-matrix-nv",
-           "bool", /*default=*/"false",
-           "Use the NV cooperative matrix extension insted of the KHR extension"
-           " to lower GPU WMMA ops">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ee1fbba1e2844e..6ec97e17c5dcc8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1253,12 +1253,6 @@ def SPIRV_C_RayTracingProvisionalKHR                    : I32EnumAttrCase<"RayTr
     Extension<[SPV_KHR_ray_tracing]>
   ];
 }
-def SPIRV_C_CooperativeMatrixNV                         : I32EnumAttrCase<"CooperativeMatrixNV", 5357> {
-  list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
-  list<Availability> availability = [
-    Extension<[SPV_NV_cooperative_matrix]>
-  ];
-}
 def SPIRV_C_FragmentShaderSampleInterlockEXT            : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
   list<Availability> availability = [
@@ -1501,7 +1495,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
       SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
       SPIRV_C_RayTracingMotionBlurNV, SPIRV_C_PhysicalStorageBufferAddresses,
-      SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_CooperativeMatrixNV,
+      SPIRV_C_RayTracingProvisionalKHR,
       SPIRV_C_FragmentShaderSampleInterlockEXT,
       SPIRV_C_FragmentShaderShadingRateInterlockEXT, SPIRV_C_ShaderSMBuiltinsNV,
       SPIRV_C_FragmentShaderPixelInterlockEXT, SPIRV_C_DemoteToHelperInvocation,
@@ -4123,8 +4117,6 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
 def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
 def SPIRV_IsCooperativeMatrixType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
-def SPIRV_IsCooperativeMatrixNVType :
-  CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
 def SPIRV_IsJointMatrixType :
   CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
@@ -4157,9 +4149,6 @@ def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
 def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                    SPIRV_IsCooperativeMatrixType,
                                   "any SPIR-V cooperative matrix type">;
-def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
-                                     SPIRV_IsCooperativeMatrixNVType,
-                                     "any SPIR-V NV cooperative matrix type">;
 def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
                                 "any SPIR-V image type">;
 def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4178,13 +4167,12 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
-               SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
-    SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+    SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
+    SPIRV_AnySampledImage
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4195,11 +4183,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
     "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
     "Cooperative Matrix">;
 
-class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
-  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
-    "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
-    "Cooperative Matrix NV">;
-
 class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
     "::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
@@ -4213,12 +4196,11 @@ class SPIRV_ScalarOrVectorOf<Type type> :
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, SPIRV_VectorOf<type>,
-               SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
+               SPIRV_CoopMatrixOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
     AnyTypeOf<[SPIRV_AnyMatrix,
-               SPIRV_CoopMatrixOfType<[type]>,
-               SPIRV_CoopMatrixNVOfType<[type]>]>;
+               SPIRV_CoopMatrixOfType<[type]>]>;
 
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4480,11 +4462,6 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR   : I32EnumAttrCase<"OpCooperativeMatrix
 def SPIRV_OC_OpCooperativeMatrixStoreKHR  : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
-def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
-def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
-def SPIRV_OC_OpCooperativeMatrixStoreNV   : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
-def SPIRV_OC_OpCooperativeMatrixMulAddNV  : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
-def SPIRV_OC_OpCooperativeMatrixLengthNV  : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL     : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
 def SPIRV_OC_OpSubgroupBlockWriteINTEL    : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 def SPIRV_OC_OpAssumeTrueKHR              : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4585,9 +4562,6 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
       SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
       SPIRV_OC_OpCooperativeMatrixLengthKHR,
-      SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
-      SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
-      SPIRV_OC_OpCooperativeMatrixLengthNV,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
       SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 29ad45bddd5529..46732ba19afed5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// SPV_NV_cooperative_matrix extension ops.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
-  [Pure]> {
-  let summary = "See extension SPV_NV_cooperative_matrix";
-
-  let description = [{
-    Number of components of a cooperative matrix type accessible to each
-    invocation when treated as a composite.
-
-    Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
-
-    Type is a cooperative matrix type.
-
-    #### Example:
-
-    ```
-    %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    ```
-  }];
-
-  let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_NV_cooperative_matrix]>,
-    Capability<[SPIRV_C_CooperativeMatrixNV]>
-  ];
-
-  let arguments = (ins
-    TypeAttr:$cooperative_matrix_type
-  );
-
-  let results = (outs
-    SPIRV_Int32:$result
-  );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> {
-  let summary = "See extension SPV_NV_cooperative_matrix";
-
-  let description = [{
-    Load a cooperative matrix through a pointer.
-
-    Result Type is the type of the loaded object. It must be a cooperative
-    matrix type.
-
-    Pointer is a pointer into an array. Its type must be an OpTypePointer whose
-    Type operand is a scalar or vector type. The storage class of Pointer must
-    be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
-    supported) PhysicalStorageBufferEXT.
-
-    Stride is the number of elements in the array in memory between the first
-    component of consecutive rows (or columns) in the result. It must be a
-    scalar integer type.
-
-    ColumnMajor indicates whether the values loaded from memory are arranged in
-    column-major or row-major order. It must be a boolean constant instruction,
-    with false indicating row major and true indicating column major.
-
-    Memory Access must be a Memory Access literal. If not present, it is the
-    same as specifying None.
-
-    If ColumnMajor is false, then elements (row,*) of the result are taken in
-    order from contiguous locations starting at Pointer[row*Stride]. If
-    ColumnMajor is true, then elements (*,col) of the result are taken in order
-    from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
-    decoration on Pointer is ignored.
-
-    For a given dynamic instance of this instruction, all operands of this
-    instruction must be the same for all invocations in a given scope instance
-    (where the scope is the scope the cooperative matrix type was created with).
-    All invocations in a given scope instance must be active or all must be
-    inactive.
-
-    ### Custom assembly form
-
-    ``` {.ebnf}
-    cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad`
-                              ssa-use `,` ssa-use `,` ssa-use
-                              (`[` memory-access `]`)? ` : `
-                              pointer-type `as`
-                              cooperative-matrix-type
-    ```
-
-    #### Example:
-
-    ```
-    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
-         : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    ```
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_NV_cooperative_matrix]>,
-    Capability<[SPIRV_C_CooperativeMatrixNV]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyPtr:$pointer,
-    SPIRV_Integer:$stride,
-    SPIRV_Bool:$columnmajor,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
-  );
-
-  let results = (outs
-    SPIRV_AnyCooperativeMatrixNV:$result
-  );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd",
-  [Pure, AllTypesMatch<["c", "result"]>]> {
-  let summary = "See extension SPV_NV_cooperative_matrix";
-
-  let description = [{
-    Linear-algebraic matrix multiply of A by B and then component-wise add C.
-    The order of the operations is implementation-dependent. The internal
-    precision of floating-point operations is defined by the client API.
-    Integer operations are performed at the precision of the Result Type and are
-    exact unless there is overflow or underflow, in which case the result is
-    undefined.
-
-    Result Type must be a cooperative matrix type with M rows and N columns.
-
-    A is a cooperative matrix with M rows and K columns.
-
-    B is a cooperative matrix with K rows and N columns.
-
-    C is a cooperative matrix with M rows and N columns.
-
-    The values of M, N, and K must be consistent across the result and operands.
-    This is referred to as an MxNxK matrix multiply.
-
-    A, B, C, and Result Type must have the same scope, and this defines the
-    scope of the operation. A, B, C, and Result Type need not necessarily have
-    the same component type, this is defined by the client API.
-
-    If the Component Type of any matrix operand is an integer type, then its
-    components are treated as signed if its Component Type has Signedness of 1
-    and are treated as unsigned otherwise.
-
-    For a given dynamic instance of this instruction, all invocations in a given
-    scope instance must be active or all must be inactive (where the scope is
-    the scope of the operation).
-
-    #### Example:
-
-    ```
-    %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2,  :
-      !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    ```
-  }];
-
-  let assemblyFormat = [{
-    operands attr-dict `:` type($a) `,` type($b) `->` type($c)
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_NV_cooperative_matrix]>,
-    Capability<[SPIRV_C_CooperativeMatrixNV]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyCooperativeMatrixNV:$a,
-    SPIRV_AnyCooperativeMatrixNV:$b,
-    SPIRV_AnyCooperativeMatrixNV:$c
-  );
-
-  let results = (outs
-    SPIRV_AnyCooperativeMatrixNV:$result
-  );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> {
-  let summary = "See extension SPV_NV_cooperative_matrix";
-
-  let description = [{
-    Store a cooperative matrix through a pointer.
-
-    Pointer is a pointer into an array. Its type must be an OpTypePointer whose
-    Type operand is a scalar or vector type. The storage class of Pointer must
-    be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
-    supported) PhysicalStorageBufferEXT.
-
-    Object is the object to store. Its type must be an
-    OpTypeCooperativeMatrixNV.
-
-    Stride is the number of elements in the array in memory between the first
-    component of consecutive rows (or columns) in the result. It must be a
-    scalar integer type.
-
-    ColumnMajor indicates whether the values stored to memory are arranged in
-    column-major or row-major order. It must be a boolean constant instruction,
-    with false indicating row major and true indicating column major.
-
-    Memory Access must be a Memory Access literal. If not present, it is the
-    same as specifying None.
-
-    ``` {.ebnf}
-    coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore `
-                              ssa-use `, ` ssa-use `, `
-                              ssa-use `, ` ssa-use `, `
-                              (`[` memory-access `]`)? `:`
-                              pointer-type `,` coop-matrix-type
-    ```
-
-    #### Example:
-
-    ```
-      spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
-        !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    ```
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_NV_cooperative_matrix]>,
-    Capability<[SPIRV_C_CooperativeMatrixNV]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyPtr:$pointer,
-    SPIRV_AnyCooperativeMatrixNV:$object,
-    SPIRV_Integer:$stride,
-    SPIRV_Bool:$columnmajor,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
-  );
-
-  let results = (outs);
-}
-
 // -----
 
 #endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index d946d936d4e6cf..55f0c787b44403 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,7 +29,6 @@ namespace spirv {
 namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
-struct CooperativeMatrixNVTypeStorage;
 struct ImageTypeStorage;
 struct JointMatrixTypeStorage;
 struct MatrixTypeStorage;
@@ -421,32 +420,6 @@ class CooperativeMatrixType
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
-// SPIR-V NV cooperative matrix type
-class CooperativeMatrixNVType
-    : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
-                            detail::CooperativeMatrixNVTypeStorage> {
-public:
-  using Base::Base;
-
-  static constexpr StringLiteral name = "spirv.NV.coopmatrix";
-
-  static CooperativeMatrixNVType get(Type elementType, Scope scope,
-                                     unsigned rows, unsigned columns);
-  Type getElementType() const;
-
-  /// Returns the scope of the matrix.
-  Scope getScope() const;
-  /// Returns the number of rows of the matrix.
-  unsigned getRows() const;
-  /// Returns the number of columns of the matrix.
-  unsigned getColumns() const;
-
-  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                     std::optional<StorageClass> storage = std::nullopt);
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-};
-
 // SPIR-V joint matrix type
 class JointMatrixINTELType
     : public Type::TypeBase<JointMatrixINTELType, CompositeType,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 8279b3408a6e66..0dd0e7e21b0553 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -112,18 +112,12 @@ void GPUToSPIRVPass::runOnOperation() {
     SPIRVConversionOptions options;
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
-    populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
-                                               this->useCoopMatrixNV);
+    populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter);
 
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    if (this->useCoopMatrixNV) {
-      populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
-                                                           patterns);
-    } else {
-      populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
-                                                            patterns);
-    }
+    populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+                                                          patterns);
 
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 4a4281aaaf0dbc..92cc0eadb9784c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -32,19 +32,18 @@
 
 namespace mlir {
 //===----------------------------------------------------------------------===//
-// Patterns and helpers used by both the KHR and the NV lowering paths.
+// Patterns and helpers.
 //===----------------------------------------------------------------------===//
 
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
 /// Returns false if cannot.
 ///
-/// See SPV_NV_cooperative_matrix for supported elementwise ops.
+/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
 static bool createElementwiseOp(ConversionPatternRewriter &builder,
                                 gpu::SubgroupMmaElementwiseOp op, Type coopType,
                                 ValueRange operands) {
-  assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
-      coopType)));
+  assert((isa<spirv::CooperativeMatrixType>(coopType)));
 
   switch (op.getOpType()) {
   case gpu::MMAElementwiseOp::ADDF:
@@ -89,8 +88,7 @@ bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
           llvm::map_range(operands, [](Value v) { return v.getType(); })))
     return false;
 
-  return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
-      operands.front().getType());
+  return isa<spirv::CooperativeMatrixType>(operands.front().getType());
 }
 
 namespace {
@@ -292,104 +290,6 @@ struct WmmaMmaOpToSPIRVLowering final
 
 } // namespace
 } // namespace khr
-
-//===----------------------------------------------------------------------===//
-// SPV_NV_cooperative_matrix
-//===----------------------------------------------------------------------===//
-
-namespace nv {
-namespace {
-
-/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
-/// dialect.
-struct WmmaLoadOpToSPIRVLowering final
-    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = subgroupMmaLoadMatrixOp->getLoc();
-    auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
-
-    gpu::MMAMatrixType retType =
-        cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
-    auto memrefType =
-        cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
-    Value bufferPtr =
-        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
-                             adaptor.getIndices(), loc, rewriter);
-    auto coopType =
-        typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
-    if (!coopType)
-      return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
-                                         "type conversion failed");
-
-    int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
-    auto i32Type = rewriter.getI32Type();
-    auto strideValue = rewriter.create<spirv::ConstantOp>(
-        loc, i32Type, IntegerAttr::get(i32Type, stride));
-    bool isColMajor = static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
-    auto columnMajor = rewriter.create<spirv::ConstantOp>(
-        loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor));
-    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
-        subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
-        spirv::MemoryAccessAttr());
-    return success();
-  }
-};
-
-/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
-/// dialect.
-struct WmmaStoreOpToSPIRVLowering final
-    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = subgroupMmaStoreMatrixOp->getLoc();
-    auto memrefType =
-        cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
-    Value bufferPtr = spirv::getElementPtr(
-        *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
-        adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
-    int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
-    auto i32Type = rewriter.getI32Type();
-    auto strideValue = rewriter.create<spirv::ConstantOp>(
-        loc, i32Type, IntegerAttr::get(i32Type, stride));
-    bool useColMajor =
-        static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
-    auto columnMajor = rewriter.create<spirv::ConstantOp>(
-        loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
-    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
-        subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
-        columnMajor, spirv::MemoryAccessAttr());
-    return success();
-  }
-};
-
-/// Converts GPU MMA Compute to
-/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
-struct WmmaMmaOpToSPIRVLowering final
-    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
-        subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
-        adaptor.getOpB(), adaptor.getOpC());
-    return success();
-  }
-};
-
-} // namespace
-} // namespace nv
 } // namespace mlir
 
 void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
@@ -404,31 +304,8 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
                                                           /*benefit=*/2);
 }
 
-void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
-    SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
-  using namespace mlir;
-  MLIRContext *context = patterns.getContext();
-  patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
-               nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
-               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
-  // Give the following patterns higher benefit to prevail over the default one.
-  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
-                                                          /*benefit=*/2);
-}
-
 void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
-    mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
-  if (useNVTypes) {
-    typeConverter.addConversion([](gpu::MMAMatrixType type) {
-      ArrayRef<int64_t> retTypeShape = type.getShape();
-      Type elementType = type.getElementType();
-      return spirv::CooperativeMatrixNVType::get(
-          elementType, spirv::Scope::Subgroup, retTypeShape[0],
-          retTypeShape[1]);
-    });
-    return;
-  }
-
+    mlir::SPIRVTypeConverter &typeConverter) {
   typeConverter.addConversion([](gpu::MMAMatrixType type) {
     ArrayRef<int64_t> retTypeShape = type.getShape();
     Type elementType = type.getElementType();
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index f24da2ca5c3f24..52b4380ed27f7c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -37,7 +37,7 @@ static LogicalResult verifyCastOp(Operation *op,
   auto [operandElemTy, resultElemTy] =
       TypeSwitch<Type, TypePair>(operandType)
           .Case<VectorType, spirv::CooperativeMatrixType,
-                spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
+                spirv::JointMatrixINTELType>(
               [resultType](auto concreteOperandTy) -> TypePair {
                 if (auto concreteResultTy =
                         dyn_cast<decltype(concreteOperandTy)>(resultType)) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index c8b274ceec3e59..d532d466334a56 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -136,156 +136,4 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult NVCooperativeMatrixLengthOp::verify() {
-  if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
-    return emitOpError(
-               "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
-           << getCooperativeMatrixType() << " instead";
-  }
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
-                                             OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  Type columnMajorType = parser.getBuilder().getIntegerType(1);
-  Type ptrType;
-  Type elementType;
-  if (parser.parseOperandList(operandInfo, 3) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
-    return failure();
-  }
-  if (parser.resolveOperands(operandInfo,
-                             {ptrType, strideType, columnMajorType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  result.addTypes(elementType);
-  return success();
-}
-
-void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getStride() << ", "
-          << getColumnmajor();
-  // Print optional memory access attribute.
-  if (auto memAccess = getMemoryAccess())
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult
-verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
-  Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
-  if (!llvm::isa<ScalarType>(pointeeType) &&
-      !llvm::isa<VectorType>(pointeeType))
-    return op->emitError(
-               "Pointer must point to a scalar or vector type but provided ")
-           << pointeeType;
-  StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
-  if (storage != StorageClass::Workgroup &&
-      storage != StorageClass::StorageBuffer &&
-      storage != StorageClass::PhysicalStorageBuffer)
-    return op->emitError(
-               "Pointer storage class must be Workgroup, StorageBuffer or "
-               "PhysicalStorageBufferEXT but provided ")
-           << stringifyStorageClass(storage);
-  return success();
-}
-
-LogicalResult NVCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
-                                          getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
-                                              OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  Type columnMajorType = parser.getBuilder().getIntegerType(1);
-  Type ptrType;
-  Type elementType;
-  if (parser.parseOperandList(operandInfo, 4) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(ptrType) || parser.parseComma() ||
-      parser.parseType(elementType)) {
-    return failure();
-  }
-  if (parser.resolveOperands(
-          operandInfo, {ptrType, elementType, strideType, columnMajorType},
-          parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
-          << ", " << getColumnmajor();
-  // Print optional memory access attribute.
-  if (auto memAccess = getMemoryAccess())
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
-}
-
-LogicalResult NVCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
-                                          getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixMulAdd
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) {
-  if (op.getC().getType() != op.getResult().getType())
-    return op.emitOpError("result and third operand must have the same type");
-  auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
-  auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
-  auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
-  auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
-  if (typeA.getRows() != typeR.getRows() ||
-      typeA.getColumns() != typeB.getRows() ||
-      typeB.getColumns() != typeR.getColumns())
-    return op.emitOpError("matrix size must match");
-  if (typeR.getScope() != typeA.getScope() ||
-      typeR.getScope() != typeB.getScope() ||
-      typeR.getScope() != typeC.getScope())
-    return op.emitOpError("matrix scope must match");
-  auto elementTypeA = typeA.getElementType();
-  auto elementTypeB = typeB.getElementType();
-  if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
-    if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
-        llvm::cast<IntegerType>(elementTypeB).getWidth())
-      return op.emitOpError(
-          "matrix A and B integer element types must be the same bit width");
-  } else if (elementTypeA != elementTypeB) {
-    return op.emitOpError(
-        "matrix A and B non-integer element types must match");
-  }
-  if (typeR.getElementType() != typeC.getElementType())
-    return op.emitOpError("matrix accumulator element type must match");
-  return success();
-}
-
-LogicalResult NVCooperativeMatrixMulAddOp::verify() {
-  return verifyCoopMatrixMulAddNV(*this);
-}
-
 } // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..9d4d1aec36709e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -360,37 +360,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
-// nv-cooperative-matrix-type ::=
-//   `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>`
-static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
-                                         DialectAsmParser &parser) {
-  if (parser.parseLess())
-    return Type();
-
-  SmallVector<int64_t, 2> dims;
-  SMLoc countLoc = parser.getCurrentLocation();
-  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
-    return Type();
-
-  if (dims.size() != 2) {
-    parser.emitError(countLoc, "expected rows and columns size");
-    return Type();
-  }
-
-  auto elementTy = parseAndVerifyType(dialect, parser);
-  if (!elementTy)
-    return Type();
-
-  Scope scope;
-  if (parser.parseComma() ||
-      spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
-    return Type();
-
-  if (parser.parseGreater())
-    return Type();
-  return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
-}
-
 // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
 // element-type
 //                                                       `,` layout `,` scope`>`
@@ -810,8 +779,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseArrayType(*this, parser);
   if (keyword == "coopmatrix")
     return parseCooperativeMatrixType(*this, parser);
-  if (keyword == "NV.coopmatrix")
-    return parseCooperativeMatrixNVType(*this, parser);
   if (keyword == "jointmatrix")
     return parseJointMatrixType(*this, parser);
   if (keyword == "image")
@@ -917,12 +884,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
      << type.getUse() << ">";
 }
 
-static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
-  os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
-  os << type.getElementType() << ", " << stringifyScope(type.getScope());
-  os << ">";
-}
-
 static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
   os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
   os << type.getElementType() << ", "
@@ -937,10 +898,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
 
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
-      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
-            JointMatrixINTELType, PointerType, RuntimeArrayType, ImageType,
-            SampledImageType, StructType, MatrixType>(
-          [&](auto type) { print(type, os); })
+      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
+            RuntimeArrayType, ImageType, SampledImageType, StructType,
+            MatrixType>([&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..dc558b878b3b76 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -374,8 +374,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
 
   auto coopElementType =
       llvm::TypeSwitch<Type, Type>(getType())
-          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
-                spirv::JointMatrixINTELType>(
+          .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
               [](auto coopType) { return coopType.getElementType(); })
           .Default([](Type) { return nullptr; });
 
@@ -1611,8 +1610,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
   Type elementType =
       llvm::TypeSwitch<Type, Type>(getMatrix().getType())
-          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
-                spirv::MatrixType>(
+          .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
               [](auto matrixType) { return matrixType.getElementType(); })
           .Default([](Type) { return nullptr; });
 
@@ -1751,7 +1749,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
     return emitError("result type must be a composite type, but provided ")
            << getType();
 
-  if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
+  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
     return emitError("unsupported composite type  ") << cType;
   if (llvm::isa<spirv::JointMatrixINTELType>(cType))
     return emitError("unsupported composite type  ") << cType;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f1bac6490837b9..3f25696aa5eb6e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -95,9 +95,8 @@ bool CompositeType::classof(Type type) {
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
-                   spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType,
-                   spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType>(type);
+                   spirv::JointMatrixINTELType, spirv::MatrixType,
+                   spirv::RuntimeArrayType, spirv::StructType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -108,8 +107,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
-            JointMatrixINTELType, RuntimeArrayType, VectorType>(
+      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
+            RuntimeArrayType, VectorType>(
           [](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
@@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
-  if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
+  if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
   }
@@ -143,16 +142,16 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType,
-                    JointMatrixINTELType, RuntimeArrayType>(*this);
+  return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
+                    RuntimeArrayType>(*this);
 }
 
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
-            JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
+            RuntimeArrayType, StructType>(
           [&](auto type) { type.getExtensions(extensions, storage); })
       .Case<VectorType>([&](VectorType type) {
         return llvm::cast<ScalarType>(type.getElementType())
@@ -165,8 +164,8 @@ void CompositeType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
-            JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
+            RuntimeArrayType, StructType>(
           [&](auto type) { type.getCapabilities(capabilities, storage); })
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
@@ -267,70 +266,6 @@ void CooperativeMatrixType::getCapabilities(
   capabilities.push_back(caps);
 }
 
-//===----------------------------------------------------------------------===//
-// CooperativeMatrixNVType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
-
-  static CooperativeMatrixNVTypeStorage *
-  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
-    return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
-        CooperativeMatrixNVTypeStorage(key);
-  }
-
-  bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, scope, rows, columns);
-  }
-
-  CooperativeMatrixNVTypeStorage(const KeyTy &key)
-      : elementType(std::get<0>(key)), rows(std::get<2>(key)),
-        columns(std::get<3>(key)), scope(std::get<1>(key)) {}
-
-  Type elementType;
-  unsigned rows;
-  unsigned columns;
-  Scope scope;
-};
-
-CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
-                                                     Scope scope, unsigned rows,
-                                                     unsigned columns) {
-  return Base::get(elementType.getContext(), elementType, scope, rows, columns);
-}
-
-Type CooperativeMatrixNVType::getElementType() const {
-  return getImpl()->elementType;
-}
-
-Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
-
-unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
-
-unsigned CooperativeMatrixNVType::getColumns() const {
-  return getImpl()->columns;
-}
-
-void CooperativeMatrixNVType::getExtensions(
-    SPIRVType::ExtensionArrayRefVector &extensions,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
-  ArrayRef<Extension> ref(exts, std::size(exts));
-  extensions.push_back(ref);
-}
-
-void CooperativeMatrixNVType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-  static const Capability caps[] = {Capability::CooperativeMatrixNV};
-  ArrayRef<Capability> ref(caps, std::size(caps));
-  capabilities.push_back(ref);
-}
-
 //===----------------------------------------------------------------------===//
 // JointMatrixType
 //===----------------------------------------------------------------------===//
@@ -1312,7 +1247,7 @@ void MatrixType::getCapabilities(
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
-  addTypes<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType, ImageType,
-           JointMatrixINTELType, MatrixType, PointerType, RuntimeArrayType,
-           SampledImageType, StructType>();
+  addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
+           MatrixType, PointerType, RuntimeArrayType, SampledImageType,
+           StructType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 954aaa98c32998..a678124bf48322 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -165,7 +165,6 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
-  case spirv::Opcode::OpTypeCooperativeMatrixNV:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
     return processTypeForwardPointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..948dcfb4885b3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -778,8 +778,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processArrayType(operands);
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
     return processCooperativeMatrixTypeKHR(operands);
-  case spirv::Opcode::OpTypeCooperativeMatrixNV:
-    return processCooperativeMatrixTypeNV(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
   case spirv::Opcode::OpTypeJointMatrixINTEL:
@@ -955,37 +953,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
   return success();
 }
 
-LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
-    ArrayRef<uint32_t> operands) {
-  if (operands.size() != 5) {
-    return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
-                                 "type and row x column parameters");
-  }
-
-  Type elementTy = getType(operands[1]);
-  if (!elementTy) {
-    return emitError(unknownLoc,
-                     "OpTypeCooperativeMatrixNV references undefined <id> ")
-           << operands[1];
-  }
-
-  std::optional<spirv::Scope> scope =
-      spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
-  if (!scope) {
-    return emitError(
-               unknownLoc,
-               "OpTypeCooperativeMatrixNV references undefined scope <id> ")
-           << operands[2];
-  }
-
-  unsigned rows = getConstantInt(operands[3]).getInt();
-  unsigned columns = getConstantInt(operands[4]).getInt();
-
-  typeMap[operands[0]] =
-      spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
-  return success();
-}
-
 LogicalResult
 spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
   if (operands.size() != 6) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..08395dd4cf522f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -632,26 +632,6 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
-  if (auto cooperativeMatrixType =
-          dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
-    uint32_t elementTypeID = 0;
-    if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
-                               elementTypeID, serializationCtx))) {
-      return failure();
-    }
-    typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
-    auto getConstantOp = [&](uint32_t id) {
-      auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
-      return prepareConstantInt(loc, attr);
-    };
-    llvm::append_values(
-        operands, elementTypeID,
-        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
-        getConstantOp(cooperativeMatrixType.getRows()),
-        getConstantOp(cooperativeMatrixType.getColumns()));
-    return success();
-  }
-
   if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
     uint32_t elementTypeID = 0;
     if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index f129cc8ce84ec3..477f344b1ae5f4 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
+// RUN: mlir-opt --convert-gpu-to-spirv --cse \
 // RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
 
 module attributes {
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
deleted file mode 100644
index ec7da92704c07c..00000000000000
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
+++ /dev/null
@@ -1,194 +0,0 @@
-// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
-// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_load_op
-    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
-    gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      %i = arith.constant 16 : index
-      %j = arith.constant 16 : index
-      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
-    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
-    gpu.func @gpu_wmma_load_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      %i = arith.constant 16 : index
-      %j = arith.constant 16 : index
-      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_store_op
-    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      %i = arith.constant 16 : index
-      %j = arith.constant 16 : index
-      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
-      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
-    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
-    // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
-    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
-    gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      %i = arith.constant 16 : index
-      %j = arith.constant 16 : index
-      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
-      // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
-    gpu.func @gpu_wmma_constant_op() kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK: {{%.*}} = spirv.Constant
-      %cst = arith.constant 1.0 : f16
-      // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup>
-      %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
-    // CHECK-SAME:    %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME:    %[[S:.+]]: f16
-    gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-      %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-      %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: spirv.Return
-      gpu.return
-    }
-  }
-}
-
-// -----
-
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
-  gpu.module @kernels {
-    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
-    // CHECK-SAME:    %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-    // CHECK-SAME:    %[[S:.+]]: f16
-    gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
-      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
-      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
-      // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
-      %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
-      gpu.return
-    }
-  }
-}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..aaee2ccd3cb8c1 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -146,14 +146,6 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
 
 // -----
 
-func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
-  // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // spirv.ConvertSToF
 //===----------------------------------------------------------------------===//
@@ -238,14 +230,6 @@ func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, M
 
 // -----
 
-func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
-  // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
-  %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
-  spirv.Return
-}
-
-// -----
-
 func.func @f_convert_vector(%arg0 : f32) -> f32 {
   // expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}}
   %0 = spirv.FConvert %arg0 : f32 to f32
@@ -254,14 +238,6 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
 
 // -----
 
-func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) {
-  // expected-error @+1 {{incompatible operand and result types}}
-  %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
-  spirv.Return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // spirv.SConvert
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index b10677f0f5f99f..3fc8dfb2767d1e 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -32,13 +32,6 @@ func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<
   return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
 }
 
-// CHECK-LABEL: func @composite_construct_coopmatrix_nv
-func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
-  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
 // -----
 
 func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
@@ -75,22 +68,6 @@ func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32
 
 // -----
 
-func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
-  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
-  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
-// -----
-
-func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
-  // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
-  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
-// -----
-
 func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> {
   // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
   %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32>
@@ -143,14 +120,6 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
-func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 {
-  // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  return %0 : f32
-}
-
-// -----
-
 func.func @composite_extract_no_ssa_operand() -> () {
   // expected-error @+1 {{expected SSA operand}}
   %0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
@@ -271,14 +240,6 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3
 
 // -----
 
-func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> {
-  // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-}
-
-// -----
-
 func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
   // expected-error @+1 {{expected at least one index}}
   %0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 445ab8a48d3ce6..d3e1dbc229ef99 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -13,14 +13,6 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
 
 // -----
 
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
-  // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
-  %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// -----
-
 // CHECK-LABEL: @cooperative_matrix_load
 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
@@ -118,24 +110,6 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
 
 // -----
 
-spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{expected '<'}}
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
-    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
-  // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
-  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
-    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
 spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index f52666af280e4b..372fcc6e514b97 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   }
 
   // CHECK-LABEL: @matrix_times_scalar_2
-  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-    spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
+    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
   }
 
   // CHECK-LABEL: @matrix_transpose_1
diff --git a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
deleted file mode 100644
index 43cbf61b60ef0b..00000000000000
--- a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
+++ /dev/null
@@ -1,177 +0,0 @@
-// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
-
-//===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{matrix A and B non-integer element types must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
-  // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
-  spirv.ReturnValue %0 : i32
-}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..6f6ce1202d1704 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -797,7 +797,7 @@ spirv.module Logical GLSL450 {
 }
 
 //===----------------------------------------------------------------------===//
-// spirv.SpecConstantComposite (spirv.NV.coopmatrix)
+// spirv.SpecConstantComposite (spirv.KHR.coopmatrix)
 //===----------------------------------------------------------------------===//
 
 // -----
@@ -805,7 +805,7 @@ spirv.module Logical GLSL450 {
 spirv.module Logical GLSL450 {
   spirv.SpecConstant @sc1 = 1.5 : f32
   // expected-error @+1 {{unsupported composite type}}
-  spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device>
+  spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA>
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index e10a6fc77e8566..05ab91b6db6bd9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -479,25 +479,6 @@ func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
-//===----------------------------------------------------------------------===//
-
-// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
-func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
-
-// -----
-
-// expected-error @+1 {{expected ','}}
-func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> ()
-
-// -----
-
-// expected-error @+1 {{expected rows and columns size}}
-func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> ()
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // Matrix
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index af8f41a30d24fc..b52c3f4aa2f117 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   }
 
   // CHECK-LABEL: @matrix_times_scalar_3
-  spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
-    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
-    spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+  spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> "None" {
+    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
   }
 
   // CHECK-LABEL: @matrix_transpose_1
diff --git a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
deleted file mode 100644
index 2eec99f72691cc..00000000000000
--- a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
+++ /dev/null
@@ -1,102 +0,0 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
-  // CHECK-LABEL: @cooperative_matrix_load
-  spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_load_memaccess
-  spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_store
-  spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
-    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_store_memaccess
-  spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-    // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_length
-  spirv.func @cooperative_matrix_length() -> i32 "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.ReturnValue %0 : i32
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_muladd
-  spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-    %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_add
-  spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_sub
-  spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_sdiv
-  spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_udiv
-  spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_fadd
-  spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_fsub
-  spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_fdiv
-  spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @cooperative_matrix_access_chain
-  spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
-    %0 = spirv.Constant 0: i32
-    // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-    %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-    spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-  }
-}



More information about the Mlir-commits mailing list