[Mlir-commits] [mlir] b8f62dc - [MLIR][SPIRV] Add intel joint matrix ops

Nirvedh Meshram llvmlistbot at llvm.org
Mon Aug 15 16:51:58 PDT 2022


Author: Nirvedh Meshram
Date: 2022-08-15T23:49:45Z
New Revision: b8f62dc22a9333bc016b19921a9dac6be6cb947c

URL: https://github.com/llvm/llvm-project/commit/b8f62dc22a9333bc016b19921a9dac6be6cb947c
DIFF: https://github.com/llvm/llvm-project/commit/b8f62dc22a9333bc016b19921a9dac6be6cb947c.diff

LOG: [MLIR][SPIRV] Add intel joint matrix ops

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D131586

Added: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
    mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
    mlir/test/Target/SPIRV/joint-matrix-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 7ce34dcd5dedb..f5bf65bf6d0f7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -27,12 +27,12 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
-    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
-    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
+    SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand1,
+    SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand2
   );
 
   let results = (outs
-    SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+    SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$result
   );
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 791615274d52f..02742278fe57f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -64,6 +64,27 @@ def SPV_CooperativeMatrixPropertiesNVArrayAttr :
     TypedArrayAttrBase<SPV_CooperativeMatrixPropertiesNVAttr,
                        "CooperativeMatrixPropertiesNV array attribute">;
 
+// Description of the supported joint matrix operations. See
+// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
+def SPV_JointMatrixPropertiesINTELAttr :
+    SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
+  let parameters = (ins
+    "int":$m_size,
+    "int":$n_size,
+    "int":$k_size,
+    "mlir::Type":$a_type,
+    "mlir::Type":$b_type,
+    "mlir::Type":$c_type,
+    "mlir::Type":$result_type,
+    "mlir::spirv::ScopeAttr":$scope
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def SPV_JointMatrixPropertiesINTELArrayAttr :
+    TypedArrayAttrBase<SPV_JointMatrixPropertiesINTELAttr,
+                       "JointMatrixPropertiesINTEL array attribute">;
+
 // This attribute specifies the limits for various resources on the target
 // architecture.
 //

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0cfb3ead3642c..84601dd3eae55 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -387,6 +387,7 @@ def SPV_INTEL_debug_module                       : I32EnumAttrCase<"SPV_INTEL_de
 def SPV_INTEL_fp_fast_math_mode                  : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
 def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
+def SPV_INTEL_joint_matrix                       : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -443,7 +444,7 @@ def SPV_ExtensionAttr :
       SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
       SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
       SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
-      SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
+      SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
       SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1390,6 +1391,12 @@ def SPV_C_ShaderStereoViewNV                          : I32EnumAttrCase<"ShaderS
   ];
 }
 
+def SPV_C_JointMatrixINTEL                         : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_joint_matrix]>
+  ];
+}
+
 def SPV_CapabilityAttr :
     SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
       SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
@@ -1481,7 +1488,7 @@ def SPV_CapabilityAttr :
       SPV_C_UniformTexelBufferArrayNonUniformIndexing,
       SPV_C_StorageTexelBufferArrayNonUniformIndexing,
       SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV,
-      SPV_C_ShaderStereoViewNV
+      SPV_C_ShaderStereoViewNV, SPV_C_JointMatrixINTEL
     ]>;
 
 def SPV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -3981,6 +3988,16 @@ def SPV_SamplerUseAttr: SPV_I32EnumAttr<
   "image_sampler_use_info",
   [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
 
+def SPV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
+def SPV_ML_RowMajor    : I32EnumAttrCase<"RowMajor", 1>;
+def SPV_ML_PackedA     : I32EnumAttrCase<"PackedA", 2>;
+def SPV_ML_PackedB     : I32EnumAttrCase<"PackedB", 3>;
+
+def SPV_MatrixLayoutAttr  :
+    SPV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
+      SPV_ML_ColumnMajor, SPV_ML_RowMajor, SPV_ML_PackedA, SPV_ML_PackedB
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
 //===----------------------------------------------------------------------===//
@@ -4013,6 +4030,8 @@ def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
 def SPV_IsCooperativeMatrixType :
   CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
 def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">;
+def SPV_IsJointMatrixType :
+  CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">;
 def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
 def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
 def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
@@ -4043,6 +4062,8 @@ def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
                                "any SPIR-V cooperative matrix type">;
 def SPV_AnyImage : DialectType<SPIRV_Dialect, SPV_IsImageType,
                                 "any SPIR-V image type">;
+def SPV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPV_IsJointMatrixType,
+                                "any SPIR-V joint matrix type">;
 def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
                                 "any SPIR-V matrix type">;
 def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
@@ -4057,11 +4078,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
 def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
 def SPV_Composite :
     AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
-               SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
+               SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>;
 def SPV_Type : AnyTypeOf<[
     SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
     SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
-    SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage
+    SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix,
+    SPV_AnySampledImage
   ]>;
 
 def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4072,6 +4094,11 @@ class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
     "$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()",
     "Cooperative Matrix">;
 
+class SPV_JointMatrixOfType<list<Type> allowedTypes> :
+  ContainerType<AnyTypeOf<allowedTypes>, SPV_IsJointMatrixType,
+    "$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()",
+    "Joint Matrix">;
+
 class SPV_ScalarOrVectorOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
 
@@ -4079,6 +4106,14 @@ class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
     AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
                SPV_CoopMatrixOfType<[type]>]>;
 
+class SPV_ScalarOrVectorOrJointMatrixOf<Type type> :
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+               SPV_JointMatrixOfType<[type]>]>;
+
+class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<Type type> :
+    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+               SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>;
+
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
@@ -4311,6 +4346,11 @@ def SPV_OC_OpSubgroupBlockReadINTEL    : I32EnumAttrCase<"OpSubgroupBlockReadINT
 def SPV_OC_OpSubgroupBlockWriteINTEL   : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 def SPV_OC_OpAssumeTrueKHR             : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
 def SPV_OC_OpAtomicFAddEXT             : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
+def SPV_OC_OpTypeJointMatrixINTEL      : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
+def SPV_OC_OpJointMatrixLoadINTEL      : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
+def SPV_OC_OpJointMatrixStoreINTEL     : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
+def SPV_OC_OpJointMatrixMadINTEL       : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
+def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
 
 def SPV_OpcodeAttr :
     SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4376,7 +4416,10 @@ def SPV_OpcodeAttr :
       SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
       SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
       SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL,
-      SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT
+      SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT,
+      SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL,
+      SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL,
+      SPV_OC_OpTypejointMatrixWorkItemLengthINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 5048dd10ae575..27bad3f08e083 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -23,11 +23,11 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
              !listconcat(traits,
                          [NoSideEffect, SameOperandsAndResultShape])> {
   let arguments = (ins
-    SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand
+    SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<operandType>:$operand
   );
 
   let results = (outs
-    SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
+    SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<resultType>:$result
   );
   let assemblyFormat = [{
     $operand attr-dict `:` type($operand) `to` type($result)

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
new file mode 100644
index 0000000000000..aa45ef80b5e94
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
@@ -0,0 +1,248 @@
+//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of joint matrix multiply extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
+#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
+
+// -----
+
+def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
+  [NoSideEffect]> {
+  let summary = "See extension SPV_INTEL_joint_matrix";
+
+  let description = [{
+    Return number of components owned by the current work-item in 
+    a joint matrix.
+
+    Result Type must be an 32-bit unsigned integer type scalar.
+
+    Type is a joint matrix type.
+
+    ``` {.ebnf}
+    joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL
+                                    ` : ` joint-matrix-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<Subgroup, i32, 8, 16>
+    ```
+  }];
+
+  let assemblyFormat = "attr-dict `:` $type";
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_6>,
+    Extension<[SPV_INTEL_joint_matrix]>,
+    Capability<[SPV_C_JointMatrixINTEL]>
+  ];
+
+  let arguments = (ins
+    TypeAttr:$type
+  );
+
+  let results = (outs
+    SPV_Int32:$result
+  );
+  let hasVerifier = 0;
+}
+
+// -----
+
+def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
+  let summary = "See extension SPV_INTEL_joint_matrix";
+
+  let description = [{
+    Load a matrix through a pointer.
+
+    Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
+
+    Pointer is the pointer to load through. It specifies start of memory region where 
+    elements of the matrix are stored and arranged according to Layout.
+
+    Stride is the number of elements in memory between beginnings of successive rows, 
+    columns (or words) in the result. It must be a scalar integer type.
+
+    Layout indicates how the values loaded from memory are arranged. It must be the 
+    result of a constant instruction.
+
+    Scope is syncronization scope for operation on the matrix. It must be the result 
+    of a constant instruction with scalar integer type.
+
+    If present, any Memory Operands must begin with a memory operand literal. If not 
+    present, it is the same as specifying the memory operand None.
+
+    #### Example:
+    ```mlir
+    %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride 
+         {memory_access = #spv.memory_access<Volatile>} : 
+         (!spv.ptr<i32, CrossWorkgroup>, i32) -> 
+         !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_6>,
+    Extension<[SPV_INTEL_joint_matrix]>,
+    Capability<[SPV_C_JointMatrixINTEL]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$scope,
+    SPV_MatrixLayoutAttr:$layout,
+    SPV_AnyPtr:$pointer,
+    SPV_Integer:$stride,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+    OptionalAttr<I32Attr>:$alignment
+  );
+
+  let results = (outs
+    SPV_AnyJointMatrix:$result
+  );
+}
+
+// -----
+
+def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
+  [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
+  let summary = "See extension SPV_INTEL_joint_matrix";
+
+  let description = [{
+    Multiply matrix A by matrix B and add matrix C to the result 
+    of the multiplication: A*B+C. Here A is a M x K matrix, B is 
+    a K x N matrix and C is a M x N matrix.
+
+    Behavior is undefined if sizes of operands do not meet the 
+    conditions above. All operands and the Result Type must be 
+    OpTypeJointMatrixINTEL.
+
+    A must be a OpTypeJointMatrixINTEL whose Component Type is a 
+    signed numerical type, Row Count equals to M and Column Count 
+    equals to K
+
+    B must be a OpTypeJointMatrixINTEL whose Component Type is a 
+    signed numerical type, Row Count equals to K and Column Count 
+    equals to N
+
+    C and Result Type must be a OpTypeJointMatrixINTEL with Row 
+    Count equals to M and Column Count equals to N
+
+    Scope is syncronization scope for operation on the matrix. 
+    It must be the result of a constant instruction with scalar 
+    integer type.
+
+    #### Example:
+    ```mlir
+    %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : 
+         !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, 
+         !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> 
+         -> !spv.jointmatrix<8x8xi32,  RowMajor, Subgroup>
+    ```
+
+  }];
+
+  let assemblyFormat = [{
+    $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_6>,
+    Extension<[SPV_INTEL_joint_matrix]>,
+    Capability<[SPV_C_JointMatrixINTEL]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$scope,
+    SPV_AnyJointMatrix:$a,
+    SPV_AnyJointMatrix:$b,
+    SPV_AnyJointMatrix:$c
+  );
+
+  let results = (outs
+    SPV_AnyJointMatrix:$result
+  );
+}
+
+// -----
+
+def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
+  let summary = "See extension SPV_INTEL_joint_matrix";
+
+  let description = [{
+    Store a matrix through a pointer.
+
+    Pointer is the pointer to store through. It specifies 
+    start of memory region where elements of the matrix must 
+    be stored and arranged according to Layout.
+
+    Object is the matrix to store. It must be 
+    OpTypeJointMatrixINTEL.
+
+    Stride is the number of elements in memory between beginnings 
+    of successive rows, columns (or words) of the Object. It must 
+    be a scalar integer type.
+
+    Layout indicates how the values stored to memory are arranged. 
+    It must be the result of a constant instruction.
+
+    Scope is syncronization scope for operation on the matrix. 
+    It must be the result of a constant instruction with scalar 
+    integer type.
+
+    If present, any Memory Operands must begin with a memory operand 
+    literal. If not present, it is the same as specifying the memory 
+    operand None.
+
+    #### Example:
+    ```mlir
+    spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride 
+    {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, 
+    !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+    ```
+
+  }];
+
+   let assemblyFormat = [{
+    $scope $layout operands attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_6>,
+    Extension<[SPV_INTEL_joint_matrix]>,
+    Capability<[SPV_C_JointMatrixINTEL]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$scope,
+    SPV_MatrixLayoutAttr:$layout,
+    SPV_AnyPtr:$pointer,
+    SPV_AnyJointMatrix:$object,
+    SPV_Integer:$stride,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+    OptionalAttr<I32Attr>:$alignment
+  );
+
+  let results = (outs);
+}
+
+// -----
+
+#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index aa87f0e142e80..5e8e5e4c7ce92 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 40a4acff751b1..d9737e4c2c579 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
 struct ImageTypeStorage;
+struct JointMatrixTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
@@ -420,6 +421,33 @@ class CooperativeMatrixNVType
                        Optional<StorageClass> storage = llvm::None);
 };
 
+// SPIR-V joint matrix type
+class JointMatrixINTELType
+    : public Type::TypeBase<JointMatrixINTELType, CompositeType,
+                            detail::JointMatrixTypeStorage> {
+public:
+  using Base::Base;
+
+  static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
+                                  unsigned columns, MatrixLayout matrixLayout);
+  Type getElementType() const;
+
+  /// Return the scope of the joint matrix.
+  Scope getScope() const;
+  /// return the number of rows of the matrix.
+  unsigned getRows() const;
+  /// return the number of columns of the matrix.
+  unsigned getColumns() const;
+
+  /// return the layout of the matrix
+  MatrixLayout getMatrixLayout() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<StorageClass> storage = llvm::None);
+};
+
 // SPIR-V matrix type
 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                                          detail::MatrixTypeStorage> {

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 88bae6bb43263..d5449a2c09b6c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -348,6 +348,39 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
 }
 
+// joint-matrix-type ::= `!spv.jointmatrix` `<`rows `x` columns `x` element-type
+//                                                       `,` layout `,` scope`>`
+static Type parseJointMatrixType(SPIRVDialect const &dialect,
+                                 DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  SmallVector<int64_t, 2> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+  if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+    return Type();
+
+  if (dims.size() != 2) {
+    parser.emitError(countLoc, "expected rows and columns size");
+    return Type();
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return Type();
+  MatrixLayout matrixLayout;
+  if (parser.parseComma() ||
+      parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
+    return Type();
+  Scope scope;
+  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+    return Type();
+  if (parser.parseGreater())
+    return Type();
+  return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
+                                   matrixLayout);
+}
+
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -753,6 +786,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseArrayType(*this, parser);
   if (keyword == "coopmatrix")
     return parseCooperativeMatrixType(*this, parser);
+  if (keyword == "jointmatrix")
+    return parseJointMatrixType(*this, parser);
   if (keyword == "image")
     return parseImageType(*this, parser);
   if (keyword == "ptr")
@@ -859,6 +894,13 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
+  os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+  os << type.getElementType() << ", "
+     << stringifyMatrixLayout(type.getMatrixLayout());
+  os << ", " << stringifyScope(type.getScope()) << ">";
+}
+
 static void print(MatrixType type, DialectAsmPrinter &os) {
   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
   os << ">";
@@ -866,9 +908,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
 
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
-      .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType>(
-          [&](auto type) { print(type, os); })
+      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+            PointerType, RuntimeArrayType, ImageType, SampledImageType,
+            StructType, MatrixType>([&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bf19ce83874e5..8011ddc47f5ff 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -436,6 +436,13 @@ static LogicalResult verifyCastOp(Operation *op,
         resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
   }
 
+  if (auto jointMatrixType =
+          operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
+    operandType = jointMatrixType.getElementType();
+    resultType =
+        resultType.cast<spirv::JointMatrixINTELType>().getElementType();
+  }
+
   auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
   auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
   auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
@@ -1637,6 +1644,17 @@ LogicalResult spirv::CompositeConstructOp::verify() {
     return success();
   }
 
+  if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
+    if (constituents.size() != 1)
+      return emitOpError("has incorrect number of operands: expected ")
+             << "1, but provided " << constituents.size();
+    if (jointType.getElementType() != constituents.front().getType())
+      return emitOpError("operand type mismatch: expected operand type ")
+             << jointType.getElementType() << ", but provided "
+             << constituents.front().getType();
+    return success();
+  }
+
   if (constituents.size() == cType.getNumElements()) {
     for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
       if (constituents[index].getType() != cType.getElementType(index)) {
@@ -3893,6 +3911,70 @@ LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
   return verifyCoopMatrixMulAdd(*this);
 }
 
+static LogicalResult
+verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
+  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
+    return op->emitError(
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  spirv::StorageClass storage =
+      pointer.cast<spirv::PointerType>().getStorageClass();
+  if (storage != spirv::StorageClass::Workgroup &&
+      storage != spirv::StorageClass::CrossWorkgroup)
+    return op->emitError("Pointer storage class must be Workgroup or "
+                         "CrossWorkgroup but provided ")
+           << stringifyStorageClass(storage);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixLoadINTEL
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
+  return verifyPointerAndJointMatrixType(*this, pointer().getType(),
+                                         result().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixStoreINTEL
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
+  return verifyPointerAndJointMatrixType(*this, pointer().getType(),
+                                         object().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixMadINTEL
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
+  if (op.c().getType() != op.result().getType())
+    return op.emitOpError("result and third operand must have the same type");
+  auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
+  auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
+  auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
+  auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
+  if (typeA.getRows() != typeR.getRows() ||
+      typeA.getColumns() != typeB.getRows() ||
+      typeB.getColumns() != typeR.getColumns())
+    return op.emitOpError("matrix size must match");
+  if (typeR.getScope() != typeA.getScope() ||
+      typeR.getScope() != typeB.getScope() ||
+      typeR.getScope() != typeC.getScope())
+    return op.emitOpError("matrix scope must match");
+  if (typeA.getElementType() != typeB.getElementType() ||
+      typeR.getElementType() != typeC.getElementType())
+    return op.emitOpError("matrix element type must match");
+  return success();
+}
+
+LogicalResult spirv::JointMatrixMadINTELOp::verify() {
+  return verifyJointMatrixMad(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // spv.MatrixTimesScalar
 //===----------------------------------------------------------------------===//
@@ -4150,6 +4232,8 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
 
   if (cType.isa<spirv::CooperativeMatrixNVType>())
     return emitError("unsupported composite type  ") << cType;
+  if (cType.isa<spirv::JointMatrixINTELType>())
+    return emitError("unsupported composite type  ") << cType;
   if (constituents.size() != cType.getNumElements())
     return emitError("has incorrect number of operands: expected ")
            << cType.getNumElements() << ", but provided "

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 03a45b8c3884a..a4c622b8d8199 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -89,9 +89,9 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
 bool CompositeType::classof(Type type) {
   if (auto vectorType = type.dyn_cast<VectorType>())
     return isValid(vectorType);
-  return type
-      .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
-           spirv::RuntimeArrayType, spirv::StructType>();
+  return type.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
+                  spirv::JointMatrixINTELType, spirv::MatrixType,
+                  spirv::RuntimeArrayType, spirv::StructType>();
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -110,7 +110,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
+      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+            RuntimeArrayType, VectorType>(
           [](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
@@ -132,6 +133,10 @@ unsigned CompositeType::getNumElements() const {
     llvm_unreachable(
         "invalid to query number of elements of spirv::CooperativeMatrix type");
   }
+  if (isa<JointMatrixINTELType>()) {
+    llvm_unreachable(
+        "invalid to query number of elements of spirv::JointMatrix type");
+  }
   if (isa<RuntimeArrayType>()) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -140,15 +145,16 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
+  return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
+              RuntimeArrayType>();
 }
 
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     Optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
-            StructType>(
+      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+            MatrixType, RuntimeArrayType, StructType>(
           [&](auto type) { type.getExtensions(extensions, storage); })
       .Case<VectorType>([&](VectorType type) {
         return type.getElementType().cast<ScalarType>().getExtensions(
@@ -161,8 +167,8 @@ void CompositeType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     Optional<StorageClass> storage) {
   TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
-            StructType>(
+      .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+            MatrixType, RuntimeArrayType, StructType>(
           [&](auto type) { type.getCapabilities(capabilities, storage); })
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
@@ -255,6 +261,74 @@ void CooperativeMatrixNVType::getCapabilities(
   capabilities.push_back(ref);
 }
 
+//===----------------------------------------------------------------------===//
+// JointMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
+  using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
+
+  static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
+                                           const KeyTy &key) {
+    return new (allocator.allocate<JointMatrixTypeStorage>())
+        JointMatrixTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
+  }
+
+  JointMatrixTypeStorage(const KeyTy &key)
+      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
+        columns(std::get<2>(key)), scope(std::get<4>(key)),
+        matrixLayout(std::get<3>(key)) {}
+
+  Type elementType;
+  unsigned rows;
+  unsigned columns;
+  Scope scope;
+  MatrixLayout matrixLayout;
+};
+
+JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
+                                               unsigned rows, unsigned columns,
+                                               MatrixLayout matrixLayout) {
+  return Base::get(elementType.getContext(), elementType, rows, columns,
+                   matrixLayout, scope);
+}
+
+Type JointMatrixINTELType::getElementType() const {
+  return getImpl()->elementType;
+}
+
+Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
+
+unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
+
+unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
+
+MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
+  return getImpl()->matrixLayout;
+}
+
+void JointMatrixINTELType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+  static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
+  ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+  extensions.push_back(ref);
+}
+
+void JointMatrixINTELType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+  static const Capability caps[] = {Capability::JointMatrixINTEL};
+  ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+  capabilities.push_back(ref);
+}
+
 //===----------------------------------------------------------------------===//
 // ImageType
 //===----------------------------------------------------------------------===//
@@ -1172,6 +1246,7 @@ void MatrixType::getCapabilities(
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
-  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
-           PointerType, RuntimeArrayType, SampledImageType, StructType>();
+  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
+           MatrixType, PointerType, RuntimeArrayType, SampledImageType,
+           StructType>();
 }

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c6787d79ffe7b..1165508572be8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -168,6 +168,8 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
     return processTypeForwardPointer(operands);
+  case spirv::Opcode::OpTypeJointMatrixINTEL:
+    return processType(opcode, operands);
   case spirv::Opcode::OpConstant:
     return processConstant(operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstant:

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e4cfc4b380e46..84d8d9caf202f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -730,6 +730,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processCooperativeMatrixType(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
+  case spirv::Opcode::OpTypeJointMatrixINTEL:
+    return processJointMatrixType(operands);
   case spirv::Opcode::OpTypeImage:
     return processImageType(operands);
   case spirv::Opcode::OpTypeSampledImage:
@@ -888,6 +890,40 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 6) {
+    return emitError(unknownLoc, "OpTypeJointMatrix must have element "
+                                 "type and row x column parameters");
+  }
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
+           << operands[1];
+  }
+
+  auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
+  if (!scope) {
+    return emitError(unknownLoc,
+                     "OpTypeJointMatrix references undefined scope <id> ")
+           << operands[5];
+  }
+  auto matrixLayout =
+      spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
+  if (!matrixLayout) {
+    return emitError(unknownLoc,
+                     "OpTypeJointMatrix references undefined scope <id> ")
+           << operands[4];
+  }
+  unsigned rows = getConstantInt(operands[2]).getInt();
+  unsigned columns = getConstantInt(operands[3]).getInt();
+
+  typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
+      elementTy, scope.value(), rows, columns, matrixLayout.value());
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 442b8e369d77b..784ec2b3e624c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -257,6 +257,8 @@ class Deserializer {
 
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
+
   LogicalResult processImageType(ArrayRef<uint32_t> operands);
 
   LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b8be9433e94e8..7c0b9f33f9e77 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -598,6 +598,27 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
+  if (auto jointMatrixType = type.dyn_cast<spirv::JointMatrixINTELType>()) {
+    uint32_t elementTypeID = 0;
+    if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
+                               elementTypeID, serializationCtx))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
+    auto getConstantOp = [&](uint32_t id) {
+      auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
+      return prepareConstantInt(loc, attr);
+    };
+    operands.push_back(elementTypeID);
+    operands.push_back(getConstantOp(jointMatrixType.getRows()));
+    operands.push_back(getConstantOp(jointMatrixType.getColumns()));
+    operands.push_back(getConstantOp(
+        static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
+    operands.push_back(
+        getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+    return success();
+  }
+
   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
     uint32_t elementTypeID = 0;
     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,

diff  --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
new file mode 100644
index 0000000000000..e6856f3cd9e7c
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
@@ -0,0 +1,158 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @joint_matrix_load
+spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+  // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+  %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+  spv.Return
+}
+
+// -----
+// CHECK-LABEL: @joint_matrix_load_memaccess
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
+  // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+  %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_load_
diff _ptr_type
+spv.func @joint_matrix_load_
diff _ptr_type(%ptr : !spv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
+  // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
+  %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_store
+spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
+  // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
+  spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_store_memaccess
+spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
+  // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+  spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_length
+spv.func @joint_matrix_length() -> i32 "None" {
+  // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
+  %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
+  spv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: @joint_matrix_muladd
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}}  : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+  %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32,  RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_add
+spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_sub
+spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_sdiv
+spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_udiv
+spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fadd
+spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fsub
+spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fdiv
+spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+  // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+// CHECK-LABEL: @joint_matrix_access_chain
+spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+  %0 = spv.Constant 0: i32
+  // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+  %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+  spv.ReturnValue %1 : !spv.ptr<f32, Function>
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
+  %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
+  %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+  // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}}
+  %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+  // expected-error @+1 {{matrix element type must match}}
+  %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
+  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+  %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32) "None" {
+  // expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
+  %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+  spv.Return
+}

diff  --git a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
new file mode 100644
index 0000000000000..1c9ef213cbafa
--- /dev/null
+++ b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
+  // CHECK-LABEL: @joint_matrix_load
+  spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+    // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+    %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_load_memaccess
+  spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+    // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_store
+  spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
+    // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
+    spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_store_memaccess
+  spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
+    // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+    spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_length
+  spv.func @joint_matrix_length() -> i32 "None" {
+    // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.ReturnValue %0 : i32
+  }
+
+  // CHECK-LABEL: @joint_matrix_muladd
+  spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}}  : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+    %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_add
+  spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_sub
+  spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_sdiv
+  spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_udiv
+  spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_fadd
+  spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_fsub
+  spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_fdiv
+  spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+    // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+    spv.Return
+  }
+
+  // CHECK-LABEL: @joint_matrix_access_chain
+  spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+    %0 = spv.Constant 0: i32
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+    %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+    spv.ReturnValue %1 : !spv.ptr<f32, Function>
+  }
+}

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 1db8c96bdc99c..8329cfe18dc0a 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -518,7 +518,8 @@ static void emitAttributeSerialization(const Attribute &attr,
   os << tabs
      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
   if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
-      attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+      attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
+      attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
     // These two enums are encoded as <id> to constant values in SPIR-V blob,
     // but we directly use the constant value as attribute in SPIR-V dialect. So
     // need to handle them separately from normal enum attributes.
@@ -810,7 +811,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
                                          StringRef words, StringRef wordIndex,
                                          raw_ostream &os) {
   if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
-      attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+      attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
+      attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
     // These two enums are encoded as <id> to constant values in SPIR-V blob,
     // but we directly use the constant value as attribute in SPIR-V dialect. So
     // need to handle them separately from normal enum attributes.


        


More information about the Mlir-commits mailing list