[Mlir-commits] [mlir] 915e55c - [mlir][spirv] Add support for matrix type

Lei Zhang llvmlistbot at llvm.org
Tue Jun 2 13:31:13 PDT 2020


Author: HazemAbdelhafez
Date: 2020-06-02T16:30:58-04:00
New Revision: 915e55c9107807cbad9c4085347f027a8ddbc5c1

URL: https://github.com/llvm/llvm-project/commit/915e55c9107807cbad9c4085347f027a8ddbc5c1
DIFF: https://github.com/llvm/llvm-project/commit/915e55c9107807cbad9c4085347f027a8ddbc5c1.diff

LOG: [mlir][spirv] Add support for matrix type

This commit adds basic matrix type support to the SPIR-V dialect
including type definition, IR assembly, parsing, printing, and
(de)serialization.

Differential Revision: https://reviews.llvm.org/D80594

Added: 
    mlir/test/Dialect/SPIRV/Serialization/matrix.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index ead6c0341cd6..a95ed18ca4e7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3109,6 +3109,7 @@ def SPV_OC_OpTypeBool                  : I32EnumAttrCase<"OpTypeBool", 20>;
 def SPV_OC_OpTypeInt                   : I32EnumAttrCase<"OpTypeInt", 21>;
 def SPV_OC_OpTypeFloat                 : I32EnumAttrCase<"OpTypeFloat", 22>;
 def SPV_OC_OpTypeVector                : I32EnumAttrCase<"OpTypeVector", 23>;
+def SPV_OC_OpTypeMatrix                : I32EnumAttrCase<"OpTypeMatrix", 24>;
 def SPV_OC_OpTypeArray                 : I32EnumAttrCase<"OpTypeArray", 28>;
 def SPV_OC_OpTypeRuntimeArray          : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
 def SPV_OC_OpTypeStruct                : I32EnumAttrCase<"OpTypeStruct", 30>;
@@ -3250,15 +3251,15 @@ def SPV_OpcodeAttr :
       SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
       SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
       SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
-      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
-      SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
-      SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
-      SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
-      SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
-      SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
-      SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
-      SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
+      SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
+      SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
+      SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
+      SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
+      SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
+      SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
+      SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
+      SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
+      SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
       SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
       SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
       SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 71eba72e5e84..b7180399a837 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -13,6 +13,8 @@
 #ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
 #define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
 
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeSupport.h"
 #include "mlir/IR/Types.h"
@@ -56,9 +58,11 @@ namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
 struct ImageTypeStorage;
+struct MatrixTypeStorage;
 struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
 struct StructTypeStorage;
+
 } // namespace detail
 
 namespace TypeKind {
@@ -66,6 +70,7 @@ enum Kind {
   Array = Type::FIRST_SPIRV_TYPE,
   CooperativeMatrix,
   Image,
+  Matrix,
   Pointer,
   RuntimeArray,
   Struct,
@@ -366,6 +371,36 @@ class CooperativeMatrixNVType
                        Optional<spirv::StorageClass> storage = llvm::None);
 };
 
+// SPIR-V matrix type
+class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
+                                         detail::MatrixTypeStorage> {
+public:
+  using Base::Base;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; }
+
+  static MatrixType get(Type columnType, uint32_t columnCount);
+
+  static MatrixType getChecked(Type columnType, uint32_t columnCount,
+                               Location location);
+
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    Type columnType,
+                                                    uint32_t columnCount);
+
+  /// Returns true if the matrix elements are vectors of float elements
+  static bool isValidColumnType(Type columnType);
+
+  Type getElementType() const;
+
+  unsigned getNumElements() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
+};
+
 } // end namespace spirv
 } // end namespace mlir
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 8c4d0ebe99a7..455064f58ce6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -116,8 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
 
 SPIRVDialect::SPIRVDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, PointerType,
-           RuntimeArrayType, StructType>();
+  addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
+           PointerType, RuntimeArrayType, StructType>();
 
   addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
 
@@ -197,6 +197,42 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
   return type;
 }
 
+static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
+                                     DialectAsmParser &parser) {
+  Type type;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (parser.parseType(type))
+    return Type();
+
+  if (auto t = type.dyn_cast<VectorType>()) {
+    if (t.getRank() != 1) {
+      parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
+      return Type();
+    }
+    if (t.getNumElements() > 4 || t.getNumElements() < 2) {
+      parser.emitError(typeLoc,
+                       "matrix columns size has to be less than or equal "
+                       "to 4 and greater than or equal 2, but found ")
+          << t.getNumElements();
+      return Type();
+    }
+
+    if (!t.getElementType().isa<FloatType>()) {
+      parser.emitError(typeLoc, "matrix columns' elements must be of "
+                                "Float type, got ")
+          << t.getElementType();
+      return Type();
+    }
+  } else {
+    parser.emitError(typeLoc, "matrix must be composed using vector "
+                              "type, got ")
+        << type;
+    return Type();
+  }
+
+  return type;
+}
+
 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
 /// missing.
@@ -279,7 +315,7 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
     return Type();
 
   if (dims.size() != 2) {
-    parser.emitError(countLoc, "expected rows and columns size.");
+    parser.emitError(countLoc, "expected rows and columns size");
     return Type();
   }
 
@@ -350,6 +386,40 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
   return RuntimeArrayType::get(elementType, stride);
 }
 
+// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
+static Type parseMatrixType(SPIRVDialect const &dialect,
+                            DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  SmallVector<int64_t, 1> countDims;
+  llvm::SMLoc countLoc = parser.getCurrentLocation();
+  if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
+    return Type();
+  if (countDims.size() != 1) {
+    parser.emitError(countLoc, "expected single unsigned "
+                               "integer for number of columns");
+    return Type();
+  }
+
+  int64_t columnCount = countDims[0];
+  // According to the specification, Matrices can have 2, 3, or 4 columns
+  if (columnCount < 2 || columnCount > 4) {
+    parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
+                               "columns");
+    return Type();
+  }
+
+  Type columnType = parseAndVerifyMatrixType(dialect, parser);
+  if (!columnType)
+    return Type();
+
+  if (parser.parseGreater())
+    return Type();
+
+  return MatrixType::get(columnType, columnCount);
+}
+
 // Specialize this function to parse each of the parameters that define an
 // ImageType. By default it assumes this is an enum type.
 template <typename ValTy>
@@ -567,7 +637,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseRuntimeArrayType(*this, parser);
   if (keyword == "struct")
     return parseStructType(*this, parser);
-
+  if (keyword == "matrix")
+    return parseMatrixType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -635,6 +706,11 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(MatrixType type, DialectAsmPrinter &os) {
+  os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
+  os << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   switch (type.getKind()) {
   case TypeKind::Array:
@@ -655,6 +731,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   case TypeKind::Struct:
     print(type.cast<StructType>(), os);
     return;
+  case TypeKind::Matrix:
+    print(type.cast<MatrixType>(), os);
+    return;
   default:
     llvm_unreachable("unhandled SPIR-V type");
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 49b39ec78435..4ba17f3a1240 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -159,6 +159,7 @@ bool CompositeType::classof(Type type) {
   switch (type.getKind()) {
   case TypeKind::Array:
   case TypeKind::CooperativeMatrix:
+  case TypeKind::Matrix:
   case TypeKind::RuntimeArray:
   case TypeKind::Struct:
     return true;
@@ -180,6 +181,8 @@ Type CompositeType::getElementType(unsigned index) const {
     return cast<ArrayType>().getElementType();
   case spirv::TypeKind::CooperativeMatrix:
     return cast<CooperativeMatrixNVType>().getElementType();
+  case spirv::TypeKind::Matrix:
+    return cast<MatrixType>().getElementType();
   case spirv::TypeKind::RuntimeArray:
     return cast<RuntimeArrayType>().getElementType();
   case spirv::TypeKind::Struct:
@@ -198,6 +201,8 @@ unsigned CompositeType::getNumElements() const {
   case spirv::TypeKind::CooperativeMatrix:
     llvm_unreachable(
         "invalid to query number of elements of spirv::CooperativeMatrix type");
+  case spirv::TypeKind::Matrix:
+    return cast<MatrixType>().getNumElements();
   case spirv::TypeKind::RuntimeArray:
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
@@ -230,6 +235,9 @@ void CompositeType::getExtensions(
   case spirv::TypeKind::CooperativeMatrix:
     cast<CooperativeMatrixNVType>().getExtensions(extensions, storage);
     break;
+  case spirv::TypeKind::Matrix:
+    cast<MatrixType>().getExtensions(extensions, storage);
+    break;
   case spirv::TypeKind::RuntimeArray:
     cast<RuntimeArrayType>().getExtensions(extensions, storage);
     break;
@@ -255,6 +263,9 @@ void CompositeType::getCapabilities(
   case spirv::TypeKind::CooperativeMatrix:
     cast<CooperativeMatrixNVType>().getCapabilities(capabilities, storage);
     break;
+  case spirv::TypeKind::Matrix:
+    cast<MatrixType>().getCapabilities(capabilities, storage);
+    break;
   case spirv::TypeKind::RuntimeArray:
     cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
     break;
@@ -823,10 +834,12 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     scalarType.getExtensions(extensions, storage);
   } else if (auto compositeType = dyn_cast<CompositeType>()) {
     compositeType.getExtensions(extensions, storage);
-  } else if (auto ptrType = dyn_cast<PointerType>()) {
-    ptrType.getExtensions(extensions, storage);
   } else if (auto imageType = dyn_cast<ImageType>()) {
     imageType.getExtensions(extensions, storage);
+  } else if (auto matrixType = dyn_cast<MatrixType>()) {
+    matrixType.getExtensions(extensions, storage);
+  } else if (auto ptrType = dyn_cast<PointerType>()) {
+    ptrType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
   }
@@ -839,10 +852,12 @@ void SPIRVType::getCapabilities(
     scalarType.getCapabilities(capabilities, storage);
   } else if (auto compositeType = dyn_cast<CompositeType>()) {
     compositeType.getCapabilities(capabilities, storage);
-  } else if (auto ptrType = dyn_cast<PointerType>()) {
-    ptrType.getCapabilities(capabilities, storage);
   } else if (auto imageType = dyn_cast<ImageType>()) {
     imageType.getCapabilities(capabilities, storage);
+  } else if (auto matrixType = dyn_cast<MatrixType>()) {
+    matrixType.getCapabilities(capabilities, storage);
+  } else if (auto ptrType = dyn_cast<PointerType>()) {
+    ptrType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
   }
@@ -1000,3 +1015,89 @@ void StructType::getCapabilities(
   for (Type elementType : getElementTypes())
     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
 }
+
+//===----------------------------------------------------------------------===//
+// MatrixType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::MatrixTypeStorage : public TypeStorage {
+  MatrixTypeStorage(Type columnType, uint32_t columnCount)
+      : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
+
+  using KeyTy = std::tuple<Type, uint32_t>;
+
+  static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+
+    // Initialize the memory using placement new.
+    return new (allocator.allocate<MatrixTypeStorage>())
+        MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(columnType, columnCount);
+  }
+
+  Type columnType;
+  const uint32_t columnCount;
+};
+
+MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
+  return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
+                   columnCount);
+}
+
+MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
+                                  Location location) {
+  return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
+}
+
+LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
+                                                       Type columnType,
+                                                       uint32_t columnCount) {
+  if (columnCount < 2 || columnCount > 4)
+    return emitError(loc, "matrix can have 2, 3, or 4 columns only");
+
+  if (!isValidColumnType(columnType))
+    return emitError(loc, "matrix columns must be vectors of floats");
+
+  /// The underlying vectors (columns) must be of size 2, 3, or 4
+  ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
+  if (columnShape.size() != 1)
+    return emitError(loc, "matrix columns must be 1D vectors");
+
+  if (columnShape[0] < 2 || columnShape[0] > 4)
+    return emitError(loc, "matrix columns must be of size 2, 3, or 4");
+
+  return success();
+}
+
+/// Returns true if the matrix elements are vectors of float elements
+bool MatrixType::isValidColumnType(Type columnType) {
+  if (auto vectorType = columnType.dyn_cast<VectorType>()) {
+    if (vectorType.getElementType().isa<FloatType>())
+      return true;
+  }
+  return false;
+}
+
+Type MatrixType::getElementType() const { return getImpl()->columnType; }
+
+unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
+
+void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                               Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void MatrixType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  {
+    static const Capability caps[] = {Capability::Matrix};
+    ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+    capabilities.push_back(ref);
+  }
+  // Add any capabilities associated with the underlying vectors (i.e., columns)
+  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 87f233580b75..750dddfa6dc4 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -225,6 +225,8 @@ class Deserializer {
 
   LogicalResult processStructType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
+
   //===--------------------------------------------------------------------===//
   // Constant
   //===--------------------------------------------------------------------===//
@@ -1170,6 +1172,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
     return processRuntimeArrayType(operands);
   case spirv::Opcode::OpTypeStruct:
     return processStructType(operands);
+  case spirv::Opcode::OpTypeMatrix:
+    return processMatrixType(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -1333,6 +1337,25 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 3) {
+    // Three operands are needed: result_id, column_type, and column_count
+    return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
+                                 " (result_id, column_type, and column_count)");
+  }
+  // Matrix columns must be of vector type
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc,
+                     "OpTypeMatrix references undefined column type.")
+           << operands[1];
+  }
+
+  uint32_t colsCount = operands[2];
+  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//
@@ -2238,6 +2261,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   case spirv::Opcode::OpTypeInt:
   case spirv::Opcode::OpTypeFloat:
   case spirv::Opcode::OpTypeVector:
+  case spirv::Opcode::OpTypeMatrix:
   case spirv::Opcode::OpTypeArray:
   case spirv::Opcode::OpTypeFunction:
   case spirv::Opcode::OpTypeRuntimeArray:

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 8ea0c4f4711b..0b1c970589b1 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1111,6 +1111,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     return success();
   }
 
+  if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
+    uint32_t elementTypeID = 0;
+    if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeMatrix;
+    operands.push_back(elementTypeID);
+    operands.push_back(matrixType.getNumElements());
+    return success();
+  }
+
   // TODO(ravishankarm) : Handle other types.
   return emitError(loc, "unhandled type in serialization: ") << type;
 }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
new file mode 100644
index 000000000000..b27702bf50d8
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  spv.func @matrix_type(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>, %arg1 : i32) "None" {
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+    %2 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+  spv.globalVariable @var0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, StorageBuffer>
+
+  // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
+  spv.globalVariable @var1 : !spv.ptr<!spv.matrix<2 x vector<3xf32>>, StorageBuffer>
+
+  // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
+  spv.globalVariable @var2 : !spv.ptr<!spv.matrix<4 x vector<4xf16>>, StorageBuffer>
+}

diff  --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 697177b0b98e..1d1a1868ea3c 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -347,3 +347,87 @@ func @missing_scope(!spv.coopmatrix<8x16xi32>) -> ()
 // expected-error @+1 {{expected rows and columns size}}
 func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> ()
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// Matrix
+//===----------------------------------------------------------------------===//
+// CHECK: func @matrix_type(!spv.matrix<2 x vector<2xf16>>)
+func @matrix_type(!spv.matrix<2 x vector<2xf16>>) -> ()
+
+// -----
+
+// CHECK: func @matrix_type(!spv.matrix<3 x vector<3xf32>>)
+func @matrix_type(!spv.matrix<3 x vector<3xf32>>) -> ()
+
+// -----
+
+// CHECK: func @matrix_type(!spv.matrix<4 x vector<4xf16>>)
+func @matrix_type(!spv.matrix<4 x vector<4xf16>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
+func @matrix_invalid_size(!spv.matrix<5 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}}
+func @matrix_invalid_size(!spv.matrix<1 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 5}}
+func @matrix_invalid_columns_size(!spv.matrix<3 x vector<5xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 1}}
+func @matrix_invalid_columns_size(!spv.matrix<3 x vector<1xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected '<'}}
+func @matrix_invalid_format(!spv.matrix 3 x vector<3xf32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+func @matrix_invalid_format(!spv.matrix< 3 x vector<3xf32>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected 'x' in dimension list}}
+func @matrix_invalid_format(!spv.matrix<2 vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got 'i32'}}
+func @matrix_invalid_type(!spv.matrix< 3 x i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got '!spv.array<16 x f32>'}}
+func @matrix_invalid_type(!spv.matrix< 3 x !spv.array<16 x f32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix must be composed using vector type, got '!spv.rtarray<i32>'}}
+func @matrix_invalid_type(!spv.matrix< 3 x !spv.rtarray<i32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{matrix columns' elements must be of Float type, got 'i32'}}
+func @matrix_invalid_type(!spv.matrix<2 x vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected single unsigned integer for number of columns}}
+func @matrix_size_type(!spv.matrix< x vector<3xi32>>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected single unsigned integer for number of columns}}
+func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> ()
+
+// -----
\ No newline at end of file


        


More information about the Mlir-commits mailing list