[Mlir-commits] [mlir] [mlir][spirv] Add support for structs decorations (PR #149793)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 21 03:51:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

<details>
<summary>Changes</summary>

An alternative implementation could use `ArrayRef` of `NamedAttribute`s or `NamedAttrList` to store structs decorations, as the deserializer uses `NamedAttribute`s for decorations. However, using a custom struct allows us to store the `spirv::Decoration`s directly rather than its name in a `StringRef`/`StringAttr`.

---

Patch is 38.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149793.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+37-2) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+56-5) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+93-25) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+23-11) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+1) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+24-4) 
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+6) 
- (modified) mlir/test/Target/SPIRV/memory-ops.mlir (+10-10) 
- (modified) mlir/test/Target/SPIRV/struct.mlir (+20-18) 
- (modified) mlir/test/Target/SPIRV/undef.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..56d09301345f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -327,10 +327,33 @@ class StructType
     }
   };
 
+  // Type for specifying the decoration(s) on the struct itself
+  struct StructDecorationInfo {
+    bool hasValue;
+    Decoration decoration;
+    Attribute decorationValue;
+
+    StructDecorationInfo(bool hasValue, Decoration decoration,
+                         Attribute decorationValue)
+        : hasValue(hasValue), decoration(decoration),
+          decorationValue(decorationValue) {}
+
+    bool operator==(const StructDecorationInfo &other) const {
+      return (this->decoration == other.decoration) &&
+             (this->decorationValue == other.decorationValue);
+    }
+
+    bool operator<(const StructDecorationInfo &other) const {
+      return static_cast<uint32_t>(this->decoration) <
+             static_cast<uint32_t>(other.decoration);
+    }
+  };
+
   /// Construct a literal StructType with at least one member.
   static StructType get(ArrayRef<Type> memberTypes,
                         ArrayRef<OffsetInfo> offsetInfo = {},
-                        ArrayRef<MemberDecorationInfo> memberDecorations = {});
+                        ArrayRef<MemberDecorationInfo> memberDecorations = {},
+                        ArrayRef<StructDecorationInfo> structDecorations = {});
 
   /// Construct an identified StructType. This creates a StructType whose body
   /// (member types, offset info, and decorations) is not set yet. A call to
@@ -364,6 +387,9 @@ class StructType
 
   bool hasOffset() const;
 
+  /// Returns true if the struct has a specified decoration.
+  bool hasDecoration(spirv::Decoration decoration) const;
+
   uint64_t getMemberOffset(unsigned) const;
 
   // Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -377,12 +403,18 @@ class StructType
       unsigned i,
       SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
 
+  // Returns in `structDecorations` the Decorations associated with the
+  // StructType.
+  void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+                                &structDecorations) const;
+
   /// Sets the contents of an incomplete identified StructType. This method must
   /// be called only for identified StructTypes and it must be called only once
   /// per instance. Otherwise, failure() is returned.
   LogicalResult
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
-             ArrayRef<MemberDecorationInfo> memberDecorations = {});
+             ArrayRef<MemberDecorationInfo> memberDecorations = {},
+             ArrayRef<StructDecorationInfo> structDecorations = {});
 
   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                      std::optional<StorageClass> storage = std::nullopt);
@@ -393,6 +425,9 @@ class StructType
 llvm::hash_code
 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..6121fef7318bb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
 //             `!spirv.struct<` (id `,`)?
 //                          `(`
 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
-//                          `)>`
+//                          `)`
+//                            (`,` struct-decoration)?
+//                          `>`
 static Type parseStructType(SPIRVDialect const &dialect,
                             DialectAsmParser &parser) {
   // TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
     return Type();
   }
 
-  if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+  if (failed(parser.parseRParen()))
+    return Type();
+
+  SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+  auto parseStructDecoration = [&]() {
+    std::optional<spirv::Decoration> decoration =
+        parseAndVerify<spirv::Decoration>(dialect, parser);
+    if (!decoration)
+      return failure();
+
+    // Parse decoration value if it exists.
+    if (succeeded(parser.parseOptionalEqual())) {
+      Attribute decorationValue;
+
+      if (failed(parser.parseAttribute(decorationValue)))
+        return failure();
+
+      structDecorationInfo.emplace_back(true, decoration.value(),
+                                        decorationValue);
+    } else {
+      structDecorationInfo.emplace_back(false, decoration.value(),
+                                        UnitAttr::get(dialect.getContext()));
+    }
+    return success();
+  };
+
+  while (succeeded(parser.parseOptionalComma()))
+    if (failed(parseStructDecoration()))
+      return Type();
+
+  if (failed(parser.parseGreater()))
     return Type();
 
   if (!identifier.empty()) {
     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
-                                     memberDecorationInfo)))
+                                     memberDecorationInfo,
+                                     structDecorationInfo)))
       return Type();
     return idStructTy;
   }
 
-  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+  return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+                         structDecorationInfo);
 }
 
 // spirv-type ::= array-type
@@ -892,7 +927,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
   };
   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
                         printMember);
-  os << ")>";
+  os << ")";
+
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+  type.getStructDecorations(decorations);
+  if (!decorations.empty()) {
+    os << ", ";
+    auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+      os << stringifyDecoration(decoration.decoration);
+      if (decoration.hasValue) {
+        os << "=";
+        os.printAttributeWithoutType(decoration.decorationValue);
+      }
+    };
+    llvm::interleaveComma(decorations, os, eachFn);
+  }
+
+  os << ">";
 }
 
 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3799abd6fc743..4bb06b349d040 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -837,12 +837,14 @@ void SampledImageType::getCapabilities(
 /// - for literal structs:
 ///   - a list of member types;
 ///   - a list of member offset info;
-///   - a list of member decoration info.
+///   - a list of member decoration info;
+///   - a list of struct decoration info.
 ///
 /// Identified structures only have a mutable component consisting of:
 /// - a list of member types;
 /// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
 struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// Construct a storage object for an identified struct type. A struct type
   /// associated with such storage must call StructType::trySetBody(...) later
@@ -850,6 +852,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(StringRef identifier)
       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
         numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+        numStructDecorations(0), structDecorationsInfo(nullptr),
         identifier(identifier) {}
 
   /// Construct a storage object for a literal struct type. A struct type
@@ -857,10 +860,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   StructTypeStorage(
       unsigned numMembers, Type const *memberTypes,
       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
-      StructType::MemberDecorationInfo const *memberDecorationsInfo)
+      StructType::MemberDecorationInfo const *memberDecorationsInfo,
+      unsigned numStructDecorations,
+      StructType::StructDecorationInfo const *structDecorationsInfo)
       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
-        memberDecorationsInfo(memberDecorationsInfo) {}
+        memberDecorationsInfo(memberDecorationsInfo),
+        numStructDecorations(numStructDecorations),
+        structDecorationsInfo(structDecorationsInfo) {}
 
   /// A storage key is divided into 2 parts:
   /// - for identified structs:
@@ -869,16 +876,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   ///   - an ArrayRef<Type> for member types;
   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+  ///     info;
+  ///   - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
   ///     info.
   ///
   /// An identified struct type is uniqued only by the first part (field 0)
   /// of the key.
   ///
-  /// A literal struct type is uniqued only by the second part (fields 1, 2, and
-  /// 3) of the key. The identifier field (field 0) must be empty.
+  /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+  /// and 4) of the key. The identifier field (field 0) must be empty.
   using KeyTy =
       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
-                 ArrayRef<StructType::MemberDecorationInfo>>;
+                 ArrayRef<StructType::MemberDecorationInfo>,
+                 ArrayRef<StructType::StructDecorationInfo>>;
 
   /// For identified structs, return true if the given key contains the same
   /// identifier.
@@ -892,7 +902,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     }
 
     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
-                        getMemberDecorationsInfo());
+                        getMemberDecorationsInfo(), getStructDecorationsInfo());
   }
 
   /// If the given key contains a non-empty identifier, this method constructs
@@ -939,9 +949,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
     }
 
-    return new (allocator.allocate<StructTypeStorage>())
-        StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
-                          numMemberDecorations, memberDecorationList);
+    const StructType::StructDecorationInfo *structDecorationList = nullptr;
+    unsigned numStructDecorations = 0;
+    if (!std::get<4>(key).empty()) {
+      auto keyStructDecorations = std::get<4>(key);
+      numStructDecorations = keyStructDecorations.size();
+      structDecorationList = allocator.copyInto(keyStructDecorations).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+        keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+        memberDecorationList, numStructDecorations, structDecorationList);
   }
 
   ArrayRef<Type> getMemberTypes() const {
@@ -963,6 +981,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
     return {};
   }
 
+  ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+    if (structDecorationsInfo)
+      return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+                                                        numStructDecorations);
+    return {};
+  }
+
   StringRef getIdentifier() const { return identifier; }
 
   bool isIdentified() const { return !identifier.empty(); }
@@ -975,17 +1000,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   /// - If called for an identified struct whose body was set before (through a
   /// call to this method) but with different contents from the passed
   /// arguments.
-  LogicalResult mutate(
-      TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
-      ArrayRef<StructType::OffsetInfo> structOffsetInfo,
-      ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+  LogicalResult
+  mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+         ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+         ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+         ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
     if (!isIdentified())
       return failure();
 
     if (memberTypesAndIsBodySet.getInt() &&
         (getMemberTypes() != structMemberTypes ||
          getOffsetInfo() != structOffsetInfo ||
-         getMemberDecorationsInfo() != structMemberDecorationInfo))
+         getMemberDecorationsInfo() != structMemberDecorationInfo ||
+         getStructDecorationsInfo() != structDecorationInfo))
       return failure();
 
     memberTypesAndIsBodySet.setInt(true);
@@ -1009,6 +1036,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
           allocator.copyInto(structMemberDecorationInfo).data();
     }
 
+    if (!structDecorationInfo.empty()) {
+      numStructDecorations = structDecorationInfo.size();
+      structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+    }
+
     return success();
   }
 
@@ -1017,21 +1049,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   unsigned numMembers;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
+  unsigned numStructDecorations;
+  StructType::StructDecorationInfo const *structDecorationsInfo;
   StringRef identifier;
 };
 
 StructType
 StructType::get(ArrayRef<Type> memberTypes,
                 ArrayRef<StructType::OffsetInfo> offsetInfo,
-                ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+                ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+                ArrayRef<StructType::StructDecorationInfo> structDecorations) {
   assert(!memberTypes.empty() && "Struct needs at least one member type");
   // Sort the decorations.
-  SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+  SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
       memberDecorations);
-  llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+  llvm::array_pod_sort(sortedMemberDecorations.begin(),
+                       sortedMemberDecorations.end());
+  SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+      structDecorations);
+  llvm::array_pod_sort(sortedStructDecorations.begin(),
+                       sortedStructDecorations.end());
+
   return Base::get(memberTypes.vec().front().getContext(),
                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
-                   sortedDecorations);
+                   sortedMemberDecorations, sortedStructDecorations);
 }
 
 StructType StructType::getIdentified(MLIRContext *context,
@@ -1041,18 +1082,21 @@ StructType StructType::getIdentified(MLIRContext *context,
 
   return Base::get(context, identifier, ArrayRef<Type>(),
                    ArrayRef<StructType::OffsetInfo>(),
-                   ArrayRef<StructType::MemberDecorationInfo>());
+                   ArrayRef<StructType::MemberDecorationInfo>(),
+                   ArrayRef<StructType::StructDecorationInfo>());
 }
 
 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
   StructType newStructType = Base::get(
       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-      ArrayRef<StructType::MemberDecorationInfo>());
+      ArrayRef<StructType::MemberDecorationInfo>(),
+      ArrayRef<StructType::StructDecorationInfo>());
   // Set an empty body in case this is a identified struct.
   if (newStructType.isIdentified() &&
       failed(newStructType.trySetBody(
           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
-          ArrayRef<StructType::MemberDecorationInfo>())))
+          ArrayRef<StructType::MemberDecorationInfo>(),
+          ArrayRef<StructType::StructDecorationInfo>())))
     return StructType();
 
   return newStructType;
@@ -1076,6 +1120,15 @@ TypeRange StructType::getElementTypes() const {
 
 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
 
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+  for (StructType::StructDecorationInfo info :
+       getImpl()->getStructDecorationsInfo())
+    if (info.decoration == decoration)
+      return true;
+
+  return false;
+}
+
 uint64_t StructType::getMemberOffset(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");
   return getImpl()->offsetInfo[index];
@@ -1107,11 +1160,21 @@ void StructType::getMemberDecorations(
   }
 }
 
+void StructType::getStructDecorations(
+    SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+    const {
+  structDecorations.clear();
+  auto implDecorations = getImpl()->getStructDecorationsInfo();
+  structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
 LogicalResult
 StructType::trySetBody(ArrayRef<Type> memberTypes,
                        ArrayRef<OffsetInfo> offsetInfo,
-                       ArrayRef<MemberDecorationInfo> memberDecorations) {
-  return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+                       ArrayRef<MemberDecorationInfo> memberDecorations,
+                       ArrayRef<StructDecorationInfo> structDecorations) {
+  return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+                      structDecorations);
 }
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1133,6 +1196,11 @@ llvm::hash_code spirv::hash_value(
                             memberDecorationInfo.decoration);
 }
 
+llvm::hash_code spirv::hash_value(
+    const StructType::StructDecorationInfo &structDecorationInfo) {
+  return llvm::hash_value(structDecorationInfo.decoration);
+}
+
 //===----------------------------------------------------------------------===//
 // MatrixType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..c8aa67c8c3b0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
     }
-    // Block decoration does not affect spirv.struct type, but is still stored
-    // for verification.
-    // TODO: Update StructType to contain this information since
-    // it is needed for many validation rules.
     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
     break;
   case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
 
       if (failed(structType.trySetBody(
               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
-              deferredStructIt->memberDecorationsInfo)))
+              deferredStructIt->memberDecorationsInfo,
+              deferredStructIt->structDecorationsInfo)))
         return failure();
 
       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1202,24 +1199,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
     }
   }
 
+  SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+  i...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/149793


More information about the Mlir-commits mailing list