[Mlir-commits] [mlir] [amdgpu][mlir] fold memref into global load async to lds (PR #195409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 1 21:24:44 PDT 2026


https://github.com/efric created https://github.com/llvm/llvm-project/pull/195409

None

>From b54277170af298d36a6eec30251b7d95ff9854f6 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 1 May 2026 21:19:56 -0700
Subject: [PATCH] fold memref global async to lds

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../AMDGPU/Transforms/FoldMemRefsOps.cpp      | 75 ++++++++++++++-----
 .../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir   | 50 +++++++++++++
 2 files changed, 105 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 8570927a77794..e9c429c70fa1a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -65,39 +65,73 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
       });
 }
 
+static LogicalResult
+foldMemrefViewOps(PatternRewriter &rewriter, Operation *op, Value source,
+                  mlir::OperandRange sourceIndices, Value dest,
+                  mlir::OperandRange destIndices,
+                  SmallVectorImpl<Value> &resolvedSourceIndices,
+                  SmallVectorImpl<Value> &resolvedDestIndices,
+                  Value &sourceBase, Value &destBase) {
+  Location loc = op->getLoc();
+
+  LogicalResult didFoldSource =
+      foldMemrefViewOp(rewriter, loc, source, sourceIndices,
+                       resolvedSourceIndices, sourceBase, "source");
+  if (failed(didFoldSource)) {
+    sourceBase = source;
+    resolvedSourceIndices.assign(sourceIndices.begin(), sourceIndices.end());
+  }
+
+  LogicalResult didFoldDest =
+      foldMemrefViewOp(rewriter, loc, dest, destIndices, resolvedDestIndices,
+                       destBase, "destination");
+  if (failed(didFoldDest)) {
+    destBase = dest;
+    resolvedDestIndices.assign(destIndices.begin(), destIndices.end());
+  }
+
+  if (failed(didFoldSource) && failed(didFoldDest))
+    return rewriter.notifyMatchFailure(op, "no fold found");
+
+  return success();
+}
+
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using Base::Base;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-
     SmallVector<Value> sourceIndices, destIndices;
     Value memrefSource, memrefDest;
 
-    auto foldSrcResult =
-        foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
-                         sourceIndices, memrefSource, "source");
+    if (failed(foldMemrefViewOps(rewriter, op, op.getSrc(), op.getSrcIndices(),
+                                 op.getDst(), op.getDstIndices(), sourceIndices,
+                                 destIndices, memrefSource, memrefDest)))
+      return failure();
 
-    if (failed(foldSrcResult)) {
-      memrefSource = op.getSrc();
-      sourceIndices = op.getSrcIndices();
-    }
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(
+        op, memrefSource, sourceIndices, memrefDest, destIndices,
+        op.getTransferType(), op.getAsync());
 
-    auto foldDstResult =
-        foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
-                         destIndices, memrefDest, "destination");
+    return success();
+  }
+};
 
-    if (failed(foldDstResult)) {
-      memrefDest = op.getDst();
-      destIndices = op.getDstIndices();
-    }
+struct FoldMemRefOpsIntoGlobalLoadAsyncToLDSOp final
+    : OpRewritePattern<GlobalLoadAsyncToLDSOp> {
+  using Base::Base;
+  LogicalResult matchAndRewrite(GlobalLoadAsyncToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> sourceIndices, destIndices;
+    Value memrefSource, memrefDest;
 
-    if (failed(foldSrcResult) && failed(foldDstResult))
-      return rewriter.notifyMatchFailure(op, "no fold found");
+    if (failed(foldMemrefViewOps(rewriter, op, op.getSrc(), op.getSrcIndices(),
+                                 op.getDst(), op.getDstIndices(), sourceIndices,
+                                 destIndices, memrefSource, memrefDest)))
+      return failure();
 
-    rewriter.replaceOpWithNewOp<GatherToLDSOp>(
+    rewriter.replaceOpWithNewOp<GlobalLoadAsyncToLDSOp>(
         op, memrefSource, sourceIndices, memrefDest, destIndices,
-        op.getTransferType(), op.getAsync());
+        op.getTransferType(), op.getMask());
 
     return success();
   }
@@ -160,6 +194,7 @@ struct FoldMemRefOpsIntoTransposeLoadOp final
 void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit) {
   patterns.add<FoldMemRefOpsIntoGatherToLDSOp,
+               FoldMemRefOpsIntoGlobalLoadAsyncToLDSOp,
                FoldMemRefOpsIntoDmaBaseOp<MakeDmaBaseOp>,
                FoldMemRefOpsIntoDmaBaseOp<MakeGatherDmaBaseOp>,
                FoldMemRefOpsIntoTransposeLoadOp>(patterns.getContext(),
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 4fc6bc1846c3d..1274fe59f8be5 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -418,3 +418,53 @@ func.func @test_make_gather_dma_base_nop(%mem: memref<64x128xf16, #gpu_global_ad
     : memref<64x128xf16, #gpu_global_addrspace>, memref<64x64xf16, #gpu_lds_addrspace> -> !amdgpu.tdm_gather_base<f16, i16>
   func.return
 }
+
+// -----
+
+#gpu_lds_addrspace = #gpu.address_space<workgroup>
+#gpu_global_addrspace = #gpu.address_space<global>
+
+// CHECK: #[[GLOBAL_ASYNC_MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[GLOBAL_ASYNC_MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @test_global_load_async_to_lds_both_fold_masked
+// CHECK-SAME: %[[SRC:.*]]: memref<64x128xf32, #gpu.address_space<global>>, %[[LDS:.*]]: memref<4096xf32, #gpu.address_space<workgroup>>, %[[GLOBAL_I:.*]]: index, %[[GLOBAL_J:.*]]: index, %[[LDS_I:.*]]: index, %[[LDS_J:.*]]: index, %[[MASK:.*]]: i1
+func.func @test_global_load_async_to_lds_both_fold_masked(%src: memref<64x128xf32, #gpu_global_addrspace>, %lds: memref<4096xf32, #gpu_lds_addrspace>, %global_i: index, %global_j: index, %lds_i: index, %lds_j: index, %mask: i1) {
+  // CHECK: %[[GI:.*]] = affine.apply #[[GLOBAL_ASYNC_MAP]]()[%[[GLOBAL_I]]]
+  // CHECK: %[[GJ:.*]] = affine.apply #[[GLOBAL_ASYNC_MAP1]]()[%[[GLOBAL_J]]]
+  // CHECK: %[[LDS_IDX:.*]] = affine.linearize_index [%[[LDS_I]], %[[LDS_J]]] by (64, 64) : index
+  // CHECK: amdgpu.global_load_async_to_lds %[[SRC]][%[[GI]], %[[GJ]]], %[[LDS]][%[[LDS_IDX]]], %[[MASK]]
+  // CHECK-SAME: vector<4xf32>, memref<64x128xf32, #gpu.address_space<global>>, memref<4096xf32, #gpu.address_space<workgroup>>
+
+  %subview = memref.subview %src[32, 64][32, 64][1, 1]
+    : memref<64x128xf32, #gpu_global_addrspace>
+    to memref<32x64xf32, strided<[128, 1], offset: 4160>, #gpu_global_addrspace>
+  %expand_lds = memref.expand_shape %lds [[0, 1]] output_shape [64, 64]
+    : memref<4096xf32, #gpu_lds_addrspace>
+    into memref<64x64xf32, #gpu_lds_addrspace>
+  amdgpu.global_load_async_to_lds %subview[%global_i, %global_j], %expand_lds[%lds_i, %lds_j], %mask
+    : vector<4xf32>, memref<32x64xf32, strided<[128, 1], offset: 4160>, #gpu_global_addrspace>,
+      memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = #gpu.address_space<workgroup>
+#gpu_global_addrspace = #gpu.address_space<global>
+
+// CHECK: func @test_global_load_async_to_lds_no_mask_dst_collapse
+// CHECK-SAME: %[[SRC:.*]]: memref<8192xi32, #gpu.address_space<global>>, %[[LDS:.*]]: memref<64x64xi32, #gpu.address_space<workgroup>>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index
+func.func @test_global_load_async_to_lds_no_mask_dst_collapse(%src: memref<8192xi32, #gpu_global_addrspace>, %lds: memref<64x64xi32, #gpu_lds_addrspace>, %src_idx: index, %dst_idx: index) {
+  // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[DST_IDX]] into (64, 64) : index, index
+  // CHECK: amdgpu.global_load_async_to_lds %[[SRC]][%[[SRC_IDX]]], %[[LDS]][%[[INDICES]]#0, %[[INDICES]]#1] :
+  // CHECK-SAME: i32, memref<8192xi32, #gpu.address_space<global>>, memref<64x64xi32, #gpu.address_space<workgroup>>
+
+  %collapse_lds = memref.collapse_shape %lds [[0, 1]]
+    : memref<64x64xi32, #gpu_lds_addrspace>
+    into memref<4096xi32, #gpu_lds_addrspace>
+  amdgpu.global_load_async_to_lds %src[%src_idx], %collapse_lds[%dst_idx]
+    : i32, memref<8192xi32, #gpu_global_addrspace>,
+      memref<4096xi32, #gpu_lds_addrspace>
+  func.return
+}



More information about the Mlir-commits mailing list