[Mlir-commits] [mlir] [mlir][spirv] Drop support for SPV_INTEL_joint_matrix (PR #102332)

Andrea Faulds llvmlistbot at llvm.org
Wed Aug 7 09:51:27 PDT 2024


https://github.com/andfau-amd created https://github.com/llvm/llvm-project/pull/102332

This was a "preview" extension, never formalized, that has now been supplanted by SPV_KHR_cooperative_matrix.

>From 2cd9bfdfa0d3a4a2ce813b6a5fb7d17f7883f4e8 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Wed, 7 Aug 2024 18:46:37 +0200
Subject: [PATCH] [mlir][spirv] Drop support for SPV_INTEL_joint_matrix

This was a "preview" extension, never formalized, that has now been
supplanted by SPV_KHR_cooperative_matrix.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td  |  21 --
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  35 +--
 .../Dialect/SPIRV/IR/SPIRVJointMatrixOps.td   | 243 ------------------
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |   1 -
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  30 ---
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |   1 -
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         |   3 +-
 mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp  |  84 ------
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  50 +---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |   4 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  96 +------
 .../SPIRV/Deserialization/DeserializeOps.cpp  |   2 -
 .../SPIRV/Deserialization/Deserializer.cpp    |  36 ---
 .../SPIRV/Deserialization/Deserializer.h      |   2 -
 .../Target/SPIRV/Serialization/Serializer.cpp |  19 --
 .../Dialect/SPIRV/IR/joint-matrix-ops.mlir    |  99 -------
 mlir/test/Target/SPIRV/joint-matrix-ops.mlir  |  45 ----
 17 files changed, 19 insertions(+), 752 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
 delete mode 100644 mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
 delete mode 100644 mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
 delete mode 100644 mlir/test/Target/SPIRV/joint-matrix-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 3a11284da05122..f2a12f68d481b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -101,27 +101,6 @@ def SPIRV_CooperativeMatrixPropertiesNVArrayAttr :
     TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesNVAttr,
                        "CooperativeMatrixPropertiesNV array attribute">;
 
-// Description of the supported joint matrix operations. See
-// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
-def SPIRV_JointMatrixPropertiesINTELAttr :
-    SPIRV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
-  let parameters = (ins
-    "int":$m_size,
-    "int":$n_size,
-    "int":$k_size,
-    "mlir::Type":$a_type,
-    "mlir::Type":$b_type,
-    "mlir::Type":$c_type,
-    "mlir::Type":$result_type,
-    "mlir::spirv::ScopeAttr":$scope
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-def SPIRV_JointMatrixPropertiesINTELArrayAttr :
-    TypedArrayAttrBase<SPIRV_JointMatrixPropertiesINTELAttr,
-                       "JointMatrixPropertiesINTEL array attribute">;
-
 // This attribute specifies the limits for various resources on the target
 // architecture.
 //
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b38978272c5bdc..af0b2624feb327 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -399,7 +399,6 @@ def SPV_INTEL_debug_module                       : I32EnumAttrCase<"SPV_INTEL_de
 def SPV_INTEL_fp_fast_math_mode                  : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
 def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
-def SPV_INTEL_joint_matrix                       : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
@@ -459,7 +458,7 @@ def SPIRV_ExtensionAttr :
       SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
       SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
       SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
-      SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
+      SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
       SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1410,12 +1409,6 @@ def SPIRV_C_ShaderStereoViewNV                          : I32EnumAttrCase<"Shade
   ];
 }
 
-def SPIRV_C_JointMatrixINTEL                         : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
-  list<Availability> availability = [
-    Extension<[SPV_INTEL_joint_matrix]>
-  ];
-}
-
 def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_bfloat16_conversion]>
@@ -1514,7 +1507,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
       SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
-      SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL
+      SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4131,8 +4124,6 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
 def SPIRV_IsCooperativeMatrixType :
   CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
 def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
-def SPIRV_IsJointMatrixType :
-  CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
 def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
 def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4164,8 +4155,6 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                   "any SPIR-V cooperative matrix type">;
 def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
                                 "any SPIR-V image type">;
-def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
-                                "any SPIR-V joint matrix type">;
 def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
                                 "any SPIR-V matrix type">;
 def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4180,12 +4169,11 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-    SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
-    SPIRV_AnySampledImage
+    SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4196,11 +4184,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
     "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
     "Cooperative Matrix">;
 
-class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
-  ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
-    "::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
-    "Joint Matrix">;
-
 class SPIRV_VectorOf<Type type> :
     VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
 
@@ -4482,12 +4465,6 @@ def SPIRV_OC_OpAtomicFAddEXT              : I32EnumAttrCase<"OpAtomicFAddEXT", 6
 def SPIRV_OC_OpGroupIMulKHR               : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR               : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
 
-def SPIRV_OC_OpTypeJointMatrixINTEL       : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
-def SPIRV_OC_OpJointMatrixLoadINTEL       : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
-def SPIRV_OC_OpJointMatrixStoreINTEL      : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
-def SPIRV_OC_OpJointMatrixMadINTEL        : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
-def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
-
 def SPIRV_OC_OpConvertFToBF16INTEL        : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
 def SPIRV_OC_OpConvertBF16ToFINTEL        : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
 
@@ -4579,10 +4556,6 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
       SPIRV_OC_OpGroupFMulKHR,
 
-      SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL,
-      SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL,
-      SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL,
-
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL
     ]>;
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
deleted file mode 100644
index f96849de9abb1e..00000000000000
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
+++ /dev/null
@@ -1,243 +0,0 @@
-//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This is the op definition spec of joint matrix multiply extension ops.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
-#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
-
-// -----
-
-def SPIRV_INTELJointMatrixWorkItemLengthOp : SPIRV_IntelVendorOp<"JointMatrixWorkItemLength",
-  [Pure]> {
-  let summary = "See extension SPV_INTEL_joint_matrix";
-
-  let description = [{
-    Return number of components owned by the current work-item in
-    a joint matrix.
-
-    Result Type must be an 32-bit unsigned integer type scalar.
-
-    Type is a joint matrix type.
-
-    #### Example:
-
-    ```
-    %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<Subgroup, i32, 8, 16>
-    ```
-  }];
-
-  let assemblyFormat = "attr-dict `:` $joint_matrix_type";
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_INTEL_joint_matrix]>,
-    Capability<[SPIRV_C_JointMatrixINTEL]>
-  ];
-
-  let arguments = (ins
-    TypeAttr:$joint_matrix_type
-  );
-
-  let results = (outs
-    SPIRV_Int32:$result
-  );
-  let hasVerifier = 0;
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixLoadOp : SPIRV_IntelVendorOp<"JointMatrixLoad", []> {
-  let summary = "See extension SPV_INTEL_joint_matrix";
-
-  let description = [{
-    Load a matrix through a pointer.
-
-    Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
-
-    Pointer is the pointer to load through. It specifies start of memory region where
-    elements of the matrix are stored and arranged according to Layout.
-
-    Stride is the number of elements in memory between beginnings of successive rows,
-    columns (or words) in the result. It must be a scalar integer type.
-
-    Layout indicates how the values loaded from memory are arranged. It must be the
-    result of a constant instruction.
-
-    Scope is syncronization scope for operation on the matrix. It must be the result
-    of a constant instruction with scalar integer type.
-
-    If present, any Memory Operands must begin with a memory operand literal. If not
-    present, it is the same as specifying the memory operand None.
-
-    #### Example:
-    ```mlir
-    %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride
-         {memory_access = #spirv.memory_access<Volatile>} :
-         (!spirv.ptr<i32, CrossWorkgroup>, i32) ->
-         !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
-    ```
-  }];
-
-  let assemblyFormat = [{
-    $scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_INTEL_joint_matrix]>,
-    Capability<[SPIRV_C_JointMatrixINTEL]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyPtr:$pointer,
-    SPIRV_Integer:$stride,
-    SPIRV_MatrixLayoutAttr:$layout,
-    SPIRV_ScopeAttr:$scope,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
-  );
-
-  let results = (outs
-    SPIRV_AnyJointMatrix:$result
-  );
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixMadOp : SPIRV_IntelVendorOp<"JointMatrixMad",
-  [Pure, AllTypesMatch<["c", "result"]>]> {
-  let summary = "See extension SPV_INTEL_joint_matrix";
-
-  let description = [{
-    Multiply matrix A by matrix B and add matrix C to the result
-    of the multiplication: A*B+C. Here A is a M x K matrix, B is
-    a K x N matrix and C is a M x N matrix.
-
-    Behavior is undefined if sizes of operands do not meet the
-    conditions above. All operands and the Result Type must be
-    OpTypeJointMatrixINTEL.
-
-    A must be a OpTypeJointMatrixINTEL whose Component Type is a
-    signed numerical type, Row Count equals to M and Column Count
-    equals to K
-
-    B must be a OpTypeJointMatrixINTEL whose Component Type is a
-    signed numerical type, Row Count equals to K and Column Count
-    equals to N
-
-    C and Result Type must be a OpTypeJointMatrixINTEL with Row
-    Count equals to M and Column Count equals to N
-
-    Scope is syncronization scope for operation on the matrix.
-    It must be the result of a constant instruction with scalar
-    integer type.
-
-    #### Example:
-    ```mlir
-    %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c :
-         !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
-         !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-         -> !spirv.jointmatrix<8x8xi32,  RowMajor, Subgroup>
-    ```
-
-  }];
-
-  let assemblyFormat = [{
-    $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_INTEL_joint_matrix]>,
-    Capability<[SPIRV_C_JointMatrixINTEL]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyJointMatrix:$a,
-    SPIRV_AnyJointMatrix:$b,
-    SPIRV_AnyJointMatrix:$c,
-    SPIRV_ScopeAttr:$scope
-  );
-
-  let results = (outs
-    SPIRV_AnyJointMatrix:$result
-  );
-}
-
-// -----
-
-def SPIRV_INTELJointMatrixStoreOp : SPIRV_IntelVendorOp<"JointMatrixStore", []> {
-  let summary = "See extension SPV_INTEL_joint_matrix";
-
-  let description = [{
-    Store a matrix through a pointer.
-
-    Pointer is the pointer to store through. It specifies
-    start of memory region where elements of the matrix must
-    be stored and arranged according to Layout.
-
-    Object is the matrix to store. It must be
-    OpTypeJointMatrixINTEL.
-
-    Stride is the number of elements in memory between beginnings
-    of successive rows, columns (or words) of the Object. It must
-    be a scalar integer type.
-
-    Layout indicates how the values stored to memory are arranged.
-    It must be the result of a constant instruction.
-
-    Scope is syncronization scope for operation on the matrix.
-    It must be the result of a constant instruction with scalar
-    integer type.
-
-    If present, any Memory Operands must begin with a memory operand
-    literal. If not present, it is the same as specifying the memory
-    operand None.
-
-    #### Example:
-    ```mlir
-    spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride
-    {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>,
-    !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
-    ```
-
-  }];
-
-   let assemblyFormat = [{
-    $scope $layout operands attr-dict `:` `(` type(operands) `)`
-  }];
-
-  let availability = [
-    MinVersion<SPIRV_V_1_0>,
-    MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_INTEL_joint_matrix]>,
-    Capability<[SPIRV_C_JointMatrixINTEL]>
-  ];
-
-  let arguments = (ins
-    SPIRV_AnyPtr:$pointer,
-    SPIRV_AnyJointMatrix:$object,
-    SPIRV_Integer:$stride,
-    SPIRV_MatrixLayoutAttr:$layout,
-    SPIRV_ScopeAttr:$scope,
-    OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
-  );
-
-  let results = (outs);
-}
-
-// -----
-
-#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 13533d1d65b8ff..9912f195ba11e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -30,7 +30,6 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
-include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..d00bd818f48cca 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -30,7 +30,6 @@ namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
 struct ImageTypeStorage;
-struct JointMatrixTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
@@ -420,35 +419,6 @@ class CooperativeMatrixType
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
-// SPIR-V joint matrix type
-class JointMatrixINTELType
-    : public Type::TypeBase<JointMatrixINTELType, CompositeType,
-                            detail::JointMatrixTypeStorage> {
-public:
-  using Base::Base;
-
-  static constexpr StringLiteral name = "spirv.jointmatrix";
-
-  static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
-                                  unsigned columns, MatrixLayout matrixLayout);
-  Type getElementType() const;
-
-  /// Return the scope of the joint matrix.
-  Scope getScope() const;
-  /// return the number of rows of the matrix.
-  unsigned getRows() const;
-  /// return the number of columns of the matrix.
-  unsigned getColumns() const;
-
-  /// return the layout of the matrix
-  MatrixLayout getMatrixLayout() const;
-
-  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                     std::optional<StorageClass> storage = std::nullopt);
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-};
-
 // SPIR-V matrix type
 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                                          detail::MatrixTypeStorage> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b185264211474f..7d760e0dd80222 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -9,7 +9,6 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   CooperativeMatrixOps.cpp
   GroupOps.cpp
   IntegerDotProductOps.cpp
-  JointMatrixOps.cpp
   MemoryOps.cpp
   SPIRVAttributes.cpp
   SPIRVCanonicalization.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index 52b4380ed27f7c..e27dc274673be4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -36,8 +36,7 @@ static LogicalResult verifyCastOp(Operation *op,
   using TypePair = std::pair<Type, Type>;
   auto [operandElemTy, resultElemTy] =
       TypeSwitch<Type, TypePair>(operandType)
-          .Case<VectorType, spirv::CooperativeMatrixType,
-                spirv::JointMatrixINTELType>(
+          .Case<VectorType, spirv::CooperativeMatrixType>(
               [resultType](auto concreteOperandTy) -> TypePair {
                 if (auto concreteResultTy =
                         dyn_cast<decltype(concreteOperandTy)>(resultType)) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
deleted file mode 100644
index 63305ecdd0c4e9..00000000000000
--- a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
+++ /dev/null
@@ -1,84 +0,0 @@
-//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops  -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines the Intel Joint Matrix operations in the SPIR-V dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-
-namespace mlir {
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
-  Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
-  if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
-      !llvm::isa<VectorType>(pointeeType))
-    return op->emitError(
-               "Pointer must point to a scalar or vector type but provided ")
-           << pointeeType;
-  spirv::StorageClass storage =
-      llvm::cast<spirv::PointerType>(pointer).getStorageClass();
-  if (storage != spirv::StorageClass::Workgroup &&
-      storage != spirv::StorageClass::CrossWorkgroup &&
-      storage != spirv::StorageClass::UniformConstant &&
-      storage != spirv::StorageClass::Generic)
-    return op->emitError("Pointer storage class must be Workgroup or "
-                         "CrossWorkgroup but provided ")
-           << stringifyStorageClass(storage);
-  return success();
-}
-
-LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
-  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
-                                         getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixStore
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
-  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
-                                         getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixMad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
-  if (op.getC().getType() != op.getResult().getType())
-    return op.emitOpError("result and third operand must have the same type");
-  auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
-  auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
-  auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
-  auto typeR =
-      llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
-  if (typeA.getRows() != typeR.getRows() ||
-      typeA.getColumns() != typeB.getRows() ||
-      typeB.getColumns() != typeR.getColumns())
-    return op.emitOpError("matrix size must match");
-  if (typeR.getScope() != typeA.getScope() ||
-      typeR.getScope() != typeB.getScope() ||
-      typeR.getScope() != typeC.getScope())
-    return op.emitOpError("matrix scope must match");
-  if (typeA.getElementType() != typeB.getElementType() ||
-      typeR.getElementType() != typeC.getElementType())
-    return op.emitOpError("matrix element type must match");
-  return success();
-}
-
-LogicalResult spirv::INTELJointMatrixMadOp::verify() {
-  return verifyJointMatrixMad(*this);
-}
-
-} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 72488d6e5d0b09..48be287ef833b2 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -362,41 +362,6 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
-// joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x`
-// element-type
-//                                                       `,` layout `,` scope`>`
-static Type parseJointMatrixType(SPIRVDialect const &dialect,
-                                 DialectAsmParser &parser) {
-  if (parser.parseLess())
-    return Type();
-
-  SmallVector<int64_t, 2> dims;
-  SMLoc countLoc = parser.getCurrentLocation();
-  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
-    return Type();
-
-  if (dims.size() != 2) {
-    parser.emitError(countLoc, "expected rows and columns size");
-    return Type();
-  }
-
-  auto elementTy = parseAndVerifyType(dialect, parser);
-  if (!elementTy)
-    return Type();
-  MatrixLayout matrixLayout;
-  if (parser.parseComma() ||
-      spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
-    return Type();
-  Scope scope;
-  if (parser.parseComma() ||
-      spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
-    return Type();
-  if (parser.parseGreater())
-    return Type();
-  return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
-                                   matrixLayout);
-}
-
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -781,8 +746,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseArrayType(*this, parser);
   if (keyword == "coopmatrix")
     return parseCooperativeMatrixType(*this, parser);
-  if (keyword == "jointmatrix")
-    return parseJointMatrixType(*this, parser);
   if (keyword == "image")
     return parseImageType(*this, parser);
   if (keyword == "ptr")
@@ -886,13 +849,6 @@ static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
      << type.getUse() << ">";
 }
 
-static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
-  os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
-  os << type.getElementType() << ", "
-     << stringifyMatrixLayout(type.getMatrixLayout());
-  os << ", " << stringifyScope(type.getScope()) << ">";
-}
-
 static void print(MatrixType type, DialectAsmPrinter &os) {
   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
   os << ">";
@@ -900,9 +856,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
 
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
-      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, PointerType,
-            RuntimeArrayType, ImageType, SampledImageType, StructType,
-            MatrixType>([&](auto type) { print(type, os); })
+      .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
+            ImageType, SampledImageType, StructType, MatrixType>(
+          [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d9bc05acddc82b..c8386fecea038a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -373,7 +373,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
 
   auto coopElementType =
       llvm::TypeSwitch<Type, Type>(getType())
-          .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
+          .Case<spirv::CooperativeMatrixType>(
               [](auto coopType) { return coopType.getElementType(); })
           .Default([](Type) { return nullptr; });
 
@@ -1834,8 +1834,6 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
 
   if (llvm::isa<spirv::CooperativeMatrixType>(cType))
     return emitError("unsupported composite type  ") << cType;
-  if (llvm::isa<spirv::JointMatrixINTELType>(cType))
-    return emitError("unsupported composite type  ") << cType;
   if (constituents.size() != cType.getNumElements())
     return emitError("has incorrect number of operands: expected ")
            << cType.getNumElements() << ", but provided "
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..cd531882cafb08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -95,8 +95,8 @@ bool CompositeType::classof(Type type) {
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
-                   spirv::JointMatrixINTELType, spirv::MatrixType,
-                   spirv::RuntimeArrayType, spirv::StructType>(type);
+                   spirv::MatrixType, spirv::RuntimeArrayType,
+                   spirv::StructType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -107,8 +107,7 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType,
-            RuntimeArrayType, VectorType>(
+      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
           [](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
@@ -130,10 +129,6 @@ unsigned CompositeType::getNumElements() const {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
   }
-  if (llvm::isa<JointMatrixINTELType>(*this)) {
-    llvm_unreachable(
-        "invalid to query number of elements of spirv::JointMatrix type");
-  }
   if (llvm::isa<RuntimeArrayType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -142,16 +137,15 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !llvm::isa<CooperativeMatrixType, JointMatrixINTELType,
-                    RuntimeArrayType>(*this);
+  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
 }
 
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
-            RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
+            StructType>(
           [&](auto type) { type.getExtensions(extensions, storage); })
       .Case<VectorType>([&](VectorType type) {
         return llvm::cast<ScalarType>(type.getElementType())
@@ -164,8 +158,8 @@ void CompositeType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, JointMatrixINTELType, MatrixType,
-            RuntimeArrayType, StructType>(
+      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
+            StructType>(
           [&](auto type) { type.getCapabilities(capabilities, storage); })
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
@@ -266,75 +260,6 @@ void CooperativeMatrixType::getCapabilities(
   capabilities.push_back(caps);
 }
 
-//===----------------------------------------------------------------------===//
-// JointMatrixType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
-
-  static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
-                                           const KeyTy &key) {
-    return new (allocator.allocate<JointMatrixTypeStorage>())
-        JointMatrixTypeStorage(key);
-  }
-
-  bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
-  }
-
-  JointMatrixTypeStorage(const KeyTy &key)
-      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
-        columns(std::get<2>(key)), scope(std::get<4>(key)),
-        matrixLayout(std::get<3>(key)) {}
-
-  Type elementType;
-  unsigned rows;
-  unsigned columns;
-  Scope scope;
-  MatrixLayout matrixLayout;
-};
-
-JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
-                                               unsigned rows, unsigned columns,
-                                               MatrixLayout matrixLayout) {
-  return Base::get(elementType.getContext(), elementType, rows, columns,
-                   matrixLayout, scope);
-}
-
-Type JointMatrixINTELType::getElementType() const {
-  return getImpl()->elementType;
-}
-
-Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
-
-unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
-
-unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
-
-MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
-  return getImpl()->matrixLayout;
-}
-
-void JointMatrixINTELType::getExtensions(
-    SPIRVType::ExtensionArrayRefVector &extensions,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
-  ArrayRef<Extension> ref(exts, std::size(exts));
-  extensions.push_back(ref);
-}
-
-void JointMatrixINTELType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-  static const Capability caps[] = {Capability::JointMatrixINTEL};
-  ArrayRef<Capability> ref(caps, std::size(caps));
-  capabilities.push_back(ref);
-}
-
 //===----------------------------------------------------------------------===//
 // ImageType
 //===----------------------------------------------------------------------===//
@@ -1247,7 +1172,6 @@ void MatrixType::getCapabilities(
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
-  addTypes<ArrayType, CooperativeMatrixType, ImageType, JointMatrixINTELType,
-           MatrixType, PointerType, RuntimeArrayType, SampledImageType,
-           StructType>();
+  addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
+           RuntimeArrayType, SampledImageType, StructType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b2903824c9e76..b30da773d48967 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -168,8 +168,6 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
     return processTypeForwardPointer(operands);
-  case spirv::Opcode::OpTypeJointMatrixINTEL:
-    return processType(opcode, operands);
   case spirv::Opcode::OpConstant:
     return processConstant(operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstant:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 12980879b20ab7..38293f7106a05a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -852,8 +852,6 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processCooperativeMatrixTypeKHR(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
-  case spirv::Opcode::OpTypeJointMatrixINTEL:
-    return processJointMatrixType(operands);
   case spirv::Opcode::OpTypeImage:
     return processImageType(operands);
   case spirv::Opcode::OpTypeSampledImage:
@@ -1025,40 +1023,6 @@ LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
   return success();
 }
 
-LogicalResult
-spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
-  if (operands.size() != 6) {
-    return emitError(unknownLoc, "OpTypeJointMatrix must have element "
-                                 "type and row x column parameters");
-  }
-
-  Type elementTy = getType(operands[1]);
-  if (!elementTy) {
-    return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
-           << operands[1];
-  }
-
-  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
-  if (!scope) {
-    return emitError(unknownLoc,
-                     "OpTypeJointMatrix references undefined scope <id> ")
-           << operands[5];
-  }
-  auto matrixLayout =
-      spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
-  if (!matrixLayout) {
-    return emitError(unknownLoc,
-                     "OpTypeJointMatrix references undefined scope <id> ")
-           << operands[4];
-  }
-  unsigned rows = getConstantInt(operands[2]).getInt();
-  unsigned columns = getConstantInt(operands[3]).getInt();
-
-  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
-      elementTy, scope.value(), rows, columns, matrixLayout.value());
-  return success();
-}
-
 LogicalResult
 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index fc9a8f5f9364b2..264d580c40f097 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -273,8 +273,6 @@ class Deserializer {
 
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
-  LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
-
   LogicalResult processImageType(ArrayRef<uint32_t> operands);
 
   LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 714a3edfb56573..b0feda0517caa6 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -662,25 +662,6 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
-  if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
-    uint32_t elementTypeID = 0;
-    if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
-                               elementTypeID, serializationCtx))) {
-      return failure();
-    }
-    typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
-    auto getConstantOp = [&](uint32_t id) {
-      auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
-      return prepareConstantInt(loc, attr);
-    };
-    llvm::append_values(
-        operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
-        getConstantOp(jointMatrixType.getColumns()),
-        getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
-        getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
-    return success();
-  }
-
   if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
     uint32_t elementTypeID = 0;
     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
diff --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
deleted file mode 100644
index afb856d9b13cd1..00000000000000
--- a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
+++ /dev/null
@@ -1,99 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
-
-// CHECK-LABEL: @joint_matrix_load
-spirv.func @joint_matrix_load(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
-  // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
-  %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
-  spirv.Return
-}
-
-// -----
-// CHECK-LABEL: @joint_matrix_load_memaccess
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
-  // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, CrossWorkgroup>, i32) -> !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
-  %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, CrossWorkgroup>, i32) -> !spirv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_load_diff_ptr_type
-spirv.func @joint_matrix_load_diff_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
-  // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<vector<4xi32>, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>
-  %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<vector<4xi32>, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>
-  spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_store
-spirv.func @joint_matrix_store(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
-  // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
-  spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
-  spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_store_memaccess
-spirv.func @joint_matrix_store_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
-  // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
-  spirv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
-  spirv.Return
-}
-
-// CHECK-LABEL: @joint_matrix_length
-spirv.func @joint_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, PackedB, Subgroup>
-  %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, PackedB, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @joint_matrix_muladd
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-  %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spirv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32,  RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix size must match}}
-  %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix size must match}}
-  %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.INTEL.JointMatrixMad' op matrix scope must match}}
-  %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-  // expected-error @+1 {{matrix element type must match}}
-  %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
-  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
-  %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<!spirv.struct<(f32 [0])>, Workgroup>, i32)-> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
-  // expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
-  %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Function>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-  spirv.Return
-}
diff --git a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
deleted file mode 100644
index a89921c5d0d361..00000000000000
--- a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
-  // CHECK-LABEL: @joint_matrix_load
-  spirv.func @joint_matrix_load(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
-    // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
-    %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @joint_matrix_load_memaccess
-  spirv.func @joint_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32) "None" {
-    // CHECK: {{%.*}} = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-    %0 = spirv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, i32) -> !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @joint_matrix_store
-  spirv.func @joint_matrix_store(%ptr : !spirv.ptr<i32, Workgroup>, %stride : i32, %m : !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
-    // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
-    spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @joint_matrix_store_memaccess
-  spirv.func @joint_matrix_store_memaccess(%ptr : !spirv.ptr<i32, Workgroup>, %m : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
-    // CHECK: spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
-    spirv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spirv.memory_access<Volatile>} : (!spirv.ptr<i32, Workgroup>, !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
-    spirv.Return
-  }
-
-  // CHECK-LABEL: @joint_matrix_length
-  spirv.func @joint_matrix_length() -> i32 "None" {
-    // CHECK: {{%.*}} = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-    %0 = spirv.INTEL.JointMatrixWorkItemLength : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>
-    spirv.ReturnValue %0 : i32
-  }
-
-  // CHECK-LABEL: @joint_matrix_muladd
-  spirv.func @joint_matrix_muladd(%a : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spirv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-    %r = spirv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spirv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spirv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spirv.jointmatrix<8x8xi32, RowMajor, Subgroup>
-    spirv.Return
-  }
-}



More information about the Mlir-commits mailing list