[Mlir-commits] [mlir] [mlir][memref-to-spirv]: Reverse Image Load Coordinates (PR #160495)

Jack Frankland llvmlistbot at llvm.org
Mon Sep 29 09:53:05 PDT 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/160495

>From 32b43c867a2d9aceddf102c18eafd44b7757fe78 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Wed, 24 Sep 2025 11:13:59 +0100
Subject: [PATCH 1/7] [mlir][memref-to-spirv]: Reverse Image Load Coordinates

When converting a `memref.load` from the image address space to a
`spirv.ImageFetch` ensure that we reverse the load coordinates when they
are extracted from the load.

This is required because the coordinate operand to the fetch operation
is a vector with the coordinates in increasing rank whereas the load
operation has coordinates of decreaseing rank. For example, if the
memref.load operation loaded from the coordinates `[%z, %y, %x]` then
`%x` would be the fastest moving dimension followed by `%y` then `%z` so
the spirv.ImageFetch operation would expect a vector `<%x, %y, %z>`.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           |   4 +-
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 150 +++++++++++-------
 2 files changed, 97 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index f44552c4556c2..da1c2a18414cc 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -758,8 +758,10 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (memrefType.getRank() != 1) {
     auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
                                            adaptor.getIndices().getType()[0]);
+    auto indices = llvm::to_vector(adaptor.getIndices());
+    auto indicesReversed = llvm::to_vector(llvm::reverse(indices));
     coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
-                                                 adaptor.getIndices());
+                                                 indicesReversed);
   } else {
     coords = adaptor.getIndices()[0];
   }
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e6321e99693ac..56bf4939a0e63 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -550,121 +550,159 @@ module attributes {
   }
 
   // CHECK-LABEL: @load_from_image_2D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D(%arg0: memref<2x2xf32, #spirv.storage_class<Image>>, %arg1: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xf32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_3D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x3xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_3D(%arg0: memref<3x3x3xf32, #spirv.storage_class<Image>>, %arg1: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<27 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<3x3x3xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
+    // CHECK: %[[Z:.*]] = arith.constant 2 : index
+    // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32
+    %z = arith.constant 2 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_f16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class<Image>>, %arg1: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_f16(%arg0: memref<2x2xf16, #spirv.storage_class<Image>>, %arg1: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>, vector<2xi32> -> vector<4xf16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xf16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class<Image>>, %arg1: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i32(%arg0: memref<2x2xi32, #spirv.storage_class<Image>>, %arg1: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>, vector<2xi32> -> vector<4xi32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xi32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_ui32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class<Image>>, %arg1: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui32(%arg0: memref<2x2xui32, #spirv.storage_class<Image>>, %arg1: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>, vector<2xi32> -> vector<4xui32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xui32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class<Image>>, %arg1: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i16(%arg0: memref<2x2xi16, #spirv.storage_class<Image>>, %arg1: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>, vector<2xi32> -> vector<4xi16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xi16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xi16, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_ui16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class<Image>>, %arg1: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xui16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui16(%arg0: memref<2x2xui16, #spirv.storage_class<Image>>, %arg1: memref<2x2xui16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x ui16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>, vector<2xi32> -> vector<4xui16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xui16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xui16, #spirv.storage_class<StorageBuffer>>
     return
   }
 

>From dcc20d1545cff05617ece999848249c7d9600f29 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Wed, 24 Sep 2025 14:52:47 +0100
Subject: [PATCH 2/7] Address PR Feedback

Reverse indices in place.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index da1c2a18414cc..1b66dab9ee985 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -758,8 +758,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (memrefType.getRank() != 1) {
     auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
                                            adaptor.getIndices().getType()[0]);
-    auto indices = llvm::to_vector(adaptor.getIndices());
-    auto indicesReversed = llvm::to_vector(llvm::reverse(indices));
+    auto indicesReversed = llvm::to_vector(llvm::reverse(adaptor.getIndices()));
     coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
                                                  indicesReversed);
   } else {

>From 7a396077a416f54a13fcc616361ea13584ebe9ad Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Thu, 25 Sep 2025 16:06:45 +0100
Subject: [PATCH 3/7] Address Feedback:

* Generalize coordinate mapping to support arbitrary permutations.
* Add lit tests for permutations.
* Make memrefs in lit tests non square.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           |  43 ++++-
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 169 ++++++++++++------
 2 files changed, 148 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 1b66dab9ee985..0dc7f693a6cc8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -699,6 +699,37 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   return success();
 }
 
+template <typename OpAdaptor>
+static FailureOr<SmallVector<Value>>
+extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
+                              ConversionPatternRewriter &rewriter) {
+  // Texel coordinates are ordered from inner most to outer most dimension
+  // i.e. u, v, w, a where:
+  // u: Coordinate in the first dimension of an image.
+  // v: Coordinate in the second dimension of an image.
+  // w: Coordinate in the third dimension of an image.
+  // a: Coordinate for array layer.
+  //
+  // The memrefs layout determines the dimension ordering so we need to invert
+  // the map to get the ordering.
+  SmallVector<Value> indices = adaptor.getIndices();
+  auto map = loadOp.getMemRefType().getLayout().getAffineMap();
+  if (!map.isPermutation())
+    return rewriter.notifyMatchFailure(
+        loadOp,
+        "Cannot lower memrefs with memory layout which is not a permutation");
+
+  const unsigned dimCount = map.getNumDims();
+  SmallVector<Value, 3> coords(dimCount);
+  for (unsigned dim = 0; dim < dimCount; ++dim)
+    coords[map.getDimPosition(dim)] = indices[dim];
+
+  // We need to do a final reversal since the image fetch op expects the first
+  // dimension in the 0th element position, 2nd dimension in the 1st element
+  // position etc. which is the opposite to the ordering in the map.
+  return llvm::to_vector(llvm::reverse(coords));
+}
+
 LogicalResult
 ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                     ConversionPatternRewriter &rewriter) const {
@@ -755,14 +786,16 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
 
   // Build a vector of coordinates or just a scalar index if we have a 1D image.
   Value coords;
-  if (memrefType.getRank() != 1) {
+  if (memrefType.getRank() == 1) {
+    coords = adaptor.getIndices()[0];
+  } else {
+    auto maybeCoords = extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
+    if (failed(maybeCoords))
+      return failure();
     auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
                                            adaptor.getIndices().getType()[0]);
-    auto indicesReversed = llvm::to_vector(llvm::reverse(adaptor.getIndices()));
     coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
-                                                 indicesReversed);
-  } else {
-    coords = adaptor.getIndices()[0];
+                                                 maybeCoords.value());
   }
 
   // Fetch the value out of the image.
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 56bf4939a0e63..0f7542cd9c469 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -515,6 +515,10 @@ module attributes {
 
 // Check Image Support.
 
+// CHECK: #[[COLMAJMAP:[a-z_]+]] = affine_map<(d0, d1) -> (d1, d0)>
+#col_major = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[CUSTOMLAYOUTMAP:[a-z0-9_]+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
+#custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
     Shader,
@@ -550,13 +554,13 @@ module attributes {
   }
 
   // CHECK-LABEL: @load_from_image_2D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D(%arg0: memref<2x2xf32, #spirv.storage_class<Image>>, %arg1: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D(%arg0: memref<2x4xf32, #spirv.storage_class<Image>>, %arg1: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 3 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -565,45 +569,92 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%y, %x] : memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
-  // CHECK-LABEL: @load_from_image_3D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x3xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_3D(%arg0: memref<3x3x3xf32, #spirv.storage_class<Image>>, %arg1: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<27 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<3x3x3xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-LABEL: @load_from_col_major_image_2D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_col_major_image_2D(%arg0: memref<2x4xf32, #col_major, #spirv.storage_class<Image>>, %arg1: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 3 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
-    // CHECK: %[[Z:.*]] = arith.constant 2 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%x, %y] : memref<2x4xf32, #col_major, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_3D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_3D(%arg0: memref<2x3x4xf32, #spirv.storage_class<Image>>, %arg1: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 3 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 3 : index
+    // CHECK: %[[Y:.*]] = arith.constant 2 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 2 : index
+    // CHECK: %[[Z:.*]] = arith.constant 1 : index
+    // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32
+    %z = arith.constant 1 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_custom_layout_image_3D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_custom_layout_image_3D(%arg0: memref<2x3x4xf32, #custom,  #spirv.storage_class<Image>>, %arg1: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 3 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 3 : index
+    // CHECK: %[[Y:.*]] = arith.constant 2 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 2 : index
+    // CHECK: %[[Z:.*]] = arith.constant 1 : index
     // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32
-    %z = arith.constant 2 : index
+    %z = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%x, %y, %z] : memref<2x3x4xf32, #custom, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_f16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_f16(%arg0: memref<2x2xf16, #spirv.storage_class<Image>>, %arg1: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xf16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_f16(%arg0: memref<2x3xf16, #spirv.storage_class<Image>>, %arg1: memref<2x3xf16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 2 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -612,20 +663,20 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>, vector<2xi32> -> vector<4xf16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xf16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x3xf16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16
-    memref.store %0, %arg1[%y, %x] : memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x3xf16, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i32(%arg0: memref<2x2xi32, #spirv.storage_class<Image>>, %arg1: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xi32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i32(%arg0: memref<2x3xi32, #spirv.storage_class<Image>>, %arg1: memref<2x3xi32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 2 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -634,20 +685,20 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>, vector<2xi32> -> vector<4xi32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xi32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x3xi32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32
-    memref.store %0, %arg1[%y, %x] : memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x3xi32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_ui32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_ui32(%arg0: memref<2x2xui32, #spirv.storage_class<Image>>, %arg1: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x ui32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xui32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui32(%arg0: memref<2x3xui32, #spirv.storage_class<Image>>, %arg1: memref<2x3xui32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 2 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -656,20 +707,20 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>, vector<2xi32> -> vector<4xui32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xui32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x3xui32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32
-    memref.store %0, %arg1[%y, %x] : memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x3xui32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i16(%arg0: memref<2x2xi16, #spirv.storage_class<Image>>, %arg1: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xi16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i16(%arg0: memref<2x3xi16, #spirv.storage_class<Image>>, %arg1: memref<2x3xi16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 2 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -678,20 +729,20 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>, vector<2xi32> -> vector<4xi16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xi16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x3xi16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16
-    memref.store %0, %arg1[%y, %x] : memref<2x2xi16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x3xi16, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_ui16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xui16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_ui16(%arg0: memref<2x2xui16, #spirv.storage_class<Image>>, %arg1: memref<2x2xui16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x ui16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
-    // CHECK: %[[X:.*]] = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xui16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui16(%arg0: memref<2x3xui16, #spirv.storage_class<Image>>, %arg1: memref<2x3xui16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
-    %x = arith.constant 0 : index
+    %x = arith.constant 2 : index
     // CHECK: %[[Y:.*]] = arith.constant 1 : index
     // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
     %y = arith.constant 1 : index
@@ -700,9 +751,9 @@ module attributes {
     // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>, vector<2xi32> -> vector<4xui16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16>
-    %0 = memref.load %arg0[%y, %x] : memref<2x2xui16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x3xui16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16
-    memref.store %0, %arg1[%y, %x] : memref<2x2xui16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x3xui16, #spirv.storage_class<StorageBuffer>>
     return
   }
 

>From ae1c1303370b3c988347cb50321733f3d40807a6 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 29 Sep 2025 10:44:32 +0100
Subject: [PATCH 4/7] Address Feedback

Update comments to reflect support for linearly tiled images only.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0dc7f693a6cc8..0f56c11748da9 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -703,12 +703,11 @@ template <typename OpAdaptor>
 static FailureOr<SmallVector<Value>>
 extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) {
-  // Texel coordinates are ordered from inner most to outer most dimension
-  // i.e. u, v, w, a where:
-  // u: Coordinate in the first dimension of an image.
-  // v: Coordinate in the second dimension of an image.
-  // w: Coordinate in the third dimension of an image.
-  // a: Coordinate for array layer.
+  // At present we only support linear "tiling" as specified in Vulkan, this
+  // means that texels are assumed to be laid out in memory in a row-major
+  // order. This allows us to support any memref layout that is a permutation of
+  // the dimensions. Future work will pass an optional image layout to the
+  // rewrite pattern so that we can support optimized target specific tilings.
   //
   // The memrefs layout determines the dimension ordering so we need to invert
   // the map to get the ordering.

>From a6e360ef79c0d8bb877a1f3ecd8f14a036eadf1c Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 29 Sep 2025 11:47:02 +0100
Subject: [PATCH 5/7] Address Feedback

Make map variables globally scoped in lit tests.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 48 +++++++++----------
 1 file changed, 24 insertions(+), 24 deletions(-)

diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 0f7542cd9c469..1c95ec2389ed1 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -515,9 +515,9 @@ module attributes {
 
 // Check Image Support.
 
-// CHECK: #[[COLMAJMAP:[a-z_]+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: #[[$COLMAJMAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 #col_major = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK: #[[CUSTOMLAYOUTMAP:[a-z0-9_]+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
+// CHECK: #[[$CUSTOMLAYOUTMAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
 #custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
@@ -538,8 +538,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_1D(
   // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class<Image>>, %arg1: memref<1xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
     %cst = arith.constant 0 : index
     // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
@@ -556,8 +556,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D(%arg0: memref<2x4xf32, #spirv.storage_class<Image>>, %arg1: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 3 : index
@@ -576,10 +576,10 @@ module attributes {
   }
 
   // CHECK-LABEL: @load_from_col_major_image_2D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_col_major_image_2D(%arg0: memref<2x4xf32, #col_major, #spirv.storage_class<Image>>, %arg1: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<8 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 3 : index
@@ -600,8 +600,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_3D(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_3D(%arg0: memref<2x3x4xf32, #spirv.storage_class<Image>>, %arg1: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 3 : index
@@ -623,10 +623,10 @@ module attributes {
   }
 
   // CHECK-LABEL: @load_from_custom_layout_image_3D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_custom_layout_image_3D(%arg0: memref<2x3x4xf32, #custom,  #spirv.storage_class<Image>>, %arg1: memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<24 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 3 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 3 : index
@@ -650,8 +650,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D_f16(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xf16, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D_f16(%arg0: memref<2x3xf16, #spirv.storage_class<Image>>, %arg1: memref<2x3xf16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x f16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 2 : index
@@ -672,8 +672,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D_i32(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xi32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D_i32(%arg0: memref<2x3xi32, #spirv.storage_class<Image>>, %arg1: memref<2x3xi32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 2 : index
@@ -694,8 +694,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D_ui32(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xui32, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D_ui32(%arg0: memref<2x3xui32, #spirv.storage_class<Image>>, %arg1: memref<2x3xui32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 2 : index
@@ -716,8 +716,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D_i16(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xi16, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D_i16(%arg0: memref<2x3xi16, #spirv.storage_class<Image>>, %arg1: memref<2x3xi16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 2 : index
@@ -738,8 +738,8 @@ module attributes {
   // CHECK-LABEL: @load_from_image_2D_ui16(
   // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x3xui16, #spirv.storage_class<StorageBuffer>>
   func.func @load_from_image_2D_ui16(%arg0: memref<2x3xui16, #spirv.storage_class<Image>>, %arg1: memref<2x3xui16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<6 x ui16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
     // CHECK: %[[X:.*]] = arith.constant 2 : index
     // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
     %x = arith.constant 2 : index

>From 94ef93ace623e3a790614aa4c54aa9468533b571 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 29 Sep 2025 17:22:24 +0100
Subject: [PATCH 6/7] Address Feedback

* Remove `auto` usage
* Expand comment

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0f56c11748da9..a90dcc8cc3ef1 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -708,24 +708,23 @@ extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // order. This allows us to support any memref layout that is a permutation of
   // the dimensions. Future work will pass an optional image layout to the
   // rewrite pattern so that we can support optimized target specific tilings.
-  //
-  // The memrefs layout determines the dimension ordering so we need to invert
-  // the map to get the ordering.
   SmallVector<Value> indices = adaptor.getIndices();
-  auto map = loadOp.getMemRefType().getLayout().getAffineMap();
+  AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
   if (!map.isPermutation())
     return rewriter.notifyMatchFailure(
         loadOp,
         "Cannot lower memrefs with memory layout which is not a permutation");
 
+  // The memrefs layout determines the dimension ordering so we need to follow
+  // the map to get the ordering of the dimensions/indices.
   const unsigned dimCount = map.getNumDims();
   SmallVector<Value, 3> coords(dimCount);
   for (unsigned dim = 0; dim < dimCount; ++dim)
     coords[map.getDimPosition(dim)] = indices[dim];
 
-  // We need to do a final reversal since the image fetch op expects the first
-  // dimension in the 0th element position, 2nd dimension in the 1st element
-  // position etc. which is the opposite to the ordering in the map.
+  // We need to reverse the coordinates because the memref layout is slowest to
+  // fastest moving and the vector coordinates for the image op is fastest to
+  // slowest moving.
   return llvm::to_vector(llvm::reverse(coords));
 }
 
@@ -788,7 +787,8 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (memrefType.getRank() == 1) {
     coords = adaptor.getIndices()[0];
   } else {
-    auto maybeCoords = extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
+    FailureOr<SmallVector<Value>> maybeCoords =
+        extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
     if (failed(maybeCoords))
       return failure();
     auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},

>From 83a6276ee1d3c3cdde9a2530d78e7ce0d95a3ed1 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 29 Sep 2025 17:52:09 +0100
Subject: [PATCH 7/7] Address Feedback

* Add negative test

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir   | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 1c95ec2389ed1..ab3c8b7397e1a 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -519,6 +519,8 @@ module attributes {
 #col_major = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK: #[[$CUSTOMLAYOUTMAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
 #custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
+// CHECK: #[[$NONPERMMAP:.*]] = affine_map<(d0, d1) -> (d0, d1 mod 2)>
+#non_permutation = affine_map<(d0, d1) -> (d0, d1 mod 2)>
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
     Shader,
@@ -786,4 +788,15 @@ module attributes {
     memref.store %0, %arg1[%cst] : memref<1xvector<1xf32>, #spirv.storage_class<StorageBuffer>>
     return
   }
+
+  // CHECK-LABEL: @load_non_perm_layout(
+  func.func @load_non_perm_layout(%arg0: memref<2x4xf32, #non_permutation, #spirv.storage_class<Image>>, %arg1: memref<2x4xf32, #spirv.storage_class<StorageBuffer>>) {
+    %x = arith.constant 3 : index
+    %y = arith.constant 1 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #non_permutation, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
 }



More information about the Mlir-commits mailing list