[Mlir-commits] [mlir] [mlir][spirv] Drop support for SPV_NV_cooperative_matrix (PR #76782)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Jan 2 22:07:26 PST 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/76782
This extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel.
Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication.
>From 01987b32f359c871af8f9dba39a6c9cb14b8f745 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 3 Jan 2024 01:02:49 -0500
Subject: [PATCH] [mlir][spirv] Drop support for SPV_NV_cooperative_matrix
This extension has been superseeded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD,
and Intel.
Given that the KHR version has been supported for nearly half a year,
drop the NV-specific extension to reduce the maintanance burden and code
duplication.
---
.../mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h | 12 +-
mlir/include/mlir/Conversion/Passes.td | 4 -
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 38 +--
.../SPIRV/IR/SPIRVCooperativeMatrixOps.td | 247 ------------------
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 27 --
.../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 12 +-
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 133 +---------
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 2 +-
.../Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 152 -----------
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 46 +---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 8 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 93 +------
.../SPIRV/Deserialization/DeserializeOps.cpp | 1 -
.../SPIRV/Deserialization/Deserializer.cpp | 33 ---
.../Target/SPIRV/Serialization/Serializer.cpp | 20 --
.../wmma-ops-to-spirv-khr-coop-matrix.mlir | 2 +-
.../wmma-ops-to-spirv-nv-coop-matrix.mlir | 194 --------------
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir | 24 --
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 39 ---
.../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 26 --
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 8 +-
.../SPIRV/IR/nv-cooperative-matrix-ops.mlir | 177 -------------
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 4 +-
mlir/test/Dialect/SPIRV/IR/types.mlir | 19 --
mlir/test/Target/SPIRV/matrix.mlir | 8 +-
.../SPIRV/nv-cooperative-matrix-ops.mlir | 102 --------
26 files changed, 49 insertions(+), 1382 deletions(-)
delete mode 100644 mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
delete mode 100644 mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
delete mode 100644 mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index cd650345f1daa2..d34549432161db 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -31,16 +31,10 @@ void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
-/// using the NV Cooperative Matrix extension.
-void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
- SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-
-/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
-/// conversion to the type converter. Defaults to KHR cooperative matrix types.
-/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
+/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
+/// conversion to the type converter.
void populateMMAToSPIRVCoopMatrixTypeConversion(
- SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
+ SPIRVTypeConverter &typeConverter);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..71be8841ca7c03 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -564,10 +564,6 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
Option<"use64bitIndex", "use-64bit-index",
"bool", /*default=*/"false",
"Use 64-bit integers to convert index types">,
- Option<"useCoopMatrixNV", "use-coop-matrix-nv",
- "bool", /*default=*/"false",
- "Use the NV cooperative matrix extension insted of the KHR extension"
- " to lower GPU WMMA ops">,
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ee1fbba1e2844e..6ec97e17c5dcc8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1253,12 +1253,6 @@ def SPIRV_C_RayTracingProvisionalKHR : I32EnumAttrCase<"RayTr
Extension<[SPV_KHR_ray_tracing]>
];
}
-def SPIRV_C_CooperativeMatrixNV : I32EnumAttrCase<"CooperativeMatrixNV", 5357> {
- list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
- list<Availability> availability = [
- Extension<[SPV_NV_cooperative_matrix]>
- ];
-}
def SPIRV_C_FragmentShaderSampleInterlockEXT : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363> {
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
list<Availability> availability = [
@@ -1501,7 +1495,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
SPIRV_C_RayTracingMotionBlurNV, SPIRV_C_PhysicalStorageBufferAddresses,
- SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_CooperativeMatrixNV,
+ SPIRV_C_RayTracingProvisionalKHR,
SPIRV_C_FragmentShaderSampleInterlockEXT,
SPIRV_C_FragmentShaderShadingRateInterlockEXT, SPIRV_C_ShaderSMBuiltinsNV,
SPIRV_C_FragmentShaderPixelInterlockEXT, SPIRV_C_DemoteToHelperInvocation,
@@ -4123,8 +4117,6 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
-def SPIRV_IsCooperativeMatrixNVType :
- CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
def SPIRV_IsJointMatrixType :
CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
@@ -4157,9 +4149,6 @@ def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
SPIRV_IsCooperativeMatrixType,
"any SPIR-V cooperative matrix type">;
-def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
- SPIRV_IsCooperativeMatrixNVType,
- "any SPIR-V NV cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4178,13 +4167,12 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
- SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
- SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
+ SPIRV_AnySampledImage
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4195,11 +4183,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
"::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
"Cooperative Matrix">;
-class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
- ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
- "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
- "Cooperative Matrix NV">;
-
class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
@@ -4213,12 +4196,11 @@ class SPIRV_ScalarOrVectorOf<Type type> :
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>,
- SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
+ SPIRV_CoopMatrixOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
AnyTypeOf<[SPIRV_AnyMatrix,
- SPIRV_CoopMatrixOfType<[type]>,
- SPIRV_CoopMatrixNVOfType<[type]>]>;
+ SPIRV_CoopMatrixOfType<[type]>]>;
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4480,11 +4462,6 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrix
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
-def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
-def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
-def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
-def SPIRV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
-def SPIRV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4585,9 +4562,6 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
SPIRV_OC_OpCooperativeMatrixLengthKHR,
- SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
- SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
- SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 29ad45bddd5529..46732ba19afed5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
];
}
-//===----------------------------------------------------------------------===//
-// SPV_NV_cooperative_matrix extension ops.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
- [Pure]> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Number of components of a cooperative matrix type accessible to each
- invocation when treated as a composite.
-
- Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
-
- Type is a cooperative matrix type.
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- ```
- }];
-
- let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- TypeAttr:$cooperative_matrix_type
- );
-
- let results = (outs
- SPIRV_Int32:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Load a cooperative matrix through a pointer.
-
- Result Type is the type of the loaded object. It must be a cooperative
- matrix type.
-
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
- Type operand is a scalar or vector type. The storage class of Pointer must
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
- supported) PhysicalStorageBufferEXT.
-
- Stride is the number of elements in the array in memory between the first
- component of consecutive rows (or columns) in the result. It must be a
- scalar integer type.
-
- ColumnMajor indicates whether the values loaded from memory are arranged in
- column-major or row-major order. It must be a boolean constant instruction,
- with false indicating row major and true indicating column major.
-
- Memory Access must be a Memory Access literal. If not present, it is the
- same as specifying None.
-
- If ColumnMajor is false, then elements (row,*) of the result are taken in
- order from contiguous locations starting at Pointer[row*Stride]. If
- ColumnMajor is true, then elements (*,col) of the result are taken in order
- from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
- decoration on Pointer is ignored.
-
- For a given dynamic instance of this instruction, all operands of this
- instruction must be the same for all invocations in a given scope instance
- (where the scope is the scope the cooperative matrix type was created with).
- All invocations in a given scope instance must be active or all must be
- inactive.
-
- ### Custom assembly form
-
- ``` {.ebnf}
- cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad`
- ssa-use `,` ssa-use `,` ssa-use
- (`[` memory-access `]`)? ` : `
- pointer-type `as`
- cooperative-matrix-type
- ```
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
- : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- ```
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_Integer:$stride,
- SPIRV_Bool:$columnmajor,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
- );
-
- let results = (outs
- SPIRV_AnyCooperativeMatrixNV:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd",
- [Pure, AllTypesMatch<["c", "result"]>]> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Linear-algebraic matrix multiply of A by B and then component-wise add C.
- The order of the operations is implementation-dependent. The internal
- precision of floating-point operations is defined by the client API.
- Integer operations are performed at the precision of the Result Type and are
- exact unless there is overflow or underflow, in which case the result is
- undefined.
-
- Result Type must be a cooperative matrix type with M rows and N columns.
-
- A is a cooperative matrix with M rows and K columns.
-
- B is a cooperative matrix with K rows and N columns.
-
- C is a cooperative matrix with M rows and N columns.
-
- The values of M, N, and K must be consistent across the result and operands.
- This is referred to as an MxNxK matrix multiply.
-
- A, B, C, and Result Type must have the same scope, and this defines the
- scope of the operation. A, B, C, and Result Type need not necessarily have
- the same component type, this is defined by the client API.
-
- If the Component Type of any matrix operand is an integer type, then its
- components are treated as signed if its Component Type has Signedness of 1
- and are treated as unsigned otherwise.
-
- For a given dynamic instance of this instruction, all invocations in a given
- scope instance must be active or all must be inactive (where the scope is
- the scope of the operation).
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, :
- !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- ```
- }];
-
- let assemblyFormat = [{
- operands attr-dict `:` type($a) `,` type($b) `->` type($c)
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyCooperativeMatrixNV:$a,
- SPIRV_AnyCooperativeMatrixNV:$b,
- SPIRV_AnyCooperativeMatrixNV:$c
- );
-
- let results = (outs
- SPIRV_AnyCooperativeMatrixNV:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Store a cooperative matrix through a pointer.
-
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
- Type operand is a scalar or vector type. The storage class of Pointer must
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
- supported) PhysicalStorageBufferEXT.
-
- Object is the object to store. Its type must be an
- OpTypeCooperativeMatrixNV.
-
- Stride is the number of elements in the array in memory between the first
- component of consecutive rows (or columns) in the result. It must be a
- scalar integer type.
-
- ColumnMajor indicates whether the values stored to memory are arranged in
- column-major or row-major order. It must be a boolean constant instruction,
- with false indicating row major and true indicating column major.
-
- Memory Access must be a Memory Access literal. If not present, it is the
- same as specifying None.
-
- ``` {.ebnf}
- coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore `
- ssa-use `, ` ssa-use `, `
- ssa-use `, ` ssa-use `, `
- (`[` memory-access `]`)? `:`
- pointer-type `,` coop-matrix-type
- ```
-
- #### Example:
-
- ```
- spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
- !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- ```
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_AnyCooperativeMatrixNV:$object,
- SPIRV_Integer:$stride,
- SPIRV_Bool:$columnmajor,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
- );
-
- let results = (outs);
-}
-
// -----
#endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index d946d936d4e6cf..55f0c787b44403 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,7 +29,6 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
-struct CooperativeMatrixNVTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
@@ -421,32 +420,6 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
};
-// SPIR-V NV cooperative matrix type
-class CooperativeMatrixNVType
- : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
- detail::CooperativeMatrixNVTypeStorage> {
-public:
- using Base::Base;
-
- static constexpr StringLiteral name = "spirv.NV.coopmatrix";
-
- static CooperativeMatrixNVType get(Type elementType, Scope scope,
- unsigned rows, unsigned columns);
- Type getElementType() const;
-
- /// Returns the scope of the matrix.
- Scope getScope() const;
- /// Returns the number of rows of the matrix.
- unsigned getRows() const;
- /// Returns the number of columns of the matrix.
- unsigned getColumns() const;
-
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-};
-
// SPIR-V joint matrix type
class JointMatrixINTELType
: public Type::TypeBase<JointMatrixINTELType, CompositeType,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 8279b3408a6e66..0dd0e7e21b0553 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -112,18 +112,12 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
- populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
- this->useCoopMatrixNV);
+ populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter);
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
- if (this->useCoopMatrixNV) {
- populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
- patterns);
- } else {
- populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
- patterns);
- }
+ populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+ patterns);
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 4a4281aaaf0dbc..92cc0eadb9784c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -32,19 +32,18 @@
namespace mlir {
//===----------------------------------------------------------------------===//
-// Patterns and helpers used by both the KHR and the NV lowering paths.
+// Patterns and helpers.
//===----------------------------------------------------------------------===//
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
///
-/// See SPV_NV_cooperative_matrix for supported elementwise ops.
+/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
static bool createElementwiseOp(ConversionPatternRewriter &builder,
gpu::SubgroupMmaElementwiseOp op, Type coopType,
ValueRange operands) {
- assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
- coopType)));
+ assert((isa<spirv::CooperativeMatrixType>(coopType)));
switch (op.getOpType()) {
case gpu::MMAElementwiseOp::ADDF:
@@ -89,8 +88,7 @@ bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
llvm::map_range(operands, [](Value v) { return v.getType(); })))
return false;
- return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
- operands.front().getType());
+ return isa<spirv::CooperativeMatrixType>(operands.front().getType());
}
namespace {
@@ -292,104 +290,6 @@ struct WmmaMmaOpToSPIRVLowering final
} // namespace
} // namespace khr
-
-//===----------------------------------------------------------------------===//
-// SPV_NV_cooperative_matrix
-//===----------------------------------------------------------------------===//
-
-namespace nv {
-namespace {
-
-/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
-/// dialect.
-struct WmmaLoadOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = subgroupMmaLoadMatrixOp->getLoc();
- auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
-
- gpu::MMAMatrixType retType =
- cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
- auto memrefType =
- cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
- Value bufferPtr =
- spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
- adaptor.getIndices(), loc, rewriter);
- auto coopType =
- typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
- if (!coopType)
- return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
- "type conversion failed");
-
- int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
- auto i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
- bool isColMajor = static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
- auto columnMajor = rewriter.create<spirv::ConstantOp>(
- loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor));
- rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
- subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
- spirv::MemoryAccessAttr());
- return success();
- }
-};
-
-/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
-/// dialect.
-struct WmmaStoreOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = subgroupMmaStoreMatrixOp->getLoc();
- auto memrefType =
- cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
- Value bufferPtr = spirv::getElementPtr(
- *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
- adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
- int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
- auto i32Type = rewriter.getI32Type();
- auto strideValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, IntegerAttr::get(i32Type, stride));
- bool useColMajor =
- static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
- auto columnMajor = rewriter.create<spirv::ConstantOp>(
- loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
- rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
- subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
- columnMajor, spirv::MemoryAccessAttr());
- return success();
- }
-};
-
-/// Converts GPU MMA Compute to
-/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
-struct WmmaMmaOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
- subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
- adaptor.getOpB(), adaptor.getOpC());
- return success();
- }
-};
-
-} // namespace
-} // namespace nv
} // namespace mlir
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
@@ -404,31 +304,8 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
/*benefit=*/2);
}
-void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
- SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
- using namespace mlir;
- MLIRContext *context = patterns.getContext();
- patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
- nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
- WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
- // Give the following patterns higher benefit to prevail over the default one.
- patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
- /*benefit=*/2);
-}
-
void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
- mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
- if (useNVTypes) {
- typeConverter.addConversion([](gpu::MMAMatrixType type) {
- ArrayRef<int64_t> retTypeShape = type.getShape();
- Type elementType = type.getElementType();
- return spirv::CooperativeMatrixNVType::get(
- elementType, spirv::Scope::Subgroup, retTypeShape[0],
- retTypeShape[1]);
- });
- return;
- }
-
+ mlir::SPIRVTypeConverter &typeConverter) {
typeConverter.addConversion([](gpu::MMAMatrixType type) {
ArrayRef<int64_t> retTypeShape = type.getShape();
Type elementType = type.getElementType();
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index f24da2ca5c3f24..52b4380ed27f7c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -37,7 +37,7 @@ static LogicalResult verifyCastOp(Operation *op,
auto [operandElemTy, resultElemTy] =
TypeSwitch<Type, TypePair>(operandType)
.Case<VectorType, spirv::CooperativeMatrixType,
- spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
+ spirv::JointMatrixINTELType>(
[resultType](auto concreteOperandTy) -> TypePair {
if (auto concreteResultTy =
dyn_cast<decltype(concreteOperandTy)>(resultType)) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index c8b274ceec3e59..d532d466334a56 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -136,156 +136,4 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult NVCooperativeMatrixLengthOp::verify() {
- if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
- return emitOpError(
- "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
- << getCooperativeMatrixType() << " instead";
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
- Type strideType = parser.getBuilder().getIntegerType(32);
- Type columnMajorType = parser.getBuilder().getIntegerType(1);
- Type ptrType;
- Type elementType;
- if (parser.parseOperandList(operandInfo, 3) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
- return failure();
- }
- if (parser.resolveOperands(operandInfo,
- {ptrType, strideType, columnMajorType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- result.addTypes(elementType);
- return success();
-}
-
-void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getStride() << ", "
- << getColumnmajor();
- // Print optional memory access attribute.
- if (auto memAccess = getMemoryAccess())
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult
-verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
- Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
- if (!llvm::isa<ScalarType>(pointeeType) &&
- !llvm::isa<VectorType>(pointeeType))
- return op->emitError(
- "Pointer must point to a scalar or vector type but provided ")
- << pointeeType;
- StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
- if (storage != StorageClass::Workgroup &&
- storage != StorageClass::StorageBuffer &&
- storage != StorageClass::PhysicalStorageBuffer)
- return op->emitError(
- "Pointer storage class must be Workgroup, StorageBuffer or "
- "PhysicalStorageBufferEXT but provided ")
- << stringifyStorageClass(storage);
- return success();
-}
-
-LogicalResult NVCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
- getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
- Type strideType = parser.getBuilder().getIntegerType(32);
- Type columnMajorType = parser.getBuilder().getIntegerType(1);
- Type ptrType;
- Type elementType;
- if (parser.parseOperandList(operandInfo, 4) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(ptrType) || parser.parseComma() ||
- parser.parseType(elementType)) {
- return failure();
- }
- if (parser.resolveOperands(
- operandInfo, {ptrType, elementType, strideType, columnMajorType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
- << ", " << getColumnmajor();
- // Print optional memory access attribute.
- if (auto memAccess = getMemoryAccess())
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
-}
-
-LogicalResult NVCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
- getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixMulAdd
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) {
- if (op.getC().getType() != op.getResult().getType())
- return op.emitOpError("result and third operand must have the same type");
- auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
- auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
- auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
- auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
- if (typeA.getRows() != typeR.getRows() ||
- typeA.getColumns() != typeB.getRows() ||
- typeB.getColumns() != typeR.getColumns())
- return op.emitOpError("matrix size must match");
- if (typeR.getScope() != typeA.getScope() ||
- typeR.getScope() != typeB.getScope() ||
- typeR.getScope() != typeC.getScope())
- return op.emitOpError("matrix scope must match");
- auto elementTypeA = typeA.getElementType();
- auto elementTypeB = typeB.getElementType();
- if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
- if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
- llvm::cast<IntegerType>(elementTypeB).getWidth())
- return op.emitOpError(
- "matrix A and B integer element types must be the same bit width");
- } else if (elementTypeA != elementTypeB) {
- return op.emitOpError(
- "matrix A and B non-integer element types must match");
- }
- if (typeR.getElementType() != typeC.getElementType())
- return op.emitOpError("matrix accumulator element type must match");
- return success();
-}
-
-LogicalResult NVCooperativeMatrixMulAddOp::verify() {
- return verifyCoopMatrixMulAddNV(*this);
-}
-
} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..9d4d1aec36709e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -360,37 +360,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
-// nv-cooperative-matrix-type ::=
-// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>`
-static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
- DialectAsmParser &parser) {
- if (parser.parseLess())
- return Type();
-
- SmallVector<int64_t, 2> dims;
- SMLoc countLoc = parser.getCurrentLocation();
- if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
- return Type();
-
- if (dims.size() != 2) {
- parser.emitError(countLoc, "expected rows and columns size");
- return Type();
- }
-
- auto elementTy = parseAndVerifyType(dialect, parser);
- if (!elementTy)
- return Type();
-
- Scope scope;
- if (parser.parseComma() ||
- spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
- return Type();
-
- if (parser.parseGreater())
- return Type();
- return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
-}
-
// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
// element-type
// `,` layout `,` scope`>`
@@ -810,8 +779,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
- if (keyword == "NV.coopmatrix")
- return parseCooperativeMatrixNVType(*this, parser);
if (keyword == "jointmatrix")
return parseJointMatrixType(*this, parser);
if (keyword == "image")
@@ -917,12 +884,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
<< type.getUse() << ">";
}
-static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
- os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
- os << type.getElementType() << ", " << stringifyScope(type.getScope());
- os << ">";
-}
-
static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", "
@@ -937,10 +898,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
- .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
- JointMatrixINTELType, PointerType, RuntimeArrayType, ImageType,
- SampledImageType, StructType, MatrixType>(
- [&](auto type) { print(type, os); })
+ .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
+ RuntimeArrayType, ImageType, SampledImageType, StructType,
+ MatrixType>([&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..dc558b878b3b76 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -374,8 +374,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
- .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
- spirv::JointMatrixINTELType>(
+ .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
[](auto coopType) { return coopType.getElementType(); })
.Default([](Type) { return nullptr; });
@@ -1611,8 +1610,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
LogicalResult spirv::MatrixTimesScalarOp::verify() {
Type elementType =
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
- .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
- spirv::MatrixType>(
+ .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
.Default([](Type) { return nullptr; });
@@ -1751,7 +1749,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
return emitError("result type must be a composite type, but provided ")
<< getType();
- if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
+ if (llvm::isa<spirv::CooperativeMatrixType>(cType))
return emitError("unsupported composite type ") << cType;
if (llvm::isa<spirv::JointMatrixINTELType>(cType))
return emitError("unsupported composite type ") << cType;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f1bac6490837b9..3f25696aa5eb6e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -95,9 +95,8 @@ bool CompositeType::classof(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
- spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType,
- spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType>(type);
+ spirv::JointMatrixINTELType, spirv::MatrixType,
+ spirv::RuntimeArrayType, spirv::StructType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -108,8 +107,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
- JointMatrixINTELType, RuntimeArrayType, VectorType>(
+ .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
+ RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
@@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
- if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
+ if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
}
@@ -143,16 +142,16 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType,
- JointMatrixINTELType, RuntimeArrayType>(*this);
+ return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
+ RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
- JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
+ RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return llvm::cast<ScalarType>(type.getElementType())
@@ -165,8 +164,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
- JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
+ RuntimeArrayType, StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
@@ -267,70 +266,6 @@ void CooperativeMatrixType::getCapabilities(
capabilities.push_back(caps);
}
-//===----------------------------------------------------------------------===//
-// CooperativeMatrixNVType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::CooperativeMatrixNVTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
-
- static CooperativeMatrixNVTypeStorage *
- construct(TypeStorageAllocator &allocator, const KeyTy &key) {
- return new (allocator.allocate<CooperativeMatrixNVTypeStorage>())
- CooperativeMatrixNVTypeStorage(key);
- }
-
- bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, scope, rows, columns);
- }
-
- CooperativeMatrixNVTypeStorage(const KeyTy &key)
- : elementType(std::get<0>(key)), rows(std::get<2>(key)),
- columns(std::get<3>(key)), scope(std::get<1>(key)) {}
-
- Type elementType;
- unsigned rows;
- unsigned columns;
- Scope scope;
-};
-
-CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
- Scope scope, unsigned rows,
- unsigned columns) {
- return Base::get(elementType.getContext(), elementType, scope, rows, columns);
-}
-
-Type CooperativeMatrixNVType::getElementType() const {
- return getImpl()->elementType;
-}
-
-Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
-
-unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
-
-unsigned CooperativeMatrixNVType::getColumns() const {
- return getImpl()->columns;
-}
-
-void CooperativeMatrixNVType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
-}
-
-void CooperativeMatrixNVType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
- static const Capability caps[] = {Capability::CooperativeMatrixNV};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
-}
-
//===----------------------------------------------------------------------===//
// JointMatrixType
//===----------------------------------------------------------------------===//
@@ -1312,7 +1247,7 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
- addTypes<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType, ImageType,
- JointMatrixINTELType, MatrixType, PointerType, RuntimeArrayType,
- SampledImageType, StructType>();
+ addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
+ MatrixType, PointerType, RuntimeArrayType, SampledImageType,
+ StructType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 954aaa98c32998..a678124bf48322 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -165,7 +165,6 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
- case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
return processTypeForwardPointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..948dcfb4885b3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -778,8 +778,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processArrayType(operands);
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processCooperativeMatrixTypeKHR(operands);
- case spirv::Opcode::OpTypeCooperativeMatrixNV:
- return processCooperativeMatrixTypeNV(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
@@ -955,37 +953,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
return success();
}
-LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
- ArrayRef<uint32_t> operands) {
- if (operands.size() != 5) {
- return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
- "type and row x column parameters");
- }
-
- Type elementTy = getType(operands[1]);
- if (!elementTy) {
- return emitError(unknownLoc,
- "OpTypeCooperativeMatrixNV references undefined <id> ")
- << operands[1];
- }
-
- std::optional<spirv::Scope> scope =
- spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
- if (!scope) {
- return emitError(
- unknownLoc,
- "OpTypeCooperativeMatrixNV references undefined scope <id> ")
- << operands[2];
- }
-
- unsigned rows = getConstantInt(operands[3]).getInt();
- unsigned columns = getConstantInt(operands[4]).getInt();
-
- typeMap[operands[0]] =
- spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
- return success();
-}
-
LogicalResult
spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 6) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..08395dd4cf522f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -632,26 +632,6 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto cooperativeMatrixType =
- dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
- uint32_t elementTypeID = 0;
- if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
- elementTypeID, serializationCtx))) {
- return failure();
- }
- typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
- auto getConstantOp = [&](uint32_t id) {
- auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
- return prepareConstantInt(loc, attr);
- };
- llvm::append_values(
- operands, elementTypeID,
- getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
- getConstantOp(cooperativeMatrixType.getRows()),
- getConstantOp(cooperativeMatrixType.getColumns()));
- return success();
- }
-
if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index f129cc8ce84ec3..477f344b1ae5f4 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
+// RUN: mlir-opt --convert-gpu-to-spirv --cse \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
module attributes {
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
deleted file mode 100644
index ec7da92704c07c..00000000000000
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
+++ /dev/null
@@ -1,194 +0,0 @@
-// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
-// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_load_op
- // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
- gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
- // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
- gpu.func @gpu_wmma_load_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer> as !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_store_op
- // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
- // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
- // CHECK-SAME: {{%.*}}: !spirv.NV.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
- // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
- gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, !spirv.NV.coopmatrix<16x16xf16, Subgroup> -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
- gpu.func @gpu_wmma_constant_op() kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: {{%.*}} = spirv.Constant
- %cst = arith.constant 1.0 : f16
- // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- gpu.func @gpu_wmma_elementwise_op_default(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup>
- %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
- // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: %[[S:.+]]: f16
- gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- %C = gpu.subgroup_mma_elementwise mulf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- %D = gpu.subgroup_mma_elementwise mulf %B, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: spirv.Return
- gpu.return
- }
- }
-}
-
-// -----
-
-module attributes {
- gpu.container_module,
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
- gpu.module @kernels {
- // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
- // CHECK-SAME: %[[A:.+]]: !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- // CHECK-SAME: %[[S:.+]]: f16
- gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16) kernel
- attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
- // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
- // CHECK: %{{.+}} = spirv.FAdd %[[A]], %[[SM]] : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
- %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
- gpu.return
- }
- }
-}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..aaee2ccd3cb8c1 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -146,14 +146,6 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou
// -----
-func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
- // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
@@ -238,14 +230,6 @@ func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, M
// -----
-func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
- // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
- %0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
- spirv.Return
-}
-
-// -----
-
func.func @f_convert_vector(%arg0 : f32) -> f32 {
// expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}}
%0 = spirv.FConvert %arg0 : f32 to f32
@@ -254,14 +238,6 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
// -----
-func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) {
- // expected-error @+1 {{incompatible operand and result types}}
- %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
- spirv.Return
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index b10677f0f5f99f..3fc8dfb2767d1e 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -32,13 +32,6 @@ func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<
return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
}
-// CHECK-LABEL: func @composite_construct_coopmatrix_nv
-func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
- // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
// -----
func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
@@ -75,22 +68,6 @@ func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32
// -----
-func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
- // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
- %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
-// -----
-
-func.func @composite_construct_NV.coopmatrix_incorrect_element_type(%arg0 : i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
- // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
- %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-}
-
-// -----
-
func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> {
// expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
%0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32>
@@ -143,14 +120,6 @@ func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
// -----
-func.func @composite_extract_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) -> f32 {
- // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %0 = spirv.CompositeExtract %arg0[2 : i32] : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- return %0 : f32
-}
-
-// -----
-
func.func @composite_extract_no_ssa_operand() -> () {
// expected-error @+1 {{expected SSA operand}}
%0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
@@ -271,14 +240,6 @@ func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f3
// -----
-func.func @composite_insert_NV.coopmatrix(%arg0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spirv.NV.coopmatrix<8x16xi32, Subgroup> {
- // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- return %0: !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-}
-
-// -----
-
func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
// expected-error @+1 {{expected at least one index}}
%0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 445ab8a48d3ce6..d3e1dbc229ef99 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -13,14 +13,6 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
// -----
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
- // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
- %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.ReturnValue %0 : i32
-}
-
-// -----
-
// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
@@ -118,24 +110,6 @@ spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageB
// -----
-spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // expected-error @+1 {{expected '<'}}
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
- !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
- // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
- %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
- !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index f52666af280e4b..372fcc6e514b97 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -9,10 +9,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @matrix_times_scalar_2
- spirv.func @matrix_times_scalar_2(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+ spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
+ spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
}
// CHECK-LABEL: @matrix_transpose_1
diff --git a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
deleted file mode 100644
index 43cbf61b60ef0b..00000000000000
--- a/mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir
+++ /dev/null
@@ -1,177 +0,0 @@
-// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
-
-//===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
- %0 = spirv.Constant 0: i32
- // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Workgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{matrix A and B non-integer element types must match}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xui8, Subgroup>, !spirv.NV.coopmatrix<16x8xsi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // expected-error @+1 {{Pointer must point to a scalar or vector type}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
- // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
- // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
- spirv.ReturnValue %0 : i32
-}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..6f6ce1202d1704 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -797,7 +797,7 @@ spirv.module Logical GLSL450 {
}
//===----------------------------------------------------------------------===//
-// spirv.SpecConstantComposite (spirv.NV.coopmatrix)
+// spirv.SpecConstantComposite (spirv.KHR.coopmatrix)
//===----------------------------------------------------------------------===//
// -----
@@ -805,7 +805,7 @@ spirv.module Logical GLSL450 {
spirv.module Logical GLSL450 {
spirv.SpecConstant @sc1 = 1.5 : f32
// expected-error @+1 {{unsupported composite type}}
- spirv.SpecConstantComposite @scc (@sc1) : !spirv.NV.coopmatrix<8x16xf32, Device>
+ spirv.SpecConstantComposite @scc (@sc1) : !spirv.coopmatrix<8x16xf32, Device, MatrixA>
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index e10a6fc77e8566..05ab91b6db6bd9 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -479,25 +479,6 @@ func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup
// -----
-//===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
-//===----------------------------------------------------------------------===//
-
-// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
-func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
-
-// -----
-
-// expected-error @+1 {{expected ','}}
-func.func private @missing_scope(!spirv.NV.coopmatrix<8x16xi32>) -> ()
-
-// -----
-
-// expected-error @+1 {{expected rows and columns size}}
-func.func private @missing_count(!spirv.NV.coopmatrix<8xi32, Subgroup>) -> ()
-
-// -----
-
//===----------------------------------------------------------------------===//
// Matrix
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index af8f41a30d24fc..b52c3f4aa2f117 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -23,10 +23,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @matrix_times_scalar_3
- spirv.func @matrix_times_scalar_3(%arg0 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.NV.coopmatrix<16x16xf16, Subgroup> "None" {
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.NV.coopmatrix<16x16xf16, Subgroup>, f16
- spirv.ReturnValue %result : !spirv.NV.coopmatrix<16x16xf16, Subgroup>
+ spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> "None" {
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+ %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+ spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
}
// CHECK-LABEL: @matrix_transpose_1
diff --git a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
deleted file mode 100644
index 2eec99f72691cc..00000000000000
--- a/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
+++ /dev/null
@@ -1,102 +0,0 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
- // CHECK-LABEL: @cooperative_matrix_load
- spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_load_memaccess
- spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_store
- spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_store_memaccess
- spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_length
- spirv.func @cooperative_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.ReturnValue %0 : i32
- }
-
- // CHECK-LABEL: @cooperative_matrix_muladd
- spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_add
- spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_sub
- spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_sdiv
- spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_udiv
- spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_fadd
- spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_fsub
- spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_fdiv
- spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @cooperative_matrix_access_chain
- spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spirv.ptr<f32, Function> "None" {
- %0 = spirv.Constant 0: i32
- // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- %1 = spirv.AccessChain %a[%0] : !spirv.ptr<!spirv.NV.coopmatrix<8x16xf32, Subgroup>, Function>, i32
- spirv.ReturnValue %1 : !spirv.ptr<f32, Function>
- }
-}
More information about the Mlir-commits
mailing list