[Mlir-commits] [mlir] [AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (PR #152277)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 6 02:57:04 PDT 2025
https://github.com/sebvince updated https://github.com/llvm/llvm-project/pull/152277
>From d63ecf5a885580e76ab932b5d0f8de2fae320e72 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Mon, 4 Aug 2025 04:51:38 -0500
Subject: [PATCH 1/4] [AMDGPU] fold dst operand of gather to lds
---
.../AMDGPU/Transforms/FoldMemRefsOps.cpp | 98 ++++++++++---------
1 file changed, 52 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7ee385ed..0ca9970bb7992 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -28,63 +29,68 @@ struct AmdgpuFoldMemRefOpsPass final
}
};
+
+static LogicalResult foldMemrefViewOp(
+ PatternRewriter &rewriter, Location loc,
+ Value view, mlir::OperandRange indices,
+ SmallVectorImpl<Value> &resolvedIndices,
+ Value &memrefBase, StringRef role)
+{
+ Operation *defOp = view.getDefiningOp();
+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+ indices, resolvedIndices);
+ memrefBase = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
+ return failure();
+ }
+ memrefBase = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
+ return failure();
+ }
+ memrefBase = collapseShapeOp.getViewSource();
+ return success();
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(
+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
+ });
+}
+
+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value memrefSource;
- SmallVector<Value> sourceIndices;
- auto foldResult =
- llvm::TypeSwitch<Operation *, LogicalResult>(
- op.getSrc().getDefiningOp())
- .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
- // If the source is a SubViewOp, we can directly rewrite the
- // GatherToLDSOp.
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loc, subviewOp.getMixedOffsets(),
- subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
- op.getSrcIndices(), sourceIndices);
- memrefSource = subviewOp.getSource();
- return success();
- })
- .Case<memref::ExpandShapeOp>(
- [&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, op.getSrcIndices(),
- sourceIndices, false))) {
- return failure();
- }
- memrefSource = expandShapeOp.getViewSource();
- return success();
- })
- .Case<memref::CollapseShapeOp>(
- [&](memref::CollapseShapeOp collapseShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, op.getSrcIndices(),
- sourceIndices))) {
- return failure();
- }
- memrefSource = collapseShapeOp.getViewSource();
- return success();
- })
- .Default([&](Operation *op) {
- // If the source is not a SubViewOp, ExpandShapeOp, or
- // CollapseShapeOp, we cannot fold the GatherToLDSOp.
- return rewriter.notifyMatchFailure(
- op,
- "source producer is not one of SubViewOp, ExpandShapeOp, or "
- "CollapseShapeOp");
- });
+ SmallVector<Value> sourceIndices, destIndices;
+ Value memrefSource, memrefDest;
+
+ auto foldSrcResult = foldMemrefViewOp(
+ rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+
+ auto foldDstResult = foldMemrefViewOp(
+ rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
- if (failed(foldResult)) {
+ if (failed(foldSrcResult) || failed(foldDstResult)) {
return failure();
}
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
- op.getDst(), op.getDstIndices(),
- op.getTransferType());
+ memrefDest, destIndices,
+ op.getTransferType());
return success();
}
>From b4908bde58802d872eba87e05635187866e7763b Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Tue, 5 Aug 2025 04:53:02 -0500
Subject: [PATCH 2/4] Fix cast issue thanks to me at alanli.org
---
.../Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 0ca9970bb7992..4474c5de97b0e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -37,6 +37,9 @@ static LogicalResult foldMemrefViewOp(
Value &memrefBase, StringRef role)
{
Operation *defOp = view.getDefiningOp();
+ if (!defOp) {
+ return failure();
+ }
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
@@ -81,12 +84,19 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
auto foldSrcResult = foldMemrefViewOp(
rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+ if (failed(foldSrcResult)) {
+ memrefSource = op.getSrc();
+ sourceIndices = op.getSrcIndices();
+ }
+
auto foldDstResult = foldMemrefViewOp(
rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
- if (failed(foldSrcResult) || failed(foldDstResult)) {
- return failure();
- }
+ if (failed(foldDstResult)) {
+ memrefDest = op.getDst();
+ destIndices = op.getDstIndices();
+ }
+
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
memrefDest, destIndices,
>From 2af05facca92f238bc6e351f858b55bc3aed6586 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Wed, 6 Aug 2025 03:54:58 -0500
Subject: [PATCH 3/4] Update tests
---
.../Dialect/AMDGPU/amdgpu-fold-memrefs.mlir | 28 +++++++++++--------
1 file changed, 16 insertions(+), 12 deletions(-)
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 57afa127c9da8..47bd82edd6212 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
// CHECK: func @test_expand_shape
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
- // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+ // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
- // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
- // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+ // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+ // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
+ // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
+ // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
- %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+ %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<8192xf16>
- %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+ %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+ %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
- amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+ amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
func.return
}
@@ -80,15 +82,17 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
- // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+ // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+ // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
+ // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+ %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<64x128xf16>
- %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+ %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
%c0 = arith.constant 0 : index
- amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
- : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
+ amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
+ : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
func.return
}
>From 2ae5b13ac80ae130c402b2431b906dff0fb70891 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Wed, 6 Aug 2025 04:55:57 -0500
Subject: [PATCH 4/4] Fix formatting and removed unnecessary include file
---
.../AMDGPU/Transforms/FoldMemRefsOps.cpp | 103 +++++++++---------
1 file changed, 52 insertions(+), 51 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index 4474c5de97b0e..d54751098410b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -29,49 +28,50 @@ struct AmdgpuFoldMemRefOpsPass final
}
};
-
-static LogicalResult foldMemrefViewOp(
- PatternRewriter &rewriter, Location loc,
- Value view, mlir::OperandRange indices,
- SmallVectorImpl<Value> &resolvedIndices,
- Value &memrefBase, StringRef role)
-{
- Operation *defOp = view.getDefiningOp();
- if (!defOp) {
- return failure();
- }
- return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
- .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loc, subviewOp.getMixedOffsets(),
- subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
- indices, resolvedIndices);
- memrefBase = subviewOp.getSource();
- return success();
- })
- .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
- return failure();
- }
- memrefBase = expandShapeOp.getViewSource();
- return success();
- })
- .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
+ Value view, mlir::OperandRange indices,
+ SmallVectorImpl<Value> &resolvedIndices,
+ Value &memrefBase, StringRef role) {
+ Operation *defOp = view.getDefiningOp();
+ if (!defOp) {
+ return failure();
+ }
+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
+ resolvedIndices);
+ memrefBase = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices,
+ false))) {
+ return failure();
+ }
+ memrefBase = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>(
+ [&](memref::CollapseShapeOp collapseShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
- return failure();
+ loc, rewriter, collapseShapeOp, indices,
+ resolvedIndices))) {
+ return failure();
}
memrefBase = collapseShapeOp.getViewSource();
return success();
- })
- .Default([&](Operation *op) {
- return rewriter.notifyMatchFailure(
- op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
- });
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(
+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
+ "CollapseShapeOp")
+ .str());
+ });
}
-
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
@@ -81,26 +81,27 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
SmallVector<Value> sourceIndices, destIndices;
Value memrefSource, memrefDest;
- auto foldSrcResult = foldMemrefViewOp(
- rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
-
+ auto foldSrcResult =
+ foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
+ sourceIndices, memrefSource, "source");
+
if (failed(foldSrcResult)) {
- memrefSource = op.getSrc();
- sourceIndices = op.getSrcIndices();
+ memrefSource = op.getSrc();
+ sourceIndices = op.getSrcIndices();
}
- auto foldDstResult = foldMemrefViewOp(
- rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
+ auto foldDstResult =
+ foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
+ destIndices, memrefDest, "destination");
if (failed(foldDstResult)) {
- memrefDest = op.getDst();
- destIndices = op.getDstIndices();
- }
-
+ memrefDest = op.getDst();
+ destIndices = op.getDstIndices();
+ }
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
- memrefDest, destIndices,
- op.getTransferType());
+ memrefDest, destIndices,
+ op.getTransferType());
return success();
}
More information about the Mlir-commits
mailing list