[Mlir-commits] [mlir] [AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds for DST operand (PR #152277)
Alan Li
llvmlistbot at llvm.org
Wed Aug 6 02:45:04 PDT 2025
================
@@ -28,63 +29,78 @@ struct AmdgpuFoldMemRefOpsPass final
}
};
+
+static LogicalResult foldMemrefViewOp(
+ PatternRewriter &rewriter, Location loc,
+ Value view, mlir::OperandRange indices,
+ SmallVectorImpl<Value> &resolvedIndices,
+ Value &memrefBase, StringRef role)
+{
+ Operation *defOp = view.getDefiningOp();
+ if (!defOp) {
+ return failure();
+ }
+ return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
+ .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, subviewOp.getMixedOffsets(),
+ subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+ indices, resolvedIndices);
+ memrefBase = subviewOp.getSource();
+ return success();
+ })
+ .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) {
+ return failure();
+ }
+ memrefBase = expandShapeOp.getViewSource();
+ return success();
+ })
+ .Case<memref::CollapseShapeOp>([&](memref::CollapseShapeOp collapseShapeOp) {
+ if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, indices, resolvedIndices))) {
+ return failure();
+ }
+ memrefBase = collapseShapeOp.getViewSource();
+ return success();
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(
+ op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or CollapseShapeOp").str());
+ });
+}
+
+
struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value memrefSource;
- SmallVector<Value> sourceIndices;
- auto foldResult =
- llvm::TypeSwitch<Operation *, LogicalResult>(
- op.getSrc().getDefiningOp())
- .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
- // If the source is a SubViewOp, we can directly rewrite the
- // GatherToLDSOp.
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loc, subviewOp.getMixedOffsets(),
- subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
- op.getSrcIndices(), sourceIndices);
- memrefSource = subviewOp.getSource();
- return success();
- })
- .Case<memref::ExpandShapeOp>(
- [&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, op.getSrcIndices(),
- sourceIndices, false))) {
- return failure();
- }
- memrefSource = expandShapeOp.getViewSource();
- return success();
- })
- .Case<memref::CollapseShapeOp>(
- [&](memref::CollapseShapeOp collapseShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, op.getSrcIndices(),
- sourceIndices))) {
- return failure();
- }
- memrefSource = collapseShapeOp.getViewSource();
- return success();
- })
- .Default([&](Operation *op) {
- // If the source is not a SubViewOp, ExpandShapeOp, or
- // CollapseShapeOp, we cannot fold the GatherToLDSOp.
- return rewriter.notifyMatchFailure(
- op,
- "source producer is not one of SubViewOp, ExpandShapeOp, or "
- "CollapseShapeOp");
- });
+ SmallVector<Value> sourceIndices, destIndices;
+ Value memrefSource, memrefDest;
- if (failed(foldResult)) {
- return failure();
+ auto foldSrcResult = foldMemrefViewOp(
+ rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source");
+
+ if (failed(foldSrcResult)) {
+ memrefSource = op.getSrc();
+ sourceIndices = op.getSrcIndices();
}
+ auto foldDstResult = foldMemrefViewOp(
+ rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination");
+
+ if (failed(foldDstResult)) {
+ memrefDest = op.getDst();
+ destIndices = op.getDstIndices();
+ }
----------------
lialan wrote:
Can you also add a few unit tests to test the failed path?
If either the `src` or `dst` is coming from the function argument list, the `getDefiningOp()` will return nullptr, hence make it fail.
https://github.com/llvm/llvm-project/pull/152277
More information about the Mlir-commits
mailing list