[Mlir-commits] [mlir] [mlir][memref][spirv] Add SPIR-V Image Lowering (PR #150978)

Jack Frankland llvmlistbot at llvm.org
Mon Jul 28 23:51:14 PDT 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/150978

>From cab184f34a0518744493eeacf2f9ee1e1bb114de Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 28 Jul 2025 15:47:12 +0100
Subject: [PATCH 1/2] [mlir][memref][spirv] Add SPIR-V Image Lowering

Adds an initial conversion in the Memref -> SPIR-V lowering for images.
Any memref in the "Image" storage-class/address-space will be considered
for lowering to the `!spirv.image` type during Memref to SPIR-V
conversion. Initially only the reading of sampled images are support and
images are read via the `OpImageFetch` instruction. Future work should
expand the conversion patterns to target non-sampled images and add
support for image write operations.

Images are supported for fp32, fp16, int32, uint32, int16 and uint16
types and lit tests have been added to verify this is the case along
with negative testing to check the cases where images aren't supported.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 106 +++++++++-
 .../SPIRV/Transforms/SPIRVConversion.cpp      |  77 +++++++
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 188 ++++++++++++++++++
 3 files changed, 365 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index b866afbce98b0..8547b9e133dcd 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -243,6 +243,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts memref.load to spirv.Image + spirv.ImageFetch
+class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
+public:
+  using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts memref.store to spirv.Store on integers.
 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
 public:
@@ -527,6 +537,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!memrefType.getElementType().isSignlessInteger())
     return failure();
 
+  if (memrefType.getMemorySpace() ==
+      spirv::StorageClassAttr::get(rewriter.getContext(),
+                                   spirv::StorageClass::Image))
+    return failure();
+
   const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   Value accessChain =
       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
@@ -643,6 +658,12 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
+
+  if (memrefType.getMemorySpace() ==
+      spirv::StorageClassAttr::get(rewriter.getContext(),
+                                   spirv::StorageClass::Image))
+    return failure();
+
   Value loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
@@ -661,6 +682,79 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   return success();
 }
 
+LogicalResult
+ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  const auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
+  if (memrefType.getMemorySpace() !=
+      spirv::StorageClassAttr::get(rewriter.getContext(),
+                                   spirv::StorageClass::Image))
+    return failure();
+
+  const auto loadPtr = adaptor.getMemref();
+  const auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
+  if (failed(memoryRequirements))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine memory requirements");
+
+  const auto [memoryAccess, alignment] = *memoryRequirements;
+
+  if (!loadOp.getMemRefType().hasRank())
+    return rewriter.notifyMatchFailure(
+        loadOp, "cannot lower unranked memrefs to SPIR-V images");
+
+  // We currently only support lowering of scalar memref elements to texels in
+  // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
+  // elements to texels in richer formats.
+  if (!loadOp.getMemRefType().getElementType().isIntOrFloat())
+    return rewriter.notifyMatchFailure(
+        loadOp, "cannot lower memrefs who's element type is not int or float "
+                "to SPIR-V images");
+
+  // We currently only support sampled images since OpImageFetch does not work
+  // for plain images and the OpImageRead instruction needs to be materialized
+  // instead or texels need to be accessed via atomics through a texel pointer.
+  // Future work will generalize support to plain images.
+  if (const auto convertedPointeeType = cast<spirv::PointerType>(
+          getTypeConverter()->convertType(loadOp.getMemRefType()));
+      !isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
+    return rewriter.notifyMatchFailure(loadOp,
+                                       "cannot lower memrefs which do not "
+                                       "convert to SPIR-V sampled images");
+
+  // Materialize the lowering.
+  auto imageLoadOp = rewriter.create<spirv::LoadOp>(loadOp->getLoc(), loadPtr,
+                                                    memoryAccess, alignment);
+  // Extract the image from the sampled image.
+  auto imageOp = rewriter.create<spirv::ImageOp>(loadOp->getLoc(), imageLoadOp);
+
+  // Build a vector of coordinates or just a scalar index if we have a 1D image.
+  Value coords;
+  if (memrefType.getRank() != 1) {
+    const auto coordVectorType = VectorType::get(
+        {loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]);
+    coords = rewriter.create<spirv::CompositeConstructOp>(
+        loadOp->getLoc(), coordVectorType, adaptor.getIndices());
+  } else {
+    coords = adaptor.getIndices()[0];
+  }
+
+  // Fetch the value out of the image.
+  const auto resultVectorType = VectorType::get({4}, loadOp.getType());
+  auto fetchOp = rewriter.create<spirv::ImageFetchOp>(
+      loadOp->getLoc(), resultVectorType, imageOp, coords,
+      mlir::spirv::ImageOperandsAttr{}, ValueRange{});
+
+  // Note that because OpImageFetch returns a rank 4 vector we need to extract
+  // the elements corresponding to the load which will since we only support the
+  // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
+  const auto compositeExtractOp =
+      rewriter.create<spirv::CompositeExtractOp>(loadOp->getLoc(), fetchOp, 0);
+
+  rewriter.replaceOp(loadOp, compositeExtractOp);
+  return success();
+}
+
 LogicalResult
 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
@@ -952,11 +1046,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns
-      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
-           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
-           CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
-          typeConverter, patterns.getContext());
+  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+               DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
+               IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
+               StoreOpPattern, ReinterpretCastPattern, CastPattern,
+               ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
+                                                      patterns.getContext());
 }
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..139db07d33923 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -575,6 +575,83 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
   spirv::StorageClass storageClass = attr.getValue();
 
+  // Images are a special case since they are an opaque type from which elements
+  // may be accessed via image specific ops or directly through a texture
+  // pointer.
+  if (storageClass == spirv::StorageClass::Image) {
+    const auto rank = type.getRank();
+    if (rank < 1 || rank > 3) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << type << " illegal: cannot lower memref of rank " << rank
+                 << " to a SPIR-V Image\n");
+      return nullptr;
+    }
+
+    const auto dim = [rank]() {
+#define DIM_CASE(DIM)                                                          \
+  case DIM:                                                                    \
+    return spirv::Dim::Dim##DIM##D
+      switch (rank) {
+        DIM_CASE(1);
+        DIM_CASE(2);
+        DIM_CASE(3);
+      default:
+        llvm_unreachable("Invalid memref rank!");
+      }
+#undef DIM_CASE
+    }();
+
+    // Note that we currently only support lowering to single element texels
+    // e.g. R32f
+    const auto elementType = type.getElementType();
+    if (!elementType.isIntOrFloat()) {
+      LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
+                              << elementType << " to a  SPIR-V Image\n");
+      return nullptr;
+    }
+
+    const auto imageFormat = [&elementType]() {
+      return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
+          .Case<Float16Type>(
+              [](Float16Type) { return spirv::ImageFormat::R16f; })
+          .Case<Float32Type>(
+              [](Float32Type) { return spirv::ImageFormat::R32f; })
+          .Case<IntegerType>([](IntegerType intType) {
+            auto const isSigned = intType.isSigned() || intType.isSignless();
+#define BIT_WIDTH_CASE(BIT_WIDTH)                                              \
+  case BIT_WIDTH:                                                              \
+    return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i                      \
+                    : spirv::ImageFormat::R##BIT_WIDTH##ui
+
+            switch (intType.getWidth()) {
+              BIT_WIDTH_CASE(16);
+              BIT_WIDTH_CASE(32);
+            default:
+              llvm_unreachable("Unhandled integer type!");
+            }
+          })
+          .Default([](Type) {
+            llvm_unreachable("Unhandled element type!");
+            // We need to return something here to satisfy the type switch.
+            return spirv::ImageFormat::R32f;
+          });
+#undef BIT_WIDTH_CASE
+    }();
+
+    // Currently every memref in the image storage class is converted to a
+    // sampled image so we can hardcode the NeedSampler field. Future work
+    // will generalize this to support regular non-sampled images.
+    const auto spvImageTy = spirv::ImageType::get(
+        elementType, dim, spirv::ImageDepthInfo::DepthUnknown,
+        spirv::ImageArrayedInfo::NonArrayed,
+        spirv::ImageSamplingInfo::SingleSampled,
+        spirv::ImageSamplerUseInfo::NeedSampler, imageFormat);
+    const auto spvSampledImageTy = spirv::SampledImageType::get(spvImageTy);
+    const auto imagePtrTy = spirv::PointerType::get(
+        spvSampledImageTy, spirv::StorageClass::UniformConstant);
+    return imagePtrTy;
+  }
+
   if (isa<IntegerType>(type.getElementType())) {
     if (type.getElementTypeBitWidth() == 1)
       return convertBoolMemrefType(targetEnv, options, type, storageClass);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index d0ddac8cd801c..2a7be0be7477a 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -488,3 +488,191 @@ module attributes {
     return
   }
 }
+
+// -----
+
+// Check Image Support.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
+    Shader,
+    PhysicalStorageBufferAddresses,
+    Image1D,
+    StorageImageExtendedFormats,
+    Float16,
+    StorageBuffer16BitAccess,
+    StorageUniform16,
+    Int16
+  ], [
+    SPV_KHR_storage_buffer_storage_class,
+    SPV_KHR_physical_storage_buffer,
+    SPV_KHR_16bit_storage
+  ]>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: @load_from_image_1D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class<Image>>, %arg1: memref<1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK-NOT: spirv.CompositeConstruct
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst] : memref<1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst] : memref<1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_3D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_f16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class<Image>>, %arg1: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>, vector<2xi32> -> vector<4xf16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_i32(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class<Image>>, %arg1: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>, vector<2xi32> -> vector<4xi32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_ui32(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class<Image>>, %arg1: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>, vector<2xi32> -> vector<4xui32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_i16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class<Image>>, %arg1: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>, vector<2xi32> -> vector<4xi16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_ui16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class<Image>>, %arg1: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>, vector<2xi32> -> vector<4xui16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_rank0(
+  func.func @load_from_image_2D_rank0(%arg0: memref<f32, #spirv.storage_class<Image>>, %arg1: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_rank4(
+  func.func @load_from_image_2D_rank4(%arg0: memref<1x1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_vector(
+  func.func @load_from_image_2D_vector(%arg0: memref<1xvector<1xf32>, #spirv.storage_class<Image>>, %arg1: memref<1xvector<1xf32>, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[%cst] : memref<1xvector<1xf32>, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[%cst] : memref<1xvector<1xf32>, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+}

>From da38fd21068ad4baba0dc0db7697d59b2baf5981 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Tue, 29 Jul 2025 07:48:00 +0100
Subject: [PATCH 2/2] [mlir][memref][spirv]: Address Feedback

* Remove templated type from using statement.
* Remove `const` from MLIR types.
* Remove `auto` usage where type not clear from source code.
* Make `Type` variable naming consistent.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp   | 12 ++++++------
 .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 16 ++++++++--------
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 8547b9e133dcd..18c5ce009cc6c 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -246,7 +246,7 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
 /// Converts memref.load to spirv.Image + spirv.ImageFetch
 class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
 public:
-  using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
@@ -685,14 +685,14 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
 LogicalResult
 ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                     ConversionPatternRewriter &rewriter) const {
-  const auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
+  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getMemorySpace() !=
       spirv::StorageClassAttr::get(rewriter.getContext(),
                                    spirv::StorageClass::Image))
     return failure();
 
-  const auto loadPtr = adaptor.getMemref();
-  const auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
+  auto loadPtr = adaptor.getMemref();
+  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
   if (failed(memoryRequirements))
     return rewriter.notifyMatchFailure(
         loadOp, "failed to determine memory requirements");
@@ -715,7 +715,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // for plain images and the OpImageRead instruction needs to be materialized
   // instead or texels need to be accessed via atomics through a texel pointer.
   // Future work will generalize support to plain images.
-  if (const auto convertedPointeeType = cast<spirv::PointerType>(
+  if (auto convertedPointeeType = cast<spirv::PointerType>(
           getTypeConverter()->convertType(loadOp.getMemRefType()));
       !isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
     return rewriter.notifyMatchFailure(loadOp,
@@ -740,7 +740,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   }
 
   // Fetch the value out of the image.
-  const auto resultVectorType = VectorType::get({4}, loadOp.getType());
+  auto resultVectorType = VectorType::get({4}, loadOp.getType());
   auto fetchOp = rewriter.create<spirv::ImageFetchOp>(
       loadOp->getLoc(), resultVectorType, imageOp, coords,
       mlir::spirv::ImageOperandsAttr{}, ValueRange{});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 139db07d33923..bdfac4fceb62a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -579,7 +579,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   // may be accessed via image specific ops or directly through a texture
   // pointer.
   if (storageClass == spirv::StorageClass::Image) {
-    const auto rank = type.getRank();
+    const int64_t rank = type.getRank();
     if (rank < 1 || rank > 3) {
       LLVM_DEBUG(llvm::dbgs()
                  << type << " illegal: cannot lower memref of rank " << rank
@@ -602,8 +602,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     }();
 
     // Note that we currently only support lowering to single element texels
-    // e.g. R32f
-    const auto elementType = type.getElementType();
+    // e.g. R32f.
+    auto elementType = type.getElementType();
     if (!elementType.isIntOrFloat()) {
       LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
                               << elementType << " to a  SPIR-V Image\n");
@@ -641,15 +641,15 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     // Currently every memref in the image storage class is converted to a
     // sampled image so we can hardcode the NeedSampler field. Future work
     // will generalize this to support regular non-sampled images.
-    const auto spvImageTy = spirv::ImageType::get(
+    auto spvImageType = spirv::ImageType::get(
         elementType, dim, spirv::ImageDepthInfo::DepthUnknown,
         spirv::ImageArrayedInfo::NonArrayed,
         spirv::ImageSamplingInfo::SingleSampled,
         spirv::ImageSamplerUseInfo::NeedSampler, imageFormat);
-    const auto spvSampledImageTy = spirv::SampledImageType::get(spvImageTy);
-    const auto imagePtrTy = spirv::PointerType::get(
-        spvSampledImageTy, spirv::StorageClass::UniformConstant);
-    return imagePtrTy;
+    auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
+    auto imagePtrType = spirv::PointerType::get(
+        spvSampledImageType, spirv::StorageClass::UniformConstant);
+    return imagePtrType;
   }
 
   if (isa<IntegerType>(type.getElementType())) {



More information about the Mlir-commits mailing list