[Mlir-commits] [mlir] [mlir][spirv] Fix remaining coop matrix verification corner cases (PR #66137)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 12 13:08:47 PDT 2023
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/66137:
- Check `MakePointer*` load/store attribute values.
- Support coop matrix types in `MatrixTimesScalar` verification.
- Add test cases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.
>From dc00e29f35b9b4b4114d6a5a42b903beaf195e9f Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 12 Sep 2023 16:06:07 -0400
Subject: [PATCH] [mlir][spirv] Fix remaining coop matrix verification corner
cases
- Check `MakePointer*` load/store attribute values.
- Suuport coop matrix types in `MatrixTimesScalar` verification.
- Add testcases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.
---
.../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td | 5 +-
.../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 39 ++-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 23 +-
...s.mlir => khr-cooperative-matrix-ops.mlir} | 233 +++++++++---------
.../SPIRV/IR/nv-cooperative-matrix-ops.mlir | 177 +++++++++++++
5 files changed, 334 insertions(+), 143 deletions(-)
rename mlir/test/Dialect/SPIRV/IR/{cooperative-matrix-ops.mlir => khr-cooperative-matrix-ops.mlir} (68%)
create mode 100644 mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
let summary = "Scale a floating-point matrix.";
let description = [{
- Result Type must be an OpTypeMatrix whose Column Type is a vector of
- floating-point type.
+ Result Type must be a matrix type with a float component type.
- The type of Matrix must be the same as Result Type. Each component in
+ The type of Matrix must be the same as Result Type. Each component in
each column in Matrix is multiplied by Scalar.
Scalar must have the same type as the Component Type in Result Type.
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
static LogicalResult
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
<< pointeeType;
}
- // The 'Aligned' memory operand requires an alignment literal to follow, which
- // needs to be implemented on the level of op parsing and (de-)serialization.
- // TODO: Consider adding support for this attribute value.
- if (memoryOperand &&
- spirv::bitEnumContainsAll(memoryOperand.getValue(),
- spirv::MemoryAccess::Aligned)) {
- return op->emitOpError("has unhandled memory operand 'Aligned'");
+ if (memoryOperand) {
+ spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+ if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
+ spirv::bitEnumContainsAll(operandSet,
+ spirv::MemoryAccess::MakePointerAvailable)) {
+ return op->emitOpError(
+ "not compatible with memory operand 'MakePointerAvailable'");
+ }
+
+ if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
+ spirv::bitEnumContainsAll(operandSet,
+ spirv::MemoryAccess::MakePointerVisible)) {
+ return op->emitOpError(
+ "not compatible with memory operand 'MakePointerVisible'");
+ }
+
+ // The 'Aligned' memory operand requires an alignment literal to follow,
+ // which needs to be implemented on the level of op parsing and
+ // (de-)serialization.
+ // TODO: Consider adding support for this attribute value.
+ if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+ spirv::MemoryAccess::Aligned)) {
+ return op->emitOpError("has unhandled memory operand 'Aligned'");
+ }
}
// TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
getResult().getType(), getMemoryOperandAttr());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..6cd75ee6d9cba48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <numeric>
#include <optional>
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesScalarOp::verify() {
- if (auto inputCoopmat = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(
- getMatrix().getType())) {
- if (inputCoopmat.getElementType() != getScalar().getType())
- return emitError("input matrix components' type and scaling value must "
- "have the same type");
- return success();
- }
+ Type elementType =
+ llvm::TypeSwitch<Type, Type>(getMatrix().getType())
+ .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+ spirv::MatrixType>(
+ [](auto matrixType) { return matrixType.getElementType(); })
+ .Default([](Type) { return nullptr; });
+
+ assert(elementType && "Unhandld type");
// Check that the scalar type is the same as the matrix element type.
- auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
- if (getScalar().getType() != inputMatrix.getElementType())
- return emitError("input matrix components' type and scaling value must "
- "have the same type");
+ if (getScalar().getType() != elementType)
+ return emitOpError("input matrix components' type and scaling value must "
+ "have the same type");
return success();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
// -----
+spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
// -----
+spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+ // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
// -----
//===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+ %p = spirv.SNegate %a : !matA_i32
+ %q = spirv.SNegate %b : !matB_i32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+ // CHECK: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+ %p = spirv.FNegate %a : !matA_f32
+ %q = spirv.FNegate %b : !matB_f32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.IAdd %a, %a : !matA_i32
+ %q = spirv.IAdd %b, %b : !matB_i32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+ // CHECK: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.FAdd %a, %a : !matA_f32
+ %q = spirv.FAdd %b, %b : !matB_f32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.ISub %a, %a : !matA_i32
+ %q = spirv.ISub %b, %b : !matB_i32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+ // CHECK: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.FSub %a, %a : !matA_f32
+ %q = spirv.FSub %b, %b : !matB_f32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+ // CHECK: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.FMul %a, %a : !matA_f32
+ %q = spirv.FMul %b, %b : !matB_f32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.IMul %a, %a : !matA_i32
+ %q = spirv.IMul %b, %b : !matB_i32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+ // CHECK: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.FDiv %a, %a : !matA_f32
+ %q = spirv.FDiv %b, %b : !matB_f32
spirv.Return
}
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.SDiv %a, %a : !matA_i32
+ %q = spirv.SDiv %b, %b : !matB_i32
spirv.Return
}
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
- %0 = spirv.Constant 0: i32
- // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+ // CHECK: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+ %p = spirv.UDiv %a, %a : !matA_i32
+ %q = spirv.UDiv %b, %b : !matB_i32
spirv.Return
}
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @matrix_times_scalar
+spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
+ // CHECK: spirv.MatrixTimesScalar {{%.*}} : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, f32
+ %p = spirv.MatrixTimesScalar %a, %b : !matA_f32, f32
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{matrix A and B non-integer element types must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
+// For binary arithmetic instructions with coop matrix operands, the types must
+// match.
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
+ %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
+ // expected-error @+1 {{op requires the same type for all operands and results}}
+ %q = "spirv.IAdd"(%a, %b) :
+ (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
+ -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // expected-error @+1 {{Pointer must point to a scalar or vector type}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
+ %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
+ // expected-error @+1 {{op requires the same type for all operands and results}}
+ %q = "spirv.FAdd"(%a, %b) :
+ (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
+ -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
- // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
+ // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+ %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
spirv.Return
}
-
-// -----
-
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
- spirv.ReturnValue %0 : i32
-}
diff --git a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
new file mode 100644
index 000000000000000..43cbf61b60ef0b6
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
@@ -0,0 +1,177 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cooperative_matrix_load
+spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_memaccess
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
+spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_store_memaccess
+spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_length
+spirv.func @cooperative_matrix_length() -> i32 "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: @cooperative_matrix_muladd
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_add
+spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sub
+spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_sdiv
+spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_udiv
+spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fadd
+spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fsub
+spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_fdiv
+spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+// CHECK-LABEL: @cooperative_matrix_access_chain
+spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+ %0 = spirv.Constant 0: i32
+ // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{matrix A and B non-integer element types must match}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+ // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.ReturnValue %0 : i32
+}
More information about the Mlir-commits
mailing list