[Mlir-commits] [mlir] 8949dc7 - [mlir][amdgpu] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (#152277)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 8 05:47:37 PDT 2025


Author: sebvince
Date: 2025-08-08T05:47:33-07:00
New Revision: 8949dc7f9c2405fea2d07bab5bce08576ddb92a4

URL: https://github.com/llvm/llvm-project/commit/8949dc7f9c2405fea2d07bab5bce08576ddb92a4
DIFF: https://github.com/llvm/llvm-project/commit/8949dc7f9c2405fea2d07bab5bce08576ddb92a4.diff

LOG: [mlir][amdgpu] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (#152277)

Added: 
    

Modified: 
    mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
    mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index a3fdc7ee385ed..d54751098410b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final
   }
 };
 
+static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
+                                      Value view, mlir::OperandRange indices,
+                                      SmallVectorImpl<Value> &resolvedIndices,
+                                      Value &memrefBase, StringRef role) {
+  Operation *defOp = view.getDefiningOp();
+  if (!defOp) {
+    return failure();
+  }
+  return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+      .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+        mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+            rewriter, loc, subviewOp.getMixedOffsets(),
+            subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
+            resolvedIndices);
+        memrefBase = subviewOp.getSource();
+        return success();
+      })
+      .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+        if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                loc, rewriter, expandShapeOp, indices, resolvedIndices,
+                false))) {
+          return failure();
+        }
+        memrefBase = expandShapeOp.getViewSource();
+        return success();
+      })
+      .Case<memref::CollapseShapeOp>(
+          [&](memref::CollapseShapeOp collapseShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                    loc, rewriter, collapseShapeOp, indices,
+                    resolvedIndices))) {
+              return failure();
+            }
+            memrefBase = collapseShapeOp.getViewSource();
+            return success();
+          })
+      .Default([&](Operation *op) {
+        return rewriter.notifyMatchFailure(
+            op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
+                        "CollapseShapeOp")
+                    .str());
+      });
+}
+
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    Value memrefSource;
-    SmallVector<Value> sourceIndices;
-    auto foldResult =
-        llvm::TypeSwitch<Operation *, LogicalResult>(
-            op.getSrc().getDefiningOp())
-            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-              // If the source is a SubViewOp, we can directly rewrite the
-              // GatherToLDSOp.
-              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-                  rewriter, loc, subviewOp.getMixedOffsets(),
-                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-                  op.getSrcIndices(), sourceIndices);
-              memrefSource = subviewOp.getSource();
-              return success();
-            })
-            .Case<memref::ExpandShapeOp>(
-                [&](memref::ExpandShapeOp expandShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
-                          sourceIndices, false))) {
-                    return failure();
-                  }
-                  memrefSource = expandShapeOp.getViewSource();
-                  return success();
-                })
-            .Case<memref::CollapseShapeOp>(
-                [&](memref::CollapseShapeOp collapseShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
-                          sourceIndices))) {
-                    return failure();
-                  }
-                  memrefSource = collapseShapeOp.getViewSource();
-                  return success();
-                })
-            .Default([&](Operation *op) {
-              // If the source is not a SubViewOp, ExpandShapeOp, or
-              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
-              return rewriter.notifyMatchFailure(
-                  op,
-                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
-                  "CollapseShapeOp");
-            });
+    SmallVector<Value> sourceIndices, destIndices;
+    Value memrefSource, memrefDest;
+
+    auto foldSrcResult =
+        foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
+                         sourceIndices, memrefSource, "source");
+
+    if (failed(foldSrcResult)) {
+      memrefSource = op.getSrc();
+      sourceIndices = op.getSrcIndices();
+    }
+
+    auto foldDstResult =
+        foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
+                         destIndices, memrefDest, "destination");
 
-    if (failed(foldResult)) {
-      return failure();
+    if (failed(foldDstResult)) {
+      memrefDest = op.getDst();
+      destIndices = op.getDstIndices();
     }
 
     rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
-                                               op.getDst(), op.getDstIndices(),
+                                               memrefDest, destIndices,
                                                op.getTransferType());
 
     return success();

diff  --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
index 57afa127c9da8..8ca3dd60414df 100644
--- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
 // CHECK: func @test_expand_shape
 // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
 func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
-  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
-  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+  // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
 
-  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<8192xf16>
-  %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
+  %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
+  amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
     : vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
   func.return
 }
@@ -80,15 +82,82 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
   // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
   // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
-  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
+  // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
   // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
 
   %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
   %mem = memref.alloc() : memref<64x128xf16>
-  %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
+  %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
   %c0 = arith.constant 0 : index
-  amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
+  amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
+    : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+
+// CHECK: func @test_expand_shape_src_raw_buffer
+// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @test_expand_shape_src_raw_buffer(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG1]], %[[ARG2]]] by (64, 128) : index
+  // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[IDXM]]], %[[LOCAL]][%[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
+
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+  %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>> into memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>
+
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %alloc[%c0]
+    : vector<8xf16>, memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape_dst_only
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape_dst_only(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX_LDS:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (64, 64) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]]], %[[LOCAL]][%[[IDX_LDS]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>
+
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<8192xf16>
+  %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
+
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %mem[%offset_i], %expand_alloc[%offset_j, %c0]
     : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
   func.return
 }
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_nop
+// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+func.func @test_nop(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
+  // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[ARG1]]], %[[LOCAL]][%[[ARG2]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>
+
+  %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %mem[%offset_i], %alloc[%offset_j]
+    : vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
+  func.return
+}


        


More information about the Mlir-commits mailing list