[Mlir-commits] [mlir] [AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (PR #152277)

Alan Li llvmlistbot at llvm.org
Wed Aug 6 02:45:04 PDT 2025


================
@@ -28,63 +29,78 @@ struct AmdgpuFoldMemRefOpsPass final
   }
 };
 
+
+static LogicalResult foldMemrefViewOp(
+    PatternRewriter &rewriter, Location loc, 
+    Value view, mlir::OperandRange indices, 
+    SmallVectorImpl<Value> &resolvedIndices, 
+    Value &memrefBase, StringRef role) 
+{
+    Operation *defOp = view.getDefiningOp();
+    if (!defOp) {
+        return failure();
+    }
+    return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+        .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+            mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                rewriter, loc, subviewOp.getMixedOffsets(),
+                subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                indices, resolvedIndices);
+            memrefBase = subviewOp.getSource();
+            return success();
+        })
+        .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                    loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
+                return failure();
+            }
+            memrefBase = expandShapeOp.getViewSource();
+            return success();
+        })
+        .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                    loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
+                return failure();
+            }
+            memrefBase = collapseShapeOp.getViewSource();
+            return success();
+        })
+        .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(
+                op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
+        });
+}
+
+
 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherToLDSOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
 
-    Value memrefSource;
-    SmallVector<Value> sourceIndices;
-    auto foldResult =
-        llvm::TypeSwitch<Operation *, LogicalResult>(
-            op.getSrc().getDefiningOp())
-            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
-              // If the source is a SubViewOp, we can directly rewrite the
-              // GatherToLDSOp.
-              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-                  rewriter, loc, subviewOp.getMixedOffsets(),
-                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
-                  op.getSrcIndices(), sourceIndices);
-              memrefSource = subviewOp.getSource();
-              return success();
-            })
-            .Case<memref::ExpandShapeOp>(
-                [&](memref::ExpandShapeOp expandShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
-                          sourceIndices, false))) {
-                    return failure();
-                  }
-                  memrefSource = expandShapeOp.getViewSource();
-                  return success();
-                })
-            .Case<memref::CollapseShapeOp>(
-                [&](memref::CollapseShapeOp collapseShapeOp) {
-                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
-                          sourceIndices))) {
-                    return failure();
-                  }
-                  memrefSource = collapseShapeOp.getViewSource();
-                  return success();
-                })
-            .Default([&](Operation *op) {
-              // If the source is not a SubViewOp, ExpandShapeOp, or
-              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
-              return rewriter.notifyMatchFailure(
-                  op,
-                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
-                  "CollapseShapeOp");
-            });
+    SmallVector<Value> sourceIndices, destIndices;
+    Value memrefSource, memrefDest;
 
-    if (failed(foldResult)) {
-      return failure();
+    auto foldSrcResult = foldMemrefViewOp(
+        rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+  
+    if (failed(foldSrcResult)) {
+        memrefSource = op.getSrc();
+        sourceIndices = op.getSrcIndices();
     }
 
+    auto foldDstResult = foldMemrefViewOp(
+        rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
+
+    if (failed(foldDstResult)) {
+        memrefDest = op.getDst();
+        destIndices = op.getDstIndices();
+     }
----------------
lialan wrote:

Can you also add a few unit tests to test the failed path?

If either the `src` or `dst` is coming from the function argument list, the `getDefiningOp()` will return nullptr, hence make it fail.

https://github.com/llvm/llvm-project/pull/152277


More information about the Mlir-commits mailing list