[Mlir-commits] [mlir] 1cf1486 - Revert "[mlir][spirv] Enhance structure type member decoration handling"

Mehdi Amini llvmlistbot at llvm.org
Wed Jun 10 17:52:10 PDT 2020


Author: Mehdi Amini
Date: 2020-06-11T00:52:03Z
New Revision: 1cf14860db884b48a1a79a9697cc9da40f1ac124

URL: https://github.com/llvm/llvm-project/commit/1cf14860db884b48a1a79a9697cc9da40f1ac124
DIFF: https://github.com/llvm/llvm-project/commit/1cf14860db884b48a1a79a9697cc9da40f1ac124.diff

LOG: Revert "[mlir][spirv] Enhance structure type member decoration handling"

This reverts commit 4b7aa6c8c1b0f68c6800225b39b3b389adf31332.

This broke gcc builds.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/struct.mlir
    mlir/test/Dialect/SPIRV/types.mlir
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 3addfd338c9b..b7180399a837 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -276,39 +276,22 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 public:
   using Base::Base;
 
-  // Type for specifying the offset of the struct members
-  using OffsetInfo = uint32_t;
-
-  // Type for specifying the decoration(s) on struct members
-  struct MemberDecorationInfo {
-    uint32_t memberIndex : 31;
-    uint32_t hasValue : 1;
-    Decoration decoration;
-    uint32_t decorationValue;
-
-    MemberDecorationInfo(uint32_t index, uint32_t hasValue,
-                         Decoration decoration, uint32_t decorationValue)
-        : memberIndex(index), hasValue(hasValue), decoration(decoration),
-          decorationValue(decorationValue) {}
-
-    bool operator==(const MemberDecorationInfo &other) const {
-      return (this->memberIndex == other.memberIndex) &&
-             (this->decoration == other.decoration) &&
-             (this->decorationValue == other.decorationValue);
-    }
+  // Layout information used for members in a struct in SPIR-V
+  //
+  // TODO(ravishankarm) : For now this only supports the offset type, so uses
+  // uint64_t value to represent the offset, with
+  // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+  // something that can hold all the information needed for 
diff erent member
+  // types
+  using LayoutInfo = uint64_t;
 
-    bool operator<(const MemberDecorationInfo &other) const {
-      return this->memberIndex < other.memberIndex ||
-             (this->memberIndex == other.memberIndex &&
-              this->decoration < other.decoration);
-    }
-  };
+  using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
 
   static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
 
   /// Construct a StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
-                        ArrayRef<OffsetInfo> offsetInfo = {},
+                        ArrayRef<LayoutInfo> layoutInfo = {},
                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
 
   /// Construct a struct with no members.
@@ -340,9 +323,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
   ElementTypeRange getElementTypes() const;
 
-  bool hasOffset() const;
+  bool hasLayout() const;
 
-  uint64_t getMemberOffset(unsigned) const;
+  uint64_t getOffset(unsigned) const;
 
   // Returns in `allMemberDecorations` the spirv::Decorations (apart from
   // Offset) associated with all members of the StructType.
@@ -351,9 +334,8 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
   // Returns in `memberDecorations` all the spirv::Decorations (apart from
   // Offset) associated with the `i`-th member of the StructType.
-  void getMemberDecorations(unsigned i,
-                            SmallVectorImpl<StructType::MemberDecorationInfo>
-                                &memberDecorations) const;
+  void getMemberDecorations(
+      unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      Optional<spirv::StorageClass> storage = llvm::None);
@@ -361,9 +343,6 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
                        Optional<spirv::StorageClass> storage = llvm::None);
 };
 
-llvm::hash_code
-hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
-
 // SPIR-V cooperative matrix type
 class CooperativeMatrixNVType
     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,

diff  --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index e456f499b745..f953f958d75e 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   }
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
+  SmallVector<Size, 4> layoutInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
 
   Size structMemberOffset = 0;
@@ -46,8 +46,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
         decorateType(structType.getElementType(i), memberSize, memberAlignment);
     structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
     memberTypes.push_back(memberType);
-    offsetInfo.push_back(
-        static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
+    layoutInfo.push_back(structMemberOffset);
     // If the member's size is the max value, it must be the last member and it
     // must be a runtime array.
     assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -67,7 +66,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
   size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
   alignment = maxMemberAlignment;
   structType.getMemberDecorations(memberDecorations);
-  return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
+  return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
 }
 
 Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -169,7 +168,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
   case spirv::StorageClass::StorageBuffer:
   case spirv::StorageClass::PushConstant:
   case spirv::StorageClass::PhysicalStorageBuffer:
-    return structType.hasOffset() || !structType.getNumElements();
+    return structType.hasLayout() || !structType.getNumElements();
   default:
     return true;
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 43e70a1bdc63..455064f58ce6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -535,31 +535,30 @@ static Type parseImageType(SPIRVDialect const &dialect,
 static ParseResult parseStructMemberDecorations(
     SPIRVDialect const &dialect, DialectAsmParser &parser,
     ArrayRef<Type> memberTypes,
-    SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
+    SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
 
   // Check if the first element is offset.
-  llvm::SMLoc offsetLoc = parser.getCurrentLocation();
-  StructType::OffsetInfo offset = 0;
-  OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
-  if (offsetParseResult.hasValue()) {
-    if (failed(*offsetParseResult))
+  llvm::SMLoc layoutLoc = parser.getCurrentLocation();
+  StructType::LayoutInfo layout = 0;
+  OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
+  if (layoutParseResult.hasValue()) {
+    if (failed(*layoutParseResult))
       return failure();
 
-    if (offsetInfo.size() != memberTypes.size() - 1) {
-      return parser.emitError(offsetLoc,
-                              "offset specification must be given for "
-                              "all members");
+    if (layoutInfo.size() != memberTypes.size() - 1) {
+      return parser.emitError(
+          layoutLoc, "layout specification must be given for all members");
     }
-    offsetInfo.push_back(offset);
+    layoutInfo.push_back(layout);
   }
 
   // Check for no spirv::Decorations.
   if (succeeded(parser.parseOptionalRSquare()))
     return success();
 
-  // If there was an offset, make sure to parse the comma.
-  if (offsetParseResult.hasValue() && parser.parseComma())
+  // If there was a layout, make sure to parse the comma.
+  if (layoutParseResult.hasValue() && parser.parseComma())
     return failure();
 
   // Check for spirv::Decorations.
@@ -568,23 +567,9 @@ static ParseResult parseStructMemberDecorations(
     if (!memberDecoration)
       return failure();
 
-    // Parse member decoration value if it exists.
-    if (succeeded(parser.parseOptionalEqual())) {
-      auto memberDecorationValue =
-          parseAndVerifyInteger<uint32_t>(dialect, parser);
-
-      if (!memberDecorationValue)
-        return failure();
-
-      memberDecorationInfo.emplace_back(
-          static_cast<uint32_t>(memberTypes.size() - 1), 1,
-          memberDecoration.getValue(), memberDecorationValue.getValue());
-    } else {
-      memberDecorationInfo.emplace_back(
-          static_cast<uint32_t>(memberTypes.size() - 1), 0,
-          memberDecoration.getValue(), 0);
-    }
-
+    memberDecorationInfo.emplace_back(
+        static_cast<uint32_t>(memberTypes.size() - 1),
+        memberDecoration.getValue());
   } while (succeeded(parser.parseOptionalComma()));
 
   return parser.parseRSquare();
@@ -602,7 +587,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return StructType::getEmpty(dialect.getContext());
 
   SmallVector<Type, 4> memberTypes;
-  SmallVector<StructType::OffsetInfo, 4> offsetInfo;
+  SmallVector<StructType::LayoutInfo, 4> layoutInfo;
   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
 
   do {
@@ -612,21 +597,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
     memberTypes.push_back(memberType);
 
     if (succeeded(parser.parseOptionalLSquare())) {
-      if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
+      if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
                                        memberDecorationInfo)) {
         return Type();
       }
     }
   } while (succeeded(parser.parseOptionalComma()));
 
-  if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
+  if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
     parser.emitError(parser.getNameLoc(),
-                     "offset specification must be given for all members");
+                     "layout specification must be given for all members");
     return Type();
   }
   if (parser.parseGreater())
     return Type();
-  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -694,20 +679,17 @@ static void print(StructType type, DialectAsmPrinter &os) {
   os << "struct<";
   auto printMember = [&](unsigned i) {
     os << type.getElementType(i);
-    SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
+    SmallVector<spirv::Decoration, 0> decorations;
     type.getMemberDecorations(i, decorations);
-    if (type.hasOffset() || !decorations.empty()) {
+    if (type.hasLayout() || !decorations.empty()) {
       os << " [";
-      if (type.hasOffset()) {
-        os << type.getMemberOffset(i);
+      if (type.hasLayout()) {
+        os << type.getOffset(i);
         if (!decorations.empty())
           os << ", ";
       }
-      auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
-        os << stringifyDecoration(decoration.decoration);
-        if (decoration.hasValue) {
-          os << "=" << decoration.decorationValue;
-        }
+      auto eachFn = [&os](spirv::Decoration decoration) {
+        os << stringifyDecoration(decoration);
       };
       llvm::interleaveComma(decorations, os, eachFn);
       os << "]";

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 963f5393c572..0226f5117540 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -874,17 +874,17 @@ void SPIRVType::getCapabilities(
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
-      StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
+      StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations,
       StructType::MemberDecorationInfo const *memberDecorationsInfo)
       : TypeStorage(numMembers), memberTypes(memberTypes),
-        offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+        layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
         memberDecorationsInfo(memberDecorationsInfo) {}
 
-  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
+  using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>,
                            ArrayRef<StructType::MemberDecorationInfo>>;
   bool operator==(const KeyTy &key) const {
     return key ==
-           KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo());
+           KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo());
   }
 
   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -897,13 +897,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       typesList = allocator.copyInto(keyTypes).data();
     }
 
-    const StructType::OffsetInfo *offsetInfoList = nullptr;
+    const StructType::LayoutInfo *layoutInfoList = nullptr;
     if (!std::get<1>(key).empty()) {
-      ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<1>(key);
-      assert(keyOffsetInfo.size() == keyTypes.size() &&
-             "size of offset information must be same as the size of number of "
+      ArrayRef<StructType::LayoutInfo> keyLayoutInfo = std::get<1>(key);
+      assert(keyLayoutInfo.size() == keyTypes.size() &&
+             "size of layout information must be same as the size of number of "
              "elements");
-      offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
+      layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
     }
 
     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
@@ -914,7 +914,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
     return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
+        StructTypeStorage(keyTypes.size(), typesList, layoutInfoList,
                           numMemberDecorations, memberDecorationList);
   }
 
@@ -922,9 +922,9 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return ArrayRef<Type>(memberTypes, getSubclassData());
   }
 
-  ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
-    if (offsetInfo) {
-      return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
+  ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+    if (layoutInfo) {
+      return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
     }
     return {};
   }
@@ -938,14 +938,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   }
 
   Type const *memberTypes;
-  StructType::OffsetInfo const *offsetInfo;
+  StructType::LayoutInfo const *layoutInfo;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
-                ArrayRef<StructType::OffsetInfo> offsetInfo,
+                ArrayRef<StructType::LayoutInfo> layoutInfo,
                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
@@ -953,12 +953,12 @@ StructType::get(ArrayRef<Type> memberTypes,
       memberDecorations.begin(), memberDecorations.end());
   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
   return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
-                   memberTypes, offsetInfo, sortedDecorations);
+                   memberTypes, layoutInfo, sortedDecorations);
 }
 
 StructType StructType::getEmpty(MLIRContext *context) {
   return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
-                   ArrayRef<StructType::OffsetInfo>(),
+                   ArrayRef<StructType::LayoutInfo>(),
                    ArrayRef<StructType::MemberDecorationInfo>());
 }
 
@@ -975,11 +975,11 @@ StructType::ElementTypeRange StructType::getElementTypes() const {
   return ElementTypeRange(getImpl()->memberTypes, getNumElements());
 }
 
-bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
 
-uint64_t StructType::getMemberOffset(unsigned index) const {
+uint64_t StructType::getOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
-  return getImpl()->offsetInfo[index];
+  return getImpl()->layoutInfo[index];
 }
 
 void StructType::getMemberDecorations(
@@ -992,16 +992,15 @@ void StructType::getMemberDecorations(
 }
 
 void StructType::getMemberDecorations(
-    unsigned index,
-    SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
+    unsigned index, SmallVectorImpl<spirv::Decoration> &decorations) const {
   assert(getNumElements() > index && "member index out of range");
   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
-  decorationsInfo.clear();
-  for (const auto &memberDecoration : memberDecorations) {
-    if (memberDecoration.memberIndex == index) {
-      decorationsInfo.push_back(memberDecoration);
+  decorations.clear();
+  for (auto &memberDecoration : memberDecorations) {
+    if (memberDecoration.first == index) {
+      decorations.push_back(memberDecoration.second);
     }
-    if (memberDecoration.memberIndex > index) {
+    if (memberDecoration.first > index) {
       // Early exit since the decorations are stored sorted.
       return;
     }
@@ -1021,12 +1020,6 @@ void StructType::getCapabilities(
     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
 }
 
-llvm::hash_code spirv::hash_value(
-    const StructType::MemberDecorationInfo &memberDecorationInfo) {
-  return llvm::hash_combine(memberDecorationInfo.memberIndex,
-                            memberDecorationInfo.decoration);
-}
-
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4bb9e6d26b97..ecd79d7153ab 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1305,7 +1305,7 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     memberTypes.push_back(memberType);
   }
 
-  SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
+  SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
   if (memberDecorationMap.count(operands[0])) {
     auto &allMemberDecorations = memberDecorationMap[operands[0]];
@@ -1314,27 +1314,27 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
           // Check for offset.
           if (memberDecoration.first == spirv::Decoration::Offset) {
-            // If offset info is empty, resize to the number of members;
-            if (offsetInfo.empty()) {
-              offsetInfo.resize(memberTypes.size());
+            // If layoutInfo is empty, resize to the number of members;
+            if (layoutInfo.empty()) {
+              layoutInfo.resize(memberTypes.size());
             }
-            offsetInfo[memberIndex] = memberDecoration.second[0];
+            layoutInfo[memberIndex] = memberDecoration.second[0];
           } else {
             if (!memberDecoration.second.empty()) {
-              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
-                                                 memberDecoration.first,
-                                                 memberDecoration.second[0]);
-            } else {
-              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
-                                                 memberDecoration.first, 0);
+              return emitError(unknownLoc,
+                               "unhandled OpMemberDecoration with decoration ")
+                     << stringifyDecoration(memberDecoration.first)
+                     << " which has additional operands";
             }
+            memberDecorationsInfo.emplace_back(memberIndex,
+                                               memberDecoration.first);
           }
         }
       }
     }
   }
   typeMap[operands[0]] =
-      spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+      spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
   // TODO(ravishankarm): Update StructType to have member name as attribute as
   // well.
   return success();

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f8641873fd95..81f873281c63 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -227,9 +227,9 @@ class Serializer {
   }
 
   /// Process member decoration
-  LogicalResult processMemberDecoration(
-      uint32_t structID,
-      const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
+  LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+                                        spirv::Decoration decorationType,
+                                        ArrayRef<uint32_t> values = {});
 
   //===--------------------------------------------------------------------===//
   // Types
@@ -736,14 +736,14 @@ LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
   return success();
 }
 
-LogicalResult Serializer::processMemberDecoration(
-    uint32_t structID,
-    const spirv::StructType::MemberDecorationInfo &memberDecoration) {
+LogicalResult
+Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+                                    spirv::Decoration decorationType,
+                                    ArrayRef<uint32_t> values) {
   SmallVector<uint32_t, 4> args(
-      {structID, memberDecoration.memberIndex,
-       static_cast<uint32_t>(memberDecoration.decoration)});
-  if (memberDecoration.hasValue) {
-    args.push_back(memberDecoration.decorationValue);
+      {structID, memberIndex, static_cast<uint32_t>(decorationType)});
+  if (!values.empty()) {
+    args.append(values.begin(), values.end());
   }
   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
                                args);
@@ -1070,7 +1070,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
   }
 
   if (auto structType = type.dyn_cast<spirv::StructType>()) {
-    bool hasOffset = structType.hasOffset();
+    bool hasLayout = structType.hasLayout();
     for (auto elementIndex :
          llvm::seq<uint32_t>(0, structType.getNumElements())) {
       uint32_t elementTypeID = 0;
@@ -1079,12 +1079,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
         return failure();
       }
       operands.push_back(elementTypeID);
-      if (hasOffset) {
+      if (hasLayout) {
         // Decorate each struct member with an offset
-        spirv::StructType::MemberDecorationInfo offsetDecoration{
-            elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
-            static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
-        if (failed(processMemberDecoration(resultID, offsetDecoration))) {
+        if (failed(processMemberDecoration(
+                resultID, elementIndex, spirv::Decoration::Offset,
+                static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
           return emitError(loc, "cannot decorate ")
                  << elementIndex << "-th member of " << structType
                  << " with its offset";
@@ -1094,11 +1093,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
     structType.getMemberDecorations(memberDecorations);
     for (auto &memberDecoration : memberDecorations) {
-      if (failed(processMemberDecoration(resultID, memberDecoration))) {
+      if (failed(processMemberDecoration(resultID, memberDecoration.first,
+                                         memberDecoration.second))) {
         return emitError(loc, "cannot decorate ")
-               << static_cast<uint32_t>(memberDecoration.memberIndex)
-               << "-th member of " << structType << " with "
-               << stringifyDecoration(memberDecoration.decoration);
+               << memberDecoration.first << "-th member of " << structType
+               << " with " << stringifyDecoration(memberDecoration.second);
       }
     }
     typeEnum = spirv::Opcode::OpTypeStruct;

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
index fff591d2f24e..3066462fd71b 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
@@ -22,9 +22,6 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
   spv.globalVariable @var6 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
 
-  // CHECK: !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
-  spv.globalVariable @var7 : !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
-
   // CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
   spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
 

diff  --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index d5eb073c9aa5..1d1a1868ea3c 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -275,23 +275,17 @@ func @struct_type_with_decoration7(!spv.struct<f32 [0], !spv.struct<i32, f32 [No
 // CHECK: func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
 
-// CHECK: func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-
-// CHECK: func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-
 // CHECK: func @struct_empty(!spv.struct<>)
 func @struct_empty(!spv.struct<>)
 
 // -----
 
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
 func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
 
 // -----
 
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
 func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
 
 // -----
@@ -336,16 +330,6 @@ func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i3
 
 // -----
 
-// expected-error @+1 {{expected ']'}}
-func @struct_type_missing_comma(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16]>)
-
-// -----
-
-// expected-error @+1 {{expected integer value}}
-func @struct_missing_member_decorator_value(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=]>)
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // CooperativeMatrix
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 340bfd939e1b..06c417ca23f7 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -58,8 +58,8 @@ class SerializationTest : public ::testing::Test {
   Type getFloatStructType() {
     OpBuilder opBuilder(module.body());
     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
-    llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
-    auto structType = spirv::StructType::get(elementTypes, offsetInfo);
+    llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
+    auto structType = spirv::StructType::get(elementTypes, layoutInfo);
     return structType;
   }
 


        


More information about the Mlir-commits mailing list