[Mlir-commits] [mlir] [mlir][spirv] Add support for structs decorations (PR #149793)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 03:51:35 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
An alternative implementation could use `ArrayRef` of `NamedAttribute`s or `NamedAttrList` to store structs decorations, as the deserializer uses `NamedAttribute`s for decorations. However, using a custom struct allows us to store the `spirv::Decoration`s directly rather than its name in a `StringRef`/`StringAttr`.
---
Patch is 38.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149793.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+37-2)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+56-5)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+93-25)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+23-11)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+1)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+24-4)
- (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+6)
- (modified) mlir/test/Target/SPIRV/memory-ops.mlir (+10-10)
- (modified) mlir/test/Target/SPIRV/struct.mlir (+20-18)
- (modified) mlir/test/Target/SPIRV/undef.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..56d09301345f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -327,10 +327,33 @@ class StructType
}
};
+ // Type for specifying the decoration(s) on the struct itself
+ struct StructDecorationInfo {
+ bool hasValue;
+ Decoration decoration;
+ Attribute decorationValue;
+
+ StructDecorationInfo(bool hasValue, Decoration decoration,
+ Attribute decorationValue)
+ : hasValue(hasValue), decoration(decoration),
+ decorationValue(decorationValue) {}
+
+ bool operator==(const StructDecorationInfo &other) const {
+ return (this->decoration == other.decoration) &&
+ (this->decorationValue == other.decorationValue);
+ }
+
+ bool operator<(const StructDecorationInfo &other) const {
+ return static_cast<uint32_t>(this->decoration) <
+ static_cast<uint32_t>(other.decoration);
+ }
+ };
+
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
@@ -364,6 +387,9 @@ class StructType
bool hasOffset() const;
+ /// Returns true if the struct has a specified decoration.
+ bool hasDecoration(spirv::Decoration decoration) const;
+
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -377,12 +403,18 @@ class StructType
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
+ // Returns in `structDecorations` the Decorations associated with the
+ // StructType.
+ void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+ &structDecorations) const;
+
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
@@ -393,6 +425,9 @@ class StructType
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..6121fef7318bb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(true, decoration.value(),
+ decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(false, decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -892,7 +927,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3799abd6fc743..4bb06b349d040 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -837,12 +837,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -850,6 +852,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -857,10 +860,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -869,16 +876,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -892,7 +902,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -939,9 +949,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -963,6 +981,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -975,17 +1000,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1009,6 +1036,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1017,21 +1049,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1041,18 +1082,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1076,6 +1120,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1107,11 +1160,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1133,6 +1196,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..c8aa67c8c3b0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1202,24 +1199,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ i...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/149793
More information about the Mlir-commits
mailing list