[Mlir-commits] [mlir] [mlir][spirv][nfc] Refactor member decorations in StructType (PR #150218)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 06:21:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

<details>
<summary>Changes</summary>

This patch makes `==` and `<` for MemberDecorationInfo a friend function and removes a `hasValue` field. `decorationValue` is also made an `mlir::Attribute` so `UnitAttr` can be used to represent no-value. This is consistent with how OpDecorate is handled in the deserializer. Using `Attribute` will also enable handling non-integer values, however, there seem to be no such decorations for struct members now.

---
Full diff: https://github.com/llvm/llvm-project/pull/150218.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+21-16) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+9-10) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+6-5) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+7-4) 
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..d230620e0d57b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -301,30 +301,35 @@ class StructType
 
   static constexpr StringLiteral name = "spirv.struct";
 
-  // Type for specifying the decoration(s) on struct members
+  // Type for specifying the decoration(s) on struct members.
+  // If `decorationValue` is UnitAttr then decoration has no
+  // value.
   struct MemberDecorationInfo {
-    uint32_t memberIndex : 31;
-    uint32_t hasValue : 1;
+    uint32_t memberIndex;
     Decoration decoration;
-    uint32_t decorationValue;
+    Attribute decorationValue;
 
-    MemberDecorationInfo(uint32_t index, uint32_t hasValue,
-                         Decoration decoration, uint32_t decorationValue)
-        : memberIndex(index), hasValue(hasValue), decoration(decoration),
+    MemberDecorationInfo(uint32_t index, Decoration decoration,
+                         Attribute decorationValue)
+        : memberIndex(index), decoration(decoration),
           decorationValue(decorationValue) {}
 
-    bool operator==(const MemberDecorationInfo &other) const {
-      return (this->memberIndex == other.memberIndex) &&
-             (this->decoration == other.decoration) &&
-             (this->decorationValue == other.decorationValue);
+    friend bool operator==(const MemberDecorationInfo &lhs,
+                           const MemberDecorationInfo &rhs) {
+      return (lhs.memberIndex == rhs.memberIndex) &&
+             (lhs.decoration == rhs.decoration) &&
+             (lhs.decorationValue == rhs.decorationValue);
     }
 
-    bool operator<(const MemberDecorationInfo &other) const {
-      return this->memberIndex < other.memberIndex ||
-             (this->memberIndex == other.memberIndex &&
-              static_cast<uint32_t>(this->decoration) <
-                  static_cast<uint32_t>(other.decoration));
+    friend bool operator<(const MemberDecorationInfo &lhs,
+                          const MemberDecorationInfo &rhs) {
+      return lhs.memberIndex < rhs.memberIndex ||
+             (lhs.memberIndex == rhs.memberIndex &&
+              llvm::to_underlying(lhs.decoration) <
+                  llvm::to_underlying(rhs.decoration));
     }
+
+    bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
   };
 
   /// Construct a literal StructType with at least one member.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..4b11395359ed4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -665,19 +665,17 @@ static ParseResult parseStructMemberDecorations(
 
     // Parse member decoration value if it exists.
     if (succeeded(parser.parseOptionalEqual())) {
-      auto memberDecorationValue =
-          parseAndVerifyInteger<uint32_t>(dialect, parser);
-
-      if (!memberDecorationValue)
+      Attribute memberDecorationValue;
+      if (failed(parser.parseAttribute(memberDecorationValue)))
         return failure();
 
       memberDecorationInfo.emplace_back(
-          static_cast<uint32_t>(memberTypes.size() - 1), 1,
-          memberDecoration.value(), memberDecorationValue.value());
+          static_cast<uint32_t>(memberTypes.size() - 1),
+          memberDecoration.value(), memberDecorationValue);
     } else {
       memberDecorationInfo.emplace_back(
-          static_cast<uint32_t>(memberTypes.size() - 1), 0,
-          memberDecoration.value(), 0);
+          static_cast<uint32_t>(memberTypes.size() - 1),
+          memberDecoration.value(), UnitAttr::get(dialect.getContext()));
     }
     return success();
   };
@@ -882,8 +880,9 @@ static void print(StructType type, DialectAsmPrinter &os) {
       }
       auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
         os << stringifyDecoration(decoration.decoration);
-        if (decoration.hasValue) {
-          os << "=" << decoration.decorationValue;
+        if (decoration.hasValue()) {
+          os << "=";
+          os.printAttributeWithoutType(decoration.decorationValue);
         }
       };
       llvm::interleaveComma(decorations, os, eachFn);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..9e369962f7275 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1188,13 +1188,14 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
             }
             offsetInfo[memberIndex] = memberDecoration.second[0];
           } else {
+            auto intType = mlir::IntegerType::get(context, 32);
             if (!memberDecoration.second.empty()) {
-              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
-                                                 memberDecoration.first,
-                                                 memberDecoration.second[0]);
+              memberDecorationsInfo.emplace_back(
+                  memberIndex, memberDecoration.first,
+                  IntegerAttr::get(intType, memberDecoration.second[0]));
             } else {
-              memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
-                                                 memberDecoration.first, 0);
+              memberDecorationsInfo.emplace_back(
+                  memberIndex, memberDecoration.first, UnitAttr::get(context));
             }
           }
         }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 3400fcf6374e8..22bdf11cdb9cb 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -406,8 +406,9 @@ LogicalResult Serializer::processMemberDecoration(
   SmallVector<uint32_t, 4> args(
       {structID, memberDecoration.memberIndex,
        static_cast<uint32_t>(memberDecoration.decoration)});
-  if (memberDecoration.hasValue) {
-    args.push_back(memberDecoration.decorationValue);
+  if (memberDecoration.hasValue()) {
+    args.push_back(
+        cast<mlir::IntegerAttr>(memberDecoration.decorationValue).getInt());
   }
   encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
   return success();
@@ -666,10 +667,12 @@ LogicalResult Serializer::prepareBasicType(
       }
       operands.push_back(elementTypeID);
       if (hasOffset) {
+        auto intType = IntegerType::get(structType.getContext(), 32);
         // Decorate each struct member with an offset
         spirv::StructType::MemberDecorationInfo offsetDecoration{
-            elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
-            static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
+            elementIndex, spirv::Decoration::Offset,
+            IntegerAttr::get(intType,
+                             structType.getMemberOffset(elementIndex))};
         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
           return emitError(loc, "cannot decorate ")
                  << elementIndex << "-th member of " << structType
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7d45b5ea82643..5d05a65414969 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -353,7 +353,7 @@ func.func private @struct_type_missing_comma(!spirv.struct<(!spirv.matrix<3 x ve
 
 // -----
 
-// expected-error @+1 {{expected integer value}}
+// expected-error @+1 {{expected attribute value}}
 func.func private @struct_missing_member_decorator_value(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=])>)
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/150218


More information about the Mlir-commits mailing list