[Mlir-commits] [mlir] [mlir][spirv] Add support for structs decorations (PR #149793)

Igor Wodiany llvmlistbot at llvm.org
Mon Jul 21 03:51:02 PDT 2025


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/149793

An alternative implementation could use `ArrayRef` of `NamedAttribute`s or `NamedAttrList` to store structs decorations, as the deserializer uses `NamedAttribute`s for decorations. However, using a custom struct allows us to store the `spirv::Decoration`s directly rather than its name in a `StringRef`/`StringAttr`.

>From 9d2abda552eeed2e703225fb0f9cbb5319f3cd10 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Fri, 18 Jul 2025 10:01:26 +0000
Subject: [PATCH] [mlir][spirv] Add support for structs decorations

An alternative implementation could use ArrayRef of NamedAttributes
or NamedAttrList to store structs decorations, as the deserializer
uses NamedAttributes for decorations. However, using a custom struct
allows us to store the spirv::Decorations directly rather than its
name in a StringRef.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  39 +++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  61 ++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 118 ++++++++++++++----
 .../SPIRV/Deserialization/Deserializer.cpp    |  34 +++--
 .../SPIRV/Deserialization/Deserializer.h      |   1 +
 .../Target/SPIRV/Serialization/Serializer.cpp |  28 ++++-
 mlir/test/Dialect/SPIRV/IR/types.mlir         |   6 +
 mlir/test/Target/SPIRV/memory-ops.mlir        |  20 +--
 mlir/test/Target/SPIRV/struct.mlir            |  38 +++---
 mlir/test/Target/SPIRV/undef.mlir             |   6 +-
 10 files changed, 273 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..56d09301345f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -327,10 +327,33 @@ class StructType
     }
   };
 
+  // Type for specifying the decoration(s) on the struct itself
+  struct StructDecorationInfo {
+    bool hasValue;
+    Decoration decoration;
+    Attribute decorationValue;
+
+    StructDecorationInfo(bool hasValue, Decoration decoration,
+                         Attribute decorationValue)
+        : hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const StructDecorationInfo &other) const {
+      return (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
+
+    bool operator<(const StructDecorationInfo &other) const {
+      return static_cast<uint32_t>(this->decoration) <
+             static_cast<uint32_t>(other.decoration);
+    }
+  };
+
   /// Construct a literal StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
                         ArrayRef<OffsetInfo> offsetInfo = {},
-                        ArrayRef<MemberDecorationInfo> memberDecorations = {});
+                        ArrayRef<MemberDecorationInfo> memberDecorations = {},
+                        ArrayRef<StructDecorationInfo> structDecorations = {});
 
   /// Construct an identified StructType. This creates a StructType whose body
   /// (member types, offset info, and decorations) is not set yet. A call to
@@ -364,6 +387,9 @@ class StructType
 
   bool hasOffset() const;
 
+  /// Returns true if the struct has a specified decoration.
+  bool hasDecoration(spirv::Decoration decoration) const;
+
   uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -377,12 +403,18 @@ class StructType
       unsigned i,
       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
 
+  // Returns in `structDecorations` the Decorations associated with the
+  // StructType.
+  void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+                                &structDecorations) const;
+
   /// Sets the contents of an incomplete identified StructType. This method must
   /// be called only for identified StructTypes and it must be called only once
   /// per instance. Otherwise, failure() is returned.
   LogicalResult
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
-             ArrayRef<MemberDecorationInfo> memberDecorations = {});
+             ArrayRef<MemberDecorationInfo> memberDecorations = {},
+             ArrayRef<StructDecorationInfo> structDecorations = {});
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      std::optional<StorageClass> storage = std::nullopt);
@@ -393,6 +425,9 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..6121fef7318bb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
 //             `!spirv.struct<` (id `,`)?
 //                          `(`
 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
-//                          `)>`
+//                          `)`
+//                            (`,` struct-decoration)?
+//                          `>`
 static Type parseStructType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
   // TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return Type();
   }
 
-  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+  if (failed(parser.parseRParen()))
+    return Type();
+
+  SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+  auto parseStructDecoration = [&]() {
+    std::optional<spirv::Decoration> decoration =
+        parseAndVerify<spirv::Decoration>(dialect, parser);
+    if (!decoration)
+      return failure();
+
+    // Parse decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      Attribute decorationValue;
+
+      if (failed(parser.parseAttribute(decorationValue)))
+        return failure();
+
+      structDecorationInfo.emplace_back(true, decoration.value(),
+                                        decorationValue);
+    } else {
+      structDecorationInfo.emplace_back(false, decoration.value(),
+                                        UnitAttr::get(dialect.getContext()));
+    }
+    return success();
+  };
+
+  while (succeeded(parser.parseOptionalComma()))
+    if (failed(parseStructDecoration()))
+      return Type();
+
+  if (failed(parser.parseGreater()))
     return Type();
 
   if (!identifier.empty()) {
     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
-                                     memberDecorationInfo)))
+                                     memberDecorationInfo,
+                                     structDecorationInfo)))
       return Type();
     return idStructTy;
   }
 
-  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+                         structDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -892,7 +927,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
   };
   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                         printMember);
-  os << ")>";
+  os << ")";
+
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+  type.getStructDecorations(decorations);
+  if (!decorations.empty()) {
+    os << ", ";
+    auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+      os << stringifyDecoration(decoration.decoration);
+      if (decoration.hasValue) {
+        os << "=";
+        os.printAttributeWithoutType(decoration.decorationValue);
+      }
+    };
+    llvm::interleaveComma(decorations, os, eachFn);
+  }
+
+  os << ">";
 }
 
 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3799abd6fc743..4bb06b349d040 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -837,12 +837,14 @@ void SampledImageType::getCapabilities(
 /// - for literal structs:
 ///   - a list of member types;
 ///   - a list of member offset info;
-///   - a list of member decoration info.
+///   - a list of member decoration info;
+///   - a list of struct decoration info.
 ///
 /// Identified structures only have a mutable component consisting of:
 /// - a list of member types;
 /// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// Construct a storage object for an identified struct type. A struct type
   /// associated with such storage must call StructType::trySetBody(...) later
@@ -850,6 +852,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(StringRef identifier)
       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+        numStructDecorations(0), structDecorationsInfo(nullptr),
         identifier(identifier) {}
 
   /// Construct a storage object for a literal struct type. A struct type
@@ -857,10 +860,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
-      StructType::MemberDecorationInfo const *memberDecorationsInfo)
+      StructType::MemberDecorationInfo const *memberDecorationsInfo,
+      unsigned numStructDecorations,
+      StructType::StructDecorationInfo const *structDecorationsInfo)
       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
-        memberDecorationsInfo(memberDecorationsInfo) {}
+        memberDecorationsInfo(memberDecorationsInfo),
+        numStructDecorations(numStructDecorations),
+        structDecorationsInfo(structDecorationsInfo) {}
 
   /// A storage key is divided into 2 parts:
   /// - for identified structs:
@@ -869,16 +876,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   ///   - an ArrayRef<Type> for member types;
   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+  ///     info;
+  ///   - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
   ///     info.
   ///
   /// An identified struct type is uniqued only by the first part (field 0)
   /// of the key.
   ///
-  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
-  /// 3) of the key. The identifier field (field 0) must be empty.
+  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+  /// and 4) of the key. The identifier field (field 0) must be empty.
   using KeyTy =
       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
-                 ArrayRef<StructType::MemberDecorationInfo>>;
+                 ArrayRef<StructType::MemberDecorationInfo>,
+                 ArrayRef<StructType::StructDecorationInfo>>;
 
   /// For identified structs, return true if the given key contains the same
   /// identifier.
@@ -892,7 +902,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     }
 
     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
-                        getMemberDecorationsInfo());
+                        getMemberDecorationsInfo(), getStructDecorationsInfo());
   }
 
   /// If the given key contains a non-empty identifier, this method constructs
@@ -939,9 +949,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
 
-    return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
-                          numMemberDecorations, memberDecorationList);
+    const StructType::StructDecorationInfo *structDecorationList = nullptr;
+    unsigned numStructDecorations = 0;
+    if (!std::get<4>(key).empty()) {
+      auto keyStructDecorations = std::get<4>(key);
+      numStructDecorations = keyStructDecorations.size();
+      structDecorationList = allocator.copyInto(keyStructDecorations).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+        keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+        memberDecorationList, numStructDecorations, structDecorationList);
   }
 
   ArrayRef<Type> getMemberTypes() const {
@@ -963,6 +981,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return {};
   }
 
+  ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+    if (structDecorationsInfo)
+      return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+                                                        numStructDecorations);
+    return {};
+  }
+
   StringRef getIdentifier() const { return identifier; }
 
   bool isIdentified() const { return !identifier.empty(); }
@@ -975,17 +1000,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// - If called for an identified struct whose body was set before (through a
   /// call to this method) but with different contents from the passed
   /// arguments.
-  LogicalResult mutate(
-      TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
-      ArrayRef<StructType::OffsetInfo> structOffsetInfo,
-      ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+  LogicalResult
+  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+         ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+         ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+         ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
     if (!isIdentified())
       return failure();
 
     if (memberTypesAndIsBodySet.getInt() &&
         (getMemberTypes() != structMemberTypes ||
          getOffsetInfo() != structOffsetInfo ||
-         getMemberDecorationsInfo() != structMemberDecorationInfo))
+         getMemberDecorationsInfo() != structMemberDecorationInfo ||
+         getStructDecorationsInfo() != structDecorationInfo))
       return failure();
 
     memberTypesAndIsBodySet.setInt(true);
@@ -1009,6 +1036,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
           allocator.copyInto(structMemberDecorationInfo).data();
     }
 
+    if (!structDecorationInfo.empty()) {
+      numStructDecorations = structDecorationInfo.size();
+      structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+    }
+
     return success();
   }
 
@@ -1017,21 +1049,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   unsigned numMembers;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
+  unsigned numStructDecorations;
+  StructType::StructDecorationInfo const *structDecorationsInfo;
   StringRef identifier;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
                 ArrayRef<StructType::OffsetInfo> offsetInfo,
-                ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+                ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+                ArrayRef<StructType::StructDecorationInfo> structDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
-  SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
       memberDecorations);
-  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+  llvm::array_pod_sort(sortedMemberDecorations.begin(),
+                       sortedMemberDecorations.end());
+  SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+      structDecorations);
+  llvm::array_pod_sort(sortedStructDecorations.begin(),
+                       sortedStructDecorations.end());
+
   return Base::get(memberTypes.vec().front().getContext(),
                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
-                   sortedDecorations);
+                   sortedMemberDecorations, sortedStructDecorations);
 }
 
 StructType StructType::getIdentified(MLIRContext *context,
@@ -1041,18 +1082,21 @@ StructType StructType::getIdentified(MLIRContext *context,
 
   return Base::get(context, identifier, ArrayRef<Type>(),
                    ArrayRef<StructType::OffsetInfo>(),
-                   ArrayRef<StructType::MemberDecorationInfo>());
+                   ArrayRef<StructType::MemberDecorationInfo>(),
+                   ArrayRef<StructType::StructDecorationInfo>());
 }
 
 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
   StructType newStructType = Base::get(
       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-      ArrayRef<StructType::MemberDecorationInfo>());
+      ArrayRef<StructType::MemberDecorationInfo>(),
+      ArrayRef<StructType::StructDecorationInfo>());
   // Set an empty body in case this is a identified struct.
   if (newStructType.isIdentified() &&
       failed(newStructType.trySetBody(
           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-          ArrayRef<StructType::MemberDecorationInfo>())))
+          ArrayRef<StructType::MemberDecorationInfo>(),
+          ArrayRef<StructType::StructDecorationInfo>())))
     return StructType();
 
   return newStructType;
@@ -1076,6 +1120,15 @@ TypeRange StructType::getElementTypes() const {
 
 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+  for (StructType::StructDecorationInfo info :
+       getImpl()->getStructDecorationsInfo())
+    if (info.decoration == decoration)
+      return true;
+
+  return false;
+}
+
 uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
   return getImpl()->offsetInfo[index];
@@ -1107,11 +1160,21 @@ void StructType::getMemberDecorations(
   }
 }
 
+void StructType::getStructDecorations(
+    SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+    const {
+  structDecorations.clear();
+  auto implDecorations = getImpl()->getStructDecorationsInfo();
+  structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
 LogicalResult
 StructType::trySetBody(ArrayRef<Type> memberTypes,
                        ArrayRef<OffsetInfo> offsetInfo,
-                       ArrayRef<MemberDecorationInfo> memberDecorations) {
-  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+                       ArrayRef<MemberDecorationInfo> memberDecorations,
+                       ArrayRef<StructDecorationInfo> structDecorations) {
+  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+                      structDecorations);
 }
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1133,6 +1196,11 @@ llvm::hash_code spirv::hash_value(
                             memberDecorationInfo.decoration);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::StructDecorationInfo &structDecorationInfo) {
+  return llvm::hash_value(structDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..c8aa67c8c3b0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
     }
-    // Block decoration does not affect spirv.struct type, but is still stored
-    // for verification.
-    // TODO: Update StructType to contain this information since
-    // it is needed for many validation rules.
     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
     break;
   case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
 
       if (failed(structType.trySetBody(
               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
-              deferredStructIt->memberDecorationsInfo)))
+              deferredStructIt->memberDecorationsInfo,
+              deferredStructIt->structDecorationsInfo)))
         return failure();
 
       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1202,24 +1199,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     }
   }
 
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+  if (decorations.count(operands[0])) {
+    NamedAttrList &allDecorations = decorations[operands[0]];
+    for (NamedAttribute &decorationAttr : allDecorations) {
+      std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+          llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+      assert(decoration.has_value());
+      bool hasValue =
+          (decorationAttr.getValue() != mlir::UnitAttr::get(context));
+      structDecorationsInfo.emplace_back(hasValue, decoration.value(),
+                                         decorationAttr.getValue());
+    }
+  }
+
   uint32_t structID = operands[0];
   std::string structIdentifier = nameMap.lookup(structID).str();
 
   if (structIdentifier.empty()) {
     assert(unresolvedMemberTypes.empty() &&
            "didn't expect unresolved member types");
-    typeMap[structID] =
-        spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+    typeMap[structID] = spirv::StructType::get(
+        memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
   } else {
     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
     typeMap[structID] = structTy;
 
     if (!unresolvedMemberTypes.empty())
-      deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
-                                          memberTypes, offsetInfo,
-                                          memberDecorationsInfo});
+      deferredStructTypesInfos.push_back(
+          {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+           memberDecorationsInfo, structDecorationsInfo});
     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
-                                        memberDecorationsInfo)))
+                                        memberDecorationsInfo,
+                                        structDecorationsInfo)))
       return failure();
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd2bf501..db1cc3f8d79c2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
   SmallVector<Type, 4> memberTypes;
   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
 };
 
 /// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 3400fcf6374e8..70fce56523e90 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -319,6 +319,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
   case spirv::Decoration::RestrictPointer:
   case spirv::Decoration::NoContraction:
   case spirv::Decoration::Constant:
+  case spirv::Decoration::Block:
     // For unit attributes and decoration attributes, the args list
     // has no values so we do nothing.
     if (isa<UnitAttr, DecorationAttr>(attr))
@@ -617,11 +618,16 @@ LogicalResult Serializer::prepareBasicType(
     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
     operands.push_back(pointeeTypeID);
 
+    // TODO: Now struct decorations are supported this code may not be
+    // necessary. However, it is left to support backwards compatibility.
+    // Ideally, Block decorations should be inserted when converting to SPIR-V.
     if (isInterfaceStructPtrType(ptrType)) {
-      if (failed(emitDecoration(getTypeID(pointeeStruct),
-                                spirv::Decoration::Block)))
-        return emitError(loc, "cannot decorate ")
-               << pointeeStruct << " with Block decoration";
+      auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+      if (!structType.hasDecoration(spirv::Decoration::Block))
+        if (failed(emitDecoration(getTypeID(pointeeStruct),
+                                  spirv::Decoration::Block)))
+          return emitError(loc, "cannot decorate ")
+                 << pointeeStruct << " with Block decoration";
     }
 
     return success();
@@ -689,6 +695,20 @@ LogicalResult Serializer::prepareBasicType(
       }
     }
 
+    SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorations;
+    structType.getStructDecorations(structDecorations);
+
+    for (spirv::StructType::StructDecorationInfo &structDecoration :
+         structDecorations) {
+      if (failed(processDecorationAttr(loc, resultID,
+                                       structDecoration.decoration,
+                                       structDecoration.decorationValue))) {
+        return emitError(loc, "cannot decorate struct ")
+               << structType << " with "
+               << stringifyDecoration(structDecoration.decoration);
+      }
+    }
+
     typeEnum = spirv::Opcode::OpTypeStruct;
 
     if (structType.isIdentified())
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7d45b5ea82643..071389d497fde 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
 // CHECK: func private @struct_empty(!spirv.struct<()>)
 func.func private @struct_empty(!spirv.struct<()>)
 
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
 // -----
 
 // expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c3921d427..786d07a218c66 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 // -----
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
-  spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
-    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+  spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
     %0 = spirv.Constant 0 : i32
-    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     %2 = spirv.Load "StorageBuffer" %1 : f32
 
-    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
     %3 = spirv.Constant 0 : i32
-    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     spirv.Store "StorageBuffer" %4, %2 : f32
     spirv.Return
   }
 
-  spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
-    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+  spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
     %0 = spirv.Constant 0 : i32
-    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
     %2 = spirv.Load "StorageBuffer" %1 : i32
 
-    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
     %3 = spirv.Constant 0 : i32
-    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
     spirv.Store "StorageBuffer" %4, %2 : i32
     spirv.Return
   }
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0bfa2660..4984ee79f903d 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
   spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
-  spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
-  spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
-  spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
-  spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
-  spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
-  spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
 
   // CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
   spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
   spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
 
+  // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+  spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
-  spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+  spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
 
-  // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
-  spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+  // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+  spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
 
-  // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
-  spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+  // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+  spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
 
   // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
   // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe8b40af..8889b80e86f95 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
     %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
     %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
-    // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
-    %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+    // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+    %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
     %8 = spirv.Constant 0 : i32
-    %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
     spirv.Return
   }
 }



More information about the Mlir-commits mailing list