[Mlir-commits] [mlir] 4b7aa6c - [mlir][spirv] Enhance structure type member decoration handling

Lei Zhang llvmlistbot at llvm.org
Wed Jun 10 16:25:19 PDT 2020


Author: HazemAbdelhafez
Date: 2020-06-10T19:25:03-04:00
New Revision: 4b7aa6c8c1b0f68c6800225b39b3b389adf31332

URL: https://github.com/llvm/llvm-project/commit/4b7aa6c8c1b0f68c6800225b39b3b389adf31332
DIFF: https://github.com/llvm/llvm-project/commit/4b7aa6c8c1b0f68c6800225b39b3b389adf31332.diff

LOG: [mlir][spirv] Enhance structure type member decoration handling

Modify structure type in SPIR-V dialect to support:
1) Multiple decorations per structure member
2) Key-value based decorations (e.g., MatrixStride)

This commit kept the Offset decoration separate from members'
decorations container for easier implementation and logical clarity.
As such, all references to Structure layoutinfo are now offsetinfo,
and any member layout defining decoration (e.g., RowMajor for Matrix)
will be add to the members' decorations container along with its
value if any.

Differential Revision: https://reviews.llvm.org/D81426

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/Dialect/SPIRV/Serialization/struct.mlir
    mlir/test/Dialect/SPIRV/types.mlir
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index b7180399a837..3addfd338c9b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -276,22 +276,39 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 public:
   using Base::Base;
 
-  // Layout information used for members in a struct in SPIR-V
-  //
-  // TODO(ravishankarm) : For now this only supports the offset type, so uses
-  // uint64_t value to represent the offset, with
-  // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
-  // something that can hold all the information needed for 
diff erent member
-  // types
-  using LayoutInfo = uint64_t;
+  // Type for specifying the offset of the struct members
+  using OffsetInfo = uint32_t;
+
+  // Type for specifying the decoration(s) on struct members
+  struct MemberDecorationInfo {
+    uint32_t memberIndex : 31;
+    uint32_t hasValue : 1;
+    Decoration decoration;
+    uint32_t decorationValue;
+
+    MemberDecorationInfo(uint32_t index, uint32_t hasValue,
+                         Decoration decoration, uint32_t decorationValue)
+        : memberIndex(index), hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const MemberDecorationInfo &other) const {
+      return (this->memberIndex == other.memberIndex) &&
+             (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
 
-  using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
+    bool operator<(const MemberDecorationInfo &other) const {
+      return this->memberIndex < other.memberIndex ||
+             (this->memberIndex == other.memberIndex &&
+              this->decoration < other.decoration);
+    }
+  };
 
   static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
 
   /// Construct a StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
-                        ArrayRef<LayoutInfo> layoutInfo = {},
+                        ArrayRef<OffsetInfo> offsetInfo = {},
                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
 
   /// Construct a struct with no members.
@@ -323,9 +340,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
   ElementTypeRange getElementTypes() const;
 
-  bool hasLayout() const;
+  bool hasOffset() const;
 
-  uint64_t getOffset(unsigned) const;
+  uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `allMemberDecorations` the spirv::Decorations (apart from
   // Offset) associated with all members of the StructType.
@@ -334,8 +351,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
   // Returns in `memberDecorations` all the spirv::Decorations (apart from
   // Offset) associated with the `i`-th member of the StructType.
-  void getMemberDecorations(
-      unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
+  void getMemberDecorations(unsigned i,
+                            SmallVectorImpl<StructType::MemberDecorationInfo>
+                                &memberDecorations) const;
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
@@ -343,6 +361,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
                        Optional<spirv::StorageClass> storage = llvm::None);
 };
 
+llvm::hash_code
+hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+
 // SPIR-V cooperative matrix type
 class CooperativeMatrixNVType
     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,

diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index a891c8c4dbcb..e2d7e2c409c4 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -200,6 +200,9 @@ class DialectAsmParser {
   /// Parse a `=` token.
   virtual ParseResult parseEqual() = 0;
 
+  /// Parse a `=` token if present.
+  virtual ParseResult parseOptionalEqual() = 0;
+
   /// Parse a given keyword.
   ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
     auto loc = getCurrentLocation();

diff  --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index f953f958d75e..e456f499b745 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   }
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<Size, 4> layoutInfo;
+  SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
 
   Size structMemberOffset = 0;
@@ -46,7 +46,8 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
         decorateType(structType.getElementType(i), memberSize, memberAlignment);
     structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
     memberTypes.push_back(memberType);
-    layoutInfo.push_back(structMemberOffset);
+    offsetInfo.push_back(
+        static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
     // If the member's size is the max value, it must be the last member and it
     // must be a runtime array.
     assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -66,7 +67,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
   alignment = maxMemberAlignment;
   structType.getMemberDecorations(memberDecorations);
-  return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
+  return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
 }
 
 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -168,7 +169,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
   case spirv::StorageClass::StorageBuffer:
   case spirv::StorageClass::PushConstant:
   case spirv::StorageClass::PhysicalStorageBuffer:
-    return structType.hasLayout() || !structType.getNumElements();
+    return structType.hasOffset() || !structType.getNumElements();
   default:
     return true;
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 455064f58ce6..43e70a1bdc63 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -535,30 +535,31 @@ static Type parseImageType(SPIRVDialect const &dialect,
 static ParseResult parseStructMemberDecorations(
     SPIRVDialect const &dialect, DialectAsmParser &parser,
     ArrayRef<Type> memberTypes,
-    SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
+    SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
 
   // Check if the first element is offset.
-  llvm::SMLoc layoutLoc = parser.getCurrentLocation();
-  StructType::LayoutInfo layout = 0;
-  OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
-  if (layoutParseResult.hasValue()) {
-    if (failed(*layoutParseResult))
+  llvm::SMLoc offsetLoc = parser.getCurrentLocation();
+  StructType::OffsetInfo offset = 0;
+  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
+  if (offsetParseResult.hasValue()) {
+    if (failed(*offsetParseResult))
       return failure();
 
-    if (layoutInfo.size() != memberTypes.size() - 1) {
-      return parser.emitError(
-          layoutLoc, "layout specification must be given for all members");
+    if (offsetInfo.size() != memberTypes.size() - 1) {
+      return parser.emitError(offsetLoc,
+                              "offset specification must be given for "
+                              "all members");
     }
-    layoutInfo.push_back(layout);
+    offsetInfo.push_back(offset);
   }
 
   // Check for no spirv::Decorations.
   if (succeeded(parser.parseOptionalRSquare()))
     return success();
 
-  // If there was a layout, make sure to parse the comma.
-  if (layoutParseResult.hasValue() && parser.parseComma())
+  // If there was an offset, make sure to parse the comma.
+  if (offsetParseResult.hasValue() && parser.parseComma())
     return failure();
 
   // Check for spirv::Decorations.
@@ -567,9 +568,23 @@ static ParseResult parseStructMemberDecorations(
     if (!memberDecoration)
       return failure();
 
-    memberDecorationInfo.emplace_back(
-        static_cast<uint32_t>(memberTypes.size() - 1),
-        memberDecoration.getValue());
+    // Parse member decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      auto memberDecorationValue =
+          parseAndVerifyInteger<uint32_t>(dialect, parser);
+
+      if (!memberDecorationValue)
+        return failure();
+
+      memberDecorationInfo.emplace_back(
+          static_cast<uint32_t>(memberTypes.size() - 1), 1,
+          memberDecoration.getValue(), memberDecorationValue.getValue());
+    } else {
+      memberDecorationInfo.emplace_back(
+          static_cast<uint32_t>(memberTypes.size() - 1), 0,
+          memberDecoration.getValue(), 0);
+    }
+
   } while (succeeded(parser.parseOptionalComma()));
 
   return parser.parseRSquare();
@@ -587,7 +602,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return StructType::getEmpty(dialect.getContext());
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<StructType::LayoutInfo, 4> layoutInfo;
+  SmallVector<StructType::OffsetInfo, 4> offsetInfo;
   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
 
   do {
@@ -597,21 +612,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
     memberTypes.push_back(memberType);
 
     if (succeeded(parser.parseOptionalLSquare())) {
-      if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
+      if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
                                        memberDecorationInfo)) {
         return Type();
       }
     }
   } while (succeeded(parser.parseOptionalComma()));
 
-  if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
+  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
     parser.emitError(parser.getNameLoc(),
-                     "layout specification must be given for all members");
+                     "offset specification must be given for all members");
     return Type();
   }
   if (parser.parseGreater())
     return Type();
-  return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -679,17 +694,20 @@ static void print(StructType type, DialectAsmPrinter &os) {
   os << "struct<";
   auto printMember = [&](unsigned i) {
     os << type.getElementType(i);
-    SmallVector<spirv::Decoration, 0> decorations;
+    SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
     type.getMemberDecorations(i, decorations);
-    if (type.hasLayout() || !decorations.empty()) {
+    if (type.hasOffset() || !decorations.empty()) {
       os << " [";
-      if (type.hasLayout()) {
-        os << type.getOffset(i);
+      if (type.hasOffset()) {
+        os << type.getMemberOffset(i);
         if (!decorations.empty())
           os << ", ";
       }
-      auto eachFn = [&os](spirv::Decoration decoration) {
-        os << stringifyDecoration(decoration);
+      auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
+        os << stringifyDecoration(decoration.decoration);
+        if (decoration.hasValue) {
+          os << "=" << decoration.decorationValue;
+        }
       };
       llvm::interleaveComma(decorations, os, eachFn);
       os << "]";

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 0226f5117540..963f5393c572 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -874,17 +874,17 @@ void SPIRVType::getCapabilities(
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
-      StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations,
+      StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
       StructType::MemberDecorationInfo const *memberDecorationsInfo)
       : TypeStorage(numMembers), memberTypes(memberTypes),
-        layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+        offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
         memberDecorationsInfo(memberDecorationsInfo) {}
 
-  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>,
+  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
                            ArrayRef<StructType::MemberDecorationInfo>>;
   bool operator==(const KeyTy &key) const {
     return key ==
-           KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo());
+           KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo());
   }
 
   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -897,13 +897,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       typesList = allocator.copyInto(keyTypes).data();
     }
 
-    const StructType::LayoutInfo *layoutInfoList = nullptr;
+    const StructType::OffsetInfo *offsetInfoList = nullptr;
     if (!std::get<1>(key).empty()) {
-      ArrayRef<StructType::LayoutInfo> keyLayoutInfo = std::get<1>(key);
-      assert(keyLayoutInfo.size() == keyTypes.size() &&
-             "size of layout information must be same as the size of number of "
+      ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<1>(key);
+      assert(keyOffsetInfo.size() == keyTypes.size() &&
+             "size of offset information must be same as the size of number of "
              "elements");
-      layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
+      offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
     }
 
     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
@@ -914,7 +914,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
     return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, layoutInfoList,
+        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
                           numMemberDecorations, memberDecorationList);
   }
 
@@ -922,9 +922,9 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return ArrayRef<Type>(memberTypes, getSubclassData());
   }
 
-  ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
-    if (layoutInfo) {
-      return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
+  ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
+    if (offsetInfo) {
+      return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
     }
     return {};
   }
@@ -938,14 +938,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   }
 
   Type const *memberTypes;
-  StructType::LayoutInfo const *layoutInfo;
+  StructType::OffsetInfo const *offsetInfo;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
-                ArrayRef<StructType::LayoutInfo> layoutInfo,
+                ArrayRef<StructType::OffsetInfo> offsetInfo,
                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
@@ -953,12 +953,12 @@ StructType::get(ArrayRef<Type> memberTypes,
       memberDecorations.begin(), memberDecorations.end());
   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
   return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
-                   memberTypes, layoutInfo, sortedDecorations);
+                   memberTypes, offsetInfo, sortedDecorations);
 }
 
 StructType StructType::getEmpty(MLIRContext *context) {
   return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
-                   ArrayRef<StructType::LayoutInfo>(),
+                   ArrayRef<StructType::OffsetInfo>(),
                    ArrayRef<StructType::MemberDecorationInfo>());
 }
 
@@ -975,11 +975,11 @@ StructType::ElementTypeRange StructType::getElementTypes() const {
   return ElementTypeRange(getImpl()->memberTypes, getNumElements());
 }
 
-bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
+bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
-uint64_t StructType::getOffset(unsigned index) const {
+uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
-  return getImpl()->layoutInfo[index];
+  return getImpl()->offsetInfo[index];
 }
 
 void StructType::getMemberDecorations(
@@ -992,15 +992,16 @@ void StructType::getMemberDecorations(
 }
 
 void StructType::getMemberDecorations(
-    unsigned index, SmallVectorImpl<spirv::Decoration> &decorations) const {
+    unsigned index,
+    SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
   assert(getNumElements() > index && "member index out of range");
   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
-  decorations.clear();
-  for (auto &memberDecoration : memberDecorations) {
-    if (memberDecoration.first == index) {
-      decorations.push_back(memberDecoration.second);
+  decorationsInfo.clear();
+  for (const auto &memberDecoration : memberDecorations) {
+    if (memberDecoration.memberIndex == index) {
+      decorationsInfo.push_back(memberDecoration);
     }
-    if (memberDecoration.first > index) {
+    if (memberDecoration.memberIndex > index) {
       // Early exit since the decorations are stored sorted.
       return;
     }
@@ -1020,6 +1021,12 @@ void StructType::getCapabilities(
     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::MemberDecorationInfo &memberDecorationInfo) {
+  return llvm::hash_combine(memberDecorationInfo.memberIndex,
+                            memberDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index ecd79d7153ab..4bb9e6d26b97 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1305,7 +1305,7 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     memberTypes.push_back(memberType);
   }
 
-  SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
+  SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
   if (memberDecorationMap.count(operands[0])) {
     auto &allMemberDecorations = memberDecorationMap[operands[0]];
@@ -1314,27 +1314,27 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
           // Check for offset.
           if (memberDecoration.first == spirv::Decoration::Offset) {
-            // If layoutInfo is empty, resize to the number of members;
-            if (layoutInfo.empty()) {
-              layoutInfo.resize(memberTypes.size());
+            // If offset info is empty, resize to the number of members;
+            if (offsetInfo.empty()) {
+              offsetInfo.resize(memberTypes.size());
             }
-            layoutInfo[memberIndex] = memberDecoration.second[0];
+            offsetInfo[memberIndex] = memberDecoration.second[0];
           } else {
             if (!memberDecoration.second.empty()) {
-              return emitError(unknownLoc,
-                               "unhandled OpMemberDecoration with decoration ")
-                     << stringifyDecoration(memberDecoration.first)
-                     << " which has additional operands";
+              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
+                                                 memberDecoration.first,
+                                                 memberDecoration.second[0]);
+            } else {
+              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
+                                                 memberDecoration.first, 0);
             }
-            memberDecorationsInfo.emplace_back(memberIndex,
-                                               memberDecoration.first);
           }
         }
       }
     }
   }
   typeMap[operands[0]] =
-      spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
+      spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
   // TODO(ravishankarm): Update StructType to have member name as attribute as
   // well.
   return success();

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 81f873281c63..f8641873fd95 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -227,9 +227,9 @@ class Serializer {
   }
 
   /// Process member decoration
-  LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex,
-                                        spirv::Decoration decorationType,
-                                        ArrayRef<uint32_t> values = {});
+  LogicalResult processMemberDecoration(
+      uint32_t structID,
+      const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
 
   //===--------------------------------------------------------------------===//
   // Types
@@ -736,14 +736,14 @@ LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
   return success();
 }
 
-LogicalResult
-Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
-                                    spirv::Decoration decorationType,
-                                    ArrayRef<uint32_t> values) {
+LogicalResult Serializer::processMemberDecoration(
+    uint32_t structID,
+    const spirv::StructType::MemberDecorationInfo &memberDecoration) {
   SmallVector<uint32_t, 4> args(
-      {structID, memberIndex, static_cast<uint32_t>(decorationType)});
-  if (!values.empty()) {
-    args.append(values.begin(), values.end());
+      {structID, memberDecoration.memberIndex,
+       static_cast<uint32_t>(memberDecoration.decoration)});
+  if (memberDecoration.hasValue) {
+    args.push_back(memberDecoration.decorationValue);
   }
   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
                                args);
@@ -1070,7 +1070,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
   }
 
   if (auto structType = type.dyn_cast<spirv::StructType>()) {
-    bool hasLayout = structType.hasLayout();
+    bool hasOffset = structType.hasOffset();
     for (auto elementIndex :
          llvm::seq<uint32_t>(0, structType.getNumElements())) {
       uint32_t elementTypeID = 0;
@@ -1079,11 +1079,12 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
         return failure();
       }
       operands.push_back(elementTypeID);
-      if (hasLayout) {
+      if (hasOffset) {
         // Decorate each struct member with an offset
-        if (failed(processMemberDecoration(
-                resultID, elementIndex, spirv::Decoration::Offset,
-                static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
+        spirv::StructType::MemberDecorationInfo offsetDecoration{
+            elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
+            static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
+        if (failed(processMemberDecoration(resultID, offsetDecoration))) {
           return emitError(loc, "cannot decorate ")
                  << elementIndex << "-th member of " << structType
                  << " with its offset";
@@ -1093,11 +1094,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
     structType.getMemberDecorations(memberDecorations);
     for (auto &memberDecoration : memberDecorations) {
-      if (failed(processMemberDecoration(resultID, memberDecoration.first,
-                                         memberDecoration.second))) {
+      if (failed(processMemberDecoration(resultID, memberDecoration))) {
         return emitError(loc, "cannot decorate ")
-               << memberDecoration.first << "-th member of " << structType
-               << " with " << stringifyDecoration(memberDecoration.second);
+               << static_cast<uint32_t>(memberDecoration.memberIndex)
+               << "-th member of " << structType << " with "
+               << stringifyDecoration(memberDecoration.decoration);
       }
     }
     typeEnum = spirv::Opcode::OpTypeStruct;

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 88e576c36df7..be96a39e6789 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -544,6 +544,11 @@ class CustomDialectAsmParser : public DialectAsmParser {
     return parser.parseToken(Token::equal, "expected '='");
   }
 
+  /// Parse a `=` token if present.
+  ParseResult parseOptionalEqual() override {
+    return success(parser.consumeIf(Token::equal));
+  }
+
   /// Parse a '<' token.
   ParseResult parseLess() override {
     return parser.parseToken(Token::less, "expected '<'");

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
index 3066462fd71b..fff591d2f24e 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
@@ -22,6 +22,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
   spv.globalVariable @var6 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
 
+  // CHECK: !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
+  spv.globalVariable @var7 : !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
+
   // CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
   spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
 

diff  --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 1d1a1868ea3c..d5eb073c9aa5 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -275,17 +275,23 @@ func @struct_type_with_decoration7(!spv.struct<f32 [0], !spv.struct<i32, f32 [No
 // CHECK: func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 
+// CHECK: func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
+func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
+
+// CHECK: func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
+func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
+
 // CHECK: func @struct_empty(!spv.struct<>)
 func @struct_empty(!spv.struct<>)
 
 // -----
 
-// expected-error @+1 {{layout specification must be given for all members}}
+// expected-error @+1 {{offset specification must be given for all members}}
 func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
 
 // -----
 
-// expected-error @+1 {{layout specification must be given for all members}}
+// expected-error @+1 {{offset specification must be given for all members}}
 func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
 
 // -----
@@ -330,6 +336,16 @@ func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i3
 
 // -----
 
+// expected-error @+1 {{expected ']'}}
+func @struct_type_missing_comma(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16]>)
+
+// -----
+
+// expected-error @+1 {{expected integer value}}
+func @struct_missing_member_decorator_value(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=]>)
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // CooperativeMatrix
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 06c417ca23f7..340bfd939e1b 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -58,8 +58,8 @@ class SerializationTest : public ::testing::Test {
   Type getFloatStructType() {
     OpBuilder opBuilder(module.body());
     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
-    llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
-    auto structType = spirv::StructType::get(elementTypes, layoutInfo);
+    llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
+    auto structType = spirv::StructType::get(elementTypes, offsetInfo);
     return structType;
   }
 


        


More information about the Mlir-commits mailing list