[Mlir-commits] [mlir] 96c8b9e - [mlir][memref][spirv] Add SPIR-V Image Lowering (#150978)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 4 02:16:30 PDT 2025


Author: Jack Frankland
Date: 2025-08-04T10:16:26+01:00
New Revision: 96c8b9e508f72304cbd958be18fdd7b7240f7e63

URL: https://github.com/llvm/llvm-project/commit/96c8b9e508f72304cbd958be18fdd7b7240f7e63
DIFF: https://github.com/llvm/llvm-project/commit/96c8b9e508f72304cbd958be18fdd7b7240f7e63.diff

LOG: [mlir][memref][spirv] Add SPIR-V Image Lowering (#150978)

Adds an initial conversion in the Memref -> SPIR-V lowering for images.
Any memref in the "Image" storage-class/address-space will be considered
for lowering to the `!spirv.image` type during Memref to SPIR-V
conversion. Initially only the reading of sampled images are support and
images are read via the `OpImageFetch` instruction. Future work should
expand the conversion patterns to target non-sampled images and add
support for image write operations.

Images are supported for fp32, fp16, int32, uint32, int16 and uint16
types and lit tests have been added to verify this is the case along
with negative testing to check the cases where images aren't supported.

---------

Signed-off-by: Jack Frankland <jack.frankland at arm.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7a705336bf11c..c7ecd8334da42 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -244,6 +244,16 @@ class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts memref.load to spirv.Image + spirv.ImageFetch
+class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts memref.store to spirv.Store on integers.
 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
 public:
@@ -528,6 +538,17 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (!memrefType.getElementType().isSignlessInteger())
     return failure();
 
+  auto memorySpaceAttr =
+      dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+  if (!memorySpaceAttr)
+    return rewriter.notifyMatchFailure(
+        loadOp, "missing memory space SPIR-V storage class attribute");
+
+  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
+    return rewriter.notifyMatchFailure(
+        loadOp,
+        "failed to lower memref in image storage class to storage buffer");
+
   const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   Value accessChain =
       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
@@ -644,6 +665,18 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
+
+  auto memorySpaceAttr =
+      dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+  if (!memorySpaceAttr)
+    return rewriter.notifyMatchFailure(
+        loadOp, "missing memory space SPIR-V storage class attribute");
+
+  if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
+    return rewriter.notifyMatchFailure(
+        loadOp,
+        "failed to lower memref in image storage class to storage buffer");
+
   Value loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
@@ -662,6 +695,87 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   return success();
 }
 
+LogicalResult
+ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
+
+  auto memorySpaceAttr =
+      dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+  if (!memorySpaceAttr)
+    return rewriter.notifyMatchFailure(
+        loadOp, "missing memory space SPIR-V storage class attribute");
+
+  if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to lower memref in non-image storage class to image");
+
+  Value loadPtr = adaptor.getMemref();
+  auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
+  if (failed(memoryRequirements))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine memory requirements");
+
+  const auto [memoryAccess, alignment] = *memoryRequirements;
+
+  if (!loadOp.getMemRefType().hasRank())
+    return rewriter.notifyMatchFailure(
+        loadOp, "cannot lower unranked memrefs to SPIR-V images");
+
+  // We currently only support lowering of scalar memref elements to texels in
+  // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
+  // elements to texels in richer formats.
+  if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
+    return rewriter.notifyMatchFailure(
+        loadOp,
+        "cannot lower memrefs who's element type is not a SPIR-V scalar type"
+        "to SPIR-V images");
+
+  // We currently only support sampled images since OpImageFetch does not work
+  // for plain images and the OpImageRead instruction needs to be materialized
+  // instead or texels need to be accessed via atomics through a texel pointer.
+  // Future work will generalize support to plain images.
+  auto convertedPointeeType = cast<spirv::PointerType>(
+      getTypeConverter()->convertType(loadOp.getMemRefType()));
+  if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
+    return rewriter.notifyMatchFailure(loadOp,
+                                       "cannot lower memrefs which do not "
+                                       "convert to SPIR-V sampled images");
+
+  // Materialize the lowering.
+  Location loc = loadOp->getLoc();
+  auto imageLoadOp =
+      spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
+  // Extract the image from the sampled image.
+  auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
+
+  // Build a vector of coordinates or just a scalar index if we have a 1D image.
+  Value coords;
+  if (memrefType.getRank() != 1) {
+    auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
+                                           adaptor.getIndices().getType()[0]);
+    coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
+                                                 adaptor.getIndices());
+  } else {
+    coords = adaptor.getIndices()[0];
+  }
+
+  // Fetch the value out of the image.
+  auto resultVectorType = VectorType::get({4}, loadOp.getType());
+  auto fetchOp = spirv::ImageFetchOp::create(
+      rewriter, loc, resultVectorType, imageOp, coords,
+      mlir::spirv::ImageOperandsAttr{}, ValueRange{});
+
+  // Note that because OpImageFetch returns a rank 4 vector we need to extract
+  // the elements corresponding to the load which will since we only support the
+  // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
+  auto compositeExtractOp =
+      spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
+
+  rewriter.replaceOp(loadOp, compositeExtractOp);
+  return success();
+}
+
 LogicalResult
 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
@@ -953,11 +1067,11 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns
-      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
-           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
-           CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
-          typeConverter, patterns.getContext());
+  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+               DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
+               IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
+               StoreOpPattern, ReinterpretCastPattern, CastPattern,
+               ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
+                                                      patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 8f4c4cc027798..49f4ce8de7c76 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -608,6 +608,45 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
+static spirv::Dim convertRank(int64_t rank) {
+  switch (rank) {
+  case 1:
+    return spirv::Dim::Dim1D;
+  case 2:
+    return spirv::Dim::Dim2D;
+  case 3:
+    return spirv::Dim::Dim3D;
+  default:
+    llvm_unreachable("Invalid memref rank!");
+  }
+}
+
+static spirv::ImageFormat getImageFormat(Type elementType) {
+  return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
+      .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
+      .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
+      .Case<IntegerType>([](IntegerType intType) {
+        auto const isSigned = intType.isSigned() || intType.isSignless();
+#define BIT_WIDTH_CASE(BIT_WIDTH)                                              \
+  case BIT_WIDTH:                                                              \
+    return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i                      \
+                    : spirv::ImageFormat::R##BIT_WIDTH##ui
+
+        switch (intType.getWidth()) {
+          BIT_WIDTH_CASE(16);
+          BIT_WIDTH_CASE(32);
+        default:
+          llvm_unreachable("Unhandled integer type!");
+        }
+      })
+      .Default([](Type) {
+        llvm_unreachable("Unhandled element type!");
+        // We need to return something here to satisfy the type switch.
+        return spirv::ImageFormat::R32f;
+      });
+#undef BIT_WIDTH_CASE
+}
+
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                               const SPIRVConversionOptions &options,
                               MemRefType type) {
@@ -623,6 +662,41 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
   spirv::StorageClass storageClass = attr.getValue();
 
+  // Images are a special case since they are an opaque type from which elements
+  // may be accessed via image specific ops or directly through a texture
+  // pointer.
+  if (storageClass == spirv::StorageClass::Image) {
+    const int64_t rank = type.getRank();
+    if (rank < 1 || rank > 3) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << type << " illegal: cannot lower memref of rank " << rank
+                 << " to a SPIR-V Image\n");
+      return nullptr;
+    }
+
+    // Note that we currently only support lowering to single element texels
+    // e.g. R32f.
+    auto elementType = type.getElementType();
+    if (!isa<spirv::ScalarType>(elementType)) {
+      LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
+                              << elementType << " to a  SPIR-V Image\n");
+      return nullptr;
+    }
+
+    // Currently every memref in the image storage class is converted to a
+    // sampled image so we can hardcode the NeedSampler field. Future work
+    // will generalize this to support regular non-sampled images.
+    auto spvImageType = spirv::ImageType::get(
+        elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
+        spirv::ImageArrayedInfo::NonArrayed,
+        spirv::ImageSamplingInfo::SingleSampled,
+        spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
+    auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
+    auto imagePtrType = spirv::PointerType::get(
+        spvSampledImageType, spirv::StorageClass::UniformConstant);
+    return imagePtrType;
+  }
+
   if (isa<IntegerType>(type.getElementType())) {
     if (type.getElementTypeBitWidth() == 1)
       return convertBoolMemrefType(targetEnv, options, type, storageClass);

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index d0ddac8cd801c..2a7be0be7477a 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -488,3 +488,191 @@ module attributes {
     return
   }
 }
+
+// -----
+
+// Check Image Support.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [
+    Shader,
+    PhysicalStorageBufferAddresses,
+    Image1D,
+    StorageImageExtendedFormats,
+    Float16,
+    StorageBuffer16BitAccess,
+    StorageUniform16,
+    Int16
+  ], [
+    SPV_KHR_storage_buffer_storage_class,
+    SPV_KHR_physical_storage_buffer,
+    SPV_KHR_16bit_storage
+  ]>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: @load_from_image_1D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class<Image>>, %arg1: memref<1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK-NOT: spirv.CompositeConstruct
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst] : memref<1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst] : memref<1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_3D(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
+    %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
+    memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_f16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class<Image>>, %arg1: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>, vector<2xi32> -> vector<4xf16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_i32(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class<Image>>, %arg1: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>, vector<2xi32> -> vector<4xi32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_ui32(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class<Image>>, %arg1: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>, vector<2xi32> -> vector<4xui32>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_i16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class<Image>>, %arg1: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>, vector<2xi32> -> vector<4xi16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_ui16(
+  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class<Image>>, %arg1: memref<1x1xui16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>, UniformConstant>
+    %cst = arith.constant 0 : index
+    // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
+    // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16ui>, vector<2xi32> -> vector<4xui16>
+    // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16>
+    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<Image>>
+    // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16
+    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_rank0(
+  func.func @load_from_image_2D_rank0(%arg0: memref<f32, #spirv.storage_class<Image>>, %arg1: memref<f32, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[] : memref<f32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_rank4(
+  func.func @load_from_image_2D_rank4(%arg0: memref<1x1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[%cst, %cst, %cst, %cst] : memref<1x1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+
+  // CHECK-LABEL: @load_from_image_2D_vector(
+  func.func @load_from_image_2D_vector(%arg0: memref<1xvector<1xf32>, #spirv.storage_class<Image>>, %arg1: memref<1xvector<1xf32>, #spirv.storage_class<StorageBuffer>>) {
+    %cst = arith.constant 0 : index
+    // CHECK-NOT: spirv.Image
+    // CHECK-NOT: spirv.ImageFetch
+    %0 = memref.load %arg0[%cst] : memref<1xvector<1xf32>, #spirv.storage_class<Image>>
+    memref.store %0, %arg1[%cst] : memref<1xvector<1xf32>, #spirv.storage_class<StorageBuffer>>
+    return
+  }
+}


        


More information about the Mlir-commits mailing list