[Mlir-commits] [mlir] a290c3a - [mlir][spirv] Improve stride support in array types
Lei Zhang
llvmlistbot at llvm.org
Mon Apr 13 11:12:44 PDT 2020
Author: Lei Zhang
Date: 2020-04-13T14:08:17-04:00
New Revision: a290c3af9dd01dd3f13a5231626892f6a658f4e4
URL: https://github.com/llvm/llvm-project/commit/a290c3af9dd01dd3f13a5231626892f6a658f4e4
DIFF: https://github.com/llvm/llvm-project/commit/a290c3af9dd01dd3f13a5231626892f6a658f4e4.diff
LOG: [mlir][spirv] Improve stride support in array types
This commit added stride support in runtime array types. It also
adjusted the assembly form for the stride from `[N]` to `stride=N`.
This makes the IR more readable, especially for the cases where
one mix array types and struct types.
Differential Revision: https://reviews.llvm.org/D78034
Added:
Modified:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/simple.mlir
mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
mlir/test/Dialect/SPIRV/Serialization/array.mlir
mlir/test/Dialect/SPIRV/Serialization/constant.mlir
mlir/test/Dialect/SPIRV/Serialization/loop.mlir
mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
mlir/test/Dialect/SPIRV/Serialization/struct.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
mlir/test/Dialect/SPIRV/structure-ops.mlir
mlir/test/Dialect/SPIRV/types.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 0ba2115b2a6f..de896b8e7daf 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -287,13 +287,15 @@ element-type ::= integer-type
| vector-type
| spirv-type
-array-type ::= `!spv.array<` integer-literal `x` element-type `>`
+array-type ::= `!spv.array` `<` integer-literal `x` element-type
+ (`,` `stride` `=` integer-literal)? `>`
```
For example,
```mlir
!spv.array<4 x i32>
+!spv.array<4 x i32, stride = 4>
!spv.array<16 x vector<4 x f32>>
```
@@ -351,13 +353,14 @@ For example,
This corresponds to SPIR-V [runtime array type][RuntimeArrayType]. Its syntax is
```
-runtime-array-type ::= `!spv.rtarray<` element-type `>`
+runtime-array-type ::= `!spv.rtarray` `<` element-type (`,` `stride` `=` integer-literal)? `>`
```
For example,
```mlir
!spv.rtarray<i32>
+!spv.rtarray<i32, stride=4>
!spv.rtarray<vector<4 x f32>>
```
diff --git a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
index e7dc6763b62a..fdae7be1eaed 100644
--- a/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
+++ b/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
@@ -18,9 +18,11 @@
namespace mlir {
class Type;
class VectorType;
+
namespace spirv {
-class StructType;
class ArrayType;
+class RuntimeArrayType;
+class StructType;
} // namespace spirv
/// According to the Vulkan spec "14.5.4. Offset and Stride Assignment":
@@ -47,10 +49,9 @@ class VulkanLayoutUtils {
public:
using Size = uint64_t;
- /// Returns a new StructType with layout info. Assigns the type size in bytes
- /// to the `size`. Assigns the type alignment in bytes to the `alignment`.
- static spirv::StructType decorateType(spirv::StructType structType,
- Size &size, Size &alignment);
+ /// Returns a new StructType with layout decoration.
+ static spirv::StructType decorateType(spirv::StructType structType);
+
/// Checks whether a type is legal in terms of Vulkan layout info
/// decoration. A type is dynamically illegal if it's a composite type in the
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage
@@ -58,10 +59,17 @@ class VulkanLayoutUtils {
static bool isLegalType(Type type);
private:
+ /// Returns a new type with layout decoration. Assigns the type size in bytes
+ /// to the `size`. Assigns the type alignment in bytes to the `alignment`.
static Type decorateType(Type type, Size &size, Size &alignment);
+
static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
static Type decorateType(spirv::ArrayType arrayType, Size &size,
Size &alignment);
+ static Type decorateType(spirv::RuntimeArrayType arrayType, Size &alignment);
+ static spirv::StructType decorateType(spirv::StructType structType,
+ Size &size, Size &alignment);
+
/// Calculates the alignment for the given scalar type.
static Size getScalarTypeAlignment(Type scalarType);
};
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 85b35f73f82c..55911cf10b7b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -147,23 +147,22 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;
- // Zero layout specifies that is no layout
- using LayoutInfo = uint64_t;
static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
static ArrayType get(Type elementType, unsigned elementCount);
+ /// Returns an array type with the given stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount,
- LayoutInfo layoutInfo);
+ unsigned stride);
unsigned getNumElements() const;
Type getElementType() const;
- bool hasLayout() const;
-
- uint64_t getArrayStride() const;
+ /// Returns the array stride in bytes. 0 means no stride decorated on this
+ /// type.
+ unsigned getArrayStride() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
@@ -243,8 +242,15 @@ class RuntimeArrayType
static RuntimeArrayType get(Type elementType);
+ /// Returns a runtime array type with the given stride in bytes.
+ static RuntimeArrayType get(Type elementType, unsigned stride);
+
Type getElementType() const;
+ /// Returns the array stride in bytes. 0 means no stride decorated on this
+ /// type.
+ unsigned getArrayStride() const;
+
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index d4ce17c93706..f953f958d75e 100644
--- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -16,6 +16,13 @@
using namespace mlir;
+spirv::StructType
+VulkanLayoutUtils::decorateType(spirv::StructType structType) {
+ Size size = 0;
+ Size alignment = 1;
+ return decorateType(structType, size, alignment);
+}
+
spirv::StructType
VulkanLayoutUtils::decorateType(spirv::StructType structType,
VulkanLayoutUtils::Size &size,
@@ -25,21 +32,26 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
}
SmallVector<Type, 4> memberTypes;
- SmallVector<VulkanLayoutUtils::Size, 4> layoutInfo;
+ SmallVector<Size, 4> layoutInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
- VulkanLayoutUtils::Size structMemberOffset = 0;
- VulkanLayoutUtils::Size maxMemberAlignment = 1;
+ Size structMemberOffset = 0;
+ Size maxMemberAlignment = 1;
for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
- VulkanLayoutUtils::Size memberSize = 0;
- VulkanLayoutUtils::Size memberAlignment = 1;
+ Size memberSize = 0;
+ Size memberAlignment = 1;
- auto memberType = VulkanLayoutUtils::decorateType(
- structType.getElementType(i), memberSize, memberAlignment);
+ auto memberType =
+ decorateType(structType.getElementType(i), memberSize, memberAlignment);
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
memberTypes.push_back(memberType);
layoutInfo.push_back(structMemberOffset);
+ // If the member's size is the max value, it must be the last member and it
+ // must be a runtime array.
+ assert(memberSize != std::numeric_limits<Size>().max() ||
+ (i + 1 == e &&
+ structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
// According to the Vulkan spec:
// "A structure has a base alignment equal to the largest base alignment of
// any of its members."
@@ -60,7 +72,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
if (type.isa<spirv::ScalarType>()) {
- alignment = VulkanLayoutUtils::getScalarTypeAlignment(type);
+ alignment = getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
return type;
@@ -68,14 +80,14 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
switch (type.getKind()) {
case spirv::TypeKind::Struct:
- return VulkanLayoutUtils::decorateType(type.cast<spirv::StructType>(), size,
- alignment);
+ return decorateType(type.cast<spirv::StructType>(), size, alignment);
case spirv::TypeKind::Array:
- return VulkanLayoutUtils::decorateType(type.cast<spirv::ArrayType>(), size,
- alignment);
+ return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
case StandardTypes::Vector:
- return VulkanLayoutUtils::decorateType(type.cast<VectorType>(), size,
- alignment);
+ return decorateType(type.cast<VectorType>(), size, alignment);
+ case spirv::TypeKind::RuntimeArray:
+ size = std::numeric_limits<Size>().max();
+ return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
default:
llvm_unreachable("unhandled SPIR-V type");
}
@@ -86,11 +98,10 @@ Type VulkanLayoutUtils::decorateType(VectorType vectorType,
VulkanLayoutUtils::Size &alignment) {
const auto numElements = vectorType.getNumElements();
auto elementType = vectorType.getElementType();
- VulkanLayoutUtils::Size elementSize = 0;
- VulkanLayoutUtils::Size elementAlignment = 1;
+ Size elementSize = 0;
+ Size elementAlignment = 1;
- auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize,
- elementAlignment);
+ auto memberType = decorateType(elementType, elementSize, elementAlignment);
// According to the Vulkan spec:
// 1. "A two-component vector has a base alignment equal to twice its scalar
// alignment."
@@ -106,11 +117,10 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
VulkanLayoutUtils::Size &alignment) {
const auto numElements = arrayType.getNumElements();
auto elementType = arrayType.getElementType();
- spirv::ArrayType::LayoutInfo elementSize = 0;
- VulkanLayoutUtils::Size elementAlignment = 1;
+ Size elementSize = 0;
+ Size elementAlignment = 1;
- auto memberType = VulkanLayoutUtils::decorateType(elementType, elementSize,
- elementAlignment);
+ auto memberType = decorateType(elementType, elementSize, elementAlignment);
// According to the Vulkan spec:
// "An array has a base alignment equal to the base alignment of its element
// type."
@@ -119,6 +129,15 @@ Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
return spirv::ArrayType::get(memberType, numElements, elementSize);
}
+Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
+ VulkanLayoutUtils::Size &alignment) {
+ auto elementType = arrayType.getElementType();
+ Size elementSize = 0;
+
+ auto memberType = decorateType(elementType, elementSize, alignment);
+ return spirv::RuntimeArrayType::get(memberType, elementSize);
+}
+
VulkanLayoutUtils::Size
VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
// According to the Vulkan spec:
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 9478abb690af..e811fe6ec40a 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -149,7 +149,7 @@ Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
DialectAsmParser &parser);
template <>
-Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
+Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
DialectAsmParser &parser);
static Type parseAndVerifyType(SPIRVDialect const &dialect,
@@ -196,13 +196,39 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return type;
}
+/// Parses an optional `, stride = N` assembly segment. If no parsing failure
+/// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
+/// missing.
+static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
+ DialectAsmParser &parser,
+ unsigned &stride) {
+ if (failed(parser.parseOptionalComma())) {
+ stride = 0;
+ return success();
+ }
+
+ if (parser.parseKeyword("stride") || parser.parseEqual())
+ return failure();
+
+ llvm::SMLoc strideLoc = parser.getCurrentLocation();
+ Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
+ if (!optStride)
+ return failure();
+
+ if (!(stride = optStride.getValue())) {
+ parser.emitError(strideLoc, "ArrayStride must be greater than zero");
+ return failure();
+ }
+ return success();
+}
+
// element-type ::= integer-type
// | floating-point-type
// | vector-type
// | spirv-type
//
-// array-type ::= `!spv.array<` integer-literal `x` element-type
-// (`[` integer-literal `]`)? `>`
+// array-type ::= `!spv.array` `<` integer-literal `x` element-type
+// (`,` `stride` `=` integer-literal)? `>`
static Type parseArrayType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -230,25 +256,13 @@ static Type parseArrayType(SPIRVDialect const &dialect,
if (!elementType)
return Type();
- ArrayType::LayoutInfo layoutInfo = 0;
- if (succeeded(parser.parseOptionalLSquare())) {
- llvm::SMLoc layoutLoc = parser.getCurrentLocation();
- auto layout = parseAndVerify<ArrayType::LayoutInfo>(dialect, parser);
- if (!layout)
- return Type();
-
- if (!(layoutInfo = layout.getValue())) {
- parser.emitError(layoutLoc, "ArrayStride must be greater than zero");
- return Type();
- }
-
- if (parser.parseRSquare())
- return Type();
- }
+ unsigned stride = 0;
+ if (failed(parseOptionalArrayStride(dialect, parser, stride)))
+ return Type();
if (parser.parseGreater())
return Type();
- return ArrayType::get(elementType, count, layoutInfo);
+ return ArrayType::get(elementType, count, stride);
}
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
@@ -285,7 +299,8 @@ static Type parsePointerType(SPIRVDialect const &dialect,
return PointerType::get(pointeeType, *storageClass);
}
-// runtime-array-type ::= `!spv.rtarray<` element-type `>`
+// runtime-array-type ::= `!spv.rtarray` `<` element-type
+// (`,` `stride` `=` integer-literal)? `>`
static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
@@ -295,9 +310,13 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
if (!elementType)
return Type();
+ unsigned stride = 0;
+ if (failed(parseOptionalArrayStride(dialect, parser, stride)))
+ return Type();
+
if (parser.parseGreater())
return Type();
- return RuntimeArrayType::get(elementType);
+ return RuntimeArrayType::get(elementType, stride);
}
// Specialize this function to parse each of the parameters that define an
@@ -337,9 +356,9 @@ static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
}
template <>
-Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
+Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
- return parseAndVerifyInteger<uint64_t>(dialect, parser);
+ return parseAndVerifyInteger<unsigned>(dialect, parser);
}
namespace {
@@ -526,14 +545,16 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
static void print(ArrayType type, DialectAsmPrinter &os) {
os << "array<" << type.getNumElements() << " x " << type.getElementType();
- if (type.hasLayout()) {
- os << " [" << type.getArrayStride() << "]";
- }
+ if (unsigned stride = type.getArrayStride())
+ os << ", stride=" << stride;
os << ">";
}
static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
- os << "rtarray<" << type.getElementType() << ">";
+ os << "rtarray<" << type.getElementType();
+ if (unsigned stride = type.getArrayStride())
+ os << ", stride=" << stride;
+ os << ">";
}
static void print(PointerType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 3f963bd1d8a8..71ca0c3d2bc7 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -100,7 +100,7 @@ spirv::getRecursiveImpliedCapabilities(Capability cap) {
//===----------------------------------------------------------------------===//
struct spirv::detail::ArrayTypeStorage : public TypeStorage {
- using KeyTy = std::tuple<Type, unsigned, ArrayType::LayoutInfo>;
+ using KeyTy = std::tuple<Type, unsigned, unsigned>;
static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@@ -108,28 +108,28 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getSubclassData(), layoutInfo);
+ return key == KeyTy(elementType, getSubclassData(), stride);
}
ArrayTypeStorage(const KeyTy &key)
: TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
- layoutInfo(std::get<2>(key)) {}
+ stride(std::get<2>(key)) {}
Type elementType;
- ArrayType::LayoutInfo layoutInfo;
+ unsigned stride;
};
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
- elementCount, 0);
+ elementCount, /*stride=*/0);
}
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
- ArrayType::LayoutInfo layoutInfo) {
+ unsigned stride) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
- elementCount, layoutInfo);
+ elementCount, stride);
}
unsigned ArrayType::getNumElements() const {
@@ -138,10 +138,7 @@ unsigned ArrayType::getNumElements() const {
Type ArrayType::getElementType() const { return getImpl()->elementType; }
-// ArrayStride must be greater than zero
-bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
-
-uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
+unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
@@ -215,7 +212,7 @@ void CompositeType::getExtensions(
cast<ArrayType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::RuntimeArray:
- cast<ArrayType>().getExtensions(extensions, storage);
+ cast<RuntimeArrayType>().getExtensions(extensions, storage);
break;
case spirv::TypeKind::Struct:
cast<StructType>().getExtensions(extensions, storage);
@@ -237,7 +234,7 @@ void CompositeType::getCapabilities(
cast<ArrayType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::RuntimeArray:
- cast<ArrayType>().getCapabilities(capabilities, storage);
+ cast<RuntimeArrayType>().getCapabilities(capabilities, storage);
break;
case spirv::TypeKind::Struct:
cast<StructType>().getCapabilities(capabilities, storage);
@@ -523,7 +520,7 @@ void PointerType::getCapabilities(
//===----------------------------------------------------------------------===//
struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
- using KeyTy = Type;
+ using KeyTy = std::pair<Type, unsigned>;
static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
@@ -531,20 +528,32 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
RuntimeArrayTypeStorage(key);
}
- bool operator==(const KeyTy &key) const { return elementType == key; }
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, getSubclassData());
+ }
- RuntimeArrayTypeStorage(const KeyTy &key) : elementType(key) {}
+ RuntimeArrayTypeStorage(const KeyTy &key)
+ : TypeStorage(key.second), elementType(key.first) {}
Type elementType;
};
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
- elementType);
+ elementType, /*stride=*/0);
+}
+
+RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
+ return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
+ elementType, stride);
}
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
+unsigned RuntimeArrayType::getArrayStride() const {
+ return getImpl()->getSubclassData();
+}
+
void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 82be16847594..9c4670932d73 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1203,7 +1203,8 @@ Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
"OpTypeRuntimeArray references undefined <id> ")
<< operands[1];
}
- typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType);
+ typeMap[operands[0]] = spirv::RuntimeArrayType::get(
+ memberType, typeDecorations.lookup(operands[0]));
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 5dbc4600db5a..b25a99c61c05 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -676,10 +676,19 @@ namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
Location loc, spirv::ArrayType type, uint32_t resultID) {
- if (type.hasLayout()) {
+ if (unsigned stride = type.getArrayStride()) {
// OpDecorate %arrayTypeSSA ArrayStride strideLiteral
- return emitDecoration(resultID, spirv::Decoration::ArrayStride,
- {static_cast<uint32_t>(type.getArrayStride())});
+ return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
+ }
+ return success();
+}
+
+template <>
+LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
+ Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) {
+ if (unsigned stride = type.getArrayStride()) {
+ // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
+ return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
}
return success();
}
@@ -1011,9 +1020,9 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
elementTypeID))) {
return failure();
}
- operands.push_back(elementTypeID);
typeEnum = spirv::Opcode::OpTypeRuntimeArray;
- return success();
+ operands.push_back(elementTypeID);
+ return processTypeDecoration(loc, runtimeArrayType, resultID);
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index b4674afbd980..b14bdd152b58 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -30,14 +30,11 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
- spirv::StructType::LayoutInfo structSize = 0;
- VulkanLayoutUtils::Size structAlignment = 1;
SmallVector<NamedAttribute, 4> globalVarAttrs;
auto ptrType = op.type().cast<spirv::PointerType>();
auto structType = VulkanLayoutUtils::decorateType(
- ptrType.getPointeeType().cast<spirv::StructType>(), structSize,
- structAlignment);
+ ptrType.getPointeeType().cast<spirv::StructType>());
auto decoratedType =
spirv::PointerType::get(structType, ptrType.getStorageClass());
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 101536d1afe9..d6b32436c0b4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -50,10 +50,8 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
// Set the offset information.
- VulkanLayoutUtils::Size size = 0, alignment = 0;
varPointeeType =
- VulkanLayoutUtils::decorateType(varPointeeType, size, alignment)
- .cast<spirv::StructType>();
+ VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
varType =
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index 05c9d90c498c..13dc621af8b7 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -27,9 +27,9 @@ module attributes {
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-LABEL: spv.func @load_store_kernel
- // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}}
// CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index 9cf0f5045d73..0d0b4c891337 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -5,7 +5,7 @@ module attributes {gpu.container_module} {
// CHECK: spv.module Logical GLSL450 {
// CHECK-LABEL: spv.func @basic_module_structure
// CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
+ // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>)
attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
index c286b4c104db..8e09a199f279 100644
--- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
@@ -6,12 +6,12 @@
module attributes {gpu.container_module} {
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
- spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} {
- %0 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ %0 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
%2 = spv.constant 0 : i32
- %3 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
- %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ %3 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
+ %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
%5 = spv.Load "StorageBuffer" %4 : f32
spv.Return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 59e1c8016a54..6abdde44e3e5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -312,17 +312,17 @@ func @constant() {
%3 = constant dense<[2, 3]> : vector<2xi32>
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
- // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
%5 = constant dense<1> : tensor<2x3xi32>
- // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+ // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
%6 = constant dense<1.0> : tensor<2x3xf32>
- // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+ // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
%7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
- // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
%8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
- // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
%9 = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
- // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
%10 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
return
}
@@ -335,7 +335,7 @@ func @constant_16bit() {
%1 = constant 5.0 : f16
// CHECK: spv.constant dense<[2, 3]> : vector<2xi16>
%2 = constant dense<[2, 3]> : vector<2xi16>
- // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16 [2]>
+ // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16, stride=2>
%3 = constant dense<4.0> : tensor<5xf16>
return
}
@@ -348,7 +348,7 @@ func @constant_64bit() {
%1 = constant 5.0 : f64
// CHECK: spv.constant dense<[2, 3]> : vector<2xi64>
%2 = constant dense<[2, 3]> : vector<2xi64>
- // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64 [8]>
+ // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64, stride=8>
%3 = constant dense<4.0> : tensor<5xf64>
return
}
@@ -373,9 +373,9 @@ func @constant_16bit() {
%1 = constant 5.0 : f16
// CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
%2 = constant dense<[2, 3]> : vector<2xi16>
- // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]>
+ // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
%3 = constant dense<4.0> : tensor<5xf16>
- // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]>
+ // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
%4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
return
}
@@ -388,9 +388,9 @@ func @constant_64bit() {
%1 = constant 5.0 : f64
// CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
%2 = constant dense<[2, 3]> : vector<2xi64>
- // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]>
+ // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
%3 = constant dense<4.0> : tensor<5xf64>
- // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]>
+ // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
%4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
return
}
@@ -572,8 +572,8 @@ func @select(%arg0 : i32, %arg1 : i32) {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @load_store_zero_rank_float
-// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
-// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>)
func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
// CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
// CHECK: spv.AccessChain [[ARG0]][
@@ -591,8 +591,8 @@ func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
}
// CHECK-LABEL: @load_store_zero_rank_int
-// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>,
-// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>)
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>)
func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
// CHECK: spv.AccessChain [[ARG0]][
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 8ad15cf11104..89d1fe2eb1e3 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -311,35 +311,35 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32, stride=4> [0]>, StorageBuffer>
func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
// CHECK-LABEL: spv.func @memref_8bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32 [4]> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32, stride=4> [0]>, Uniform>
func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
// CHECK-LABEL: spv.func @memref_8bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32 [4]> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32, stride=4> [0]>, PushConstant>
func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i32, stride=4> [0]>, StorageBuffer>
func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32 [4]> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x si32, stride=4> [0]>, Uniform>
func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
// CHECK-LABEL: spv.func @memref_16bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32 [4]> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x ui32, stride=4> [0]>, PushConstant>
func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32 [4]> [0]>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, Input>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32 [4]> [0]>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f32, stride=4> [0]>, Output>
func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
} // end module
@@ -358,12 +358,12 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_8bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8, stride=1> [0]>, PushConstant>
func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, PushConstant>
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16, stride=2> [0]>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16, stride=2> [0]>, PushConstant>
func @memref_16bit_PushConstant(
%arg0: memref<16xi16, 7>,
%arg1: memref<16xf16, 7>
@@ -385,12 +385,12 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8, stride=1> [0]>, StorageBuffer>
func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16, stride=2> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16, stride=2> [0]>, StorageBuffer>
func @memref_16bit_StorageBuffer(
%arg0: memref<16xi16, 0>,
%arg1: memref<16xf16, 0>
@@ -412,12 +412,12 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_8bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8 [1]> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i8, stride=1> [0]>, Uniform>
func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, Uniform>
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16, stride=2> [0]>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16, stride=2> [0]>, Uniform>
func @memref_16bit_Uniform(
%arg0: memref<16xi16, 4>,
%arg1: memref<16xf16, 4>
@@ -438,11 +438,11 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16 [2]> [0]>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x f16, stride=2> [0]>, Input>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16 [2]> [0]>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<16 x i16, stride=2> [0]>, Output>
func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
} // end module
@@ -459,22 +459,22 @@ module attributes {
// CHECK-LABEL: spv.func @memref_offset_strides
func @memref_offset_strides(
-// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<72 x f32 [4]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<256 x f32 [4]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<64 x f32 [4]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<88 x f32 [4]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<72 x f32, stride=4> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<256 x f32, stride=4> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f32, stride=4> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<88 x f32, stride=4> [0]>, StorageBuffer>
%arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major
%arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8
%arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row
%arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major
%arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col
-// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<72 x f16 [2]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<256 x f16 [2]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<64 x f16 [2]> [0]>, StorageBuffer>
-// CHECK-SAME: !spv.array<88 x f16 [2]> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<72 x f16, stride=2> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<256 x f16, stride=2> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<64 x f16, stride=2> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.array<88 x f16, stride=2> [0]>, StorageBuffer>
%arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>,
%arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>,
%arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>,
@@ -519,10 +519,10 @@ module attributes {
} {
// CHECK-LABEL: spv.func @int_tensor_types
-// CHECK-SAME: !spv.array<32 x i64 [8]>
-// CHECK-SAME: !spv.array<32 x i32 [4]>
-// CHECK-SAME: !spv.array<32 x i16 [2]>
-// CHECK-SAME: !spv.array<32 x i8 [1]>
+// CHECK-SAME: !spv.array<32 x i64, stride=8>
+// CHECK-SAME: !spv.array<32 x i32, stride=4>
+// CHECK-SAME: !spv.array<32 x i16, stride=2>
+// CHECK-SAME: !spv.array<32 x i8, stride=1>
func @int_tensor_types(
%arg0: tensor<8x4xi64>,
%arg1: tensor<8x4xi32>,
@@ -531,9 +531,9 @@ func @int_tensor_types(
) { return }
// CHECK-LABEL: spv.func @float_tensor_types
-// CHECK-SAME: !spv.array<32 x f64 [8]>
-// CHECK-SAME: !spv.array<32 x f32 [4]>
-// CHECK-SAME: !spv.array<32 x f16 [2]>
+// CHECK-SAME: !spv.array<32 x f64, stride=8>
+// CHECK-SAME: !spv.array<32 x f32, stride=4>
+// CHECK-SAME: !spv.array<32 x f16, stride=2>
func @float_tensor_types(
%arg0: tensor<8x4xf64>,
%arg1: tensor<8x4xf32>,
@@ -553,10 +553,10 @@ module attributes {
} {
// CHECK-LABEL: spv.func @int_tensor_types
-// CHECK-SAME: !spv.array<32 x i32 [4]>
-// CHECK-SAME: !spv.array<32 x i32 [4]>
-// CHECK-SAME: !spv.array<32 x i32 [4]>
-// CHECK-SAME: !spv.array<32 x i32 [4]>
+// CHECK-SAME: !spv.array<32 x i32, stride=4>
+// CHECK-SAME: !spv.array<32 x i32, stride=4>
+// CHECK-SAME: !spv.array<32 x i32, stride=4>
+// CHECK-SAME: !spv.array<32 x i32, stride=4>
func @int_tensor_types(
%arg0: tensor<8x4xi64>,
%arg1: tensor<8x4xi32>,
@@ -565,9 +565,9 @@ func @int_tensor_types(
) { return }
// CHECK-LABEL: spv.func @float_tensor_types
-// CHECK-SAME: !spv.array<32 x f32 [4]>
-// CHECK-SAME: !spv.array<32 x f32 [4]>
-// CHECK-SAME: !spv.array<32 x f32 [4]>
+// CHECK-SAME: !spv.array<32 x f32, stride=4>
+// CHECK-SAME: !spv.array<32 x f32, stride=4>
+// CHECK-SAME: !spv.array<32 x f32, stride=4>
func @float_tensor_types(
%arg0: tensor<8x4xf64>,
%arg1: tensor<8x4xf32>,
diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
index cc94c089dfb2..8632d2d9f37d 100644
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
@@ -16,7 +16,7 @@ module attributes {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @fold_static_stride_subview_with_load
-// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32 [4]> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32
+// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32, stride=4> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32
func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) {
// CHECK: [[C2:%.*]] = spv.constant 2
// CHECK: [[C3:%.*]] = spv.constant 3
@@ -38,7 +38,7 @@ func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : in
}
// CHECK-LABEL: @fold_static_stride_subview_with_store
-// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32 [4]> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32
+// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<384 x f32, stride=4> [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32
func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) {
// CHECK: [[C2:%.*]] = spv.constant 2
// CHECK: [[C3:%.*]] = spv.constant 3
diff --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir
index aa7cc405b5ee..0d14cbf9d3b6 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
- spv.func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
- // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32 [4]> [128]>, StorageBuffer>
- %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32 [4]> [128]>, StorageBuffer>
+ spv.func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
+ // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32, stride=4>, stride=128>, StorageBuffer>
+ %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32, stride=4>, stride=128>, StorageBuffer>
spv.Return
}
}
@@ -11,8 +11,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
- // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.rtarray<f32>, StorageBuffer>
- spv.globalVariable @var0 : !spv.ptr<!spv.rtarray<f32>, StorageBuffer>
+ // CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.rtarray<f32, stride=4>, StorageBuffer>
+ spv.globalVariable @var0 : !spv.ptr<!spv.rtarray<f32, stride=4>, StorageBuffer>
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.rtarray<vector<4xf16>>, Input>
spv.globalVariable @var1 : !spv.ptr<!spv.rtarray<vector<4xf16>>, Input>
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
index 180bd2b644be..a276e4ee9781 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir
@@ -226,16 +226,16 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @multi_dimensions_const
- spv.func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" {
- // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
- %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
- spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+ spv.func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>) "None" {
+ // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
+ %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
+ spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
}
// CHECK-LABEL: @multi_dimensions_splat_const
- spv.func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" {
- // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
- %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
- spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
+ spv.func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>) "None" {
+ // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
+ %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
+ spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32, stride=4>, stride=12>, stride=24>
}
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
index e280f21c38f1..1b041a4aa604 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -60,14 +60,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
- spv.globalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
- spv.globalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ spv.globalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+ spv.globalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
spv.func @loop_kernel() "None" {
- %0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
- %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
- %3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
- %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
+ %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+ %3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+ %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
%6 = spv.constant 4 : i32
%7 = spv.constant 42 : i32
%8 = spv.constant 2 : i32
@@ -84,9 +84,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.BranchConditional %10, ^body, ^merge
// CHECK-NEXT: ^bb2: // pred: ^bb1
^body:
- %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
+ %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>
%12 = spv.Load "StorageBuffer" %11 : f32
- %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
+ %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>
spv.Store "StorageBuffer" %13, %12 : f32
// CHECK: %[[ADD:.*]] = spv.IAdd
%14 = spv.IAdd %9, %8 : i32
diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index d082fa01d9ae..fbe45fa87d20 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -27,32 +27,32 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
- spv.func @load_store_zero_rank_float(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+ spv.func @load_store_zero_rank_float(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>
// CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32
%0 = spv.constant 0 : i32
- %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+ %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>
%2 = spv.Load "StorageBuffer" %1 : f32
- // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+ // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>
// CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
%3 = spv.constant 0 : i32
- %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+ %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>
spv.Store "StorageBuffer" %4, %2 : f32
spv.Return
}
- spv.func @load_store_zero_rank_int(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+ spv.func @load_store_zero_rank_int(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>
// CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32
%0 = spv.constant 0 : i32
- %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+ %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>
%2 = spv.Load "StorageBuffer" %1 : i32
- // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+ // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>
// CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
%3 = spv.constant 0 : i32
- %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+ %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>
spv.Store "StorageBuffer" %4, %2 : i32
spv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
index e96cc418615f..3066462fd71b 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir
@@ -1,17 +1,17 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
- // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
- spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
+ // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Input>
+ spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Input>
- // CHECK: !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
- spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
+ // CHECK: !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32, stride=4> [4]> [4]>, Input>
+ spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32, stride=4> [4]> [4]>, Input>
// CHECK: !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
spv.globalVariable @var2 : !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
- // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
- spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
+ // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32, stride=4> [0]>, stride=512> [0]>, StorageBuffer>
+ spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32, stride=4> [0]>, stride=512> [0]>, StorageBuffer>
// CHECK: !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4]>, StorageBuffer>
spv.globalVariable @var4 : !spv.ptr<!spv.struct<f32 [0, NonWritable], i32 [4]>, StorageBuffer>
@@ -25,9 +25,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: !spv.ptr<!spv.struct<>, StorageBuffer>
spv.globalVariable @empty : !spv.ptr<!spv.struct<>, StorageBuffer>
- // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>,
- // CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>
- spv.func @kernel_1(%arg0: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>, %arg1: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>) -> () "None" {
+ // CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Input>,
+ // CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Output>
+ spv.func @kernel(%arg0: !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Input>, %arg1: !spv.ptr<!spv.struct<!spv.array<128 x f32, stride=4> [0]>, Output>) -> () "None" {
spv.Return
}
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index 3972def985bb..187c5741d7f0 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -10,7 +10,7 @@ module attributes {
// CHECK-LABEL: spv.module
spv.module Logical GLSL450 {
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
// CHECK: spv.func [[FN:@.*]]()
spv.func @kernel(
%arg0: f32
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 42ff3f55e1ea..c75c4d0f979c 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -17,9 +17,9 @@ spv.module Logical GLSL450 {
spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
- // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32, stride=4>, stride=16> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32, stride=4>, stride=16> [0]>, StorageBuffer>
+ // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32, stride=4>, stride=16> [0]>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
index 1129f89d7d84..975012d3a26a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
@@ -4,19 +4,19 @@ spv.module Logical GLSL450 {
// CHECK: spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0], !spv.struct<f32 [0], i32 [4]> [4], f32 [12]>, Uniform>
spv.globalVariable @var0 bind(0,1) : !spv.ptr<!spv.struct<i32, !spv.struct<f32, i32>, f32>, Uniform>
- // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<64 x i32 [4]> [0], f32 [256]>, StorageBuffer>
+ // CHECK: spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<64 x i32, stride=4> [0], f32 [256]>, StorageBuffer>
spv.globalVariable @var1 bind(0,2) : !spv.ptr<!spv.struct<!spv.array<64xi32>, f32>, StorageBuffer>
- // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64 x i32 [4]> [0], f32 [256]> [0], i32 [260]>, StorageBuffer>
+ // CHECK: spv.globalVariable @var2 bind(1, 0) : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64 x i32, stride=4> [0], f32 [256]> [0], i32 [260]>, StorageBuffer>
spv.globalVariable @var2 bind(1,0) : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64xi32>, f32>, i32>, StorageBuffer>
- // CHECK: spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<16 x !spv.struct<f32 [0], f32 [4], !spv.array<16 x f32 [4]> [8]> [72]> [0], f32 [1152]>, StorageBuffer>
+ // CHECK: spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<16 x !spv.struct<f32 [0], f32 [4], !spv.array<16 x f32, stride=4> [8]>, stride=72> [0], f32 [1152]>, StorageBuffer>
spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<16x!spv.struct<f32, f32, !spv.array<16xf32>>>, f32>, StorageBuffer>
// CHECK: spv.globalVariable @var4 bind(1, 2) : !spv.ptr<!spv.struct<!spv.struct<!spv.struct<i1 [0], i8 [1], i16 [2], i32 [4], i64 [8]> [0], f32 [16], i1 [20]> [0], i1 [24]>, StorageBuffer>
spv.globalVariable @var4 bind(1,2) : !spv.ptr<!spv.struct<!spv.struct<!spv.struct<i1, i8, i16, i32, i64>, f32, i1>, i1>, StorageBuffer>
- // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr<!spv.struct<!spv.array<256 x f32 [4]> [0]>, StorageBuffer>
+ // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr<!spv.struct<!spv.array<256 x f32, stride=4> [0]>, StorageBuffer>
spv.globalVariable @var5 bind(1,3) : !spv.ptr<!spv.struct<!spv.array<256xf32>>, StorageBuffer>
spv.func @kernel() -> () "None" {
@@ -38,10 +38,10 @@ spv.module Logical GLSL450 {
// CHECK: spv.globalVariable @var1 : !spv.ptr<!spv.struct<!spv.struct<i16 [0], !spv.struct<i1 [0], f64 [8]> [8], f32 [24]> [0], f32 [32]>, Uniform>
spv.globalVariable @var1 : !spv.ptr<!spv.struct<!spv.struct<i16, !spv.struct<i1, f64>, f32>, f32>, Uniform>
- // CHECK: spv.globalVariable @var2 : !spv.ptr<!spv.struct<!spv.struct<i16 [0], !spv.struct<i1 [0], !spv.array<16 x !spv.array<16 x i64 [8]> [128]> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform>
+ // CHECK: spv.globalVariable @var2 : !spv.ptr<!spv.struct<!spv.struct<i16 [0], !spv.struct<i1 [0], !spv.array<16 x !spv.array<16 x i64, stride=8>, stride=128> [8]> [8], f32 [2064]> [0], f32 [2072]>, Uniform>
spv.globalVariable @var2 : !spv.ptr<!spv.struct<!spv.struct<i16, !spv.struct<i1, !spv.array<16x!spv.array<16xi64>>>, f32>, f32>, Uniform>
- // CHECK: spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64 x i64 [8]> [0], i1 [512]> [0], i1 [520]>, Uniform>
+ // CHECK: spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64 x i64, stride=8> [0], i1 [512]> [0], i1 [520]>, Uniform>
spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.struct<!spv.array<64xi64>, i1>, i1>, Uniform>
// CHECK: spv.globalVariable @var4 : !spv.ptr<!spv.struct<i1 [0], !spv.struct<i64 [0], i1 [8], i1 [9], i1 [10], i1 [11]> [8], i1 [24]>, Uniform>
diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index 8e0f447223e4..2b1f7a038fe4 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -58,20 +58,20 @@ func @const() -> () {
// CHECK: %2 = spv.constant 5.000000e-01 : f32
// CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32>
// CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>>
- // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
- // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
- // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
- // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+ // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
+ // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
+ // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
+ // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
%0 = spv.constant true
%1 = spv.constant 42 : i32
%2 = spv.constant 0.5 : f32
%3 = spv.constant dense<[2, 3]> : vector<2xi32>
%4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
- %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
- %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
- %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
- %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+ %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
+ %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
+ %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
+ %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
return
}
@@ -118,14 +118,14 @@ func @value_result_type_mismatch() -> () {
func @value_result_type_mismatch() -> () {
// expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}}
- %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
+ %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32>>
}
// -----
func @value_result_num_elements_mismatch() -> () {
// expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}}
- %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
+ %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32>>
return
}
diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir
index 3c11d3b88a02..4c1adafce4a8 100644
--- a/mlir/test/Dialect/SPIRV/types.mlir
+++ b/mlir/test/Dialect/SPIRV/types.mlir
@@ -12,8 +12,8 @@ func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> ()
// CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>)
func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
-// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>)
-func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> ()
+// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32, stride=4>, stride=128>)
+func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32, stride=4>, stride = 128>) -> ()
// -----
@@ -78,7 +78,7 @@ func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
// -----
// expected-error @+1 {{ArrayStride must be greater than zero}}
-func @array_type_zero_stride(!spv.array<4xi32 [0]>) -> ()
+func @array_type_zero_stride(!spv.array<4xi32, stride=0>) -> ()
// -----
@@ -132,6 +132,9 @@ func @scalar_runtime_array_type(!spv.rtarray<f32>, !spv.rtarray<i32>) -> ()
// CHECK: func @vector_runtime_array_type(!spv.rtarray<vector<4xf32>>)
func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> ()
+// CHECK: func @runtime_array_type_stride(!spv.rtarray<f32, stride=4>)
+func @runtime_array_type_stride(!spv.rtarray<f32, stride=4>) -> ()
+
// -----
// expected-error @+1 {{expected '<'}}
@@ -149,6 +152,11 @@ func @redundant_count(!spv.rtarray<4xf32>) -> ()
// -----
+// expected-error @+1 {{ArrayStride must be greater than zero}}
+func @runtime_array_type_zero_stride(!spv.rtarray<i32, stride=0>) -> ()
+
+// -----
+
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list