[Mlir-commits] [mlir] [AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (PR #152277)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 6 02:57:04 PDT 2025


https://github.com/sebvince updated https://github.com/llvm/llvm-project/pull/152277

>From d63ecf5a885580e76ab932b5d0f8de2fae320e72 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Mon, 4 Aug 2025 04:51:38 -0500
Subject: [PATCH 1/4] [AMDGPU] fold dst operand of gather to lds

---
 .../AMDGPU/Transforms/FoldMemRefsOps.cpp      | 98 ++++++++++---------
 1 file changed, 52 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7ee385ed..0ca9970bb7992 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -28,63 +29,68 @@ struct AmdgpuFoldMemRefOpsPass final
   }
 };
 
+
+static LogicalResult foldMemrefViewOp(
+    PatternRewriter &rewriter, Location loc, 
+    Value view, mlir::OperandRange indices, 
+    SmallVectorImpl<Value> &resolvedIndices, 
+    Value &memrefBase, StringRef role) 
+{
+    Operation *defOp = view.getDefiningOp();
+    return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+            mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                rewriter, loc, subviewOp.getMixedOffsets(),
+                subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                indices, resolvedIndices);
+            memrefBase = subviewOp.getSource();
+            return success();
+        })
+        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                    loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
+                return failure();
+            }
+            memrefBase = expandShapeOp.getViewSource();
+            return success();
+        })
+        .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                    loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
+                return failure();
+            }
+            memrefBase = collapseShapeOp.getViewSource();
+            return success();
+        })
+        .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(
+                op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
+        });
+}
+
+
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    Value memrefSource;
-    SmallVector<Value> sourceIndices;
-    auto foldResult =
-        llvm::TypeSwitch<Operation *, LogicalResult>(
-            op.getSrc().getDefiningOp())
-            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-              // If the source is a SubViewOp, we can directly rewrite the
-              // GatherToLDSOp.
-              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-                  rewriter, loc, subviewOp.getMixedOffsets(),
-                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-                  op.getSrcIndices(), sourceIndices);
-              memrefSource = subviewOp.getSource();
-              return success();
-            })
-            .Case<memref::ExpandShapeOp>(
-                [&](memref::ExpandShapeOp expandShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
-                          sourceIndices, false))) {
-                    return failure();
-                  }
-                  memrefSource = expandShapeOp.getViewSource();
-                  return success();
-                })
-            .Case<memref::CollapseShapeOp>(
-                [&](memref::CollapseShapeOp collapseShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
-                          sourceIndices))) {
-                    return failure();
-                  }
-                  memrefSource = collapseShapeOp.getViewSource();
-                  return success();
-                })
-            .Default([&](Operation *op) {
-              // If the source is not a SubViewOp, ExpandShapeOp, or
-              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
-              return rewriter.notifyMatchFailure(
-                  op,
-                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
-                  "CollapseShapeOp");
-            });
+    SmallVector<Value> sourceIndices, destIndices;
+    Value memrefSource, memrefDest;
+
+    auto foldSrcResult = foldMemrefViewOp(
+        rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+  
+    auto foldDstResult = foldMemrefViewOp(
+        rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
 
-    if (failed(foldResult)) {
+    if (failed(foldSrcResult) || failed(foldDstResult)) {
       return failure();
     }
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
-                                               op.getDst(), op.getDstIndices(),
-                                               op.getTransferType());
+                                              memrefDest, destIndices,
+                                              op.getTransferType());
 
     return success();
   }

>From b4908bde58802d872eba87e05635187866e7763b Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Tue, 5 Aug 2025 04:53:02 -0500
Subject: [PATCH 2/4] Fix cast issue thanks to me at alanli.org

---
 .../Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 0ca9970bb7992..4474c5de97b0e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -37,6 +37,9 @@ static LogicalResult foldMemrefViewOp(
     Value &memrefBase, StringRef role) 
 {
     Operation *defOp = view.getDefiningOp();
+    if (!defOp) {
+        return failure();
+    }
     return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
         .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
             mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
@@ -81,12 +84,19 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
     auto foldSrcResult = foldMemrefViewOp(
         rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
   
+    if (failed(foldSrcResult)) {
+        memrefSource = op.getSrc();
+        sourceIndices = op.getSrcIndices();
+    }
+
     auto foldDstResult = foldMemrefViewOp(
         rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
 
-    if (failed(foldSrcResult) || failed(foldDstResult)) {
-      return failure();
-    }
+    if (failed(foldDstResult)) {
+        memrefDest = op.getDst();
+        destIndices = op.getDstIndices();
+     }
+  
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
                                               memrefDest, destIndices,

>From 2af05facca92f238bc6e351f858b55bc3aed6586 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Wed, 6 Aug 2025 03:54:58 -0500
Subject: [PATCH 3/4] Update tests

---
 .../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir   | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 57afa127c9da8..47bd82edd6212 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
 // CHECK: func @test_expand_shape
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
-  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+  // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
 
-  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<8192xf16>
-  %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+  amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
     : vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
   func.return
 }
@@ -80,15 +82,17 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+  // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<64x128xf16>
-  %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+  %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
-    : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
+    : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
   func.return
 }

>From 2ae5b13ac80ae130c402b2431b906dff0fb70891 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Wed, 6 Aug 2025 04:55:57 -0500
Subject: [PATCH 4/4] Fix formatting and removed unnecessary include file

---
 .../AMDGPU/Transforms/FoldMemRefsOps.cpp      | 103 +++++++++---------
 1 file changed, 52 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 4474c5de97b0e..d54751098410b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -29,49 +28,50 @@ struct AmdgpuFoldMemRefOpsPass final
   }
 };
 
-
-static LogicalResult foldMemrefViewOp(
-    PatternRewriter &rewriter, Location loc, 
-    Value view, mlir::OperandRange indices, 
-    SmallVectorImpl<Value> &resolvedIndices, 
-    Value &memrefBase, StringRef role) 
-{
-    Operation *defOp = view.getDefiningOp();
-    if (!defOp) {
-        return failure();
-    }
-    return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
-        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-            mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-                rewriter, loc, subviewOp.getMixedOffsets(),
-                subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-                indices, resolvedIndices);
-            memrefBase = subviewOp.getSource();
-            return success();
-        })
-        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
-            if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                    loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
-                return failure();
-            }
-            memrefBase = expandShapeOp.getViewSource();
-            return success();
-        })
-        .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
+                                      Value view, mlir::OperandRange indices,
+                                      SmallVectorImpl<Value> &resolvedIndices,
+                                      Value &memrefBase, StringRef role) {
+  Operation *defOp = view.getDefiningOp();
+  if (!defOp) {
+    return failure();
+  }
+  return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+      .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+        mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+            rewriter, loc, subviewOp.getMixedOffsets(),
+            subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
+            resolvedIndices);
+        memrefBase = subviewOp.getSource();
+        return success();
+      })
+      .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+        if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                loc, rewriter, expandShapeOp, indices, resolvedIndices,
+                false))) {
+          return failure();
+        }
+        memrefBase = expandShapeOp.getViewSource();
+        return success();
+      })
+      .Case<memref::CollapseShapeOp>(
+          [&](memref::CollapseShapeOp collapseShapeOp) {
             if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                    loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
-                return failure();
+                    loc, rewriter, collapseShapeOp, indices,
+                    resolvedIndices))) {
+              return failure();
             }
             memrefBase = collapseShapeOp.getViewSource();
             return success();
-        })
-        .Default([&](Operation *op) {
-            return rewriter.notifyMatchFailure(
-                op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
-        });
+          })
+      .Default([&](Operation *op) {
+        return rewriter.notifyMatchFailure(
+            op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
+                        "CollapseShapeOp")
+                    .str());
+      });
 }
 
-
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
@@ -81,26 +81,27 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
     SmallVector<Value> sourceIndices, destIndices;
     Value memrefSource, memrefDest;
 
-    auto foldSrcResult = foldMemrefViewOp(
-        rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
-  
+    auto foldSrcResult =
+        foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
+                         sourceIndices, memrefSource, "source");
+
     if (failed(foldSrcResult)) {
-        memrefSource = op.getSrc();
-        sourceIndices = op.getSrcIndices();
+      memrefSource = op.getSrc();
+      sourceIndices = op.getSrcIndices();
     }
 
-    auto foldDstResult = foldMemrefViewOp(
-        rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
+    auto foldDstResult =
+        foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
+                         destIndices, memrefDest, "destination");
 
     if (failed(foldDstResult)) {
-        memrefDest = op.getDst();
-        destIndices = op.getDstIndices();
-     }
-  
+      memrefDest = op.getDst();
+      destIndices = op.getDstIndices();
+    }
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
-                                              memrefDest, destIndices,
-                                              op.getTransferType());
+                                               memrefDest, destIndices,
+                                               op.getTransferType());
 
     return success();
   }



More information about the Mlir-commits mailing list