[Mlir-commits] [mlir] [mlir][spirv] Add support for structs decorations (PR #149793)

Igor Wodiany llvmlistbot at llvm.org
Thu Jul 24 07:38:52 PDT 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/149793

>From 3c2a9d2abb14f84cf7750487a8ba52f1b069f8b2 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Fri, 18 Jul 2025 10:01:26 +0000
Subject: [PATCH 1/3] [mlir][spirv] Add support for structs decorations

An alternative implementation could use ArrayRef of NamedAttributes
or NamedAttrList to store structs decorations, as the deserializer
uses NamedAttributes for decorations. However, using a custom struct
allows us to store the spirv::Decorations directly rather than its
name in a StringRef.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  39 +++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  61 ++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 118 ++++++++++++++----
 .../SPIRV/Deserialization/Deserializer.cpp    |  34 +++--
 .../SPIRV/Deserialization/Deserializer.h      |   1 +
 .../Target/SPIRV/Serialization/Serializer.cpp |  28 ++++-
 mlir/test/Dialect/SPIRV/IR/types.mlir         |   6 +
 mlir/test/Target/SPIRV/memory-ops.mlir        |  20 +--
 mlir/test/Target/SPIRV/struct.mlir            |  38 +++---
 mlir/test/Target/SPIRV/undef.mlir             |   6 +-
 10 files changed, 273 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index c691d5901529b..129422e36dac3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -330,10 +330,33 @@ class StructType
     bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
   };
 
+  // Type for specifying the decoration(s) on the struct itself
+  struct StructDecorationInfo {
+    bool hasValue;
+    Decoration decoration;
+    Attribute decorationValue;
+
+    StructDecorationInfo(bool hasValue, Decoration decoration,
+                         Attribute decorationValue)
+        : hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const StructDecorationInfo &other) const {
+      return (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
+
+    bool operator<(const StructDecorationInfo &other) const {
+      return static_cast<uint32_t>(this->decoration) <
+             static_cast<uint32_t>(other.decoration);
+    }
+  };
+
   /// Construct a literal StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
                         ArrayRef<OffsetInfo> offsetInfo = {},
-                        ArrayRef<MemberDecorationInfo> memberDecorations = {});
+                        ArrayRef<MemberDecorationInfo> memberDecorations = {},
+                        ArrayRef<StructDecorationInfo> structDecorations = {});
 
   /// Construct an identified StructType. This creates a StructType whose body
   /// (member types, offset info, and decorations) is not set yet. A call to
@@ -367,6 +390,9 @@ class StructType
 
   bool hasOffset() const;
 
+  /// Returns true if the struct has a specified decoration.
+  bool hasDecoration(spirv::Decoration decoration) const;
+
   uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -380,12 +406,18 @@ class StructType
       unsigned i,
       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
 
+  // Returns in `structDecorations` the Decorations associated with the
+  // StructType.
+  void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+                                &structDecorations) const;
+
   /// Sets the contents of an incomplete identified StructType. This method must
   /// be called only for identified StructTypes and it must be called only once
   /// per instance. Otherwise, failure() is returned.
   LogicalResult
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
-             ArrayRef<MemberDecorationInfo> memberDecorations = {});
+             ArrayRef<MemberDecorationInfo> memberDecorations = {},
+             ArrayRef<StructDecorationInfo> structDecorations = {});
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      std::optional<StorageClass> storage = std::nullopt);
@@ -396,6 +428,9 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f2f7f7057a7f6..b3c31cb298451 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -691,7 +691,9 @@ static ParseResult parseStructMemberDecorations(
 //             `!spirv.struct<` (id `,`)?
 //                          `(`
 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
-//                          `)>`
+//                          `)`
+//                            (`,` struct-decoration)?
+//                          `>`
 static Type parseStructType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
   // TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -765,17 +767,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return Type();
   }
 
-  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+  if (failed(parser.parseRParen()))
+    return Type();
+
+  SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+  auto parseStructDecoration = [&]() {
+    std::optional<spirv::Decoration> decoration =
+        parseAndVerify<spirv::Decoration>(dialect, parser);
+    if (!decoration)
+      return failure();
+
+    // Parse decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      Attribute decorationValue;
+
+      if (failed(parser.parseAttribute(decorationValue)))
+        return failure();
+
+      structDecorationInfo.emplace_back(true, decoration.value(),
+                                        decorationValue);
+    } else {
+      structDecorationInfo.emplace_back(false, decoration.value(),
+                                        UnitAttr::get(dialect.getContext()));
+    }
+    return success();
+  };
+
+  while (succeeded(parser.parseOptionalComma()))
+    if (failed(parseStructDecoration()))
+      return Type();
+
+  if (failed(parser.parseGreater()))
     return Type();
 
   if (!identifier.empty()) {
     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
-                                     memberDecorationInfo)))
+                                     memberDecorationInfo,
+                                     structDecorationInfo)))
       return Type();
     return idStructTy;
   }
 
-  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+                         structDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -891,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
   };
   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                         printMember);
-  os << ")>";
+  os << ")";
+
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+  type.getStructDecorations(decorations);
+  if (!decorations.empty()) {
+    os << ", ";
+    auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+      os << stringifyDecoration(decoration.decoration);
+      if (decoration.hasValue) {
+        os << "=";
+        os.printAttributeWithoutType(decoration.decorationValue);
+      }
+    };
+    llvm::interleaveComma(decorations, os, eachFn);
+  }
+
+  os << ">";
 }
 
 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 46739bcd79b8a..eaef205f73321 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -835,12 +835,14 @@ void SampledImageType::getCapabilities(
 /// - for literal structs:
 ///   - a list of member types;
 ///   - a list of member offset info;
-///   - a list of member decoration info.
+///   - a list of member decoration info;
+///   - a list of struct decoration info.
 ///
 /// Identified structures only have a mutable component consisting of:
 /// - a list of member types;
 /// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// Construct a storage object for an identified struct type. A struct type
   /// associated with such storage must call StructType::trySetBody(...) later
@@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(StringRef identifier)
       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+        numStructDecorations(0), structDecorationsInfo(nullptr),
         identifier(identifier) {}
 
   /// Construct a storage object for a literal struct type. A struct type
@@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
-      StructType::MemberDecorationInfo const *memberDecorationsInfo)
+      StructType::MemberDecorationInfo const *memberDecorationsInfo,
+      unsigned numStructDecorations,
+      StructType::StructDecorationInfo const *structDecorationsInfo)
       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
-        memberDecorationsInfo(memberDecorationsInfo) {}
+        memberDecorationsInfo(memberDecorationsInfo),
+        numStructDecorations(numStructDecorations),
+        structDecorationsInfo(structDecorationsInfo) {}
 
   /// A storage key is divided into 2 parts:
   /// - for identified structs:
@@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   ///   - an ArrayRef<Type> for member types;
   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+  ///     info;
+  ///   - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
   ///     info.
   ///
   /// An identified struct type is uniqued only by the first part (field 0)
   /// of the key.
   ///
-  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
-  /// 3) of the key. The identifier field (field 0) must be empty.
+  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+  /// and 4) of the key. The identifier field (field 0) must be empty.
   using KeyTy =
       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
-                 ArrayRef<StructType::MemberDecorationInfo>>;
+                 ArrayRef<StructType::MemberDecorationInfo>,
+                 ArrayRef<StructType::StructDecorationInfo>>;
 
   /// For identified structs, return true if the given key contains the same
   /// identifier.
@@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     }
 
     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
-                        getMemberDecorationsInfo());
+                        getMemberDecorationsInfo(), getStructDecorationsInfo());
   }
 
   /// If the given key contains a non-empty identifier, this method constructs
@@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
 
-    return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
-                          numMemberDecorations, memberDecorationList);
+    const StructType::StructDecorationInfo *structDecorationList = nullptr;
+    unsigned numStructDecorations = 0;
+    if (!std::get<4>(key).empty()) {
+      auto keyStructDecorations = std::get<4>(key);
+      numStructDecorations = keyStructDecorations.size();
+      structDecorationList = allocator.copyInto(keyStructDecorations).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+        keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+        memberDecorationList, numStructDecorations, structDecorationList);
   }
 
   ArrayRef<Type> getMemberTypes() const {
@@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return {};
   }
 
+  ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+    if (structDecorationsInfo)
+      return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+                                                        numStructDecorations);
+    return {};
+  }
+
   StringRef getIdentifier() const { return identifier; }
 
   bool isIdentified() const { return !identifier.empty(); }
@@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// - If called for an identified struct whose body was set before (through a
   /// call to this method) but with different contents from the passed
   /// arguments.
-  LogicalResult mutate(
-      TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
-      ArrayRef<StructType::OffsetInfo> structOffsetInfo,
-      ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+  LogicalResult
+  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+         ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+         ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+         ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
     if (!isIdentified())
       return failure();
 
     if (memberTypesAndIsBodySet.getInt() &&
         (getMemberTypes() != structMemberTypes ||
          getOffsetInfo() != structOffsetInfo ||
-         getMemberDecorationsInfo() != structMemberDecorationInfo))
+         getMemberDecorationsInfo() != structMemberDecorationInfo ||
+         getStructDecorationsInfo() != structDecorationInfo))
       return failure();
 
     memberTypesAndIsBodySet.setInt(true);
@@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
           allocator.copyInto(structMemberDecorationInfo).data();
     }
 
+    if (!structDecorationInfo.empty()) {
+      numStructDecorations = structDecorationInfo.size();
+      structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+    }
+
     return success();
   }
 
@@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   unsigned numMembers;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
+  unsigned numStructDecorations;
+  StructType::StructDecorationInfo const *structDecorationsInfo;
   StringRef identifier;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
                 ArrayRef<StructType::OffsetInfo> offsetInfo,
-                ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+                ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+                ArrayRef<StructType::StructDecorationInfo> structDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
-  SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
       memberDecorations);
-  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+  llvm::array_pod_sort(sortedMemberDecorations.begin(),
+                       sortedMemberDecorations.end());
+  SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+      structDecorations);
+  llvm::array_pod_sort(sortedStructDecorations.begin(),
+                       sortedStructDecorations.end());
+
   return Base::get(memberTypes.vec().front().getContext(),
                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
-                   sortedDecorations);
+                   sortedMemberDecorations, sortedStructDecorations);
 }
 
 StructType StructType::getIdentified(MLIRContext *context,
@@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context,
 
   return Base::get(context, identifier, ArrayRef<Type>(),
                    ArrayRef<StructType::OffsetInfo>(),
-                   ArrayRef<StructType::MemberDecorationInfo>());
+                   ArrayRef<StructType::MemberDecorationInfo>(),
+                   ArrayRef<StructType::StructDecorationInfo>());
 }
 
 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
   StructType newStructType = Base::get(
       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-      ArrayRef<StructType::MemberDecorationInfo>());
+      ArrayRef<StructType::MemberDecorationInfo>(),
+      ArrayRef<StructType::StructDecorationInfo>());
   // Set an empty body in case this is a identified struct.
   if (newStructType.isIdentified() &&
       failed(newStructType.trySetBody(
           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-          ArrayRef<StructType::MemberDecorationInfo>())))
+          ArrayRef<StructType::MemberDecorationInfo>(),
+          ArrayRef<StructType::StructDecorationInfo>())))
     return StructType();
 
   return newStructType;
@@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const {
 
 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+  for (StructType::StructDecorationInfo info :
+       getImpl()->getStructDecorationsInfo())
+    if (info.decoration == decoration)
+      return true;
+
+  return false;
+}
+
 uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
   return getImpl()->offsetInfo[index];
@@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations(
   }
 }
 
+void StructType::getStructDecorations(
+    SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+    const {
+  structDecorations.clear();
+  auto implDecorations = getImpl()->getStructDecorationsInfo();
+  structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
 LogicalResult
 StructType::trySetBody(ArrayRef<Type> memberTypes,
                        ArrayRef<OffsetInfo> offsetInfo,
-                       ArrayRef<MemberDecorationInfo> memberDecorations) {
-  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+                       ArrayRef<MemberDecorationInfo> memberDecorations,
+                       ArrayRef<StructDecorationInfo> structDecorations) {
+  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+                      structDecorations);
 }
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value(
                             memberDecorationInfo.decoration);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::StructDecorationInfo &structDecorationInfo) {
+  return llvm::hash_value(structDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e5934bb9943fd..0410d023bcd17 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
     }
-    // Block decoration does not affect spirv.struct type, but is still stored
-    // for verification.
-    // TODO: Update StructType to contain this information since
-    // it is needed for many validation rules.
     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
     break;
   case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
 
       if (failed(structType.trySetBody(
               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
-              deferredStructIt->memberDecorationsInfo)))
+              deferredStructIt->memberDecorationsInfo,
+              deferredStructIt->structDecorationsInfo)))
         return failure();
 
       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1203,24 +1200,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     }
   }
 
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+  if (decorations.count(operands[0])) {
+    NamedAttrList &allDecorations = decorations[operands[0]];
+    for (NamedAttribute &decorationAttr : allDecorations) {
+      std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+          llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+      assert(decoration.has_value());
+      bool hasValue =
+          (decorationAttr.getValue() != mlir::UnitAttr::get(context));
+      structDecorationsInfo.emplace_back(hasValue, decoration.value(),
+                                         decorationAttr.getValue());
+    }
+  }
+
   uint32_t structID = operands[0];
   std::string structIdentifier = nameMap.lookup(structID).str();
 
   if (structIdentifier.empty()) {
     assert(unresolvedMemberTypes.empty() &&
            "didn't expect unresolved member types");
-    typeMap[structID] =
-        spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+    typeMap[structID] = spirv::StructType::get(
+        memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
   } else {
     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
     typeMap[structID] = structTy;
 
     if (!unresolvedMemberTypes.empty())
-      deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
-                                          memberTypes, offsetInfo,
-                                          memberDecorationsInfo});
+      deferredStructTypesInfos.push_back(
+          {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+           memberDecorationsInfo, structDecorationsInfo});
     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
-                                        memberDecorationsInfo)))
+                                        memberDecorationsInfo,
+                                        structDecorationsInfo)))
       return failure();
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd2bf501..db1cc3f8d79c2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
   SmallVector<Type, 4> memberTypes;
   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
 };
 
 /// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 6aab6bb668c78..468d9ef0ef9c9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -319,6 +319,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
   case spirv::Decoration::RestrictPointer:
   case spirv::Decoration::NoContraction:
   case spirv::Decoration::Constant:
+  case spirv::Decoration::Block:
     // For unit attributes and decoration attributes, the args list
     // has no values so we do nothing.
     if (isa<UnitAttr, DecorationAttr>(attr))
@@ -618,11 +619,16 @@ LogicalResult Serializer::prepareBasicType(
     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
     operands.push_back(pointeeTypeID);
 
+    // TODO: Now struct decorations are supported this code may not be
+    // necessary. However, it is left to support backwards compatibility.
+    // Ideally, Block decorations should be inserted when converting to SPIR-V.
     if (isInterfaceStructPtrType(ptrType)) {
-      if (failed(emitDecoration(getTypeID(pointeeStruct),
-                                spirv::Decoration::Block)))
-        return emitError(loc, "cannot decorate ")
-               << pointeeStruct << " with Block decoration";
+      auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+      if (!structType.hasDecoration(spirv::Decoration::Block))
+        if (failed(emitDecoration(getTypeID(pointeeStruct),
+                                  spirv::Decoration::Block)))
+          return emitError(loc, "cannot decorate ")
+                 << pointeeStruct << " with Block decoration";
     }
 
     return success();
@@ -692,6 +698,20 @@ LogicalResult Serializer::prepareBasicType(
       }
     }
 
+    SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorations;
+    structType.getStructDecorations(structDecorations);
+
+    for (spirv::StructType::StructDecorationInfo &structDecoration :
+         structDecorations) {
+      if (failed(processDecorationAttr(loc, resultID,
+                                       structDecoration.decoration,
+                                       structDecoration.decorationValue))) {
+        return emitError(loc, "cannot decorate struct ")
+               << structType << " with "
+               << stringifyDecoration(structDecoration.decoration);
+      }
+    }
+
     typeEnum = spirv::Opcode::OpTypeStruct;
 
     if (structType.isIdentified())
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 5d05a65414969..6d321afebf7f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
 // CHECK: func private @struct_empty(!spirv.struct<()>)
 func.func private @struct_empty(!spirv.struct<()>)
 
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
 // -----
 
 // expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c3921d427..786d07a218c66 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 // -----
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
-  spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
-    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+  spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
     %0 = spirv.Constant 0 : i32
-    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     %2 = spirv.Load "StorageBuffer" %1 : f32
 
-    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
     %3 = spirv.Constant 0 : i32
-    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
     spirv.Store "StorageBuffer" %4, %2 : f32
     spirv.Return
   }
 
-  spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
-    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+  spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
     %0 = spirv.Constant 0 : i32
-    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+    %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
     %2 = spirv.Load "StorageBuffer" %1 : i32
 
-    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+    // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
     // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
     %3 = spirv.Constant 0 : i32
-    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+    %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
     spirv.Store "StorageBuffer" %4, %2 : i32
     spirv.Return
   }
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0bfa2660..4984ee79f903d 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
   spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
-  spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
-  spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
-  spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
-  spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
-  spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
-  spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+  spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
 
   // CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
   spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
   spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
 
+  // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+  spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
 
-  // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
-  spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+  // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+  spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
 
-  // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
-  spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+  // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+  spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
 
-  // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
-  spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+  // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+  spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
 
   // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
   // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe8b40af..8889b80e86f95 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
     %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
     %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
-    // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
-    %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+    // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+    %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
     %8 = spirv.Constant 0 : i32
-    %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+    %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
     spirv.Return
   }
 }

>From 725df8adabcac62952cc7ec8dbee2876c7e7195f Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 22 Jul 2025 12:44:47 +0000
Subject: [PATCH 2/3] Address some feedback

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h    | 2 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp         | 4 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp           | 2 +-
 mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 2 +-
 4 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 129422e36dac3..e3be01e97d229 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -330,7 +330,7 @@ class StructType
     bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
   };
 
-  // Type for specifying the decoration(s) on the struct itself
+  // Type for specifying the decoration(s) on the struct itself.
   struct StructDecorationInfo {
     bool hasValue;
     Decoration decoration;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index b3c31cb298451..a77d1b2a45af7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -770,7 +770,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
   if (failed(parser.parseRParen()))
     return Type();
 
-  SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+  SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
 
   auto parseStructDecoration = [&]() {
     std::optional<spirv::Decoration> decoration =
@@ -928,7 +928,7 @@ static void print(StructType type, DialectAsmPrinter &os) {
                         printMember);
   os << ")";
 
-  SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+  SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
   type.getStructDecorations(decorations);
   if (!decorations.empty()) {
     os << ", ";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index eaef205f73321..ddb342621f371 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1063,7 +1063,7 @@ StructType::get(ArrayRef<Type> memberTypes,
       memberDecorations);
   llvm::array_pod_sort(sortedMemberDecorations.begin(),
                        sortedMemberDecorations.end());
-  SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+  SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
       structDecorations);
   llvm::array_pod_sort(sortedStructDecorations.begin(),
                        sortedStructDecorations.end());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 468d9ef0ef9c9..04d698748a796 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -698,7 +698,7 @@ LogicalResult Serializer::prepareBasicType(
       }
     }
 
-    SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorations;
+    SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
     structType.getStructDecorations(structDecorations);
 
     for (spirv::StructType::StructDecorationInfo &structDecoration :

>From 60c2d95728921f63bc1e83dc172dacd3ed3d61d2 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 24 Jul 2025 14:33:11 +0000
Subject: [PATCH 3/3] Feedback and match new member decrations

---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 23 ++++++++++---------
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  8 +++----
 .../SPIRV/Deserialization/Deserializer.cpp    |  4 +---
 3 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index e3be01e97d229..531feccccb032 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -332,24 +332,25 @@ class StructType
 
   // Type for specifying the decoration(s) on the struct itself.
   struct StructDecorationInfo {
-    bool hasValue;
     Decoration decoration;
     Attribute decorationValue;
 
-    StructDecorationInfo(bool hasValue, Decoration decoration,
-                         Attribute decorationValue)
-        : hasValue(hasValue), decoration(decoration),
-          decorationValue(decorationValue) {}
+    StructDecorationInfo(Decoration decoration, Attribute decorationValue)
+        : decoration(decoration), decorationValue(decorationValue) {}
 
-    bool operator==(const StructDecorationInfo &other) const {
-      return (this->decoration == other.decoration) &&
-             (this->decorationValue == other.decorationValue);
+    friend bool operator==(const StructDecorationInfo &lhs,
+                           const StructDecorationInfo &rhs) {
+      return lhs.decoration == rhs.decoration &&
+             lhs.decorationValue == rhs.decorationValue;
     }
 
-    bool operator<(const StructDecorationInfo &other) const {
-      return static_cast<uint32_t>(this->decoration) <
-             static_cast<uint32_t>(other.decoration);
+    friend bool operator<(const StructDecorationInfo &lhs,
+                          const StructDecorationInfo &rhs) {
+      return llvm::to_underlying(lhs.decoration) <
+             llvm::to_underlying(rhs.decoration);
     }
+
+    bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
   };
 
   /// Construct a literal StructType with at least one member.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a77d1b2a45af7..706552c3ad2d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -781,14 +781,12 @@ static Type parseStructType(SPIRVDialect const &dialect,
     // Parse decoration value if it exists.
     if (succeeded(parser.parseOptionalEqual())) {
       Attribute decorationValue;
-
       if (failed(parser.parseAttribute(decorationValue)))
         return failure();
 
-      structDecorationInfo.emplace_back(true, decoration.value(),
-                                        decorationValue);
+      structDecorationInfo.emplace_back(decoration.value(), decorationValue);
     } else {
-      structDecorationInfo.emplace_back(false, decoration.value(),
+      structDecorationInfo.emplace_back(decoration.value(),
                                         UnitAttr::get(dialect.getContext()));
     }
     return success();
@@ -934,7 +932,7 @@ static void print(StructType type, DialectAsmPrinter &os) {
     os << ", ";
     auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
       os << stringifyDecoration(decoration.decoration);
-      if (decoration.hasValue) {
+      if (decoration.hasValue()) {
         os << "=";
         os.printAttributeWithoutType(decoration.decorationValue);
       }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 0410d023bcd17..88931b53a6889 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1207,9 +1207,7 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
       std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
           llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
       assert(decoration.has_value());
-      bool hasValue =
-          (decorationAttr.getValue() != mlir::UnitAttr::get(context));
-      structDecorationsInfo.emplace_back(hasValue, decoration.value(),
+      structDecorationsInfo.emplace_back(decoration.value(),
                                          decorationAttr.getValue());
     }
   }



More information about the Mlir-commits mailing list