[Mlir-commits] [mlir] 6f0ce46 - Revert "[mlir][spirv] Enhance structure type member decoration handling"
Mehdi Amini
llvmlistbot at llvm.org
Thu Jun 11 22:01:31 PDT 2020
Author: Mehdi Amini
Date: 2020-06-12T05:01:24Z
New Revision: 6f0ce46873b609851634b2c77fc76bf8d580c3c6
URL: https://github.com/llvm/llvm-project/commit/6f0ce46873b609851634b2c77fc76bf8d580c3c6
DIFF: https://github.com/llvm/llvm-project/commit/6f0ce46873b609851634b2c77fc76bf8d580c3c6.diff
LOG: Revert "[mlir][spirv] Enhance structure type member decoration handling"
This reverts commit 5d74df5b03e46b7bd3700e3595c7008a6905b288.
This broke the MSVC build: <bits/stdint-uintn.h> isn't available on Windows
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/struct.mlir
mlir/test/Dialect/SPIRV/types.mlir
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 235c322d7488..b7180399a837 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -19,7 +19,6 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
-#include <bits/stdint-uintn.h>
#include <tuple>
// Forward declare enum classes related to op availability. Their definitions
@@ -277,40 +276,22 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
public:
using Base::Base;
- // Type for specifying the offset of the struct members
- using OffsetInfo = uint32_t;
-
- // Type for specifying the decoration(s) on struct members
- struct MemberDecorationInfo {
- uint32_t memberIndex : 31;
- uint32_t hasValue : 1;
- Decoration decoration;
- uint32_t decorationValue;
-
- MemberDecorationInfo(uint32_t index, uint32_t hasValue,
- Decoration decoration, uint32_t decorationValue)
- : memberIndex(index), hasValue(hasValue), decoration(decoration),
- decorationValue(decorationValue) {}
-
- bool operator==(const MemberDecorationInfo &other) const {
- return (this->memberIndex == other.memberIndex) &&
- (this->decoration == other.decoration) &&
- (this->decorationValue == other.decorationValue);
- }
+ // Layout information used for members in a struct in SPIR-V
+ //
+ // TODO(ravishankarm) : For now this only supports the offset type, so uses
+ // uint64_t value to represent the offset, with
+ // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+ // something that can hold all the information needed for
diff erent member
+ // types
+ using LayoutInfo = uint64_t;
- bool operator<(const MemberDecorationInfo &other) const {
- return this->memberIndex < other.memberIndex ||
- (this->memberIndex == other.memberIndex &&
- static_cast<uint32_t>(this->decoration) <
- static_cast<uint32_t>(other.decoration));
- }
- };
+ using MemberDecorationInfo = std::pair<uint32_t, spirv::Decoration>;
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
/// Construct a StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
- ArrayRef<OffsetInfo> offsetInfo = {},
+ ArrayRef<LayoutInfo> layoutInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
/// Construct a struct with no members.
@@ -342,9 +323,9 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
ElementTypeRange getElementTypes() const;
- bool hasOffset() const;
+ bool hasLayout() const;
- uint64_t getMemberOffset(unsigned) const;
+ uint64_t getOffset(unsigned) const;
// Returns in `allMemberDecorations` the spirv::Decorations (apart from
// Offset) associated with all members of the StructType.
@@ -353,9 +334,8 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
// Returns in `memberDecorations` all the spirv::Decorations (apart from
// Offset) associated with the `i`-th member of the StructType.
- void getMemberDecorations(unsigned i,
- SmallVectorImpl<StructType::MemberDecorationInfo>
- &memberDecorations) const;
+ void getMemberDecorations(
+ unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
@@ -363,9 +343,6 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
Optional<spirv::StorageClass> storage = llvm::None);
};
-llvm::hash_code
-hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
-
// SPIR-V cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index e456f499b745..f953f958d75e 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -32,7 +32,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
}
SmallVector<Type, 4> memberTypes;
- SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
+ SmallVector<Size, 4> layoutInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
Size structMemberOffset = 0;
@@ -46,8 +46,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
decorateType(structType.getElementType(i), memberSize, memberAlignment);
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
memberTypes.push_back(memberType);
- offsetInfo.push_back(
- static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
+ layoutInfo.push_back(structMemberOffset);
// If the member's size is the max value, it must be the last member and it
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
@@ -67,7 +66,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
alignment = maxMemberAlignment;
structType.getMemberDecorations(memberDecorations);
- return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
+ return spirv::StructType::get(memberTypes, layoutInfo, memberDecorations);
}
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
@@ -169,7 +168,7 @@ bool VulkanLayoutUtils::isLegalType(Type type) {
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::PhysicalStorageBuffer:
- return structType.hasOffset() || !structType.getNumElements();
+ return structType.hasLayout() || !structType.getNumElements();
default:
return true;
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 43e70a1bdc63..455064f58ce6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -535,31 +535,30 @@ static Type parseImageType(SPIRVDialect const &dialect,
static ParseResult parseStructMemberDecorations(
SPIRVDialect const &dialect, DialectAsmParser &parser,
ArrayRef<Type> memberTypes,
- SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
+ SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
// Check if the first element is offset.
- llvm::SMLoc offsetLoc = parser.getCurrentLocation();
- StructType::OffsetInfo offset = 0;
- OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
- if (offsetParseResult.hasValue()) {
- if (failed(*offsetParseResult))
+ llvm::SMLoc layoutLoc = parser.getCurrentLocation();
+ StructType::LayoutInfo layout = 0;
+ OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
+ if (layoutParseResult.hasValue()) {
+ if (failed(*layoutParseResult))
return failure();
- if (offsetInfo.size() != memberTypes.size() - 1) {
- return parser.emitError(offsetLoc,
- "offset specification must be given for "
- "all members");
+ if (layoutInfo.size() != memberTypes.size() - 1) {
+ return parser.emitError(
+ layoutLoc, "layout specification must be given for all members");
}
- offsetInfo.push_back(offset);
+ layoutInfo.push_back(layout);
}
// Check for no spirv::Decorations.
if (succeeded(parser.parseOptionalRSquare()))
return success();
- // If there was an offset, make sure to parse the comma.
- if (offsetParseResult.hasValue() && parser.parseComma())
+ // If there was a layout, make sure to parse the comma.
+ if (layoutParseResult.hasValue() && parser.parseComma())
return failure();
// Check for spirv::Decorations.
@@ -568,23 +567,9 @@ static ParseResult parseStructMemberDecorations(
if (!memberDecoration)
return failure();
- // Parse member decoration value if it exists.
- if (succeeded(parser.parseOptionalEqual())) {
- auto memberDecorationValue =
- parseAndVerifyInteger<uint32_t>(dialect, parser);
-
- if (!memberDecorationValue)
- return failure();
-
- memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 1,
- memberDecoration.getValue(), memberDecorationValue.getValue());
- } else {
- memberDecorationInfo.emplace_back(
- static_cast<uint32_t>(memberTypes.size() - 1), 0,
- memberDecoration.getValue(), 0);
- }
-
+ memberDecorationInfo.emplace_back(
+ static_cast<uint32_t>(memberTypes.size() - 1),
+ memberDecoration.getValue());
} while (succeeded(parser.parseOptionalComma()));
return parser.parseRSquare();
@@ -602,7 +587,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
return StructType::getEmpty(dialect.getContext());
SmallVector<Type, 4> memberTypes;
- SmallVector<StructType::OffsetInfo, 4> offsetInfo;
+ SmallVector<StructType::LayoutInfo, 4> layoutInfo;
SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
do {
@@ -612,21 +597,21 @@ static Type parseStructType(SPIRVDialect const &dialect,
memberTypes.push_back(memberType);
if (succeeded(parser.parseOptionalLSquare())) {
- if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
+ if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
memberDecorationInfo)) {
return Type();
}
}
} while (succeeded(parser.parseOptionalComma()));
- if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
+ if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
parser.emitError(parser.getNameLoc(),
- "offset specification must be given for all members");
+ "layout specification must be given for all members");
return Type();
}
if (parser.parseGreater())
return Type();
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
}
// spirv-type ::= array-type
@@ -694,20 +679,17 @@ static void print(StructType type, DialectAsmPrinter &os) {
os << "struct<";
auto printMember = [&](unsigned i) {
os << type.getElementType(i);
- SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
+ SmallVector<spirv::Decoration, 0> decorations;
type.getMemberDecorations(i, decorations);
- if (type.hasOffset() || !decorations.empty()) {
+ if (type.hasLayout() || !decorations.empty()) {
os << " [";
- if (type.hasOffset()) {
- os << type.getMemberOffset(i);
+ if (type.hasLayout()) {
+ os << type.getOffset(i);
if (!decorations.empty())
os << ", ";
}
- auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
- os << stringifyDecoration(decoration.decoration);
- if (decoration.hasValue) {
- os << "=" << decoration.decorationValue;
- }
+ auto eachFn = [&os](spirv::Decoration decoration) {
+ os << stringifyDecoration(decoration);
};
llvm::interleaveComma(decorations, os, eachFn);
os << "]";
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 963f5393c572..0226f5117540 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -874,17 +874,17 @@ void SPIRVType::getCapabilities(
struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
- StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
+ StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations,
StructType::MemberDecorationInfo const *memberDecorationsInfo)
: TypeStorage(numMembers), memberTypes(memberTypes),
- offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+ layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
memberDecorationsInfo(memberDecorationsInfo) {}
- using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
+ using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>,
ArrayRef<StructType::MemberDecorationInfo>>;
bool operator==(const KeyTy &key) const {
return key ==
- KeyTy(getMemberTypes(), getOffsetInfo(), getMemberDecorationsInfo());
+ KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo());
}
static StructTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -897,13 +897,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
typesList = allocator.copyInto(keyTypes).data();
}
- const StructType::OffsetInfo *offsetInfoList = nullptr;
+ const StructType::LayoutInfo *layoutInfoList = nullptr;
if (!std::get<1>(key).empty()) {
- ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<1>(key);
- assert(keyOffsetInfo.size() == keyTypes.size() &&
- "size of offset information must be same as the size of number of "
+ ArrayRef<StructType::LayoutInfo> keyLayoutInfo = std::get<1>(key);
+ assert(keyLayoutInfo.size() == keyTypes.size() &&
+ "size of layout information must be same as the size of number of "
"elements");
- offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
+ layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
}
const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
@@ -914,7 +914,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
+ StructTypeStorage(keyTypes.size(), typesList, layoutInfoList,
numMemberDecorations, memberDecorationList);
}
@@ -922,9 +922,9 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return ArrayRef<Type>(memberTypes, getSubclassData());
}
- ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
- if (offsetInfo) {
- return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
+ ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+ if (layoutInfo) {
+ return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
}
return {};
}
@@ -938,14 +938,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
Type const *memberTypes;
- StructType::OffsetInfo const *offsetInfo;
+ StructType::LayoutInfo const *layoutInfo;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
- ArrayRef<StructType::OffsetInfo> offsetInfo,
+ ArrayRef<StructType::LayoutInfo> layoutInfo,
ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
@@ -953,12 +953,12 @@ StructType::get(ArrayRef<Type> memberTypes,
memberDecorations.begin(), memberDecorations.end());
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
- memberTypes, offsetInfo, sortedDecorations);
+ memberTypes, layoutInfo, sortedDecorations);
}
StructType StructType::getEmpty(MLIRContext *context) {
return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
- ArrayRef<StructType::OffsetInfo>(),
+ ArrayRef<StructType::LayoutInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
}
@@ -975,11 +975,11 @@ StructType::ElementTypeRange StructType::getElementTypes() const {
return ElementTypeRange(getImpl()->memberTypes, getNumElements());
}
-bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
-uint64_t StructType::getMemberOffset(unsigned index) const {
+uint64_t StructType::getOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
- return getImpl()->offsetInfo[index];
+ return getImpl()->layoutInfo[index];
}
void StructType::getMemberDecorations(
@@ -992,16 +992,15 @@ void StructType::getMemberDecorations(
}
void StructType::getMemberDecorations(
- unsigned index,
- SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
+ unsigned index, SmallVectorImpl<spirv::Decoration> &decorations) const {
assert(getNumElements() > index && "member index out of range");
auto memberDecorations = getImpl()->getMemberDecorationsInfo();
- decorationsInfo.clear();
- for (const auto &memberDecoration : memberDecorations) {
- if (memberDecoration.memberIndex == index) {
- decorationsInfo.push_back(memberDecoration);
+ decorations.clear();
+ for (auto &memberDecoration : memberDecorations) {
+ if (memberDecoration.first == index) {
+ decorations.push_back(memberDecoration.second);
}
- if (memberDecoration.memberIndex > index) {
+ if (memberDecoration.first > index) {
// Early exit since the decorations are stored sorted.
return;
}
@@ -1021,12 +1020,6 @@ void StructType::getCapabilities(
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
}
-llvm::hash_code spirv::hash_value(
- const StructType::MemberDecorationInfo &memberDecorationInfo) {
- return llvm::hash_combine(memberDecorationInfo.memberIndex,
- memberDecorationInfo.decoration);
-}
-
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4bb9e6d26b97..ecd79d7153ab 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1305,7 +1305,7 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
memberTypes.push_back(memberType);
}
- SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
+ SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
if (memberDecorationMap.count(operands[0])) {
auto &allMemberDecorations = memberDecorationMap[operands[0]];
@@ -1314,27 +1314,27 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
// Check for offset.
if (memberDecoration.first == spirv::Decoration::Offset) {
- // If offset info is empty, resize to the number of members;
- if (offsetInfo.empty()) {
- offsetInfo.resize(memberTypes.size());
+ // If layoutInfo is empty, resize to the number of members;
+ if (layoutInfo.empty()) {
+ layoutInfo.resize(memberTypes.size());
}
- offsetInfo[memberIndex] = memberDecoration.second[0];
+ layoutInfo[memberIndex] = memberDecoration.second[0];
} else {
if (!memberDecoration.second.empty()) {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
- memberDecoration.first,
- memberDecoration.second[0]);
- } else {
- memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
- memberDecoration.first, 0);
+ return emitError(unknownLoc,
+ "unhandled OpMemberDecoration with decoration ")
+ << stringifyDecoration(memberDecoration.first)
+ << " which has additional operands";
}
+ memberDecorationsInfo.emplace_back(memberIndex,
+ memberDecoration.first);
}
}
}
}
}
typeMap[operands[0]] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
// TODO(ravishankarm): Update StructType to have member name as attribute as
// well.
return success();
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f8641873fd95..81f873281c63 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -227,9 +227,9 @@ class Serializer {
}
/// Process member decoration
- LogicalResult processMemberDecoration(
- uint32_t structID,
- const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
+ LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+ spirv::Decoration decorationType,
+ ArrayRef<uint32_t> values = {});
//===--------------------------------------------------------------------===//
// Types
@@ -736,14 +736,14 @@ LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
return success();
}
-LogicalResult Serializer::processMemberDecoration(
- uint32_t structID,
- const spirv::StructType::MemberDecorationInfo &memberDecoration) {
+LogicalResult
+Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
+ spirv::Decoration decorationType,
+ ArrayRef<uint32_t> values) {
SmallVector<uint32_t, 4> args(
- {structID, memberDecoration.memberIndex,
- static_cast<uint32_t>(memberDecoration.decoration)});
- if (memberDecoration.hasValue) {
- args.push_back(memberDecoration.decorationValue);
+ {structID, memberIndex, static_cast<uint32_t>(decorationType)});
+ if (!values.empty()) {
+ args.append(values.begin(), values.end());
}
return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
args);
@@ -1070,7 +1070,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
- bool hasOffset = structType.hasOffset();
+ bool hasLayout = structType.hasLayout();
for (auto elementIndex :
llvm::seq<uint32_t>(0, structType.getNumElements())) {
uint32_t elementTypeID = 0;
@@ -1079,12 +1079,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return failure();
}
operands.push_back(elementTypeID);
- if (hasOffset) {
+ if (hasLayout) {
// Decorate each struct member with an offset
- spirv::StructType::MemberDecorationInfo offsetDecoration{
- elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
- static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
- if (failed(processMemberDecoration(resultID, offsetDecoration))) {
+ if (failed(processMemberDecoration(
+ resultID, elementIndex, spirv::Decoration::Offset,
+ static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
<< " with its offset";
@@ -1094,11 +1093,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
structType.getMemberDecorations(memberDecorations);
for (auto &memberDecoration : memberDecorations) {
- if (failed(processMemberDecoration(resultID, memberDecoration))) {
+ if (failed(processMemberDecoration(resultID, memberDecoration.first,
+ memberDecoration.second))) {
return emitError(loc, "cannot decorate ")
- << static_cast<uint32_t>(memberDecoration.memberIndex)
- << "-th member of " << structType << " with "
- << stringifyDecoration(memberDecoration.decoration);
+ << memberDecoration.first << "-th member of " << structType
+ << " with " << stringifyDecoration(memberDecoration.second);
}
}
typeEnum = spirv::Opcode::OpTypeStruct;
diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
index fff591d2f24e..3066462fd71b 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
@@ -22,9 +22,6 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
spv.globalVariable @var6 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]>, StorageBuffer>
- // CHECK: !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
- spv.globalVariable @var7 : !spv.ptr<!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>, StorageBuffer>
-
// CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index d5eb073c9aa5..1d1a1868ea3c 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -275,23 +275,17 @@ func @struct_type_with_decoration7(!spv.struct<f32 [0], !spv.struct<i32, f32 [No
// CHECK: func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
func @struct_type_with_decoration8(!spv.struct<f32, !spv.struct<i32 [0], f32 [4, NonReadable]>>)
-// CHECK: func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_1(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]>)
-
-// CHECK: func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-func @struct_type_with_matrix_2(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=16]>)
-
// CHECK: func @struct_empty(!spv.struct<>)
func @struct_empty(!spv.struct<>)
// -----
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
// -----
-// expected-error @+1 {{offset specification must be given for all members}}
+// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
// -----
@@ -336,16 +330,6 @@ func @struct_type_missing_comma(!spv.struct<f32 [0, NonWritable NonReadable], i3
// -----
-// expected-error @+1 {{expected ']'}}
-func @struct_type_missing_comma(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor MatrixStride=16]>)
-
-// -----
-
-// expected-error @+1 {{expected integer value}}
-func @struct_missing_member_decorator_value(!spv.struct<!spv.matrix<3 x vector<3xf32>> [0, RowMajor, MatrixStride=]>)
-
-// -----
-
//===----------------------------------------------------------------------===//
// CooperativeMatrix
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 340bfd939e1b..06c417ca23f7 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -58,8 +58,8 @@ class SerializationTest : public ::testing::Test {
Type getFloatStructType() {
OpBuilder opBuilder(module.body());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
- llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
- auto structType = spirv::StructType::get(elementTypes, offsetInfo);
+ llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
+ auto structType = spirv::StructType::get(elementTypes, layoutInfo);
return structType;
}
More information about the Mlir-commits
mailing list