[Mlir-commits] [mlir] [AMDGPU] fold `memref.subview/expand_shape/collapse_shape` into `amdgpu.gather_to_lds` (PR #149851)

Alan Li llvmlistbot at llvm.org
Wed Jul 23 07:22:08 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/149851

>From 9f6afe18bceeca2b9d6e26368be2e06bbaf870a9 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 21 Jul 2025 16:33:54 +0000
Subject: [PATCH 1/9] [AMDGPU] fold memref.subview into amdgpu.gather_to_lds

---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.h   |  6 +-
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  | 12 ++++
 .../Dialect/AMDGPU/Transforms/CMakeLists.txt  |  3 +-
 .../AMDGPU/Transforms/FoldSubviewOps.cpp      | 65 +++++++++++++++++++
 .../Dialect/AMDGPU/amdgpu-fold-subviews.mlir  | 50 ++++++++++++++
 5 files changed, 134 insertions(+), 2 deletions(-)
 create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
 create mode 100644 mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index cc2f543e79f69..a61903609aaff 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,8 +22,9 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUFOLDSUBVIEWOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 8d0e6829ab0cc..7529511b0ea76 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
     "memref::MemRefDialect"
   ];
 }
+
+def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-subview-ops"> {
+  let summary = "Fold subview operations into their parent operations";
+  let description = [{
+    This pass identifies `memref.subview` source of `GatherToLDSOp` and
+    attempts to fold the source op, potentially simplifying the overall
+    operation and improving performance.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54ea6c0c..20621ec0d55a4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  ResolveStridedMetadata.cpp
+  FoldSubviewOps.cpp
   MaskedloadToLoad.cpp
+  ResolveStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
new file mode 100644
index 0000000000000..a962f7a2526b2
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
@@ -0,0 +1,65 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+namespace {
+struct AmdgpuFoldSubviewOpsPass
+    : public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
+          AmdgpuFoldSubviewOpsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuFoldSubviewOpsPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    // Check if the source is a subview operation:
+    auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
+    if (!subviewOp)
+      return rewriter.notifyMatchFailure(
+          loc, "GatherToLDSOp can only be folded if the source is a SubviewOp");
+
+    SmallVector<Value> sourceIndices;
+    mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(),
+        subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices);
+
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(
+        op, subviewOp.getSource(), sourceIndices, op.getDst(),
+        op.getDstIndices(), op.getTransferType());
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuFoldSubviewOpsPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
new file mode 100644
index 0000000000000..d582991c3622f
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -amdgpu-fold-subview-ops -split-input-file %s | FileCheck %s
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_memref
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_memref(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK:  %[[MEM]][%arg0, %arg1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @subview_folding_offset
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+
+  // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+  // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+
+  // CHECK:  %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[32, 64][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1], offset: 4160>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}

>From 71fe3aa49154184123546c40c72d695680be7133 Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 21 Jul 2025 14:21:05 -0400
Subject: [PATCH 2/9] Update
 mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
index a962f7a2526b2..7b81800f07ab2 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
@@ -43,7 +43,7 @@ struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
     auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
     if (!subviewOp)
       return rewriter.notifyMatchFailure(
-          loc, "GatherToLDSOp can only be folded if the source is a SubviewOp");
+          loc, "GatherToLDSOp folding is currently supported only when the source is a SubviewOp. This is one specific pattern, and other scenarios may be added in the future.");
 
     SmallVector<Value> sourceIndices;
     mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(

>From bd4ade5466db59f84e88dc62773c38a40bb05c77 Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 21 Jul 2025 14:21:15 -0400
Subject: [PATCH 3/9] Update
 mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 7529511b0ea76..fad939ced9877 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -74,8 +74,8 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
 def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-subview-ops"> {
   let summary = "Fold subview operations into their parent operations";
   let description = [{
-    This pass identifies `memref.subview` source of `GatherToLDSOp` and
-    attempts to fold the source op, potentially simplifying the overall
+    This pass identifies `memref.subview` sources of `GatherToLDSOp` and
+    attempts to fold the source ops, potentially simplifying the overall
     operation and improving performance.
   }];
   let dependentDialects = [

>From 9552f4ed9b2857c79fedb2faab32cdaddd8dfda1 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 21 Jul 2025 14:49:21 -0400
Subject: [PATCH 4/9] linting

---
 mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
index 7b81800f07ab2..adbdf4b856bd5 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
@@ -43,7 +43,9 @@ struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
     auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
     if (!subviewOp)
       return rewriter.notifyMatchFailure(
-          loc, "GatherToLDSOp folding is currently supported only when the source is a SubviewOp. This is one specific pattern, and other scenarios may be added in the future.");
+          loc, "GatherToLDSOp folding is currently supported only when the "
+               "source is a SubviewOp. This is one specific pattern, and other "
+               "scenarios may be added in the future.");
 
     SmallVector<Value> sourceIndices;
     mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(

>From 5d6483db146deb10fc8a769ab613cb6def3ca083 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 21 Jul 2025 19:59:35 +0000
Subject: [PATCH 5/9] updating tests

---
 mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
index d582991c3622f..a0f02a9bc9340 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
@@ -8,7 +8,7 @@ func.func @test_memref(%offset_i: index, %offset_j: index) {
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
-  // CHECK:  %[[MEM]][%arg0, %arg1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%arg0, %arg1], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
@@ -37,7 +37,7 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
   // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
   // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
 
-  // CHECK:  %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>

>From cf50f5f16b797da30f90f843ac42b3aac2ff4f9c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Jul 2025 23:41:48 -0400
Subject: [PATCH 6/9] Support Expandshape and collapse shape.

---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   | 37 ++++++++
 .../AMDGPU/Transforms/FoldSubviewOps.cpp      | 50 ++++++----
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 91 -------------------
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 66 ++++++++++++++
 4 files changed, 137 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 34ad279a07a8b..dd3b3dea6ef26 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
 /// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
index adbdf4b856bd5..f005842d83306 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
@@ -11,7 +11,9 @@
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir::amdgpu {
 #define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
@@ -33,28 +35,44 @@ struct AmdgpuFoldSubviewOpsPass
   }
 };
 
-struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
-  using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    // Check if the source is a subview operation:
-    auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
-    if (!subviewOp)
-      return rewriter.notifyMatchFailure(
-          loc, "GatherToLDSOp folding is currently supported only when the "
-               "source is a SubviewOp. This is one specific pattern, and other "
-               "scenarios may be added in the future.");
-
+    Value memrefSource;
     SmallVector<Value> sourceIndices;
-    mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-        rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(),
-        subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices);
+    llvm::TypeSwitch<Operation *>(op.getSrc().getDefiningOp())
+        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+          // If the source is a SubViewOp, we can directly rewrite the
+          // GatherToLDSOp.
+          mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+              rewriter, loc, subviewOp.getMixedOffsets(),
+              subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+              op.getSrcIndices(), sourceIndices);
+          memrefSource = subviewOp.getSource();
+        })
+        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+          mlir::memref::resolveSourceIndicesExpandShape(
+              loc, rewriter, expandShapeOp, op.getSrcIndices(), sourceIndices,
+              false);
+          memrefSource = expandShapeOp.getViewSource();
+        })
+        .Case<memref::CollapseShapeOp>(
+            [&](memref::CollapseShapeOp collapseShapeOp) {
+              mlir::memref::resolveSourceIndicesCollapseShape(
+                  loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                  sourceIndices);
+              memrefSource = collapseShapeOp.getViewSource();
+            });
+
+    if (!memrefSource)
+      return failure();
 
-    rewriter.replaceOpWithNewOp<GatherToLDSOp>(
-        op, subviewOp.getSource(), sourceIndices, op.getDst(),
-        op.getDstIndices(), op.getTransferType());
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+                                               op.getDst(), op.getDstIndices(),
+                                               op.getTransferType());
 
     return success();
   }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188af9129..24da447ad7685 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-///    : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-///          memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
-  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
-  // Traverse all reassociation groups to determine the appropriate indices
-  // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-    if (groupSize == 1) {
-      sourceIndices.push_back(indices[group[0]]);
-      continue;
-    }
-    SmallVector<OpFoldResult> groupBasis =
-        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
-    SmallVector<Value> groupIndices =
-        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
-    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
-    sourceIndices.push_back(collapsedIndex);
-  }
-  return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-///    : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-///          memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
-  // Note: collapse_shape requires a strided memref, we can do this.
-  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-      loc, collapseShapeOp.getSrc());
-  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
-  for (auto [index, group] :
-       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-
-    if (groupSize == 1) {
-      sourceIndices.push_back(index);
-      continue;
-    }
-
-    SmallVector<OpFoldResult> basis =
-        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
-    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        loc, index, basis, /*hasOuterBound=*/true);
-    llvm::append_range(sourceIndices, delinearize.getResults());
-  }
-  if (collapseShapeOp.getReassociationIndices().empty()) {
-    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    int64_t srcRank =
-        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
-    for (int64_t i = 0; i < srcRank; i++) {
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-    }
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cfc74708..97fe3cb5b4705 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+  // Traverse all reassociation groups to determine the appropriate indices
+  // corresponding to each one of them post op folding.
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
+    }
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
+  }
+  return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
+    }
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    int64_t srcRank =
+        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+    for (int64_t i = 0; i < srcRank; i++) {
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    }
+  }
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir

>From 3db555db0fa24b54662bf449609479bba8933ea1 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 22 Jul 2025 23:52:39 -0400
Subject: [PATCH 7/9] update tests.

---
 .../AMDGPU/Transforms/FoldSubviewOps.cpp      | 76 +++++++++++--------
 .../Dialect/AMDGPU/amdgpu-fold-subviews.mlir  | 52 ++++++++++++-
 2 files changed, 94 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
index f005842d83306..95ba0a76d4510 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
@@ -18,12 +18,7 @@
 namespace mlir::amdgpu {
 #define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
 
-namespace {
 struct AmdgpuFoldSubviewOpsPass
     : public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
           AmdgpuFoldSubviewOpsPass> {
@@ -43,32 +38,51 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
 
     Value memrefSource;
     SmallVector<Value> sourceIndices;
-    llvm::TypeSwitch<Operation *>(op.getSrc().getDefiningOp())
-        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-          // If the source is a SubViewOp, we can directly rewrite the
-          // GatherToLDSOp.
-          mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-              rewriter, loc, subviewOp.getMixedOffsets(),
-              subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-              op.getSrcIndices(), sourceIndices);
-          memrefSource = subviewOp.getSource();
-        })
-        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
-          mlir::memref::resolveSourceIndicesExpandShape(
-              loc, rewriter, expandShapeOp, op.getSrcIndices(), sourceIndices,
-              false);
-          memrefSource = expandShapeOp.getViewSource();
-        })
-        .Case<memref::CollapseShapeOp>(
-            [&](memref::CollapseShapeOp collapseShapeOp) {
-              mlir::memref::resolveSourceIndicesCollapseShape(
-                  loc, rewriter, collapseShapeOp, op.getSrcIndices(),
-                  sourceIndices);
-              memrefSource = collapseShapeOp.getViewSource();
+    auto foldResult =
+        llvm::TypeSwitch<Operation *, LogicalResult>(
+            op.getSrc().getDefiningOp())
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+              // If the source is a SubViewOp, we can directly rewrite the
+              // GatherToLDSOp.
+              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                  rewriter, loc, subviewOp.getMixedOffsets(),
+                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                  op.getSrcIndices(), sourceIndices);
+              memrefSource = subviewOp.getSource();
+              return success();
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expandShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
+                          sourceIndices, false))) {
+                    return failure();
+                  }
+                  memrefSource = expandShapeOp.getViewSource();
+                  return success();
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapseShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                          sourceIndices))) {
+                    return failure();
+                  }
+                  memrefSource = collapseShapeOp.getViewSource();
+                  return success();
+                })
+            .Default([&](Operation *op) {
+              // If the source is not a SubViewOp, ExpandShapeOp, or
+              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+              return rewriter.notifyMatchFailure(
+                  op,
+                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
+                  "CollapseShapeOp");
             });
 
-    if (!memrefSource)
+    if (failed(foldResult)) {
       return failure();
+    }
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
                                                op.getDst(), op.getDstIndices(),
@@ -77,9 +91,9 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
     return success();
   }
 };
-} // namespace
 
-void mlir::amdgpu::populateAmdgpuFoldSubviewOpsPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit) {
   patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
 }
+} // namespace mlir::amdgpu
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
index a0f02a9bc9340..2c1b1a652fe1e 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt -amdgpu-fold-subview-ops -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --amdgpu-fold-subview-ops --split-input-file %s | FileCheck %s
 
 #gpu_lds_addrspace = 3
 
-// CHECK: func @test_memref
+// CHECK: func @test_subview_folding
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
-func.func @test_memref(%offset_i: index, %offset_j: index) {
+func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
@@ -48,3 +48,49 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
     : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
   func.return
 }
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<8192xf16>
+  %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_collapse_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}

>From d6746b955d060549a2fb91105f3422a2f9996b03 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 23 Jul 2025 08:36:32 -0400
Subject: [PATCH 8/9] Rename and update

---
 mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td  | 10 +++++-----
 mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt      |  2 +-
 .../{FoldSubviewOps.cpp => FoldMemRefsOps.cpp}         | 10 +++++-----
 ...gpu-fold-subviews.mlir => amdgpu-fold-memrefs.mlir} |  2 +-
 4 files changed, 12 insertions(+), 12 deletions(-)
 rename mlir/lib/Dialect/AMDGPU/Transforms/{FoldSubviewOps.cpp => FoldMemRefsOps.cpp} (94%)
 rename mlir/test/Dialect/AMDGPU/{amdgpu-fold-subviews.mlir => amdgpu-fold-memrefs.mlir} (98%)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index fad939ced9877..76b8c825ac272 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -71,12 +71,12 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
   ];
 }
 
-def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-subview-ops"> {
-  let summary = "Fold subview operations into their parent operations";
+def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+  let summary = "Fold memref operations into their parent operations";
   let description = [{
-    This pass identifies `memref.subview` sources of `GatherToLDSOp` and
-    attempts to fold the source ops, potentially simplifying the overall
-    operation and improving performance.
+    This pass identifies memref operations (subview, expand_shape, collapse_shape)
+    that are sources of `GatherToLDSOp` and attempts to fold the source ops,
+    potentially simplifying the overall operation and improving performance.
   }];
   let dependentDialects = [
     "memref::MemRefDialect"
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 20621ec0d55a4..3b0c072ed1217 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  FoldSubviewOps.cpp
+  FoldMemRefsOps.cpp
   MaskedloadToLoad.cpp
   ResolveStridedMetadata.cpp
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
similarity index 94%
rename from mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
rename to mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 95ba0a76d4510..73923af1329db 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -1,4 +1,4 @@
-//===- FoldSubviewOps.cpp - AMDGPU fold subview ops ---------------------===//
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,9 +19,9 @@ namespace mlir::amdgpu {
 #define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
-struct AmdgpuFoldSubviewOpsPass
+struct AmdgpuFoldMemRefOpsPass
     : public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
-          AmdgpuFoldSubviewOpsPass> {
+          AmdgpuFoldMemRefOpsPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateAmdgpuFoldSubviewOpsPatterns(patterns);
@@ -30,7 +30,7 @@ struct AmdgpuFoldSubviewOpsPass
   }
 };
 
-struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
@@ -94,6 +94,6 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
 
 void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit) {
-  patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
+  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
 }
 } // namespace mlir::amdgpu
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
similarity index 98%
rename from mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
rename to mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 2c1b1a652fe1e..a751a4ac1158e 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --amdgpu-fold-subview-ops --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --amdgpu-fold-memrefs-ops --split-input-file %s | FileCheck %s
 
 #gpu_lds_addrspace = 3
 

>From 2d653d05755e960a0f9345cf73530256d8a4bc23 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 23 Jul 2025 10:20:03 -0400
Subject: [PATCH 9/9] Fix according to comments

---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.h    |  6 +++---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td   |  2 +-
 .../AMDGPU/Transforms/FoldMemRefsOps.cpp       | 18 ++++++++----------
 .../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir    | 18 ++++++++----------
 4 files changed, 20 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a61903609aaff..58b9c74b2f8e0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,7 +22,7 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPUFOLDSUBVIEWOPSPASS
+#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
@@ -39,8 +39,8 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
-void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 76b8c825ac272..8664f971cabde 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -71,7 +71,7 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
   ];
 }
 
-def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
   let summary = "Fold memref operations into their parent operations";
   let description = [{
     This pass identifies memref operations (subview, expand_shape, collapse_shape)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 73923af1329db..a3fdc7ee385ed 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,21 +12,19 @@
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
-struct AmdgpuFoldMemRefOpsPass
-    : public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
-          AmdgpuFoldMemRefOpsPass> {
+struct AmdgpuFoldMemRefOpsPass final
+    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateAmdgpuFoldSubviewOpsPatterns(patterns);
-    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
-      signalPassFailure();
+    populateAmdgpuFoldMemRefOpsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
   }
 };
 
@@ -92,8 +90,8 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   }
 };
 
-void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit) {
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit) {
   patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
 }
 } // namespace mlir::amdgpu
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index a751a4ac1158e..57afa127c9da8 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -5,10 +5,10 @@
 // CHECK: func @test_subview_folding
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
-  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%arg0, %arg1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]], %[[ARG1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
@@ -30,14 +30,12 @@ func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
 // CHECK: func @subview_folding_offset
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
-
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
   // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
-
-  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
@@ -56,11 +54,11 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
 // CHECK: func @test_expand_shape
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
-  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
@@ -79,11 +77,11 @@ func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
 // CHECK: func @test_collapse_shape
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
-  // CHECK:  amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>



More information about the Mlir-commits mailing list