[Mlir-commits] [mlir] 1d51597 - [mlir][spirv] Add missing NV prefix/suffix for coop matrix
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jul 10 07:08:19 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-10T10:06:24-04:00
New Revision: 1d515978709cd97d818738d39e699bf5d88dedab
URL: https://github.com/llvm/llvm-project/commit/1d515978709cd97d818738d39e699bf5d88dedab
DIFF: https://github.com/llvm/llvm-project/commit/1d515978709cd97d818738d39e699bf5d88dedab.diff
LOG: [mlir][spirv] Add missing NV prefix/suffix for coop matrix
This is in preparation for adding the KHR version of the cooperative
matrix extension, `SPV_KHR_cooperative_matrix`, that comes with
equivalent ops and type. These are not cross-extension compatible,
so it's better to add prefixes/suffixes to the Nvidia one,
`SPV_NV_cooperative_matrix`, before adding the KHR counterparts.
In near future, I plan for these two extensions to co-exist in
the SPIR-V dialect, but we may want to remove the NV one at some point.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D154799
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
mlir/test/Dialect/SPIRV/IR/types.mlir
mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
mlir/test/Target/SPIRV/matrix.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 87d10407739485..2b33327dec1fbc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4069,7 +4069,7 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
!interleave(widths, "/") # "-bit signless/unsigned integer">;
def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
-def SPIRV_IsCooperativeMatrixType :
+def SPIRV_IsCooperativeMatrixNVType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
def SPIRV_IsJointMatrixType :
@@ -4100,9 +4100,9 @@ def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
"any SPIR-V pointer type">;
def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
"any SPIR-V array type">;
-def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
- SPIRV_IsCooperativeMatrixType,
- "any SPIR-V cooperative matrix type">;
+def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
+ SPIRV_IsCooperativeMatrixNVType,
+ "any SPIR-V NV cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4121,21 +4121,21 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+ SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
+ SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
SPIRV_AnySampledImage
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
-class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
- ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixType,
+class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
"::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
- "Cooperative Matrix">;
+ "Cooperative Matrix NV">;
class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
@@ -4147,10 +4147,10 @@ class SPIRV_ScalarOrVectorOf<Type type> :
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
- SPIRV_CoopMatrixOfType<[type]>]>;
+ SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
- AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>;
+ AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>;
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 22ee3fb301a2d1..71c4f7e17bf013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -13,6 +13,10 @@
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix extension ops.
+//===----------------------------------------------------------------------===//
+
// -----
def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
@@ -35,7 +39,7 @@ def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLengt
For example:
```
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<Subgroup, i32, 8, 16>
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<Subgroup, i32, 8, 16>
```
}];
@@ -111,7 +115,7 @@ def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad",
```
%0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
- : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<i32, Workgroup, 16, 8>
+ : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<i32, Workgroup, 16, 8>
```
}];
@@ -130,7 +134,7 @@ def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad",
);
let results = (outs
- SPIRV_AnyCooperativeMatrix:$result
+ SPIRV_AnyCooperativeMatrixNV:$result
);
}
@@ -182,7 +186,7 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
```
%0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, :
- !spirv.coopmatrix<Subgroup, i32, 8, 16>
+ !spirv.NV.coopmatrix<Subgroup, i32, 8, 16>
```
}];
@@ -198,13 +202,13 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
];
let arguments = (ins
- SPIRV_AnyCooperativeMatrix:$a,
- SPIRV_AnyCooperativeMatrix:$b,
- SPIRV_AnyCooperativeMatrix:$c
+ SPIRV_AnyCooperativeMatrixNV:$a,
+ SPIRV_AnyCooperativeMatrixNV:$b,
+ SPIRV_AnyCooperativeMatrixNV:$c
);
let results = (outs
- SPIRV_AnyCooperativeMatrix:$result
+ SPIRV_AnyCooperativeMatrixNV:$result
);
}
@@ -247,7 +251,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
```
spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
- !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<Workgroup, i32, 16, 8>
+ !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<Workgroup, i32, 16, 8>
```
}];
@@ -260,7 +264,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
let arguments = (ins
SPIRV_AnyPtr:$pointer,
- SPIRV_AnyCooperativeMatrix:$object,
+ SPIRV_AnyCooperativeMatrixNV:$object,
SPIRV_Integer:$stride,
SPIRV_Bool:$columnmajor,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index daa4d61f7103f1..b5b1f5ad4f52f1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -27,7 +27,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
-struct CooperativeMatrixTypeStorage;
+struct CooperativeMatrixNVTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
@@ -398,10 +398,10 @@ class StructType
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
-// SPIR-V cooperative matrix type
+// SPIR-V NV cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
- detail::CooperativeMatrixTypeStorage> {
+ detail::CooperativeMatrixNVTypeStorage> {
public:
using Base::Base;
@@ -409,11 +409,11 @@ class CooperativeMatrixNVType
unsigned rows, unsigned columns);
Type getElementType() const;
- /// Return the scope of the cooperative matrix.
+ /// Returns the scope of the matrix.
Scope getScope() const;
- /// return the number of rows of the matrix.
+ /// Returns the number of rows of the matrix.
unsigned getRows() const;
- /// return the number of columns of the matrix.
+ /// Returns the number of columns of the matrix.
unsigned getColumns() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 37e5d77caa49c4..a0bf1300b183a3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -318,9 +318,8 @@ static Type parseArrayType(SPIRVDialect const &dialect,
return ArrayType::get(elementType, count, stride);
}
-// cooperative-matrix-type ::= `!spirv.coopmatrix` `<` element-type ',' scope
-// ','
-// rows ',' columns>`
+// cooperative-matrix-type ::=
+// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type ',' scope>
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -786,7 +785,7 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "array")
return parseArrayType(*this, parser);
- if (keyword == "coopmatrix")
+ if (keyword == "NV.coopmatrix")
return parseCooperativeMatrixType(*this, parser);
if (keyword == "jointmatrix")
return parseJointMatrixType(*this, parser);
@@ -891,7 +890,7 @@ static void print(StructType type, DialectAsmPrinter &os) {
}
static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
- os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+ os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", " << stringifyScope(type.getScope());
os << ">";
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 30fc3e1d11bb11..1599c5bb74ae06 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -203,23 +203,23 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
}
//===----------------------------------------------------------------------===//
-// CooperativeMatrixType
+// CooperativeMatrixNVType
//===----------------------------------------------------------------------===//
-struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
+struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
- static CooperativeMatrixTypeStorage *
+ static CooperativeMatrixNVTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
- return new (allocator.allocate<CooperativeMatrixTypeStorage>())
- CooperativeMatrixTypeStorage(key);
+ return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
+ CooperativeMatrixNVTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, scope, rows, columns);
}
- CooperativeMatrixTypeStorage(const KeyTy &key)
+ CooperativeMatrixNVTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<2>(key)),
columns(std::get<3>(key)), scope(std::get<1>(key)) {}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index 829107f2625bed..12b6a2eb94268c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -11,7 +11,7 @@ module attributes {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -33,7 +33,7 @@ module attributes {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -49,13 +49,13 @@ module attributes {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
@@ -71,14 +71,14 @@ module attributes {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+ // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 16 : index
%j = arith.constant 16 : index
// CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
@@ -93,12 +93,12 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -117,7 +117,7 @@ module attributes {
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.Constant
%cst = arith.constant 1.0 : f16
- // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -132,15 +132,15 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -155,14 +155,14 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
- // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
%C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
+ // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
%D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: spirv.Return
gpu.return
@@ -177,13 +177,13 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
- // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup>
+ // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
%C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index cfbbf9494a00ec..1835c6ae1d5f80 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -138,9 +138,9 @@ func.func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
// -----
-func.func @convert_f_to_u_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) {
- // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xi32, Subgroup>
+func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
@@ -222,9 +222,9 @@ func.func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> {
// -----
-func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) {
- // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup>
- %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup> to !spirv.coopmatrix<8x16xf64, Subgroup>
+func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
+ %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index e363d14b29ad17..ce7f6bc6118b31 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -29,10 +29,10 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
// -----
-func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
- // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
- %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
+ // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}
// -----
@@ -53,18 +53,18 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
// -----
-func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
+func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
- %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+ %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}
// -----
-func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.coopmatrix<8x16xf32, Subgroup> {
+func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
// expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
- %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.coopmatrix<8x16xf32, Subgroup>
+ %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
}
// -----
@@ -121,9 +121,9 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
// -----
-func.func @composite_extract_coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup>) -> f32 {
- // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup>
- %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.coopmatrix<8x16xf32, Subgroup>
+func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 {
+ // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
return %0 : f32
}
@@ -249,10 +249,10 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3
// -----
-func.func @composite_insert_coopmatrix(%arg0: !spirv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.coopmatrix<8x16xi32, Subgroup> {
- // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.coopmatrix<8x16xi32, Subgroup>
- return %0: !spirv.coopmatrix<8x16xi32, Subgroup>
+func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> {
+ // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>
}
// -----
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index de31458b94771e..2e387403964612 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -2,150 +2,150 @@
// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
spirv.Return
}
// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
spirv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup>
+spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_length
spirv.func @cooperative_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.ReturnValue %0 : i32
}
// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x32xi8, Subgroup>, !spirv.coopmatrix<32x8xi8, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// -----
// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
%0 = spirv.Constant 0: i32
- // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<16x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Workgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{matrix A and B non-integer element types must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// -----
-spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
@@ -153,7 +153,7 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
@@ -161,6 +161,6 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
// expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 8cdf2390d7232a..f52666af280e4b 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @matrix_times_scalar_2
- spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
- spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+ spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+ spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
}
// CHECK-LABEL: @matrix_transpose_1
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 174485f71d21d1..722e4434aeaf9f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -797,7 +797,7 @@ spirv.module Logical GLSL450 {
}
//===----------------------------------------------------------------------===//
-// spirv.SpecConstantComposite (spirv.coopmatrix)
+// spirv.SpecConstantComposite (spirv.NV.coopmatrix)
//===----------------------------------------------------------------------===//
// -----
@@ -805,7 +805,7 @@ spirv.module Logical GLSL450 {
spirv.module Logical GLSL450 {
spirv.SpecConstant @sc1 = 1.5 : f32
// expected-error @+1 {{unsupported composite type}}
- spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device>
+ spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device>
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7e2833e79646ea..06f0ccfd0a3774 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -439,18 +439,18 @@ func.func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.str
// CooperativeMatrix
//===----------------------------------------------------------------------===//
-// CHECK: func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>)
-func.func private @coop_matrix_type(!spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<8x8xf32, Workgroup>) -> ()
+// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
+func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
// -----
// expected-error @+1 {{expected ','}}
-func.func private @missing_scope(!spirv.coopmatrix<8x16xi32>) -> ()
+func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> ()
// -----
// expected-error @+1 {{expected rows and columns size}}
-func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup>) -> ()
+func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> ()
// -----
diff --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
index 9b060f18d0fc33..2eec99f72691cc 100644
--- a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
@@ -3,100 +3,100 @@
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
- spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup>
+ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
- spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup>
+ spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_length
spirv.func @cooperative_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.ReturnValue %0 : i32
}
// CHECK-LABEL: @cooperative_matrix_muladd
- spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_add
- spirv.func @cooperative_matrix_add(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.IAdd %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+ spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_sub
- spirv.func @cooperative_matrix_sub(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.ISub %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+ spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_sdiv
- spirv.func @cooperative_matrix_sdiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.SDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+ spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_udiv
- spirv.func @cooperative_matrix_udiv(%a : !spirv.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.UDiv %a, %b : !spirv.coopmatrix<8x16xi32, Subgroup>
+ spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fadd
- spirv.func @cooperative_matrix_fadd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FAdd %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+ spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fsub
- spirv.func @cooperative_matrix_fsub(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FSub %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+ spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_fdiv
- spirv.func @cooperative_matrix_fdiv(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FDiv %a, %b : !spirv.coopmatrix<8x16xf32, Subgroup>
+ spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+ %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_access_chain
- spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
+ spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
%0 = spirv.Constant 0: i32
- // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+ %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
}
}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0b71b3f24b19da..af8f41a30d24fc 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @matrix_times_scalar_3
- spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
- spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
+ spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
+ spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
}
// CHECK-LABEL: @matrix_transpose_1
More information about the Mlir-commits
mailing list