[Mlir-commits] [mlir] [mlir][spirv] Drop support for SPV_INTEL_joint_matrix (PR #102332)
Andrea Faulds
llvmlistbot at llvm.org
Wed Aug 7 09:51:27 PDT 2024
https://github.com/andfau-amd created https://github.com/llvm/llvm-project/pull/102332
This was a "preview" extension, never formalized, that has now been supplanted by SPV_KHR_cooperative_matrix.
>From 2cd9bfdfa0d3a4a2ce813b6a5fb7d17f7883f4e8 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Wed, 7 Aug 2024 18:46:37 +0200
Subject: [PATCH] [mlir][spirv] Drop support for SPV_INTEL_joint_matrix
This was a "preview" extension, never formalized, that has now been
supplanted by SPV_KHR_cooperative_matrix.
---
.../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td | 21 --
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 35 +--
.../Dialect/SPIRV/IR/SPIRVJointMatrixOps.td | 243 ------------------
.../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 -
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 30 ---
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt | 1 -
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 3 +-
mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp | 84 ------
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 50 +---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 4 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 96 +------
.../SPIRV/Deserialization/DeserializeOps.cpp | 2 -
.../SPIRV/Deserialization/Deserializer.cpp | 36 ---
.../SPIRV/Deserialization/Deserializer.h | 2 -
.../Target/SPIRV/Serialization/Serializer.cpp | 19 --
.../Dialect/SPIRV/IR/joint-matrix-ops.mlir | 99 -------
mlir/test/Target/SPIRV/joint-matrix-ops.mlir | 45 ----
17 files changed, 19 insertions(+), 752 deletions(-)
delete mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
delete mode 100644 mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
delete mode 100644 mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
delete mode 100644 mlir/test/Target/SPIRV/joint-matrix-ops.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 3a11284da05122..f2a12f68d481b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -101,27 +101,6 @@ def SPIRV_CooperativeMatrixPropertiesNVArrayAttr :
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesNVAttr,
"CooperativeMatrixPropertiesNV array attribute">;
-// Description of the supported joint matrix operations. See
-// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
-def SPIRV_JointMatrixPropertiesINTELAttr :
- SPIRV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
- let parameters = (ins
- "int":$m_size,
- "int":$n_size,
- "int":$k_size,
- "mlir::Type":$a_type,
- "mlir::Type":$b_type,
- "mlir::Type":$c_type,
- "mlir::Type":$result_type,
- "mlir::spirv::ScopeAttr":$scope
- );
- let assemblyFormat = "`<` struct(params) `>`";
-}
-
-def SPIRV_JointMatrixPropertiesINTELArrayAttr :
- TypedArrayAttrBase<SPIRV_JointMatrixPropertiesINTELAttr,
- "JointMatrixPropertiesINTEL array attribute">;
-
// This attribute specifies the limits for various resources on the target
// architecture.
//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b38978272c5bdc..af0b2624feb327 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -399,7 +399,6 @@ def SPV_INTEL_debug_module : I32EnumAttrCase<"SPV_INTEL_de
def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
-def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
@@ -459,7 +458,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
- SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
+ SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1410,12 +1409,6 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
];
}
-def SPIRV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
- list<Availability> availability = [
- Extension<[SPV_INTEL_joint_matrix]>
- ];
-}
-
def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
Extension<[SPV_INTEL_bfloat16_conversion]>
@@ -1514,7 +1507,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
- SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL
+ SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4131,8 +4124,6 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
-def SPIRV_IsJointMatrixType :
- CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4164,8 +4155,6 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
-def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
- "any SPIR-V joint matrix type">;
def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
"any SPIR-V matrix type">;
def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4180,12 +4169,11 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
- SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4196,11 +4184,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
"::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
"Cooperative Matrix">;
-class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
- ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
- "::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
- "Joint Matrix">;
-
class SPIRV_VectorOf<Type type> :
VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
@@ -4482,12 +4465,6 @@ def SPIRV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
-def SPIRV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
-def SPIRV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
-def SPIRV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
-def SPIRV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
-def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
-
def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
@@ -4579,10 +4556,6 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR,
- SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL,
- SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL,
- SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL,
-
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL
]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
deleted file mode 100644
index f96849de9abb1e..00000000000000
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
+++ /dev/null
@@ -1,243 +0,0 @@
-//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the op definition spec of joint matrix multiply extension ops.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
-#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
-
-// -----
-
-def SPIRV_INTELJointMatrixWorkItemLengthOp : SPIRV_IntelVendorOp<"JointMatrixWorkItemLength",
- [Pure]> {
- let summary = "See extension SPV_INTEL_joint_matrix";
-
- let description = [{
- Return number of components owned by the current work-item in
- a joint matrix.
-
- Result Type must be an 32-bit unsigned integer type scalar.
-
- Type is a joint matrix type.
-
- #### Example:
-
- ```
- %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<Subgroup, i32, 8, 16>
- ```
- }];
-
- let assemblyFormat = "attr-dict `:` $joint_matrix_type";
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_INTEL_joint_matrix]>,
- Capability<[SPIRV_C_JointMatrixINTEL]>
- ];
-
- let arguments = (ins
- TypeAttr:$joint_matrix_type
- );
-
- let results = (outs
- SPIRV_Int32:$result
- );
- let hasVerifier = 0;
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixLoadOp : SPIRV_IntelVendorOp<"JointMatrixLoad", []> {
- let summary = "See extension SPV_INTEL_joint_matrix";
-
- let description = [{
- Load a matrix through a pointer.
-
- Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
-
- Pointer is the pointer to load through. It specifies start of memory region where
- elements of the matrix are stored and arranged according to Layout.
-
- Stride is the number of elements in memory between beginnings of successive rows,
- columns (or words) in the result. It must be a scalar integer type.
-
- Layout indicates how the values loaded from memory are arranged. It must be the
- result of a constant instruction.
-
- Scope is syncronization scope for operation on the matrix. It must be the result
- of a constant instruction with scalar integer type.
-
- If present, any Memory Operands must begin with a memory operand literal. If not
- present, it is the same as specifying the memory operand None.
-
- #### Example:
- ```mlir
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride
- {memory_access = #spirv.memory_access<Volatile>} :
- (!spirv.ptr<i32, CrossWorkgroup>, i32) ->
- !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
- ```
- }];
-
- let assemblyFormat = [{
- $scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_INTEL_joint_matrix]>,
- Capability<[SPIRV_C_JointMatrixINTEL]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_Integer:$stride,
- SPIRV_MatrixLayoutAttr:$layout,
- SPIRV_ScopeAttr:$scope,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
- );
-
- let results = (outs
- SPIRV_AnyJointMatrix:$result
- );
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixMadOp : SPIRV_IntelVendorOp<"JointMatrixMad",
- [Pure, AllTypesMatch<["c", "result"]>]> {
- let summary = "See extension SPV_INTEL_joint_matrix";
-
- let description = [{
- Multiply matrix A by matrix B and add matrix C to the result
- of the multiplication: A*B+C. Here A is a M x K matrix, B is
- a K x N matrix and C is a M x N matrix.
-
- Behavior is undefined if sizes of operands do not meet the
- conditions above. All operands and the Result Type must be
- OpTypeJointMatrixINTEL.
-
- A must be a OpTypeJointMatrixINTEL whose Component Type is a
- signed numerical type, Row Count equals to M and Column Count
- equals to K
-
- B must be a OpTypeJointMatrixINTEL whose Component Type is a
- signed numerical type, Row Count equals to K and Column Count
- equals to N
-
- C and Result Type must be a OpTypeJointMatrixINTEL with Row
- Count equals to M and Column Count equals to N
-
- Scope is syncronization scope for operation on the matrix.
- It must be the result of a constant instruction with scalar
- integer type.
-
- #### Example:
- ```mlir
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c :
- !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
- !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
- -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- ```
-
- }];
-
- let assemblyFormat = [{
- $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_INTEL_joint_matrix]>,
- Capability<[SPIRV_C_JointMatrixINTEL]>
- ];
-
- let arguments = (ins
- SPIRV_AnyJointMatrix:$a,
- SPIRV_AnyJointMatrix:$b,
- SPIRV_AnyJointMatrix:$c,
- SPIRV_ScopeAttr:$scope
- );
-
- let results = (outs
- SPIRV_AnyJointMatrix:$result
- );
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixStoreOp : SPIRV_IntelVendorOp<"JointMatrixStore", []> {
- let summary = "See extension SPV_INTEL_joint_matrix";
-
- let description = [{
- Store a matrix through a pointer.
-
- Pointer is the pointer to store through. It specifies
- start of memory region where elements of the matrix must
- be stored and arranged according to Layout.
-
- Object is the matrix to store. It must be
- OpTypeJointMatrixINTEL.
-
- Stride is the number of elements in memory between beginnings
- of successive rows, columns (or words) of the Object. It must
- be a scalar integer type.
-
- Layout indicates how the values stored to memory are arranged.
- It must be the result of a constant instruction.
-
- Scope is syncronization scope for operation on the matrix.
- It must be the result of a constant instruction with scalar
- integer type.
-
- If present, any Memory Operands must begin with a memory operand
- literal. If not present, it is the same as specifying the memory
- operand None.
-
- #### Example:
- ```mlir
- spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride
- {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>,
- !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
- ```
-
- }];
-
- let assemblyFormat = [{
- $scope $layout operands attr-dict `:` `(` type(operands) `)`
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_INTEL_joint_matrix]>,
- Capability<[SPIRV_C_JointMatrixINTEL]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_AnyJointMatrix:$object,
- SPIRV_Integer:$stride,
- SPIRV_MatrixLayoutAttr:$layout,
- SPIRV_ScopeAttr:$scope,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
- );
-
- let results = (outs);
-}
-
-// -----
-
-#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 13533d1d65b8ff..9912f195ba11e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -30,7 +30,6 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
-include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..d00bd818f48cca 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -30,7 +30,6 @@ namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
-struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
@@ -420,35 +419,6 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
};
-// SPIR-V joint matrix type
-class JointMatrixINTELType
- : public Type::TypeBase<JointMatrixINTELType, CompositeType,
- detail::JointMatrixTypeStorage> {
-public:
- using Base::Base;
-
- static constexpr StringLiteral name = "spirv.jointmatrix";
-
- static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
- unsigned columns, MatrixLayout matrixLayout);
- Type getElementType() const;
-
- /// Return the scope of the joint matrix.
- Scope getScope() const;
- /// return the number of rows of the matrix.
- unsigned getRows() const;
- /// return the number of columns of the matrix.
- unsigned getColumns() const;
-
- /// return the layout of the matrix
- MatrixLayout getMatrixLayout() const;
-
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-};
-
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b185264211474f..7d760e0dd80222 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -9,7 +9,6 @@ add_mlir_dialect_library(MLIRSPIRVDialect
CooperativeMatrixOps.cpp
GroupOps.cpp
IntegerDotProductOps.cpp
- JointMatrixOps.cpp
MemoryOps.cpp
SPIRVAttributes.cpp
SPIRVCanonicalization.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index 52b4380ed27f7c..e27dc274673be4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -36,8 +36,7 @@ static LogicalResult verifyCastOp(Operation *op,
using TypePair = std::pair<Type, Type>;
auto [operandElemTy, resultElemTy] =
TypeSwitch<Type, TypePair>(operandType)
- .Case<VectorType, spirv::CooperativeMatrixType,
- spirv::JointMatrixINTELType>(
+ .Case<VectorType, spirv::CooperativeMatrixType>(
[resultType](auto concreteOperandTy) -> TypePair {
if (auto concreteResultTy =
dyn_cast<decltype(concreteOperandTy)>(resultType)) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
deleted file mode 100644
index 63305ecdd0c4e9..00000000000000
--- a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
+++ /dev/null
@@ -1,84 +0,0 @@
-//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines the Intel Joint Matrix operations in the SPIR-V dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-
-namespace mlir {
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
- Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
- if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
- !llvm::isa<VectorType>(pointeeType))
- return op->emitError(
- "Pointer must point to a scalar or vector type but provided ")
- << pointeeType;
- spirv::StorageClass storage =
- llvm::cast<spirv::PointerType>(pointer).getStorageClass();
- if (storage != spirv::StorageClass::Workgroup &&
- storage != spirv::StorageClass::CrossWorkgroup &&
- storage != spirv::StorageClass::UniformConstant &&
- storage != spirv::StorageClass::Generic)
- return op->emitError("Pointer storage class must be Workgroup or "
- "CrossWorkgroup but provided ")
- << stringifyStorageClass(storage);
- return success();
-}
-
-LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
- return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
- getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixStore
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
- return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
- getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixMad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
- if (op.getC().getType() != op.getResult().getType())
- return op.emitOpError("result and third operand must have the same type");
- auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
- auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
- auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
- auto typeR =
- llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
- if (typeA.getRows() != typeR.getRows() ||
- typeA.getColumns() != typeB.getRows() ||
- typeB.getColumns() != typeR.getColumns())
- return op.emitOpError("matrix size must match");
- if (typeR.getScope() != typeA.getScope() ||
- typeR.getScope() != typeB.getScope() ||
- typeR.getScope() != typeC.getScope())
- return op.emitOpError("matrix scope must match");
- if (typeA.getElementType() != typeB.getElementType() ||
- typeR.getElementType() != typeC.getElementType())
- return op.emitOpError("matrix element type must match");
- return success();
-}
-
-LogicalResult spirv::INTELJointMatrixMadOp::verify() {
- return verifyJointMatrixMad(*this);
-}
-
-} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 72488d6e5d0b09..48be287ef833b2 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -362,41 +362,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
-// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
-// element-type
-// `,` layout `,` scope`>`
-static Type parseJointMatrixType(SPIRVDialect const &dialect,
- DialectAsmParser &parser) {
- if (parser.parseLess())
- return Type();
-
- SmallVector<int64_t, 2> dims;
- SMLoc countLoc = parser.getCurrentLocation();
- if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
- return Type();
-
- if (dims.size() != 2) {
- parser.emitError(countLoc, "expected rows and columns size");
- return Type();
- }
-
- auto elementTy = parseAndVerifyType(dialect, parser);
- if (!elementTy)
- return Type();
- MatrixLayout matrixLayout;
- if (parser.parseComma() ||
- spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
- return Type();
- Scope scope;
- if (parser.parseComma() ||
- spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
- return Type();
- if (parser.parseGreater())
- return Type();
- return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
- matrixLayout);
-}
-
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -781,8 +746,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
- if (keyword == "jointmatrix")
- return parseJointMatrixType(*this, parser);
if (keyword == "image")
return parseImageType(*this, parser);
if (keyword == "ptr")
@@ -886,13 +849,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
<< type.getUse() << ">";
}
-static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
- os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
- os << type.getElementType() << ", "
- << stringifyMatrixLayout(type.getMatrixLayout());
- os << ", " << stringifyScope(type.getScope()) << ">";
-}
-
static void print(MatrixType type, DialectAsmPrinter &os) {
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
os << ">";
@@ -900,9 +856,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
- .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
- RuntimeArrayType, ImageType, SampledImageType, StructType,
- MatrixType>([&](auto type) { print(type, os); })
+ .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
+ ImageType, SampledImageType, StructType, MatrixType>(
+ [&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d9bc05acddc82b..c8386fecea038a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -373,7 +373,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
- .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
+ .Case<spirv::CooperativeMatrixType>(
[](auto coopType) { return coopType.getElementType(); })
.Default([](Type) { return nullptr; });
@@ -1834,8 +1834,6 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
if (llvm::isa<spirv::CooperativeMatrixType>(cType))
return emitError("unsupported composite type ") << cType;
- if (llvm::isa<spirv::JointMatrixINTELType>(cType))
- return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..cd531882cafb08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -95,8 +95,8 @@ bool CompositeType::classof(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
- spirv::JointMatrixINTELType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType>(type);
+ spirv::MatrixType, spirv::RuntimeArrayType,
+ spirv::StructType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -107,8 +107,7 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
- RuntimeArrayType, VectorType>(
+ .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
@@ -130,10 +129,6 @@ unsigned CompositeType::getNumElements() const {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
}
- if (llvm::isa<JointMatrixINTELType>(*this)) {
- llvm_unreachable(
- "invalid to query number of elements of spirv::JointMatrix type");
- }
if (llvm::isa<RuntimeArrayType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@@ -142,16 +137,15 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
- RuntimeArrayType>(*this);
+ return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
- RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
+ StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return llvm::cast<ScalarType>(type.getElementType())
@@ -164,8 +158,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
- RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
+ StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
@@ -266,75 +260,6 @@ void CooperativeMatrixType::getCapabilities(
capabilities.push_back(caps);
}
-//===----------------------------------------------------------------------===//
-// JointMatrixType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
-
- static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<JointMatrixTypeStorage>())
- JointMatrixTypeStorage(key);
- }
-
- bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
- }
-
- JointMatrixTypeStorage(const KeyTy &key)
- : elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<4>(key)),
- matrixLayout(std::get<3>(key)) {}
-
- Type elementType;
- unsigned rows;
- unsigned columns;
- Scope scope;
- MatrixLayout matrixLayout;
-};
-
-JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
- unsigned rows, unsigned columns,
- MatrixLayout matrixLayout) {
- return Base::get(elementType.getContext(), elementType, rows, columns,
- matrixLayout, scope);
-}
-
-Type JointMatrixINTELType::getElementType() const {
- return getImpl()->elementType;
-}
-
-Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
-
-unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
-
-unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
-
-MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
- return getImpl()->matrixLayout;
-}
-
-void JointMatrixINTELType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
-}
-
-void JointMatrixINTELType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
- static const Capability caps[] = {Capability::JointMatrixINTEL};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
-}
-
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
@@ -1247,7 +1172,6 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
- addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
- MatrixType, PointerType, RuntimeArrayType, SampledImageType,
- StructType>();
+ addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
+ RuntimeArrayType, SampledImageType, StructType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b2903824c9e76..b30da773d48967 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -168,8 +168,6 @@ LogicalResult spirv::Deserializer::processInstruction(
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
return processTypeForwardPointer(operands);
- case spirv::Opcode::OpTypeJointMatrixINTEL:
- return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstant:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 12980879b20ab7..38293f7106a05a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -852,8 +852,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processCooperativeMatrixTypeKHR(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
- case spirv::Opcode::OpTypeJointMatrixINTEL:
- return processJointMatrixType(operands);
case spirv::Opcode::OpTypeImage:
return processImageType(operands);
case spirv::Opcode::OpTypeSampledImage:
@@ -1025,40 +1023,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
return success();
}
-LogicalResult
-spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
- if (operands.size() != 6) {
- return emitError(unknownLoc, "OpTypeJointMatrix must have element "
- "type and row x column parameters");
- }
-
- Type elementTy = getType(operands[1]);
- if (!elementTy) {
- return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
- << operands[1];
- }
-
- auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
- if (!scope) {
- return emitError(unknownLoc,
- "OpTypeJointMatrix references undefined scope <id> ")
- << operands[5];
- }
- auto matrixLayout =
- spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
- if (!matrixLayout) {
- return emitError(unknownLoc,
- "OpTypeJointMatrix references undefined scope <id> ")
- << operands[4];
- }
- unsigned rows = getConstantInt(operands[2]).getInt();
- unsigned columns = getConstantInt(operands[3]).getInt();
-
- typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
- elementTy, scope.value(), rows, columns, matrixLayout.value());
- return success();
-}
-
LogicalResult
spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index fc9a8f5f9364b2..264d580c40f097 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -273,8 +273,6 @@ class Deserializer {
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
- LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
-
LogicalResult processImageType(ArrayRef<uint32_t> operands);
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 714a3edfb56573..b0feda0517caa6 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -662,25 +662,6 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
- uint32_t elementTypeID = 0;
- if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
- elementTypeID, serializationCtx))) {
- return failure();
- }
- typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
- auto getConstantOp = [&](uint32_t id) {
- auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
- return prepareConstantInt(loc, attr);
- };
- llvm::append_values(
- operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
- getConstantOp(jointMatrixType.getColumns()),
- getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
- getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
- return success();
- }
-
if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
diff --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
deleted file mode 100644
index afb856d9b13cd1..00000000000000
--- a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
+++ /dev/null
@@ -1,99 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
-
-// CHECK-LABEL: @joint_matrix_load
-spirv.func @joint_matrix_load(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
- spirv.Return
-}
-
-// -----
-// CHECK-LABEL: @joint_matrix_load_memaccess
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, CrossWorkgroup>, i32) -> !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, CrossWorkgroup>, i32) -> !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_load_diff_ptr_type
-spirv.func @joint_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<vector<4xi32>, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<vector<4xi32>, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>
- spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_store
-spirv.func @joint_matrix_store(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
- // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
- spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
- spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_store_memaccess
-spirv.func @joint_matrix_store_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
- // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
- spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
- spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_length
-spirv.func @joint_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, PackedB, Subgroup>
- %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, PackedB, Subgroup>
- spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @joint_matrix_muladd
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix size must match}}
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix size must match}}
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix scope must match}}
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // expected-error @+1 {{matrix element type must match}}
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
- // expected-error @+1 {{Pointer must point to a scalar or vector type}}
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<!spirv.struct<(f32 [0])>, Workgroup>, i32)-> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
- // expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Function>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- spirv.Return
-}
diff --git a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
deleted file mode 100644
index a89921c5d0d361..00000000000000
--- a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
- // CHECK-LABEL: @joint_matrix_load
- spirv.func @joint_matrix_load(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @joint_matrix_load_memaccess
- spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- spirv.Return
- }
-
- // CHECK-LABEL: @joint_matrix_store
- spirv.func @joint_matrix_store(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32, %m : !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
- // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
- spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
- spirv.Return
- }
-
- // CHECK-LABEL: @joint_matrix_store_memaccess
- spirv.func @joint_matrix_store_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
- // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
- spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
- spirv.Return
- }
-
- // CHECK-LABEL: @joint_matrix_length
- spirv.func @joint_matrix_length() -> i32 "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
- spirv.ReturnValue %0 : i32
- }
-
- // CHECK-LABEL: @joint_matrix_muladd
- spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
- // CHECK: {{%.*}} = spirv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
- spirv.Return
- }
-}
More information about the Mlir-commits
mailing list