[Mlir-commits] [mlir] [AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (PR #152277)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 6 02:20:35 PDT 2025


https://github.com/sebvince created https://github.com/llvm/llvm-project/pull/152277

This PR follows https://github.com/llvm/llvm-project/pull/150334 and applies folding of  memref.subview/expand_shape/collapse_shape ops on the DST operand as well. 



>From d63ecf5a885580e76ab932b5d0f8de2fae320e72 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Mon, 4 Aug 2025 04:51:38 -0500
Subject: [PATCH 1/3] [AMDGPU] fold dst operand of gather to lds

---
 .../AMDGPU/Transforms/FoldMemRefsOps.cpp      | 98 ++++++++++---------
 1 file changed, 52 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7ee385ed..0ca9970bb7992 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -28,63 +29,68 @@ struct AmdgpuFoldMemRefOpsPass final
   }
 };
 
+
+static LogicalResult foldMemrefViewOp(
+    PatternRewriter &rewriter, Location loc, 
+    Value view, mlir::OperandRange indices, 
+    SmallVectorImpl<Value> &resolvedIndices, 
+    Value &memrefBase, StringRef role) 
+{
+    Operation *defOp = view.getDefiningOp();
+    return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+            mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                rewriter, loc, subviewOp.getMixedOffsets(),
+                subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                indices, resolvedIndices);
+            memrefBase = subviewOp.getSource();
+            return success();
+        })
+        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                    loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
+                return failure();
+            }
+            memrefBase = expandShapeOp.getViewSource();
+            return success();
+        })
+        .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                    loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
+                return failure();
+            }
+            memrefBase = collapseShapeOp.getViewSource();
+            return success();
+        })
+        .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(
+                op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
+        });
+}
+
+
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    Value memrefSource;
-    SmallVector<Value> sourceIndices;
-    auto foldResult =
-        llvm::TypeSwitch<Operation *, LogicalResult>(
-            op.getSrc().getDefiningOp())
-            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-              // If the source is a SubViewOp, we can directly rewrite the
-              // GatherToLDSOp.
-              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-                  rewriter, loc, subviewOp.getMixedOffsets(),
-                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-                  op.getSrcIndices(), sourceIndices);
-              memrefSource = subviewOp.getSource();
-              return success();
-            })
-            .Case<memref::ExpandShapeOp>(
-                [&](memref::ExpandShapeOp expandShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
-                          sourceIndices, false))) {
-                    return failure();
-                  }
-                  memrefSource = expandShapeOp.getViewSource();
-                  return success();
-                })
-            .Case<memref::CollapseShapeOp>(
-                [&](memref::CollapseShapeOp collapseShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
-                          sourceIndices))) {
-                    return failure();
-                  }
-                  memrefSource = collapseShapeOp.getViewSource();
-                  return success();
-                })
-            .Default([&](Operation *op) {
-              // If the source is not a SubViewOp, ExpandShapeOp, or
-              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
-              return rewriter.notifyMatchFailure(
-                  op,
-                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
-                  "CollapseShapeOp");
-            });
+    SmallVector<Value> sourceIndices, destIndices;
+    Value memrefSource, memrefDest;
+
+    auto foldSrcResult = foldMemrefViewOp(
+        rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+  
+    auto foldDstResult = foldMemrefViewOp(
+        rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
 
-    if (failed(foldResult)) {
+    if (failed(foldSrcResult) || failed(foldDstResult)) {
       return failure();
     }
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
-                                               op.getDst(), op.getDstIndices(),
-                                               op.getTransferType());
+                                              memrefDest, destIndices,
+                                              op.getTransferType());
 
     return success();
   }

>From b4908bde58802d872eba87e05635187866e7763b Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Tue, 5 Aug 2025 04:53:02 -0500
Subject: [PATCH 2/3] Fix cast issue thanks to me at alanli.org

---
 .../Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 0ca9970bb7992..4474c5de97b0e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -37,6 +37,9 @@ static LogicalResult foldMemrefViewOp(
     Value &memrefBase, StringRef role) 
 {
     Operation *defOp = view.getDefiningOp();
+    if (!defOp) {
+        return failure();
+    }
     return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
         .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
             mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
@@ -81,12 +84,19 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
     auto foldSrcResult = foldMemrefViewOp(
         rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
   
+    if (failed(foldSrcResult)) {
+        memrefSource = op.getSrc();
+        sourceIndices = op.getSrcIndices();
+    }
+
     auto foldDstResult = foldMemrefViewOp(
         rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
 
-    if (failed(foldSrcResult) || failed(foldDstResult)) {
-      return failure();
-    }
+    if (failed(foldDstResult)) {
+        memrefDest = op.getDst();
+        destIndices = op.getDstIndices();
+     }
+  
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
                                               memrefDest, destIndices,

>From 2af05facca92f238bc6e351f858b55bc3aed6586 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Wed, 6 Aug 2025 03:54:58 -0500
Subject: [PATCH 3/3] Update tests

---
 .../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir   | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 57afa127c9da8..47bd82edd6212 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
 // CHECK: func @test_expand_shape
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
-  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+  // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
 
-  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<8192xf16>
-  %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+  amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
     : vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
   func.return
 }
@@ -80,15 +82,17 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+  // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<64x128xf16>
-  %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+  %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
-    : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
+    : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
   func.return
 }



More information about the Mlir-commits mailing list