[Mlir-commits] [mlir] ca7c058 - [mlir][spirv] Rework type capability queries (#160113)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 08:28:45 PDT 2025
Author: Jakub Kuderski
Date: 2025-09-22T15:28:41Z
New Revision: ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735
URL: https://github.com/llvm/llvm-project/commit/ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735
DIFF: https://github.com/llvm/llvm-project/commit/ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735.diff
LOG: [mlir][spirv] Rework type capability queries (#160113)
* Fix infinite recursion with nested structs.
* Drop `::getCapbilities` function from derived types, so that there's
only one entry point that queries type extensions.
* Move all capability logic to a new helper class -- this way the
`::getCapabilities` functions can't diverge across concrete types and
'convenience types' like CompositeType.
Fixes: #159963
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 6beffc17d6d58..475e3f495e065 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -89,9 +89,6 @@ class ScalarType : public SPIRVType {
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
std::optional<int64_t> getSizeInBytes();
};
@@ -116,9 +113,6 @@ class CompositeType : public SPIRVType {
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
std::optional<int64_t> getSizeInBytes();
};
@@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// type.
unsigned getArrayStride() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
std::optional<int64_t> getSizeInBytes();
@@ -186,9 +177,6 @@ class ImageType
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V pointer type
@@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
Type getPointeeType() const;
StorageClass getStorageClass() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V run-time array type
@@ -228,9 +213,6 @@ class RuntimeArrayType
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V sampled image type
@@ -252,10 +234,6 @@ class SampledImageType
Type imageType);
Type getImageType() const;
-
- void
- getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<spirv::StorageClass> storage = std::nullopt);
};
/// SPIR-V struct type. Two kinds of struct types are supported:
@@ -405,9 +383,6 @@ class StructType
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
llvm::hash_code
@@ -440,9 +415,6 @@ class CooperativeMatrixType
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
ArrayRef<int64_t> getShape() const;
@@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
/// SPIR-V TensorARM Type
@@ -531,9 +500,6 @@ class TensorArmType
ArrayRef<int64_t> getShape() const;
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
};
} // namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 8244e64abba12..7c2f43bea9ddb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -45,17 +45,67 @@ class TypeExtensionVisitor {
return;
TypeSwitch<SPIRVType>(type)
- .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+ .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
[this](auto concreteType) { addConcrete(concreteType); })
- .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+ .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
+ .Case<SampledImageType>([this](SampledImageType concreteType) {
+ add(concreteType.getImageType());
+ })
.Case<StructType>([this](StructType concreteType) {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
+ .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ }
+
+ void add(Type type) { add(cast<SPIRVType>(type)); }
+
+private:
+ // Types that add unique extensions.
+ void addConcrete(CooperativeMatrixType type);
+ void addConcrete(PointerType type);
+ void addConcrete(ScalarType type);
+ void addConcrete(TensorArmType type);
+
+ SPIRVType::ExtensionArrayRefVector &extensions;
+ std::optional<StorageClass> storage;
+ llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
+};
+
+// Helper function to collect capabilities implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type capability information. All capability
+// logic should be added to this class, while the
+// `SPIRVType::getCapabilities` function should not handle capability-related
+// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
+class TypeCapabilityVisitor {
+public:
+ TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage)
+ : capabilities(capabilities), storage(storage) {}
+
+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+ // as seen and dispatches to the right concrete `.add` function.
+ void add(SPIRVType type) {
+ if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
+ return;
+
+ TypeSwitch<SPIRVType>(type)
+ .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
+ RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
+ [this](auto concreteType) { addConcrete(concreteType); })
+ .Case<ArrayType>([this](ArrayType concreteType) {
+ add(concreteType.getElementType());
+ })
.Case<SampledImageType>([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
+ .Case<StructType>([this](StructType concreteType) {
+ for (Type elementType : concreteType.getElementTypes())
+ add(elementType);
+ })
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
}
@@ -63,12 +113,16 @@ class TypeExtensionVisitor {
private:
// Types that add unique extensions.
- void addConcrete(ScalarType type);
- void addConcrete(PointerType type);
void addConcrete(CooperativeMatrixType type);
+ void addConcrete(ImageType type);
+ void addConcrete(MatrixType type);
+ void addConcrete(PointerType type);
+ void addConcrete(RuntimeArrayType type);
+ void addConcrete(ScalarType type);
void addConcrete(TensorArmType type);
+ void addConcrete(VectorType type);
- SPIRVType::ExtensionArrayRefVector &extensions;
+ SPIRVType::CapabilityArrayRefVector &capabilities;
std::optional<StorageClass> storage;
llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
};
@@ -118,13 +172,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
-void ArrayType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
-}
-
std::optional<int64_t> ArrayType::getSizeInBytes() {
auto elementType = llvm::cast<SPIRVType>(getElementType());
std::optional<int64_t> size = elementType.getSizeInBytes();
@@ -188,30 +235,14 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}
-void CompositeType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
- StructType>(
- [&](auto type) { type.getCapabilities(capabilities, storage); })
- .Case<VectorType>([&](VectorType type) {
- auto vecSize = getNumElements();
- if (vecSize == 8 || vecSize == 16) {
- static const Capability caps[] = {Capability::Vector16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
- }
- return llvm::cast<ScalarType>(type.getElementType())
- .getCapabilities(capabilities, storage);
- })
- .Case<TensorArmType>([&](TensorArmType type) {
- static constexpr Capability cap{Capability::TensorsARM};
- capabilities.push_back(cap);
- return llvm::cast<ScalarType>(type.getElementType())
- .getCapabilities(capabilities, storage);
- })
- .Default([](Type) { llvm_unreachable("invalid composite type"); });
+void TypeCapabilityVisitor::addConcrete(VectorType type) {
+ add(type.getElementType());
+
+ int64_t vecSize = type.getNumElements();
+ if (vecSize == 8 || vecSize == 16) {
+ static constexpr auto cap = Capability::Vector16;
+ capabilities.push_back(cap);
+ }
}
std::optional<int64_t> CompositeType::getSizeInBytes() {
@@ -317,12 +348,9 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
extensions.push_back(ext);
}
-void CooperativeMatrixType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
- static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
+ add(type.getElementType());
+ static constexpr auto caps = Capability::CooperativeMatrixKHR;
capabilities.push_back(caps);
}
@@ -428,14 +456,14 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
-void ImageType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass>) {
- if (auto dimCaps = spirv::getCapabilities(getDim()))
+void TypeCapabilityVisitor::addConcrete(ImageType type) {
+ if (auto dimCaps = spirv::getCapabilities(type.getDim()))
capabilities.push_back(*dimCaps);
- if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+ if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
capabilities.push_back(*fmtCaps);
+
+ add(type.getElementType());
}
//===----------------------------------------------------------------------===//
@@ -486,15 +514,15 @@ void TypeExtensionVisitor::addConcrete(PointerType type) {
extensions.push_back(*scExts);
}
-void PointerType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
+void TypeCapabilityVisitor::addConcrete(PointerType type) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- llvm::cast<SPIRVType>(getPointeeType())
- .getCapabilities(capabilities, getStorageClass());
+ std::optional<StorageClass> oldStorageClass = storage;
+ storage = type.getStorageClass();
+ add(type.getPointeeType());
+ storage = oldStorageClass;
- if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+ if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
capabilities.push_back(*scCaps);
}
@@ -534,16 +562,10 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
-void RuntimeArrayType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- {
- static const Capability caps[] = {Capability::Shader};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
- }
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
+ add(type.getElementType());
+ static constexpr auto cap = Capability::Shader;
+ capabilities.push_back(cap);
}
//===----------------------------------------------------------------------===//
@@ -601,10 +623,8 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
}
}
-void ScalarType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- unsigned bitwidth = getIntOrFloatBitWidth();
+void TypeCapabilityVisitor::addConcrete(ScalarType type) {
+ unsigned bitwidth = type.getIntOrFloatBitWidth();
// 8- or 16-bit integer/floating-point numbers will require extra capabilities
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
@@ -613,15 +633,13 @@ void ScalarType::getCapabilities(
#define STORAGE_CASE(storage, cap8, cap16) \
case StorageClass::storage: { \
if (bitwidth == 8) { \
- static const Capability caps[] = {Capability::cap8}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::cap8; \
+ capabilities.push_back(cap); \
return; \
} \
if (bitwidth == 16) { \
- static const Capability caps[] = {Capability::cap16}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::cap16; \
+ capabilities.push_back(cap); \
return; \
} \
/* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
@@ -640,9 +658,8 @@ void ScalarType::getCapabilities(
case StorageClass::Input:
case StorageClass::Output: {
if (bitwidth == 16) {
- static const Capability caps[] = {Capability::StorageInputOutput16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
+ static constexpr auto cap = Capability::StorageInputOutput16;
+ capabilities.push_back(cap);
return;
}
break;
@@ -658,12 +675,11 @@ void ScalarType::getCapabilities(
#define WIDTH_CASE(type, width) \
case width: { \
- static const Capability caps[] = {Capability::type##width}; \
- ArrayRef<Capability> ref(caps, std::size(caps)); \
- capabilities.push_back(ref); \
+ static constexpr auto cap = Capability::type##width; \
+ capabilities.push_back(cap); \
} break
- if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
switch (bitwidth) {
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
@@ -675,14 +691,14 @@ void ScalarType::getCapabilities(
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
- assert(llvm::isa<FloatType>(*this));
+ assert(isa<FloatType>(type));
switch (bitwidth) {
case 16: {
- if (isa<BFloat16Type>(*this)) {
- static const Capability cap = Capability::BFloat16TypeKHR;
+ if (isa<BFloat16Type>(type)) {
+ static constexpr auto cap = Capability::BFloat16TypeKHR;
capabilities.push_back(cap);
} else {
- static const Capability cap = Capability::Float16;
+ static constexpr auto cap = Capability::Float16;
capabilities.push_back(cap);
}
break;
@@ -740,23 +756,7 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
- scalarType.getCapabilities(capabilities, storage);
- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
- compositeType.getCapabilities(capabilities, storage);
- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
- imageType.getCapabilities(capabilities, storage);
- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
- sampledImageType.getCapabilities(capabilities, storage);
- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
- matrixType.getCapabilities(capabilities, storage);
- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
- ptrType.getCapabilities(capabilities, storage);
- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- tensorArmType.getCapabilities(capabilities, storage);
- } else {
- llvm_unreachable("invalid SPIR-V Type to getCapabilities");
- }
+ TypeCapabilityVisitor{capabilities, storage}.add(*this);
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
@@ -814,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-void SampledImageType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
-}
-
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
@@ -1172,13 +1166,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
structDecorations);
}
-void StructType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- for (Type elementType : getElementTypes())
- llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
-}
-
llvm::hash_code spirv::hash_value(
const StructType::MemberDecorationInfo &memberDecorationInfo) {
return llvm::hash_combine(memberDecorationInfo.memberIndex,
@@ -1271,16 +1258,10 @@ unsigned MatrixType::getNumElements() const {
return (getImpl()->columnCount) * getNumRows();
}
-void MatrixType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- {
- static const Capability caps[] = {Capability::Matrix};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
- }
- // Add any capabilities associated with the underlying vectors (i.e., columns)
- llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(MatrixType type) {
+ add(type.getColumnType());
+ static constexpr auto cap = Capability::Matrix;
+ capabilities.push_back(cap);
}
//===----------------------------------------------------------------------===//
@@ -1332,12 +1313,9 @@ void TypeExtensionVisitor::addConcrete(TensorArmType type) {
extensions.push_back(ext);
}
-void TensorArmType::getCapabilities(
- SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType())
- .getCapabilities(capabilities, storage);
- static constexpr Capability cap{Capability::TensorsARM};
+void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
+ add(type.getElementType());
+ static constexpr auto cap = Capability::TensorsARM;
capabilities.push_back(cap);
}
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index d24f37b553bb5..1a1c24a09aa8c 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s | FileCheck %s
// `scf.parallel` conversion is not supported yet.
// Make sure that we do not accidentally invalidate this function by removing
@@ -19,14 +19,3 @@ func.func @func(%arg0: i64) {
}
return
}
-
-// -----
-
-// Make sure we don't crash on recursive structs.
-// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
-
-// expected-error at below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
-spirv.module Physical64 GLSL450 {
- spirv.GlobalVariable @recursive:
- !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
-}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2d20ae0a13105..7dab87f8081ed 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -232,7 +232,7 @@ spirv.module Logical GLSL450 attributes {
}
}
-// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [GraphARM, Int8, TensorsARM, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
spirv.module Logical Vulkan attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
@@ -242,3 +242,14 @@ spirv.module Logical Vulkan attributes {
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
}
}
+
+// Check that extension and capability queries handle recursive types.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Addresses, Matrix], [SPV_KHR_storage_buffer_storage_class]>
+spirv.module Physical64 GLSL450 attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.6, [Shader, Addresses], [SPV_KHR_storage_buffer_storage_class]>,
+ #spirv.resource_limits<>>
+} {
+ spirv.GlobalVariable @recursive:
+ !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}
More information about the Mlir-commits
mailing list