[Mlir-commits] [mlir] b8f62dc - [MLIR][SPIRV] Add intel joint matrix ops
Nirvedh Meshram
llvmlistbot at llvm.org
Mon Aug 15 16:51:58 PDT 2022
Author: Nirvedh Meshram
Date: 2022-08-15T23:49:45Z
New Revision: b8f62dc22a9333bc016b19921a9dac6be6cb947c
URL: https://github.com/llvm/llvm-project/commit/b8f62dc22a9333bc016b19921a9dac6be6cb947c
DIFF: https://github.com/llvm/llvm-project/commit/b8f62dc22a9333bc016b19921a9dac6be6cb947c.diff
LOG: [MLIR][SPIRV] Add intel joint matrix ops
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D131586
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
mlir/test/Target/SPIRV/joint-matrix-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 7ce34dcd5dedb..f5bf65bf6d0f7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -27,12 +27,12 @@ class SPV_ArithmeticBinaryOp<string mnemonic, Type type,
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
- SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
- SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
+ SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand1,
+ SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand2
);
let results = (outs
- SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+ SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 791615274d52f..02742278fe57f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -64,6 +64,27 @@ def SPV_CooperativeMatrixPropertiesNVArrayAttr :
TypedArrayAttrBase<SPV_CooperativeMatrixPropertiesNVAttr,
"CooperativeMatrixPropertiesNV array attribute">;
+// Description of the supported joint matrix operations. See
+// https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
+def SPV_JointMatrixPropertiesINTELAttr :
+ SPV_Attr<"JointMatrixPropertiesINTEL", "joint_matrix_props"> {
+ let parameters = (ins
+ "int":$m_size,
+ "int":$n_size,
+ "int":$k_size,
+ "mlir::Type":$a_type,
+ "mlir::Type":$b_type,
+ "mlir::Type":$c_type,
+ "mlir::Type":$result_type,
+ "mlir::spirv::ScopeAttr":$scope
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def SPV_JointMatrixPropertiesINTELArrayAttr :
+ TypedArrayAttrBase<SPV_JointMatrixPropertiesINTELAttr,
+ "JointMatrixPropertiesINTEL array attribute">;
+
// This attribute specifies the limits for various resources on the target
// architecture.
//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 0cfb3ead3642c..84601dd3eae55 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -387,6 +387,7 @@ def SPV_INTEL_debug_module : I32EnumAttrCase<"SPV_INTEL_de
def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp_fast_math_mode", 4027>;
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
+def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -443,7 +444,7 @@ def SPV_ExtensionAttr :
SPV_INTEL_usm_storage_classes, SPV_INTEL_io_pipes, SPV_INTEL_blocking_pipes,
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
- SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
+ SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1390,6 +1391,12 @@ def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderS
];
}
+def SPV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMatrixINTEL", 6118> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_joint_matrix]>
+ ];
+}
+
def SPV_CapabilityAttr :
SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
@@ -1481,7 +1488,7 @@ def SPV_CapabilityAttr :
SPV_C_UniformTexelBufferArrayNonUniformIndexing,
SPV_C_StorageTexelBufferArrayNonUniformIndexing,
SPV_C_ShaderViewportIndexLayerEXT, SPV_C_ShaderViewportMaskNV,
- SPV_C_ShaderStereoViewNV
+ SPV_C_ShaderStereoViewNV, SPV_C_JointMatrixINTEL
]>;
def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3981,6 +3988,16 @@ def SPV_SamplerUseAttr: SPV_I32EnumAttr<
"image_sampler_use_info",
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
+def SPV_ML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 0>;
+def SPV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
+def SPV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
+def SPV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>;
+
+def SPV_MatrixLayoutAttr :
+ SPV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
+ SPV_ML_ColumnMajor, SPV_ML_RowMajor, SPV_ML_PackedA, SPV_ML_PackedB
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
@@ -4013,6 +4030,8 @@ def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsCooperativeMatrixType :
CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
def SPV_IsImageType : CPred<"$_self.isa<::mlir::spirv::ImageType>()">;
+def SPV_IsJointMatrixType :
+ CPred<"$_self.isa<::mlir::spirv::JointMatrixINTELType>()">;
def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">;
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
@@ -4043,6 +4062,8 @@ def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPV_AnyImage : DialectType<SPIRV_Dialect, SPV_IsImageType,
"any SPIR-V image type">;
+def SPV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPV_IsJointMatrixType,
+ "any SPIR-V joint matrix type">;
def SPV_AnyMatrix : DialectType<SPIRV_Dialect, SPV_IsMatrixType,
"any SPIR-V matrix type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
@@ -4057,11 +4078,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
def SPV_Composite :
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
- SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>;
+ SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix]>;
def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct,
- SPV_AnyCooperativeMatrix, SPV_AnyMatrix, SPV_AnySampledImage
+ SPV_AnyCooperativeMatrix, SPV_AnyJointMatrix, SPV_AnyMatrix,
+ SPV_AnySampledImage
]>;
def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4072,6 +4094,11 @@ class SPV_CoopMatrixOfType<list<Type> allowedTypes> :
"$_self.cast<::mlir::spirv::CooperativeMatrixNVType>().getElementType()",
"Cooperative Matrix">;
+class SPV_JointMatrixOfType<list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, SPV_IsJointMatrixType,
+ "$_self.cast<::mlir::spirv::JointMatrixINTELType>().getElementType()",
+ "Joint Matrix">;
+
class SPV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
@@ -4079,6 +4106,14 @@ class SPV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
SPV_CoopMatrixOfType<[type]>]>;
+class SPV_ScalarOrVectorOrJointMatrixOf<Type type> :
+ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ SPV_JointMatrixOfType<[type]>]>;
+
+class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<Type type> :
+ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>;
+
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
@@ -4311,6 +4346,11 @@ def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINT
def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
+def SPV_OC_OpTypeJointMatrixINTEL : I32EnumAttrCase<"OpTypeJointMatrixINTEL", 6119>;
+def SPV_OC_OpJointMatrixLoadINTEL : I32EnumAttrCase<"OpJointMatrixLoadINTEL", 6120>;
+def SPV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreINTEL", 6121>;
+def SPV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
+def SPV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4376,7 +4416,10 @@ def SPV_OpcodeAttr :
SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL,
- SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT
+ SPV_OC_OpAssumeTrueKHR, SPV_OC_OpAtomicFAddEXT,
+ SPV_OC_OpTypeJointMatrixINTEL, SPV_OC_OpJointMatrixLoadINTEL,
+ SPV_OC_OpJointMatrixStoreINTEL, SPV_OC_OpJointMatrixMadINTEL,
+ SPV_OC_OpTypejointMatrixWorkItemLengthINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 5048dd10ae575..27bad3f08e083 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -23,11 +23,11 @@ class SPV_CastOp<string mnemonic, Type resultType, Type operandType,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultShape])> {
let arguments = (ins
- SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand
+ SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<operandType>:$operand
);
let results = (outs
- SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result
+ SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<resultType>:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
new file mode 100644
index 0000000000000..aa45ef80b5e94
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
@@ -0,0 +1,248 @@
+//===- SPIRVJointMatrixOps.td - joint matmul ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of joint matrix multiply extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
+#define MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
+
+// -----
+
+def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
+ [NoSideEffect]> {
+ let summary = "See extension SPV_INTEL_joint_matrix";
+
+ let description = [{
+ Return number of components owned by the current work-item in
+ a joint matrix.
+
+ Result Type must be an 32-bit unsigned integer type scalar.
+
+ Type is a joint matrix type.
+
+ ``` {.ebnf}
+ joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL
+ ` : ` joint-matrix-type
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<Subgroup, i32, 8, 16>
+ ```
+ }];
+
+ let assemblyFormat = "attr-dict `:` $type";
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_6>,
+ Extension<[SPV_INTEL_joint_matrix]>,
+ Capability<[SPV_C_JointMatrixINTEL]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$type
+ );
+
+ let results = (outs
+ SPV_Int32:$result
+ );
+ let hasVerifier = 0;
+}
+
+// -----
+
+def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
+ let summary = "See extension SPV_INTEL_joint_matrix";
+
+ let description = [{
+ Load a matrix through a pointer.
+
+ Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
+
+ Pointer is the pointer to load through. It specifies start of memory region where
+ elements of the matrix are stored and arranged according to Layout.
+
+ Stride is the number of elements in memory between beginnings of successive rows,
+ columns (or words) in the result. It must be a scalar integer type.
+
+ Layout indicates how the values loaded from memory are arranged. It must be the
+ result of a constant instruction.
+
+ Scope is syncronization scope for operation on the matrix. It must be the result
+ of a constant instruction with scalar integer type.
+
+ If present, any Memory Operands must begin with a memory operand literal. If not
+ present, it is the same as specifying the memory operand None.
+
+ #### Example:
+ ```mlir
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
+ {memory_access = #spv.memory_access<Volatile>} :
+ (!spv.ptr<i32, CrossWorkgroup>, i32) ->
+ !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $scope $layout operands attr-dict `:` `(` type(operands) `)` `->` type($result)
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_6>,
+ Extension<[SPV_INTEL_joint_matrix]>,
+ Capability<[SPV_C_JointMatrixINTEL]>
+ ];
+
+ let arguments = (ins
+ SPV_ScopeAttr:$scope,
+ SPV_MatrixLayoutAttr:$layout,
+ SPV_AnyPtr:$pointer,
+ SPV_Integer:$stride,
+ OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+ OptionalAttr<I32Attr>:$alignment
+ );
+
+ let results = (outs
+ SPV_AnyJointMatrix:$result
+ );
+}
+
+// -----
+
+def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
+ [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
+ let summary = "See extension SPV_INTEL_joint_matrix";
+
+ let description = [{
+ Multiply matrix A by matrix B and add matrix C to the result
+ of the multiplication: A*B+C. Here A is a M x K matrix, B is
+ a K x N matrix and C is a M x N matrix.
+
+ Behavior is undefined if sizes of operands do not meet the
+ conditions above. All operands and the Result Type must be
+ OpTypeJointMatrixINTEL.
+
+ A must be a OpTypeJointMatrixINTEL whose Component Type is a
+ signed numerical type, Row Count equals to M and Column Count
+ equals to K
+
+ B must be a OpTypeJointMatrixINTEL whose Component Type is a
+ signed numerical type, Row Count equals to K and Column Count
+ equals to N
+
+ C and Result Type must be a OpTypeJointMatrixINTEL with Row
+ Count equals to M and Column Count equals to N
+
+ Scope is syncronization scope for operation on the matrix.
+ It must be the result of a constant instruction with scalar
+ integer type.
+
+ #### Example:
+ ```mlir
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
+ !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
+ !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
+ -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ ```
+
+ }];
+
+ let assemblyFormat = [{
+ $scope operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_6>,
+ Extension<[SPV_INTEL_joint_matrix]>,
+ Capability<[SPV_C_JointMatrixINTEL]>
+ ];
+
+ let arguments = (ins
+ SPV_ScopeAttr:$scope,
+ SPV_AnyJointMatrix:$a,
+ SPV_AnyJointMatrix:$b,
+ SPV_AnyJointMatrix:$c
+ );
+
+ let results = (outs
+ SPV_AnyJointMatrix:$result
+ );
+}
+
+// -----
+
+def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
+ let summary = "See extension SPV_INTEL_joint_matrix";
+
+ let description = [{
+ Store a matrix through a pointer.
+
+ Pointer is the pointer to store through. It specifies
+ start of memory region where elements of the matrix must
+ be stored and arranged according to Layout.
+
+ Object is the matrix to store. It must be
+ OpTypeJointMatrixINTEL.
+
+ Stride is the number of elements in memory between beginnings
+ of successive rows, columns (or words) of the Object. It must
+ be a scalar integer type.
+
+ Layout indicates how the values stored to memory are arranged.
+ It must be the result of a constant instruction.
+
+ Scope is syncronization scope for operation on the matrix.
+ It must be the result of a constant instruction with scalar
+ integer type.
+
+ If present, any Memory Operands must begin with a memory operand
+ literal. If not present, it is the same as specifying the memory
+ operand None.
+
+ #### Example:
+ ```mlir
+ spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
+ {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
+ !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+ ```
+
+ }];
+
+ let assemblyFormat = [{
+ $scope $layout operands attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_6>,
+ Extension<[SPV_INTEL_joint_matrix]>,
+ Capability<[SPV_C_JointMatrixINTEL]>
+ ];
+
+ let arguments = (ins
+ SPV_ScopeAttr:$scope,
+ SPV_MatrixLayoutAttr:$layout,
+ SPV_AnyPtr:$pointer,
+ SPV_AnyJointMatrix:$object,
+ SPV_Integer:$stride,
+ OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+ OptionalAttr<I32Attr>:$alignment
+ );
+
+ let results = (outs);
+}
+
+// -----
+
+#endif // MLIR_DIALECT_SPIRV_IR_JOINT_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index aa87f0e142e80..5e8e5e4c7ce92 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -30,6 +30,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCastOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 40a4acff751b1..d9737e4c2c579 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
+struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
@@ -420,6 +421,33 @@ class CooperativeMatrixNVType
Optional<StorageClass> storage = llvm::None);
};
+// SPIR-V joint matrix type
+class JointMatrixINTELType
+ : public Type::TypeBase<JointMatrixINTELType, CompositeType,
+ detail::JointMatrixTypeStorage> {
+public:
+ using Base::Base;
+
+ static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
+ unsigned columns, MatrixLayout matrixLayout);
+ Type getElementType() const;
+
+ /// Return the scope of the joint matrix.
+ Scope getScope() const;
+ /// return the number of rows of the matrix.
+ unsigned getRows() const;
+ /// return the number of columns of the matrix.
+ unsigned getColumns() const;
+
+ /// return the layout of the matrix
+ MatrixLayout getMatrixLayout() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage = llvm::None);
+};
+
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 88bae6bb43263..d5449a2c09b6c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -348,6 +348,39 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
}
+// joint-matrix-type ::= `!spv.jointmatrix` `<`rows `x` columns `x` element-type
+// `,` layout `,` scope`>`
+static Type parseJointMatrixType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return Type();
+
+ SmallVector<int64_t, 2> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+ if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+ return Type();
+
+ if (dims.size() != 2) {
+ parser.emitError(countLoc, "expected rows and columns size");
+ return Type();
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return Type();
+ MatrixLayout matrixLayout;
+ if (parser.parseComma() ||
+ parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
+ return Type();
+ Scope scope;
+ if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ return Type();
+ if (parser.parseGreater())
+ return Type();
+ return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1],
+ matrixLayout);
+}
+
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -753,6 +786,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseArrayType(*this, parser);
if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
+ if (keyword == "jointmatrix")
+ return parseJointMatrixType(*this, parser);
if (keyword == "image")
return parseImageType(*this, parser);
if (keyword == "ptr")
@@ -859,6 +894,13 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(JointMatrixINTELType type, DialectAsmPrinter &os) {
+ os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+ os << type.getElementType() << ", "
+ << stringifyMatrixLayout(type.getMatrixLayout());
+ os << ", " << stringifyScope(type.getScope()) << ">";
+}
+
static void print(MatrixType type, DialectAsmPrinter &os) {
os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
os << ">";
@@ -866,9 +908,9 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
- .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType>(
- [&](auto type) { print(type, os); })
+ .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+ PointerType, RuntimeArrayType, ImageType, SampledImageType,
+ StructType, MatrixType>([&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bf19ce83874e5..8011ddc47f5ff 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -436,6 +436,13 @@ static LogicalResult verifyCastOp(Operation *op,
resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
}
+ if (auto jointMatrixType =
+ operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
+ operandType = jointMatrixType.getElementType();
+ resultType =
+ resultType.cast<spirv::JointMatrixINTELType>().getElementType();
+ }
+
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
@@ -1637,6 +1644,17 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return success();
}
+ if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
+ if (constituents.size() != 1)
+ return emitOpError("has incorrect number of operands: expected ")
+ << "1, but provided " << constituents.size();
+ if (jointType.getElementType() != constituents.front().getType())
+ return emitOpError("operand type mismatch: expected operand type ")
+ << jointType.getElementType() << ", but provided "
+ << constituents.front().getType();
+ return success();
+ }
+
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
@@ -3893,6 +3911,70 @@ LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
return verifyCoopMatrixMulAdd(*this);
}
+static LogicalResult
+verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
+ Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+ if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
+ return op->emitError(
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ spirv::StorageClass storage =
+ pointer.cast<spirv::PointerType>().getStorageClass();
+ if (storage != spirv::StorageClass::Workgroup &&
+ storage != spirv::StorageClass::CrossWorkgroup)
+ return op->emitError("Pointer storage class must be Workgroup or "
+ "CrossWorkgroup but provided ")
+ << stringifyStorageClass(storage);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixLoadINTEL
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
+ return verifyPointerAndJointMatrixType(*this, pointer().getType(),
+ result().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixStoreINTEL
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
+ return verifyPointerAndJointMatrixType(*this, pointer().getType(),
+ object().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.JointMatrixMadINTEL
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
+ if (op.c().getType() != op.result().getType())
+ return op.emitOpError("result and third operand must have the same type");
+ auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
+ if (typeA.getRows() != typeR.getRows() ||
+ typeA.getColumns() != typeB.getRows() ||
+ typeB.getColumns() != typeR.getColumns())
+ return op.emitOpError("matrix size must match");
+ if (typeR.getScope() != typeA.getScope() ||
+ typeR.getScope() != typeB.getScope() ||
+ typeR.getScope() != typeC.getScope())
+ return op.emitOpError("matrix scope must match");
+ if (typeA.getElementType() != typeB.getElementType() ||
+ typeR.getElementType() != typeC.getElementType())
+ return op.emitOpError("matrix element type must match");
+ return success();
+}
+
+LogicalResult spirv::JointMatrixMadINTELOp::verify() {
+ return verifyJointMatrixMad(*this);
+}
+
//===----------------------------------------------------------------------===//
// spv.MatrixTimesScalar
//===----------------------------------------------------------------------===//
@@ -4150,6 +4232,8 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
if (cType.isa<spirv::CooperativeMatrixNVType>())
return emitError("unsupported composite type ") << cType;
+ if (cType.isa<spirv::JointMatrixINTELType>())
+ return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 03a45b8c3884a..a4c622b8d8199 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -89,9 +89,9 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
bool CompositeType::classof(Type type) {
if (auto vectorType = type.dyn_cast<VectorType>())
return isValid(vectorType);
- return type
- .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType>();
+ return type.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
+ spirv::JointMatrixINTELType, spirv::MatrixType,
+ spirv::RuntimeArrayType, spirv::StructType>();
}
bool CompositeType::isValid(VectorType type) {
@@ -110,7 +110,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
+ .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+ RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
@@ -132,6 +133,10 @@ unsigned CompositeType::getNumElements() const {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
+ if (isa<JointMatrixINTELType>()) {
+ llvm_unreachable(
+ "invalid to query number of elements of spirv::JointMatrix type");
+ }
if (isa<RuntimeArrayType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@@ -140,15 +145,16 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
+ return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
+ RuntimeArrayType>();
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
- StructType>(
+ .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+ MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return type.getElementType().cast<ScalarType>().getExtensions(
@@ -161,8 +167,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
- StructType>(
+ .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
+ MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
@@ -255,6 +261,74 @@ void CooperativeMatrixNVType::getCapabilities(
capabilities.push_back(ref);
}
+//===----------------------------------------------------------------------===//
+// JointMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::JointMatrixTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<Type, unsigned, unsigned, MatrixLayout, Scope>;
+
+ static JointMatrixTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<JointMatrixTypeStorage>())
+ JointMatrixTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, rows, columns, matrixLayout, scope);
+ }
+
+ JointMatrixTypeStorage(const KeyTy &key)
+ : elementType(std::get<0>(key)), rows(std::get<1>(key)),
+ columns(std::get<2>(key)), scope(std::get<4>(key)),
+ matrixLayout(std::get<3>(key)) {}
+
+ Type elementType;
+ unsigned rows;
+ unsigned columns;
+ Scope scope;
+ MatrixLayout matrixLayout;
+};
+
+JointMatrixINTELType JointMatrixINTELType::get(Type elementType, Scope scope,
+ unsigned rows, unsigned columns,
+ MatrixLayout matrixLayout) {
+ return Base::get(elementType.getContext(), elementType, rows, columns,
+ matrixLayout, scope);
+}
+
+Type JointMatrixINTELType::getElementType() const {
+ return getImpl()->elementType;
+}
+
+Scope JointMatrixINTELType::getScope() const { return getImpl()->scope; }
+
+unsigned JointMatrixINTELType::getRows() const { return getImpl()->rows; }
+
+unsigned JointMatrixINTELType::getColumns() const { return getImpl()->columns; }
+
+MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
+ return getImpl()->matrixLayout;
+}
+
+void JointMatrixINTELType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
+ ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+ extensions.push_back(ref);
+}
+
+void JointMatrixINTELType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ static const Capability caps[] = {Capability::JointMatrixINTEL};
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+ capabilities.push_back(ref);
+}
+
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
@@ -1172,6 +1246,7 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
- addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
- PointerType, RuntimeArrayType, SampledImageType, StructType>();
+ addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
+ MatrixType, PointerType, RuntimeArrayType, SampledImageType,
+ StructType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c6787d79ffe7b..1165508572be8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -168,6 +168,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
return processTypeForwardPointer(operands);
+ case spirv::Opcode::OpTypeJointMatrixINTEL:
+ return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands, /*isSpec=*/false);
case spirv::Opcode::OpSpecConstant:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e4cfc4b380e46..84d8d9caf202f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -730,6 +730,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processCooperativeMatrixType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
+ case spirv::Opcode::OpTypeJointMatrixINTEL:
+ return processJointMatrixType(operands);
case spirv::Opcode::OpTypeImage:
return processImageType(operands);
case spirv::Opcode::OpTypeSampledImage:
@@ -888,6 +890,40 @@ spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processJointMatrixType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 6) {
+ return emitError(unknownLoc, "OpTypeJointMatrix must have element "
+ "type and row x column parameters");
+ }
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc, "OpTypeJointMatrix references undefined <id> ")
+ << operands[1];
+ }
+
+ auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
+ if (!scope) {
+ return emitError(unknownLoc,
+ "OpTypeJointMatrix references undefined scope <id> ")
+ << operands[5];
+ }
+ auto matrixLayout =
+ spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
+ if (!matrixLayout) {
+ return emitError(unknownLoc,
+ "OpTypeJointMatrix references undefined scope <id> ")
+ << operands[4];
+ }
+ unsigned rows = getConstantInt(operands[2]).getInt();
+ unsigned columns = getConstantInt(operands[3]).getInt();
+
+ typeMap[operands[0]] = spirv::JointMatrixINTELType::get(
+ elementTy, scope.value(), rows, columns, matrixLayout.value());
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 442b8e369d77b..784ec2b3e624c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -257,6 +257,8 @@ class Deserializer {
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+ LogicalResult processJointMatrixType(ArrayRef<uint32_t> operands);
+
LogicalResult processImageType(ArrayRef<uint32_t> operands);
LogicalResult processSampledImageType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b8be9433e94e8..7c0b9f33f9e77 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -598,6 +598,27 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
+ if (auto jointMatrixType = type.dyn_cast<spirv::JointMatrixINTELType>()) {
+ uint32_t elementTypeID = 0;
+ if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
+ elementTypeID, serializationCtx))) {
+ return failure();
+ }
+ typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
+ auto getConstantOp = [&](uint32_t id) {
+ auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
+ return prepareConstantInt(loc, attr);
+ };
+ operands.push_back(elementTypeID);
+ operands.push_back(getConstantOp(jointMatrixType.getRows()));
+ operands.push_back(getConstantOp(jointMatrixType.getColumns()));
+ operands.push_back(getConstantOp(
+ static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
+ operands.push_back(
+ getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
+ return success();
+ }
+
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
diff --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
new file mode 100644
index 0000000000000..e6856f3cd9e7c
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
@@ -0,0 +1,158 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @joint_matrix_load
+spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+ spv.Return
+}
+
+// -----
+// CHECK-LABEL: @joint_matrix_load_memaccess
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_load_
diff _ptr_type
+spv.func @joint_matrix_load_
diff _ptr_type(%ptr : !spv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_store
+spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
+ // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
+ spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_store_memaccess
+spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
+ // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+ spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_length
+spv.func @joint_matrix_length() -> i32 "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
+ %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
+ spv.ReturnValue %0 : i32
+}
+
+// CHECK-LABEL: @joint_matrix_muladd
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_add
+spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_sub
+spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_sdiv
+spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_udiv
+spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fadd
+spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fsub
+spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @joint_matrix_fdiv
+spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+// CHECK-LABEL: @joint_matrix_access_chain
+spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+ %0 = spv.Constant 0: i32
+ // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+ %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+ spv.ReturnValue %1 : !spv.ptr<f32, Function>
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}}
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // expected-error @+1 {{matrix element type must match}}
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
+ // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32) "None" {
+ // expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+}
diff --git a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
new file mode 100644
index 0000000000000..1c9ef213cbafa
--- /dev/null
+++ b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
+ // CHECK-LABEL: @joint_matrix_load
+ spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_load_memaccess
+ spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_store
+ spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
+ // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
+ spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_store_memaccess
+ spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
+ // CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+ spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_length
+ spv.func @joint_matrix_length() -> i32 "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.ReturnValue %0 : i32
+ }
+
+ // CHECK-LABEL: @joint_matrix_muladd
+ spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_add
+ spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_sub
+ spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_sdiv
+ spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_udiv
+ spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_fadd
+ spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_fsub
+ spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_fdiv
+ spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" {
+ // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @joint_matrix_access_chain
+ spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
+ %0 = spv.Constant 0: i32
+ // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+ %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32
+ spv.ReturnValue %1 : !spv.ptr<f32, Function>
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 1db8c96bdc99c..8329cfe18dc0a 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -518,7 +518,8 @@ static void emitAttributeSerialization(const Attribute &attr,
os << tabs
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
- attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+ attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
+ attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
@@ -810,7 +811,8 @@ static void emitAttributeDeserialization(const Attribute &attr,
StringRef words, StringRef wordIndex,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
- attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+ attr.getAttrDefName() == "SPV_MemorySemanticsAttr" ||
+ attr.getAttrDefName() == "SPV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
More information about the Mlir-commits
mailing list