[Mlir-commits] [mlir] ca7c058 - [mlir][spirv] Rework type capability queries (#160113)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 22 08:28:45 PDT 2025


Author: Jakub Kuderski
Date: 2025-09-22T15:28:41Z
New Revision: ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735

URL: https://github.com/llvm/llvm-project/commit/ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735
DIFF: https://github.com/llvm/llvm-project/commit/ca7c058701bbbdd1b9bbdb083cbcb21f2bb47735.diff

LOG: [mlir][spirv] Rework type capability queries (#160113)

* Fix infinite recursion with nested structs.
* Drop `::getCapbilities` function from derived types, so that there's
only one entry point that queries type extensions.
* Move all capability logic to a new helper class -- this way the
`::getCapabilities` functions can't diverge across concrete types and
'convenience types' like CompositeType.

Fixes: #159963

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 6beffc17d6d58..475e3f495e065 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -89,9 +89,6 @@ class ScalarType : public SPIRVType {
   /// Returns true if the given float type is valid for the SPIR-V dialect.
   static bool isValid(IntegerType);
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   std::optional<int64_t> getSizeInBytes();
 };
 
@@ -116,9 +113,6 @@ class CompositeType : public SPIRVType {
   /// implementation dependent.
   bool hasCompileTimeKnownNumElements() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   std::optional<int64_t> getSizeInBytes();
 };
 
@@ -144,9 +138,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
   /// type.
   unsigned getArrayStride() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   /// Returns the array size in bytes. Since array type may have an explicit
   /// stride declaration (in bytes), we also include it in the calculation.
   std::optional<int64_t> getSizeInBytes();
@@ -186,9 +177,6 @@ class ImageType
   ImageSamplerUseInfo getSamplerUseInfo() const;
   ImageFormat getImageFormat() const;
   // TODO: Add support for Access qualifier
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V pointer type
@@ -204,9 +192,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
   Type getPointeeType() const;
 
   StorageClass getStorageClass() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V run-time array type
@@ -228,9 +213,6 @@ class RuntimeArrayType
   /// Returns the array stride in bytes. 0 means no stride decorated on this
   /// type.
   unsigned getArrayStride() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 // SPIR-V sampled image type
@@ -252,10 +234,6 @@ class SampledImageType
                    Type imageType);
 
   Type getImageType() const;
-
-  void
-  getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                  std::optional<spirv::StorageClass> storage = std::nullopt);
 };
 
 /// SPIR-V struct type. Two kinds of struct types are supported:
@@ -405,9 +383,6 @@ class StructType
   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
              ArrayRef<MemberDecorationInfo> memberDecorations = {},
              ArrayRef<StructDecorationInfo> structDecorations = {});
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 llvm::hash_code
@@ -440,9 +415,6 @@ class CooperativeMatrixType
   /// Returns the use parameter of the cooperative matrix.
   CooperativeMatrixUseKHR getUse() const;
 
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
-
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 
   ArrayRef<int64_t> getShape() const;
@@ -493,9 +465,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 
   /// Returns the elements' type (i.e, single element type).
   Type getElementType() const;
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 /// SPIR-V TensorARM Type
@@ -531,9 +500,6 @@ class TensorArmType
   ArrayRef<int64_t> getShape() const;
   bool hasRank() const { return !getShape().empty(); }
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
-
-  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
-                       std::optional<StorageClass> storage = std::nullopt);
 };
 
 } // namespace spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 8244e64abba12..7c2f43bea9ddb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -45,17 +45,67 @@ class TypeExtensionVisitor {
       return;
 
     TypeSwitch<SPIRVType>(type)
-        .Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
+        .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
             [this](auto concreteType) { addConcrete(concreteType); })
-        .Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
+        .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
             [this](auto concreteType) { add(concreteType.getElementType()); })
+        .Case<SampledImageType>([this](SampledImageType concreteType) {
+          add(concreteType.getImageType());
+        })
         .Case<StructType>([this](StructType concreteType) {
           for (Type elementType : concreteType.getElementTypes())
             add(elementType);
         })
+        .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+  }
+
+  void add(Type type) { add(cast<SPIRVType>(type)); }
+
+private:
+  // Types that add unique extensions.
+  void addConcrete(CooperativeMatrixType type);
+  void addConcrete(PointerType type);
+  void addConcrete(ScalarType type);
+  void addConcrete(TensorArmType type);
+
+  SPIRVType::ExtensionArrayRefVector &extensions;
+  std::optional<StorageClass> storage;
+  llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
+};
+
+// Helper function to collect capabilities implied by a type by visiting all its
+// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
+//
+// Serves as the source-of-truth for type capability information. All capability
+// logic should be added to this class, while the
+// `SPIRVType::getCapabilities` function should not handle capability-related
+// logic directly and only invoke `TypeCapabilityVisitor::add(Type *)`.
+class TypeCapabilityVisitor {
+public:
+  TypeCapabilityVisitor(SPIRVType::CapabilityArrayRefVector &capabilities,
+                        std::optional<StorageClass> storage)
+      : capabilities(capabilities), storage(storage) {}
+
+  // Main visitor entry point. Adds all extensions to the vector. Saves `type`
+  // as seen and dispatches to the right concrete `.add` function.
+  void add(SPIRVType type) {
+    if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
+      return;
+
+    TypeSwitch<SPIRVType>(type)
+        .Case<CooperativeMatrixType, ImageType, MatrixType, PointerType,
+              RuntimeArrayType, ScalarType, TensorArmType, VectorType>(
+            [this](auto concreteType) { addConcrete(concreteType); })
+        .Case<ArrayType>([this](ArrayType concreteType) {
+          add(concreteType.getElementType());
+        })
         .Case<SampledImageType>([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
+        .Case<StructType>([this](StructType concreteType) {
+          for (Type elementType : concreteType.getElementTypes())
+            add(elementType);
+        })
         .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
   }
 
@@ -63,12 +113,16 @@ class TypeExtensionVisitor {
 
 private:
   // Types that add unique extensions.
-  void addConcrete(ScalarType type);
-  void addConcrete(PointerType type);
   void addConcrete(CooperativeMatrixType type);
+  void addConcrete(ImageType type);
+  void addConcrete(MatrixType type);
+  void addConcrete(PointerType type);
+  void addConcrete(RuntimeArrayType type);
+  void addConcrete(ScalarType type);
   void addConcrete(TensorArmType type);
+  void addConcrete(VectorType type);
 
-  SPIRVType::ExtensionArrayRefVector &extensions;
+  SPIRVType::CapabilityArrayRefVector &capabilities;
   std::optional<StorageClass> storage;
   llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
 };
@@ -118,13 +172,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
-void ArrayType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-}
-
 std::optional<int64_t> ArrayType::getSizeInBytes() {
   auto elementType = llvm::cast<SPIRVType>(getElementType());
   std::optional<int64_t> size = elementType.getSizeInBytes();
@@ -188,30 +235,14 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
   return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
 }
 
-void CompositeType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  TypeSwitch<Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
-            StructType>(
-          [&](auto type) { type.getCapabilities(capabilities, storage); })
-      .Case<VectorType>([&](VectorType type) {
-        auto vecSize = getNumElements();
-        if (vecSize == 8 || vecSize == 16) {
-          static const Capability caps[] = {Capability::Vector16};
-          ArrayRef<Capability> ref(caps, std::size(caps));
-          capabilities.push_back(ref);
-        }
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getCapabilities(capabilities, storage);
-      })
-      .Case<TensorArmType>([&](TensorArmType type) {
-        static constexpr Capability cap{Capability::TensorsARM};
-        capabilities.push_back(cap);
-        return llvm::cast<ScalarType>(type.getElementType())
-            .getCapabilities(capabilities, storage);
-      })
-      .Default([](Type) { llvm_unreachable("invalid composite type"); });
+void TypeCapabilityVisitor::addConcrete(VectorType type) {
+  add(type.getElementType());
+
+  int64_t vecSize = type.getNumElements();
+  if (vecSize == 8 || vecSize == 16) {
+    static constexpr auto cap = Capability::Vector16;
+    capabilities.push_back(cap);
+  }
 }
 
 std::optional<int64_t> CompositeType::getSizeInBytes() {
@@ -317,12 +348,9 @@ void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
   extensions.push_back(ext);
 }
 
-void CooperativeMatrixType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-  static constexpr Capability caps[] = {Capability::CooperativeMatrixKHR};
+void TypeCapabilityVisitor::addConcrete(CooperativeMatrixType type) {
+  add(type.getElementType());
+  static constexpr auto caps = Capability::CooperativeMatrixKHR;
   capabilities.push_back(caps);
 }
 
@@ -428,14 +456,14 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
 
 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
-void ImageType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass>) {
-  if (auto dimCaps = spirv::getCapabilities(getDim()))
+void TypeCapabilityVisitor::addConcrete(ImageType type) {
+  if (auto dimCaps = spirv::getCapabilities(type.getDim()))
     capabilities.push_back(*dimCaps);
 
-  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+  if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
     capabilities.push_back(*fmtCaps);
+
+  add(type.getElementType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -486,15 +514,15 @@ void TypeExtensionVisitor::addConcrete(PointerType type) {
     extensions.push_back(*scExts);
 }
 
-void PointerType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
+void TypeCapabilityVisitor::addConcrete(PointerType type) {
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
-  llvm::cast<SPIRVType>(getPointeeType())
-      .getCapabilities(capabilities, getStorageClass());
+  std::optional<StorageClass> oldStorageClass = storage;
+  storage = type.getStorageClass();
+  add(type.getPointeeType());
+  storage = oldStorageClass;
 
-  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+  if (auto scCaps = spirv::getCapabilities(type.getStorageClass()))
     capabilities.push_back(*scCaps);
 }
 
@@ -534,16 +562,10 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
 
-void RuntimeArrayType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  {
-    static const Capability caps[] = {Capability::Shader};
-    ArrayRef<Capability> ref(caps, std::size(caps));
-    capabilities.push_back(ref);
-  }
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
+  add(type.getElementType());
+  static constexpr auto cap = Capability::Shader;
+  capabilities.push_back(cap);
 }
 
 //===----------------------------------------------------------------------===//
@@ -601,10 +623,8 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
   }
 }
 
-void ScalarType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  unsigned bitwidth = getIntOrFloatBitWidth();
+void TypeCapabilityVisitor::addConcrete(ScalarType type) {
+  unsigned bitwidth = type.getIntOrFloatBitWidth();
 
   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
@@ -613,15 +633,13 @@ void ScalarType::getCapabilities(
 #define STORAGE_CASE(storage, cap8, cap16)                                     \
   case StorageClass::storage: {                                                \
     if (bitwidth == 8) {                                                       \
-      static const Capability caps[] = {Capability::cap8};                     \
-      ArrayRef<Capability> ref(caps, std::size(caps));                         \
-      capabilities.push_back(ref);                                             \
+      static constexpr auto cap = Capability::cap8;                            \
+      capabilities.push_back(cap);                                             \
       return;                                                                  \
     }                                                                          \
     if (bitwidth == 16) {                                                      \
-      static const Capability caps[] = {Capability::cap16};                    \
-      ArrayRef<Capability> ref(caps, std::size(caps));                         \
-      capabilities.push_back(ref);                                             \
+      static constexpr auto cap = Capability::cap16;                           \
+      capabilities.push_back(cap);                                             \
       return;                                                                  \
     }                                                                          \
     /* For 64-bit integers/floats, Int64/Float64 enables support for all */    \
@@ -640,9 +658,8 @@ void ScalarType::getCapabilities(
     case StorageClass::Input:
     case StorageClass::Output: {
       if (bitwidth == 16) {
-        static const Capability caps[] = {Capability::StorageInputOutput16};
-        ArrayRef<Capability> ref(caps, std::size(caps));
-        capabilities.push_back(ref);
+        static constexpr auto cap = Capability::StorageInputOutput16;
+        capabilities.push_back(cap);
         return;
       }
       break;
@@ -658,12 +675,11 @@ void ScalarType::getCapabilities(
 
 #define WIDTH_CASE(type, width)                                                \
   case width: {                                                                \
-    static const Capability caps[] = {Capability::type##width};                \
-    ArrayRef<Capability> ref(caps, std::size(caps));                           \
-    capabilities.push_back(ref);                                               \
+    static constexpr auto cap = Capability::type##width;                       \
+    capabilities.push_back(cap);                                               \
   } break
 
-  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
+  if (auto intType = dyn_cast<IntegerType>(type)) {
     switch (bitwidth) {
       WIDTH_CASE(Int, 8);
       WIDTH_CASE(Int, 16);
@@ -675,14 +691,14 @@ void ScalarType::getCapabilities(
       llvm_unreachable("invalid bitwidth to getCapabilities");
     }
   } else {
-    assert(llvm::isa<FloatType>(*this));
+    assert(isa<FloatType>(type));
     switch (bitwidth) {
     case 16: {
-      if (isa<BFloat16Type>(*this)) {
-        static const Capability cap = Capability::BFloat16TypeKHR;
+      if (isa<BFloat16Type>(type)) {
+        static constexpr auto cap = Capability::BFloat16TypeKHR;
         capabilities.push_back(cap);
       } else {
-        static const Capability cap = Capability::Float16;
+        static constexpr auto cap = Capability::Float16;
         capabilities.push_back(cap);
       }
       break;
@@ -740,23 +756,7 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
 void SPIRVType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
-    scalarType.getCapabilities(capabilities, storage);
-  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
-    compositeType.getCapabilities(capabilities, storage);
-  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
-    imageType.getCapabilities(capabilities, storage);
-  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
-    sampledImageType.getCapabilities(capabilities, storage);
-  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
-    matrixType.getCapabilities(capabilities, storage);
-  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
-    ptrType.getCapabilities(capabilities, storage);
-  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    tensorArmType.getCapabilities(capabilities, storage);
-  } else {
-    llvm_unreachable("invalid SPIR-V Type to getCapabilities");
-  }
+  TypeCapabilityVisitor{capabilities, storage}.add(*this);
 }
 
 std::optional<int64_t> SPIRVType::getSizeInBytes() {
@@ -814,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-void SampledImageType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
-}
-
 //===----------------------------------------------------------------------===//
 // StructType
 //===----------------------------------------------------------------------===//
@@ -1172,13 +1166,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
                       structDecorations);
 }
 
-void StructType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  for (Type elementType : getElementTypes())
-    llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
-}
-
 llvm::hash_code spirv::hash_value(
     const StructType::MemberDecorationInfo &memberDecorationInfo) {
   return llvm::hash_combine(memberDecorationInfo.memberIndex,
@@ -1271,16 +1258,10 @@ unsigned MatrixType::getNumElements() const {
   return (getImpl()->columnCount) * getNumRows();
 }
 
-void MatrixType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  {
-    static const Capability caps[] = {Capability::Matrix};
-    ArrayRef<Capability> ref(caps, std::size(caps));
-    capabilities.push_back(ref);
-  }
-  // Add any capabilities associated with the underlying vectors (i.e., columns)
-  llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
+void TypeCapabilityVisitor::addConcrete(MatrixType type) {
+  add(type.getColumnType());
+  static constexpr auto cap = Capability::Matrix;
+  capabilities.push_back(cap);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1332,12 +1313,9 @@ void TypeExtensionVisitor::addConcrete(TensorArmType type) {
   extensions.push_back(ext);
 }
 
-void TensorArmType::getCapabilities(
-    SPIRVType::CapabilityArrayRefVector &capabilities,
-    std::optional<StorageClass> storage) {
-  llvm::cast<SPIRVType>(getElementType())
-      .getCapabilities(capabilities, storage);
-  static constexpr Capability cap{Capability::TensorsARM};
+void TypeCapabilityVisitor::addConcrete(TensorArmType type) {
+  add(type.getElementType());
+  static constexpr auto cap = Capability::TensorsARM;
   capabilities.push_back(cap);
 }
 

diff  --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index d24f37b553bb5..1a1c24a09aa8c 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --convert-scf-to-spirv %s --verify-diagnostics --split-input-file | FileCheck %s
+// RUN: mlir-opt --convert-scf-to-spirv %s | FileCheck %s
 
 // `scf.parallel` conversion is not supported yet.
 // Make sure that we do not accidentally invalidate this function by removing
@@ -19,14 +19,3 @@ func.func @func(%arg0: i64) {
   }
   return
 }
-
-// -----
-
-// Make sure we don't crash on recursive structs.
-// TODO(https://github.com/llvm/llvm-project/issues/159963): Promote this to a `vce-deduction.mlir` testcase.
-
-// expected-error at below {{failed to legalize operation 'spirv.module' that was explicitly marked illegal}}
-spirv.module Physical64 GLSL450 {
-  spirv.GlobalVariable @recursive:
-    !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
-}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2d20ae0a13105..7dab87f8081ed 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -232,7 +232,7 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [GraphARM, Int8, TensorsARM, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
 spirv.module Logical Vulkan attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
@@ -242,3 +242,14 @@ spirv.module Logical Vulkan attributes {
       spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
   }
 }
+
+// Check that extension and capability queries handle recursive types.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Addresses, Matrix], [SPV_KHR_storage_buffer_storage_class]>
+spirv.module Physical64 GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.6, [Shader, Addresses], [SPV_KHR_storage_buffer_storage_class]>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @recursive:
+    !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+}


        


More information about the Mlir-commits mailing list