[Mlir-commits] [mlir] b7f889a - [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (#150503)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 24 16:58:33 PDT 2025


Author: Quinn Dawkins
Date: 2025-07-24T19:58:30-04:00
New Revision: b7f889a29cca10e5227813578ab5b85dc1c81fda

URL: https://github.com/llvm/llvm-project/commit/b7f889a29cca10e5227813578ab5b85dc1c81fda
DIFF: https://github.com/llvm/llvm-project/commit/b7f889a29cca10e5227813578ab5b85dc1c81fda.diff

LOG: [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (#150503)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Dialect/AMDGPU/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index b237f7b5749e7..92aacdaef4136 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
     $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def AMDGPU_TransposeLoadOp :

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 18e8270f5aa99..9a0a230e8abca 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GatherToLDSOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult GatherToLDSOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
   MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,42 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+namespace {
+/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
+/// information or changes layout, the cast can be skipped.
+struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    bool modified = false;
+    auto foldCast = [&](OpOperand &operand) {
+      if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
+        if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
+          rewriter.modifyOpInPlace(gatherOp,
+                                   [&] { operand.assign(castOp.getSource()); });
+          modified = true;
+        }
+      }
+    };
+
+    foldCast(gatherOp.getSrcMutable());
+    foldCast(gatherOp.getDstMutable());
+
+    return success(modified);
+  }
+};
+} // namespace
+
+void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldGatherToLDSOfCast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TransposeLoadOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult TransposeLoadOp::verify() {
   MemRefType srcType = cast<MemRefType>(getSrc().getType());
 

diff  --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 4559e39cf0569..5501ad42dbd90 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
   amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast
+func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+  %c0 = arith.constant 0 : index
+  %0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
+  // CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
+  // CHECK-SAME: : f32, memref<128x72xf32, 1>
+  amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
+    : f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
+func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
+// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
+// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
+  %c0 = arith.constant 0 : index
+  %0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
+  // CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
+  // CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
+  amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
+    : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
+  func.return
+}


        


More information about the Mlir-commits mailing list