[Mlir-commits] [mlir] [amdgpu][mlir] fold memref into global load async to lds (PR #195409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 1 21:24:44 PDT 2026
https://github.com/efric created https://github.com/llvm/llvm-project/pull/195409
None
>From b54277170af298d36a6eec30251b7d95ff9854f6 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 1 May 2026 21:19:56 -0700
Subject: [PATCH] fold memref global async to lds
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../AMDGPU/Transforms/FoldMemRefsOps.cpp | 75 ++++++++++++++-----
.../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir | 50 +++++++++++++
2 files changed, 105 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 8570927a77794..e9c429c70fa1a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -65,39 +65,73 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
});
}
+static LogicalResult
+foldMemrefViewOps(PatternRewriter &rewriter, Operation *op, Value source,
+ mlir::OperandRange sourceIndices, Value dest,
+ mlir::OperandRange destIndices,
+ SmallVectorImpl<Value> &resolvedSourceIndices,
+ SmallVectorImpl<Value> &resolvedDestIndices,
+ Value &sourceBase, Value &destBase) {
+ Location loc = op->getLoc();
+
+ LogicalResult didFoldSource =
+ foldMemrefViewOp(rewriter, loc, source, sourceIndices,
+ resolvedSourceIndices, sourceBase, "source");
+ if (failed(didFoldSource)) {
+ sourceBase = source;
+ resolvedSourceIndices.assign(sourceIndices.begin(), sourceIndices.end());
+ }
+
+ LogicalResult didFoldDest =
+ foldMemrefViewOp(rewriter, loc, dest, destIndices, resolvedDestIndices,
+ destBase, "destination");
+ if (failed(didFoldDest)) {
+ destBase = dest;
+ resolvedDestIndices.assign(destIndices.begin(), destIndices.end());
+ }
+
+ if (failed(didFoldSource) && failed(didFoldDest))
+ return rewriter.notifyMatchFailure(op, "no fold found");
+
+ return success();
+}
+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using Base::Base;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
-
SmallVector<Value> sourceIndices, destIndices;
Value memrefSource, memrefDest;
- auto foldSrcResult =
- foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
- sourceIndices, memrefSource, "source");
+ if (failed(foldMemrefViewOps(rewriter, op, op.getSrc(), op.getSrcIndices(),
+ op.getDst(), op.getDstIndices(), sourceIndices,
+ destIndices, memrefSource, memrefDest)))
+ return failure();
- if (failed(foldSrcResult)) {
- memrefSource = op.getSrc();
- sourceIndices = op.getSrcIndices();
- }
+ rewriter.replaceOpWithNewOp<GatherToLDSOp>(
+ op, memrefSource, sourceIndices, memrefDest, destIndices,
+ op.getTransferType(), op.getAsync());
- auto foldDstResult =
- foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
- destIndices, memrefDest, "destination");
+ return success();
+ }
+};
- if (failed(foldDstResult)) {
- memrefDest = op.getDst();
- destIndices = op.getDstIndices();
- }
+struct FoldMemRefOpsIntoGlobalLoadAsyncToLDSOp final
+ : OpRewritePattern<GlobalLoadAsyncToLDSOp> {
+ using Base::Base;
+ LogicalResult matchAndRewrite(GlobalLoadAsyncToLDSOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> sourceIndices, destIndices;
+ Value memrefSource, memrefDest;
- if (failed(foldSrcResult) && failed(foldDstResult))
- return rewriter.notifyMatchFailure(op, "no fold found");
+ if (failed(foldMemrefViewOps(rewriter, op, op.getSrc(), op.getSrcIndices(),
+ op.getDst(), op.getDstIndices(), sourceIndices,
+ destIndices, memrefSource, memrefDest)))
+ return failure();
- rewriter.replaceOpWithNewOp<GatherToLDSOp>(
+ rewriter.replaceOpWithNewOp<GlobalLoadAsyncToLDSOp>(
op, memrefSource, sourceIndices, memrefDest, destIndices,
- op.getTransferType(), op.getAsync());
+ op.getTransferType(), op.getMask());
return success();
}
@@ -160,6 +194,7 @@ struct FoldMemRefOpsIntoTransposeLoadOp final
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FoldMemRefOpsIntoGatherToLDSOp,
+ FoldMemRefOpsIntoGlobalLoadAsyncToLDSOp,
FoldMemRefOpsIntoDmaBaseOp<MakeDmaBaseOp>,
FoldMemRefOpsIntoDmaBaseOp<MakeGatherDmaBaseOp>,
FoldMemRefOpsIntoTransposeLoadOp>(patterns.getContext(),
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 4fc6bc1846c3d..1274fe59f8be5 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -418,3 +418,53 @@ func.func @test_make_gather_dma_base_nop(%mem: memref<64x128xf16, #gpu_global_ad
: memref<64x128xf16, #gpu_global_addrspace>, memref<64x64xf16, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<f16, i16>
func.return
}
+
+// -----
+
+#gpu_lds_addrspace = #gpu.address_space<workgroup>
+#gpu_global_addrspace = #gpu.address_space<global>
+
+// CHECK: #[[GLOBAL_ASYNC_MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[GLOBAL_ASYNC_MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @test_global_load_async_to_lds_both_fold_masked
+// CHECK-SAME: %[[SRC:.*]]: memref<64x128xf32, #gpu.address_space<global>>, %[[LDS:.*]]: memref<4096xf32, #gpu.address_space<workgroup>>, %[[GLOBAL_I:.*]]: index, %[[GLOBAL_J:.*]]: index, %[[LDS_I:.*]]: index, %[[LDS_J:.*]]: index, %[[MASK:.*]]: i1
+func.func @test_global_load_async_to_lds_both_fold_masked(%src: memref<64x128xf32, #gpu_global_addrspace>, %lds: memref<4096xf32, #gpu_lds_addrspace>, %global_i: index, %global_j: index, %lds_i: index, %lds_j: index, %mask: i1) {
+ // CHECK: %[[GI:.*]] = affine.apply #[[GLOBAL_ASYNC_MAP]]()[%[[GLOBAL_I]]]
+ // CHECK: %[[GJ:.*]] = affine.apply #[[GLOBAL_ASYNC_MAP1]]()[%[[GLOBAL_J]]]
+ // CHECK: %[[LDS_IDX:.*]] = affine.linearize_index [%[[LDS_I]], %[[LDS_J]]] by (64, 64) : index
+ // CHECK: amdgpu.global_load_async_to_lds %[[SRC]][%[[GI]], %[[GJ]]], %[[LDS]][%[[LDS_IDX]]], %[[MASK]]
+ // CHECK-SAME: vector<4xf32>, memref<64x128xf32, #gpu.address_space<global>>, memref<4096xf32, #gpu.address_space<workgroup>>
+
+ %subview = memref.subview %src[32, 64][32, 64][1, 1]
+ : memref<64x128xf32, #gpu_global_addrspace>
+ to memref<32x64xf32, strided<[128, 1], offset: 4160>, #gpu_global_addrspace>
+ %expand_lds = memref.expand_shape %lds [[0, 1]] output_shape [64, 64]
+ : memref<4096xf32, #gpu_lds_addrspace>
+ into memref<64x64xf32, #gpu_lds_addrspace>
+ amdgpu.global_load_async_to_lds %subview[%global_i, %global_j], %expand_lds[%lds_i, %lds_j], %mask
+ : vector<4xf32>, memref<32x64xf32, strided<[128, 1], offset: 4160>, #gpu_global_addrspace>,
+ memref<64x64xf32, #gpu_lds_addrspace>
+ func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = #gpu.address_space<workgroup>
+#gpu_global_addrspace = #gpu.address_space<global>
+
+// CHECK: func @test_global_load_async_to_lds_no_mask_dst_collapse
+// CHECK-SAME: %[[SRC:.*]]: memref<8192xi32, #gpu.address_space<global>>, %[[LDS:.*]]: memref<64x64xi32, #gpu.address_space<workgroup>>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index
+func.func @test_global_load_async_to_lds_no_mask_dst_collapse(%src: memref<8192xi32, #gpu_global_addrspace>, %lds: memref<64x64xi32, #gpu_lds_addrspace>, %src_idx: index, %dst_idx: index) {
+ // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[DST_IDX]] into (64, 64) : index, index
+ // CHECK: amdgpu.global_load_async_to_lds %[[SRC]][%[[SRC_IDX]]], %[[LDS]][%[[INDICES]]#0, %[[INDICES]]#1] :
+ // CHECK-SAME: i32, memref<8192xi32, #gpu.address_space<global>>, memref<64x64xi32, #gpu.address_space<workgroup>>
+
+ %collapse_lds = memref.collapse_shape %lds [[0, 1]]
+ : memref<64x64xi32, #gpu_lds_addrspace>
+ into memref<4096xi32, #gpu_lds_addrspace>
+ amdgpu.global_load_async_to_lds %src[%src_idx], %collapse_lds[%dst_idx]
+ : i32, memref<8192xi32, #gpu_global_addrspace>,
+ memref<4096xi32, #gpu_lds_addrspace>
+ func.return
+}
More information about the Mlir-commits
mailing list