[Mlir-commits] [mlir] 915e55c - [mlir][spirv] Add support for matrix type
Lei Zhang
llvmlistbot at llvm.org
Tue Jun 2 13:31:13 PDT 2020
Author: HazemAbdelhafez
Date: 2020-06-02T16:30:58-04:00
New Revision: 915e55c9107807cbad9c4085347f027a8ddbc5c1
URL: https://github.com/llvm/llvm-project/commit/915e55c9107807cbad9c4085347f027a8ddbc5c1
DIFF: https://github.com/llvm/llvm-project/commit/915e55c9107807cbad9c4085347f027a8ddbc5c1.diff
LOG: [mlir][spirv] Add support for matrix type
This commit adds basic matrix type support to the SPIR-V dialect
including type definition, IR assembly, parsing, printing, and
(de)serialization.
Differential Revision: https://reviews.llvm.org/D80594
Added:
mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index ead6c0341cd6..a95ed18ca4e7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3109,6 +3109,7 @@ def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>;
def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
+def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
@@ -3250,15 +3251,15 @@ def SPV_OpcodeAttr :
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
- SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
- SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
- SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
- SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
- SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
- SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
- SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
- SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
- SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
+ SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
+ SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
+ SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
+ SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
+ SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
+ SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
+ SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
+ SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
+ SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 71eba72e5e84..b7180399a837 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
#define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
@@ -56,9 +58,11 @@ namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
+struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct StructTypeStorage;
+
} // namespace detail
namespace TypeKind {
@@ -66,6 +70,7 @@ enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
CooperativeMatrix,
Image,
+ Matrix,
Pointer,
RuntimeArray,
Struct,
@@ -366,6 +371,36 @@ class CooperativeMatrixNVType
Optional<spirv::StorageClass> storage = llvm::None);
};
+// SPIR-V matrix type
+class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
+ detail::MatrixTypeStorage> {
+public:
+ using Base::Base;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; }
+
+ static MatrixType get(Type columnType, uint32_t columnCount);
+
+ static MatrixType getChecked(Type columnType, uint32_t columnCount,
+ Location location);
+
+ static LogicalResult verifyConstructionInvariants(Location loc,
+ Type columnType,
+ uint32_t columnCount);
+
+ /// Returns true if the matrix elements are vectors of float elements
+ static bool isValidColumnType(Type columnType);
+
+ Type getElementType() const;
+
+ unsigned getNumElements() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
+};
+
} // end namespace spirv
} // end namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 8c4d0ebe99a7..455064f58ce6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -116,8 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
- RuntimeArrayType, StructType>();
+ addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
+ PointerType, RuntimeArrayType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
@@ -197,6 +197,42 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return type;
}
+static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ Type type;
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ if (parser.parseType(type))
+ return Type();
+
+ if (auto t = type.dyn_cast<VectorType>()) {
+ if (t.getRank() != 1) {
+ parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
+ return Type();
+ }
+ if (t.getNumElements() > 4 || t.getNumElements() < 2) {
+ parser.emitError(typeLoc,
+ "matrix columns size has to be less than or equal "
+ "to 4 and greater than or equal 2, but found ")
+ << t.getNumElements();
+ return Type();
+ }
+
+ if (!t.getElementType().isa<FloatType>()) {
+ parser.emitError(typeLoc, "matrix columns' elements must be of "
+ "Float type, got ")
+ << t.getElementType();
+ return Type();
+ }
+ } else {
+ parser.emitError(typeLoc, "matrix must be composed using vector "
+ "type, got ")
+ << type;
+ return Type();
+ }
+
+ return type;
+}
+
/// Parses an optional `, stride = N` assembly segment. If no parsing failure
/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
/// missing.
@@ -279,7 +315,7 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return Type();
if (dims.size() != 2) {
- parser.emitError(countLoc, "expected rows and columns size.");
+ parser.emitError(countLoc, "expected rows and columns size");
return Type();
}
@@ -350,6 +386,40 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
return RuntimeArrayType::get(elementType, stride);
}
+// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
+static Type parseMatrixType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return Type();
+
+ SmallVector<int64_t, 1> countDims;
+ llvm::SMLoc countLoc = parser.getCurrentLocation();
+ if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
+ return Type();
+ if (countDims.size() != 1) {
+ parser.emitError(countLoc, "expected single unsigned "
+ "integer for number of columns");
+ return Type();
+ }
+
+ int64_t columnCount = countDims[0];
+ // According to the specification, Matrices can have 2, 3, or 4 columns
+ if (columnCount < 2 || columnCount > 4) {
+ parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
+ "columns");
+ return Type();
+ }
+
+ Type columnType = parseAndVerifyMatrixType(dialect, parser);
+ if (!columnType)
+ return Type();
+
+ if (parser.parseGreater())
+ return Type();
+
+ return MatrixType::get(columnType, columnCount);
+}
+
// Specialize this function to parse each of the parameters that define an
// ImageType. By default it assumes this is an enum type.
template <typename ValTy>
@@ -567,7 +637,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseRuntimeArrayType(*this, parser);
if (keyword == "struct")
return parseStructType(*this, parser);
-
+ if (keyword == "matrix")
+ return parseMatrixType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -635,6 +706,11 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(MatrixType type, DialectAsmPrinter &os) {
+ os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
+ os << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
@@ -655,6 +731,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
case TypeKind::Struct:
print(type.cast<StructType>(), os);
return;
+ case TypeKind::Matrix:
+ print(type.cast<MatrixType>(), os);
+ return;
default:
llvm_unreachable("unhandled SPIR-V type");
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 49b39ec78435..4ba17f3a1240 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -159,6 +159,7 @@ bool CompositeType::classof(Type type) {
switch (type.getKind()) {
case TypeKind::Array:
case TypeKind::CooperativeMatrix:
+ case TypeKind::Matrix:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
return true;
@@ -180,6 +181,8 @@ Type CompositeType::getElementType(unsigned index) const {
return cast<ArrayType>().getElementType();
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
+ case spirv::TypeKind::Matrix:
+ return cast<MatrixType>().getElementType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
@@ -198,6 +201,8 @@ unsigned CompositeType::getNumElements() const {
case spirv::TypeKind::CooperativeMatrix:
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
+ case spirv::TypeKind::Matrix:
+ return cast<MatrixType>().getNumElements();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
@@ -230,6 +235,9 @@ void CompositeType::getExtensions(
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
break;
+ case spirv::TypeKind::Matrix:
+ cast<MatrixType>().getExtensions(extensions, storage);
+ break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getExtensions(extensions, storage);
break;
@@ -255,6 +263,9 @@ void CompositeType::getCapabilities(
case spirv::TypeKind::CooperativeMatrix:
cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
break;
+ case spirv::TypeKind::Matrix:
+ cast<MatrixType>().getCapabilities(capabilities, storage);
+ break;
case spirv::TypeKind::RuntimeArray:
cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
break;
@@ -823,10 +834,12 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
scalarType.getExtensions(extensions, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getExtensions(extensions, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
- ptrType.getExtensions(extensions, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getExtensions(extensions, storage);
+ } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ matrixType.getExtensions(extensions, storage);
+ } else if (auto ptrType = dyn_cast<PointerType>()) {
+ ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@@ -839,10 +852,12 @@ void SPIRVType::getCapabilities(
scalarType.getCapabilities(capabilities, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getCapabilities(capabilities, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
- ptrType.getCapabilities(capabilities, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getCapabilities(capabilities, storage);
+ } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ matrixType.getCapabilities(capabilities, storage);
+ } else if (auto ptrType = dyn_cast<PointerType>()) {
+ ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@@ -1000,3 +1015,89 @@ void StructType::getCapabilities(
for (Type elementType : getElementTypes())
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
}
+
+//===----------------------------------------------------------------------===//
+// MatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::MatrixTypeStorage : public TypeStorage {
+ MatrixTypeStorage(Type columnType, uint32_t columnCount)
+ : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
+
+ using KeyTy = std::tuple<Type, uint32_t>;
+
+ static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+
+ // Initialize the memory using placement new.
+ return new (allocator.allocate<MatrixTypeStorage>())
+ MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(columnType, columnCount);
+ }
+
+ Type columnType;
+ const uint32_t columnCount;
+};
+
+MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
+ return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
+ columnCount);
+}
+
+MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
+ Location location) {
+ return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
+}
+
+LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
+ Type columnType,
+ uint32_t columnCount) {
+ if (columnCount < 2 || columnCount > 4)
+ return emitError(loc, "matrix can have 2, 3, or 4 columns only");
+
+ if (!isValidColumnType(columnType))
+ return emitError(loc, "matrix columns must be vectors of floats");
+
+ /// The underlying vectors (columns) must be of size 2, 3, or 4
+ ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
+ if (columnShape.size() != 1)
+ return emitError(loc, "matrix columns must be 1D vectors");
+
+ if (columnShape[0] < 2 || columnShape[0] > 4)
+ return emitError(loc, "matrix columns must be of size 2, 3, or 4");
+
+ return success();
+}
+
+/// Returns true if the matrix elements are vectors of float elements
+bool MatrixType::isValidColumnType(Type columnType) {
+ if (auto vectorType = columnType.dyn_cast<VectorType>()) {
+ if (vectorType.getElementType().isa<FloatType>())
+ return true;
+ }
+ return false;
+}
+
+Type MatrixType::getElementType() const { return getImpl()->columnType; }
+
+unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
+
+void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void MatrixType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ {
+ static const Capability caps[] = {Capability::Matrix};
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+ capabilities.push_back(ref);
+ }
+ // Add any capabilities associated with the underlying vectors (i.e., columns)
+ getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 87f233580b75..750dddfa6dc4 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -225,6 +225,8 @@ class Deserializer {
LogicalResult processStructType(ArrayRef<uint32_t> operands);
+ LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
+
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
@@ -1170,6 +1172,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return processRuntimeArrayType(operands);
case spirv::Opcode::OpTypeStruct:
return processStructType(operands);
+ case spirv::Opcode::OpTypeMatrix:
+ return processMatrixType(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1333,6 +1337,25 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 3) {
+ // Three operands are needed: result_id, column_type, and column_count
+ return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
+ " (result_id, column_type, and column_count)");
+ }
+ // Matrix columns must be of vector type
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc,
+ "OpTypeMatrix references undefined column type.")
+ << operands[1];
+ }
+
+ uint32_t colsCount = operands[2];
+ typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
@@ -2238,6 +2261,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeInt:
case spirv::Opcode::OpTypeFloat:
case spirv::Opcode::OpTypeVector:
+ case spirv::Opcode::OpTypeMatrix:
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeRuntimeArray:
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 8ea0c4f4711b..0b1c970589b1 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1111,6 +1111,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
+ if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
+ uint32_t elementTypeID = 0;
+ if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
+ return failure();
+ }
+ typeEnum = spirv::Opcode::OpTypeMatrix;
+ operands.push_back(elementTypeID);
+ operands.push_back(matrixType.getNumElements());
+ return success();
+ }
+
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
new file mode 100644
index 000000000000..b27702bf50d8
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+ spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
+ // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+ %2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+ spv.Return
+ }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+ // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+ spv.globalVariable @var0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+
+ // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
+ spv.globalVariable @var1 : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
+
+ // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
+ spv.globalVariable @var2 : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
+}
diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 697177b0b98e..1d1a1868ea3c 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -347,3 +347,87 @@ func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
// expected-error @+1 {{expected rows and columns size}}
func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
+// -----
+
+//===----------------------------------------------------------------------===//
+// Matrix
+//===----------------------------------------------------------------------===//
+// CHECK: func @matrix_type(!spv.matrix<2 x vector<2xf16>>)
+func @matrix_type(!spv.matrix<2 x vector<2xf16>>) -> ()
+
+// -----
+
+// CHECK: func @matrix_type(!spv.matrix<3 x vector<3xf32>>)
+func @matrix_type(!spv.matrix<3 x vector<3xf32>>) -> ()
+
+// -----
+
+// CHECK: func @matrix_type(!spv.matrix<4 x vector<4xf16>>)
+func @matrix_type(!spv.matrix<4 x vector<4xf16>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
+func @matrix_invalid_size(!spv.matrix<5 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
+func @matrix_invalid_size(!spv.matrix<1 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 5}}
+func @matrix_invalid_columns_size(!spv.matrix<3 x vector<5xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 1}}
+func @matrix_invalid_columns_size(!spv.matrix<3 x vector<1xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected '<'}}
+func @matrix_invalid_format(!spv.matrix 3 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+func @matrix_invalid_format(!spv.matrix< 3 x vector<3xf32>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected 'x' in dimension list}}
+func @matrix_invalid_format(!spv.matrix<2 vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got 'i32'}}
+func @matrix_invalid_type(!spv.matrix< 3 x i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got '!spv.array<16 x f32>'}}
+func @matrix_invalid_type(!spv.matrix< 3 x !spv.array<16 x f32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got '!spv.rtarray<i32>'}}
+func @matrix_invalid_type(!spv.matrix< 3 x !spv.rtarray<i32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns' elements must be of Float type, got 'i32'}}
+func @matrix_invalid_type(!spv.matrix<2 x vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected single unsigned integer for number of columns}}
+func @matrix_size_type(!spv.matrix< x vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected single unsigned integer for number of columns}}
+func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> ()
+
+// -----
\ No newline at end of file
More information about the Mlir-commits
mailing list