[Mlir-commits] [mlir] [mlir][SPIR-V] Refine OpTypeImage capability inference (PR #195060)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 04:16:34 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arseniy Obolenskiy (aobolensk)

<details>
<summary>Changes</summary>

Capability requirements for OpTypeImage are determined by Dim, Sampled, MS, and Arrayed

related to LLVM SPIR-V backend PR https://github.com/llvm/llvm-project/pull/192626

---
Full diff: https://github.com/llvm/llvm-project/pull/195060.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+95-4) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir (+108) 
- (modified) mlir/test/Target/SPIRV/image.mlir (+55-1) 


``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index aafc2180761d0..ccc9ce38072d4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -46,9 +46,10 @@ class TypeExtensionVisitor {
       return;
 
     TypeSwitch<SPIRVType>(type)
-        .Case<CooperativeMatrixType, PointerType, ScalarType, TensorArmType>(
+        .Case<CooperativeMatrixType, ImageType, PointerType, ScalarType,
+              TensorArmType>(
             [this](auto concreteType) { addConcrete(concreteType); })
-        .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
+        .Case<ArrayType, MatrixType, RuntimeArrayType, VectorType>(
             [this](auto concreteType) { add(concreteType.getElementType()); })
         .Case([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
@@ -66,6 +67,7 @@ class TypeExtensionVisitor {
 private:
   // Types that add unique extensions.
   void addConcrete(CooperativeMatrixType type);
+  void addConcrete(ImageType type);
   void addConcrete(PointerType type);
   void addConcrete(ScalarType type);
   void addConcrete(TensorArmType type);
@@ -423,13 +425,102 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
 
 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
+void TypeExtensionVisitor::addConcrete(ImageType type) {
+  // OpTypeImage with a 64-bit integer Sampled Type requires the
+  // SPV_EXT_shader_image_int64 extension (companion to Int64ImageEXT).
+  if (auto intTy = dyn_cast<IntegerType>(type.getElementType());
+      intTy && intTy.getWidth() == 64) {
+    static constexpr auto ext = Extension::SPV_EXT_shader_image_int64;
+    extensions.push_back(ext);
+  }
+  add(type.getElementType());
+}
+
 void TypeCapabilityVisitor::addConcrete(ImageType type) {
-  if (auto dimCaps = spirv::getCapabilities(type.getDim()))
-    capabilities.push_back(*dimCaps);
+  // Capability requirements for OpTypeImage are determined jointly by Dim,
+  // Sampled, MS, and Arrayed - see the SPIR-V spec's "Capabilities" column on
+  // OpTypeImage.
+  Dim dim = type.getDim();
+  bool isMultisampled =
+      type.getSamplingInfo() == ImageSamplingInfo::MultiSampled;
+  bool isArrayed = type.getArrayedInfo() == ImageArrayedInfo::Arrayed;
+  ImageSamplerUseInfo sampler = type.getSamplerUseInfo();
+  bool noSampler = sampler == ImageSamplerUseInfo::NoSampler;
+  bool needSampler = sampler == ImageSamplerUseInfo::NeedSampler;
+
+  switch (dim) {
+  case Dim::Dim1D: {
+    if (needSampler) {
+      static constexpr auto cap = Capability::Sampled1D;
+      capabilities.push_back(cap);
+    } else if (noSampler) {
+      static constexpr auto cap = Capability::Image1D;
+      capabilities.push_back(cap);
+    } else {
+      static constexpr Capability caps[] = {Capability::Image1D,
+                                            Capability::Sampled1D};
+      capabilities.push_back(caps);
+    }
+    break;
+  }
+  case Dim::Dim2D:
+    if (isMultisampled && noSampler) {
+      static constexpr auto cap = Capability::StorageImageMultisample;
+      capabilities.push_back(cap);
+    }
+    if (isMultisampled && isArrayed) {
+      static constexpr auto cap = Capability::ImageMSArray;
+      capabilities.push_back(cap);
+    }
+    break;
+  case Dim::Dim3D:
+    break;
+  case Dim::Cube: {
+    static constexpr auto shaderCap = Capability::Shader;
+    capabilities.push_back(shaderCap);
+    if (isArrayed) {
+      static constexpr auto cap = Capability::ImageCubeArray;
+      capabilities.push_back(cap);
+    }
+    break;
+  }
+  case Dim::Rect: {
+    static constexpr Capability caps[] = {Capability::ImageRect,
+                                          Capability::SampledRect};
+    capabilities.push_back(caps);
+    break;
+  }
+  case Dim::Buffer: {
+    if (needSampler) {
+      static constexpr auto cap = Capability::SampledBuffer;
+      capabilities.push_back(cap);
+    } else if (noSampler) {
+      static constexpr auto cap = Capability::ImageBuffer;
+      capabilities.push_back(cap);
+    } else {
+      static constexpr Capability caps[] = {Capability::ImageBuffer,
+                                            Capability::SampledBuffer};
+      capabilities.push_back(caps);
+    }
+    break;
+  }
+  case Dim::SubpassData: {
+    static constexpr auto cap = Capability::InputAttachment;
+    capabilities.push_back(cap);
+    break;
+  }
+  }
 
   if (auto fmtCaps = spirv::getCapabilities(type.getImageFormat()))
     capabilities.push_back(*fmtCaps);
 
+  // OpTypeImage with a 64-bit integer Sampled Type requires Int64ImageEXT.
+  if (auto intTy = dyn_cast<IntegerType>(type.getElementType());
+      intTy && intTy.getWidth() == 64) {
+    static constexpr auto cap = Capability::Int64ImageEXT;
+    capabilities.push_back(cap);
+  }
+
   add(type.getElementType());
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index d382ca5691fab..fd500d70f685e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -303,3 +303,111 @@ spirv.module Physical64 GLSL450 attributes {
   spirv.GlobalVariable @recursive:
     !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
 }
+
+//===----------------------------------------------------------------------===//
+// Image type capabilities
+//===----------------------------------------------------------------------===//
+
+// 2D + MultiSampled + NoSampler requires StorageImageMultisample.
+// CHECK: requires #spirv.vce<v1.0, [Shader, StorageImageMultisample, Matrix], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, StorageImageMultisample], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_2d_ms_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// 2D + MultiSampled + Arrayed + NoSampler requires both StorageImageMultisample
+// and ImageMSArray.
+// CHECK: requires #spirv.vce<v1.0, [Shader, StorageImageMultisample, ImageMSArray, Matrix], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, StorageImageMultisample, ImageMSArray], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_2d_ms_arrayed_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// 2D + SingleSampled does not request multisample-related caps.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_2d_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// Cube without Arrayed requires Shader but not ImageCubeArray.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Matrix], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_cube bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Cube, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, UniformConstant>
+}
+
+// Cube + Arrayed requires ImageCubeArray (which transitively implies
+// SampledCubeArray).
+// CHECK: requires #spirv.vce<v1.0, [Shader, ImageCubeArray, Matrix, SampledCubeArray], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, ImageCubeArray], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_cube_arrayed bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Cube, NoDepth, Arrayed, SingleSampled, NeedSampler, Unknown>, UniformConstant>
+}
+
+// 1D + NoSampler requires Image1D directly (and pulls in Sampled1D as an
+// implied capability, but the target_env need not list Sampled1D).
+// CHECK: requires #spirv.vce<v1.0, [Shader, Image1D, Matrix, Sampled1D], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, Image1D], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_1d_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// 1D + NeedSampler requires Sampled1D.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Sampled1D, Matrix], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, Sampled1D], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_1d_sampled bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Unknown>, UniformConstant>
+}
+
+// Buffer + NoSampler requires ImageBuffer (which transitively implies
+// SampledBuffer).
+// CHECK: requires #spirv.vce<v1.0, [Shader, ImageBuffer, Matrix, SampledBuffer], []>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, ImageBuffer], []>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_buffer_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// 64-bit integer sampled type requires Int64ImageEXT and the
+// SPV_EXT_shader_image_int64 extension.
+// CHECK: requires #spirv.vce<v1.0, [Shader, Int64ImageEXT, Int64, Matrix], [SPV_EXT_shader_image_int64]>
+spirv.module Logical GLSL450 attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [Shader, Int64, Int64ImageEXT], [SPV_EXT_shader_image_int64]>,
+    #spirv.resource_limits<>>
+} {
+  spirv.GlobalVariable @img_2d_i64 bind(0, 0) :
+    !spirv.ptr<!spirv.image<i64, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
diff --git a/mlir/test/Target/SPIRV/image.mlir b/mlir/test/Target/SPIRV/image.mlir
index a0c245c6d55dc..ac3dfc3856426 100644
--- a/mlir/test/Target/SPIRV/image.mlir
+++ b/mlir/test/Target/SPIRV/image.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s
 
 // RUN: %if spirv-tools %{ rm -rf %t %}
 // RUN: %if spirv-tools %{ mkdir %t %}
@@ -15,3 +15,57 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Sampled
   // CHECK: !spirv.ptr<!spirv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
   spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.image<i32, SubpassData, DepthUnknown, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
 }
+
+// -----
+
+// 2D + MultiSampled + NoSampler storage image — validates StorageImageMultisample.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, StorageImageMultisample], []> {
+  // CHECK: !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_2d_ms_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, NonArrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// -----
+
+// 2D + MultiSampled + Arrayed + NoSampler — validates StorageImageMultisample + ImageMSArray.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, StorageImageMultisample, ImageMSArray], []> {
+  // CHECK: !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_2d_ms_arrayed bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim2D, NoDepth, Arrayed, MultiSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// -----
+
+// Cube + Arrayed sampled image — validates ImageCubeArray.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, ImageCubeArray, SampledCubeArray], []> {
+  // CHECK: !spirv.ptr<!spirv.image<f32, Cube, NoDepth, Arrayed, SingleSampled, NeedSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_cube_arrayed bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Cube, NoDepth, Arrayed, SingleSampled, NeedSampler, Unknown>, UniformConstant>
+}
+
+// -----
+
+// 1D storage image — validates Image1D.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Image1D, Sampled1D], []> {
+  // CHECK: !spirv.ptr<!spirv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_1d_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Dim1D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// -----
+
+// Buffer storage image — validates ImageBuffer.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, ImageBuffer, SampledBuffer], []> {
+  // CHECK: !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_buffer_storage bind(0, 0) :
+    !spirv.ptr<!spirv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}
+
+// -----
+
+// 64-bit integer sampled type — validates Int64ImageEXT + SPV_EXT_shader_image_int64.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Int64, Int64ImageEXT], [SPV_EXT_shader_image_int64]> {
+  // CHECK: !spirv.ptr<!spirv.image<i64, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+  spirv.GlobalVariable @img_2d_i64 bind(0, 0) :
+    !spirv.ptr<!spirv.image<i64, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>, UniformConstant>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/195060


More information about the Mlir-commits mailing list