[Mlir-commits] [mlir] [mlir][spirv] Fix remaining coop matrix verification corner cases (PR #66137)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 12 13:09:50 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-mlir-spirv

<details>
<summary>Changes</summary>

- Check `MakePointer*` load/store attribute values.
- Support coop matrix types in `MatrixTimesScalar` verification.
- Add test cases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.
--

Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff

5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td (+2-3) 
- (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+29-10) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+12-11) 
- (renamed) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+114-119) 
- (added) mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir (+177) 


<pre>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
-    Result Type must be an OpTypeMatrix whose Column Type is a vector of
-    floating-point type.
+    Result Type must be a matrix type with a float component type.
 
-     The type of Matrix must be the same as Result Type. Each component in
+    The type of Matrix must be the same as Result Type. Each component in
     each column in Matrix is multiplied by Scalar.
 
     Scalar must have the same type as the Component Type in Result Type.
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
            << pointeeType;
   }
 
-  // The 'Aligned' memory operand requires an alignment literal to follow, which
-  // needs to be implemented on the level of op parsing and (de-)serialization.
-  // TODO: Consider adding support for this attribute value.
-  if (memoryOperand &&
-      spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  if (memoryOperand) {
+    spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+    if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerAvailable)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerAvailable'");
+    }
+
+    if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerVisible)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerVisible'");
+    }
+
+    // The 'Aligned' memory operand requires an alignment literal to follow,
+    // which needs to be implemented on the level of op parsing and
+    // (de-)serialization.
+    // TODO: Consider adding support for this attribute value.
+    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                  spirv::MemoryAccess::Aligned)) {
+      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    }
   }
 
   // TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                 getResult().getType(), getMemoryOperandAttr());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..6cd75ee6d9cba48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <cassert>
 #include <numeric>
 #include <optional>
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  if (auto inputCoopmat = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(
-          getMatrix().getType())) {
-    if (inputCoopmat.getElementType() != getScalar().getType())
-      return emitError("input matrix components' type and scaling value must "
-                       "have the same type");
-    return success();
-  }
+  Type elementType =
+      llvm::TypeSwitch<Type, Type>(getMatrix().getType())
+          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+                spirv::MatrixType>(
+              [](auto matrixType) { return matrixType.getElementType(); })
+          .Default([](Type) { return nullptr; });
+
+  assert(elementType && "Unhandld type");
 
   // Check that the scalar type is the same as the matrix element type.
-  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
-  if (getScalar().getType() != inputMatrix.getElementType())
-    return emitError("input matrix components' type and scaling value must "
-                     "have the same type");
+  if (getScalar().getType() != elementType)
+    return emitOpError("input matrix components' type and scaling value must "
+                       "have the same type");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
 
 // -----
 
+spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
+    !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
 
 // -----
 
+spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
+    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 // -----
 
 //===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
 
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SNegate %a : !matA_i32
+  %q = spirv.SNegate %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_f32
+  %q = spirv.FNegate %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IAdd %a, %a : !matA_i32
+  %q = spirv.IAdd %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_f32
+  %q = spirv.FAdd %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.ISub %a, %a : !matA_i32
+  %q = spirv.ISub %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_f32
+  %q = spirv.FSub %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_f32
+  %q = spirv.FMul %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IMul %a, %a : !matA_i32
+  %q = spirv.IMul %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_f32
+  %q = spirv.FDiv %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SDiv %a, %a : !matA_i32
+  %q = spirv.SDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.UDiv %a, %a : !matA_i32
+  %q = spirv.UDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup...
<truncated>
</pre>

</details>

https://github.com/llvm/llvm-project/pull/66137


More information about the Mlir-commits mailing list