[Mlir-commits] [mlir] 4c4bdf0 - [mlir][spirv] Fix remaining coop matrix verification corner cases (#66137)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 12 14:04:06 PDT 2023


Author: Jakub Kuderski
Date: 2023-09-12T17:04:03-04:00
New Revision: 4c4bdf0c3ab855f8ef2834ee869b9091c50ed937

URL: https://github.com/llvm/llvm-project/commit/4c4bdf0c3ab855f8ef2834ee869b9091c50ed937
DIFF: https://github.com/llvm/llvm-project/commit/4c4bdf0c3ab855f8ef2834ee869b9091c50ed937.diff

LOG: [mlir][spirv] Fix remaining coop matrix verification corner cases (#66137)

- Check `MakePointer*` load/store attribute values.
- Support coop matrix types in `MatrixTimesScalar` verification.
- Add test cases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.

Added: 
    mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
    mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
    mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Removed: 
    mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
-    Result Type must be an OpTypeMatrix whose Column Type is a vector of
-    floating-point type.
+    Result Type must be a matrix type with a float component type.
 
-     The type of Matrix must be the same as Result Type. Each component in
+    The type of Matrix must be the same as Result Type. Each component in
     each column in Matrix is multiplied by Scalar.
 
     Scalar must have the same type as the Component Type in Result Type.

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
            << pointeeType;
   }
 
-  // The 'Aligned' memory operand requires an alignment literal to follow, which
-  // needs to be implemented on the level of op parsing and (de-)serialization.
-  // TODO: Consider adding support for this attribute value.
-  if (memoryOperand &&
-      spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  if (memoryOperand) {
+    spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+    if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerAvailable)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerAvailable'");
+    }
+
+    if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerVisible)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerVisible'");
+    }
+
+    // The 'Aligned' memory operand requires an alignment literal to follow,
+    // which needs to be implemented on the level of op parsing and
+    // (de-)serialization.
+    // TODO: Consider adding support for this attribute value.
+    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                  spirv::MemoryAccess::Aligned)) {
+      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    }
   }
 
   // TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                 getResult().getType(), getMemoryOperandAttr());

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..1f07b0b9e85bff6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <cassert>
 #include <numeric>
 #include <optional>
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  if (auto inputCoopmat = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(
-          getMatrix().getType())) {
-    if (inputCoopmat.getElementType() != getScalar().getType())
-      return emitError("input matrix components' type and scaling value must "
-                       "have the same type");
-    return success();
-  }
+  Type elementType =
+      llvm::TypeSwitch<Type, Type>(getMatrix().getType())
+          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+                spirv::MatrixType>(
+              [](auto matrixType) { return matrixType.getElementType(); })
+          .Default([](Type) { return nullptr; });
+
+  assert(elementType && "Unhandled type");
 
   // Check that the scalar type is the same as the matrix element type.
-  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
-  if (getScalar().getType() != inputMatrix.getElementType())
-    return emitError("input matrix components' type and scaling value must "
-                     "have the same type");
+  if (getScalar().getType() != elementType)
+    return emitOpError("input matrix components' type and scaling value must "
+                       "have the same type");
 
   return success();
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
 
 // -----
 
+spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
 
 // -----
 
+spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 // -----
 
 //===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
-spirv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
 
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SNegate %a : !matA_i32
+  %q = spirv.SNegate %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_f32
+  %q = spirv.FNegate %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IAdd %a, %a : !matA_i32
+  %q = spirv.IAdd %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_f32
+  %q = spirv.FAdd %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.ISub %a, %a : !matA_i32
+  %q = spirv.ISub %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_f32
+  %q = spirv.FSub %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_f32
+  %q = spirv.FMul %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IMul %a, %a : !matA_i32
+  %q = spirv.IMul %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_f32
+  %q = spirv.FDiv %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SDiv %a, %a : !matA_i32
+  %q = spirv.SDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.UDiv %a, %a : !matA_i32
+  %q = spirv.UDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @matrix_times_scalar
+spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
+  // CHECK: spirv.MatrixTimesScalar {{%.*}} : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, f32
+  %p = spirv.MatrixTimesScalar %a, %b : !matA_f32, f32
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{matrix A and B non-integer element types must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
+// For binary arithmetic instructions with coop matrix operands, the types must
+// match.
 
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
+                 %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
+  // expected-error @+1 {{op requires the same type for all operands and results}}
+  %q = "spirv.IAdd"(%a, %b) :
+    (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
+    -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
+                 %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
+  // expected-error @+1 {{op requires the same type for all operands and results}}
+  %q = "spirv.FAdd"(%a, %b) :
+    (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
+    -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
   spirv.Return
 }
 
 // -----
 
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
-  // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
+  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
   spirv.Return
 }
-
-// -----
-
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
-  spirv.ReturnValue %0 : i32
-}

diff  --git a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
new file mode 100644
index 000000000000000..43cbf61b60ef0b6
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
@@ -0,0 +1,177 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cooperative_matrix_load
+spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_memaccess
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
+spirv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
+  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_memaccess
+spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_length
+spirv.func @cooperative_matrix_length() -> i32 "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: @cooperative_matrix_muladd
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_add
+spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sub
+spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sdiv
+spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_udiv
+spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fadd
+spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fsub
+spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fdiv
+spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+// CHECK-LABEL: @cooperative_matrix_access_chain
+spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+  %0 = spirv.Constant 0: i32
+  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+  spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{matrix A and B non-integer element types must match}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+  // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
+  %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+  spirv.ReturnValue %0 : i32
+}


        


More information about the Mlir-commits mailing list