[Mlir-commits] [mlir] [mlir][spirv] Rework type extension queries (PR #160020)

Jakub Kuderski llvmlistbot at llvm.org
Sun Sep 21 19:21:22 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/160020

>From 901902b287ff692792b6f09aaeb8b643a6a203fe Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 21 Sep 2025 22:01:08 -0400
Subject: [PATCH 1/3] [mlir][spirv] Rework type extension queries

* Fix infinite recursion with nested structs.
* Move all extension logic to a new helper class -- this way
  `::getExtensions` functions can't diverge across concrete types and
  'convenience types' like `CompositeType`.

We should also fix `::getCapabilities` in a similar way and move the
testcase to `vce-deduction.mlir`.

Issue: https://github.com/llvm/llvm-project/issues/159963
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 174 +++++++++++-------
 .../Conversion/SCFToSPIRV/unsupported.mlir    |  13 +-
 2 files changed, 120 insertions(+), 67 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d890dac96b118..85250700b9bf9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -14,14 +14,73 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <cstdint>
 
 using namespace mlir;
 using namespace mlir::spirv;
 
+namespace {
+// Helper function to collect extensions implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type extension information. All extension
+// logic should be added to this class, while
+// `*Type::getExtensions` functions should not handle extension-related logic
+// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
+class TypeExtensionVisitor {
+  SPIRVType::ExtensionArrayRefVector &extensions;
+  std::optional<StorageClass> storage;
+  DenseSet<Type> seen;
+
+public:
+  TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
+                       std::optional<StorageClass> storage)
+      : extensions(extensions), storage(storage) {}
+
+  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+  // as seen and dispatches to the right concrete `.add` function.
+  void add(SPIRVType type) {
+    if (auto [_it, inserted] = seen.insert(type); !inserted)
+      return;
+
+    TypeSwitch<SPIRVType>(type)
+        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
+              VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
+              ImageType, SampledImageType>(
+            [this](auto concreteType) { add(concreteType); })
+        .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+  }
+
+  // Convenience overloads for use in `T::getExtensions` functions.
+  void add(Type type) { add(cast<SPIRVType>(type)); }
+  void add(Type *type) { add(cast<SPIRVType>(*type)); }
+
+  // Types that add unique extensions.
+  void add(ScalarType type);
+  void add(PointerType type);
+  void add(CooperativeMatrixType type);
+  void add(TensorArmType type);
+
+  // Trivial passthrough without any new extensions.
+  void add(VectorType type) { add(type.getElementType()); }
+  void add(ArrayType type) { add(type.getElementType()); }
+  void add(RuntimeArrayType type) { add(type.getElementType()); }
+  void add(StructType type) {
+    for (Type elementType : type.getElementTypes())
+      add(elementType);
+  }
+  void add(MatrixType type) { add(type.getElementType()); }
+  void add(ImageType type) { add(type.getElementType()); }
+  void add(SampledImageType type) { add(type.getImageType()); }
+};
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ArrayType
 //===----------------------------------------------------------------------===//
@@ -67,7 +126,7 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                               std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void ArrayType::getCapabilities(
@@ -143,22 +202,7 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
 void CompositeType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
-            StructType>(
-          [&](auto type) { type.getExtensions(extensions, storage); })
-      .Case<VectorType>([&](VectorType type) {
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getExtensions(extensions, storage);
-      })
-      .Case<TensorArmType>([&](TensorArmType type) {
-        static constexpr Extension ext{Extension::SPV_ARM_tensors};
-        extensions.push_back(ext);
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getExtensions(extensions, storage);
-      })
-
-      .Default([](Type) { llvm_unreachable("invalid composite type"); });
+  TypeExtensionVisitor{extensions, storage}.add(cast<SPIRVType>(*this));
 }
 
 void CompositeType::getCapabilities(
@@ -284,12 +328,16 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
   return getImpl()->use;
 }
 
+void TypeExtensionVisitor::add(CooperativeMatrixType type) {
+  add(type.getElementType());
+  static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
+  extensions.push_back(ext);
+}
+
 void CooperativeMatrixType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
-  extensions.push_back(exts);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void CooperativeMatrixType::getCapabilities(
@@ -403,9 +451,9 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
 
 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
-void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
-                              std::optional<StorageClass>) {
-  // Image types do not require extra extensions thus far.
+void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                              std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void ImageType::getCapabilities(
@@ -454,17 +502,23 @@ StorageClass PointerType::getStorageClass() const {
   return getImpl()->storageClass;
 }
 
-void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                                std::optional<StorageClass> storage) {
+void TypeExtensionVisitor::add(PointerType type) {
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
-  llvm::cast<SPIRVType>(getPointeeType())
-      .getExtensions(extensions, getStorageClass());
+  std::optional<StorageClass> oldStorageClass = storage;
+  storage = type.getStorageClass();
+  add(type.getPointeeType());
+  storage = oldStorageClass;
 
-  if (auto scExts = spirv::getExtensions(getStorageClass()))
+  if (auto scExts = spirv::getExtensions(type.getStorageClass()))
     extensions.push_back(*scExts);
 }
 
+void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                                std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
 void PointerType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
@@ -516,7 +570,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
 void RuntimeArrayType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void RuntimeArrayType::getCapabilities(
@@ -553,10 +607,9 @@ bool ScalarType::isValid(IntegerType type) {
   return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
 }
 
-void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
-                               std::optional<StorageClass> storage) {
-  if (isa<BFloat16Type>(*this)) {
-    static const Extension ext = Extension::SPV_KHR_bfloat16;
+void TypeExtensionVisitor::add(ScalarType type) {
+  if (isa<BFloat16Type>(type)) {
+    static constexpr auto ext = Extension::SPV_KHR_bfloat16;
     extensions.push_back(ext);
   }
 
@@ -570,18 +623,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
   case StorageClass::PushConstant:
   case StorageClass::StorageBuffer:
   case StorageClass::Uniform:
-    if (getIntOrFloatBitWidth() == 8) {
-      static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
-      ArrayRef<Extension> ref(exts, std::size(exts));
-      extensions.push_back(ref);
+    if (type.getIntOrFloatBitWidth() == 8) {
+      static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
+      extensions.push_back(ext);
     }
     [[fallthrough]];
   case StorageClass::Input:
   case StorageClass::Output:
-    if (getIntOrFloatBitWidth() == 16) {
-      static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
-      ArrayRef<Extension> ref(exts, std::size(exts));
-      extensions.push_back(ref);
+    if (type.getIntOrFloatBitWidth() == 16) {
+      static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
+      extensions.push_back(ext);
     }
     break;
   default:
@@ -589,6 +640,11 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
   }
 }
 
+void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                               std::optional<StorageClass> storage) {
+  TypeExtensionVisitor{extensions, storage}.add(this);
+}
+
 void ScalarType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
@@ -722,23 +778,7 @@ bool SPIRVType::isScalarOrVector() {
 
 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                               std::optional<StorageClass> storage) {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
-    scalarType.getExtensions(extensions, storage);
-  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
-    compositeType.getExtensions(extensions, storage);
-  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
-    imageType.getExtensions(extensions, storage);
-  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
-    sampledImageType.getExtensions(extensions, storage);
-  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
-    matrixType.getExtensions(extensions, storage);
-  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
-    ptrType.getExtensions(extensions, storage);
-  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    tensorArmType.getExtensions(extensions, storage);
-  } else {
-    llvm_unreachable("invalid SPIR-V Type to getExtensions");
-  }
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void SPIRVType::getCapabilities(
@@ -821,7 +861,7 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 void SampledImageType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-  llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void SampledImageType::getCapabilities(
@@ -1184,8 +1224,7 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
 
 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
-  for (Type elementType : getElementTypes())
-    llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void StructType::getCapabilities(
@@ -1289,7 +1328,7 @@ unsigned MatrixType::getNumElements() const {
 
 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void MatrixType::getCapabilities(
@@ -1347,13 +1386,16 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
 ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
 
+void TypeExtensionVisitor::add(TensorArmType type) {
+  add(type.getElementType());
+  static constexpr auto ext = Extension::SPV_ARM_tensors;
+  extensions.push_back(ext);
+}
+
 void TensorArmType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
-
-  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static constexpr Extension ext{Extension::SPV_ARM_tensors};
-  extensions.push_back(ext);
+  TypeExtensionVisitor{extensions, storage}.add(this);
 }
 
 void TensorArmType::getCapabilities(
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index 71bf2f3d918e8..d24f37b553bb5 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
 
 // `scf.parallel` conversion is not supported yet.
 // Make sure that we do not accidentally invalidate this function by removing
@@ -19,3 +19,14 @@ func.func @func(%arg0: i64) {
   }
   return
 }
+
+// -----
+
+// Make sure we don't crash on recursive structs.
+// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
+
+// expected-error at below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
+spirv.module Physical64 GLSL450 {
+  spirv.GlobalVariable @recursive:
+    !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}

>From d07dc4c7da2389f3df2ae4f3ff7a13733b7cdff0 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 21 Sep 2025 22:15:57 -0400
Subject: [PATCH 2/3] Make concrete add functions private

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 52 +++++++++++-------------
 1 file changed, 24 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 85250700b9bf9..0f13058dafadb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -33,10 +33,6 @@ namespace {
 // `*Type::getExtensions` functions should not handle extension-related logic
 // directly and only invoke `TypeExtensionVisitor::add(Type *)`.
 class TypeExtensionVisitor {
-  SPIRVType::ExtensionArrayRefVector &extensions;
-  std::optional<StorageClass> storage;
-  DenseSet<Type> seen;
-
 public:
   TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
                        std::optional<StorageClass> storage)
@@ -49,10 +45,17 @@ class TypeExtensionVisitor {
       return;
 
     TypeSwitch<SPIRVType>(type)
-        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType,
-              VectorType, ArrayType, RuntimeArrayType, StructType, MatrixType,
-              ImageType, SampledImageType>(
-            [this](auto concreteType) { add(concreteType); })
+        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+            [this](auto concreteType) { addConcrete(concreteType); })
+        .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+            [this](auto concreteType) { add(concreteType.getElementType()); })
+        .Case<StructType>([this](StructType concreteType) {
+          for (Type elementType : concreteType.getElementTypes())
+            add(elementType);
+        })
+        .Case<SampledImageType>([this](SampledImageType concreteType) {
+          add(concreteType.getImageType());
+        })
         .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
   }
 
@@ -60,23 +63,16 @@ class TypeExtensionVisitor {
   void add(Type type) { add(cast<SPIRVType>(type)); }
   void add(Type *type) { add(cast<SPIRVType>(*type)); }
 
+private:
   // Types that add unique extensions.
-  void add(ScalarType type);
-  void add(PointerType type);
-  void add(CooperativeMatrixType type);
-  void add(TensorArmType type);
-
-  // Trivial passthrough without any new extensions.
-  void add(VectorType type) { add(type.getElementType()); }
-  void add(ArrayType type) { add(type.getElementType()); }
-  void add(RuntimeArrayType type) { add(type.getElementType()); }
-  void add(StructType type) {
-    for (Type elementType : type.getElementTypes())
-      add(elementType);
-  }
-  void add(MatrixType type) { add(type.getElementType()); }
-  void add(ImageType type) { add(type.getElementType()); }
-  void add(SampledImageType type) { add(type.getImageType()); }
+  void addConcrete(ScalarType type);
+  void addConcrete(PointerType type);
+  void addConcrete(CooperativeMatrixType type);
+  void addConcrete(TensorArmType type);
+
+  SPIRVType::ExtensionArrayRefVector &extensions;
+  std::optional<StorageClass> storage;
+  DenseSet<Type> seen;
 };
 
 } // namespace
@@ -328,7 +324,7 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
   return getImpl()->use;
 }
 
-void TypeExtensionVisitor::add(CooperativeMatrixType type) {
+void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
   add(type.getElementType());
   static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
   extensions.push_back(ext);
@@ -502,7 +498,7 @@ StorageClass PointerType::getStorageClass() const {
   return getImpl()->storageClass;
 }
 
-void TypeExtensionVisitor::add(PointerType type) {
+void TypeExtensionVisitor::addConcrete(PointerType type) {
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
   std::optional<StorageClass> oldStorageClass = storage;
@@ -607,7 +603,7 @@ bool ScalarType::isValid(IntegerType type) {
   return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
 }
 
-void TypeExtensionVisitor::add(ScalarType type) {
+void TypeExtensionVisitor::addConcrete(ScalarType type) {
   if (isa<BFloat16Type>(type)) {
     static constexpr auto ext = Extension::SPV_KHR_bfloat16;
     extensions.push_back(ext);
@@ -1386,7 +1382,7 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
 ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
 
-void TypeExtensionVisitor::add(TensorArmType type) {
+void TypeExtensionVisitor::addConcrete(TensorArmType type) {
   add(type.getElementType());
   static constexpr auto ext = Extension::SPV_ARM_tensors;
   extensions.push_back(ext);

>From 9054751bfc59895aba99743e9003667450878636 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 21 Sep 2025 22:21:13 -0400
Subject: [PATCH 3/3] Use SmallDenseSet

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 0f13058dafadb..dbad7d6cf41ee 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -72,7 +72,7 @@ class TypeExtensionVisitor {
 
   SPIRVType::ExtensionArrayRefVector &extensions;
   std::optional<StorageClass> storage;
-  DenseSet<Type> seen;
+  llvm::SmallDenseSet<Type> seen;
 };
 
 } // namespace



More information about the Mlir-commits mailing list