[Mlir-commits] [mlir] [mlir][spirv][nfc] Refactor member decorations in StructType (PR #150218)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 06:21:32 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
This patch makes `==` and `<` for MemberDecorationInfo a friend function and removes a `hasValue` field. `decorationValue` is also made an `mlir::Attribute` so `UnitAttr` can be used to represent no-value. This is consistent with how OpDecorate is handled in the deserializer. Using `Attribute` will also enable handling non-integer values, however, there seem to be no such decorations for struct members now.
---
Full diff: https://github.com/llvm/llvm-project/pull/150218.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+21-16)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+9-10)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+6-5)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+7-4)
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..d230620e0d57b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -301,30 +301,35 @@ class StructType
static constexpr StringLiteral name = "spirv.struct";
- // Type for specifying the decoration(s) on struct members
+ // Type for specifying the decoration(s) on struct members.
+ // If `decorationValue` is UnitAttr then decoration has no
+ // value.
struct MemberDecorationInfo {
- uint32_t memberIndex : 31;
- uint32_t hasValue : 1;
+ uint32_t memberIndex;
Decoration decoration;
- uint32_t decorationValue;
+ Attribute decorationValue;
- MemberDecorationInfo(uint32_t index, uint32_t hasValue,
- Decoration decoration, uint32_t decorationValue)
- : memberIndex(index), hasValue(hasValue), decoration(decoration),
+ MemberDecorationInfo(uint32_t index, Decoration decoration,
+ Attribute decorationValue)
+ : memberIndex(index), decoration(decoration),
decorationValue(decorationValue) {}
- bool operator==(const MemberDecorationInfo &other) const {
- return (this->memberIndex == other.memberIndex) &&
- (this->decoration == other.decoration) &&
- (this->decorationValue == other.decorationValue);
+ friend bool operator==(const MemberDecorationInfo &lhs,
+ const MemberDecorationInfo &rhs) {
+ return (lhs.memberIndex == rhs.memberIndex) &&
+ (lhs.decoration == rhs.decoration) &&
+ (lhs.decorationValue == rhs.decorationValue);
}
- bool operator<(const MemberDecorationInfo &other) const {
- return this->memberIndex < other.memberIndex ||
- (this->memberIndex == other.memberIndex &&
- static_cast<uint32_t>(this->decoration) <
- static_cast<uint32_t>(other.decoration));
+ friend bool operator<(const MemberDecorationInfo &lhs,
+ const MemberDecorationInfo &rhs) {
+ return lhs.memberIndex < rhs.memberIndex ||
+ (lhs.memberIndex == rhs.memberIndex &&
+ llvm::to_underlying(lhs.decoration) <
+ llvm::to_underlying(rhs.decoration));
}
+
+ bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
};
/// Construct a literal StructType with at least one member.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..4b11395359ed4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -665,19 +665,17 @@ static ParseResult parseStructMemberDecorations(
// Parse member decoration value if it exists.
if (succeeded(parser.parseOptionalEqual())) {
- auto memberDecorationValue =
- parseAndVerifyInteger<uint32_t>(dialect, parser);
-
- if (!memberDecorationValue)
+ Attribute memberDecorationValue;
+ if (failed(parser.parseAttribute(memberDecorationValue)))
return failure();
memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 1,
- memberDecoration.value(), memberDecorationValue.value());
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.value(), memberDecorationValue);
} else {
memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 0,
- memberDecoration.value(), 0);
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.value(), UnitAttr::get(dialect.getContext()));
}
return success();
};
@@ -882,8 +880,9 @@ static void print(StructType type, DialectAsmPrinter &os) {
}
auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
os << stringifyDecoration(decoration.decoration);
- if (decoration.hasValue) {
- os << "=" << decoration.decorationValue;
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
}
};
llvm::interleaveComma(decorations, os, eachFn);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..9e369962f7275 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1188,13 +1188,14 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
offsetInfo[memberIndex] = memberDecoration.second[0];
} else {
+ auto intType = mlir::IntegerType::get(context, 32);
if (!memberDecoration.second.empty()) {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
- memberDecoration.first,
- memberDecoration.second[0]);
+ memberDecorationsInfo.emplace_back(
+ memberIndex, memberDecoration.first,
+ IntegerAttr::get(intType, memberDecoration.second[0]));
} else {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
- memberDecoration.first, 0);
+ memberDecorationsInfo.emplace_back(
+ memberIndex, memberDecoration.first, UnitAttr::get(context));
}
}
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 3400fcf6374e8..22bdf11cdb9cb 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -406,8 +406,9 @@ LogicalResult Serializer::processMemberDecoration(
SmallVector<uint32_t, 4> args(
{structID, memberDecoration.memberIndex,
static_cast<uint32_t>(memberDecoration.decoration)});
- if (memberDecoration.hasValue) {
- args.push_back(memberDecoration.decorationValue);
+ if (memberDecoration.hasValue()) {
+ args.push_back(
+ cast<mlir::IntegerAttr>(memberDecoration.decorationValue).getInt());
}
encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
return success();
@@ -666,10 +667,12 @@ LogicalResult Serializer::prepareBasicType(
}
operands.push_back(elementTypeID);
if (hasOffset) {
+ auto intType = IntegerType::get(structType.getContext(), 32);
// Decorate each struct member with an offset
spirv::StructType::MemberDecorationInfo offsetDecoration{
- elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
- static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
+ elementIndex, spirv::Decoration::Offset,
+ IntegerAttr::get(intType,
+ structType.getMemberOffset(elementIndex))};
if (failed(processMemberDecoration(resultID, offsetDecoration))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7d45b5ea82643..5d05a65414969 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -353,7 +353,7 @@ func.func private @struct_type_missing_comma(!spirv.struct<(!spirv.matrix<3 x ve
// -----
-// expected-error @+1 {{expected integer value}}
+// expected-error @+1 {{expected attribute value}}
func.func private @struct_missing_member_decorator_value(!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=])>)
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/150218
More information about the Mlir-commits
mailing list