[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (PR #150503)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Jul 24 13:14:22 PDT 2025
https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/150503
>From 60d5258c3e23eafd2088cbe66a1c2f3d09ccb842 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 24 Jul 2025 15:02:41 -0400
Subject: [PATCH 1/3] [mlir][AMDGPU] Add canonicalizer for folding casts into
gather_to_lds
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 1 +
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 57 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/canonicalize.mlir | 14 +++++
3 files changed, 72 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b237f7b5749e7..92aacdaef4136 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def AMDGPU_TransposeLoadOp :
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 18e8270f5aa99..28eb99600f48b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GatherToLDSOp
+//===----------------------------------------------------------------------===//
+
LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,59 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+namespace {
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
+ using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ bool modified = false;
+
+ // Check source.
+ if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+ if (fromType && toType &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.modifyOpInPlace(gatherOp, [&] {
+ gatherOp.getSrcMutable().assign(castOp.getSource());
+ });
+ modified = true;
+ }
+ }
+
+ // Check target.
+ if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+
+ if (fromType && toType &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.modifyOpInPlace(gatherOp, [&] {
+ gatherOp.getDstMutable().assign(castOp.getSource());
+ });
+ modified = true;
+ }
+ }
+
+ return success(modified);
+ }
+};
+} // namespace
+
+void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldGatherToLDSOfCast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeLoadOp
+//===----------------------------------------------------------------------===//
+
LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 4559e39cf0569..19f258f439bbf 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -130,3 +130,17 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast
+func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+ %c0 = arith.constant 0 : index
+ %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
+ // CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
+ // CHECK-SAME: : f32, memref<128x72xf32, 1>
+ amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
+ : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
+ func.return
+}
>From bc4c650f713f9bc22f18b3b7c701363ab4662813 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 24 Jul 2025 15:38:57 -0400
Subject: [PATCH 2/3] Update comment + address comments + add dest test
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 19 +++++++++----------
mlir/test/Dialect/AMDGPU/canonicalize.mlir | 15 +++++++++++++++
2 files changed, 24 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 28eb99600f48b..823f0c041231d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -551,10 +551,9 @@ LogicalResult GatherToLDSOp::verify() {
}
namespace {
-/// If the source/target of a CopyOp is a CastOp that does not modify the shape
-/// and element type, the cast can be skipped. Such CastOps only cast the layout
-/// of the type.
-struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
+/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
+/// information or changes layout, the cast can be skipped.
+struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
@@ -563,10 +562,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
// Check source.
if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
- if (fromType && toType &&
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
fromType.getElementType() == toType.getElementType()) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getSrcMutable().assign(castOp.getSource());
@@ -577,10 +576,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
// Check target.
if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
- if (fromType && toType &&
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
fromType.getElementType() == toType.getElementType()) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getDstMutable().assign(castOp.getSource());
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 19f258f439bbf..5501ad42dbd90 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -144,3 +144,18 @@ func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memr
: f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
+func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
+ %c0 = arith.constant 0 : index
+ %0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
+ // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
+ // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
+ amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
+ : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
+ func.return
+}
>From acda939fd53fa2354d8838a2fda1474debefd58d Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 24 Jul 2025 16:14:09 -0400
Subject: [PATCH 3/3] Drop redundant lines
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 12 ++----------
1 file changed, 2 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 823f0c041231d..626808d8586f4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -562,11 +562,7 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
// Check source.
if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
- auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
-
- if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
- fromType.getElementType() == toType.getElementType()) {
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getSrcMutable().assign(castOp.getSource());
});
@@ -576,11 +572,7 @@ struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
// Check target.
if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
- auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
-
- if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
- fromType.getElementType() == toType.getElementType()) {
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
rewriter.modifyOpInPlace(gatherOp, [&] {
gatherOp.getDstMutable().assign(castOp.getSource());
});
More information about the Mlir-commits
mailing list