[Mlir-commits] [mlir] [mlir][spirv] Add support for structs decorations (PR #149793)
Igor Wodiany
llvmlistbot at llvm.org
Mon Jul 21 03:51:02 PDT 2025
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/149793
An alternative implementation could use `ArrayRef` of `NamedAttribute`s or `NamedAttrList` to store structs decorations, as the deserializer uses `NamedAttribute`s for decorations. However, using a custom struct allows us to store the `spirv::Decoration`s directly rather than its name in a `StringRef`/`StringAttr`.
>From 9d2abda552eeed2e703225fb0f9cbb5319f3cd10 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Fri, 18 Jul 2025 10:01:26 +0000
Subject: [PATCH] [mlir][spirv] Add support for structs decorations
An alternative implementation could use ArrayRef of NamedAttributes
or NamedAttrList to store structs decorations, as the deserializer
uses NamedAttributes for decorations. However, using a custom struct
allows us to store the spirv::Decorations directly rather than its
name in a StringRef.
---
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 39 +++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 61 ++++++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 118 ++++++++++++++----
.../SPIRV/Deserialization/Deserializer.cpp | 34 +++--
.../SPIRV/Deserialization/Deserializer.h | 1 +
.../Target/SPIRV/Serialization/Serializer.cpp | 28 ++++-
mlir/test/Dialect/SPIRV/IR/types.mlir | 6 +
mlir/test/Target/SPIRV/memory-ops.mlir | 20 +--
mlir/test/Target/SPIRV/struct.mlir | 38 +++---
mlir/test/Target/SPIRV/undef.mlir | 6 +-
10 files changed, 273 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 212cba61d396c..56d09301345f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -327,10 +327,33 @@ class StructType
}
};
+ // Type for specifying the decoration(s) on the struct itself
+ struct StructDecorationInfo {
+ bool hasValue;
+ Decoration decoration;
+ Attribute decorationValue;
+
+ StructDecorationInfo(bool hasValue, Decoration decoration,
+ Attribute decorationValue)
+ : hasValue(hasValue), decoration(decoration),
+ decorationValue(decorationValue) {}
+
+ bool operator==(const StructDecorationInfo &other) const {
+ return (this->decoration == other.decoration) &&
+ (this->decorationValue == other.decorationValue);
+ }
+
+ bool operator<(const StructDecorationInfo &other) const {
+ return static_cast<uint32_t>(this->decoration) <
+ static_cast<uint32_t>(other.decoration);
+ }
+ };
+
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
@@ -364,6 +387,9 @@ class StructType
bool hasOffset() const;
+ /// Returns true if the struct has a specified decoration.
+ bool hasDecoration(spirv::Decoration decoration) const;
+
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -377,12 +403,18 @@ class StructType
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
+ // Returns in `structDecorations` the Decorations associated with the
+ // StructType.
+ void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+ &structDecorations) const;
+
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
@@ -393,6 +425,9 @@ class StructType
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index f32c53b8f0b9e..6121fef7318bb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,50 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 0> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(true, decoration.value(),
+ decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(false, decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -892,7 +927,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3799abd6fc743..4bb06b349d040 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -837,12 +837,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -850,6 +852,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -857,10 +860,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -869,16 +876,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -892,7 +902,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -939,9 +949,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -963,6 +981,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -975,17 +1000,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1009,6 +1036,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1017,21 +1049,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 4> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1041,18 +1082,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1076,6 +1120,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1107,11 +1160,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1133,6 +1196,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d133d0332e271..c8aa67c8c3b0d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1202,24 +1199,39 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ bool hasValue =
+ (decorationAttr.getValue() != mlir::UnitAttr::get(context));
+ structDecorationsInfo.emplace_back(hasValue, decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd2bf501..db1cc3f8d79c2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 3400fcf6374e8..70fce56523e90 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -319,6 +319,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -617,11 +618,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -689,6 +695,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 7d45b5ea82643..071389d497fde 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
// CHECK: func private @struct_empty(!spirv.struct<()>)
func.func private @struct_empty(!spirv.struct<()>)
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
// -----
// expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c3921d427..786d07a218c66 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : f32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : f32
spirv.Return
}
- spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : i32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : i32
spirv.Return
}
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0bfa2660..4984ee79f903d 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
- spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
- spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
- spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
- spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
// CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
+ // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+ spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
- spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+ spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
// CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe8b40af..8889b80e86f95 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
%5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
%6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
- // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
- %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+ // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+ %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
%8 = spirv.Constant 0 : i32
- %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Return
}
}
More information about the Mlir-commits
mailing list