[Mlir-commits] [mlir] [mlir][memref][spirv] Add SPIR-V Image Lowering (PR #150978)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 30 05:09:54 PDT 2025
================
@@ -661,6 +694,86 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return success();
}
+LogicalResult
+ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
+
+ auto memorySpaceAttr =
+ dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!memorySpaceAttr)
+ return rewriter.notifyMatchFailure(
+ loadOp, "missing memory space SPIR-V storage class attribute");
+
+ if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to lower memref in non-image storage class to image");
+
+ Value loadPtr = adaptor.getMemref();
+ auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
+ if (failed(memoryRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine memory requirements");
+
+ const auto [memoryAccess, alignment] = *memoryRequirements;
+
+ if (!loadOp.getMemRefType().hasRank())
+ return rewriter.notifyMatchFailure(
+ loadOp, "cannot lower unranked memrefs to SPIR-V images");
+
+ // We currently only support lowering of scalar memref elements to texels in
+ // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
+ // elements to texels in richer formats.
+ if (!loadOp.getMemRefType().getElementType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(
+ loadOp, "cannot lower memrefs who's element type is not int or float "
+ "to SPIR-V images");
+
+ // We currently only support sampled images since OpImageFetch does not work
+ // for plain images and the OpImageRead instruction needs to be materialized
+ // instead or texels need to be accessed via atomics through a texel pointer.
+ // Future work will generalize support to plain images.
+ auto convertedPointeeType = cast<spirv::PointerType>(
+ getTypeConverter()->convertType(loadOp.getMemRefType()));
+ if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
+ return rewriter.notifyMatchFailure(loadOp,
+ "cannot lower memrefs which do not "
+ "convert to SPIR-V sampled images");
+
+ // Materialize the lowering.
+ auto imageLoadOp = spirv::LoadOp::create(rewriter, loadOp->getLoc(), loadPtr,
+ memoryAccess, alignment);
+ // Extract the image from the sampled image.
+ auto imageOp =
+ spirv::ImageOp::create(rewriter, loadOp->getLoc(), imageLoadOp);
+
+ // Build a vector of coordinates or just a scalar index if we have a 1D image.
+ Value coords;
+ if (memrefType.getRank() != 1) {
+ const auto coordVectorType = VectorType::get(
+ {loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]);
+ coords = spirv::CompositeConstructOp::create(
+ rewriter, loadOp->getLoc(), coordVectorType, adaptor.getIndices());
----------------
kuhar wrote:
Since we query the loadOp so frequently, maybe hoist it to a local variable `Location loc = loadOp.getLoc();`?
https://github.com/llvm/llvm-project/pull/150978
More information about the Mlir-commits
mailing list