[Mlir-commits] [mlir] b359bba - [mlir][spirv] First step to support spirv cooperative matrix extension.
Thomas Raoux
llvmlistbot at llvm.org
Tue May 19 19:30:05 PDT 2020
Author: Thomas Raoux
Date: 2020-05-19T19:29:41-07:00
New Revision: b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1
URL: https://github.com/llvm/llvm-project/commit/b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1
DIFF: https://github.com/llvm/llvm-project/commit/b359bbaa8b41a84ae54369e3017ce1a5c7afe1a1.diff
LOG: [mlir][spirv] First step to support spirv cooperative matrix extension.
Add a new type to SPIRV dialect for cooperative matrix and add new op for
cooperative matrix load. This is missing most instructions to support
cooperative matrix extension but this is a stop-gap patch to avoid creating big
review.
Differential Revision: https://reviews.llvm.org/D80043
Added:
mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
new file mode 100644
index 000000000000..f368aec45efb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/ParserUtils.h
@@ -0,0 +1,41 @@
+//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines utilities used for parsing types and ops for SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_SPIRV_PARSERUTILS_H_
+#define MLIR_DIALECT_SPIRV_PARSERUTILS_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+/// Parses the next keyword in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass, typename ParserType>
+static ParseResult
+parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ StringRef keyword;
+ SmallVector<NamedAttribute, 1> attr;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&keyword))
+ return failure();
+ if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
+ value = attr.getValue();
+ return success();
+ }
+ return parser.emitError(loc, "invalid ")
+ << attrName << " attribute specification: " << keyword;
+}
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_PARSERUTILS_H_
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 64063cb77d01..b958a10c5952 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -2991,8 +2991,10 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, IOrUI<w>),
StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
-def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
+def SPV_IsCooperativeMatrixType :
+ CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">;
+def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
@@ -3012,6 +3014,9 @@ def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
"any SPIR-V pointer type">;
def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
"any SPIR-V array type">;
+def SPV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
+ SPV_IsCooperativeMatrixType,
+ "any SPIR-V cooperative matrix type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
"any SPIR-V runtime array type">;
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
@@ -3220,6 +3225,8 @@ def SPV_OC_OpGroupNonUniformSMax : I32EnumAttrCase<"OpGroupNonUniformSMax"
def SPV_OC_OpGroupNonUniformUMax : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
def SPV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
+def SPV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
+def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3271,7 +3278,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
- SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR
+ SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
+ SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
new file mode 100644
index 000000000000..931f56f58755
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -0,0 +1,94 @@
+//===- SPIRVCooperativeMatrixOps.td - cooperative matmul ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of cooperative matrix multiply extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_COOPERATIVE_MATRIX_OPS
+#define SPIRV_COOPERATIVE_MATRIX_OPS
+
+// -----
+
+def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
+ let summary = "See extension SPV_NV_cooperative_matrix";
+
+ let description = [{
+ Load a cooperative matrix through a pointer.
+
+ Result Type is the type of the loaded object. It must be a cooperative
+ matrix type.
+
+ Pointer is a pointer into an array. Its type must be an OpTypePointer whose
+ Type operand is a scalar or vector type. The storage class of Pointer must
+ be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
+ supported) PhysicalStorageBufferEXT.
+
+ Stride is the number of elements in the array in memory between the first
+ component of consecutive rows (or columns) in the result. It must be a
+ scalar integer type.
+
+ ColumnMajor indicates whether the values loaded from memory are arranged in
+ column-major or row-major order. It must be a boolean constant instruction,
+ with false indicating row major and true indicating column major.
+
+ Memory Access must be a Memory Access literal. If not present, it is the
+ same as specifying None.
+
+ If ColumnMajor is false, then elements (row,*) of the result are taken in
+ order from contiguous locations starting at Pointer[row*Stride]. If
+ ColumnMajor is true, then elements (*,col) of the result are taken in order
+ from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
+ decoration on Pointer is ignored.
+
+ For a given dynamic instance of this instruction, all operands of this
+ instruction must be the same for all invocations in a given scope instance
+ (where the scope is the scope the cooperative matrix type was created with).
+ All invocations in a given scope instance must be active or all must be
+ inactive.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ cooperative-matrix-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
+ storage-class ssa-use (`[` memory-access `]`)? `
+ : ` cooperative-matrix-type
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
+ : !spv.coopmatrix<i32, Workgroup, 16, 8>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[SPV_NV_cooperative_matrix]>,
+ Capability<[SPV_C_CooperativeMatrixNV]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyPtr:$pointer,
+ SPV_Integer:$stride,
+ SPV_Bool:$columnmajor,
+ OptionalAttr<SPV_MemoryAccessAttr>:$memory_access
+ );
+
+ let results = (outs
+ SPV_AnyCooperativeMatrix:$result
+ );
+
+ let verifier = [{ return success(); }];
+}
+
+// -----
+
+#endif // SPIRV_COOPERATIVE_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 518dca69873d..520ed14c9624 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -28,6 +28,7 @@ include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
+include "mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 3b5a82d239b9..078fb5a67225 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -54,6 +54,7 @@ SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
namespace detail {
struct ArrayTypeStorage;
+struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
@@ -63,6 +64,7 @@ struct StructTypeStorage;
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
+ CooperativeMatrix,
Image,
Pointer,
RuntimeArray,
@@ -330,6 +332,34 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
Optional<spirv::StorageClass> storage = llvm::None);
};
+// SPIR-V cooperative matrix type
+class CooperativeMatrixNVType
+ : public Type::TypeBase<CooperativeMatrixNVType, SPIRVType,
+ detail::CooperativeMatrixTypeStorage> {
+public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) {
+ return kind == TypeKind::CooperativeMatrix;
+ }
+
+ static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope,
+ unsigned rows, unsigned columns);
+ Type getElementType() const;
+
+ /// Return the scope of the cooperative matrix.
+ spirv::Scope getScope() const;
+ /// return the number of rows of the matrix.
+ unsigned getRows() const;
+ /// return the number of columns of the matrix.
+ unsigned getColumns() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
+};
+
} // end namespace spirv
} // end namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index c74698a93bfc..8c4d0ebe99a7 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/ParserUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
@@ -115,7 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
+ addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
+ RuntimeArrayType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
@@ -264,6 +266,36 @@ static Type parseArrayType(SPIRVDialect const &dialect,
return ArrayType::get(elementType, count, stride);
}
+// cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
+// rows ',' coloumns>`
+static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return Type();
+
+ SmallVector<int64_t, 2> dims;
+ llvm::SMLoc countLoc = parser.getCurrentLocation();
+ if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
+ return Type();
+
+ if (dims.size() != 2) {
+ parser.emitError(countLoc, "expected rows and columns size.");
+ return Type();
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return Type();
+
+ Scope scope;
+ if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ return Type();
+
+ if (parser.parseGreater())
+ return Type();
+ return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
+}
+
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -525,6 +557,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "array")
return parseArrayType(*this, parser);
+ if (keyword == "coopmatrix")
+ return parseCooperativeMatrixType(*this, parser);
if (keyword == "image")
return parseImageType(*this, parser);
if (keyword == "ptr")
@@ -595,11 +629,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
+ os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
+ os << type.getElementType() << ", " << stringifyScope(type.getScope());
+ os << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
print(type.cast<ArrayType>(), os);
return;
+ case TypeKind::CooperativeMatrix:
+ print(type.cast<CooperativeMatrixNVType>(), os);
+ return;
case TypeKind::Pointer:
print(type.cast<PointerType>(), os);
return;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index e7bdfe902804..eed597b1d21c 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/ParserUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
@@ -140,25 +141,6 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
return success();
}
-/// Parses the next keyword in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- StringRef keyword;
- SmallVector<NamedAttribute, 1> attr;
- auto loc = parser.getCurrentLocation();
- if (parser.parseKeyword(&keyword))
- return failure();
- if (Optional<EnumClass> attr = spirv::symbolizeEnum<EnumClass>(keyword)) {
- value = attr.getValue();
- return success();
- }
- return parser.emitError(loc, "invalid ")
- << attrName << " attribute specification: " << keyword;
-}
-
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
/// the enum class's name as attribute name.
@@ -2637,6 +2619,49 @@ static LogicalResult verify(spirv::VariableOp varOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.CooperativeMatrixLoadNV
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
+ OperationState &state) {
+ spirv::StorageClass storageClass;
+ SmallVector<OpAsmParser::OperandType, 3> operandInfo;
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type elementType;
+ if (parseEnumStrAttr(storageClass, parser) ||
+ parser.parseOperandList(operandInfo, 3) ||
+ parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+ parser.parseType(elementType)) {
+ return failure();
+ }
+
+ auto ptrType = spirv::PointerType::get(
+ elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
+ storageClass);
+ SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
+ if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
+ state.operands)) {
+ return failure();
+ }
+
+ state.addTypes(elementType);
+ return success();
+}
+
+static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
+ StringRef sc = stringifyStorageClass(
+ M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
+ printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
+ << "\" " << M.pointer() << ", " << M.stride() << ", "
+ << M.columnmajor();
+ // Print optional memory access attribute.
+ if (auto memAccess = M.memory_access())
+ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+ printer << " : " << M.getType();
+}
+
namespace mlir {
namespace spirv {
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 71ca0c3d2bc7..ce5a6c0c4fd9 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -158,6 +158,7 @@ void ArrayType::getCapabilities(
bool CompositeType::classof(Type type) {
switch (type.getKind()) {
case TypeKind::Array:
+ case TypeKind::CooperativeMatrix:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
return true;
@@ -177,6 +178,8 @@ Type CompositeType::getElementType(unsigned index) const {
switch (getKind()) {
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
+ case spirv::TypeKind::CooperativeMatrix:
+ return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
@@ -192,6 +195,9 @@ unsigned CompositeType::getNumElements() const {
switch (getKind()) {
case spirv::TypeKind::Array:
return cast<ArrayType>().getNumElements();
+ case spirv::TypeKind::CooperativeMatrix:
+ return cast<CooperativeMatrixNVType>().getRows() *
+ cast<CooperativeMatrixNVType>().getColumns();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@@ -211,6 +217,9 @@ void CompositeType::getExtensions(
case spirv::TypeKind::Array:
cast<ArrayType>().getExtensions(extensions, storage);
break;
+ case spirv::TypeKind::CooperativeMatrix:
+ cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
+ break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getExtensions(extensions, storage);
break;
@@ -233,6 +242,9 @@ void CompositeType::getCapabilities(
case spirv::TypeKind::Array:
cast<ArrayType>().getCapabilities(capabilities, storage);
break;
+ case spirv::TypeKind::CooperativeMatrix:
+ cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
+ break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
break;
@@ -248,6 +260,70 @@ void CompositeType::getCapabilities(
}
}
+//===----------------------------------------------------------------------===//
+// CooperativeMatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
+ using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
+
+ static CooperativeMatrixTypeStorage *
+ construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<CooperativeMatrixTypeStorage>())
+ CooperativeMatrixTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, getScope(), rows, columns);
+ }
+
+ CooperativeMatrixTypeStorage(const KeyTy &key)
+ : TypeStorage(static_cast<unsigned>(std::get<1>(key))),
+ elementType(std::get<0>(key)), rows(std::get<2>(key)),
+ columns(std::get<3>(key)) {}
+
+ Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
+
+ Type elementType;
+ unsigned rows;
+ unsigned columns;
+};
+
+CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
+ Scope scope, unsigned rows,
+ unsigned columns) {
+ return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
+ elementType, scope, rows, columns);
+}
+
+Type CooperativeMatrixNVType::getElementType() const {
+ return getImpl()->elementType;
+}
+
+Scope CooperativeMatrixNVType::getScope() const {
+ return getImpl()->getScope();
+}
+
+unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
+
+unsigned CooperativeMatrixNVType::getColumns() const {
+ return getImpl()->columns;
+}
+
+void CooperativeMatrixNVType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ extensions.push_back(Extension::SPV_NV_cooperative_matrix);
+}
+
+void CooperativeMatrixNVType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ capabilities.push_back(Capability::CooperativeMatrixNV);
+}
+
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index a45780ba63a0..87f233580b75 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -217,6 +217,8 @@ class Deserializer {
LogicalResult processArrayType(ArrayRef<uint32_t> operands);
+ LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
+
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
@@ -1160,6 +1162,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
} break;
case spirv::Opcode::OpTypeArray:
return processArrayType(operands);
+ case spirv::Opcode::OpTypeCooperativeMatrixNV:
+ return processCooperativeMatrixType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
@@ -1229,6 +1233,35 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 5) {
+ return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
+ "type and row x column parameters");
+ }
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc,
+ "OpTypeCooperativeMatrix references undefined <id> ")
+ << operands[1];
+ }
+
+ auto scope = spirv::symbolizeScope(operands[2]);
+ if (!scope) {
+ return emitError(unknownLoc,
+ "OpTypeCooperativeMatrix references undefined scope <id> ")
+ << operands[2];
+ }
+
+ unsigned rows = operands[3];
+ unsigned columns = operands[4];
+
+ typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
+ elementTy, scope.getValue(), rows, columns);
+ return success();
+}
+
LogicalResult
Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
@@ -2210,6 +2243,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
+ case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpConstant:
return processConstant(operands, /*isSpec=*/false);
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 2b500ddbf985..8ea0c4f4711b 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1096,6 +1096,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
+ if (auto cooperativeMatrixType =
+ type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ uint32_t elementTypeID = 0;
+ if (failed(processType(loc, cooperativeMatrixType.getElementType(),
+ elementTypeID))) {
+ return failure();
+ }
+ typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
+ operands.push_back(elementTypeID);
+ operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
+ operands.push_back(cooperativeMatrixType.getRows());
+ operands.push_back(cooperativeMatrixType.getColumns());
+ return success();
+ }
+
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
new file mode 100644
index 000000000000..e90996ee24b7
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
+ // CHECK-LABEL: @cooperative_matrix_load
+ spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ spv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_load_memaccess
+ spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
new file mode 100644
index 000000000000..c121943acf82
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @cooperative_matrix_load
+spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ spv.Return
+}
+
+// -----
+// CHECK-LABEL: @cooperative_matrix_load_memaccess
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 4c1adafce4a8..697177b0b98e 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -327,3 +327,23 @@ func @struct_type_missing_comma(!spv.struct<f32 [0 NonWritable], i32 [4]>)
// expected-error @+1 {{expected ']'}}
func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i32 [4]>)
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// CooperativeMatrix
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>)
+func @coop_matrix_type(!spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xf32, Workgroup>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected ','}}
+func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected rows and columns size}}
+func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
+
More information about the Mlir-commits
mailing list