[Mlir-commits] [mlir] 32b1f16 - [mlir][spirv] Rework type extension queries (#160020)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 06:08:21 PDT 2025
Author: Jakub Kuderski
Date: 2025-09-22T09:08:18-04:00
New Revision: 32b1f167fbee28debc7527b939a6764575c854a4
URL: https://github.com/llvm/llvm-project/commit/32b1f167fbee28debc7527b939a6764575c854a4
DIFF: https://github.com/llvm/llvm-project/commit/32b1f167fbee28debc7527b939a6764575c854a4.diff
LOG: [mlir][spirv] Rework type extension queries (#160020)
* Fix infinite recursion with nested structs.
* Drop `::getExtensions` function from derived types, so that there's
only one entry point that queries type extensions.
* Move all extension logic to a new helper class -- this way the
`::getExtensions` functions can't diverge across concrete types and
'convenience types' like `CompositeType`.
We should also fix `::getCapabilities` in a similar way and move the
testcase to `vce-deduction.mlir`.
Issue: https://github.com/llvm/llvm-project/issues/159963
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 531feccccb032..6beffc17d6d58 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -89,8 +89,6 @@ class ScalarType : public SPIRVType {
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
@@ -118,8 +116,6 @@ class CompositeType : public SPIRVType {
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
@@ -148,8 +144,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// type.
unsigned getArrayStride() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
@@ -193,8 +187,6 @@ class ImageType
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
@@ -213,8 +205,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
StorageClass getStorageClass() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
@@ -239,8 +229,6 @@ class RuntimeArrayType
/// type.
unsigned getArrayStride() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
@@ -265,8 +253,6 @@ class SampledImageType
Type getImageType() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<spirv::StorageClass> storage = std::nullopt);
void
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<spirv::StorageClass> storage = std::nullopt);
@@ -420,8 +406,6 @@ class StructType
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
@@ -456,8 +440,6 @@ class CooperativeMatrixType
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
@@ -512,8 +494,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
@@ -552,8 +532,6 @@ class TensorArmType
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d890dac96b118..8244e64abba12 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -14,14 +14,67 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::spirv;
+namespace {
+// Helper function to collect extensions implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type extension information. All extension
+// logic should be added to this class, while the
+// `SPIRVType::getExtensions` function should not handle extension-related logic
+// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
+class TypeExtensionVisitor {
+public:
+ TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage)
+ : extensions(extensions), storage(storage) {}
+
+ // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+ // as seen and dispatches to the right concrete `.add` function.
+ void add(SPIRVType type) {
+ if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
+ return;
+
+ TypeSwitch<SPIRVType>(type)
+ .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+ [this](auto concreteType) { addConcrete(concreteType); })
+ .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+ [this](auto concreteType) { add(concreteType.getElementType()); })
+ .Case<StructType>([this](StructType concreteType) {
+ for (Type elementType : concreteType.getElementTypes())
+ add(elementType);
+ })
+ .Case<SampledImageType>([this](SampledImageType concreteType) {
+ add(concreteType.getImageType());
+ })
+ .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ }
+
+ void add(Type type) { add(cast<SPIRVType>(type)); }
+
+private:
+ // Types that add unique extensions.
+ void addConcrete(ScalarType type);
+ void addConcrete(PointerType type);
+ void addConcrete(CooperativeMatrixType type);
+ void addConcrete(TensorArmType type);
+
+ SPIRVType::ExtensionArrayRefVector &extensions;
+ std::optional<StorageClass> storage;
+ llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
@@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
-void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-}
-
void ArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}
-void CompositeType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- TypeSwitch<Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
- StructType>(
- [&](auto type) { type.getExtensions(extensions, storage); })
- .Case<VectorType>([&](VectorType type) {
- return llvm::cast<ScalarType>(type.getElementType())
- .getExtensions(extensions, storage);
- })
- .Case<TensorArmType>([&](TensorArmType type) {
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
- extensions.push_back(ext);
- return llvm::cast<ScalarType>(type.getElementType())
- .getExtensions(extensions, storage);
- })
-
- .Default([](Type) { llvm_unreachable("invalid composite type"); });
-}
-
void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
return getImpl()->use;
}
-void CooperativeMatrixType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
- extensions.push_back(exts);
+void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
+ add(type.getElementType());
+ static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
+ extensions.push_back(ext);
}
void CooperativeMatrixType::getCapabilities(
@@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
-void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
- std::optional<StorageClass>) {
- // Image types do not require extra extensions thus far.
-}
-
void ImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass>) {
@@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const {
return getImpl()->storageClass;
}
-void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
+void TypeExtensionVisitor::addConcrete(PointerType type) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- llvm::cast<SPIRVType>(getPointeeType())
- .getExtensions(extensions, getStorageClass());
+ std::optional<StorageClass> oldStorageClass = storage;
+ storage = type.getStorageClass();
+ add(type.getPointeeType());
+ storage = oldStorageClass;
- if (auto scExts = spirv::getExtensions(getStorageClass()))
+ if (auto scExts = spirv::getExtensions(type.getStorageClass()))
extensions.push_back(*scExts);
}
@@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
-void RuntimeArrayType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-}
-
void RuntimeArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) {
return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
}
-void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- if (isa<BFloat16Type>(*this)) {
- static const Extension ext = Extension::SPV_KHR_bfloat16;
+void TypeExtensionVisitor::addConcrete(ScalarType type) {
+ if (isa<BFloat16Type>(type)) {
+ static constexpr auto ext = Extension::SPV_KHR_bfloat16;
extensions.push_back(ext);
}
@@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
case StorageClass::PushConstant:
case StorageClass::StorageBuffer:
case StorageClass::Uniform:
- if (getIntOrFloatBitWidth() == 8) {
- static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
+ if (type.getIntOrFloatBitWidth() == 8) {
+ static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
+ extensions.push_back(ext);
}
[[fallthrough]];
case StorageClass::Input:
case StorageClass::Output:
- if (getIntOrFloatBitWidth() == 16) {
- static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
+ if (type.getIntOrFloatBitWidth() == 16) {
+ static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
+ extensions.push_back(ext);
}
break;
default:
@@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() {
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
- scalarType.getExtensions(extensions, storage);
- } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
- compositeType.getExtensions(extensions, storage);
- } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
- imageType.getExtensions(extensions, storage);
- } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
- sampledImageType.getExtensions(extensions, storage);
- } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
- matrixType.getExtensions(extensions, storage);
- } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
- ptrType.getExtensions(extensions, storage);
- } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
- tensorArmType.getExtensions(extensions, storage);
- } else {
- llvm_unreachable("invalid SPIR-V Type to getExtensions");
- }
+ TypeExtensionVisitor{extensions, storage}.add(*this);
}
void SPIRVType::getCapabilities(
@@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-void SampledImageType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
-}
-
void SampledImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
structDecorations);
}
-void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- for (Type elementType : getElementTypes())
- llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
-}
-
void StructType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
return (getImpl()->columnCount) * getNumRows();
}
-void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
- llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
-}
-
void MatrixType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
@@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type TensorArmType::getElementType() const { return getImpl()->elementType; }
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
-void TensorArmType::getExtensions(
- SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage) {
-
- llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static constexpr Extension ext{Extension::SPV_ARM_tensors};
+void TypeExtensionVisitor::addConcrete(TensorArmType type) {
+ add(type.getElementType());
+ static constexpr auto ext = Extension::SPV_ARM_tensors;
extensions.push_back(ext);
}
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index 71bf2f3d918e8..d24f37b553bb5 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
// `scf.parallel` conversion is not supported yet.
// Make sure that we do not accidentally invalidate this function by removing
@@ -19,3 +19,14 @@ func.func @func(%arg0: i64) {
}
return
}
+
+// -----
+
+// Make sure we don't crash on recursive structs.
+// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
+
+// expected-error at below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
+spirv.module Physical64 GLSL450 {
+ spirv.GlobalVariable @recursive:
+ !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}
More information about the Mlir-commits
mailing list