[Mlir-commits] [mlir] 4ba61f5 - [mlirv][spirv] Add KHR Cooperative Matrix type and extension
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 12 18:13:08 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-12T21:11:08-04:00
New Revision: 4ba61f5a30d2f24670c7d35ce0ad3f2572808240
URL: https://github.com/llvm/llvm-project/commit/4ba61f5a30d2f24670c7d35ce0ad3f2572808240
DIFF: https://github.com/llvm/llvm-project/commit/4ba61f5a30d2f24670c7d35ce0ad3f2572808240.diff
LOG: [mlirv][spirv] Add KHR Cooperative Matrix type and extension
Start plumbing through support for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.
Register the extension, add new coop matrix type, and add
`spirv.KHR.CooperativeMatrixLength` op to exercise it.
Make sure that mixing of the KHR and NV coop matrix extensions is not
allowed. Make cast verification more robust.
Reviewed By: antiagainst, qedawkins
Differential Revision: https://reviews.llvm.org/D154877
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
mlir/test/Dialect/SPIRV/IR/types.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2b33327dec1fbc..885fb7138b1410 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -343,6 +343,7 @@ def SPV_KHR_uniform_group_instructions : I32EnumAttrCase<"SPV_KHR_uniform_
def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup_rotate", 28>;
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
+def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -435,6 +436,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
+ SPV_KHR_cooperative_matrix,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -835,6 +837,12 @@ def SPIRV_C_RayCullMaskKHR : I32EnumAttrCase<"RayCu
Extension<[SPV_KHR_ray_cull_mask]>
];
}
+def SPIRV_C_CooperativeMatrixKHR : I32EnumAttrCase<"CooperativeMatrixKHR", 6022> {
+ list<Availability> availability = [
+ Extension<[SPV_KHR_cooperative_matrix]>,
+ MinVersion<SPIRV_V_1_6>
+ ];
+}
def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitInstructions", 6025> {
list<Availability> availability = [
Extension<[SPV_KHR_bit_instructions]>
@@ -1457,6 +1465,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
+ SPIRV_C_CooperativeMatrixKHR,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
@@ -4017,7 +4026,7 @@ def SPIRV_ArrayedAttr : SPIRV_I32EnumAttr<
def SPIRV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
def SPIRV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>;
-def SPIRV_SamplingAttr: SPIRV_I32EnumAttr<
+def SPIRV_SamplingAttr : SPIRV_I32EnumAttr<
"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
"image_sampling_info", [SPIRV_ISI_SingleSampled, SPIRV_ISI_MultiSampled]>;
@@ -4035,11 +4044,23 @@ def SPIRV_ML_RowMajor : I32EnumAttrCase<"RowMajor", 1>;
def SPIRV_ML_PackedA : I32EnumAttrCase<"PackedA", 2>;
def SPIRV_ML_PackedB : I32EnumAttrCase<"PackedB", 3>;
-def SPIRV_MatrixLayoutAttr :
+def SPIRV_MatrixLayoutAttr :
SPIRV_I32EnumAttr<"MatrixLayout", "valid SPIR-V MatrixLayout", "matrixLayout", [
SPIRV_ML_ColumnMajor, SPIRV_ML_RowMajor, SPIRV_ML_PackedA, SPIRV_ML_PackedB
]>;
+// Cooperative Matrix Use for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
+def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
+def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;
+
+def SPIRV_KHR_CooperativeMatrixUseAttr :
+ SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
+ "valid SPIR-V Cooperative Matrix Use (KHR)",
+ "coop_matrix_use_khr", [
+ SPIRV_KHR_CMU_MatrixA, SPIRV_KHR_CMU_MatrixB, SPIRV_KHR_CMU_MatrixAcc
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
@@ -4069,6 +4090,8 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
!interleave(widths, "/") # "-bit signless/unsigned integer">;
def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
+def SPIRV_IsCooperativeMatrixType :
+ CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
def SPIRV_IsCooperativeMatrixNVType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
@@ -4100,6 +4123,9 @@ def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
"any SPIR-V pointer type">;
def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
"any SPIR-V array type">;
+def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
+ SPIRV_IsCooperativeMatrixType,
+ "any SPIR-V cooperative matrix type">;
def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
SPIRV_IsCooperativeMatrixNVType,
"any SPIR-V NV cooperative matrix type">;
@@ -4121,17 +4147,23 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
+ SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrixNV, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
- SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
+ SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
+class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
+ ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixType,
+ "::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
+ "Cooperative Matrix">;
+
class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
"::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
@@ -4147,10 +4179,12 @@ class SPIRV_ScalarOrVectorOf<Type type> :
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
- SPIRV_CoopMatrixNVOfType<[type]>]>;
+ SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
- AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixNVOfType<[type]>]>;
+ AnyTypeOf<[SPIRV_AnyMatrix,
+ SPIRV_CoopMatrixOfType<[type]>,
+ SPIRV_CoopMatrixNVOfType<[type]>]>;
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4400,6 +4434,8 @@ def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>;
def SPIRV_OC_OpSDotAccSat : I32EnumAttrCase<"OpSDotAccSat", 4453>;
def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>;
def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
+def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
+def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
@@ -4498,7 +4534,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
- SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV,
+ SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpTypeCooperativeMatrixNV,
SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index fada7f64899ad0..e6e2bbb26d094e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -7,12 +7,62 @@
//===----------------------------------------------------------------------===//
//
// This is the op definition spec of cooperative matrix multiply extension ops.
+// We support both cooperative matrix extensions:
+// - SPV_NV_cooperative_matrix
+// - SPV_KHR_cooperative_matrix
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix extension ops.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+def SPIRV_KHRCooperativeMatrixLengthOp :
+ SPIRV_KhrVendorOp<"CooperativeMatrixLength", [Pure]> {
+ let summary = "Queries the number of cooperative matrix components";
+
+ let description = [{
+ Number of components of a cooperative matrix type accessible to each
+ invocation when treated as a composite.
+
+ The type attribute must be a cooperative matrix type.
+
+ ``` {.ebnf}
+ cooperative-matrix-length-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLength
+ ` : ` cooperative-matrix-type
+ ```
+
+ #### Example:
+
+ ```
+ %0 = spirv.KHR.CooperativeMatrixLength :
+ !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ ```
+ }];
+
+ let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
+
+ let availability = [
+ MinVersion<SPIRV_V_1_6>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_KHR_cooperative_matrix]>,
+ Capability<[SPIRV_C_CooperativeMatrixKHR]>
+ ];
+
+ let arguments = (ins
+ TypeAttr:$cooperative_matrix_type
+ );
+
+ let results = (outs
+ SPIRV_Int32:$result
+ );
+}
+
//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
@@ -59,7 +109,6 @@ def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLengt
let results = (outs
SPIRV_Int32:$result
);
- let hasVerifier = 0;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index b5b1f5ad4f52f1..07f2f158ecabb6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -20,6 +20,7 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
+#include <cstdint>
#include <tuple>
namespace mlir {
@@ -27,6 +28,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
+struct CooperativeMatrixTypeStorage;
struct CooperativeMatrixNVTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
@@ -398,6 +400,33 @@ class StructType
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+// SPIR-V KHR cooperative matrix type
+class CooperativeMatrixType
+ : public Type::TypeBase<CooperativeMatrixType, CompositeType,
+ detail::CooperativeMatrixTypeStorage> {
+public:
+ using Base::Base;
+
+ static CooperativeMatrixType get(Type elementType, uint32_t rows,
+ uint32_t columns, Scope scope,
+ CooperativeMatrixUseKHR use);
+ Type getElementType() const;
+
+ /// Returns the scope of the matrix.
+ Scope getScope() const;
+ /// Returns the number of rows of the matrix.
+ uint32_t getRows() const;
+ /// Returns the number of columns of the matrix.
+ uint32_t getColumns() const;
+ /// Returns the use parameter of the cooperative matrix.
+ CooperativeMatrixUseKHR getUse() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage = std::nullopt);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage = std::nullopt);
+};
+
// SPIR-V NV cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a0bf1300b183a3..2e1c7923e24126 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -319,9 +319,45 @@ static Type parseArrayType(SPIRVDialect const &dialect,
}
// cooperative-matrix-type ::=
-// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type ',' scope>
+// `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
+// scope `,` use `>`
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ SmallVector<int64_t, 2> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+ if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+ return {};
+
+ if (dims.size() != 2) {
+ parser.emitError(countLoc, "expected row and column count");
+ return {};
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return {};
+
+ Scope scope;
+ if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ return {};
+
+ CooperativeMatrixUseKHR use;
+ if (parser.parseComma() || parseEnumKeywordAttr(use, parser, "use <id>"))
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
+}
+
+// nv-cooperative-matrix-type ::=
+// `!spirv.NV.coopmatrix` `<` rows `x` columns `x` element-type `,` scope `>`
+static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
@@ -785,8 +821,10 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "array")
return parseArrayType(*this, parser);
- if (keyword == "NV.coopmatrix")
+ if (keyword == "coopmatrix")
return parseCooperativeMatrixType(*this, parser);
+ if (keyword == "NV.coopmatrix")
+ return parseCooperativeMatrixNVType(*this, parser);
if (keyword == "jointmatrix")
return parseJointMatrixType(*this, parser);
if (keyword == "image")
@@ -889,6 +927,12 @@ static void print(StructType type, DialectAsmPrinter &os) {
structContext.remove(type.getIdentifier());
}
+static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
+ os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
+ << type.getElementType() << ", " << type.getScope() << ", "
+ << type.getUse() << ">";
+}
+
static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << "NV.coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
os << type.getElementType() << ", " << stringifyScope(type.getScope());
@@ -909,9 +953,10 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
- .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
- PointerType, RuntimeArrayType, ImageType, SampledImageType,
- StructType, MatrixType>([&](auto type) { print(type, os); })
+ .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+ JointMatrixINTELType, PointerType, RuntimeArrayType, ImageType,
+ SampledImageType, StructType, MatrixType>(
+ [&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6747f75f468a08..5c11fe8d0cffa6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <numeric>
@@ -429,36 +430,38 @@ static LogicalResult verifyCastOp(Operation *op,
Type operandType = op->getOperand(0).getType();
Type resultType = op->getResult(0).getType();
- // ODS checks that result type and operand type have the same shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- operandType = vectorType.getElementType();
- resultType = llvm::cast<VectorType>(resultType).getElementType();
- }
-
- if (auto coopMatrixType =
- llvm::dyn_cast<spirv::CooperativeMatrixNVType>(operandType)) {
- operandType = coopMatrixType.getElementType();
- resultType =
- llvm::cast<spirv::CooperativeMatrixNVType>(resultType).getElementType();
- }
-
- if (auto jointMatrixType =
- llvm::dyn_cast<spirv::JointMatrixINTELType>(operandType)) {
- operandType = jointMatrixType.getElementType();
- resultType =
- llvm::cast<spirv::JointMatrixINTELType>(resultType).getElementType();
- }
-
- auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
- auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
- auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
+ // ODS checks that result type and operand type have the same shape. Check
+ // that composite types match and extract the element types, if any.
+ using TypePair = std::pair<Type, Type>;
+ auto [operandElemTy, resultElemTy] =
+ TypeSwitch<Type, TypePair>(operandType)
+ .Case<VectorType, spirv::CooperativeMatrixType,
+ spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
+ [resultType](auto concreteOperandTy) -> TypePair {
+ if (auto concreteResultTy =
+ dyn_cast<decltype(concreteOperandTy)>(resultType)) {
+ return {concreteOperandTy.getElementType(),
+ concreteResultTy.getElementType()};
+ }
+ return {};
+ })
+ .Default([resultType](Type operandType) -> TypePair {
+ return {operandType, resultType};
+ });
+
+ if (!operandElemTy || !resultElemTy)
+ return op->emitOpError("incompatible operand and result types");
+
+ unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
+ unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
+ bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
if (requireSameBitWidth) {
if (!isSameBitWidth) {
return op->emitOpError(
"expected the same bit widths for operand type and result "
"type, but provided ")
- << operandType << " and " << resultType;
+ << operandElemTy << " and " << resultElemTy;
}
return success();
}
@@ -467,7 +470,7 @@ static LogicalResult verifyCastOp(Operation *op,
return op->emitOpError(
"expected the
diff erent bit widths for operand type and result "
"type, but provided ")
- << operandType << " and " << resultType;
+ << operandElemTy << " and " << resultElemTy;
}
return success();
}
@@ -4018,6 +4021,34 @@ LogicalResult spirv::VectorShuffleOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
+ if (!isa<spirv::CooperativeMatrixType>(getCooperativeMatrixType())) {
+ return emitOpError(
+ "type attribute must be a '!spirv.coopmatrix' type, found ")
+ << getCooperativeMatrixType() << " instead";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() {
+ if (!isa<spirv::CooperativeMatrixNVType>(getCooperativeMatrixType())) {
+ return emitOpError(
+ "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
+ << getCooperativeMatrixType() << " instead";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
@@ -4053,8 +4084,8 @@ void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " : " << getPointer().getType() << " as " << getType();
}
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
- Type coopMatrix) {
+static LogicalResult
+verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
!llvm::isa<VectorType>(pointeeType))
@@ -4074,8 +4105,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
}
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getResult().getType());
+ return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+ getResult().getType());
}
//===----------------------------------------------------------------------===//
@@ -4114,8 +4145,8 @@ void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getObject().getType());
+ return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+ getObject().getType());
}
//===----------------------------------------------------------------------===//
@@ -4123,7 +4154,7 @@ LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
//===----------------------------------------------------------------------===//
static LogicalResult
-verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
+verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = llvm::cast<spirv::CooperativeMatrixNVType>(op.getA().getType());
@@ -4156,9 +4187,13 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
}
LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
- return verifyCoopMatrixMulAdd(*this);
+ return verifyCoopMatrixMulAddNV(*this);
}
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixLoad
+//===----------------------------------------------------------------------===//
+
static LogicalResult
verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
@@ -4179,10 +4214,6 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
getResult().getType());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 1599c5bb74ae06..01c694de08a9e8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstdint>
#include <iterator>
using namespace mlir;
@@ -93,9 +94,10 @@ std::optional<int64_t> ArrayType::getSizeInBytes() {
bool CompositeType::classof(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return isValid(vectorType);
- return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
- spirv::JointMatrixINTELType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType>(type);
+ return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
+ spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType,
+ spirv::MatrixType, spirv::RuntimeArrayType,
+ spirv::StructType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -114,8 +116,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
- RuntimeArrayType, VectorType>(
+ .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+ JointMatrixINTELType, RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
@@ -133,9 +135,9 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
- if (llvm::isa<CooperativeMatrixNVType>(*this)) {
+ if (llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType>(*this)) {
llvm_unreachable(
- "invalid to query number of elements of spirv::CooperativeMatrix type");
+ "invalid to query number of elements of spirv Cooperative Matrix type");
}
if (llvm::isa<JointMatrixINTELType>(*this)) {
llvm_unreachable(
@@ -149,16 +151,16 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
- RuntimeArrayType>(*this);
+ return !llvm::isa<CooperativeMatrixType, CooperativeMatrixNVType,
+ JointMatrixINTELType, RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
- MatrixType, RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+ JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return llvm::cast<ScalarType>(type.getElementType())
@@ -171,8 +173,8 @@ void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType,
- MatrixType, RuntimeArrayType, StructType>(
+ .Case<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType,
+ JointMatrixINTELType, MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
@@ -202,6 +204,77 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
return std::nullopt;
}
+//===----------------------------------------------------------------------===//
+// CooperativeMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
+ using KeyTy =
+ std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+
+ static CooperativeMatrixTypeStorage *
+ construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<CooperativeMatrixTypeStorage>())
+ CooperativeMatrixTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, rows, columns, scope, use);
+ }
+
+ CooperativeMatrixTypeStorage(const KeyTy &key)
+ : elementType(std::get<0>(key)), rows(std::get<1>(key)),
+ columns(std::get<2>(key)), scope(std::get<3>(key)),
+ use(std::get<4>(key)) {}
+
+ Type elementType;
+ uint32_t rows;
+ uint32_t columns;
+ Scope scope;
+ CooperativeMatrixUseKHR use;
+};
+
+CooperativeMatrixType CooperativeMatrixType::get(Type elementType,
+ uint32_t rows,
+ uint32_t columns, Scope scope,
+ CooperativeMatrixUseKHR use) {
+ return Base::get(elementType.getContext(), elementType, rows, columns, scope,
+ use);
+}
+
+Type CooperativeMatrixType::getElementType() const {
+ return getImpl()->elementType;
+}
+
+uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+
+uint32_t CooperativeMatrixType::getColumns() const {
+ return getImpl()->columns;
+}
+
+Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
+
+CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
+ return getImpl()->use;
+}
+
+void CooperativeMatrixType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
+ extensions.push_back(exts);
+}
+
+void CooperativeMatrixType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
+ static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+ capabilities.push_back(caps);
+}
+
//===----------------------------------------------------------------------===//
// CooperativeMatrixNVType
//===----------------------------------------------------------------------===//
@@ -1247,7 +1320,7 @@ void MatrixType::getCapabilities(
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
- addTypes<ArrayType, CooperativeMatrixNVType, ImageType, JointMatrixINTELType,
- MatrixType, PointerType, RuntimeArrayType, SampledImageType,
- StructType>();
+ addTypes<ArrayType, CooperativeMatrixType, CooperativeMatrixNVType, ImageType,
+ JointMatrixINTELType, MatrixType, PointerType, RuntimeArrayType,
+ SampledImageType, StructType>();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 1835c6ae1d5f80..4f4a72da7c050a 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -138,6 +138,14 @@ func.func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {
// -----
+func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>) {
+ // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> to !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ %0 = spirv.ConvertFToU %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> to !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.Return
+}
+
+// -----
+
func.func @convert_f_to_u_NV.coopmatrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
// CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
%0 = spirv.ConvertFToU %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xi32, Subgroup>
@@ -222,7 +230,15 @@ func.func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> {
// -----
-func.func @f_convert_coop_matrix(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
+func.func @f_convert_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) {
+ // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xf64, Subgroup, MatrixA>
+ %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xf64, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+func.func @f_convert_coop_matrix_nv(%arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
%0 = spirv.FConvert %arg0 : !spirv.NV.coopmatrix<8x16xf32, Subgroup> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
spirv.Return
@@ -238,6 +254,14 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {
// -----
+func.func @f_convert_coop_matrix_to_nv_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>) {
+ // expected-error @+1 {{incompatible operand and result types}}
+ %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc> to !spirv.NV.coopmatrix<8x16xf64, Subgroup>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 2e387403964612..65311e1db9bb49 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -1,4 +1,29 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// CooperativeMatrix (KHR)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @cooperative_matrix_length
+spirv.func @cooperative_matrix_length() -> i32 "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+ // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
+ %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.ReturnValue %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
// CHECK-LABEL: @cooperative_matrix_load
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
@@ -7,7 +32,6 @@ spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stri
spirv.Return
}
-// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
@@ -164,3 +188,11 @@ spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<i32, Function>,
%0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<i32, Function> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.Return
}
+
+// -----
+
+spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
+ // expected-error @+1 {{'spirv.NV.CooperativeMatrixLength' op type attribute must be a '!spirv.NV.coopmatrix'}}
+ %0 = spirv.NV.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.ReturnValue %0 : i32
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 06f0ccfd0a3774..e10a6fc77e8566 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -436,11 +436,55 @@ func.func private @id_struct_recursive(!spirv.struct<a10, (!spirv.ptr<!spirv.str
// -----
//===----------------------------------------------------------------------===//
-// CooperativeMatrix
+// CooperativeMatrix (KHR)
//===----------------------------------------------------------------------===//
-// CHECK: func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
-func.func private @coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
+// CHECK-LABEL: func private @coop_matrix_types
+// CHECK-SAME: !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+// CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+// CHECK-SAME: !spirv.coopmatrix<4x8xf32, Workgroup, MatrixAcc>
+func.func private @coop_matrix_types(!spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
+ !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
+ !spirv.coopmatrix<4x8xf32, Workgroup, MatrixAcc>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected valid keyword}}
+func.func private @missing_scope(!spirv.coopmatrix<8x8xi32, >) -> ()
+
+// -----
+
+// expected-error @+1 {{expected ','}}
+func.func private @missing_use(!spirv.coopmatrix<8x16xi32, Subgroup>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected valid keyword}}
+func.func private @missing_use2(!spirv.coopmatrix<8x8xi32, Subgroup,>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected row and column count}}
+func.func private @missing_count(!spirv.coopmatrix<8xi32, Subgroup, MatrixA>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected row and column count}}
+func.func private @too_many_dims(!spirv.coopmatrix<8x16x32xi32, Subgroup, MatrixB>) -> ()
+
+// -----
+
+// expected-error @+1 {{invalid use <id> attribute specification: Subgroup}}
+func.func private @use_not_integer(!spirv.coopmatrix<8x8xi32, Subgroup, Subgroup>) -> ()
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// NV.CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>)
+func.func private @nv_coop_matrix_type(!spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xf32, Workgroup>) -> ()
// -----
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 7bf7755a8a52f9..ccf4240f8e5608 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -857,7 +857,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
attrList, attrName, words, wordIndex);
- } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+ } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"TypeAttr::get(getType({2}[{3}++]))));\n",
@@ -866,7 +866,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
PrintFatalError(
loc, llvm::Twine(
"unhandled attribute type in deserialization generation : '") +
- attr.getAttrDefName() + llvm::Twine("'"));
+ attrName + llvm::Twine("'"));
}
}
More information about the Mlir-commits
mailing list